In [1]:
import os
from collections import Counter
from glob import glob

import torch
from torch import nn
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, accuracy_score, confusion_matrix

from nn_helpers import ParticleDS, train_loop, test_loop

In [2]:
# Load the data in using our custom dataset class (which is defined in the nn_helpers file)
batch_size = 64

#predictors = ['p', 'theta', 'beta', 'nphe', 'ein', 'eout']
predictors = ['p_scaled', 'theta_scaled', 'beta_scaled', 'nphe_scaled', 'ein_scaled', 'eout_scaled']
outcome = 'id'

ds_size = '500k'
train_ds = ParticleDS(f'../data/pid_{ds_size}_train_balanced.csv', predictors, outcome)
test_ds = ParticleDS(f'../data/pid_{ds_size}_test.csv', predictors, outcome)

print(Counter(train_ds.y))
print(Counter(test_ds.y))

train_dataloader = DataLoader(train_ds, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

Counter({2: 1189, 1: 1189, 3: 1189, 0: 1189})
Counter({1: 56238, 3: 38815, 0: 4650, 2: 297})


In [3]:
# A neural net is a sequence of layers with different numbers of parameters
# This one has 4 layers
# - an input linear layer with 6 inputs (the number of predictor variables) and 16 outputs (this is arbitrary)
# - an Rectified Linear Unit layer which adjusts the results of the first layer to be 0 below 0
# - a second linear layer with 16 inputs and 4 outputs (the number of particle types)
# - a final layer that gets the most likely output of the 4 previous output values to give a final prediction
class TinyModel(nn.Module):
    def __init__(self, n_hidden=16, n_middle=0):
        super(TinyModel, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(in_features=6, out_features=n_hidden),
            nn.ReLU(),
        )
        for _ in range(n_middle):
            self.model.append(nn.Linear(in_features=n_hidden, out_features=n_hidden))
            self.model.append(nn.ReLU())
        self.model.append(nn.Linear(in_features=n_hidden, out_features=4))
        self.model.append(nn.Sigmoid())

    def forward(self, x):
        return self.model(x)

In [4]:
# Initialising the model here will start the training fresh each time
# Skip it if you just want to continue on training the model as is
nn_model = TinyModel(n_hidden=12, n_middle=0)

In [10]:
# Learning rate is how fast the model learns and epochs are the number of times the full training data are passed through the model
# Experimenting with these to get the best model is called hyperparameter tuning
learning_rate = .1 # a good learning rate allows the training and test loss to come down in step with one another but not too slowly
# set the learning rate higher when you have less data and lower when you have more to avoid overfitting
epochs = 1000 # set epochs as high as you like for this task, it's an upper limit and we can continue as long as the test set accuracy is increasing (and/or test set loss is decreasing)

# Cross Entropy Loss is a popular way to measure the difference between predicted categories and actual categories (aka loss)
loss_fn = nn.CrossEntropyLoss()
# Stochastic Gradient Descent is the method that updates the model parameters (aka how it learns)
optimiser = torch.optim.SGD(nn_model.parameters(), lr=learning_rate)

# The train loop and test loop are defined in our nn_helpers file - have a look at those to see what's happening in each loop
accuracies = [0]
for t in range(epochs):
    print(('-' * 30) + f'\nEpoch {t+1}')
    train_loss = train_loop(train_dataloader, nn_model, loss_fn, optimiser, batch_size, print_freq=0)
    accuracy, test_loss, avg_acc = test_loop(test_dataloader, nn_model, loss_fn)
    print(f'loss: {train_loss:.3f} | {test_loss:.3f}, acc: {accuracy:.4f}, avg acc: {avg_acc:.4f}')
    if avg_acc >= .75:
        model_path = f'../models/nn_{str(avg_acc)[2:5]}_{ds_size.lower()}.pt'
        if not os.path.exists(model_path):
            torch.save(nn_model.state_dict(), model_path)
    # break after avg_acc (weighted accuracy) on test set drops significantly
    if len(accuracies) > 1:
        if avg_acc/max(accuracies) <= .9:
            print("Early stop - test accuracy drop")
            break
    accuracies.append(avg_acc)
print("Done!")

------------------------------
Epoch 1
loss: 0.750 | 0.823, acc: 0.9108, avg acc: 0.9333
------------------------------
Epoch 2
loss: 0.750 | 0.823, acc: 0.9103, avg acc: 0.9335
------------------------------
Epoch 3
loss: 0.750 | 0.823, acc: 0.9103, avg acc: 0.9336
------------------------------
Epoch 4
loss: 0.750 | 0.823, acc: 0.9103, avg acc: 0.9336
------------------------------
Epoch 5
loss: 0.750 | 0.823, acc: 0.9103, avg acc: 0.9336
------------------------------
Epoch 6
loss: 0.750 | 0.823, acc: 0.9103, avg acc: 0.9336
------------------------------
Epoch 7
loss: 0.750 | 0.823, acc: 0.9103, avg acc: 0.9336
------------------------------
Epoch 8
loss: 0.750 | 0.823, acc: 0.9104, avg acc: 0.9337
------------------------------
Epoch 9
loss: 0.750 | 0.823, acc: 0.9104, avg acc: 0.9336
------------------------------
Epoch 10
loss: 0.750 | 0.823, acc: 0.9104, avg acc: 0.9337
------------------------------
Epoch 11
loss: 0.750 | 0.823, acc: 0.9104, avg acc: 0.9337
-------------------

KeyboardInterrupt: 

In [15]:
# Get the final performance metrics on the test set
best_model_path = sorted(glob('../models/*_500k.pt'))[-1]
print(best_model_path)
weights_biases = torch.load(best_model_path)
n_hidden = weights_biases['model.0.bias'].shape[0]
n_middle = len(weights_biases)//2 - 2

nn_model = TinyModel(n_hidden, n_middle)
nn_model.load_state_dict(weights_biases)
nn_model.eval()

pred_y = torch.argmax(nn_model(test_ds.X), dim=1).detach().numpy()
print(classification_report(test_ds.y, pred_y))
accuracy = accuracy_score(test_ds.y, pred_y)

conf_mat = confusion_matrix(test_ds.y, pred_y)
print(conf_mat)
class_acc = (conf_mat.diagonal()/conf_mat.sum(1))
print(' | '.join([str(round(acc, 3)) for acc in class_acc]))
print(f'avg. accuracy: {class_acc.mean():.3f}')

../models/nn_791_5m.pt
              precision    recall  f1-score   support

           0       0.35      0.96      0.52      4650
           1       0.99      0.86      0.92     56238
           2       0.09      0.38      0.14       297
           3       0.98      0.95      0.97     38815

    accuracy                           0.90    100000
   macro avg       0.61      0.79      0.64    100000
weighted avg       0.96      0.90      0.92    100000

[[ 4482    75    12    81]
 [ 6211 48341  1131   555]
 [    0   184   113     0]
 [ 1979     1    10 36825]]
0.964 | 0.86 | 0.38 | 0.949
avg. accuracy: 0.788
