## Encoder-Decoder Performance

written by Isobel Mawby (i.mawby1@lancaster.ac.uk)

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Imports
</div>

In [None]:
import numpy as np
import torch  
import torch.nn as nn  
import torch.optim as optim  
from torch.utils.data import DataLoader
from tqdm.notebook import trange, tqdm
import matplotlib.pyplot as plt
import time
import sys

import Datasets
import TrainingMetrics
import Models

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Set device
</div>

In [None]:
device = torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cpu")

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Config
</div>

In [None]:
TRAINING_FRACTION = 0.75
BATCH_SIZE = 64
LEARNING_RATE = 1e-4
N_EPOCHS = 1
ALPHA = 2.0     # Loss scaling

<div class="alert alert-block alert-info" style="font-size: 18px;">
    File 
</div>

In [None]:
modelPath = sys.path[0] + '/files/SplitPointModel_UVW'

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Pull out things from file
</div>

In [None]:
train_dataset, test_dataset = Datasets.get_split_point_datasets(device, TRAINING_FRACTION)

print('Input(train):', train_dataset.input.shape)
print('Truth(train):', train_dataset.labels.shape)
print('Contaminated(train):', train_dataset.is_contaminated.shape)
print('')
print('Input(test):', test_dataset.input.shape)
print('Truth(test):', test_dataset.labels.shape)
print('Contaminated(test):', test_dataset.is_contaminated.shape)

In [None]:
print('n_background:', torch.count_nonzero(train_dataset.is_contaminated == 0).item())
print('n_signal:', torch.count_nonzero(train_dataset.is_contaminated == 1).item())
print('n_showers:', torch.count_nonzero(train_dataset.is_contaminated == 2).item())

<div class="alert alert-block alert-info" style="font-size: 18px;">
     Do not train on showers, they confuse things
</div>

In [None]:
mask_train = train_dataset.is_contaminated != 2
mask_test = test_dataset.is_contaminated != 2

train_dataset.input = train_dataset.input.unsqueeze(1)[mask_train]
train_dataset.labels = train_dataset.labels.unsqueeze(1)[mask_train]
train_dataset.is_contaminated = train_dataset.is_contaminated[mask_train].reshape(-1,1)

test_dataset.input = test_dataset.input.unsqueeze(1)[mask_test]
test_dataset.labels = test_dataset.labels.unsqueeze(1)[mask_test]
test_dataset.is_contaminated = test_dataset.is_contaminated[mask_test].reshape(-1,1)

print('Input(train):', train_dataset.input.shape)
print('Truth(train):', train_dataset.labels.shape)
print('Contaminated(train):', train_dataset.is_contaminated.shape)
print('')
print('Input(test):', test_dataset.input.shape)
print('Truth(test):', test_dataset.labels.shape)
print('Contaminated(test):', test_dataset.is_contaminated.shape)

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=0, generator=torch.Generator(device='cpu'))
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=0, generator=torch.Generator(device='cpu'))

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Class Weights
</div>

In [None]:
nFalse = torch.count_nonzero(train_dataset.labels == 0).item()
nTrue = torch.count_nonzero(train_dataset.labels == 1).item()
maxValue = max(nTrue, nFalse)

class_weights = torch.tensor([float(maxValue)/float(nFalse), float(maxValue)/float(nTrue)])

print('nTrue:', nTrue)
print('nFalse:', nFalse)
print('weights:', class_weights)

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Setup training
</div>

In [None]:
# Get our model
_, _, n_features = train_dataset.input.shape
model = Models.ConvEncoderDecoder(device, num_features=n_features)

# Initialize the optimizer with above parameters
optimiser = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# scheduler
scheduler = optim.lr_scheduler.ExponentialLR(optimiser, gamma=0.97)

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Let's start training!
</div>

In [None]:
def loss_function(prediction, truth, class_weights) :

    b, l, _ = truth.shape
    
    # do not reduce, so we can handle masking
    loss_func = torch.nn.BCEWithLogitsLoss(reduction='none')

    # Get loss_1
    loss_1 = loss_func(prediction, truth.float())

    # Apply weighting
    loss_1[truth == 0] *= class_weights[0]
    loss_1[truth == 1] *= class_weights[1]  
    loss_1 = loss_1.squeeze(2)
    loss_1 = (loss_1.sum(dim=1)/l)
    
    # Want to make sure that model doesn't produce an excessive number of split points
    prediction = torch.sigmoid(prediction)
    n_true_splits =  truth.squeeze(2).sum(dim=1).float()
    n_pred_splits =  prediction.squeeze(2).sum(dim=1).float()
    diff = n_pred_splits - n_true_splits
    relu = nn.ReLU()
    loss_2 = relu(diff)
    
    loss_combined = loss_1 + (loss_2 * ALPHA)
    loss_combined = loss_combined.sum() / b
    
    return loss_combined

In [None]:
# Initialize progress bar for tracking epochs
pbar = trange(0, N_EPOCHS, leave=True, desc="Epoch")

# Loggers for training and testing
training_av_loss_logger = []
test_av_loss_logger = []
train_acc_logger = [[],[]]
test_acc_logger = [[],[]]

# Loop over each epoch
for epoch in pbar:

    # Metrics
    train_loss_count = 0
    test_loss_count = 0
    train_acc_count = [0, 0]                                                                                                                                                                                
    test_acc_count = [0, 0]                                                                                                                                                                                 
    counts_train = [0, 0]                                                                                                                                                                                   
    counts_test = [0, 0]   
    
    # Set the model to training mode
    venusaurus.train()
    t0 = time.time()
    
    # Loop over each batch in the training dataset
    for x_train, label_train, _ in tqdm(train_dataloader, desc="Training", leave=True):
        
        # Make prediction
        pred = model(x_train)
        
        # Compute the loss using cross-entropy loss
        loss = loss_function(pred, label_train, class_weights)
        
        # Backpropagation and optimization step
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()
            
    tt = time.time()
    
    # Update learning rate
    before_lr = optimiser.param_groups[0]["lr"]
    scheduler.step()
    after_lr = optimiser.param_groups[0]["lr"]
    print("Epoch %d: SGD lr %.4f -> %.4f" % (epoch, before_lr, after_lr))   
    
    # Set the model to evaluation mode
    model.eval()  
    
    # Loop over each batch in the training dataset
    for x_train, label_train, _ in tqdm(train_dataloader, desc="Validation(Train)", leave=True):

        # Make prediction
        pred = model(x_train)
        
        # Compute the loss using cross-entropy loss
        loss = loss_function(pred, label_train, class_weights)
        train_loss_count += loss.item()

        # Apply sigmoid for inference
        pred = torch.sigmoid(pred)
        
        # Update training accuracy
        label_train = label_train.reshape(-1)
        pred = pred.reshape(-1)
        pred = torch.round(pred)
        
        for i in range(2) :
            train_acc_count[i] += torch.count_nonzero(torch.logical_and(torch.isclose(pred, label_train.float()), label_train == i))
            counts_train[i] += torch.count_nonzero(label_train == i)      
    
    # Loop over each batch in the testing dataset
    with torch.no_grad():
        for x_test, label_test, _ in tqdm(test_dataloader, desc="Validation(Test)", leave=True):    

            # Make prediction
            pred = model(x_test)

            # Compute the loss using cross-entropy loss
            loss = loss_function(pred, label_test, class_weights)  
            test_loss_count += loss.item()

            # Apply sigmoid for inference
            pred = torch.sigmoid(pred)
            
            # Update training accuracy
            label_test = label_test.reshape(-1)
            pred = pred.reshape(-1)
            pred = torch.round(pred)
        
            for i in range(2) :
                test_acc_count[i] += torch.count_nonzero(torch.logical_and(torch.isclose(pred, label_test.float()), label_test == i))
                counts_test[i] += torch.count_nonzero(label_test == i)                            
            
        tv = time.time()            
        
        # Prints for Epoch
        train_av_loss_logger.append(train_loss_count / len(train_dataloader))
        test_av_loss_logger.append(test_loss_count / len(test_dataloader))
        print(f'Training Time: {tt-t0:.3f} s')
        print(f'Validation Time: {tv-tt:.3f} s')
        print('')
        for i in range(2) :
            train_acc = (train_acc_count[i] / counts_train[i]).item()
            test_acc = (test_acc_count[i] / counts_test[i]).item()
            
            print(f'train/test accuracy for class {i}: {train_acc*100.0:.2f}%, {test_acc*100.0:.2f}%')
            train_acc_logger[i].append(train_acc)
            test_acc_logger[i].append(test_acc)             
            
        # Save for each epoch                                                                                                                                                            
        model_cpu = model.to('cpu').eval()                                                                                                                                                
        sm = torch.jit.script(model_cpu)                                                                                                                                                       
        sm.save(f"{modelPath}_alpha_"+str(ALPHA)+"_epoch_" + str(epoch) +".pt")                                                                                                                                          
        torch.save(model_cpu.state_dict(), f"{modelPath}_alpha_"+str(ALPHA)+"_epoch_" + str(epoch) +".pkl")                                                                                                         
        model = model.to(device)                                                                                                                                                          
        print(f"Saved model at epoch {epoch} with test accuracy {test_acc:.4f}")   

In [None]:
_ = plt.figure(figsize=(10, 5))
_ = plt.plot(range(1, N_EPOCHS + 1), training_av_loss_logger)
_ = plt.plot(range(1, N_EPOCHS + 1), test_av_loss_logger)
_ = plt.legend(["Train", "Test"])
_ = plt.title("Training Vs Test Av. Loss")
_ = plt.xlabel("Epochs")
_ = plt.ylabel("Loss")

In [None]:
for i in range(2) :
    _ = plt.figure(figsize=(10, 5))
    _ = plt.plot(range(1, N_EPOCHS + 1), train_acc_logger[i])
    _ = plt.plot(range(1, N_EPOCHS + 1), test_acc_logger[i])

    _ = plt.legend(["Train", "Test"])
    _ = plt.title("Training Vs Test Accuracy")
    _ = plt.xlabel("Epochs")
    _ = plt.ylabel("Accuracy")

In [None]:
chosen_epoch = 4

modelPath_split_point = f"{modelPath}_alpha_"+str(ALPHA)+"_epoch_" + str(chosen_epoch) +".pt"
#modelPath_split_point = "/home/imawby/Venusaurus/files/SplitPosModel_U.pt"

print(modelPath_split_point)

venusaurus_split_point = torch.jit.load(modelPath_split_point)


<div class="alert alert-block alert-info" style="font-size: 18px;">
    Made some post-training performance plots
</div>

In [None]:
venusaurus_split_point.eval()

pred_final_train = []
truth_train = []
true_scores_train = []
false_scores_train = []

pred_final_val = []
truth_val = []
true_scores_val = []
false_scores_val = []

with torch.no_grad():
    
    for x_train, label_train, _ in training_dataloader:         
        
        # Make prediction
        pred = venusaurus_split_point(x_train.to('cpu')).to('cpu')
        pred = torch.sigmoid(pred)
        
        # Move to correct device
        label_train = label_train.to('cpu')
        
        # Sort out mask
        label_train = label_train.reshape(-1)
        pred = pred.reshape(-1)
        
        true_scores_train.extend(pred[label_train == 1].tolist())
        false_scores_train.extend(pred[label_train == 0].tolist())
        pred_final_train.extend(pred.tolist())
        truth_train.extend(label_train.tolist()) 
    
    for x_val, label_val, _ in validation_dataloader:          

        # Make prediction
        pred = venusaurus_split_point(x_val.to('cpu')).to('cpu')
        pred = torch.sigmoid(pred)
        
        # Move to correct device
        label_val = label_val.to('cpu')
        
        # Sort out mask
        label_val = label_val.reshape(-1)
        pred = pred.reshape(-1)
        
        true_scores_val.extend(pred[label_val == 1].tolist())
        false_scores_val.extend(pred[label_val == 0].tolist())
        pred_final_val.extend(pred.tolist())
        truth_val.extend(label_val.tolist())        


In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

def plot_scores(true_scores_train, false_scores_train, true_scores_val, false_scores_val) :
    
    true_plotting_weights_train = 1.0 / float(true_scores_train.shape[0])
    true_plotting_weights_train = torch.ones(true_scores_train.shape) * true_plotting_weights_train
    false_plotting_weights_train = 1.0 / float(false_scores_train.shape[0])
    false_plotting_weights_train = torch.ones(false_scores_train.shape) * false_plotting_weights_train
    true_plotting_weights_val = 1.0 / float(true_scores_val.shape[0])
    true_plotting_weights_val = torch.ones(true_scores_val.shape) * true_plotting_weights_val
    false_plotting_weights_val = 1.0 / float(false_scores_val.shape[0])
    false_plotting_weights_val = torch.ones(false_scores_val.shape) * false_plotting_weights_val    
    
    plt.hist(true_scores_train, bins=50, range=(0, 1.0), color='blue', label='signal_train', weights=true_plotting_weights_train, histtype='step', linestyle='solid')
    plt.hist(false_scores_train, bins=50, range=(0, 1.0), color='red', label='background_train', weights=false_plotting_weights_train, histtype='step', linestyle='solid')
    plt.hist(true_scores_val, bins=50, range=(0, 1.0), color='blue', label='signal_test', weights=true_plotting_weights_val, histtype='step', linestyle='dashed')
    plt.hist(false_scores_val, bins=50, range=(0, 1.0), color='red', label='background_test', weights=false_plotting_weights_val, histtype='step', linestyle='dashed')    

    
    #plt.ylim(0, 0.8)
    plt.yscale("log")
    
    plt.xlabel('Classification Score')
    #plt.ylabel('log(Proportion of Showers)')
    plt.ylabel('Proportion of Showers')
    plt.legend(loc='best')
    plt.show()  
    
    
    
def draw_confusion_with_threshold(pred, labels, threshold):
    
    scores = pred.copy()
    
    n_classes = 2
    
    predicted_true_mask = scores > threshold
    predicted_false_mask = np.logical_not(predicted_true_mask)
    scores[predicted_true_mask] = 1
    scores[predicted_false_mask] = 0

    confMatrix = confusion_matrix(labels, scores)
    
    trueSums = np.sum(confMatrix, axis=1)
    predSums = np.sum(confMatrix, axis=0)

    trueNormalised = np.zeros(shape=(n_classes, n_classes))
    predNormalised = np.zeros(shape=(n_classes, n_classes))

    for trueIndex in range(n_classes) : 
        for predIndex in range(n_classes) :
            nEntries = confMatrix[trueIndex][predIndex]
            if trueSums[trueIndex] > 0 :
                trueNormalised[trueIndex][predIndex] = float(nEntries) / float(trueSums[trueIndex])
            if predSums[predIndex] > 0 :
                predNormalised[trueIndex][predIndex] = float(nEntries) / float(predSums[predIndex])

    displayTrueNorm = ConfusionMatrixDisplay(confusion_matrix=trueNormalised, display_labels=["False", "True"])
    displayTrueNorm.plot()

    displayPredNorm = ConfusionMatrixDisplay(confusion_matrix=predNormalised, display_labels=["False", "True"])
    displayPredNorm.plot()



#TrainingMetrics.plot_scores(np.array(true_scores_train), np.array(false_scores_train), np.array(true_scores_val), np.array(false_scores_val))
plot_scores(np.array(true_scores_train), np.array(false_scores_train), np.array(true_scores_val), np.array(false_scores_val))
#TrainingMetrics.draw_confusion_with_threshold(np.array(pred_final_train), np.array(truth_train), 0.5)


In [None]:
draw_confusion_with_threshold(np.array(pred_final_val), np.array(truth_val), 0.95)

In [None]:
pred_final_val = np.array(pred_final_val)
pred_final_val= np.round(pred_final_val)
truth_val = np.array(truth_val)


pos_pred_count = np.count_nonzero(np.logical_and(np.isclose(pred_final_val, truth_val), truth_val == 0))
pos_true_count = np.count_nonzero(truth_val == 0)                            

print(pos_pred_count / pos_true_count)