In [1]:
import warnings
warnings.filterwarnings("ignore")

import copy
import optuna
import random
import joblib
import logging
import traceback
import numpy as np

from holodecml.torch.utils import *
from holodecml.torch.losses import *
from holodecml.torch.visual import *
from holodecml.torch.models import *
from holodecml.torch.trainers import *
from holodecml.torch.transforms import *
from holodecml.torch.optimizers import *
from holodecml.torch.data_loader import *
from holodecml.torch.beam_search import *

from aimlutils.hyper_opt.base_objective import *
from aimlutils.torch.checkpoint import *
#from aimlutils.torch.losses import *

from torch import nn
from torch.optim.lr_scheduler import *
from torch.utils.data import Dataset, DataLoader
from typing import List, Dict, Callable, Union, Any, TypeVar, Tuple

import matplotlib.pyplot as plt
from collections import defaultdict

In [2]:
is_cuda = torch.cuda.is_available()
device = torch.device(torch.cuda.current_device()) if is_cuda else torch.device("cpu")

In [16]:
with open("results/test/model.yml") as config_file:
    conf = yaml.load(config_file, Loader=yaml.FullLoader)

In [4]:
# conf["validation_data"]["path_data"] = ['/glade/p/cisl/aiml/ai4ess_hackathon/holodec/synthetic_holograms_12-25particle_gamma_600x400_training.nc']
# conf["validation_data"]["maxnum_particles"] = 25

In [4]:
if device == "cpu":
    conf["train_iterator"]["pin_memory"] = False
    conf["valid_iterator"]["pin_memory"] = False

In [5]:
# Load the image and data transformations
train_transform = LoadTransformations(conf["train_transforms"], device = device)
valid_transform = LoadTransformations(conf["validation_transforms"], device = device)

scaler_path = os.path.join(conf["trainer"]["path_save"], "scalers.save")

In [6]:
# Load the readers

train_gen = LoadReader(
    transform = train_transform,
    scaler = joblib.load(scaler_path) if os.path.isfile(scaler_path) else True,
    config = conf["train_data"]
)

if not os.path.isfile(scaler_path):
    joblib.dump(train_gen.scaler, scaler_path)

valid_gen = LoadReader(
    transform = valid_transform, 
    scaler = train_gen.scaler,
    config = conf["validation_data"]
)

Loaded data scaler transformation {'x': StandardScaler(copy=True, with_mean=True, with_std=True), 'y': StandardScaler(copy=True, with_mean=True, with_std=True), 'z': StandardScaler(copy=True, with_mean=True, with_std=True), 'd': StandardScaler(copy=True, with_mean=True, with_std=True)}
Loaded data scaler transformation {'x': StandardScaler(copy=True, with_mean=True, with_std=True), 'y': StandardScaler(copy=True, with_mean=True, with_std=True), 'z': StandardScaler(copy=True, with_mean=True, with_std=True), 'd': StandardScaler(copy=True, with_mean=True, with_std=True)}


In [7]:
# Load data iterators 

train_dataloader = DataLoader(
    train_gen,
    **conf["train_iterator"]
)

valid_dataloader = DataLoader(
    valid_gen,
    **conf["valid_iterator"]
)

### Load models 

In [8]:
PAD_token = 0
SOS_token = 1
EOS_token = 2

In [18]:
vae_conf = conf["vae"]

In [19]:
decoder_conf = conf["decoder"]

In [20]:
regressor_conf = conf["regressor"]

### Load the VAE

In [22]:
vae = LoadModel(vae_conf)
vae.build()
vae = vae.to(device)

In [23]:
# model_dict = torch.load(
#     vae_model_weights,
#     map_location=lambda storage, loc: storage
# )
# vae.load_state_dict(model_dict["model_state_dict"])

### Load RNN and regression models 

In [15]:
decoder_conf["output_size"] = len(train_gen.token_lookup) + 3
decoder = DecoderRNN(**decoder_conf).to(device)
print(decoder_conf["output_size"])

124


In [16]:
decoder_model_weights = conf["callbacks"]["MetricsLogger"]["path_save"] + "/best_rnn.pt"
model_dict = torch.load(
    decoder_model_weights,
    map_location=lambda storage, loc: storage
)
decoder.load_state_dict(model_dict["model_state_dict"])

<All keys matched successfully>

In [17]:
regressor = DenseNet2(**conf["regressor"])

In [18]:
regressor.build(vae_conf["z_dim"] + decoder_conf["hidden_size"] + 1250)
regressor = regressor.to(device)

In [19]:
linear_model_weights = conf["callbacks"]["MetricsLogger"]["path_save"] + "/best_linear.pt"
model_dict = torch.load(
    linear_model_weights,
    map_location=lambda storage, loc: storage
)
regressor.load_state_dict(model_dict["model_state_dict"])

<All keys matched successfully>

### Predict with the models

In [20]:
# Set up the beam-search object

beam_search = BeamSearch(
    end_index = EOS_token, 
    max_steps = valid_gen.maxnum_particles, 
    beam_size = 10
)
_bleu = BLEU(exclude_indices={PAD_token, EOS_token, SOS_token})

In [21]:
def predict(epoch = 0):
    
    vae.eval()
    decoder.eval()
    regressor.eval()
    
    with torch.no_grad():
    
        batch_size = conf["valid_iterator"]["batch_size"]
        batches_per_epoch = int(np.ceil(valid_gen.__len__() / batch_size))

#         batch_group_generator = tqdm(
#             enumerate(valid_dataloader), 
#             total=batches_per_epoch, 
#             leave=True
#         )

        epoch_losses = {"mse": [], "cce": [], "accuracy": [], "stop_accuracy": [], "bleu": []}
        for idx, (images, y_out, w_out) in enumerate(valid_dataloader):
            images = images.to(device)
            y_out = {task: value.to(device) for task, value in y_out.items()}
            w_out = w_out.to(device)

            with torch.no_grad():
                # 1. Predict the latent vector and image reconstruction
                z, mu, logvar, encoder_att = vae.encode(images)
                image_pred, decoder_att = vae.decode(z)
                
                combined_att = torch.cat([
                    encoder_att[2].flatten(start_dim = 1),
                    decoder_att[0].flatten(start_dim = 1)
                ], 1)
                combined_att = combined_att.clone()
                
                encoder_att = [x.detach().cpu().numpy() for x in encoder_att]
                decoder_att = [x.detach().cpu().numpy() for x in decoder_att]

                if vae.out_image_channels > 1:
                    z_real = np.sqrt(0.5) * image_pred[:,0,:,:]
                    z_imag = image_pred[:,1,:,:]
                    image_pred = torch.square(z_real) + torch.square(z_imag)
                    image_pred = torch.unsqueeze(image_pred, 1)

            # 2. Predict the number of particles
            decoder_input = torch.LongTensor([SOS_token] * w_out.shape[0]).to(device)
            encoded_image = z.to(device)
            decoder_hidden = encoded_image.clone().reshape((1, w_out.shape[0], z.size(-1)))
            
            n_dims = 2 if decoder.bidirectional else 1
            n_dims *= decoder.n_layers
            if n_dims > 1:
                decoder_hidden = torch.cat([decoder_hidden for k in range(n_dims)])

            target_tensor = w_out.long()
            target_length = w_out.shape[1]
            seq_lens = w_out.max(axis = 1)[0] + 1
            class_weights = torch.ones(w_out.shape).to(device)

            hidden_vectors = []
            bleu, accuracy, stop_accuracy, rnn_loss = [], [], [], []

            # Use beam search to get predictions
            predictions, probabilities = beam_search.search(
                decoder_input, decoder_hidden, decoder
            )

            # Validate on top-1 most likely sequence
            top_preds = predictions[:, 0, :]

            # Compute bleu metric for each sequence in the batch
            for pred, true in zip(top_preds, target_tensor):
                _bleu(pred.unsqueeze(0), true.unsqueeze(0))
            epoch_losses["bleu"].append(_bleu.get_metric(reset=False)["BLEU"])

            # Reshape the predicted tensor to match with the target_tensor
            ## This will work only if limit the beam search = target size
            B, T = target_tensor.size()
            _, t = top_preds.size()
            if t < T:
                reshaped_preds = torch.zeros(B, T)
                reshaped_preds[:, :t] = top_preds
                reshaped_preds = reshaped_preds.long().to(device)
            else:
                reshaped_preds = top_preds

            # Use greedy evaluation to get the loss
            di = 0
            while (di < 101):
            #for di in range(target_length + 1):
                decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, seq_lens)
                topv, topi = decoder_output.topk(1)
                #if batch_size > 1:
                #decoder_input = topi.squeeze().detach()  # detach from history as input
                decoder_input = reshaped_preds[:, di].detach()
                c1 = (reshaped_preds[:, di] != PAD_token)
                c2 = (reshaped_preds[:, di] != EOS_token)
                condition = c1 & c2
                real_plus_stop = torch.where(c1)
                real_particles = torch.where(condition)
                stop_token = torch.where(~c2)

                if real_plus_stop[0].size(0) == 0:
                    break

                rnn_loss.append(
                    nn.NLLLoss()(
                        decoder_output[real_plus_stop], 
                        target_tensor[:, di][real_plus_stop]
                    )
                )
                accuracy += [
                    int(i.item()==j.item())
                    for i, j in zip(reshaped_preds[:, di][real_particles], 
                                    target_tensor[:, di][real_particles])
                ]

                if stop_token[0].size(0) > 0:
                    stop_accuracy += [
                        int(i.item()==j.item()) 
                        for i, j in zip(reshaped_preds[:, di][stop_token], 
                                        target_tensor[:, di][stop_token])
                    ]

                if real_particles[0].size(0) > 0:
                    token_input = reshaped_preds[:, di] # topi.squeeze()
                    embedding = decoder.embed(token_input).squeeze(0)
                    hidden_vectors.append([real_particles, embedding])
                    
                di += 1
                
            accuracy = np.mean(accuracy)
            epoch_losses["accuracy"].append(accuracy)
            epoch_losses["stop_accuracy"].append(np.mean(stop_accuracy))

            rnn_loss = torch.mean(torch.stack(rnn_loss))
            epoch_losses["cce"].append(rnn_loss.item())

            if len(hidden_vectors) == 0:
                continue

            particle_results = {"x": [], "y": [], "z": [], "d": [], "tokens": []}
            particle_true = {"x": [], "y": [], "z": [], "d": []}
            
            # 3. Use particle embeddings to predict (x,y,z,d)
            real_parts = []
            regressor_loss = []
            for di in range(len(hidden_vectors)):
                real_particles, h_vecs = hidden_vectors[di]
                x_input = torch.cat([h_vecs.detach(), encoded_image, combined_att], axis = 1)
                particle_attributes = regressor(x_input[real_particles])
                loss = []
                for task in ["x", "y", "z", "d"]:
                    _loss = nn.L1Loss()(
                        particle_attributes[task].squeeze(1), 
                        y_out[task][:, di][real_particles].float()
                    )
                    loss.append(_loss)
                    particle_results[task].append(particle_attributes[task].cpu().numpy().squeeze(1))
                    particle_true[task].append(y_out[task][:, di].float().cpu().numpy())
                real_parts.append(real_particles[0].cpu().numpy()) 
                regressor_loss.append(torch.mean(torch.stack(loss)))
                particle_results["tokens"].append(real_particles)
            regressor_loss = torch.mean(torch.stack(regressor_loss))
            epoch_losses["mse"].append(regressor_loss.item())
    
            result = {
                    "image_true": images.cpu().numpy(),
                    "image_pred": image_pred.cpu().numpy(),
                    "particle_true": particle_true,
                    "particle_pred": particle_results,
                    "accuracy": accuracy,
                    "stop_accuracy": np.mean(stop_accuracy),
                    "cce": rnn_loss.item(),
                    "rmae": regressor_loss.item(),
                    "real_particles": real_parts,
                    "encoder_att": encoder_att,
                    "decoder_att": decoder_att
                }
            
            yield result

In [22]:
def val_loss(val_gen):
    batch_size = conf["valid_iterator"]["batch_size"]
    batches_per_epoch = int(np.ceil(valid_gen.__len__() / batch_size))
    batch_group_generator = tqdm(
        enumerate(val_gen), 
        total=batches_per_epoch, 
        leave=True
    )
#     results = {
#         "accuracy": [],
#         "stop_accuracy": [],
#         "cce": [],
#         "mae": []
#     }
    results = defaultdict(list)
    for result in batch_group_generator:
        for key, val in result.items():
            results[key].append(val)
    return results 

In [43]:
def plot_hologram(inputs, outputs, particles, n = 0, scaler = None, true = True, vmin = None, vmax = None):
    """
    Given a hologram number, plot hologram and particle point
    
    Args: 
        h: (int) hologram index
        inputs: (pd df) input images
        outputs: (pd df) output x, y, z, and d values by hid
    
    Returns:
        print of pseudocolor plot of hologram and hologram particles
    """    
    #outputs = {x: y[n] for x,y in outputs.items()}
    
    N = 0
    for k in range(len(particles)):
        real_particles = particles[k]
        if n in real_particles:
            N += 1
    
    preds = range(N)
    
    _outputs = {"x": np.zeros(N), "y": np.zeros(N), "z": np.zeros(N), "d": np.zeros(N)}
    
    for k in range(len(particles)):
        real_particles = particles[k]
        if n in real_particles:
            listed = np.where(real_particles == n)[0]
            K = listed[0] if not true else n
            for task in _outputs:
                _outputs[task][k] = outputs[task][k][K]
    
    outputs = _outputs

    inputs = inputs[n].squeeze(0)

    x_vals = np.linspace(-888, 888, inputs.shape[0])
    y_vals = np.linspace(-592, 592, inputs.shape[1])

    plt.figure(figsize=(12, 8))    
    plt.pcolormesh(x_vals, y_vals, inputs.T, cmap="RdBu_r")
    
    if isinstance(scaler, dict):
        shape = outputs["x"].shape[0]
        outputs["x"] = scaler["x"].inverse_transform(outputs["x"].reshape(-1, shape))[0]
        outputs["y"] = scaler["y"].inverse_transform(outputs["y"].reshape(-1, shape))[0]
        outputs["z"] = scaler["z"].inverse_transform(outputs["z"].reshape(-1, shape))[0]
        outputs["d"] = scaler["d"].inverse_transform(outputs["d"].reshape(-1, shape))[0]
    
    print(f"The model predicts that there are {N} particles in the hologram")
    
    if vmax is None:
        vmax = outputs["z"].max()
    if vmin is None:
        vmin = outputs["z"].min()
    
    for h in preds: 
        plt.scatter(outputs["x"][h],
                    outputs["y"][h],
                    outputs["d"][h] ** 2,
                    outputs["z"][h],
                    vmin=vmin,
                    vmax=vmax,
                    cmap="cool")
        plt.annotate(f"d: {(outputs['d'][h]):.1f} µm",
                     (outputs["x"][h], outputs["y"][h]))
    
    plt.xlabel("horizontal particle position (µm)", fontsize=16)
    if true:
        plt.ylabel("vertical particle position (µm)", fontsize=16)
    if true:
        plt.title("True", fontsize=20, pad=20)
    else:
        plt.title("Predicted", fontsize=20, pad=20)
    plt.colorbar().set_label(label="z-axis particle position (µm)", size=16)

In [38]:
class Point:
    def __init__(self, coordinates):
        x, y, z, d = coordinates
        self.x = x
        self.y = y
        self.z = z
        self.d = d
        
def distance(p1, p2):
     return ((p1.x - p2.x) ** 2 + (p1.y - p2.y) ** 2) ** 0.5

In [39]:
results_generator = predict()

In [40]:
next_result = next(results_generator)

In [44]:
# Model returns batch results, so use N here to pick single examples

def batch_plot(result, N):

    ### Make the plot -- use the same z scale in both images
    z_dist_true = [train_gen.scaler["z"].inverse_transform([dd[N]])[0] for dd in result['particle_true']["z"]]
    z_dist_pred = [train_gen.scaler["z"].inverse_transform([dd[N]])[0] for dd in result['particle_pred']["z"]]

    vmin = min(min(z_dist_true), min(z_dist_pred))
    vmax = max(max(z_dist_true), max(z_dist_pred))

    plot_hologram(result["image_true"], result['particle_true'], result['real_particles'], 
                  n = N, scaler = train_gen.scaler, true = True, vmin = vmin, vmax = vmax)
    plot_hologram(result["image_pred"], result['particle_pred'], result['real_particles'], 
                  n = N, scaler = train_gen.scaler, true = False, vmin = vmin, vmax = vmax)

    ### Sort the particles by distance to each other and compute the loss 
    pred_part = []
    true_part = []
    for particle in range(3):
        cx1, cx2 = [], []
        for task in ["x", "y", "z", "d"]:
            x1 = train_gen.scaler[task].inverse_transform([result["particle_true"][task][particle][N]])[0]
            x2 = train_gen.scaler[task].inverse_transform([result["particle_pred"][task][particle][N]])[0]
            cx1.append(x1)
            cx2.append(x2)
        true_part.append(Point(cx1))
        pred_part.append(Point(cx2))

    tosort = []
    for p1 in range(3):
        for p2 in range(3):
            tosort.append([distance(pred_part[p1], true_part[p2]), p1, p2])

    paired = []
    seen_true = []
    seen_pred = []
    for (a,b,c) in sorted(tosort): # sort by distance
        if b not in seen_pred and c not in seen_true:
            paired.append([c, b])
            seen_true.append(c)
            seen_pred.append(b)

    this_error = defaultdict(list)
    for (x,y) in paired:
        xe = abs(true_part[x].x - pred_part[y].x)
        ye = abs(true_part[x].y - pred_part[y].y)
        ze = abs(true_part[x].z - pred_part[y].z)
        de = abs(true_part[x].d - pred_part[y].d)

        this_error["x"].append(xe)
        this_error["y"].append(ye)
        this_error["z"].append(ze)
        this_error["d"].append(de)

    for key, val in this_error.items():
        print(key, np.mean(val))
        
    #paired

In [45]:
batch_plot(next_result, 11)

IndexError: list index out of range

### Compute errors for the validation data

In [29]:
batch_size = conf["valid_iterator"]["batch_size"]
batches_per_epoch = int(np.ceil(valid_gen.__len__() / batch_size))

batch_group_generator = tqdm(
    enumerate(predict()), 
    total=batches_per_epoch, 
    leave=True
)

errors = defaultdict(list)

for m, result in batch_group_generator:
    for N in range(len(result["particle_true"]["x"][0])):
        pred_part = []
        true_part = []
        for particle in range(3):
            cx1, cx2 = [], []
            for task in ["x", "y", "z", "d"]:
                x1 = train_gen.scaler[task].inverse_transform([result["particle_true"][task][particle][N]])[0]
                x2 = train_gen.scaler[task].inverse_transform([result["particle_pred"][task][particle][N]])[0]
                cx1.append(x1)
                cx2.append(x2)
            true_part.append(Point(cx1))
            pred_part.append(Point(cx2))

        tosort = []
        for p1 in range(3):
            for p2 in range(3):
                tosort.append([distance(pred_part[p1], true_part[p2]), p1, p2])

        paired = []
        seen_true = []
        seen_pred = []
        for (a,b,c) in sorted(tosort):
            if b not in seen_pred and c not in seen_true:
                paired.append([c, b])
                seen_true.append(c)
                seen_pred.append(b)

        for (x,y) in paired:
            xe = abs(true_part[x].x - pred_part[y].x)
            ye = abs(true_part[x].y - pred_part[y].y)
            ze = abs(true_part[x].z - pred_part[y].z)
            de = abs(true_part[x].d - pred_part[y].d)

            errors["x"].append(xe)
            errors["y"].append(ye)
            errors["z"].append(ze)
            errors["d"].append(de)

  0%|          | 9/2500 [00:21<1:40:18,  2.42s/it]


IndexError: index 31 is out of bounds for axis 0 with size 31

In [None]:
for task, values in errors.items():
    print(task, np.mean(values), np.std(values) )

In [None]:
#plt.imshow(np.log(result['encoder_att'][0][0]).T)

In [None]:
#plt.imshow(result['decoder_att'][0][0].T)