In [None]:
import os
import sys

import numpy as np
import xarray as xr
import pandas as pd
import datetime
import matplotlib.pyplot as plt

from tensorflow.keras.layers import Input, Conv2D, Dense, Flatten, Activation, MaxPool2D, SeparableConv2D, AveragePooling2D, concatenate,Reshape
from tensorflow.keras.models import Model, save_model, load_model
from tensorflow.keras.utils import plot_model, to_categorical
from tensorflow.keras import regularizers

# from sklearn.preprocessing import MinMaxScaler

# set path to local libraries
dirP_str = '../../../library'
if dirP_str not in sys.path:
    sys.path.append(dirP_str)

import ml_utils as ml


%matplotlib inline

In [None]:
import FourierOpticsLib as FO
import MieLibrary as mie

In [None]:
ds_path = "/scr/sci/mhayman/holodec/holodec-ml-data/"

# ds_file = "synthethic_holograms_v0_svd_ac_amplitude_float.nc"

# ds_file = "synthetic_holograms_v02_svd_multipartamplitude_d_float.nc"
ds_file = "synthetic_holograms_v03_svd_multipartamplitude_d_float.nc"
channel_name = "channels"
data_rescale = 2


num_epochs = 30



ds = xr.open_dataset(ds_path+ds_file)  # file with mean (DC) value removed

In [None]:
run_num=1

In [None]:
pixwid = 3e-6  # size of pixels
wavelength = 355e-9  # laser wavelength

In [None]:
# split_index = 7000  # number of training+validation points
# valid_index = 2000  # number of validation points
# all_labels = ds[["d"]].to_dataframe()
# train_labels = all_labels.iloc[valid_index:split_index]
# test_labels = all_labels.iloc[split_index:]
# val_labels = all_labels.iloc[:valid_index]
# scaler = MinMaxScaler()
# scaled_train_labels = pd.DataFrame(scaler.fit_transform(train_labels), index=train_labels.index, columns=train_labels.columns)
# scaled_val_labels = pd.DataFrame(scaler.fit_transform(val_labels), index=val_labels.index, columns=val_labels.columns)
# scaled_test_labels = pd.DataFrame(scaler.transform(test_labels), index=test_labels.index, columns=test_labels.columns)

In [None]:
split_index = 7000  # number of training+validation points
valid_index = 2000  # number of validation points
all_labels = ds['particle_data'].transpose('hologram_number','particle_property','particle_number')

train_labels = all_labels.isel(hologram_number=slice(valid_index,split_index))
test_labels = all_labels.isel(hologram_number=slice(split_index,None))
val_labels = all_labels.isel(hologram_number=slice(None,valid_index))

scaler = ml.MinMaxScalerX(train_labels,dim=('hologram_number','particle_number'))
scaled_train_labels = scaler.fit_transform(train_labels)
scaled_val_labels = scaler.fit_transform(val_labels)
scaled_test_labels = scaler.fit_transform(test_labels)
scaled_all_labels = scaler.fit_transform(all_labels)

In [None]:
ds['image_svd'].dims

In [None]:
in_data = ds['image_svd'].transpose('hologram_number','filter_number','channels')

In [None]:
# if not "channel" in ds["image_ft"].dims:
#     in_data = ds["image_ft"].transpose("hologram_number", "xsize", 'ysize').expand_dims("channel", 3)
# else:
#     in_data = ds["image_ft"].transpose("hologram_number", "xsize", 'ysize',"channel")

In [None]:
ds

In [None]:
ds['xsize'].size

In [None]:
image_grid = FO.Coordinate_Grid(((ds['xsize'].size,ds['ysize'].size),(pixwid,pixwid)),inputType='ccd')

In [None]:
# check that the svd filters were created correctly
ds['filter_set'].isel(filter_number=4).plot()

In [None]:
in_data.sel(filter_number=0).plot()

In [None]:
in_data.max()

In [None]:
# # Perform PCA to prefilter input data
# max_angle = np.sqrt(np.max(image_grid.fx**2+image_grid.fy**2))*wavelength
# ang_grid = np.linspace(0,max_angle,500)

# particle_range = np.linspace(5,100,100)*1e-6

# scat_data = np.zeros((ang_grid.size,particle_range.size))
# for ir,r in enumerate(particle_range):
#     scat_data[:,ir] = np.abs(mie.Mie_PhaseMatrix(1.3,2*np.pi*r/wavelength,ang_grid)[0,:])
#     scat_data[:,ir] = scat_data[:,ir]/np.sum(scat_data[:,ir])  # normalize the area under the curve

In [None]:
# max_angle

In [None]:
# plt.figure()
# plt.plot(ang_grid,scat_data)
# plt.xlabel('scattering angle [radians]')
# plt.ylabel('amplitude')
# plt.grid(b=True)

In [None]:
# pca_data = scat_data.copy()
# pca_mean = np.mean(pca_data,axis=0,keepdims=True)
# pca_data = pca_data-pca_mean

# u,s,v = np.linalg.svd(pca_data.T)

In [None]:
# plt.figure()
# plt.plot(s)
# plt.yscale('log')
# plt.xlabel('Principle Component')
# plt.ylabel('Magnitude')
# plt.grid(b=True)

In [None]:
# # based on the above plot, decide where to truncate the PCA basis set
# itrunc = 20
# vtrunc = v[:itrunc,:]
# utrunc = u[:,:itrunc]

In [None]:
# plt.figure()
# plt.plot(ang_grid,vtrunc.T)
# plt.xlabel('scattering angle [radians]')
# plt.grid(b=True)

# plt.figure()
# plt.plot(particle_range*1e6,utrunc)
# plt.xlabel('particle radius [$\mu m$]')
# plt.grid(b=True)

In [None]:
# plt.figure()
# plt.plot(filter_set[:,:,5].values.flatten(),'.',markersize=1)

# plt.figure()
# plt.plot(ang_grid,vtrunc[5,:])
# plt.plot(grid_set.flatten(),filter_set.values[:,:,5].flatten(),'.')

# plt.figure()
# plt.imshow(filter_set[:,:,5])
# plt.colorbar()

In [None]:
# in_data = xr.DataArray(np.zeros((in_data0.coords['hologram_number'].size,itrunc,
#                                  in_data0.coords['channel'].size)),
#                       dims=('hologram_number','filter_number','channel'))


# # in_data = (in_data0*filter_set).sum(dim=('xsize','ysize',))
# for ai in in_data0.coords['hologram_number'].values:
#     for bi,ch in enumerate(in_data0.coords['channel'].values):
#         in_data.values[ai,:,bi] = ((in_data0.isel(hologram_number=ai,channel=bi)*filter_set).sum(dim=('xsize','ysize',))).values

In [None]:
# (in_data0.isel(hologram_number=60,channel=0)*filter_set).sum(dim=('xsize','ysize',))

In [None]:
# in_data.sizes['hologram_number']

In [None]:
# in_data.dims

In [None]:
# in_data.coords['channel'].size

In [None]:
# ds["image_ft"].max()

In [None]:
# plt.figure(); plt.plot(scaled_train_labels.values[10,:])

In [None]:
scaled_in_data = (in_data+160) / 250

In [None]:
nn_descript = 'DenseNN256_SVD_RegularizedAll'

filter_input = Input(shape=scaled_in_data.shape[1:])
flat = Flatten()(filter_input)
dense_1_ch1 = Dense(256, activation="relu",kernel_regularizer=regularizers.l1(1e-3))(flat)
dense_2_ch1 = Dense(128, activation="relu",kernel_regularizer=regularizers.l1(1e-3))(dense_1_ch1)
dense_3_ch1 = Dense(64, activation="relu",kernel_regularizer=regularizers.l1(1e-3))(dense_2_ch1)
flat_out = Dense(np.prod(all_labels.shape[1:]),activation='relu')(dense_3_ch1)  # number of outputs determined by the parameters we are training to
out = Reshape(all_labels.shape[1:], input_shape=(np.prod(all_labels.shape[1:]),))(flat_out)
mod = Model(filter_input, out)
mod.compile(optimizer="adam", loss="mae",metrics=['acc'])
mod.summary()

In [None]:
plot_model(mod,show_shapes=True,to_file="results/holodec_"+nn_descript+'_'+ds_file.replace(".nc","")+".png")

In [None]:
history = mod.fit(scaled_in_data[valid_index:split_index].values, scaled_train_labels.values, 
                  batch_size=16, epochs=num_epochs, verbose=1,
                  validation_data=(scaled_in_data[:valid_index].values,scaled_val_labels.values))
run_num+=1

In [None]:
epochs = np.arange(len(history.history['acc']))+1
fig, ax = plt.subplots(1, 1, figsize=(8, 4))
ax.plot(epochs,history.history['loss'],'bo-',alpha=0.5,label='Training')
ax.plot(epochs,history.history['val_loss'],'rs-',alpha=0.5,label='Validation')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.grid(b=True)
plt.legend()
plt.savefig("results/LossHistory_"+nn_descript+'_'+ds_file.replace(".nc","")+f"_{num_epochs}epochs_run{run_num}"+".png", dpi=200, bbox_inches="tight")

fig, bx = plt.subplots(1, 1, figsize=(8, 4))
bx.plot(epochs,history.history['acc'],'bo-',alpha=0.5,label='Training')
bx.plot(epochs,history.history['val_acc'],'rs-',alpha=0.5,label='Validation')
bx.set_xlabel('Epoch')
bx.set_ylabel('Accuracy')
bx.grid(b=True)
plt.legend()
plt.savefig("results/AccuracyHistory_"+nn_descript+'_'+ds_file.replace(".nc","")+f"_{num_epochs}epochs_run{run_num}"+".png", dpi=200, bbox_inches="tight")


In [None]:
# can skip the training process and just load the CNN model
#mod = load_model("holodec_ft_cnn.h5")

In [None]:
cnn_start = datetime.datetime.now()
preds_out = mod.predict(scaled_in_data.values, batch_size=64)
cnn_stop = datetime.datetime.now()
print(f"{scaled_in_data.values.shape[0]} samples in {(cnn_stop-cnn_start).total_seconds()} seconds")
print(f"for {(cnn_stop-cnn_start).total_seconds()/scaled_in_data.values.shape[0]} seconds per hologram")

In [None]:
save_model(mod, ds_path+"/models/holodec_histogram_"+nn_descript+'_'+ds_file.replace(".nc","")+f"{num_epochs}epochs_run{run_num}"+".h5", save_format="h5")

In [None]:
preds_out.shape

In [None]:
preds_out_da = xr.DataArray(preds_out,dims=all_labels.dims,
                            coords=all_labels.coords)
preds_original = scaler.inverse_transform(preds_out_da)

In [None]:
mean_error = (preds_original[split_index:] - test_labels).mean(dim=('hologram_number','particle_number'))
std_error = (preds_original[split_index:] - test_labels).std(dim=('hologram_number','particle_number'))

In [None]:
validation_data = {}
fig, axes = plt.subplots(2, 2, figsize=(8, 8))
for a, clabel in enumerate(all_labels.coords['particle_property'].values):
    ax=axes.ravel()[a]
    ax.scatter(test_labels.sel(particle_property=clabel), preds_original.sel(particle_property=clabel,hologram_number=slice(split_index,None)), 1, 'k')
    diag = np.linspace(test_labels.sel(particle_property=clabel).min(), test_labels.sel(particle_property=clabel).max(), 10)
    ax.plot(diag, diag, 'b--' )
    ax.set_title(clabel)
    plt.text(0.1,0.9,f"${mean_error.sel(particle_property=clabel).values:.1f} \pm {std_error.sel(particle_property=clabel).values:.1f} \mu m$",ha='left',va='top',transform=ax.transAxes)
#     validation_data[test_labels.columns[a]] = test_labels.iloc[:, a]
#     validation_data[test_labels.columns[a]+'_pred'] = preds_original[split_index:, a]
plt.savefig("results/error_scatter_"+nn_descript+'_'+ds_file.replace(".nc","")+f"{num_epochs}epochs_run{run_num}"+".png", dpi=200, bbox_inches="tight")
# validation_data_df = pd.DataFrame(validation_data)
# validation_data_df.to_csv('results/validation_data_denseNN_MultiIn_'+ds_file.replace(".nc","_")+''.join(all_labels.columns)+'.txt')
    

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(8, 8))
for a, clabel in enumerate(all_labels.coords['particle_property'].values):
    ax=axes.ravel()[a]
    ax.hist( (preds_original.sel(particle_property=clabel,hologram_number=slice(split_index,None)).values - test_labels.sel(particle_property=clabel).values).flatten() / (test_labels.sel(particle_property=clabel).values.max() - test_labels.sel(particle_property=clabel).values.min()),
           bins=20)
    ax.set_yscale("log")
    ax.set_xlabel("Error in "+clabel)
plt.savefig("results/relative_error_histogram"+nn_descript+'_'+ds_file.replace(".nc","")+f"{num_epochs}epochs_run{run_num}"+".png", dpi=200, bbox_inches="tight")

In [None]:

fig, axes = plt.subplots(2, 2, figsize=(8, 8))
for a, clabel in enumerate(all_labels.coords['particle_property'].values):
    ax=axes.ravel()[a]
    hprop = np.histogram2d(test_labels.sel(particle_property=clabel).values.flatten(),preds_original.sel(particle_property=clabel,hologram_number=slice(split_index,None)).values.flatten(),bins=100)
    im = ax.pcolor(hprop[1], hprop[2],hprop[0].T/np.sum(hprop[0]))
    diag = np.linspace(test_labels.sel(particle_property=clabel).min(), test_labels.sel(particle_property=clabel).max(), 10)
    #ax.plot(diag, diag, 'w--' )
    im.set_clim([0,sorted((hprop[0].flatten()/np.sum(hprop[0])))[-2]])
    ax.set_title(clabel)

plt.savefig("results/histogram2D_"+nn_descript+'_'+ds_file.replace(".nc","")+f"{num_epochs}epochs_run{run_num}"+".png", dpi=200, bbox_inches="tight")

In [None]:
fig,axes = plt.subplots(1,1,figsize=(12,4))
axes.imshow(scaled_test_labels.values.T)
axes.set_xlabel('Hologram Number')
axes.set_ylabel('Histogram Index')
axes.set_title('Test Labels')
plt.savefig("results/Label_Histogram_"+nn_descript+'_'+ds_file.replace(".nc","")+f"_{num_epochs}epochs"+".png", dpi=200, bbox_inches="tight")

fig,axes = plt.subplots(1,1,figsize=(12,4))
axes.imshow(preds_out[split_index:,:].T,vmin=0,vmax=0.3)
axes.set_xlabel('Hologram Number')
axes.set_ylabel('Histogram Index')
axes.set_title('Test Output')
plt.savefig("results/Output_Histogram_"+nn_descript+'_'+ds_file.replace(".nc","")+f"_{num_epochs}epochs"+".png", dpi=200, bbox_inches="tight")


In [None]:
hindices = [10,234,500,1293]
fig,axes = plt.subplots(2,2,figsize=(8,8))
axes=axes.ravel()
for hi,holo_index in enumerate(hindices):
    axes[hi].plot(ds['particle_histogram'].z_bin_centers.values,scaled_test_labels.iloc[holo_index,:],label=f'Test Label {holo_index}')
    axes[hi].plot(ds['particle_histogram'].z_bin_centers.values,preds_out[split_index+holo_index,0:],label=f'Test Output {holo_index}')
    axes[hi].set_xlabel('Particle Diameter [$\mu m$]')
    
    #axes[hi].plot(test_labels.iloc[holo_index,:],label=f'Test Label {holo_index}')
    #axes[hi].plot(preds_out[split_index+holo_index,0:],label=f'Test Output {holo_index}')
    
    axes[hi].set_ylabel('Probability')
    
    #axes[hi].set_title(f'Hologram {holo_index}')
    axes[hi].grid(b=True)
    axes[hi].legend()

plt.savefig("results/Example_Hists_"+nn_descript+'_'+ds_file.replace(".nc","")+f"_{num_epochs}epochs"+".png", dpi=200, bbox_inches="tight")

In [None]:
part_indices = np.nonzero(scaled_test_labels.values)
pred_values = (preds_out[split_index:,:])[part_indices]
label_moments = []
label_moments.append(np.sum(scaled_test_labels.values*ds['particle_histogram'].z_bin_centers.values[np.newaxis,:],axis=1)/np.sum(scaled_test_labels.values,axis=1))
label_moments.append(np.sum(scaled_test_labels.values*ds['particle_histogram'].z_bin_centers.values[np.newaxis,:]**2,axis=1)/np.sum(scaled_test_labels.values,axis=1))
label_moments.append(np.sum(scaled_test_labels.values*ds['particle_histogram'].z_bin_centers.values[np.newaxis,:]**3,axis=1)/np.sum(scaled_test_labels.values,axis=1))
label_moments.append(np.sum(scaled_test_labels.values*ds['particle_histogram'].z_bin_centers.values[np.newaxis,:]**4,axis=1)/np.sum(scaled_test_labels.values,axis=1))

pred_moments = []
pred_moments.append(np.sum(preds_out[split_index:,:]*ds['particle_histogram'].z_bin_centers.values[np.newaxis,:],axis=1)/np.sum(preds_out[split_index:,:],axis=1))
pred_moments.append(np.sum(preds_out[split_index:,:]*ds['particle_histogram'].z_bin_centers.values[np.newaxis,:]**2,axis=1)/np.sum(preds_out[split_index:,:],axis=1))
pred_moments.append(np.sum(preds_out[split_index:,:]*ds['particle_histogram'].z_bin_centers.values[np.newaxis,:]**3,axis=1)/np.sum(preds_out[split_index:,:],axis=1))
pred_moments.append(np.sum(preds_out[split_index:,:]*ds['particle_histogram'].z_bin_centers.values[np.newaxis,:]**4,axis=1)/np.sum(preds_out[split_index:,:],axis=1))
std_pred = np.sqrt(pred_moments[1]-pred_moments[0]**2)

mean_error = []
std_error = []
fig, axes = plt.subplots(2, 2, figsize=(8, 8))
for ai,ax in enumerate(axes.ravel()):
    if ai == 0:
        mean_error.append(np.mean(pred_moments[ai]-label_moments[ai]))
    else:
        mean_error.append(np.abs(np.mean(pred_moments[ai]-label_moments[ai]))**(1/(1.0+ai)))
    std_error.append(np.std(pred_moments[ai]-label_moments[ai])**(1/(1.0+ai)))
    ax.scatter(label_moments[ai]**(1/(1.0+ai)),pred_moments[ai]**(1/(1.0+ai)), 1, 'k')
    diag = np.linspace(label_moments[ai].min()**(1/(1.0+ai)), label_moments[ai].max()**(1/(1.0+ai)), 10)
    ax.plot(diag, diag, 'b--' )
    plt.text(0.1,0.9,f"${mean_error[ai]:.1f} \pm {std_error[ai]:.1f} \mu m$",ha='left',va='top',transform=ax.transAxes)
    ax.grid(b=True)
    ax.set_title('moment %d'%(ai+1))
plt.savefig("results/Moment_Scatter_"+nn_descript+'_'+ds_file.replace(".nc","")+f"_{num_epochs}epochs"+".png", dpi=200, bbox_inches="tight")


In [None]:
np.mean(pred_moments[0]-label_moments[0])

In [None]:
np.mean((pred_moments[0]-label_moments[0])**2)

In [None]:
np.mean(pred_moments[1]-pred_moments[0]**2)