## Train Later Tier Networks

written by Isobel Mawby (i.mawby1@lancaster.ac.uk)

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Imports
</div>

In [None]:
import sys
import os
sys.path.insert(0, os.getcwd()[0:len(os.getcwd()) - 10])
sys.path.insert(1, os.getcwd()[0:len(os.getcwd()) - 10] + '/Metrics')

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import itertools

import Models
import Datasets
import TrainingMetrics

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Train later tier track-track model (isTrackMode == True) or later tier track-shower model (isTrackMode == False)?
</div>

In [None]:
isTrackMode = False

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Put the path to the later tier track-track link training file (created by WriteLaterTierFile.ipynb with isTrackMode == True) and set ouput file name
</div>

In [None]:
if (isTrackMode) :
    trainFileName = sys.path[0] + '/files/hierarchy_TRAIN_later_tier_track.npz'
    branchModelPath = sys.path[0] + '/models/track_track_branch_model'
    classifierModelPath = sys.path[0] + '/models/track_track_classifier_model'
else :
    trainFileName = sys.path[0] + '/files/hierarchy_TRAIN_later_tier_shower.npz'
    branchModelPath = sys.path[0] + '/models/track_shower_branch_model'
    classifierModelPath = sys.path[0] + '/models/track_shower_classifier_model'

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Set hyperparameters
</div>

In [None]:
N_EPOCHS = 5
BATCH_SIZE = 64
LEARNING_RATE = 1e-3
DROPOUT_RATE = 0.5

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Get data from file
</div>

In [None]:
data = np.load(trainFileName)

# Variables
variables_train = data['variables_train']
variables_test = data['variables_test']
# Truth
y_train = data['y_train']
y_test = data['y_test']
trueParentChildLink_train = data['trueParentChildLink_train']
trueParentChildLink_test = data['trueParentChildLink_test']
trueChildVisibleGeneration_train = data['trueChildVisibleGeneration_train']
trueChildVisibleGeneration_test = data['trueChildVisibleGeneration_test']
trainingCutSep_train = data['trainingCutSep_train']
trainingCutSep_test = data['trainingCutSep_test']
trainingCutDoesConnect_train = data['trainingCutDoesConnect_train']
trainingCutDoesConnect_test = data['trainingCutDoesConnect_test']
trainingCutL_train = data['trainingCutL_train']
trainingCutL_test = data['trainingCutL_test']
trainingCutT_train = data['trainingCutT_train']
trainingCutT_test = data['trainingCutT_test']

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Set multiplicity variables
</div>

In [None]:
nVariables = variables_train.shape[1]
nLinks = y_train.shape[1]

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Check shapes
</div>

In [None]:
print('variables_train.shape:', variables_train.shape)
print('variables_test.shape:', variables_test.shape)
print('y_train.shape:', y_train.shape)
print('y_test.shape:', y_test.shape)
print('trueParentChildLink_train.shape:', trueParentChildLink_train.shape)
print('trueParentChildLink_test.shape:', trueParentChildLink_test.shape)
print('trueChildVisibleGeneration_train.shape:', trueChildVisibleGeneration_train.shape)
print('trueChildVisibleGeneration_test.shape:', trueChildVisibleGeneration_test.shape)
print('trainingCutSep_train.shape:', trainingCutSep_train.shape)
print('trainingCutSep_test.shape:', trainingCutSep_test.shape)
print('trainingCutDoesConnect_train.shape:', trainingCutDoesConnect_train.shape)
print('trainingCutDoesConnect_test.shape:', trainingCutDoesConnect_test.shape)
print('trainingCutL_train.shape:', trainingCutL_train.shape)
print('trainingCutL_test.shape:', trainingCutL_test.shape)
print('trainingCutT_train.shape:', trainingCutT_train.shape)
print('trainingCutT_test.shape:', trainingCutT_test.shape)
print('')
print('ntrain:', variables_train.shape[0])
print('ntest:', variables_test.shape[0])

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Apply training cut mask
</div>

In [None]:
# training cut threshold
MAX_TRAINING_CUT_SEP = 20.0
MIN_TRAINING_CUT_L = -100.0
MAX_TRAINING_CUT_L = 100.0
MAX_TRAINING_CUT_T = 40.0

######################
# training set first
######################
# Make mask
passTrainingCutSep_train = trainingCutSep_train < MAX_TRAINING_CUT_SEP
passTrainingCutDoesConnect_train = trainingCutDoesConnect_train == 1
passTrainingCutL_train = np.logical_and(trainingCutL_train > MIN_TRAINING_CUT_L, trainingCutL_train < MAX_TRAINING_CUT_L)
passTrainingCutT_train = trainingCutT_train < MAX_TRAINING_CUT_T
passTrainingCuts_train = np.logical_or(passTrainingCutSep_train, np.logical_or(passTrainingCutDoesConnect_train, np.logical_and(passTrainingCutL_train, passTrainingCutT_train)))

# Mask the 1D variables... shape=(nEntries, )
trueChildVisibleGeneration_train = trueChildVisibleGeneration_train[passTrainingCuts_train]
trueParentChildLink_train = trueParentChildLink_train[passTrainingCuts_train]

# Mask the truth... shape=(nEntries, nLinks)
y_train = y_train[np.column_stack([passTrainingCuts_train] * nLinks)].reshape(-1, nLinks)

# Mask the variable... shape=(nEntries, nVariables)
variables_train = variables_train[[[entry] * nVariables for entry in passTrainingCuts_train]].reshape(-1, nVariables)

######################
# now test set
######################
# Make mask
passTrainingCutSep_test = trainingCutSep_test < MAX_TRAINING_CUT_SEP
passTrainingCutDoesConnect_test = trainingCutDoesConnect_test == 1
passTrainingCutL_test = np.logical_and(trainingCutL_test > MIN_TRAINING_CUT_L, trainingCutL_test < MAX_TRAINING_CUT_L)
passTrainingCutT_test = trainingCutT_test < MAX_TRAINING_CUT_T
passTrainingCuts_test = np.logical_or(passTrainingCutSep_test, np.logical_or(passTrainingCutDoesConnect_test, np.logical_and(passTrainingCutL_test, passTrainingCutT_test)))

# Mask the 1D variables... shape=(nEntries, )
trueChildVisibleGeneration_test = trueChildVisibleGeneration_test[passTrainingCuts_test]
trueParentChildLink_test = trueParentChildLink_test[passTrainingCuts_test]

# Mask the truth... shape=(nEntries, nLinks)
y_test = y_test[np.column_stack([passTrainingCuts_test] * nLinks)].reshape(-1, nLinks)

# Mask the variable... shape=(nEntries, nVariables)
variables_test = variables_test[[[entry] * nVariables for entry in passTrainingCuts_test]].reshape(-1, nVariables)

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Check shapes after training cut application
</div>

In [None]:
print('variables_train.shape:', variables_train.shape)
print('variables_test.shape:', variables_test.shape)
print('y_train.shape:', y_train.shape)
print('y_test.shape:', y_test.shape)
print('trueParentChildLink_train.shape:', trueParentChildLink_train.shape)
print('trueParentChildLink_test.shape:', trueParentChildLink_test.shape)
print('trueChildVisibleGeneration_train.shape:', trueChildVisibleGeneration_train.shape)
print('trueChildVisibleGeneration_test.shape:', trueChildVisibleGeneration_test.shape)
print('')
print('ntrain:', variables_train.shape[0])
print('ntest:', variables_test.shape[0])

<div class="alert alert-block alert-info" style="font-size: 18px;">
   Define class weights for the branch model, want to weight signal/background/wrong orientation classes so they're equal <b>and</b> equalise the secondary and 'higher' tier contributions in each class
</div>

In [None]:
# Calculate class weights
nTrue = np.count_nonzero(y_train == 1)
nBackground = np.count_nonzero(y_train == 0)
nWrongOrientation = np.count_nonzero(y_train == 2)
maxLinks = max(nTrue, nBackground, nWrongOrientation)

true_branch_weight = float(maxLinks)/float(nTrue)
background_branch_weight = float(maxLinks)/float(nBackground)
wrong_orientation_branch_weight = float(maxLinks)/float(nWrongOrientation)

# Get child tier + class type weight (i.e. secondary signal vs higher signal)
n_secondary_child_true_branch = 0 
for i in range(nLinks) : 
    n_secondary_child_true_branch += np.count_nonzero(np.logical_and(y_train[:,i] == 1, trueChildVisibleGeneration_train == 3))
n_secondary_child_background_branch = 0
for i in range(nLinks) : 
    n_secondary_child_background_branch += np.count_nonzero(np.logical_and(y_train[:,i] == 0, trueChildVisibleGeneration_train == 3))
n_secondary_child_wrong_orientation_branch = 0
for i in range(nLinks) : 
    n_secondary_child_wrong_orientation_branch += np.count_nonzero(np.logical_and(y_train[:,i] == 2, trueChildVisibleGeneration_train == 3))    
n_higher_child_true_branch = 0
for i in range(nLinks) :
    n_higher_child_true_branch += np.count_nonzero(np.logical_and(y_train[:,i] == 1, trueChildVisibleGeneration_train > 3))
n_higher_child_background_branch = 0
for i in range(nLinks) :
    n_higher_child_background_branch += np.count_nonzero(np.logical_and(y_train[:,i] == 0, trueChildVisibleGeneration_train > 3))    
n_higher_child_wrong_orientation_branch = 0
for i in range(nLinks) :
    n_higher_child_wrong_orientation_branch += np.count_nonzero(np.logical_and(y_train[:,i] == 2, trueChildVisibleGeneration_train > 3)) 


secondary_child_true_branch_weight = (nTrue * 0.5) / n_secondary_child_true_branch
secondary_child_background_branch_weight = (nBackground * 0.5) / n_secondary_child_background_branch
secondary_child_wrong_orientation_branch_weight = (nWrongOrientation * 0.5) / n_secondary_child_wrong_orientation_branch
higher_child_true_branch_weight = (nTrue * 0.5) / n_higher_child_true_branch
higher_child_background_branch_weight = (nBackground * 0.5) / n_higher_child_background_branch
higher_child_wrong_orientation_branch_weight = (nWrongOrientation * 0.5) / n_higher_child_wrong_orientation_branch

classWeights_branch = {
    'secondary_child_true_branch_weight'              : (secondary_child_true_branch_weight * true_branch_weight), 
    'secondary_child_background_branch_weight'        : (secondary_child_background_branch_weight * background_branch_weight),
    'secondary_child_wrong_orientation_branch_weight' : (secondary_child_wrong_orientation_branch_weight * wrong_orientation_branch_weight),    
    'higher_child_true_branch_weight'                 : (higher_child_true_branch_weight * true_branch_weight),
    'higher_child_background_branch_weight'           : (higher_child_background_branch_weight * background_branch_weight),
    'higher_child_wrong_orientation_branch_weight'    : (higher_child_wrong_orientation_branch_weight * wrong_orientation_branch_weight)
}

print('classWeights_branch:', classWeights_branch)

<div class="alert alert-block alert-info" style="font-size: 18px;">
   Check that we end up with the same number of signal, background and wrong orientation links <b>AND</b> the same number of secondary child and higher child links
</div>

In [None]:
this_signal = 0
for i in range(nLinks) :
    this_signal += (np.count_nonzero(np.logical_and(y_train[:,i] == 1, trueChildVisibleGeneration_train == 3)) * classWeights_branch['secondary_child_true_branch_weight'])
    this_signal += (np.count_nonzero(np.logical_and(y_train[:,i] == 1, trueChildVisibleGeneration_train > 3)) * classWeights_branch['higher_child_true_branch_weight'])
this_background = 0
for i in range(nLinks) :
    this_background += (np.count_nonzero(np.logical_and(y_train[:,i] == 0, trueChildVisibleGeneration_train == 3)) * classWeights_branch['secondary_child_background_branch_weight'])
    this_background += (np.count_nonzero(np.logical_and(y_train[:,i] == 0, trueChildVisibleGeneration_train > 3)) * classWeights_branch['higher_child_background_branch_weight'])
this_wrong_orientation = 0
for i in range(nLinks) :
    this_wrong_orientation += (np.count_nonzero(np.logical_and(y_train[:,i] == 2, trueChildVisibleGeneration_train == 3)) * classWeights_branch['secondary_child_wrong_orientation_branch_weight'])
    this_wrong_orientation += (np.count_nonzero(np.logical_and(y_train[:,i] == 2, trueChildVisibleGeneration_train > 3)) * classWeights_branch['higher_child_wrong_orientation_branch_weight'])
this_secondary = 0
for i in range(nLinks) :
    this_secondary += (np.count_nonzero(np.logical_and(y_train[:,i] == 1, trueChildVisibleGeneration_train == 3)) * classWeights_branch['secondary_child_true_branch_weight'])
    this_secondary += (np.count_nonzero(np.logical_and(y_train[:,i] == 0, trueChildVisibleGeneration_train == 3)) * classWeights_branch['secondary_child_background_branch_weight'])
    this_secondary += (np.count_nonzero(np.logical_and(y_train[:,i] == 2, trueChildVisibleGeneration_train == 3)) * classWeights_branch['secondary_child_wrong_orientation_branch_weight'])
this_higher = 0
for i in range(nLinks) :
    this_higher += (np.count_nonzero(np.logical_and(y_train[:,i] == 1, trueChildVisibleGeneration_train > 3)) * classWeights_branch['higher_child_true_branch_weight'])
    this_higher += (np.count_nonzero(np.logical_and(y_train[:,i] == 0, trueChildVisibleGeneration_train > 3)) * classWeights_branch['higher_child_background_branch_weight'])
    this_higher += (np.count_nonzero(np.logical_and(y_train[:,i] == 2, trueChildVisibleGeneration_train > 3)) * classWeights_branch['higher_child_wrong_orientation_branch_weight'])


print('this_signal:', this_signal)
print('this_background:', this_background)
print('this_wrong_orientation:', this_wrong_orientation)
print('')
print('this_secondary:', this_secondary)
print('this_higher:', this_higher)

<div class="alert alert-block alert-info" style="font-size: 18px;">
   Define class weights for the classifier model, want to weight signal/background/wrong orientation classes so they're equal <b>and</b> equalise the secondary and 'higher' tier contributions in each class
</div>

In [None]:
# Calculate true/false classifier weights
n_true_hierarchy_train = np.count_nonzero(trueParentChildLink_train == True)
n_false_hierarchy_train = np.count_nonzero(trueParentChildLink_train == False)
maxCounts_train = max(n_true_hierarchy_train, n_false_hierarchy_train)

true_classifier_weight = float(maxCounts_train)/float(n_true_hierarchy_train)
false_classifier_weight = float(maxCounts_train)/float(n_false_hierarchy_train)

# Get child tier + class type weight (i.e. secondary signal vs higher signal)
n_secondary_child_true_links_train = np.count_nonzero(np.logical_and(trueParentChildLink_train == True, (trueChildVisibleGeneration_train == 3)))
n_secondary_child_false_links_train = np.count_nonzero(np.logical_and(trueParentChildLink_train == False, (trueChildVisibleGeneration_train == 3)))
n_higher_child_true_links_train = np.count_nonzero(np.logical_and(trueParentChildLink_train == True, (trueChildVisibleGeneration_train > 3)))
n_higher_child_false_links_train = np.count_nonzero(np.logical_and(trueParentChildLink_train == False, (trueChildVisibleGeneration_train > 3)))

secondary_child_true_links_train_weight = (n_true_hierarchy_train * 0.5) / n_secondary_child_true_links_train
secondary_child_false_links_train_weight = (n_false_hierarchy_train * 0.5) / n_secondary_child_false_links_train
higher_child_true_links_train_weight = (n_true_hierarchy_train * 0.5) / n_higher_child_true_links_train
higher_child_false_links_train_weight = (n_false_hierarchy_train * 0.5) / n_higher_child_false_links_train

classWeights_classifier = {
    'secondary_child_true_links_train_weight'  : (secondary_child_true_links_train_weight * true_classifier_weight), 
    'secondary_child_false_links_train_weight' : (secondary_child_false_links_train_weight * false_classifier_weight),
    'higher_child_true_links_train_weight'     : (higher_child_true_links_train_weight * true_classifier_weight),
    'higher_child_false_links_train_weight'    : (higher_child_false_links_train_weight * false_classifier_weight)
}

print('classWeights_classifier:', classWeights_classifier)

# Convert to expected form
classifier_weight = np.ones(trueParentChildLink_train.shape)
classifier_weight[np.logical_and(trueParentChildLink_train == True, (trueChildVisibleGeneration_train == 3))] = classWeights_classifier['secondary_child_true_links_train_weight']
classifier_weight[np.logical_and(trueParentChildLink_train == False, (trueChildVisibleGeneration_train == 3))] = classWeights_classifier['secondary_child_false_links_train_weight']
classifier_weight[np.logical_and(trueParentChildLink_train == True, (trueChildVisibleGeneration_train > 3))] = classWeights_classifier['higher_child_true_links_train_weight']
classifier_weight[np.logical_and(trueParentChildLink_train == False, (trueChildVisibleGeneration_train > 3))] = classWeights_classifier['higher_child_false_links_train_weight']

print(np.sum(classifier_weight[np.logical_or(np.logical_and(trueParentChildLink_train == True, (trueChildVisibleGeneration_train == 3)), 
                                             np.logical_and(trueParentChildLink_train == False, (trueChildVisibleGeneration_train == 3)))]))     

print(np.sum(classifier_weight[np.logical_or(np.logical_and(trueParentChildLink_train == True, (trueChildVisibleGeneration_train > 3)), 
                                             np.logical_and(trueParentChildLink_train == False, (trueChildVisibleGeneration_train > 3)))]))  

<div class="alert alert-block alert-info" style="font-size: 18px;">
   Build the training input for each edge. This is a concatenation of the variable tensor of this edge and those of all other edges, such that the variables for the edge in question are first.
</div>

In [None]:
branch_model_input_train = Models.PrepareBranchModelInput(nLinks, Models.later_tier_n_orientation_indep_vars, Models.later_tier_n_orientation_dep_vars, variables_train)
branch_model_input_test = Models.PrepareBranchModelInput(nLinks, Models.later_tier_n_orientation_indep_vars, Models.later_tier_n_orientation_dep_vars, variables_test)

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Prepare Dataset objects
</div>

In [None]:
if (isTrackMode) :
    dataset_train = Datasets.FourEdgeDataset(branch_model_input_train[0], branch_model_input_train[1], branch_model_input_train[2], branch_model_input_train[3], \
                                             y_train[:,0], y_train[:,1], y_train[:,2], y_train[:,3], trueParentChildLink_train, trueChildVisibleGeneration_train)
    dataset_test = Datasets.FourEdgeDataset(branch_model_input_test[0], branch_model_input_test[1], branch_model_input_test[2], branch_model_input_test[3], \
                                            y_test[:,0], y_test[:,1], y_test[:,2], y_test[:,3], trueParentChildLink_test, trueChildVisibleGeneration_test)
else :
    dataset_train = Datasets.TwoEdgeDataset(branch_model_input_train[0], branch_model_input_train[1], y_train[:,0], y_train[:,1], trueParentChildLink_train, trueChildVisibleGeneration_train)
    dataset_test = Datasets.TwoEdgeDataset(branch_model_input_test[0], branch_model_input_test[1], y_test[:,0], y_test[:,1], trueParentChildLink_test, trueChildVisibleGeneration_test)
    
loader_train = Datasets.DataLoader(dataset_train, shuffle=True, batch_size=BATCH_SIZE)   
loader_test = Datasets.DataLoader(dataset_test, shuffle=True, batch_size=BATCH_SIZE)    

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Define branch and classifier models
</div>

In [None]:
branch_model = Models.OrientationModel(nVariables, dropoutRate=DROPOUT_RATE)
classifier_model = Models.ClassifierModel(nLinks)

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Define loss functions for training to implement custom weighting
</div>

In [None]:
def loss_function_branch(pred, target, true_gen, weight_dict) :
    # Calculate weights
    weights = torch.ones(true_gen.shape)
    weights[torch.logical_and(target == 1, true_gen == 3)] = weight_dict['secondary_child_true_branch_weight']
    weights[torch.logical_and(target == 0, true_gen == 3)] = weight_dict['secondary_child_background_branch_weight']
    weights[torch.logical_and(target == 2, true_gen == 3)] = weight_dict['secondary_child_wrong_orientation_branch_weight']    
    weights[torch.logical_and(target == 1, true_gen > 3)] = weight_dict['higher_child_true_branch_weight']
    weights[torch.logical_and(target == 0, true_gen > 3)] = weight_dict['higher_child_background_branch_weight']
    weights[torch.logical_and(target == 2, true_gen > 3)] = weight_dict['higher_child_wrong_orientation_branch_weight']    
    
    loss_func = torch.nn.CrossEntropyLoss(reduction='none')
    loss = loss_func(pred, target)    
    loss = loss * weights
    loss = torch.sum(loss) / loss.shape[0]
    
    return loss

def loss_function_classifier(pred, target, true_gen, weight_dict) :
    # I have to reshape so that I can do the weighting - sad.
    pred = pred.reshape(-1)
    target = target.reshape(-1)
    
    # Calculate weights
    weights = torch.ones(true_gen.shape)
    weights[torch.logical_and(target == 1, true_gen == 3)] = weight_dict['secondary_child_true_links_train_weight']
    weights[torch.logical_and(target == 0, true_gen == 3)] = weight_dict['secondary_child_false_links_train_weight']
    weights[torch.logical_and(target == 1, true_gen > 3)] = weight_dict['higher_child_true_links_train_weight']
    weights[torch.logical_and(target == 0, true_gen > 3)] = weight_dict['higher_child_false_links_train_weight']    
    
    # Use BCE loss
    loss_func = torch.nn.BCELoss(weight=weights)
    
    # Calculate loss
    loss = loss_func(pred, target)    

    return loss

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Training/validation loop functions.
</div>

In [None]:
def RunTrainingLoop(nLinks, dataset_batch, branch_model, classifier_model, classWeights_branch, classWeights_classifier) : 
    classifier_input = torch.empty(0,)
    total_loss = 0
    
    for i in range(nLinks) :
        edge_name = "edge" + str(i)
        pred = branch_model(dataset_batch[edge_name][0])
        total_loss += loss_function_branch(pred, dataset_batch[edge_name][1], dataset_batch["truth_gen"], classWeights_branch) 
        classifier_input = torch.concatenate((classifier_input, pred), axis=1)

    classifier_target = dataset_batch["truth_link"].reshape(-1,1)
    classifier_pred = classifier_model(classifier_input)  
    total_loss += loss_function_classifier(classifier_pred, classifier_target, dataset_batch["truth_gen"], classWeights_classifier)

    return total_loss

def RunValidationLoop(nLinks, dataset_batch, branch_model, classifier_model, classWeights_branch, classWeights_classifier, linkMetrics) : 
    classifier_input = torch.empty(0,)
    
    for i in range(nLinks) :
        edge_name = "edge" + str(i)
        pred = branch_model(dataset_batch[edge_name][0])
        branch_loss = loss_function_branch(pred, dataset_batch[edge_name][1], dataset_batch["truth_gen"], classWeights_branch) 
        linkMetrics.edge_metrics[i].Fill(branch_loss, pred, dataset_batch[edge_name][1])      
        classifier_input = torch.concatenate((classifier_input, pred), axis=1)

    classifier_target = dataset_batch["truth_link"].reshape(-1,1)
    classifier_pred = classifier_model(classifier_input)  
    classifier_loss = loss_function_classifier(classifier_pred, classifier_target, dataset_batch["truth_gen"], classWeights_classifier)
    linkMetrics.classifier_metrics.Fill(classifier_loss, classifier_pred, classifier_target) 

<div class="alert alert-block alert-info" style="font-size: 18px;">
   Training/testing loops
</div>

In [None]:
# Optimiser
optimiser = torch.optim.Adam(itertools.chain(branch_model.parameters(), classifier_model.parameters()), lr=LEARNING_RATE)

# Put here some metrics
epochs_metrics = []
training_link_metrics = []
testing_link_metrics = []

for epoch in range(N_EPOCHS):

    # Begin training mode
    branch_model.train()
    classifier_model.train()
    
    for dataset_batch in loader_train :  
        
        # Skip incomplete batches
        if (dataset_batch["truth_link"].shape[0] != BATCH_SIZE) :
            continue        
            
         # Run training loop
        total_loss = RunTrainingLoop(nLinks, dataset_batch, branch_model, classifier_model, classWeights_branch, classWeights_classifier)
        
        # Update model parameters
        optimiser.zero_grad()
        total_loss.backward()
        optimiser.step()
        
    with torch.no_grad():
        
        # Begin testing mode
        branch_model.eval()
        classifier_model.eval()
        
        # Initialise metrics            
        linkMetrics_train = TrainingMetrics.LinkMetrics(nLinks)
        linkMetrics_test = TrainingMetrics.LinkMetrics(nLinks)        
                
        # Iterate in batches over the training dataset.                        
        for dataset_batch_train in loader_train :   

            # Skip incomplete batches
            if (dataset_batch_train["truth_link"].shape[0] != BATCH_SIZE) :
                continue        

            # Run validation loop
            RunValidationLoop(nLinks, dataset_batch_train, branch_model, classifier_model, classWeights_branch, classWeights_classifier, linkMetrics_train)
            
        # Iterate in batches over the testing dataset.  
        for dataset_batch_test in loader_test :   

            # Skip incomplete batches
            if (dataset_batch_test["truth_link"].shape[0] != BATCH_SIZE) :
                continue        

            # Run validation loop
            RunValidationLoop(nLinks, dataset_batch_test, branch_model, classifier_model, classWeights_branch, classWeights_classifier, linkMetrics_test)            
            
        epochs_metrics.append(epoch)   
    
    ##########################
    # Calc metrics for epoch 
    ##########################   
    # Find threshold
    optimal_threshold_train, maximum_accuracy_train = TrainingMetrics.calculate_accuracy(linkMetrics_train)
    optimal_threshold_test, maximum_accuracy_test = TrainingMetrics.calculate_accuracy(linkMetrics_test)

    # Calculate metrics
    linkMetrics_train.Evaluate(optimal_threshold_train)
    linkMetrics_test.Evaluate(optimal_threshold_test)
    
    # Add to our lists
    training_link_metrics.append(linkMetrics_train)
    testing_link_metrics.append(linkMetrics_test) 
    
    # Do some prints
    print('----------------------------------------')
    print('Epoch:', epoch)
    print('----------------------------------------')
    print('training_classification_loss:', round(linkMetrics_train.classifier_metrics.av_loss, 2))
    print('----')
    print('optimal_threshold_train:', optimal_threshold_train)
    print('accuracy_train:', str(round(maximum_accuracy_train.item(), 2)) +'%')
    print('positive_as_positive_fraction_train:', str(round(linkMetrics_train.classifier_metrics.pos_as_pos_frac * 100.0, 2)) + '%')
    print('positive_as_negative_fraction_train:', str(round(linkMetrics_train.classifier_metrics.pos_as_neg_frac * 100.0, 2)) + '%')
    print('negative_as_negative_fraction_train:', str(round(linkMetrics_train.classifier_metrics.neg_as_pos_frac * 100.0, 2)) + '%')
    print('negative_as_positive_fraction_train:', str(round(linkMetrics_train.classifier_metrics.neg_as_neg_frac * 100.0, 2)) + '%')
    print('----')
    print('testing_classification_loss:', round(linkMetrics_test.classifier_metrics.av_loss, 2))
    print('----')
    print('optimal_threshold_test:', optimal_threshold_test)
    print('accuracy_test:', str(round(maximum_accuracy_test.item(), 2)) +'%')
    print('positive_as_positive_fraction_test:', str(round(linkMetrics_test.classifier_metrics.pos_as_pos_frac * 100.0, 2)) + '%')
    print('positive_as_negative_fraction_test:', str(round(linkMetrics_test.classifier_metrics.pos_as_neg_frac * 100.0, 2)) + '%')
    print('negative_as_negative_fraction_test:', str(round(linkMetrics_test.classifier_metrics.neg_as_pos_frac * 100.0, 2)) + '%')
    print('negative_as_positive_fraction_test:', str(round(linkMetrics_test.classifier_metrics.neg_as_neg_frac * 100.0, 2)) + '%')
    print('----')
    
    for i in [0, 1, 2] :
        TrainingMetrics.plot_scores_branch(linkMetrics_train, linkMetrics_test, 0, i)

    TrainingMetrics.plot_scores_classifier(linkMetrics_train, linkMetrics_test)    

<div class="alert alert-block alert-info" style="font-size: 18px;">
   Plot metrics associated with training 
</div>

In [None]:
TrainingMetrics.plot_branch_loss_evolution(epochs_metrics, training_link_metrics, testing_link_metrics, 0, 'Loss - branch_model_0')
TrainingMetrics.plot_branch_loss_evolution(epochs_metrics, training_link_metrics, testing_link_metrics, 1, 'Loss - branch_model_1')
TrainingMetrics.plot_classifier_loss_evolution(epochs_metrics, training_link_metrics, testing_link_metrics, 'Loss - classifier')
TrainingMetrics.plot_edge_rate(epochs_metrics, training_link_metrics, testing_link_metrics, True)
TrainingMetrics.plot_edge_rate(epochs_metrics,  training_link_metrics, testing_link_metrics, False)

<div class="alert alert-block alert-info" style="font-size: 18px;">
   Show ROC curve and confusion matrices, for the latter you can decide the threshold cut used
</div>

In [None]:
with torch.no_grad():
    branch_model.eval()
    classifier_model.eval()
    
    classifier_input = torch.empty(0,)
    
    for i in range(nLinks) :
        pred = branch_model(torch.tensor(branch_model_input_test[i], dtype=torch.float))
        classifier_input = torch.concatenate((classifier_input, pred), axis=1)

    classifier_pred_test = classifier_model(classifier_input).reshape(-1)
    neg_scores_final_test = np.array(classifier_pred_test.tolist())[trueParentChildLink_test == 0].reshape(-1)
    pos_scores_final_test = np.array(classifier_pred_test.tolist())[trueParentChildLink_test == 1].reshape(-1)
    
    TrainingMetrics.plot_roc_curve(torch.tensor(pos_scores_final_test), torch.tensor(neg_scores_final_test))
    TrainingMetrics.draw_confusion_with_threshold(classifier_pred_test, trueParentChildLink_test, 0.5)

<div class="alert alert-block alert-info" style="font-size: 18px;">
   Save the model
</div>

In [None]:
sm_branch = torch.jit.script(branch_model)
sm_branch.save(branchModelPath)

sm_classifier = torch.jit.script(classifier_model)
sm_classifier.save(classifierModelPath)