## Train Primary Track Networks

written by Isobel Mawby (i.mawby1@lancaster.ac.uk)

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Imports
</div>

In [None]:
import sys
import os
sys.path.insert(0, os.getcwd()[0:len(os.getcwd()) - 11])
sys.path.insert(1, os.getcwd()[0:len(os.getcwd()) - 11] + '/Metrics')

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import itertools

import Models
import Datasets
import TrainingMetrics

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Put the path to the primary track training file (created by WritePrimaryTierFile.ipynb with isTrackMode == True) and set ouput file names
</div>

In [None]:
trainFileName = sys.path[0] + '/files/hierarchy_TRAIN_track.npz'
branchModelPath = sys.path[0] + '/models/primary_track_branch_model'
classifierModelPath = sys.path[0] + '/models/primary_track_classifier_model'

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Set hyperparameters
</div>

In [None]:
N_EPOCHS = 5
BATCH_SIZE = 64
LEARNING_RATE = 1e-3
DROPOUT_RATE = 0.5

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Get data from file
</div>

In [None]:
data = np.load(trainFileName)

# Variables
variables_train = data['variables_train']
variables_test = data['variables_test']
# Training cut
trainingCutDCA_train = data['trainingCutDCA_train']
trainingCutDCA_test = data['trainingCutDCA_test']
# Truth
y_train = data['y_train']
y_test = data['y_test']
isTruePrimaryLink_train = data['isTruePrimaryLink_train']
isTruePrimaryLink_test = data['isTruePrimaryLink_test']

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Set multiplicity variables
</div>

In [None]:
nVariables = variables_train.shape[1]
nLinks = y_train.shape[1]

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Check shapes
</div>

In [None]:
print('variables_train.shape:', variables_train.shape)
print('variables_test.shape:', variables_test.shape)
print('y_train.shape:', y_train.shape)
print('y_test.shape:', y_test.shape)
print('trainingCutDCA_train.shape:', trainingCutDCA_train.shape)
print('trainingCutDCA_test.shape:', trainingCutDCA_test.shape)
print('isTruePrimaryLink_train.shape:', isTruePrimaryLink_train.shape)
print('isTruePrimaryLink_test.shape:', isTruePrimaryLink_test.shape)
print('')
print('ntrain:', variables_train.shape[0])
print('ntest:', variables_test.shape[0])

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Apply training cut mask
</div>

In [None]:
# training cut threshold
MAX_TRAINING_CUT_DCA = 50.0

######################
# training set first
######################
# Make mask
passTrainingCutDCA_train = trainingCutDCA_train < MAX_TRAINING_CUT_DCA
passTrainingCuts_train = passTrainingCutDCA_train

# Mask the 1D variables... shape=(nEntries, )
isTruePrimaryLink_train = isTruePrimaryLink_train[passTrainingCuts_train]

# Mask the truth... shape=(nEntries, nLinks)
y_train = y_train[np.column_stack([passTrainingCuts_train] * nLinks)].reshape(-1, nLinks)

# Mask the variable... shape=(nEntries, nVariables)
variables_train = variables_train[[[entry] * nVariables for entry in passTrainingCuts_train]].reshape(-1, nVariables)

######################
# now test set
######################
# Make mask
passTrainingCutDCA_test = trainingCutDCA_test < MAX_TRAINING_CUT_DCA
passTrainingCuts_test = passTrainingCutDCA_test

# Mask the 1D variables... shape=(nEntries, )
isTruePrimaryLink_test = isTruePrimaryLink_test[passTrainingCuts_test]

# Mask the truth... shape=(nEntries, nLinks)
y_test = y_test[np.column_stack([passTrainingCuts_test] * nLinks)].reshape(-1, nLinks)

# Mask the variable... shape=(nEntries, nVariables)
variables_test = variables_test[[[entry] * nVariables for entry in passTrainingCuts_test]].reshape(-1, nVariables)

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Check shapes after training cut application
</div>

In [None]:
print('variables_train.shape:', variables_train.shape)
print('variables_test.shape:', variables_test.shape)
print('y_train.shape:', y_train.shape)
print('y_test.shape:', y_test.shape)
print('isTruePrimaryLink_train.shape:', isTruePrimaryLink_train.shape)
print('isTruePrimaryLink_test.shape:', isTruePrimaryLink_test.shape)
print('')
print('ntrain:', variables_train.shape[0])
print('ntest:', variables_test.shape[0])

<div class="alert alert-block alert-info" style="font-size: 18px;">
   Define class weights for branch model
</div>

In [None]:
nTrue = np.count_nonzero(y_train == 1)
nBackground = np.count_nonzero(y_train == 0)
nWrongOrientation = np.count_nonzero(y_train == 2)
maxLinks = max(nTrue, nBackground, nWrongOrientation)

classWeights_branch = torch.tensor([float(maxLinks)/float(nBackground), float(maxLinks)/float(nTrue), float(maxLinks)/float(nWrongOrientation)])

<div class="alert alert-block alert-info" style="font-size: 18px;">
   Define class weights for classifier model
</div>

In [None]:
n_true_primary_train = np.count_nonzero(isTruePrimaryLink_train == True)
n_false_primary_train = np.count_nonzero(isTruePrimaryLink_train == False)

maxCounts_train = max(n_true_primary_train, n_false_primary_train)

classWeights_classifier = {'true_primary_train'  : maxCounts_train/n_true_primary_train, \
                           'false_primary_train' : maxCounts_train/n_false_primary_train}

<div class="alert alert-block alert-info" style="font-size: 18px;">
   Build the training input for each edge. This is a concatenation of the variable tensor of this edge and those of all other edges, such that the variables for the edge in question are first.
</div>

In [None]:
branch_model_input_train = Models.PrepareBranchModelInput(nLinks, Models.primary_tier_n_orientation_indep_vars, Models.primary_tier_n_orientation_dep_vars, variables_train)
branch_model_input_test = Models.PrepareBranchModelInput(nLinks, Models.primary_tier_n_orientation_indep_vars, Models.primary_tier_n_orientation_dep_vars, variables_test)

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Prepare Dataset objects
</div>

In [None]:
dataset_train = Datasets.TwoEdgeDataset(branch_model_input_train[0], branch_model_input_train[1], y_train[:,0], y_train[:,1], isTruePrimaryLink_train, np.zeros(isTruePrimaryLink_train.shape))
loader_train = Datasets.DataLoader(dataset_train, shuffle=True, batch_size=BATCH_SIZE)    

dataset_test = Datasets.TwoEdgeDataset(branch_model_input_test[0], branch_model_input_test[1], y_test[:,0], y_test[:,1], isTruePrimaryLink_test, np.zeros(isTruePrimaryLink_test.shape))
loader_test = Datasets.DataLoader(dataset_test, shuffle=True, batch_size=BATCH_SIZE)    

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Define branch and classifier models
</div>

In [None]:
branch_model = Models.OrientationModel(nVariables, dropoutRate=DROPOUT_RATE)
classifier_model = Models.ClassifierModel(nLinks)

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Define loss functions for training to implement custom weighting
</div>

In [None]:
def loss_function_branch(pred, target, classWeights) :
    loss_func = torch.nn.CrossEntropyLoss(weight=classWeights)
    loss = loss_func(pred, target)    
    return loss

def loss_function_classifier(pred, target, classWeights) :
    # Do weighting
    weight = torch.ones(target.shape)
    weight[target < 0.5] = classWeights['false_primary_train']
    weight[target > 0.5] = classWeights['true_primary_train']
    
    # Use BCE loss
    loss_func = torch.nn.BCELoss(weight=weight)

    # Calculate loss
    loss = loss_func(pred, target)    

    return loss

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Training/validation loop functions.
</div>

In [None]:
def RunTrainingLoop(nLinks, dataset_batch, branch_model, classifier_model, classWeights_branch, classWeights_classifier) : 
    classifier_input = torch.empty(0,)
    total_loss = 0
    
    for i in range(nLinks) :
        edge_name = "edge" + str(i)
        pred = branch_model(dataset_batch[edge_name][0])
        total_loss += loss_function_branch(pred, dataset_batch[edge_name][1], classWeights_branch) 
        classifier_input = torch.concatenate((classifier_input, pred), axis=1)

    classifier_target = dataset_batch["truth_link"].reshape(-1,1)
    classifier_pred = classifier_model(classifier_input)  
    total_loss += loss_function_classifier(classifier_pred, classifier_target, classWeights_classifier)

    return total_loss

def RunValidationLoop(nLinks, dataset_batch, branch_model, classifier_model, classWeights_branch, classWeights_classifier, linkMetrics) : 
    classifier_input = torch.empty(0,)
    
    for i in range(nLinks) :
        edge_name = "edge" + str(i)
        pred = branch_model(dataset_batch[edge_name][0])
        branch_loss = loss_function_branch(pred, dataset_batch[edge_name][1], classWeights_branch) 
        linkMetrics.edge_metrics[i].Fill(branch_loss, pred, dataset_batch[edge_name][1])      
        classifier_input = torch.concatenate((classifier_input, pred), axis=1)

    classifier_target = dataset_batch["truth_link"].reshape(-1,1)
    classifier_pred = classifier_model(classifier_input) 
    classifier_loss = loss_function_classifier(classifier_pred, classifier_target, classWeights_classifier)
    linkMetrics.classifier_metrics.Fill(classifier_loss, classifier_pred, classifier_target) 

<div class="alert alert-block alert-info" style="font-size: 18px;">
   Training/testing loops
</div>

In [None]:
# Optimiser
optimiser = torch.optim.Adam(itertools.chain(branch_model.parameters(), classifier_model.parameters()), lr=LEARNING_RATE)

# Put here some metrics
epochs_metrics = []
training_link_metrics = []
testing_link_metrics = []

for epoch in range(N_EPOCHS):
    
    # Begin training mode
    branch_model.train()
    classifier_model.train()
    
    for dataset_batch in loader_train:  
        
        # Skip incomplete batches
        if (dataset_batch["truth_link"].shape[0] != BATCH_SIZE) :
            continue           
            
        # Run training loop
        total_loss = RunTrainingLoop(nLinks, dataset_batch, branch_model, classifier_model, classWeights_branch, classWeights_classifier)
            
        # Update model parameters
        optimiser.zero_grad()
        total_loss.backward()
        optimiser.step()
        
    with torch.no_grad():
        
        # Begin testing mode
        branch_model.eval()
        classifier_model.eval()
        
        # Initialise metrics        
        linkMetrics_train = TrainingMetrics.LinkMetrics(nLinks)
        linkMetrics_test = TrainingMetrics.LinkMetrics(nLinks)
                
        for dataset_batch_train in loader_train:  

            # Skip incomplete batches
            if (dataset_batch_train["truth_link"].shape[0] != BATCH_SIZE) :
                continue 
                
            # Run validation loop
            RunValidationLoop(nLinks, dataset_batch_train, branch_model, classifier_model, classWeights_branch, classWeights_classifier, linkMetrics_train)
            
        for dataset_batch_test in loader_test:  

            # Skip incomplete batches
            if (dataset_batch_test["truth_link"].shape[0] != BATCH_SIZE) :
                continue
                
            # Run validation loop
            RunValidationLoop(nLinks, dataset_batch_test, branch_model, classifier_model, classWeights_branch, classWeights_classifier, linkMetrics_test)            

        epochs_metrics.append(epoch)                    
    
    ##########################
    # Calc metrics for epoch 
    ##########################   
    # Find threshold
    optimal_threshold_train, maximum_accuracy_train = TrainingMetrics.calculate_accuracy(linkMetrics_train)
    optimal_threshold_test, maximum_accuracy_test = TrainingMetrics.calculate_accuracy(linkMetrics_test)

    # Calculate metrics
    linkMetrics_train.Evaluate(optimal_threshold_train)
    linkMetrics_test.Evaluate(optimal_threshold_test)
    
    # Add to our lists
    training_link_metrics.append(linkMetrics_train)
    testing_link_metrics.append(linkMetrics_test) 
    
    # Do some prints
    print('----------------------------------------')
    print('Epoch:', epoch)
    print('----------------------------------------')
    print('training_classification_loss:', round(linkMetrics_train.classifier_metrics.av_loss, 2))
    print('----')
    print('optimal_threshold_train:', optimal_threshold_train)
    print('accuracy_train:', str(round(maximum_accuracy_train.item(), 2)) +'%')
    print('positive_as_positive_fraction_train:', str(round(linkMetrics_train.classifier_metrics.pos_as_pos_frac * 100.0, 2)) + '%')
    print('positive_as_negative_fraction_train:', str(round(linkMetrics_train.classifier_metrics.pos_as_neg_frac * 100.0, 2)) + '%')
    print('negative_as_negative_fraction_train:', str(round(linkMetrics_train.classifier_metrics.neg_as_pos_frac * 100.0, 2)) + '%')
    print('negative_as_positive_fraction_train:', str(round(linkMetrics_train.classifier_metrics.neg_as_neg_frac * 100.0, 2)) + '%')
    print('----')
    print('testing_classification_loss:', round(linkMetrics_test.classifier_metrics.av_loss, 2))
    print('----')
    print('optimal_threshold_test:', optimal_threshold_test)
    print('accuracy_test:', str(round(maximum_accuracy_test.item(), 2)) +'%')
    print('positive_as_positive_fraction_test:', str(round(linkMetrics_test.classifier_metrics.pos_as_pos_frac * 100.0, 2)) + '%')
    print('positive_as_negative_fraction_test:', str(round(linkMetrics_test.classifier_metrics.pos_as_neg_frac * 100.0, 2)) + '%')
    print('negative_as_negative_fraction_test:', str(round(linkMetrics_test.classifier_metrics.neg_as_pos_frac * 100.0, 2)) + '%')
    print('negative_as_positive_fraction_test:', str(round(linkMetrics_test.classifier_metrics.neg_as_neg_frac * 100.0, 2)) + '%')
    print('----')
    
    for i in [0, 1, 2] :
        TrainingMetrics.plot_scores_branch(linkMetrics_train, linkMetrics_test, 0, i)

    TrainingMetrics.plot_scores_classifier(linkMetrics_train, linkMetrics_test)

<div class="alert alert-block alert-info" style="font-size: 18px;">
   Plot metrics associated with training 
</div>

In [None]:
TrainingMetrics.plot_branch_loss_evolution(epochs_metrics, training_link_metrics, testing_link_metrics, 0, 'Loss - branch_model_0')
TrainingMetrics.plot_branch_loss_evolution(epochs_metrics, training_link_metrics, testing_link_metrics, 1, 'Loss - branch_model_1')
TrainingMetrics.plot_classifier_loss_evolution(epochs_metrics, training_link_metrics, testing_link_metrics, 'Loss - classifier')
TrainingMetrics.plot_edge_rate(epochs_metrics, training_link_metrics, testing_link_metrics, True)
TrainingMetrics.plot_edge_rate(epochs_metrics,  training_link_metrics, testing_link_metrics, False)

<div class="alert alert-block alert-info" style="font-size: 18px;">
   Show ROC curve and confusion matrices, for the latter you can decide the threshold cut used
</div>

In [None]:
with torch.no_grad():
    # Begin testing mode
    branch_model.eval()
    classifier_model.eval()
    # Get predictions
    pred_0_test = branch_model(torch.tensor(branch_model_input_test[0], dtype=torch.float))
    pred_1_test = branch_model(torch.tensor(branch_model_input_test[1], dtype=torch.float))
    classifier_pred_test = classifier_model(torch.concatenate((pred_0_test, pred_1_test), axis=1)).reshape(-1)

    neg_scores_final_test = np.array(classifier_pred_test.tolist())[isTruePrimaryLink_test == 0].reshape(-1)
    pos_scores_final_test = np.array(classifier_pred_test.tolist())[isTruePrimaryLink_test == 1].reshape(-1)
    
    TrainingMetrics.plot_roc_curve(torch.tensor(pos_scores_final_test), torch.tensor(neg_scores_final_test))
    TrainingMetrics.draw_confusion_with_threshold(classifier_pred_test, isTruePrimaryLink_test, 0.5)

<div class="alert alert-block alert-info" style="font-size: 18px;">
   Save the model
</div>

In [None]:
sm_branch = torch.jit.script(branch_model)
sm_branch.save(branchModelPath)

sm_classifier = torch.jit.script(classifier_model)
sm_classifier.save(classifierModelPath)