In [None]:
import sys
import numpy as np
import xarray as xr

from tensorflow.keras.layers import Input, Conv2D, Dense, Flatten, Activation, MaxPool2D, SeparableConv2D, UpSampling2D, concatenate, Conv2DTranspose

from tensorflow.keras.models import Model, save_model, load_model
from tensorflow.keras.utils import plot_model
#from tensorflow import concat
import matplotlib.pyplot as plt

import datetime
%matplotlib inline

# set path to local libraries
dirP_str = '../../../library'
if dirP_str not in sys.path:
    sys.path.append(dirP_str)

import ml_utils as ml

In [None]:
ds_path='/scr/sci/mhayman/holodec/holodec-ml-data/'

model_file = 'holodec_UNET_16Filt_5Conv_4Pool_4Layers_mse_linear_random_image_multiplane_data_256x256_1000count_1particles_v05_150epochs_run1.h5'
model_file = 'holodec_UNET_16Filt_5Conv_4Pool_5Layers_mse_linear_random_image_multiplane_data_256x256_1000count_1particles_v05_150epochs_run1.h5'

ds_file = 'random_image_multiplane_data_256x256_1000count_1particles_v05.nc'


In [None]:
ds = xr.open_dataset(ds_path+ds_file)

# run_num = 0
# num_epochs = 300

In [None]:
# select holograms to evaluate
# index_list = [18,854,247,858,143,832,21,921,222,431,321]
index_list = [10]

In [None]:
ds

In [None]:
split_index = np.int(0.7*ds.sizes['hologram_number'])  # number of training+validation points
valid_index = np.int(0.2*ds.sizes['hologram_number'])  # number of validation points
all_labels = ds['labels'] #.sel(type=['amplitude','z'])

train_labels = all_labels.isel(hologram_number=slice(valid_index,split_index))
test_labels = all_labels.isel(hologram_number=index_list)
# val_labels = all_labels.isel(hologram_number=slice(None,valid_index))

scaler = ml.MinMaxScalerX(train_labels,dim=('hologram_number','xsize','ysize'))
# scaled_train_labels = scaler.fit_transform(train_labels)
# scaled_val_labels = scaler.fit_transform(val_labels)
scaled_test_labels = scaler.fit_transform(test_labels)
# scaled_all_labels = scaler.fit_transform(all_labels)

In [None]:
in_data = ds['image'].isel(hologram_number=index_list)

In [None]:
if not 'channel' in in_data.dims:
    in_data = in_data.expand_dims("channel", 3)

In [None]:
scaled_in_data = in_data

In [None]:
# load the CNN model
mod = load_model(ds_path+"/models/"+model_file)

In [None]:
cnn_start = datetime.datetime.now()
preds_out = mod.predict(scaled_in_data.values, batch_size=64)
cnn_stop = datetime.datetime.now()
print(f"{scaled_in_data.values.shape[0]} samples in {(cnn_stop-cnn_start).total_seconds()} seconds")
print(f"for {(cnn_stop-cnn_start).total_seconds()/scaled_in_data.values.shape[0]} seconds per hologram")

In [None]:
preds_out.shape

In [None]:
all_labels.coords

In [None]:
preds_out_da = xr.DataArray(preds_out,dims=('hologram_number','xsize','ysize','layer'),
                            coords=test_labels.coords)

In [None]:
preds_original = scaler.inverse_transform(preds_out_da)

In [None]:
preds_original

In [None]:
iscatter = np.nonzero(preds_original.sel(type='amplitude').values.flatten() > 0.2)
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
for a, clabel in enumerate(test_labels.coords['layer'].values):
    ax=axes.ravel()[a]
    ax.scatter(test_labels.sel(type=clabel).values.flatten()[iscatter], preds_original.sel(type=clabel).values.flatten()[iscatter], 1, 'k')
    diag = np.linspace(test_labels.sel(type=clabel).min(), test_labels.sel(type=clabel).max(), 10)
    ax.plot(diag, diag, 'b--' )
    ax.set_title(clabel)
    plt.savefig("results/"+model_file.replace(".h5","")+f"_SampleScatterPlot"+f"_{num_epochs}epochs_run{run_num}_"+ds_file.replace(".nc","")+".png",dpi=300)

In [None]:
preds_original.sizes['layer']

In [None]:
diff_cmap = plt.get_cmap('seismic')
diff_cmap.set_bad(color='gray')

z_cmap = plt.get_cmap('viridis')
z_cmap.set_bad(color='gray')

for ai,ind in enumerate(index_list):
    fig, ax = plt.subplots(preds_original.sizes['layer'], 3, figsize=(12, 4*preds_original.sizes['layer']))
#     ax = ax.ravel()
    
    for iax,ch_layer in enumerate(test_labels.coords['layer'].values):
        if 'amplitude' in ch_layer:
            inan_mask = np.nonzero((preds_original.sel(layer=ch_layer,hologram_number=ind).values < 0.1)* \
                (test_labels.sel(layer=ch_layer,hologram_number=ind).values < 0.1))
            nan_mask = np.ones(preds_original.sel(layer=ch_layer,hologram_number=ind).values.shape)
            nan_mask[inan_mask] = np.nan
            ax[iax,0].imshow(preds_original.sel(layer=ch_layer,hologram_number=ind).values,vmin=0,vmax=1)
            ax[iax,1].imshow(test_labels.sel(layer=ch_layer,hologram_number=ind).values,vmin=0,vmax=1)
            ax[iax,2].imshow((preds_original.sel(layer=ch_layer,hologram_number=ind).values-test_labels.sel(layer=ch_layer,hologram_number=ind).values)*nan_mask,vmin=-1,vmax=1,cmap=diff_cmap)
        elif 'z' in ch_layer:
            ax[iax,0].imshow(preds_original.sel(layer=ch_layer,hologram_number=ind).values*nan_mask,vmin=0,vmax=1e-2,cmap=z_cmap)
            ax[iax,1].imshow(test_labels.sel(layer=ch_layer,hologram_number=ind).values*nan_mask,vmin=0,vmax=1e-2,cmap=z_cmap)
            ax[iax,2].imshow((preds_original.sel(layer=ch_layer,hologram_number=ind).values-test_labels.sel(layer=ch_layer,hologram_number=ind).values)*nan_mask,vmin=-1e-2,vmax=1e-2,cmap=diff_cmap)
            
    
    
#     inan_mask = np.nonzero((preds_original.sel(type='amplitude',hologram_number=ind).values < 0.1)* \
#         (test_labels.sel(type='amplitude',hologram_number=ind).values < 0.1))
#     nan_mask = np.ones(preds_original.sel(type='amplitude',hologram_number=ind).values.shape)
#     nan_mask[inan_mask] = np.nan
    
#     ax[0].imshow(preds_original.sel(type='amplitude',hologram_number=ind).values,vmin=0,vmax=1)
#     ax[1].imshow(test_labels.sel(type='amplitude',hologram_number=ind).values,vmin=0,vmax=1)
#     ax[2].imshow((preds_original.sel(type='amplitude',hologram_number=ind).values-test_labels.sel(type='amplitude',hologram_number=ind).values)*nan_mask,vmin=-1,vmax=1,cmap=diff_cmap)
#     # ax[2].imshow(scaled_in_data.values[ind,:,:,0])
#     ax[3].imshow(preds_original.sel(type='z',hologram_number=ind).values*nan_mask,vmin=0,vmax=1e-2,cmap=z_cmap)
#     ax[4].imshow(test_labels.sel(type='z',hologram_number=ind).values*nan_mask,vmin=0,vmax=1e-2,cmap=z_cmap)
#     ax[5].imshow((preds_original.sel(type='z',hologram_number=ind).values-test_labels.sel(type='z',hologram_number=ind).values)*nan_mask,vmin=-1e-2,vmax=1e-2,cmap=diff_cmap)
#     plt.savefig("results/"+model_file.replace(".h5","")+f"_ExampleImage{index_list[ind]}"+f"_{num_epochs}epochs_run{run_num}_"+ds_file.replace(".nc","")+".png",dpi=300)

In [None]:
ch_layer

In [None]:
channel_number = in_data.sizes['channel']
# index_list = [18]
for ind in range(len(index_list)):
    fig, ax = plt.subplots(2, channel_number//2, figsize=(channel_number*3, 8))
    for ai in range(channel_number):
        axind = ai//2+np.mod(ai,2)*channel_number//2
        ax[np.mod(ai,2),ai//2].imshow(scaled_in_data.isel(channel=ai,hologram_number=ind),vmin=-0.25,vmax=0.25)
    plt.savefig("results/"+model_file.replace(".h5","")+f"_ExampleInput{index_list[ind]}"+f"_{num_epochs}epochs_run{run_num}_"+ds_file.replace(".nc","")+".png",dpi=300)

In [None]:
in_data.sizes

In [None]:
from mpl_toolkits.mplot3d import Axes3D

In [None]:
xg,yg = np.meshgrid(preds_original['xsize'].values,preds_original['ysize'].values)

In [None]:
for ind,indact in enumerate(index_list):
    ipart = np.nonzero(preds_original.sel(type='amplitude',hologram_number=ind).values > 0.2)
    ipart_label = np.nonzero(test_labels.sel(type='amplitude',hologram_number=ind).values > 0.2)
    amp_p = preds_original.sel(type='amplitude',hologram_number=ind).values[ipart]
    z_p = preds_original.sel(type='z',hologram_number=ind).values[ipart]
    x_p = xg[ipart]
    y_p = yg[ipart]
    
    z_l = test_labels.sel(type='z',hologram_number=ind).values[ipart_label]
    x_l = xg[ipart_label]
    y_l = yg[ipart_label]

    fig = plt.figure(figsize=(10,6))
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(z_p,x_p,y_p,c=amp_p,vmin=0,vmax=1,s=1)
    ax.scatter(z_l,x_l,y_l,c='k',s=1)
    ax.set_xlim([ds.attrs['zmin'],ds.attrs['zmax']])
    ax.set_ylim([preds_original['xsize'].values[0],preds_original['xsize'].values[-1]])
    ax.set_zlim([preds_original['ysize'].values[0],preds_original['ysize'].values[-1]])
    ax.set_xlabel('z')
    ax.set_ylabel('x')
    ax.set_zlabel('y')
    plt.savefig("results/"+model_file.replace(".h5","")+f"_Scatter3D{index_list[ind]}"+f"_{num_epochs}epochs_run{run_num}_"+ds_file.replace(".nc","")+".png",dpi=300)

In [None]:
preds_original['xsize'].values[0]