In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from pathlib import Path
import h5py
import pickle
from tqdm.notebook import tqdm
import cv2
# import diplib as dip

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader, random_split
# from torchsummary import summary

In [None]:
def brightness_reconstruction(img): # doi: 10.1109/TPS.2018.2828863.
    im_norm = img / 255
    img = np.average(im_norm,axis=None)
    img = np.log(im_norm + 1) * (im_norm - img)
    img = img / np.max(img)
    img = np.where(img < 0, 0, img)
    return img * 255

def process_image(img, kernel_size, sigma, threshold, erode_kernel):
    img = cv2.GaussianBlur(img,(kernel_size, kernel_size),0)
    # img = brightness_reconstruction(img)
    # img = np.array(dip.MatchedFiltersLineDetector2D(img, sigma = sigma)) # 10.1109/42.34715
    # img *= 255.0/img.max()
    img = brightness_reconstruction(img)
    # img = np.where(img < threshold, 0, 1).astype('uint8')
    # img = cv2.erode(img, np.ones((erode_kernel,erode_kernel), np.uint8), iterations=1)
    return img

In [None]:
file_name = 'cam_geo/s_outs_v3_limited.h5'
with h5py.File(file_name, 'r') as f:
    print(list(f.keys()))
    synthetic_images = f['image'][:]
    
# file_name_2 = 'outputs/hdf5/x_outer_radiation.hdf5'
file_name_2 = 'outputs/hdf5/x_outer_radiation.hdf5'
with h5py.File(file_name_2, 'r') as f:
    print(list(f.keys()))
    points = f['points'][:]
    tv_images = f['tv_images'][:]
    
tv_images = tv_images / 255
for i in range(len(tv_images)):
    tv_images[i] = cv2.flip(tv_images[i], 0)

In [None]:
idx = 1840
# TV Images
plt.subplot(2, 2, 1)
plt.imshow(tv_images[idx], origin='lower')
plt.colorbar(orientation='horizontal', ax=plt.gca())
plt.title('Real Image')

# Synthetic Images
plt.subplot(2, 2, 2)
plt.imshow(synthetic_images[idx], origin='lower')
plt.colorbar(orientation='horizontal', ax=plt.gca())
plt.title(f'Synthetic Images: idx = {idx}')

plt.tight_layout()
plt.show()

In [None]:
kernel_size = 5
sigma = 0.1
threshold = .1
erode_kernel = 1

process_tv_image =  process_image(tv_images[idx], kernel_size, sigma, threshold, erode_kernel)
plt.imshow(process_tv_image, origin='lower')
plt.show()

In [None]:
test = (tv_images - synthetic_images[idx]) * np.array(synthetic_images[idx])
plt.imshow(test[idx], origin='lower')
plt.title('Multiplication of Synthetic and Real Images for idx = {idx}')
plt.show()

In [None]:
test_scalar = np.sqrt(np.sum(test**2, axis=(1,2)))

In [None]:
print(test_scalar)

In [None]:
plt.plot(test_scalar)
plt.axvline(x=idx, color='red', linestyle='--')
plt.title('Most likely indicies (by RMS)')
print(np.argmax(test_scalar))
print(test_scalar[idx])