In [None]:
import numpy as np
from pathlib import Path
from scipy.io import readsav
import h5py
import pickle
import cv2
# import diplib as dip
from IPython.display import clear_output
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from skimage.transform import resize
from tqdm.notebook import tqdm

In [None]:
def brightness_reconstruction(img): # doi: 10.1109/TPS.2018.2828863.
    im_norm = img / 255
    img = np.average(im_norm,axis=None)
    img = np.log(im_norm + 1) * (im_norm - img)
    img = img / np.max(img)
    img = np.where(img < 0, 0, img)
    return img * 255

def process_image(img, kernel_size, sigma, threshold, erode_kernel):
    # img = cv2.GaussianBlur(img,(kernel_size, kernel_size),0)
    img = brightness_reconstruction(img)
    # img = np.array(dip.MatchedFiltersLineDetector2D(img, sigma = sigma)) # 10.1109/42.34715
    img *= 255.0/img.max()
    img = brightness_reconstruction(img)
    img = np.where(img < threshold, 0, 1).astype('uint8')
    # img = cv2.erode(img, np.ones((erode_kernel,erode_kernel), np.uint8), iterations=1)
    return img

def canny(img):
    # gray=(255-255*(img-np.min(img))/(np.max(img)-np.min(img))).astype('uint8')

    # reduce the noise using Gaussian filters
    kernel_size = 11 
    blur_gray = cv2.GaussianBlur(img,(kernel_size, kernel_size),0)

    # Apply Canny edge detctor
    low_threshold = 10
    high_threshold = 20
    edges = cv2.Canny(blur_gray, low_threshold, high_threshold)
    
    return edges

### Workflow

1. Loop across all files
1. Loop across all indicies in file
1. Get processed image, and r,l datapoints
1. Append to 3 arrays
1. After each full run, save process image array, r, l datapoint to hdf5

In [None]:
tv_image_path = Path('tv_images')
inversion_data_path = Path('outputs/inversion_data')
hdf5_path = Path('outputs/hdf5')
files = sorted(tv_image_path.glob('*.sav'))
file_lengths = [len(readsav(str(file))['emission_structure'][0][3]) for file in files]
cumulative_lengths = np.insert(np.cumsum(file_lengths), 0, 0)
tv_dim = readsav(str(files[0]))['emission_structure'][0][7][0].shape
inversion_dim = readsav(str(files[0]))['emission_structure'][0][0][0].shape

In [None]:
print(tv_dim)
print(inversion_dim)

In [None]:
crop_dim = (240, 480)
kernel_size = 5
sigma = 1
threshold = 4
erode_kernel = 4
aspect_num = 1/2

In [None]:
hdf5_file_name = hdf5_path / 'tv_raw.hdf5'
hf = h5py.File(hdf5_file_name, 'w') # open h5py file
tv_dataset = hf.create_dataset("tv_images", shape=(np.sum(file_lengths), tv_dim[0], tv_dim[1]), dtype='uint8')
points_dataset = hf.create_dataset("points", shape=(np.sum(file_lengths), 4), dtype='float32')

# Add datasets to the groups
for idx, file in enumerate(files):
    frames = readsav(file)['emission_structure'][0][3].astype(int)
    tv_image = readsav(file)['emission_structure'][0][7][frames]
    tv_image_process = np.asarray(tv_image) # faster process and convert to binary
    pkl_path = (inversion_data_path / file.stem).with_suffix('.pkl')
    with open(pkl_path, 'rb') as pkl_file:
            label_info = pickle.load(pkl_file)
    points = np.concatenate((label_info['l_location'], label_info['r_location']),1)
    
    for i in range(file_lengths[idx]):
        tv_dataset[cumulative_lengths[idx]+i] = tv_image_process[i]
        points_dataset[cumulative_lengths[idx]+i] = points[i]
hf.close()

In [None]:
hdf5_file_name = hdf5_path / 'tv_process_simple.hdf5'
hf = h5py.File(hdf5_file_name, 'w') # open h5py file
tv_dataset = hf.create_dataset("tv_images", shape=(np.sum(file_lengths), crop_dim[0], crop_dim[1]), dtype='uint8')
points_dataset = hf.create_dataset("points", shape=(np.sum(file_lengths), 4), dtype='float32')

# Add datasets to the groups
for idx, file in enumerate(files):
    print(f"{idx+1} of {len(files)}")
    frames = readsav(file)['emission_structure'][0][3].astype(int)
    tv_image = readsav(file)['emission_structure'][0][7][frames]
    pkl_path = (inversion_data_path / file.stem).with_suffix('.pkl')
    with open(pkl_path, 'rb') as pkl_file:
            label_info = pickle.load(pkl_file)
    points = np.concatenate((label_info['l_location'], label_info['r_location']),1)
    
    for i in range(file_lengths[idx]):
        tv_dataset[cumulative_lengths[idx]+i] = np.asarray(process_image(tv_image[i, 0:240, 240:720],kernel_size, sigma, threshold, erode_kernel))
        points_dataset[cumulative_lengths[idx]+i] = points[i]
    clear_output()
hf.close()

In [None]:
hdf5_file_name = hdf5_path / 'compiled_inversion_no_image.hdf5'
hf = h5py.File(hdf5_file_name, 'w') # open h5py file
rz_dataset = hf.create_dataset("rz", shape=(np.sum(file_lengths), 4), dtype='float32')
intensity_dataset = hf.create_dataset("intensity", shape=(np.sum(file_lengths), 2), dtype='float32')

# Add datasets to the groups
for idx, file in enumerate(files):
    pkl_path = (inversion_data_path / file.stem).with_suffix('.pkl')
    with open(pkl_path, 'rb') as pkl_file:
            label_info = pickle.load(pkl_file)
    points = np.concatenate((label_info['l_location'], label_info['r_location']),1)
    points_i = np.concatenate((label_info['l_intensity'], label_info['r_intensity']))

    for i in range(file_lengths[idx]):
            rz_dataset[cumulative_lengths[idx]+i] = points[i]
            intensity_dataset[cumulative_lengths[idx]+i] = points_i[i]
hf.close()

In [None]:
hdf5_file_name = hdf5_path / 'x_outer_radiation.hdf5'
hf = h5py.File(hdf5_file_name, 'w') # open h5py file
tv_dataset = hf.create_dataset("tv_images", shape=(np.sum(file_lengths), tv_dim[0], tv_dim[1]), dtype='uint8')
points_dataset = hf.create_dataset("points", shape=(np.sum(file_lengths), 4), dtype='float32')
intensity_dataset = hf.create_dataset("intensity", shape=(np.sum(file_lengths), 2), dtype='float32')
print(files)
# Add datasets to the groups
for idx, file in enumerate(files):
    frames = readsav(file)['emission_structure'][0][3].astype(int)
    tv_image = readsav(file)['emission_structure'][0][7][frames]
    tv_image_process = np.asarray(tv_image) # faster process and convert to binary
    pkl_path = (inversion_data_path / file.stem).with_suffix('.pkl')
    with open(pkl_path, 'rb') as pkl_file:
            label_info = pickle.load(pkl_file)
    points = np.concatenate((label_info['x_location'], label_info['r_location']),1)
    points_i = np.concatenate((label_info['x_intensity'], label_info['r_intensity']))
    
    for i in range(file_lengths[idx]):
        tv_dataset[cumulative_lengths[idx]+i] = tv_image_process[i]
        points_dataset[cumulative_lengths[idx]+i] = points[i]
        intensity_dataset[cumulative_lengths[idx]+i] = points_i[i]
hf.close()

### Split HDF5 into Train/Test/Validation

In [None]:
randomization = False
tts_percent = 0.6
tvs_percent = 0.8

file_name_1 = Path('outputs/hdf5/s_outs_v3_limited.h5')
file_name_2 = Path('outputs/hdf5/x_outer_radiation.hdf5')

out_path = Path('outputs')


with h5py.File(file_name_1, 'r') as f:
    synthetic_images = f['image'][:] * 2 - 1
    
with h5py.File(file_name_2, 'r') as f:
    points = f['points'][:]
    tv_images = f['tv_images'][:] / 127.5 - 1
    
file_len = len(points)

crop_synthetic = np.zeros((file_len, 256, 256))
crop_tv = np.zeros((file_len, 256, 256))

for i in tqdm(range(file_len)):
    crop_synthetic[i] = resize(synthetic_images[i], (256, 256))
    crop_tv[i] = np.flip(resize(tv_images[i], (256, 256)), axis=0)

In [None]:
synth_dat, synth_val, tv_dat, tv_val = train_test_split(crop_synthetic, crop_tv, train_size=tts_percent, random_state=42)
synth_train, synth_test, tv_train, tv_test = train_test_split(synth_dat, tv_dat, train_size=tvs_percent, random_state=42)

with h5py.File(out_path / 'img2img.h5', 'w') as f:
    f.create_dataset('synth_train', data=synth_train)
    f.create_dataset('synth_test', data=synth_test)
    f.create_dataset('synthval', data=synth_val)
    f.create_dataset('tv_train', data=tv_train)
    f.create_dataset('tv_test', data=tv_test)
    f.create_dataset('tv_val', data=tv_val)

In [None]:
axs1 = plt.subplot(1,2,1)
axs1.imshow(crop_tv[211], origin='lower')
axs2 = plt.subplot(1,2,2)
axs2.imshow(crop_synthetic[10], origin='lower')
plt.show()

In [None]:
skip_indices = []

for i in range(len(crop_tv)):
    clear_output()
    axs1 = plt.subplot(1,2,1)
    axs1.imshow(crop_tv[i], origin='lower')
    axs2 = plt.subplot(1,2,2)
    axs2.imshow(crop_synthetic[i], origin='lower')
    plt.show()
    print(f"Index: {i}")
    response = input("Press 'n' if invalid index: ")
    if response.lower() == 'n':
        skip_indices.append(i)
    clear_output()
