## Encoder-Decoder Training

written by Isobel Mawby (i.mawby1@lancaster.ac.uk)

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Imports
</div>

In [None]:
import numpy as np
import torch  
import torch.nn as nn  
import torch.optim as optim  
from torch.utils.data import DataLoader
from tqdm.notebook import trange, tqdm
import matplotlib.pyplot as plt
import time

import Datasets
import TrainingMetrics
import Models

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Set device
</div>

In [None]:
device = torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cpu")

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Config
</div>

In [None]:
TRAINING_FRACTION = 0.75
BATCH_SIZE = 64
LEARNING_RATE = 1e-4
N_EPOCHS = 1
ALPHA = 2.0     # Loss scaling

<div class="alert alert-block alert-info" style="font-size: 18px;">
    File 
</div>

In [None]:
model_path = sys.path[0] + '/models/SplitPointModel_UVW'

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Pull out things from file
</div>

In [None]:
train_dataset, test_dataset = Datasets.get_split_point_datasets(device, TRAINING_FRACTION)

print('Input(train):', train_dataset.input.shape)
print('Truth(train):', train_dataset.labels.shape)
print('Contaminated(train):', train_dataset.is_contaminated.shape)
print('')
print('Input(test):', test_dataset.input.shape)
print('Truth(test):', test_dataset.labels.shape)
print('Contaminated(test):', test_dataset.is_contaminated.shape)

<div class="alert alert-block alert-info" style="font-size: 18px;">
     Do not train on showers, they confuse things
</div>

In [None]:
mask_train = train_dataset.is_contaminated != 2
mask_test = test_dataset.is_contaminated != 2

train_dataset.input = train_dataset.input.unsqueeze(1)[mask_train]
train_dataset.labels = train_dataset.labels.unsqueeze(1)[mask_train]
train_dataset.is_contaminated = train_dataset.is_contaminated[mask_train].reshape(-1,1)

test_dataset.input = test_dataset.input.unsqueeze(1)[mask_test]
test_dataset.labels = test_dataset.labels.unsqueeze(1)[mask_test]
test_dataset.is_contaminated = test_dataset.is_contaminated[mask_test].reshape(-1,1)

print('Input(train):', train_dataset.input.shape)
print('Truth(train):', train_dataset.labels.shape)
print('Contaminated(train):', train_dataset.is_contaminated.shape)
print('')
print('Input(test):', test_dataset.input.shape)
print('Truth(test):', test_dataset.labels.shape)
print('Contaminated(test):', test_dataset.is_contaminated.shape)

In [None]:
for class_index in TrainingMetrics.contamination_labels :
    print(f'{TrainingMetrics.contamination_strings[class_index]}: {torch.count_nonzero(train_dataset.labels == class_index).item()}')

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=0, generator=torch.Generator(device='cpu'))
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=0, generator=torch.Generator(device='cpu'))

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Class Weights
</div>

In [None]:
nFalse = torch.count_nonzero(train_dataset.labels == 0).item()
nTrue = torch.count_nonzero(train_dataset.labels == 1).item()
maxValue = max(nTrue, nFalse)

if (nFalse == 0) or (nTrue == 0):
    raise Exception("Training class has zero samples!")

class_weights = torch.tensor([float(maxValue)/float(nFalse), float(maxValue)/float(nTrue)])

print('nTrue:', nTrue)
print('nFalse:', nFalse)
print('weights:', class_weights)

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Setup training
</div>

In [None]:
# Get our model
_, _, n_features = train_dataset.input.shape
model = Models.ConvEncoderDecoder(device, num_features=n_features)

# Initialize the optimizer with above parameters
optimiser = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# scheduler
scheduler = optim.lr_scheduler.ExponentialLR(optimiser, gamma=0.97)

<div class="alert alert-block alert-info" style="font-size: 18px;">
    Let's start training!
</div>

In [None]:
def loss_function(prediction, truth, class_weights) :
    b, l, _ = truth.shape
    
    # do not reduce, so we can handle masking
    loss_func = torch.nn.BCEWithLogitsLoss(reduction='none')

    # Get loss_1
    loss_1 = loss_func(prediction, truth.float())

    # Apply weighting
    loss_1[truth == 0] *= class_weights[0]
    loss_1[truth == 1] *= class_weights[1]  
    loss_1 = loss_1.squeeze(2)
    loss_1 = (loss_1.sum(dim=1)/l)
    
    # Want to make sure that model doesn't produce an excessive number of split points
    prediction = torch.sigmoid(prediction)
    n_true_splits =  truth.squeeze(2).sum(dim=1).float()
    n_pred_splits =  prediction.squeeze(2).sum(dim=1).float()
    diff = n_pred_splits - n_true_splits
    relu = nn.ReLU()
    loss_2 = relu(diff)
    
    loss_combined = loss_1 + (loss_2 * ALPHA)
    loss_combined = loss_combined.sum() / b
    
    return loss_combined

In [None]:
# Initialize progress bar for tracking epochs
pbar = trange(0, N_EPOCHS, leave=True, desc="Epoch")

# Loggers for training and testing
train_av_loss_logger = []
test_av_loss_logger = []
train_acc_logger = [[],[]]
test_acc_logger = [[],[]]

# Loop over each epoch
for epoch in pbar:
    # Metrics
    train_loss_count = 0
    test_loss_count = 0
    train_acc_count = [0, 0]                                                                                                                                                                                
    test_acc_count = [0, 0]                                                                                                                                                                                 
    counts_train = [0, 0]                                                                                                                                                                                   
    counts_test = [0, 0]   
    
    # Set the model to training mode
    model.train()
    t0 = time.time()
    
    # Loop over each batch in the training dataset
    for x_train, label_train, _ in tqdm(train_dataloader, desc="Training", leave=True):
        # Make prediction
        pred = model(x_train)
        
        # Compute the loss using cross-entropy loss
        loss = loss_function(pred, label_train, class_weights)
        
        # Backpropagation and optimization step
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()
            
    tt = time.time()
    
    # Update learning rate
    before_lr = optimiser.param_groups[0]["lr"]
    scheduler.step()
    after_lr = optimiser.param_groups[0]["lr"]
    print("Epoch %d: SGD lr %.4f -> %.4f" % (epoch, before_lr, after_lr))   
    
    # Set the model to evaluation mode
    model.eval()  
    
    # Loop over each batch in the training dataset
    for x_train, label_train, _ in tqdm(train_dataloader, desc="Validation(Train)", leave=True):
        # Make prediction
        pred = model(x_train)
        
        # Compute the loss using cross-entropy loss
        loss = loss_function(pred, label_train, class_weights)
        train_loss_count += loss.item()

        # Apply sigmoid for inference
        pred = torch.sigmoid(pred)
        
        # Update training accuracy
        label_train = label_train.reshape(-1)
        pred = pred.reshape(-1)
        pred = torch.round(pred)
        
        for i in range(2) :
            train_acc_count[i] += torch.count_nonzero(torch.logical_and(torch.isclose(pred, label_train.float()), label_train == i))
            counts_train[i] += torch.count_nonzero(label_train == i)      
    
    # Loop over each batch in the testing dataset
    with torch.no_grad():
        for x_test, label_test, _ in tqdm(test_dataloader, desc="Validation(Test)", leave=True):    

            # Make prediction
            pred = model(x_test)

            # Compute the loss using cross-entropy loss
            loss = loss_function(pred, label_test, class_weights)  
            test_loss_count += loss.item()

            # Apply sigmoid for inference
            pred = torch.sigmoid(pred)
            
            # Update training accuracy
            label_test = label_test.reshape(-1)
            pred = pred.reshape(-1)
            pred = torch.round(pred)
        
            for i in range(2) :
                test_acc_count[i] += torch.count_nonzero(torch.logical_and(torch.isclose(pred, label_test.float()), label_test == i))
                counts_test[i] += torch.count_nonzero(label_test == i)                            
            
        tv = time.time()            
        
        # Prints for Epoch
        train_av_loss_logger.append(train_loss_count / len(train_dataloader))
        test_av_loss_logger.append(test_loss_count / len(test_dataloader))
        print(f'Training Time: {tt-t0:.3f} s')
        print(f'Validation Time: {tv-tt:.3f} s')
        print('')
        for i in range(2) :
            train_acc = (train_acc_count[i] / counts_train[i]).item()
            test_acc = (test_acc_count[i] / counts_test[i]).item()
            
            print(f'train/test accuracy for class {i}: {train_acc*100.0:.2f}%, {test_acc*100.0:.2f}%')
            train_acc_logger[i].append(train_acc)
            test_acc_logger[i].append(test_acc)             
            
        # Save for each epoch
        model_cpu = model.to('cpu').eval()                                                                                                                                                
        sm = torch.jit.script(model_cpu)                                                                                                                                                       
        sm.save(f"{model_path}_alpha_{ALPHA}_epoch_{epoch}.pt")                                                                                                                                          
        torch.save(model_cpu.state_dict(), f"{model_path}_alpha_{ALPHA}_epoch_{epoch}.pkl")                                                                                                         
        model = model.to(device)                                                                                                                                                          
        print(f"Saved model at epoch {epoch}")   

In [None]:
_ = plt.figure(figsize=(10, 5))
_ = plt.plot(range(1, N_EPOCHS + 1), train_av_loss_logger)
_ = plt.plot(range(1, N_EPOCHS + 1), test_av_loss_logger)
_ = plt.legend(["Train", "Test"])
_ = plt.title("Training Vs Test Av. Loss")
_ = plt.xlabel("Epochs")
_ = plt.ylabel("Loss")