# Training on data such that there are 150 events with n tracks for n in [1,25].
# Testing on data such that there are 25 events with n tracks for n in [1,25].

In [None]:
!nvidia-smi

In [None]:
# Author: Daniel Zurawski
# Author: Keshav Kapoor
# Organization: Fermilab
# Grammar: Python 3.6.1

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

### Choose either (1) or (2).
### (1) If you prefer a separate window for plots, uncomment the below.
#import matplotlib
#matplotlib.use('qt5agg')

### (2) If you prefer plots to display within the notebook, uncomment the below.
### WARNING: Plots suffer performance issues and will lag a bit.
%matplotlib notebook

import keras # Neural network models
import pandas as pd # Data frames
import numpy as np  # numerical python
from tracker3d import loader, utils, metrics

In [None]:
order   = ("r", "phi", "z")
n_noise = 0
code    = (order[0][0] + order[1][0] + order[2][0]).upper()

# True if you want to load from .npz file. False if you want to create your own data.
# This will NOT load from any .npz file. It will only load from .npz files stored using the
# loader.to_file() function.
load_from_file = True

# Name of files to save/load train and target data to/from.
file = "datasets/npz/UNIF-25T-175E-{0}-{1}N.npz".format(code, n_noise)
data, target = loader.from_file(file)
percent = 6/7  # What percent of the data should be used to train with.
bound   = int(data.shape[0] * percent)
train_data, train_target = data[:bound], target[:bound]
test_data,  test_target  = data[bound:], target[bound:]
print("Successfully loaded!")
print("train_data shape:   {0},\ntrain_target shape: {1}".format(train_data.shape, train_target.shape))
print("test_data shape:    {0},\ntest_target shape:  {1}".format(test_data.shape, test_target.shape))

In [None]:
# To be used when we define our model.
from keras.layers import TimeDistributed, Dense, LSTM, Activation
from keras.layers import Dropout, GRU, Bidirectional, Conv2D, Conv1D
from keras.layers import MaxPooling1D
from keras.layers.normalization import BatchNormalization
from keras.regularizers import l2, l1
from keras.models import Sequential

In [None]:
# It is time to define parameters for the model.
input_shape  = train_data.shape[1:] # Shape of an event.
num_classes  = train_target.shape[2] # Number of tracks per event
epochs       = 128
batch_size   = 64
valsplit     = 1/8
opt          = keras.optimizers.RMSprop(lr=0.001)

In [None]:
###############
# Input Layer #
###############
model = Sequential()
model.add(Dropout(rate=1/16, input_shape=input_shape))

#################
# Hidden Layers #
#################
for _ in range(3):
    model.add(Bidirectional(
        GRU(
            units=300, 
            return_sequences=True,
            recurrent_dropout=1/8,
            #activation="tanh",
            dropout=1/8,
            implementation=2
        ),
        merge_mode="concat"
    ))
################
# Output Layer #
################
model.add(TimeDistributed(Dense(
    units=num_classes, 
    kernel_initializer="uniform", 
    activation="softmax"
)))

###############
# Compilation #
###############
model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=["accuracy"])

# Print a summary of the model.
print("Epochs: {0}, Batch Size: {1}, Validation Split {2}%".format(
    epochs,
    batch_size,
    valsplit * 100
))
model.summary()

In [None]:
%%time

hist = model.fit (
    train_data,
    train_target,
    epochs=epochs,
    batch_size=batch_size,
    verbose=1,
    validation_split=valsplit,
)
model.save("models/TRAIN-U25T150E-TEST-U25T25E-{0}-{1}N.h5".format(code, n_noise))

In [None]:
guesses = model.predict(test_data)

In [None]:
import importlib
importlib.reload(metrics)

In [None]:
_ = metrics.accuracy_vs_tracks_boxplot(guesses, test_target, noise=(n_noise > 0))

In [None]:
thresholds = [i / 10 for i in range(0, 1 + 10)]
# Variation 1: Probability that hit was predicted correctly with certainty greater than or equal to threshold.
# Variation 2: Probability that hit was predicted incorrectly with certainty greater than or equal to threshold.
# Variation 3: Probability that hit was predicted to multiple tracks with certainties greater than or equal to threshold.
# Variation 4: Probability that hit was predicted to no track with certainty greater than or equal to threshold.
_ = metrics.threshold_boxplot(guesses, test_target, thresholds, variation="none")