In [1]:
from holodecml.data import load_raw_datasets, load_unet_datasets, load_unet_datasets_xy
from holodecml.losses import unet_loss, unet_loss_xy
from holodecml.models import custom_unet, custom_jnet, custom_jnet_full
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
import xarray as xr
import os
from os.path import join
import yaml
from tensorflow.keras.optimizers import Adam


In [2]:
path_data = "/glade/p/cisl/aiml/ai4ess_hackathon/holodec/"
num_particles = "medium"
output_cols = ["x", "y", "z", "d", "hid"]
subset = False
scaler_out = MinMaxScaler()
bin_factor = 10
h = 0


In [None]:
train_inputs_raw, train_outputs_raw = load_raw_datasets(path_data, num_particles,
                                                'train', output_cols, subset)


In [None]:
valid_inputs_raw, valid_outputs_raw = load_raw_datasets(path_data, num_particles,
                                                'valid', output_cols, subset)


In [None]:
print(len(np.where(valid_outputs_raw["hid"] == h + 1)[0]))
fig, ax = plt.subplots(figsize=(18, 12))
ax.imshow(valid_inputs_raw[h].T, cmap='gray', vmin=0, vmax=255)
ax.set_xticks([])
ax.set_yticks([])
plt.show()


In [None]:
def plot_hologram_xy(h, inputs, outputs):
    """
    Given a hologram number, plot hologram and particle point
    
    Args: 
        h: (int) hologram index
        inputs: (pd df) input images
        outputs: (pd df) output x and y values by hid
    
    Returns:
        print of pseudocolor plot of hologram and hologram particles
    """    
    x_vals = np.linspace(-888, 888, inputs[h, :, :].shape[0])
    y_vals = np.linspace(-592, 592, inputs[h, :, :].shape[1])

    plt.figure(figsize=(12, 8))
    plt.pcolormesh(x_vals, y_vals, inputs[h, :, :].T, cmap="RdBu_r")
    h_particles = np.where(outputs["hid"] == h + 1)[0]
    for h_particle in h_particles:
        plt.scatter(outputs.loc[h_particle, "x"],
                    outputs.loc[h_particle, "y"],
                    c="b", s=100)
    plt.xlabel("horizontal particle position (µm)", fontsize=16)
    plt.ylabel("vertical particle position (µm)", fontsize=16)
    plt.title(f"Hologram and particle positions plotted in two dimensions: {h_particles.shape[0]} particles", fontsize=20, pad=20)


In [None]:
plot_hologram_xy(h, valid_inputs_raw, valid_outputs_raw)


In [None]:
train_inputs, train_outputs, valid_inputs, valid_outputs = load_unet_datasets(path_data, num_particles, 
                                                                              output_cols, scaler_out, subset, bin_factor)

In [None]:
with open("../../config/jnet_xy.yml") as config_file:
    config = yaml.load(config_file, Loader=yaml.FullLoader)

path_data = config["path_data"]
path_save = config["path_save"]
if not os.path.exists(path_save):
    os.makedirs(path_save)
num_particles = config["num_particles"]
output_cols = config["output_cols"]
seed = config["random_seed"]


In [None]:
model = custom_jnet_full(
    np.expand_dims(train_inputs, axis=-1).shape[1:],
    **config["unet"]
)
model.compile(optimizer=Adam(lr=config["train"]['learning_rate']), loss=unet_loss)
model.summary()


In [None]:
hist = model.fit(
    np.expand_dims(train_inputs, axis=-1),
    train_outputs,
    batch_size=config["train"]['batch_size'],
    epochs=config["train"]['epochs'],
    validation_data=(np.expand_dims(valid_inputs, axis=-1), valid_outputs),
    verbose=config["train"]["verbose"]
)

In [None]:
valid_outputs_pred = xr.open_dataset("/glade/p/cisl/aiml/ggantos/holodec/unet/jnet_10_dz/valid_outputs_pred.nc")
valid_outputs_pred = valid_outputs_pred.to_array().values[0]


In [None]:
image_pred = valid_outputs_pred[h, :, :, 0]
image_true = valid_outputs[h, :, :, 0]

coords_true = np.where(image_true > 0)

idx = np.argwhere(np.diff(np.sort(valid_outputs_pred[h, :, :, 0].flatten())) > .0001)+1
pred_argsort = valid_outputs_pred[h, :, :, 0].flatten().argsort()
coords_pred = []
for i in pred_argsort[-idx.shape[0]:][::-1]:
    coord = np.array([c[0] for c in np.where(image_pred == image_pred.flatten()[i])])
    coords_pred.append(coord)
coords_pred = np.stack(coords_pred)


In [None]:
fig=plt.figure(figsize=(12, 8))
plt.pcolormesh(np.log(valid_outputs_pred[0, :, :, 0]).T, cmap="RdBu_r")
plt.colorbar()
plt.scatter(np.where(image_true > 0)[0], np.where(image_true > 0)[1], color='blue', s=100, label="True")
print(np.sum(valid_outputs_pred[h, :, :, 0]))
print(np.min(valid_outputs_pred[h, :, :, 0]))
print(np.max(valid_outputs_pred[h, :, :, 0]))
plt.title(f'Log of probability field for validation hologram {h}', fontsize=20)
plt.legend(fontsize=20)
plt.xticks([])
plt.yticks([])
plt.savefig("./prob_field_log.png", dpi=200, bbox_inches="tight")


In [None]:
plt.plot(np.diff(np.sort(valid_outputs_pred[h, :, :, 0].flatten())), color='red')
plt.plot(np.diff(np.sort(valid_outputs_pred[h, :, :, 0].flatten()))[np.diff(np.sort(valid_outputs_pred[h, :, :, 0].flatten())) > .0001], color='blue')

In [None]:
plt.figure(figsize=(12, 8))
x_vals = np.linspace(0, 60, valid_inputs[h, :, :].shape[0])
y_vals = np.linspace(0, 40, valid_inputs[h, :, :].shape[1])
plt.xticks([])
plt.yticks([])
plt.pcolormesh(x_vals, y_vals, valid_inputs[h, :, :].T, cmap="RdBu_r")
plt.scatter(np.where(image_true > 0)[0], np.where(image_true > 0)[1], color='blue', s=100, label="True", zorder=2)
plt.scatter(coords_pred[:, 0], coords_pred[:, 1], color='red', s=100, label="Predicted", zorder=1)
plt.legend(fontsize=20)
plt.title(f'{int(np.sum(image_true))} True vs Top {idx.shape[0]} Predicted Particles for validation hologram {h}', fontsize=20)
plt.savefig("./true_vs_pred_diff.png", dpi=200, bbox_inches="tight")


In [None]:
pred_argsort = valid_outputs_pred[h, :, :, 0].flatten().argsort()
coords_pred = []
for i in pred_argsort[-int(np.sum(image_true)):][::-1]:
    coord = np.array([c[0] for c in np.where(image_pred == image_pred.flatten()[i])])
    coords_pred.append(coord)
coords_pred = np.stack(coords_pred)


In [None]:
plt.figure(figsize=(12, 8))
x_vals = np.linspace(0, 60, valid_inputs[h, :, :].shape[0])
y_vals = np.linspace(0, 40, valid_inputs[h, :, :].shape[1])
plt.xticks([])
plt.yticks([])
plt.pcolormesh(x_vals, y_vals, valid_inputs[h, :, :].T, cmap="RdBu_r")
plt.scatter(np.where(image_true > 0)[0], np.where(image_true > 0)[1], color='blue', s=100, label="True", zorder=2)
plt.scatter(coords_pred[:, 0], coords_pred[:, 1], color='red', s=100, label="Predicted", zorder=1)
plt.legend(fontsize=20)
plt.title(f'{int(np.sum(image_true))} True vs Top {int(np.sum(image_true))} Predicted Particles for validation hologram {h}', fontsize=20)
plt.savefig("./true_vs_pred_toptrue.png", dpi=200, bbox_inches="tight")


In [None]:
fig=plt.figure(figsize=(12, 8))
plt.imshow(valid_outputs[h, :, :, 0].T, interpolation='bilinear', cmap=plt.cm.gray, aspect='auto', vmin=0, vmax=1)
plt.title(f'True probability field for validation hologram {h}\nSum of non-zero values: {np.sum(valid_outputs[h, :, :, 0]):.2f}\nMax predicted value: {np.max(valid_outputs[h, :, :, 0]):.2f}', fontsize=20)
plt.savefig("./prob_true.png", dpi=200, bbox_inches="tight")


In [None]:
fig=plt.figure(figsize=(12, 8))
plt.imshow(valid_outputs_pred[h, :, :, 0].T, interpolation='bilinear', cmap=plt.cm.gray, aspect='auto', vmin=0, vmax=1)
plt.title(f'Predicted probability field for validation hologram {h}\nSum of non-zero values: {np.sum(valid_outputs_pred[h, :, :, 0]):.2f}\nMax predicted value: {np.max(valid_outputs_pred[h, :, :, 0]):.2f}', fontsize=20)
plt.savefig("./prob_pred.png", dpi=200, bbox_inches="tight")


In [None]:
xr.DataArray(valid_outputs_pred[:, :, :, 0]).to_netcdf(path='holo_all.nc')

In [None]:
fig=plt.figure(figsize=(12, 8))
plt.imshow(valid_outputs_pred[h, :, :, 1].T, interpolation='bilinear', cmap=plt.cm.gray, aspect='auto', vmin=0, vmax=1)
plt.title(f'Predicted Z-coordinate field for validation hologram {h}', fontsize=20)
plt.savefig("./z_pred.png", dpi=200, bbox_inches="tight")


In [None]:
fig=plt.figure(figsize=(12, 8))
plt.imshow(valid_outputs[h, :, :, 1].T, interpolation='bilinear', cmap=plt.cm.gray, aspect='auto', vmin=0, vmax=1)
plt.title(f'True Z-coordinate field for validation hologram {h}', fontsize=20)
plt.savefig("./z_true.png", dpi=200, bbox_inches="tight")


In [None]:
fig=plt.figure(figsize=(12, 8))
plt.imshow(valid_outputs_pred[h, :, :, 2].T, interpolation='bilinear', cmap=plt.cm.gray, aspect='auto', vmin=0, vmax=1)
plt.title(f'Predicted Diameter field for validation hologram {h}', fontsize=20)
plt.savefig("./d_pred.png", dpi=200, bbox_inches="tight")


In [None]:
fig=plt.figure(figsize=(12, 8))
plt.imshow(valid_outputs[h, :, :, 2].T, interpolation='bilinear', cmap=plt.cm.gray, aspect='auto', vmin=0, vmax=1)
plt.title(f'True Diameter field for validation hologram {h}', fontsize=20)
plt.savefig("./d_true.png", dpi=200, bbox_inches="tight")


In [None]:
valid_outputs[h, 55, 19, 2]

In [None]:
np.where(valid_outputs[h, :, :, 2])


In [None]:
fig=plt.figure(figsize=(12, 8))
plt.scatter(valid_outputs[h, :, :, 2][np.where(valid_outputs[h, :, :, 2])], valid_outputs_pred[h, :, :, 2][np.where(valid_outputs[h, :, :, 2])])
plt.title(f'Diameter field for validation hologram {h}', fontsize=20)
plt.savefig("./d_scatter.png", dpi=200, bbox_inches="tight")


In [None]:
fig=plt.figure(figsize=(12, 8))
plt.scatter(valid_outputs[h, :, :, 1][np.where(valid_outputs[h, :, :, 1])], valid_outputs_pred[h, :, :, 1][np.where(valid_outputs[h, :, :, 1])])
plt.title(f'Z-coordinate field for validation hologram {h}', fontsize=20)
plt.savefig("./z_scatter.png", dpi=200, bbox_inches="tight")


In [None]:
fig=plt.figure(figsize=(12, 8))
plt.scatter(valid_outputs[:, :, :, 1][np.where(valid_outputs[:, :, :, 1])], valid_outputs_pred[:, :, :, 1][np.where(valid_outputs[:, :, :, 1])])
plt.title(f'Z-coordinate field for validation hologram {h}', fontsize=20)
plt.savefig("./z_scatter_all.png", dpi=200, bbox_inches="tight")


In [None]:
fig=plt.figure(figsize=(12, 8))
plt.scatter(valid_outputs[:, :, :, 2][np.where(valid_outputs[:, :, :, 2])], valid_outputs_pred[:, :, :, 2][np.where(valid_outputs[:, :, :, 2])])
plt.title(f'Diameter field for validation hologram {h}', fontsize=20)
plt.savefig("./d_scatter_all.png", dpi=200, bbox_inches="tight")
