In [None]:
###########################################################
# Imports
###########################################################
import sys
import os
sys.path.insert(0, os.getcwd()[0:len(os.getcwd()) - 11])
sys.path.insert(1, os.getcwd()[0:len(os.getcwd()) - 11] + '/Metrics')

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import itertools

import Models
import Datasets
import TrainingMetrics

In [None]:
###########################################################
# Define file
###########################################################

trainFileName = sys.path[0] + '/files/hierarchy_TRAIN_track.npz'

branchModelPath = sys.path[0] + '/models/primary_track_branch_model'
classifierModelPath = sys.path[0] + '/models/primary_track_classifier_model'

In [None]:
###########################################################
# Hyperparameters
###########################################################

N_EPOCHS = 5
BATCH_SIZE = 64
LEARNING_RATE = 1e-3
DROPOUT_RATE = 0.5

In [None]:
###########################################################
# Get data from file
###########################################################

data = np.load(trainFileName)

# Variables
variables_train = data['variables_train']
variables_test = data['variables_test']
# Training cut
trainingCutDCA_train = data['trainingCutDCA_train']
trainingCutDCA_test = data['trainingCutDCA_test']
# Truth
y_train = data['y_train']
y_test = data['y_test']
isTruePrimaryLink_train = data['isTruePrimaryLink_train']
isTruePrimaryLink_test = data['isTruePrimaryLink_test']

In [None]:
###########################################################
# Check shapes
###########################################################

print('variables_train.shape:', variables_train.shape)
print('variables_test.shape:', variables_test.shape)
print('y_train.shape:', y_train.shape)
print('y_test.shape:', y_test.shape)
print('trainingCutDCA_train.shape:', trainingCutDCA_train.shape)
print('trainingCutDCA_test.shape:', trainingCutDCA_test.shape)
print('isTruePrimaryLink_train.shape:', isTruePrimaryLink_train.shape)
print('isTruePrimaryLink_test.shape:', isTruePrimaryLink_test.shape)

nVariables = variables_train.shape[1]
nLinks = y_train.shape[1]
ntrain = variables_train.shape[0]
ntest  = variables_test.shape[0]

print('')
print('ntrain:', ntrain)
print('ntest:', ntest)

In [None]:
###########################################################
# Apply training cut topology mask
###########################################################

# These are inspired by hand-scanning :'(
MAX_TRAINING_CUT_DCA = 50.0

######################
# training set first
######################
# Make mask
passTrainingCutDCA_train = trainingCutDCA_train < MAX_TRAINING_CUT_DCA
passTrainingCuts_train = passTrainingCutDCA_train

# Mask the 1D variables... shape=(nEntries, )
trainingCutDCA_train = trainingCutDCA_train[passTrainingCuts_train]
isTruePrimaryLink_train = isTruePrimaryLink_train[passTrainingCuts_train]

# Mask the truth... shape=(nEntries, nLinks)
y_train = y_train[np.column_stack((passTrainingCuts_train, passTrainingCuts_train))].reshape(-1, nLinks)

# Mask the variable... shape=(nEntries, nVariables)
variables_train = variables_train[[[entry] * nVariables for entry in passTrainingCuts_train]].reshape(-1, nVariables)

######################
# now test set
######################
# Make mask
passTrainingCutDCA_test = trainingCutDCA_test < MAX_TRAINING_CUT_DCA
passTrainingCuts_test = passTrainingCutDCA_test

# Mask the 1D variables... shape=(nEntries, )
trainingCutDCA_test = trainingCutDCA_test[passTrainingCuts_test]
isTruePrimaryLink_test = isTruePrimaryLink_test[passTrainingCuts_test]

# Mask the truth... shape=(nEntries, nLinks)
y_test = y_test[np.column_stack((passTrainingCuts_test, passTrainingCuts_test))].reshape(-1, nLinks)

# Mask the variable... shape=(nEntries, nVariables)
variables_test = variables_test[[[entry] * nVariables for entry in passTrainingCuts_test]].reshape(-1, nVariables)

In [None]:
###########################################################
# Check shapes
###########################################################

print('variables_train.shape:', variables_train.shape)
print('variables_test.shape:', variables_test.shape)
print('y_train.shape:', y_train.shape)
print('y_test.shape:', y_test.shape)
print('trainingCutDCA_train.shape:', trainingCutDCA_train.shape)
print('trainingCutDCA_test.shape:', trainingCutDCA_test.shape)
print('isTruePrimaryLink_train.shape:', isTruePrimaryLink_train.shape)
print('isTruePrimaryLink_test.shape:', isTruePrimaryLink_test.shape)

nVariables = variables_train.shape[1]
nLinks = y_train.shape[1]
ntrain = variables_train.shape[0]
ntest  = variables_test.shape[0]

print('')
print('ntrain:', ntrain)
print('ntest:', ntest)

In [None]:
###########################################################
# Define branch weights
###########################################################

nTrue = np.count_nonzero(y_train == 1)
nBackground = np.count_nonzero(y_train == 0)
nWrongOrientation = np.count_nonzero(y_train == 2)

maxLinks = max(nTrue, nBackground, nWrongOrientation)

classWeights_branch = torch.tensor([float(maxLinks)/float(nBackground), float(maxLinks)/float(nTrue), float(maxLinks)/float(nWrongOrientation)])

print('nTrue:', nTrue)
print('nBackground:', nBackground)
print('nWrongOrientation:', nWrongOrientation)
print('class weights:', classWeights_branch)

In [None]:
###########################################################
# Work final weights
###########################################################

n_true_primary_train = np.count_nonzero(isTruePrimaryLink_train == True)
n_false_primary_train = np.count_nonzero(isTruePrimaryLink_train == False)

maxCounts_train = max(n_true_primary_train, n_false_primary_train)

classWeights_classifier = {'true_primary_train'  : maxCounts_train/n_true_primary_train, \
                           'false_primary_train' : maxCounts_train/n_false_primary_train, \
                          }

# logitsWeights_neg = float(n_true_primary_train + n_false_primary_train) / (n_false_primary_train * 2)
# logitsWeights_pos = float(n_true_primary_train + n_false_primary_train) / (n_true_primary_train * 2)
# logitsWeights = torch.tensor([logitsWeights_neg, logitsWeights_pos])

print('n_true_primary_train:', n_true_primary_train)
print('n_false_primary_train:', n_false_primary_train)
print('classWeights_classifier:', classWeights_classifier)

In [None]:
###########################################################
# Split it into two again
###########################################################

n_links = 2
n_global_vars = 1
n_link_vars = 8

i_end_of_global_vars = n_global_vars
i_target_start = variables_train.shape[1] - n_link_vars
i_rest_start = n_global_vars
i_rest_end = n_global_vars + (n_link_vars * (n_links - 1))

# Training
input_0_train_temp = variables_train
input_1_train_temp = np.concatenate((input_0_train_temp[:,0:i_end_of_global_vars], input_0_train_temp[:,i_target_start:], input_0_train_temp[:,i_rest_start:i_rest_end]), axis=1)

y_0_train_temp = y_train[:,0]
y_1_train_temp = y_train[:,1]

isTruePrimaryLink_train_temp = isTruePrimaryLink_train

# Test
input_0_test_temp = variables_test
input_1_test_temp = np.concatenate((input_0_test_temp[:,0:i_end_of_global_vars], input_0_test_temp[:,i_target_start:], input_0_test_temp[:,i_rest_start:i_rest_end]), axis=1)

y_0_test_temp = y_test[:,0]
y_1_test_temp = y_test[:,1]

isTruePrimaryLink_test_temp = isTruePrimaryLink_test

In [None]:
# set up DataLoader for training set
dataset_train = Datasets.NuToTrackDataset(input_0_train_temp, input_1_train_temp, y_train[:,0], y_train[:,1], isTruePrimaryLink_train)
loader_train = Datasets.DataLoader(dataset_train, shuffle=True, batch_size=BATCH_SIZE)    

dataset_test = Datasets.NuToTrackDataset(input_0_test_temp, input_1_test_temp, y_test[:,0], y_test[:,1], isTruePrimaryLink_test)
loader_test = Datasets.DataLoader(dataset_test, shuffle=True, batch_size=BATCH_SIZE)    

In [None]:
###########################################################
# Define the model and optimiser and compile the model
###########################################################

branch_model = Models.OrientationModel(nVariables, dropoutRate=DROPOUT_RATE)
classifier_model = Models.ClassifierModel(nLinks * 3)

In [None]:
###########################################################
# Loss functions for training
###########################################################
    
# Implement own weighting
def loss_function_branch(pred, target, classWeights) :
    loss_func = torch.nn.CrossEntropyLoss(weight=classWeights)
    loss = loss_func(pred, target)    
    return loss

# Implement own weighting
def loss_function_classifier(pred, target, classWeights) :
    
    # Do weighting
    weight = torch.ones(target.shape)
    weight[target < 0.5] = classWeights['false_primary_train']
    weight[target > 0.5] = classWeights['true_primary_train']
    
    # Use BCE loss
    loss_func = torch.nn.BCELoss(weight=weight)
    
    # Annoyingly we have to change our target to a float
    target = target.to(torch.float32)

    # Calculate loss
    loss = loss_func(pred, target)    

    return loss


In [None]:
######################
# Training and testing
######################

# Optimiser
optimiser = torch.optim.Adam(itertools.chain(branch_model.parameters(), classifier_model.parameters()), lr=LEARNING_RATE)

# Scheduler
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimiser, 'min', factor=0.1, patience=3, threshold=0.0001, threshold_mode='abs')

# Put here some metrics
training_epoch = []
training_loss_0 = []
training_loss_1 = []
training_classification_loss = []
training_accuracy = []
training_positive_as_positive_rate = []
training_positive_as_negative_rate = []
training_negative_as_negative_rate = []
training_negative_as_positive_rate = []

testing_epoch = []
testing_loss_0 = []
testing_loss_1 = []
testing_classification_loss = []
testing_accuracy = []
testing_positive_as_positive_rate = []
testing_positive_as_negative_rate = []
testing_negative_as_negative_rate = []
testing_negative_as_positive_rate = []

for epoch in range(N_EPOCHS):
    
    ######################
    # Training 
    ######################    
    branch_model.train()
    classifier_model.train()
    
    # Iterate in batches over the training dataset.                        
    for features_0, features_1, target_0, target_1, target in loader_train:  
        
        # Skip incomplete batches
        if (target.shape[0] != BATCH_SIZE) :
            continue        
            
        # Get predictions
        pred_0 = branch_model(features_0)
        pred_1 = branch_model(features_1)
        classifier_pred = classifier_model(torch.concatenate((pred_0, pred_1), axis=1))
        
        # Get loss
        target_0 = target_0.to(torch.long)
        target_1 = target_1.to(torch.long)
        classifier_target = target.to(torch.long).reshape(-1,1)
        loss_0 = loss_function_branch(pred_0, target_0, classWeights_branch)        
        loss_1 = loss_function_branch(pred_1, target_1, classWeights_branch) 
        classifier_loss = loss_function_classifier(classifier_pred, classifier_target, classWeights_classifier)
        total_loss = loss_0 + loss_1 + classifier_loss
        
        # Update model parameters
        optimiser.zero_grad()
        total_loss.backward()
        optimiser.step()
        
    ######################
    # Validation metrics 
    ######################
    with torch.no_grad():
        
        # Begin testing mode
        branch_model.eval()
        classifier_model.eval()
        
        # Initialise metrics        
        total_batches_train = 0        
        total_loss_0_train = 0
        true_scores_0_train = []
        background_scores_0_train = []
        wrong_orientation_scores_0_train = []
        total_loss_1_train = 0
        true_scores_1_train = []
        background_scores_1_train = []
        wrong_orientation_scores_1_train = []
        total_classification_loss_train = 0
        positive_scores_train = []
        negative_scores_train = []
        
        total_batches_test = 0        
        total_loss_0_test = 0
        true_scores_0_test = []
        background_scores_0_test = []
        wrong_orientation_scores_0_test = []
        total_loss_1_test = 0
        true_scores_1_test = []
        background_scores_1_test = []
        wrong_orientation_scores_1_test = []
        total_classification_loss_test = 0
        positive_scores_test = []
        negative_scores_test = []        
                
        # Iterate in batches over the training dataset.                        
        for features_0_train, features_1_train, target_0_train, target_1_train, target_train in loader_train:  

            # Skip incomplete batches
            if (target_train.shape[0] != BATCH_SIZE) :
                continue        

            # Get predictions
            pred_0_train = branch_model(features_0_train)
            pred_1_train = branch_model(features_1_train)
            classifier_pred_train = classifier_model(torch.concatenate((pred_0_train, pred_1_train), axis=1))

            # Get loss
            target_0_train = target_0_train.to(torch.long)
            target_1_train = target_1_train.to(torch.long)
            classifier_target_train = target_train.to(torch.long).reshape(-1,1)
            loss_0_train = loss_function_branch(pred_0_train, target_0_train, classWeights_branch)   
            loss_1_train = loss_function_branch(pred_1_train, target_1_train, classWeights_branch)
            classifier_loss_train = loss_function_classifier(classifier_pred_train, classifier_target_train, classWeights_classifier)
            total_loss_train = loss_0_train + loss_1_train + classifier_loss_train
            
            # Add to our metrics
            total_batches_train += 1            
            total_loss_0_train += loss_0_train.item()  
            true_scores_0_train.extend(np.array(pred_0_train.tolist())[torch.column_stack((target_0_train == 1, target_0_train == 1, target_0_train == 1)).numpy()].reshape(-1, 3))
            background_scores_0_train.extend(np.array(pred_0_train.tolist())[torch.column_stack((target_0_train == 0, target_0_train == 0, target_0_train == 0)).numpy()].reshape(-1, 3))
            wrong_orientation_scores_0_train.extend(np.array(pred_0_train.tolist())[torch.column_stack((target_0_train == 2, target_0_train == 2, target_0_train == 2)).numpy()].reshape(-1, 3))
            total_loss_1_train += loss_1_train.item()  
            true_scores_1_train.extend(np.array(pred_1_train.tolist())[torch.column_stack((target_1_train == 1, target_1_train == 1, target_1_train == 1)).numpy()].reshape(-1, 3))
            background_scores_1_train.extend(np.array(pred_1_train.tolist())[torch.column_stack((target_1_train == 0, target_1_train == 0, target_1_train == 0)).numpy()].reshape(-1, 3))
            wrong_orientation_scores_1_train.extend(np.array(pred_1_train.tolist())[torch.column_stack((target_1_train == 2, target_1_train == 2, target_1_train == 2)).numpy()].reshape(-1, 3))
            total_classification_loss_train += classifier_loss_train.item()
            positive_scores_train.extend(np.array(classifier_pred_train.tolist())[(classifier_target_train == 1).numpy()].reshape(-1))
            negative_scores_train.extend(np.array(classifier_pred_train.tolist())[(classifier_target_train == 0).numpy()].reshape(-1))
            
        training_epoch.append(epoch)            
            
        # Iterate in batches over the training dataset.                        
        for features_0_test, features_1_test, target_0_test, target_1_test, target_test in loader_test:  

            # Skip incomplete batches
            if (target_test.shape[0] != BATCH_SIZE) :
                continue        

            # Get predictions
            pred_0_test = branch_model(features_0_test)
            pred_1_test = branch_model(features_1_test)
            classifier_pred_test = classifier_model(torch.concatenate((pred_0_test, pred_1_test), axis=1))

            # Get loss
            target_0_test = target_0_test.to(torch.long)
            target_1_test = target_1_test.to(torch.long)
            classifier_target_test = target_test.to(torch.long).reshape(-1,1)
            loss_0_test = loss_function_branch(pred_0_test, target_0_test, classWeights_branch)   
            loss_1_test = loss_function_branch(pred_1_test, target_1_test, classWeights_branch)
            classifier_loss_test = loss_function_classifier(classifier_pred_test, classifier_target_test, classWeights_classifier)
            total_loss_test = loss_0_test + loss_1_test + classifier_loss_test
            
            # Add to our metrics
            total_batches_test += 1          
            total_loss_0_test += loss_0_test.item()             
            true_scores_0_test.extend(np.array(pred_0_test.tolist())[torch.column_stack((target_0_test == 1, target_0_test == 1, target_0_test == 1)).numpy()].reshape(-1, 3))
            background_scores_0_test.extend(np.array(pred_0_test.tolist())[torch.column_stack((target_0_test == 0, target_0_test == 0, target_0_test == 0)).numpy()].reshape(-1, 3))
            wrong_orientation_scores_0_test.extend(np.array(pred_0_test.tolist())[torch.column_stack((target_0_test == 2, target_0_test == 2, target_0_test == 2)).numpy()].reshape(-1, 3))
            total_loss_1_test += loss_1_test.item()             
            true_scores_1_test.extend(np.array(pred_1_test.tolist())[torch.column_stack((target_1_test == 1, target_1_test == 1, target_1_test == 1)).numpy()].reshape(-1, 3))
            background_scores_1_test.extend(np.array(pred_1_test.tolist())[torch.column_stack((target_1_test == 0, target_1_test == 0, target_1_test == 0)).numpy()].reshape(-1, 3))
            wrong_orientation_scores_1_test.extend(np.array(pred_1_test.tolist())[torch.column_stack((target_1_test == 2, target_1_test == 2, target_1_test == 2)).numpy()].reshape(-1, 3))
            total_classification_loss_test += classifier_loss_test.item()
            positive_scores_test.extend(np.array(classifier_pred_test.tolist())[(classifier_target_test == 1).numpy()].reshape(-1))
            negative_scores_test.extend(np.array(classifier_pred_test.tolist())[(classifier_target_test == 0).numpy()].reshape(-1))            
            
        testing_epoch.append(epoch)
    
    ##########################
    # Calc metrics for epoch 
    ##########################   
    # train
    optimal_threshold_train, maximum_accuracy_train = TrainingMetrics.calculate_accuracy(torch.tensor(positive_scores_train), torch.tensor(negative_scores_train)) 
    # test
    optimal_threshold_test, maximum_accuracy_test = TrainingMetrics.calculate_accuracy(torch.tensor(positive_scores_test), torch.tensor(negative_scores_test))

    # train
    positive_as_positive_train = np.count_nonzero(np.array(positive_scores_train) > optimal_threshold_train)
    positive_as_negative_train = np.count_nonzero(np.array(positive_scores_train) < optimal_threshold_train)
    negative_as_positive_train = np.count_nonzero(np.array(negative_scores_train) > optimal_threshold_train)
    negative_as_negative_train = np.count_nonzero(np.array(negative_scores_train) < optimal_threshold_train)
    # test
    positive_as_positive_test = np.count_nonzero(np.array(positive_scores_test) > optimal_threshold_test)
    positive_as_negative_test = np.count_nonzero(np.array(positive_scores_test) < optimal_threshold_test)
    negative_as_positive_test = np.count_nonzero(np.array(negative_scores_test) > optimal_threshold_test)
    negative_as_negative_test = np.count_nonzero(np.array(negative_scores_test) < optimal_threshold_test)
    
    # train
    positive_as_positive_fraction_train = float(positive_as_positive_train) / float(positive_as_positive_train + positive_as_negative_train)
    positive_as_negative_fraction_train = float(positive_as_negative_train) / float(positive_as_positive_train + positive_as_negative_train)
    negative_as_positive_fraction_train = float(negative_as_positive_train) / float(negative_as_positive_train + negative_as_negative_train)
    negative_as_negative_fraction_train = float(negative_as_negative_train) / float(negative_as_positive_train + negative_as_negative_train)
    # test
    positive_as_positive_fraction_test = float(positive_as_positive_test) / float(positive_as_positive_test + positive_as_negative_test)
    positive_as_negative_fraction_test = float(positive_as_negative_test) / float(positive_as_positive_test + positive_as_negative_test)
    negative_as_positive_fraction_test = float(negative_as_positive_test) / float(negative_as_positive_test + negative_as_negative_test)
    negative_as_negative_fraction_test = float(negative_as_negative_test) / float(negative_as_positive_test + negative_as_negative_test)
    
    # Add to our metrics
    training_loss_0.append(float(total_loss_0_train) / total_batches_train)
    training_loss_1.append(float(total_loss_1_train) / total_batches_train)
    training_classification_loss.append(float(total_classification_loss_train) / total_batches_train)
    training_accuracy.append(maximum_accuracy_train)
    training_positive_as_positive_rate.append(positive_as_positive_fraction_train)
    training_positive_as_negative_rate.append(positive_as_negative_fraction_train)
    training_negative_as_negative_rate.append(negative_as_negative_fraction_train)
    training_negative_as_positive_rate.append(negative_as_positive_fraction_train)
    
    testing_loss_0.append(float(total_loss_0_test) / total_batches_test)
    testing_loss_1.append(float(total_loss_1_test) / total_batches_test)
    testing_classification_loss.append(float(total_classification_loss_test) / total_batches_test)
    testing_positive_as_positive_rate.append(positive_as_positive_fraction_test)
    testing_positive_as_negative_rate.append(positive_as_negative_fraction_test)
    testing_negative_as_negative_rate.append(negative_as_negative_fraction_test)
    testing_negative_as_positive_rate.append(negative_as_positive_fraction_test)    
    
    # Do some prints
    print('----------------------------------------')
    print('Epoch:', epoch)
    print('----------------------------------------')
    print('training_classification_loss:', round(training_classification_loss[-1], 2))
    print('----')
    print('optimal_threshold_train:', optimal_threshold_train)
    print('accuracy_train:', str(round(maximum_accuracy_train.item(), 2)) +'%')
    print('positive_as_positive_fraction_train:', str(round(positive_as_positive_fraction_train * 100.0, 2)) + '%')
    print('positive_as_negative_fraction_train:', str(round(positive_as_negative_fraction_train * 100.0, 2)) + '%')
    print('negative_as_negative_fraction_train:', str(round(negative_as_negative_fraction_train * 100.0, 2)) + '%')
    print('negative_as_positive_fraction_train:', str(round(negative_as_positive_fraction_train * 100.0, 2)) + '%')
    print('----')
    print('testing_classification_loss:', round(testing_classification_loss[-1], 2))
    print('----')
    print('optimal_threshold_test:', optimal_threshold_test)
    print('accuracy_test:', str(round(maximum_accuracy_test.item(), 2)) +'%')
    print('positive_as_positive_fraction_test:', str(round(positive_as_positive_fraction_test * 100.0, 2)) + '%')
    print('positive_as_negative_fraction_test:', str(round(positive_as_negative_fraction_test * 100.0, 2)) + '%')
    print('negative_as_negative_fraction_test:', str(round(negative_as_negative_fraction_test * 100.0, 2)) + '%')
    print('negative_as_positive_fraction_test:', str(round(negative_as_positive_fraction_test * 100.0, 2)) + '%')
    print('----')
    TrainingMetrics.plot_scores_branch(background_scores_0_train, true_scores_0_train, wrong_orientation_scores_0_train, \
                                       background_scores_0_test, true_scores_0_test, wrong_orientation_scores_0_test, 0)
    TrainingMetrics.plot_scores_branch(background_scores_0_train, true_scores_0_train, wrong_orientation_scores_0_train, \
                                       background_scores_0_test, true_scores_0_test, wrong_orientation_scores_0_test, 1)
    TrainingMetrics.plot_scores_branch(background_scores_0_train, true_scores_0_train, wrong_orientation_scores_0_train, \
                                       background_scores_0_test, true_scores_0_test, wrong_orientation_scores_0_test, 2)   

# #     TrainingMetrics.plot_scores_branch(background_scores_1_train, true_scores_1_train, wrong_orientation_scores_1_train, \
# #                                        background_scores_1_test, true_scores_1_test, wrong_orientation_scores_1_test, 0)
# #     TrainingMetrics.plot_scores_branch(background_scores_1_train, true_scores_1_train, wrong_orientation_scores_1_train, \
# #                                        background_scores_1_test, true_scores_1_test, wrong_orientation_scores_1_test, 1)
# #     TrainingMetrics.plot_scores_branch(background_scores_1_train, true_scores_1_train, wrong_orientation_scores_1_train, \
# #                                        background_scores_1_test, true_scores_1_test, wrong_orientation_scores_1_test, 2)    
    TrainingMetrics.plot_scores_classifier(positive_scores_train, negative_scores_train, positive_scores_test, negative_scores_test)

In [None]:
# Print metrics showing evolution   
TrainingMetrics.plot_loss_evolution(training_epoch, training_loss_0, testing_loss_0, 'Loss - branch_model')
TrainingMetrics.plot_loss_evolution(training_epoch, training_loss_1, testing_loss_1, 'Loss - branch_model')
TrainingMetrics.plot_loss_evolution(training_epoch, training_classification_loss, testing_classification_loss, 'Loss - classifier')
TrainingMetrics.plot_edge_rate(training_epoch, training_positive_as_positive_rate, training_positive_as_negative_rate, testing_positive_as_positive_rate, testing_positive_as_negative_rate, True)
TrainingMetrics.plot_edge_rate(testing_epoch, training_negative_as_negative_rate, training_negative_as_positive_rate, testing_negative_as_negative_rate, testing_negative_as_positive_rate, False)

In [None]:
######################
# Confusion matrices!
######################
with torch.no_grad():

    # Begin testing mode
    branch_model.eval()
    classifier_model.eval()
    # Get predictions
    pred_0_test = branch_model(torch.tensor(input_0_test_temp, dtype=torch.float))
    pred_1_test = branch_model(torch.tensor(input_1_test_temp, dtype=torch.float))
    classifier_pred_test = classifier_model(torch.concatenate((pred_0_test, pred_1_test), axis=1)).reshape(-1)

    neg_scores_final_test = np.array(classifier_pred_test.tolist())[isTruePrimaryLink_test == 0].reshape(-1)
    pos_scores_final_test = np.array(classifier_pred_test.tolist())[isTruePrimaryLink_test == 1].reshape(-1)
    
    TrainingMetrics.plot_roc_curve(torch.tensor(pos_scores_final_test), torch.tensor(neg_scores_final_test))
    TrainingMetrics.draw_confusion_with_threshold(classifier_pred_test, isTruePrimaryLink_test, 0.5)
    TrainingMetrics.draw_confusion_with_threshold(classifier_pred_test, isTruePrimaryLink_test, 0.60)

In [None]:
######################
# Save the model
###################### 
sm_branch = torch.jit.script(branch_model)
sm_branch.save(branchModelPath)

sm_classifier = torch.jit.script(classifier_model)
sm_classifier.save(classifierModelPath)