In [1]:
import numpy as np
import xarray as xr
import pandas as pd
from tensorflow.keras.layers import Input, Conv2D, Dense, Flatten, Activation, MaxPool2D
from tensorflow.keras.models import Model, save_model, load_model
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
import datetime
%matplotlib inline

In [2]:
ds_path = "/scr/sci/mhayman/holodec/holodec-ml-data/"
# ds_path = "../../"
ds_name = "synthethic_holograms_ft_ac_complex_v0.nc"
# ds_name = "synthethic_holograms_ft_ac_v0.nc"
ds = xr.open_dataset(ds_path+ds_name)  # file with mean (DC) value removed

In [3]:
split_index = 7000
all_labels = ds[["d"]].to_dataframe()
train_labels = all_labels.iloc[:split_index]
test_labels = all_labels.iloc[split_index:]
scaler = MinMaxScaler()
scaled_train_labels = pd.DataFrame(scaler.fit_transform(train_labels), index=train_labels.index, columns=train_labels.columns)
scaled_test_labels = pd.DataFrame(scaler.transform(test_labels), index=test_labels.index, columns=test_labels.columns)

In [4]:
in_data = ds["image_ft"].transpose("hologram_number", "ysize", 'xsize').expand_dims("channel", 3)

ValueError: axes don't match array

In [None]:
ds["image_ft"].dims

In [None]:
ds["image_ft"].max()

In [None]:
scaled_in_data = in_data / 255

In [None]:
conv_input = Input(shape=(400, 600, 1))
conv_1 = Conv2D(8, (5, 5), padding="same")(conv_input)
act_1 = Activation("relu")(conv_1)
pool_1 = MaxPool2D(pool_size=(4, 4))(act_1)
conv_2 = Conv2D(16, (5, 5), padding="same")(pool_1)
act_2 = Activation("relu")(conv_2)
pool_2 = MaxPool2D(pool_size=(4, 4))(act_2)
conv_3 = Conv2D(32, (5, 5), padding="same")(pool_2)
act_3 = Activation("relu")(conv_3)
pool_3 = MaxPool2D(pool_size=(4, 4))(act_3)
flat = Flatten()(pool_3)
dense_1 = Dense(64, activation="relu")(flat)
dense_2 = Dense(32, activation="relu")(dense_1)
out = Dense(all_labels.shape[1])(dense_2)  # number of outputs determined by the parameters we are training to
mod = Model(conv_input, out)
mod.compile(optimizer="adam", loss="mae")
mod.summary()

In [None]:
mod.fit(scaled_in_data[:split_index].values, scaled_train_labels.values, batch_size=16, epochs=30, verbose=1)

In [None]:
# can skip the training process and just load the CNN model
mod = load_model("holodec_ft_cnn.h5")

In [None]:
cnn_start = datetime.datetime.now()
preds_out = mod.predict(scaled_in_data.values, batch_size=64)
cnn_stop = datetime.datetime.now()
print(f"{scaled_in_data.values.shape[0]} samples in {(cnn_stop-cnn_start).total_seconds()} seconds")
print(f"for {(cnn_stop-cnn_start).total_seconds()/scaled_in_data.values.shape[0]} seconds per hologram")

In [None]:
save_model(mod, "holodec_ft_cnn.h5", save_format="h5")

In [None]:
preds_original = scaler.inverse_transform(preds_out)

In [None]:
(test_labels.max() - test_labels.min())

In [None]:
fig, axes = plt.subplots(1, 1, figsize=(8, 8))
axes.scatter(test_labels, preds_original[split_index:], 1, 'k')
diag = np.linspace(test_labels.min(), test_labels.max(), 10)
axes.plot(diag, diag, 'b--' )
axes.set_title(test_labels.columns[0])
plt.savefig("error_hist_fft_"+ds_name.replace(".nc",".png"), dpi=200, bbox_inches="tight")
# for a, ax in enumerate(axes.ravel()):
    

In [None]:
fig, axes = plt.subplots(1, 1, figsize=(8, 8))
axes.hist( (preds_original[split_index:] - test_labels.values)/ (test_labels.values.max() - test_labels.values.min()) ,
           bins=20)
axes.set_yscale("log")
axes.set_xlabel("Error in "+test_labels.columns[0])
plt.savefig("relative_error_histogram_fft_"+ds_name.replace(".nc",".png"), dpi=200, bbox_inches="tight")

In [None]:
np.mean(np.abs(preds_original[split_index:] - test_labels.values), axis=0)

In [None]:
np.mean(np.abs(preds_out[split_index:] - scaled_test_labels.values), axis=0)

In [None]:
test_labels["d"].max()

In [None]:
ds["xsize"]