In [None]:
import sys
import numpy as np
import xarray as xr

from tensorflow.keras.layers import Input, Conv2D, Dense, Flatten, Activation, MaxPool2D, SeparableConv2D, UpSampling2D, concatenate, Conv2DTranspose, Lambda, Reshape, Layer
# from tensorflow.keras.layers import Input, Conv2D
from tensorflow.keras.models import Model, save_model, load_model
from tensorflow.keras.utils import plot_model
import tensorflow.keras.backend as K
import tensorflow.keras.metrics

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

import datetime
%matplotlib inline

# set path to local libraries
dirP_str = '../../../library'
if dirP_str not in sys.path:
    sys.path.append(dirP_str)

import ml_utils as ml
import ml_defs as mldef

In [None]:
paths= {'data':'/scr/sci/mhayman/holodec/holodec-ml-data/'}
fn = 'synthetic_holograms_v02.nc'

run_num = 0
num_epochs = 150

input_variable = 'image'

split_fraction = 0.8
valid_fraction = 0.2

image_rescale = 255

In [None]:
# VAE Model Definitions
settings={
        'n_layers':1, # must be greater than 1 because the latent layer counts
        'n_filters':2, # number of input convolutional channels
        'nConv':4, # convolution kernel size
        'nPool':4, # max pool size
        'activation':'relu', # convolution activation
        'kernel_initializer':"he_normal",
        'latent_dim':32,
        'n_dense_layers':2, # number of dense layers in bottom layer
        'loss_fun':'mse',   # training loss function
        'out_act':'linear',  # output activation
        }


In [None]:
with xr.open_dataset(paths['data']+fn,chunks={'hologram_number':1}) as ds:
    print(ds.data_vars)
#     file_base = 'histogram'+file_use+'data_%dcount'%hologram_count+run_date_str
    print('Training dataset attributes')
    for att in ds.attrs:
        print('  '+att+': '+str(ds.attrs[att]))
    
    print('   max particle size: %d'%ds['d'].values.max())
    print('   min particle size: %d'%ds['d'].values.min())
    print()
    
    # Setup labels
    split_index = np.int(split_fraction*ds.sizes['hologram_number'])  # number of training+validation points
    valid_index = np.int(valid_fraction*ds.sizes['hologram_number'])  # number of validation points

    train_labels = ds[input_variable].isel(hologram_number=slice(valid_index,split_index))
    test_labels = ds[input_variable].isel(hologram_number=slice(split_index,None))
    val_labels = ds[input_variable].isel(hologram_number=slice(None,valid_index))

    scaled_train_labels = train_labels/image_rescale
    scaled_val_labels = val_labels/image_rescale
    scaled_test_labels = test_labels/image_rescale
    
#     scaler = ml.MinMaxScalerX(train_labels,dim=('hologram_number','xsize','ysize'))
#     scaled_train_labels = scaler.fit_transform(train_labels)
#     scaled_val_labels = scaler.fit_transform(val_labels)
#     scaled_test_labels = scaler.fit_transform(test_labels)
#     scaled_all_labels = scaler.fit_transform(all_labels)

    # setup the input to be used
    
    scaled_in_data = ds[input_variable]/image_rescale
    print('\ninput dimensions:')
    print(scaled_in_data.dims)
    print(scaled_in_data.shape)
    print()
    print('split index: %d'%split_index)
    print('valid index: %d'%valid_index)
    
    if not 'channel' in scaled_in_data.dims:
        scaled_in_data = scaled_in_data.expand_dims("channel", 3)
    scaled_in_train = scaled_in_data.isel(hologram_number=slice(valid_index,split_index))
    scaled_in_valid = scaled_in_data.isel(hologram_number=slice(None,valid_index))
    scaled_in_test = scaled_in_data.isel(hologram_number=slice(split_index,None))
        

In [None]:
loss_fun = settings['loss_fun']

In [None]:
n_filters = settings['n_filters']
nConv = settings['nConv']
nPool = settings['nPool']
kernel_initializer = settings['kernel_initializer']
latent_dim = settings['latent_dim']
n_dense_layers = settings['n_dense_layers']

# create the model
input_node = Input(shape=scaled_in_data.shape[1:])  
next_input_node = input_node
for _ in range(settings['n_layers']):
    # define the down sampling layer
    conv_1d = Conv2D(n_filters, (nConv, nConv), padding="same", kernel_initializer = kernel_initializer)(next_input_node)
    act_1d = Activation("relu")(conv_1d)
    conv_2d = Conv2D(n_filters, (nConv, nConv), padding="same", kernel_initializer = kernel_initializer)(act_1d)
    act_2d = Activation("relu")(conv_2d)
    next_input_node = MaxPool2D(pool_size=(nPool, nPool))(act_2d)
    n_filters = n_filters*2

n_filters = n_filters//2

input_shape = K.int_shape(next_input_node)
zinput = Flatten()(next_input_node)
z_mean = Dense(latent_dim,activation='relu')(zinput)
z_log_var = Dense(latent_dim,activation='relu')(zinput)
for _ in range(np.maximum(n_dense_layers-2,0)):
    # represent the mean and variance branches
    # with separate dense networks
    z_mean = Dense(latent_dim,activation='relu')(z_mean)
    z_log_var = Dense(latent_dim,activation='relu')(z_log_var)
z_mean = Dense(latent_dim,activation='linear')(z_mean)
z_log_var = Dense(latent_dim,activation='linear')(z_log_var)

z = Lambda(mldef.vae_sample)([z_mean,z_log_var])

x1 = Dense(np.prod(input_shape[1:]),activation='relu')(z)
return_node = Reshape(input_shape[1:])(x1)

# # add VAE with convolution inputs and outputs
# unet_out = mldef.add_unet_vae(cnn_input,settings['n_layers'],settings['n_filters'],nConv=settings['nConv'],
#             nPool=settings['nPool'],activation=settings['activation'],
#             kernel_initializer = settings['kernel_initializer'],
#             latent_dim=settings['latent_dim'],n_dense_layers=settings['n_dense_layers'],)

for _ in range(np.maximum(settings['n_layers'],0)):
    # define the up sampling and feed foward layer
    upsamp_1u = Conv2DTranspose(n_filters, (nConv,nConv), strides=(nPool,nPool),padding="same")(return_node)
    conv_1u = Conv2D(n_filters,(nConv,nConv),padding="same",kernel_initializer = kernel_initializer)(upsamp_1u)
    act_1u = Activation("relu")(conv_1u)
    conv_2u = Conv2D(n_filters,(nConv,nConv),padding="same",kernel_initializer = kernel_initializer)(act_1u)
    return_node = Activation("relu")(conv_2u)
    n_filters = n_filters // 2

# add the output layer
z_decoded = Conv2D(1,(1,1),padding="same",activation=settings['out_act'])(return_node)



In [None]:
input_shape

In [None]:
z_decoded.shape

In [None]:
class CustomVariationalLayer(Layer):
    def vae_loss(self,x,z_decoded,z_mean,z_log_var):
        x_mse_loss = tensorflow.keras.metrics.mse(x,z_decoded)
        x_mse_loss = K.mean(K.square(x-z_decoded),axis=[1,2,3])
        beta = 5e-4
        kl_loss = -beta*K.mean(1+z_log_var-K.square(z_mean) - K.exp(z_log_var),axis=-1)
        return K.mean(x_mse_loss+ kl_loss)

    def call(self,inputs):
        x = inputs[0]
        z_decoded = inputs[1]
        z_mean = inputs[2]
        z_log_var = inputs[3]
        loss = self.vae_loss(x,z_decoded,z_mean,z_log_var)
        self.add_loss(loss,inputs=inputs)
        return z_decoded

y = CustomVariationalLayer()([input_node,z_decoded,z_mean,z_log_var])

In [None]:
# build and compile the model
mod = Model(input_node, y)
mod.compile(optimizer="adam", loss=None, metrics=['acc'])
mod.summary()

In [None]:
plot_model(mod,show_shapes=True)

In [None]:
history = mod.fit(scaled_in_train.values,y=None,
                  batch_size=16, epochs=20, verbose=1,
                  validation_data=(scaled_in_valid.values,None))

In [None]:
epochs = np.arange(len(history.history['loss']))+1
fig, ax = plt.subplots(1, 1, figsize=(8, 4))
ax.plot(epochs,history.history['loss'],'bo-',alpha=0.5,label='Training')
ax.plot(epochs,history.history['val_loss'],'rs-',alpha=0.5,label='Validation')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.grid(b=True)
plt.legend()
# plt.savefig("results/LossHistory_"+nn_descript+'_'+ds_file.replace(".nc","")+f"_{num_epochs}epochs_run{run_num}"+".png", dpi=200, bbox_inches="tight")

# fig, bx = plt.subplots(1, 1, figsize=(8, 4))
# bx.plot(epochs,history.history['acc'],'bo-',alpha=0.5,label='Training')
# bx.plot(epochs,history.history['val_acc'],'rs-',alpha=0.5,label='Validation')
# bx.set_xlabel('Epoch')
# bx.set_ylabel('Accuracy')
# bx.grid(b=True)
# plt.legend()
# plt.savefig("results/AccuracyHistory_"+nn_descript+'_'+ds_file.replace(".nc","")+f"_{num_epochs}epochs_run{run_num}"+".png", dpi=200, bbox_inches="tight")


In [None]:
cnn_start = datetime.datetime.now()
preds_out = mod.predict(scaled_in_test.values, batch_size=64)
cnn_stop = datetime.datetime.now()
print(f"{scaled_in_test.values.shape[0]} samples in {(cnn_stop-cnn_start).total_seconds()} seconds")
print(f"for {(cnn_stop-cnn_start).total_seconds()/scaled_in_test.values.shape[0]} seconds per hologram")

In [None]:
preds_out_da = xr.DataArray(preds_out,dims=('hologram_number','xsize','ysize','channel'),
                            coords=scaled_in_test.coords)

In [None]:
for im in [19, 206, 432, 543]:
    fig_obj, ax_obj_lst = plt.subplots(1, 3, figsize=(3*7, 4))
    ax_obj = ax_obj_lst[0]
    im_obj = ax_obj.imshow(scaled_in_test.isel(hologram_number=im,channel=0))
    plt.colorbar(im_obj, ax=ax_obj)
    ax_obj.set_title('True image')

    ax_obj = ax_obj_lst[1]
    im_obj = ax_obj.imshow(preds_out_da.isel(hologram_number=im,channel=0))
    plt.colorbar(im_obj, ax=ax_obj)
    ax_obj.set_title('Reconstructed Image')
    
    ax_obj = ax_obj_lst[2]
    im_obj = ax_obj.imshow(scaled_in_test.isel(hologram_number=im,channel=0)-preds_out_da.isel(hologram_number=im,channel=0))
    plt.colorbar(im_obj, ax=ax_obj)
    ax_obj.set_title('Difference')
#     plt.plot(preds_out_da.isel(hologram_number=im,channel=0,ysize=300))
    

In [None]:
# define a UNET for image processing
nFilters = 16
nPool = 2
nConv = 7
nLayers = 3
loss_fun = "mse" #,"mae" #"binary_crossentropy"
out_act = "linear" # "sigmoid"
nn_descript = f'UNET_{nFilters}Filt_{nConv}Conv_{nPool}Pool_{nLayers}Layers_'+loss_fun+'_'+out_act

# define the input based on input data dimensions
cnn_input = Input(shape=scaled_in_data.shape[1:])  

# create the unet
unet_out = add_conv_layers(cnn_input,nLayers,nFilters,nConv=nConv,nPool=nPool,activation="relu")

# add the output layer
out = Conv2D(scaled_train_labels.sizes['type'],(1,1),padding="same",activation=out_act)(unet_out)

# build and compile the model
mod = Model(cnn_input, unet_out)
mod.compile(optimizer="adam", loss=loss_fun,metrics=['acc'])
mod.summary()
run_num=0

In [None]:
plot_model(mod,show_shapes=True)

In [None]:
n_filters = 16
nPool = 4
nConv = 5
loss_fun = "mse" #,"mae" #"binary_crossentropy"
out_act = "linear" # "sigmoid"
nn_descript = f'UNET_{n_filters}Filt_{nConv}Conv_{nPool}Pool_'+loss_fun+'_'+out_act
cnn_input2 = Input(shape=scaled_in_data.shape[1:])  # input

conv_1a = SeparableConv2D(n_filters*1, (nConv, nConv), padding="same", kernel_initializer = "he_normal")(cnn_input)
act_1a = Activation("relu")(conv_1a)
conv_1b = SeparableConv2D(n_filters*1, (nConv, nConv), padding="same", kernel_initializer = "he_normal")(act_1a)
act_1b = Activation("relu")(conv_1b)
pool_1 = MaxPool2D(pool_size=(nPool, nPool))(act_1b)

conv_2a = SeparableConv2D(n_filters*2,(nConv,nConv),padding="same", kernel_initializer = "he_normal")(pool_1)
act_2a = Activation("relu")(conv_2a)
conv_2b = SeparableConv2D(n_filters*2,(nConv,nConv),padding="same", kernel_initializer = "he_normal")(act_2a)
act_2b = Activation("relu")(conv_2b)
pool_2 = MaxPool2D(pool_size=(nPool, nPool))(act_2b)

conv_3a = SeparableConv2D(n_filters*4,(nConv,nConv),padding="same", kernel_initializer = "he_normal")(pool_2)
act_3a = Activation("relu")(conv_3a)
conv_3b = SeparableConv2D(n_filters*4,(nConv,nConv),padding="same", kernel_initializer = "he_normal")(act_3a)
act_3b = Activation("relu")(conv_3b)
pool_3 = MaxPool2D(pool_size=(nPool, nPool))(act_3b)

conv_4a = SeparableConv2D(n_filters*8,(nConv,nConv),padding="same", kernel_initializer = "he_normal")(pool_3)
act_4a = Activation("relu")(conv_4a)

conv_4b = SeparableConv2D(n_filters*8,(nConv,nConv),padding="same", kernel_initializer = "he_normal")(act_4a)
act_4b = Activation("relu")(conv_4b)

# upsamp_5 = UpSampling2D((4,4))(act_4b)
upsamp_5 = Conv2DTranspose(n_filters*4, (nConv,nConv), strides=(nPool,nPool),padding="same")(act_4b)
concat_5 = concatenate([upsamp_5,act_3b],axis=3)
conv_5a = SeparableConv2D(n_filters*4,(nConv,nConv),padding="same", kernel_initializer = "he_normal")(concat_5)
act_5a = Activation("relu")(conv_5a)
conv_5b = SeparableConv2D(n_filters*4,(nConv,nConv),padding="same", kernel_initializer = "he_normal")(act_5a)
act_5b = Activation("relu")(conv_5b)


# upsamp_6 = UpSampling2D((4,4))(act_5b)
upsamp_6 = Conv2DTranspose(n_filters*2, (nConv,nConv), strides=(nPool,nPool),padding="same")(act_5b)
concat_6 = concatenate([upsamp_6,act_2b],axis=3)
conv_6a = SeparableConv2D(n_filters*2,(nConv,nConv),padding="same",kernel_initializer = "he_normal")(concat_6)
act_6a = Activation("relu")(conv_6a)
conv_6b = SeparableConv2D(n_filters*2,(nConv,nConv),padding="same",kernel_initializer = "he_normal")(act_6a)
act_6b = Activation("relu")(conv_6b)

# upsamp_7 = UpSampling2D((4,4))(act_6b)
upsamp_7 = Conv2DTranspose(n_filters, (nConv,nConv), strides=(nPool,nPool),padding="same")(act_6b)
concat_7 = concatenate([upsamp_7,act_1b],axis=3)
conv_7a = SeparableConv2D(n_filters,(nConv,nConv),padding="same",kernel_initializer = "he_normal")(concat_7)
act_7a = Activation("relu")(conv_7a)
conv_7b = SeparableConv2D(n_filters,(nConv,nConv),padding="same",kernel_initializer = "he_normal")(act_7a)
act_7b = Activation("relu")(conv_7b)

out2 = Conv2D(scaled_train_labels.sizes['type'],(1,1),padding="same",activation=out_act)(act_7b)


mod2 = Model(cnn_input2, out2)
mod2.compile(optimizer="adam", loss=loss_fun,metrics=['acc'])
mod2.summary()
run_num=0

In [None]:
plot_model(mod,show_shapes=True)
# plot_model(mod,show_shapes=True,to_file="results/holodec_"+nn_descript+'_'+ds_file.replace(".nc","")+".png")

In [None]:
scaled_train_labels.sizes

In [None]:
history = mod.fit(scaled_in_data[valid_index:split_index].values,
                  scaled_train_labels.values, 
                  batch_size=64, epochs=num_epochs, verbose=1,
                  validation_data=(scaled_in_data[:valid_index].values,scaled_val_labels.values))
run_num+=1

In [None]:
epochs = np.arange(len(history.history['loss']))+1
fig, ax = plt.subplots(1, 1, figsize=(8, 4))
ax.plot(epochs,history.history['loss'],'bo-',alpha=0.5,label='Training')
ax.plot(epochs,history.history['val_loss'],'rs-',alpha=0.5,label='Validation')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.grid(b=True)
plt.legend()
plt.savefig("results/LossHistory_"+nn_descript+'_'+ds_file.replace(".nc","")+f"_{num_epochs}epochs_run{run_num}"+".png", dpi=200, bbox_inches="tight")

fig, bx = plt.subplots(1, 1, figsize=(8, 4))
bx.plot(epochs,history.history['acc'],'bo-',alpha=0.5,label='Training')
bx.plot(epochs,history.history['val_acc'],'rs-',alpha=0.5,label='Validation')
bx.set_xlabel('Epoch')
bx.set_ylabel('Accuracy')
bx.grid(b=True)
plt.legend()
plt.savefig("results/AccuracyHistory_"+nn_descript+'_'+ds_file.replace(".nc","")+f"_{num_epochs}epochs_run{run_num}"+".png", dpi=200, bbox_inches="tight")


In [None]:
save_model(mod, ds_path+"/models/holodec_"+nn_descript+'_'+ds_file.replace(".nc","")+f"{num_epochs}epochs_run{run_num}"+".h5", save_format="h5")

In [None]:
# can skip the training process and just load the CNN model
# mod = load_model(ds_path+"/models/holodec_UNET_image_data_256x256_5000count30epochs_run1.h5")

In [None]:
cnn_start = datetime.datetime.now()
preds_out = mod.predict(scaled_in_data.values, batch_size=64)
cnn_stop = datetime.datetime.now()
print(f"{scaled_in_data.values.shape[0]} samples in {(cnn_stop-cnn_start).total_seconds()} seconds")
print(f"for {(cnn_stop-cnn_start).total_seconds()/scaled_in_data.values.shape[0]} seconds per hologram")

In [None]:
preds_out.shape

In [None]:
preds_out_da = xr.DataArray(preds_out,dims=('hologram_number','xsize','ysize','type'),
                            coords=all_labels.coords)

In [None]:
preds_original = scaler.inverse_transform(preds_out_da)

In [None]:
iscatter = np.nonzero(preds_original.sel(type='amplitude').values.flatten() > 0.2)
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
for a, clabel in enumerate(all_labels.coords['type'].values):
    ax=axes.ravel()[a]
    ax.scatter(all_labels.sel(type=clabel).values.flatten()[iscatter], preds_original.sel(type=clabel).values.flatten()[iscatter], 1, 'k')
    diag = np.linspace(all_labels.sel(type=clabel).min(), all_labels.sel(type=clabel).max(), 10)
    ax.plot(diag, diag, 'b--' )
    ax.set_title(clabel)
    plt.savefig("results/"+nn_descript+f"_ScatterPlot"+f"_{num_epochs}epochs_run{run_num}_"+ds_file.replace(".nc","")+".png",dpi=300)



In [None]:
mean_error = (preds_original[split_index:] - test_labels).mean(dim=('hologram_number','xsize','ysize'))
std_error = (preds_original[split_index:] - test_labels).std(dim=('hologram_number','xsize','ysize'))

In [None]:
mean_error

In [None]:
std_error

In [None]:
index_list = [18,2854,1247,858,3143,832,4021,3921,222,2431,321]

diff_cmap = plt.get_cmap('seismic')
diff_cmap.set_bad(color='black')

for ind in index_list:
    fig, ax = plt.subplots(2, 3, figsize=(12, 8))
    ax = ax.ravel()
    inan_mask = np.nonzero((preds_original.sel(type='amplitude',hologram_number=ind).values < 0.1)* \
        (all_labels.sel(type='amplitude',hologram_number=ind).values < 0.1))
    nan_mask = np.ones(preds_original.sel(type='amplitude',hologram_number=ind).values.shape)
    nan_mask[inan_mask] = np.nan
    ax[0].imshow(preds_original.sel(type='amplitude',hologram_number=ind).values,vmin=0,vmax=1)
    ax[1].imshow(all_labels.sel(type='amplitude',hologram_number=ind).values,vmin=0,vmax=1)
    ax[2].imshow(preds_original.sel(type='amplitude',hologram_number=ind).values-all_labels.sel(type='amplitude',hologram_number=ind).values,vmin=-1,vmax=1,cmap=diff_cmap)
    ax[3].imshow(preds_original.sel(type='z',hologram_number=ind).values*nan_mask,vmin=ds.attrs['zmin'],vmax=ds.attrs['zmax'])
    ax[4].imshow(all_labels.sel(type='z',hologram_number=ind).values*nan_mask,vmin=ds.attrs['zmin'],ds.attrs['zmax'])
    ax[5].imshow((preds_original.sel(type='z',hologram_number=ind).values-all_labels.sel(type='z',hologram_number=ind).values)*nan_mask,vmin=-1e-2,vmax=1e-2,cmap=diff_cmap)
    plt.savefig("results/"+nn_descript+f"_ExampleImage{ind}"+f"_{num_epochs}epochs_run{run_num}_"+ds_file.replace(".nc","")+".png",dpi=300)

In [None]:
channel_number = in_data.sizes['channel']
# index_list = [18]
for ind in index_list:
    fig, ax = plt.subplots(2, channel_number//2, figsize=(np.minimum(channel_number*3,12), 8))
    for ai in range(channel_number):
        axind = ai//2+np.mod(ai,2)*channel_number//2
        ax[np.mod(ai,2),ai//2].imshow(scaled_in_data.isel(channel=ai,hologram_number=ind),vmin=-2,vmax=2)
    plt.savefig("results/"+nn_descript+f"_ExampleInput{ind}"+f"_{num_epochs}epochs_run{run_num}_"+ds_file.replace(".nc","")+".png",dpi=300)

In [None]:
in_data.sizes

In [None]:
from mpl_toolkits.mplot3d import Axes3D
xg,yg = np.meshgrid(preds_original['xsize'].values,preds_original['ysize'].values)

In [None]:
for ind in index_list:
    ipart = np.nonzero(preds_original.sel(type='amplitude',hologram_number=ind).values > 0.2)
    amp_p = preds_original.sel(type='amplitude',hologram_number=ind).values[ipart]
    z_p = preds_original.sel(type='z',hologram_number=ind).values[ipart]
    x_p = xg[ipart]
    y_p = yg[ipart]

    fig = plt.figure(figsize=(10,6))
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(z_p,x_p,y_p,c=amp_p,vmin=0,vmax=1)
    ax.set_xlim([ds.attrs['zmin'],ds.attrs['zmax']])
    ax.set_ylim([preds_original['xsize'].values[0],preds_original['xsize'].values[1]])
    ax.set_zlim([preds_original['ysize'].values[0],preds_original['ysize'].values[1]])
    ax.set_xlabel('z')
    ax.set_ylabel('x')
    ax.set_zlabel('y')