In [None]:
import sys
import numpy as np
import xarray as xr

from tensorflow.keras.layers import Input, Conv2D, Dense, Flatten, Activation, MaxPool2D, SeparableConv2D, UpSampling2D, concatenate, Conv2DTranspose

from tensorflow.keras.models import Model, save_model, load_model
from tensorflow.keras.utils import plot_model
#from tensorflow import concat
import matplotlib.pyplot as plt

import datetime
%matplotlib inline

# set path to local libraries
dirP_str = '../../../library'
if dirP_str not in sys.path:
    sys.path.append(dirP_str)

import ml_utils as ml

In [None]:
ds_path='/scr/sci/mhayman/holodec/holodec-ml-data/'
# ds_file = 'image_data_256x256_50count.nc'
# ds_file = 'image_data_256x256_5000count.nc'
# ds_file = 'image_data_64x64_5000count.nc'
# ds_file = 'random_image_data_64x64_5000count.nc'
# ds_file = 'random_image_data_64x64_5000count_v02.nc' # 1 um PSF with 1 cm depth
# ds_file = 'random_image_data_64x64_5000count_v03.nc' # 1 um PSF with 10 cm depth
# ds_file = 'random_image_multiplane_data_64x64_5000count.nc' # 1 um PSF with 10 cm depth
# ds_file = "random_image_multiplane_data_64x64_5000count_1particles.nc" # 1 um PSF with 1 cm depth with 1 particles
ds_file = "random_image_multiplane_data_64x64_5000count_2particles.nc" # 1 um PSF with 1 cm depth with 2 particles


# model file
model_file = "holodec_UNET_16Filt_5Conv_4Pool_mse_linear_random_image_multiplane_data_64x64_5000count300epochs_run1.h5"

ds = xr.open_dataset(ds_path+ds_file)

run_num = 0
num_epochs = 300

In [None]:
ds

In [None]:
split_index = np.int(0.7*ds.sizes['hologram_number'])  # number of training+validation points
valid_index = np.int(0.2*ds.sizes['hologram_number'])  # number of validation points
all_labels = ds['labels'].sel(type=['amplitude','z'])

train_labels = all_labels.isel(hologram_number=slice(valid_index,split_index))
test_labels = all_labels.isel(hologram_number=slice(split_index,None))
val_labels = all_labels.isel(hologram_number=slice(None,valid_index))

scaler = ml.MinMaxScalerX(train_labels,dim=('hologram_number','xsize','ysize'))
scaled_train_labels = scaler.fit_transform(train_labels)
scaled_val_labels = scaler.fit_transform(val_labels)
scaled_test_labels = scaler.fit_transform(test_labels)
scaled_all_labels = scaler.fit_transform(all_labels)

In [None]:
in_data = ds['image']

In [None]:
scaler.max

In [None]:
valid_index

In [None]:
if not 'channel' in in_data.dims:
    in_data = in_data.expand_dims("channel", 3)

In [None]:
scaled_in_data = in_data/2

In [None]:
plt_index = 25
plt.figure()
plt.imshow(scaled_in_data.values[plt_index,:,:,-1])
plt.colorbar()

plt.figure()
plt.imshow(all_labels[plt_index,:,:,0],vmin=0,vmax=1)
plt.colorbar()

plt.figure()
plt.imshow(all_labels[plt_index,:,:,1],vmin=0,vmax=1e-1)
plt.colorbar()

In [None]:
plt.figure()
plt.hist(scaled_train_labels.values[:,:,:,0].flatten(),log=True)

In [None]:
scaled_in_data.shape

In [None]:
scaled_train_labels.sizes

In [None]:
# load the CNN model
mod = load_model(ds_path+"/models/"+model_file)

In [None]:
cnn_start = datetime.datetime.now()
preds_out = mod.predict(scaled_in_data.values, batch_size=64)
cnn_stop = datetime.datetime.now()
print(f"{scaled_in_data.values.shape[0]} samples in {(cnn_stop-cnn_start).total_seconds()} seconds")
print(f"for {(cnn_stop-cnn_start).total_seconds()/scaled_in_data.values.shape[0]} seconds per hologram")

In [None]:
preds_out.shape

In [None]:
preds_out_da = xr.DataArray(preds_out,dims=('hologram_number','xsize','ysize','type'),
                            coords=all_labels.coords)

In [None]:
preds_original = scaler.inverse_transform(preds_out_da)

In [None]:
mean_error = (preds_original[split_index:] - test_labels).mean(dim=('hologram_number','xsize','ysize'))
std_error = (preds_original[split_index:] - test_labels).std(dim=('hologram_number','xsize','ysize'))

In [None]:
iscatter = np.nonzero(preds_original.sel(type='amplitude').values.flatten() > 0.2)
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
for a, clabel in enumerate(all_labels.coords['type'].values):
    ax=axes.ravel()[a]
    ax.scatter(all_labels.sel(type=clabel).values.flatten()[iscatter], preds_original.sel(type=clabel).values.flatten()[iscatter], 1, 'k')
    diag = np.linspace(all_labels.sel(type=clabel).min(), all_labels.sel(type=clabel).max(), 10)
    ax.plot(diag, diag, 'b--' )
    ax.set_title(clabel)
    plt.savefig("results/"+model_file.replace(".h5","")+f"_ScatterPlot"+f"_{num_epochs}epochs_run{run_num}_"+ds_file.replace(".nc","")+".png",dpi=300)

In [None]:
mean_error

In [None]:
std_error

In [None]:
index_list = [18,2854,1247,858,3143,832,4021,3921,222,2431,321] #18#2854#1247
for ind in index_list:
    fig, ax = plt.subplots(2, 3, figsize=(12, 8))
    ax = ax.ravel()
    ax[0].imshow(preds_original.sel(type='amplitude',hologram_number=ind).values,vmin=0,vmax=1)
    ax[1].imshow(all_labels.sel(type='amplitude',hologram_number=ind).values,vmin=0,vmax=1)
    ax[2].imshow(preds_original.sel(type='amplitude',hologram_number=ind).values-all_labels.sel(type='amplitude',hologram_number=ind).values,vmin=-1,vmax=1,cmap=plt.get_cmap('seismic'))
    # ax[2].imshow(scaled_in_data.values[ind,:,:,0])
    ax[3].imshow(preds_original.sel(type='z',hologram_number=ind).values,vmin=0,vmax=1e-2)
    ax[4].imshow(all_labels.sel(type='z',hologram_number=ind).values,vmin=0,vmax=1e-2)
    ax[5].imshow(preds_original.sel(type='z',hologram_number=ind).values-all_labels.sel(type='z',hologram_number=ind).values,vmin=-1e-2,vmax=1e-2,cmap=plt.get_cmap('seismic'))
    plt.savefig("results/"+model_file.replace(".h5","")+f"_ExampleImage{ind}"+f"_{num_epochs}epochs_run{run_num}_"+ds_file.replace(".nc","")+".png",dpi=300)

In [None]:
channel_number = in_data.sizes['channel']
# index_list = [18]
for ind in index_list:
    fig, ax = plt.subplots(2, channel_number//2, figsize=(channel_number*3, 8))
    for ai in range(channel_number):
        axind = ai//2+np.mod(ai,2)*channel_number//2
        ax[np.mod(ai,2),ai//2].imshow(scaled_in_data.isel(channel=ai,hologram_number=ind),vmin=-1,vmax=1)
    plt.savefig("results/"+model_file.replace(".h5","")+f"_ExampleInput{ind}"+f"_{num_epochs}epochs_run{run_num}_"+ds_file.replace(".nc","")+".png",dpi=300)

In [None]:
in_data.sizes