In [None]:
import os
import sys

import numpy as np
import xarray as xr
import pandas as pd
import datetime

from tensorflow.keras.layers import Input, Conv2D, Dense, Flatten, Activation, MaxPool2D, SeparableConv2D, AveragePooling2D, concatenate,add,Reshape
from tensorflow.keras.models import Model, save_model, load_model
from tensorflow.keras.utils import plot_model
from tensorflow import concat
from tensorflow.keras import regularizers

import matplotlib.pyplot as plt

# set path to local libraries
dirP_str = '../../../library'
if dirP_str not in sys.path:
    sys.path.append(dirP_str)

import ml_utils as ml

%matplotlib inline

run_num = 0

In [None]:
ds_path = "/scr/sci/mhayman/holodec/holodec-ml-data/"

ds_file = "synthetic_holograms_v03_z_real_imag_amplitude_float32_histogram_count.nc" # 3 particle data
data_rescale = 255

In [None]:
num_epochs = 15

In [None]:
run_num = 0

In [None]:
ds = xr.open_dataset(ds_path+ds_file)  # file with mean (DC) value removed

In [None]:
ds['particle_count'].sel(hologram_number=80)

In [None]:
split_index = 7000  # number of training+validation points
valid_index = 2000  # number of validation points
all_labels = ds['particle_count'].transpose('hologram_number','count')

train_labels = all_labels.isel(hologram_number=slice(valid_index,split_index))
test_labels = all_labels.isel(hologram_number=slice(split_index,None))
val_labels = all_labels.isel(hologram_number=slice(None,valid_index))

scaler = ml.MinMaxScalerX(train_labels,dim=('hologram_number','count'))
scaled_train_labels = scaler.fit_transform(train_labels)
scaled_val_labels = scaler.fit_transform(val_labels)
scaled_test_labels = scaler.fit_transform(test_labels)
scaled_all_labels = scaler.fit_transform(all_labels)

In [None]:
# train_labels.min(dim=('hologram_number','particle_number'))
scaled_train_labels

In [None]:
all_labels.shape


In [None]:
# if channel_name is None:
#     in_data = ds["image_ft"].transpose("hologram_number", "ysize", 'xsize').expand_dims("channel", 3)
# else:
#     in_data = ds["image_ft"].transpose("hologram_number", "ysize", "xsize",channel_name)

In [None]:
in_data = ds["image"].transpose("hologram_number", "ysize", "xsize").expand_dims("channel", 3)

In [None]:
scaled_in_data = in_data / data_rescale

In [None]:
scaled_in_data.shape

In [None]:
cnn_input = Input(shape=scaled_in_data.shape[1:])  # input

nn_descript = '3CNN_3Dense_softmax'

conv_1 = Conv2D(32, (5, 5), padding="same")(cnn_input)
act_1 = Activation("relu")(conv_1)
pool_1 = MaxPool2D(pool_size=(4, 4))(act_1)

conv_2 = Conv2D(16, (5, 5), padding="same")(pool_1)
act_2 = Activation("relu")(conv_2)
pool_2 = MaxPool2D(pool_size=(4, 4))(act_2)

conv_3 = Conv2D(32, (5, 5), padding="same")(pool_2)
act_3 = Activation("relu")(conv_3)
pool_3 = MaxPool2D(pool_size=(4, 4))(act_3)

# conv_4 = Conv2D(64, (15, 15), padding="same")(pool_3)
# act_4 = Activation("relu")(conv_4)
# pool_4 = MaxPool2D(pool_size=(4, 4))(act_4)

flat_c = Flatten()(pool_3)
dense_1 = Dense(32, activation="relu")(flat_c)
dense_2 = Dense(16, activation="relu")(dense_1)
dense_3 = Dense(np.prod(all_labels.shape[1:]), activation="softmax")(dense_2) # encoding layer to sort particles by size
mod = Model(cnn_input, dense_3)
mod.compile(optimizer="adam", loss="binary_crossentropy",metrics=['acc'])
mod.summary()
run_num=0

In [None]:
plot_model(mod,show_shapes=True,to_file="results/holodec_"+nn_descript+'_'+ds_file.replace(".nc","")+".png")

In [None]:
history = mod.fit(scaled_in_data[valid_index:split_index].values,
                  scaled_train_labels.values, 
                  batch_size=64, epochs=num_epochs, verbose=1,
                  validation_data=(scaled_in_data[:valid_index].values,scaled_val_labels.values))
run_num+=1

In [None]:
epochs = np.arange(len(history.history['loss']))+1
fig, ax = plt.subplots(1, 1, figsize=(8, 4))
ax.plot(epochs,history.history['loss'],'bo-',alpha=0.5,label='Training')
ax.plot(epochs,history.history['val_loss'],'rs-',alpha=0.5,label='Validation')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.grid(b=True)
plt.legend()
plt.savefig("results/LossHistory_"+nn_descript+'_'+ds_file.replace(".nc","")+f"_{num_epochs}epochs_run{run_num}"+".png", dpi=200, bbox_inches="tight")

fig, bx = plt.subplots(1, 1, figsize=(8, 4))
bx.plot(epochs,history.history['acc'],'bo-',alpha=0.5,label='Training')
bx.plot(epochs,history.history['val_acc'],'rs-',alpha=0.5,label='Validation')
bx.set_xlabel('Epoch')
bx.set_ylabel('Accuracy')
bx.grid(b=True)
plt.legend()
plt.savefig("results/AccuracyHistory_"+nn_descript+'_'+ds_file.replace(".nc","")+f"_{num_epochs}epochs_run{run_num}"+".png", dpi=200, bbox_inches="tight")


In [None]:
# can skip the training process and just load the CNN model
# mod = load_model("holodec_ft_dxyz_cnn.h5")

In [None]:
cnn_start = datetime.datetime.now()
preds_out = mod.predict(scaled_in_data.values, batch_size=64)
cnn_stop = datetime.datetime.now()
print(f"{scaled_in_data.values.shape[0]} samples in {(cnn_stop-cnn_start).total_seconds()} seconds")
print(f"for {(cnn_stop-cnn_start).total_seconds()/scaled_in_data.values.shape[0]} seconds per hologram")

In [None]:
save_model(mod, ds_path+"/models/holodec_histogram_"+nn_descript+'_'+ds_file.replace(".nc","")+f"{num_epochs}epochs_run{run_num}"+".h5", save_format="h5")

In [None]:
preds_out.shape

In [None]:
plt_index=4440
plt.figure()
plt.plot(preds_out[plt_index,:])
plt.plot(all_labels.values[plt_index,:])

In [None]:
hcount = np.histogram2d(np.argmax(all_labels.values,axis=1)+1,np.argmax(preds_out,axis=1)+1,bins=[np.arange(10)]*2)
fig, ax = plt.subplots(1, 1, figsize=(10, 8))
im = ax.pcolor(hcount[1],hcount[2],hcount[0].T)
ax.set_xlabel('Actual_Count')
ax.set_ylabel('Predicted Count')
plt.colorbar(im)
plt.savefig("results/MostLikelyCountHist_"+nn_descript+'_'+ds_file.replace(".nc","")+f"_{num_epochs}epochs_run{run_num}"+".png", dpi=200, bbox_inches="tight")

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(12, 8))
ax.plot(preds_out[np.nonzero(all_labels.values)],'.')
ax.plot([0,preds_out.shape[0]],np.ones(2)*1/preds_out.shape[1],'k--')
ax.set_xlabel('Hologram Number')
ax.set_ylabel('Probability of Actual')
plt.savefig("results/CorrectProbability_"+nn_descript+'_'+ds_file.replace(".nc","")+f"_{num_epochs}epochs_run{run_num}"+".png", dpi=200, bbox_inches="tight")

In [None]:
fig, ax = plt.subplots(1,1,figsize=(6,6))
ax.hist(preds_out[np.nonzero(all_labels.values)])
ax.set_xlabel('Probability of Actual')
ax.set_ylabel('Count')
plt.savefig("results/CorrectHistProbability_"+nn_descript+'_'+ds_file.replace(".nc","")+f"_{num_epochs}epochs_run{run_num}"+".png", dpi=200, bbox_inches="tight")

In [None]:
preds_out_da = xr.DataArray(preds_out,dims=('hologram_number','particle_property','particle_number'),
                            coords=all_labels.coords)

In [None]:
preds_original = scaler.inverse_transform(preds_out_da)

In [None]:
(test_labels.max() - test_labels.min())

In [None]:
preds_original

In [None]:
mean_error = (preds_original[split_index:] - test_labels).mean(dim=('hologram_number','particle_number'))
std_error = (preds_original[split_index:] - test_labels).std(dim=('hologram_number','particle_number'))

In [None]:
validation_data = {}
fig, axes = plt.subplots(2, 2, figsize=(8, 8))
for a, clabel in enumerate(all_labels.coords['particle_property'].values):
    ax=axes.ravel()[a]
    ax.scatter(test_labels.sel(particle_property=clabel), preds_original.sel(particle_property=clabel,hologram_number=slice(split_index,None)), 1, 'k')
    diag = np.linspace(test_labels.sel(particle_property=clabel).min(), test_labels.sel(particle_property=clabel).max(), 10)
    ax.plot(diag, diag, 'b--' )
    ax.set_title(clabel)
    plt.text(0.1,0.9,f"${mean_error.sel(particle_property=clabel).values:.1f} \pm {std_error.sel(particle_property=clabel).values:.1f} \mu m$",ha='left',va='top',transform=ax.transAxes)
#     validation_data[test_labels.columns[a]] = test_labels.iloc[:, a]
#     validation_data[test_labels.columns[a]+'_pred'] = preds_original[split_index:, a]
plt.savefig("results/error_scatter_"+nn_descript+'_'+ds_file.replace(".nc","")+f"{num_epochs}epochs_run{run_num}"+".png", dpi=200, bbox_inches="tight")
# validation_data_df = pd.DataFrame(validation_data)
# validation_data_df.to_csv('results/validation_data_denseNN_MultiIn_'+ds_file.replace(".nc","_")+''.join(all_labels.columns)+'.txt')
    

In [None]:
mean_error.sel(particle_property=clabel)

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(8, 8))
for a, clabel in enumerate(all_labels.coords['particle_property'].values):
    ax=axes.ravel()[a]
    ax.hist( (preds_original.sel(particle_property=clabel,hologram_number=slice(split_index,None)).values - test_labels.sel(particle_property=clabel).values).flatten() / (test_labels.sel(particle_property=clabel).values.max() - test_labels.sel(particle_property=clabel).values.min()),
           bins=20)
    ax.set_yscale("log")
    ax.set_xlabel("Error in "+clabel)
plt.savefig("results/relative_error_histogram"+nn_descript+'_'+ds_file.replace(".nc","")+f"{num_epochs}epochs_run{run_num}"+".png", dpi=200, bbox_inches="tight")

In [None]:

fig, axes = plt.subplots(2, 2, figsize=(8, 8))
for a, clabel in enumerate(all_labels.coords['particle_property'].values):
    ax=axes.ravel()[a]
    hprop = np.histogram2d(test_labels.sel(particle_property=clabel).values.flatten(),preds_original.sel(particle_property=clabel,hologram_number=slice(split_index,None)).values.flatten(),bins=100)
    im = ax.pcolor(hprop[1], hprop[2],hprop[0].T/np.sum(hprop[0]))
    diag = np.linspace(test_labels.sel(particle_property=clabel).min(), test_labels.sel(particle_property=clabel).max(), 10)
    ax.plot(diag, diag, 'w--' )
    im.set_clim([0,sorted((hprop[0].flatten()/np.sum(hprop[0])))[-2]])
    ax.set_title(clabel)

plt.savefig("results/histogram2D_"+nn_descript+'_'+ds_file.replace(".nc","")+f"{num_epochs}epochs_run{run_num}"+".png", dpi=200, bbox_inches="tight")