# Conditional Neural Processes (CNP) for 1D regression.
[Conditional Neural Processes](https://arxiv.org/pdf/1807.01613.pdf) (CNPs) were
introduced as a continuation of
[Generative Query Networks](https://deepmind.com/blog/neural-scene-representation-and-rendering/)
(GQN) to extend its training regime to tasks beyond scene rendering, e.g. to
regression and classification.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import datetime
import numpy as np
import torchsnooper
import os
import plotting_utils_cnp as plotting
import data_generator as data
from matplotlib.backends.backend_pdf import PdfPages
import pandas as pd
import dask.dataframe as dd
import import_ipynb
import conditional_neural_process_model as cnp


<img src="../utilities/concept.png" alt="drawing" width="500"/>

## Running Conditional Neural Processes

Now that we have defined the dataset as well as our model and its components we
can start building everything into the graph. Before we get started we need to
set some variables:

*   **`TRAINING_ITERATIONS`** - a scalar that describes the number of iterations
    for training. At each iteration we will sample a new batch of functions from
    the GP, pick some of the points on the curves as our context points **(x,
    y)<sub>C</sub>** and some points as our target points **(x,
    y)<sub>T</sub>**. We will predict the mean and variance at the target points
    given the context and use the log likelihood of the ground truth targets as
    our loss to update the model.
*   **`MAX_CONTEXT_POINTS`** - a scalar that sets the maximum number of contest
    points used during training. The number of context points will then be a
    value between 3 and `MAX_CONTEXT_POINTS` that is sampled at random for every
    iteration.
*   **`PLOT_AFTER`** - a scalar that regulates how often we plot the
    intermediate results.

In [None]:
TRAINING_ITERATIONS = int(3601) # Total number of training points: training_iterations * batch_size * max_content_points
#BATCH_SIZE = 100 # number of simulation configurations

MAX_CONTEXT_POINTS = 1000 # 2000 # 4000
MAX_TARGET_POINTS =  2000 # 4000 # 8000
CONTEXT_IS_SUBSET = True
BATCH_SIZE = 1
CONFIG_WISE = False
PLOT_AFTER = int(200)
torch.manual_seed(0)

# all available x config/ physics parameters are ["radius","thickness","npanels","theta","length","height","z_offset","volume","nC_Ge77","time_0[ms]","x_0[m]","y_0[m]","z_0[m]","px_0[m]","py_0[m]","pz_0[m]","ekin_0[eV]","edep_0[eV]","time_t[ms]","x_t[m]","y_t[m]","z_t[m]","px_t[m]","py_t[m]","pz_t[m]","ekin_t[eV]","edep_t[eV]","nsec"]
# Comment: if using data version v1.1 for training, "radius","thickness","npanels","theta","length" is probably necessary
names_x=["radius","thickness","npanels","theta","length","r_0[m]","z_0[m]","time_t[ms]","r_t[m]","z_t[m]","L_t[m]","ln(E0vsET)","edep_t[eV]","nsec"]
name_y ='total_nC_Ge77[cts]'
x_size = len(names_x)
if isinstance(name_y,str):
    y_size = 1
else:
    y_size = len(name_y)

RATIO_TESTING_VS_TRAINING = 1/40
version="v1.3"
path_to_files=f"../simulation/out/LF/{version}/tier2/"
path_out = f'./out/'
f_out = f'{path_out}CNPGauss_{version}_{TRAINING_ITERATIONS}_c{MAX_CONTEXT_POINTS}_t{MAX_TARGET_POINTS}'

Data augmentation methods used:

<img src="../utilities/data_augmentation.png" alt="drawing" width="800"/>

In [None]:
# Set data augmentation parameters
USE_DATA_AUGMENTATION = "mixup" #"smote" #False #"mixup"
USE_BETA = [0.1,0.1] # uniform => None, beta => [a,b] U-shape [0.1,0.1] Uniform [1.,1.] falling [0.2,0.5] rising [0.2,0.5]
SIGNAL_TO_BACKGROUND_RATIO = "" # "_1to4" # used for smote augmentation

if USE_DATA_AUGMENTATION:
    path_out = f'./out/{USE_DATA_AUGMENTATION}/'
    f_out = f'CNPGauss_{version}_{TRAINING_ITERATIONS}_c{MAX_CONTEXT_POINTS}_t{MAX_TARGET_POINTS}_{USE_DATA_AUGMENTATION}{SIGNAL_TO_BACKGROUND_RATIO}'
    if USE_DATA_AUGMENTATION == "mixup":
        path_to_files = f"../simulation/out/LF/{version}/tier3/beta_{USE_BETA[0]}_{USE_BETA[1]}/"
        f_out = f'CNPGauss_{version}_{TRAINING_ITERATIONS}_c{MAX_CONTEXT_POINTS}_t{MAX_TARGET_POINTS}_beta_{USE_BETA[0]}_{USE_BETA[1]}'
    elif USE_DATA_AUGMENTATION == "smote" and CONFIG_WISE == True:
        path_to_files = f"../simulation/out/LF/{version}/tier3/smote{SIGNAL_TO_BACKGROUND_RATIO}/"
        

In [None]:
# Train dataset
dataset_train = data.DataGeneration(num_iterations=TRAINING_ITERATIONS, num_context_points=MAX_CONTEXT_POINTS, num_target_points=MAX_TARGET_POINTS, batch_size = BATCH_SIZE, config_wise=CONFIG_WISE, path_to_files=path_to_files,x_size=x_size,y_size=y_size, mode = "training", ratio_testing=RATIO_TESTING_VS_TRAINING,sig_bkg_ratio = SIGNAL_TO_BACKGROUND_RATIO, use_data_augmentation=USE_DATA_AUGMENTATION, names_x = names_x, name_y=name_y)
TRAINING_ITERATIONS = dataset_train._num_iterations
# Testing dataset
dataset_testing = data.DataGeneration(num_iterations=int(np.round(TRAINING_ITERATIONS/PLOT_AFTER))+5, num_context_points=MAX_CONTEXT_POINTS, num_target_points=MAX_TARGET_POINTS, batch_size = 1, config_wise=False, path_to_files=f"../simulation/out/LF/{version}/tier2/",x_size=x_size,y_size=y_size, mode = "testing",ratio_testing=RATIO_TESTING_VS_TRAINING, sig_bkg_ratio = SIGNAL_TO_BACKGROUND_RATIO, use_data_augmentation="None", names_x = names_x, name_y=name_y)
TRAINING_ITERATIONS = dataset_train._num_iterations if TRAINING_ITERATIONS > dataset_train._num_iterations else TRAINING_ITERATIONS
PLOT_AFTER =  int(5 * np.ceil(np.ceil(TRAINING_ITERATIONS/(dataset_testing._num_iterations-2))/5)) if PLOT_AFTER < int(np.ceil(TRAINING_ITERATIONS/(dataset_testing._num_iterations-2))) else PLOT_AFTER


We can now add the model to the graph and finalise it by defining the train step
and the initializer.

In [None]:

d_x, d_in, representation_size, d_out = x_size , x_size+y_size, 32, y_size+1
encoder_sizes = [d_in, 32, 64, 128, 128, 128, 64, 48, representation_size]
decoder_sizes = [representation_size + d_x, 32, 64, 128, 128, 128, 64, 48, d_out]

model = cnp.DeterministicModel(encoder_sizes, decoder_sizes)

optimizer = optim.Adam(model.parameters(), lr=1e-4)
#optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)
# 

bce = nn.BCELoss()
iter_testing = 0
fout = open(f'{path_out}{f_out}_training.txt', "w")

# create a PdfPages object
pdf = PdfPages(f'{path_out}{f_out}_training.pdf')

for it in range(TRAINING_ITERATIONS):
    # load data:
    data_train = dataset_train.get_data(it, CONTEXT_IS_SUBSET)

    # Get the predicted mean and variance at the target points for the testing set
    log_prob, mu, _ = model(data_train.query, data_train.target_y)
    
    # Define the loss
    loss = -log_prob.mean()
    loss.backward()

    # Perform gradient descent to update parameters
    optimizer.step()
    
    # reset gradient to 0 on all parameters
    optimizer.zero_grad()

    if max(mu[0].detach().numpy()) <= 1 and min(mu[0].detach().numpy()) >= 0:
        loss_bce = bce(mu, data_train.target_y)
    else:
        loss_bce = -1.

    mu=mu[0].detach().numpy()
    if it % 100 == 0:
        print('{} Iteration: {}/{}, train loss: {:.4f} (vs BCE {:.4f})'.format(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),it, TRAINING_ITERATIONS,loss, loss_bce))
        fout.write('Iteration: {}/{}, train loss: {:.4f} (vs BCE {:.4f})\n'.format(it, TRAINING_ITERATIONS,loss, loss_bce))
    
    if it % PLOT_AFTER == 0 or it == int(TRAINING_ITERATIONS-1):
        data_testing = dataset_testing.get_data(iter_testing, CONTEXT_IS_SUBSET)
        log_prob_testing, mu_testing, _ = model(data_testing.query, data_testing.target_y)
        # Define the loss
        loss_testing = -log_prob_testing.mean()

        if max(mu_testing[0].detach().numpy()) <= 1 and min(mu_testing[0].detach().numpy()) >= 0:
            loss_bce_testing = bce(mu_testing,  data_testing.target_y)
        else:
            loss_bce_testing = -1.

        mu_testing=mu_testing[0].detach().numpy()
        print("{}, Iteration: {}, test loss: {:.4f} (vs BCE {:.4f})".format(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), it, loss_testing, loss_bce_testing))
        fout.write("{}, Iteration: {}, test loss: {:.4f} (vs BCE {:.4f})\n".format(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), it, loss_testing, loss_bce_testing))
        if isinstance(name_y,str):
            fig = plotting.plot(mu, data_train.target_y[0].detach().numpy(), f'{loss:.2f}', mu_testing, data_testing.target_y[0].detach().numpy(), f'{loss_testing:.2f}', it)
        else:
            for k in range(y_size):
                fig = plotting.plot(mu[:,k], data_train.target_y[0].detach().numpy()[:,k], f'{loss:.2f}', mu_testing[:,k], data_testing.target_y[0].detach().numpy()[:,k], f'{loss_testing:.2f}', it)
        if it % PLOT_AFTER*5 == 0 or it == int(TRAINING_ITERATIONS-1):
            pdf.savefig(fig)
            plt.show()
            plt.clf()

        iter_testing += 1
pdf.close()
fout.close()
torch.save(model.state_dict(), f'./out/{f_out}_model.pth')

