# August 11 - Interpolating b/w samples

## Added the support for interpolating b/w two random samples

In [None]:
# Imports
import math
import os
import sys
import pandas as pd
import numpy as np

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from mpl_toolkits.mplot3d import Axes3D

# Add the path to the parent directory to augment search for module
par_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
if par_dir not in sys.path:
    sys.path.append(par_dir)
    
# Import the custom plotting module
from plot_utils import plot_utils
import random
import torch
from plot_utils import notebook_utils_2

# Label dict
label_dict = {0:"gamma", 1:"e", 2:"mu"}

In [None]:
run_id = "20190811_222844"
dump_dir = "/home/akajal/WatChMaL/VAE/dumps/" + run_id + "/"
model_status = "trained" 
np_arr_path = dump_dir + model_status + "_interpolations.npz"

# Load the np array into memory
np_arr = np.load(np_arr_path)
print(list(np_arr.keys()))

In [None]:
# Extract the actual interpolated events
np_events, np_labels, np_energies = np_arr["events"], np_arr["labels"], np_arr["energies"]
np_samples, np_pred_labels, np_pred_energies = np_arr["samples"], np_arr["predicted_labels"], np_arr["predicted_energies"]


print(np_events.shape, np_labels.shape, np_energies.shape)
print(np_samples.shape, np_pred_labels.shape, np_pred_energies.shape)

In [None]:
print(np_energies)
print(label_dict[np_labels[0]], label_dict[np_labels[1]])

## First plot the two original events sampled from the dataset

In [None]:
plot_utils.plot_actual_vs_recon(np_events[0], np_events[1], label_dict[np_labels[0]], np_energies[0], show_plot=True)
plot_utils.plot_charge_hist(np_events[0], np_events[1], 0, num_bins=200)

## First plot the two original reconstructed events sampled from the dataset

In [None]:
plot_utils.plot_actual_vs_recon(np_samples[10], np_samples[0], label_dict[np_pred_labels[10]],
                                np_pred_energies[10].item(), show_plot=True)

plot_utils.plot_charge_hist(np_samples[0], np_samples[10], 0, num_bins=200)

## Now loop over the interpolated samples and plot them

In [None]:
i = np_samples.shape[0]-1
while i > 0:
    plot_utils.plot_actual_vs_recon(np_samples[i], np_samples[i-1], label_dict[np_pred_labels[i]],
                                np_pred_energies[i].item(), show_plot=True)

    plot_utils.plot_charge_hist(np_samples[i], np_samples[i], 0, num_bins=200)
    i = i - 2

## Another sample interpolation

In [None]:
run_id = "20190823_072818"
dump_dir = "/home/akajal/WatChMaL/VAE/dumps/" + run_id + "/"
model_status = "trained" 
np_arr_path = dump_dir + model_status + "_interpolations.npz"

# Load the np array into memory
np_arr = np.load(np_arr_path)
print(list(np_arr.keys()))

# Extract the actual interpolated events
np_events, np_labels, np_energies = np_arr["events"], np_arr["labels"], np_arr["energies"]
np_samples, np_pred_labels, np_pred_energies = np_arr["samples"], np_arr["predicted_labels"], np_arr["predicted_energies"]


print(np_events.shape, np_labels.shape, np_energies.shape)
print(np_samples.shape, np_pred_labels.shape, np_pred_energies.shape)

print(np_energies)
print(label_dict[np_labels[0]], label_dict[np_labels[1]])

plot_utils.plot_actual_vs_recon(np_events[0], np_events[1], label_dict[np_labels[0]], np_energies[0], show_plot=True)
plot_utils.plot_charge_hist(np_events[0], np_events[1], 0, num_bins=200)

i = np_samples.shape[0]-1
while i > 0:
    plot_utils.plot_actual_vs_recon(np_samples[i], np_samples[i-1], label_dict[np_pred_labels[i]],
                                np_pred_energies[i].item(), show_plot=True)

    plot_utils.plot_charge_hist(np_samples[i], np_samples[i], 0, num_bins=200)
    i = i - 2

## Another sample

In [None]:
run_id = "20190823_073139"
dump_dir = "/home/akajal/WatChMaL/VAE/dumps/" + run_id + "/"
model_status = "trained" 
np_arr_path = dump_dir + model_status + "_interpolations.npz"

# Load the np array into memory
np_arr = np.load(np_arr_path)
print(list(np_arr.keys()))

# Extract the actual interpolated events
np_events, np_labels, np_energies = np_arr["events"], np_arr["labels"], np_arr["energies"]
np_samples, np_pred_labels, np_pred_energies = np_arr["samples"], np_arr["predicted_labels"], np_arr["predicted_energies"]


print(np_events.shape, np_labels.shape, np_energies.shape)
print(np_samples.shape, np_pred_labels.shape, np_pred_energies.shape)

print(np_energies)
print(label_dict[np_labels[0]], label_dict[np_labels[1]])

plot_utils.plot_actual_vs_recon(np_events[0], np_events[1], label_dict[np_labels[0]], np_energies[0], show_plot=True)
plot_utils.plot_charge_hist(np_events[0], np_events[1], 0, num_bins=200)

i = np_samples.shape[0]-1
while i > 0:
    plot_utils.plot_actual_vs_recon(np_samples[i], np_samples[i-1], label_dict[np_pred_labels[i]],
                                np_pred_energies[i].item(), show_plot=True)

    plot_utils.plot_charge_hist(np_samples[i], np_samples[i], 0, num_bins=200)
    i = i - 2

In [None]:

run_id = "20190823_080112"
dump_dir = "/home/akajal/WatChMaL/VAE/dumps/" + run_id + "/"
model_status = "trained" 
np_arr_path = dump_dir + model_status + "_interpolations.npz"

# Load the np array into memory
np_arr = np.load(np_arr_path)
print(list(np_arr.keys()))

# Extract the actual interpolated events
np_events, np_labels, np_energies = np_arr["events"], np_arr["labels"], np_arr["energies"]
np_samples, np_pred_labels, np_pred_energies = np_arr["samples"], np_arr["predicted_labels"], np_arr["predicted_energies"]


print(np_events.shape, np_labels.shape, np_energies.shape)
print(np_samples.shape, np_pred_labels.shape, np_pred_energies.shape)

print(np_energies)
print(label_dict[np_labels[0]], label_dict[np_labels[1]])

plot_utils.plot_actual_vs_recon(np_events[0], np_events[1], label_dict[np_labels[0]], np_energies[0], show_plot=True)
plot_utils.plot_charge_hist(np_events[0], np_events[1], 0, num_bins=200)

i = np_samples.shape[0]-1
while i > 0:
    plot_utils.plot_actual_vs_recon(np_samples[i], np_samples[i-1], label_dict[np_pred_labels[i]],
                                np_pred_energies[i].item(), show_plot=True)

    plot_utils.plot_charge_hist(np_samples[i], np_samples[i], 0, num_bins=200)
    i = i - 2