# PyTorch Classification 

This notebook is a basic example for training and testing LSTM model with PyTorch.

## Defining the PyTorch model and runner

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

import json
import numpy as np
import datetime
from tqdm.notebook import tqdm

In [2]:
class LSTMnet(nn.Module):
    """
    LSTM for predicting task id from step scores
    """
    def __init__(self, timesteps, input_dim, output_dim, hidden_dim, bidirectional=False):
        super(LSTMnet, self).__init__()
        self.timesteps = timesteps
        self.hidden_dim = hidden_dim

        # The LSTM takes step scores as inputs, and outputs hidden states
        # with dimensionality hidden_dim.
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=2, dropout=0.2, bidirectional=bidirectional)

        # The linear layer that maps from hidden state space to task space
        insize = hidden_dim * 2 if bidirectional else hidden_dim
        self.hidden2out = nn.Linear(insize, output_dim)

    def forward(self, input_feat):
        # input_feat shape is (batch_size, feature_dim)
        # this is reshaped into shape (batch_size, timesteps, feature_dim // timesteps)
        # then converted to (timesteps, batch_size, feature_dim // timesteps)
        batch_size = input_feat.size(0)
        input_feat = input_feat.reshape(batch_size, self.timesteps, input_feat.size(1) // self.timesteps).transpose(0, 1)
        lstm_out, _ = self.lstm(input_feat)
        return self.hidden2out(lstm_out[-1])

In [3]:
class Trainer:
    def __init__ (self, timesteps=10, input_feat_size=30, num_classes=2, hidden_feat_size=128):
        """
        Trainer for the LSTM
        """
        self.net = LSTMnet(timesteps, input_feat_size, num_classes, hidden_feat_size, True)
        
        self.cuda_flag = torch.cuda.is_available()
        if self.cuda_flag:
            self.net = self.net.cuda()

    def train(self, X, Y, epochs=10, lr=0.001, batch_size=64, decay=5000, logging=True, log_file=None):
        """
        Train the LSTM

        Params:
        @X: Training data - input of the model
        @Y: Training labels
        @logging: True for printing the training progress after each epoch
        @log_file: Path of log file
        """        
        inputs = torch.FloatTensor(X)
        labels = torch.LongTensor(Y)
        
        train_dataset = TensorDataset(inputs, labels)
        trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

        self.net.train()
        criterion = nn.CrossEntropyLoss()

        for epoch in range(1, epochs+1): # loop over data multiple times
            # Decreasing the learning rate
            if (epoch % decay == 0):
                lr /= 3
                
            optimizer = optim.SGD(self.net.parameters(), lr=lr, momentum=0.9)
            
            tot_loss = 0.0
            for data in tqdm(trainloader):
                # get the inputs
                inputs, labels = data
                if self.cuda_flag:
                    inputs = inputs.cuda()
                    labels = labels.cuda()

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward + backward + optimize
                o = self.net(inputs)
                loss = criterion(o, labels)

                loss.backward()
                optimizer.step()
                
                tot_loss += loss.item()
                
            tot_loss /= len(trainloader)

            # logging statistics
            timestamp = str(datetime.datetime.now()).split('.')[0]
            log = json.dumps({
                'timestamp': timestamp,
                'epoch': epoch,
                'loss': float('%.7f' % tot_loss),
                'lr': float('%.6f' % lr)
            })
            if logging:
                print (log)

            if log_file is not None:
                with open(log_file, 'a') as f:
                    f.write("{}\n".format(log))
            
        print ('Finished Training')

    def predict(self, inputs):
        """
        Predict the task labels corresponding to the input images
        """
        inputs = torch.FloatTensor(inputs)
        if self.cuda_flag:
            inputs = inputs.cuda()

        self.net.eval()
        with torch.no_grad():
            labels = self.net(inputs).cpu().numpy()
            
        return np.argmax(labels, axis=1)

    def score(self, X, Y):
        """
        Score the model -- compute accuracy
        """
        pred = self.predict(X)
        acc = np.sum(pred == Y) / len(Y)
        return float(acc)

    def save_model(self, checkpoint_path, model=None):
        if model is None: model = self.net
        torch.save(model.state_dict(), checkpoint_path)
    
    def load_model(self, checkpoint_path, model=None):
        if model is None: model = self.net
        if self.cuda_flag:
            model.load_state_dict(torch.load(checkpoint_path))
        else:
            model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu')))


## Get the Data

In [4]:
import pickle
import os
from subprocess import call

%pip install root_numpy
from root_numpy import root2array

Note: you may need to restart the kernel to use updated packages.
Welcome to JupyROOT 6.21/01


In [5]:
def get_dataset(data_file, test_split=0.2):
    """
    @data_file: path of the root file
    @test_split: fraction of data to be used as test set
    """
    assert 0 <= test_split <= 0.3

    # Load data
    signal = root2array(data_file, 'sgn')
    background = root2array(data_file, 'bkg')

    tree_data = [
        np.array([np.array(list(img)).reshape(-1) for img in signal]),
        np.array([np.array(list(img)).reshape(-1) for img in background])
    ]

    X_train = []
    Y_train = []
    X_test = []
    Y_test = []

    # Deterministic Random
    np.random.seed(0)

    for label, data in enumerate(tree_data):
        np.random.shuffle(data)

        test_size = int(len(data) * test_split)
        X_train.append(data[:-test_size])
        X_test.append(data[-test_size:])
        Y_train.append([label] * (len(data) - test_size))
        Y_test.append([label] * test_size)

    X_train = np.concatenate(X_train, axis=0)
    X_test = np.concatenate(X_test, axis=0)
    Y_train = np.concatenate(Y_train)
    Y_test = np.concatenate(Y_test)

    assert len(Y_train) == len(X_train)
    assert len(Y_test) == len(X_test)

    # Shuffling the data
    train_perm = np.random.permutation(len(X_train))
    X_train = X_train[train_perm]
    Y_train = Y_train[train_perm]
    
    test_perm = np.random.permutation(len(X_test))
    X_test = X_test[test_perm]
    Y_test = Y_test[test_perm]
    
    return (X_train, Y_train), (X_test, Y_test)

In [6]:
# Load data
inputFileName = "sample_timedata_t10_d30.root"

if not os.path.isfile(inputFileName):
    call(['curl', '-o', inputFileName, 'https://cernbox.cern.ch/index.php/s/Pc8SBmHfvU1X0mb/download'])

(X_train, Y_train), (X_test, Y_test) = get_dataset(inputFileName, test_split=0.2)

## Train the model

In [7]:
# Model config
lr = 0.001
num_epochs = 30
batch_size = 64

In [8]:
trainer = Trainer(10, 30, 2, 256)
trainer.train(X_train, Y_train, lr=lr, epochs=num_epochs, batch_size=batch_size, decay=10)
trainer.save_model('model_LSTM.pth')

HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


{"timestamp": "2020-03-17 21:19:44", "epoch": 1, "loss": 0.6636981, "lr": 0.001}


HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


{"timestamp": "2020-03-17 21:20:02", "epoch": 2, "loss": 0.5358954, "lr": 0.001}


HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


{"timestamp": "2020-03-17 21:20:21", "epoch": 3, "loss": 0.4596487, "lr": 0.001}


HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


{"timestamp": "2020-03-17 21:20:40", "epoch": 4, "loss": 0.4607354, "lr": 0.001}


HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


{"timestamp": "2020-03-17 21:20:59", "epoch": 5, "loss": 0.4407651, "lr": 0.001}


HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


{"timestamp": "2020-03-17 21:21:18", "epoch": 6, "loss": 0.4324736, "lr": 0.001}


HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


{"timestamp": "2020-03-17 21:21:36", "epoch": 7, "loss": 0.4196834, "lr": 0.001}


HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


{"timestamp": "2020-03-17 21:21:55", "epoch": 8, "loss": 0.411794, "lr": 0.001}


HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


{"timestamp": "2020-03-17 21:22:14", "epoch": 9, "loss": 0.4052742, "lr": 0.001}


HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


{"timestamp": "2020-03-17 21:22:32", "epoch": 10, "loss": 0.3849149, "lr": 0.000333}


HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


{"timestamp": "2020-03-17 21:22:51", "epoch": 11, "loss": 0.3784899, "lr": 0.000333}


HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


{"timestamp": "2020-03-17 21:23:10", "epoch": 12, "loss": 0.3717675, "lr": 0.000333}


HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


{"timestamp": "2020-03-17 21:23:29", "epoch": 13, "loss": 0.3687891, "lr": 0.000333}


HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


{"timestamp": "2020-03-17 21:23:48", "epoch": 14, "loss": 0.360205, "lr": 0.000333}


HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


{"timestamp": "2020-03-17 21:24:07", "epoch": 15, "loss": 0.3629168, "lr": 0.000333}


HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


{"timestamp": "2020-03-17 21:24:25", "epoch": 16, "loss": 0.3604168, "lr": 0.000333}


HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


{"timestamp": "2020-03-17 21:24:44", "epoch": 17, "loss": 0.359501, "lr": 0.000333}


HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


{"timestamp": "2020-03-17 21:25:03", "epoch": 18, "loss": 0.3546436, "lr": 0.000333}


HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


{"timestamp": "2020-03-17 21:25:22", "epoch": 19, "loss": 0.3557539, "lr": 0.000333}


HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


{"timestamp": "2020-03-17 21:25:41", "epoch": 20, "loss": 0.3451105, "lr": 0.000111}


HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


{"timestamp": "2020-03-17 21:26:00", "epoch": 21, "loss": 0.3433382, "lr": 0.000111}


HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


{"timestamp": "2020-03-17 21:26:18", "epoch": 22, "loss": 0.3453094, "lr": 0.000111}


HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


{"timestamp": "2020-03-17 21:26:37", "epoch": 23, "loss": 0.3416147, "lr": 0.000111}


HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


{"timestamp": "2020-03-17 21:26:56", "epoch": 24, "loss": 0.3444362, "lr": 0.000111}


HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


{"timestamp": "2020-03-17 21:27:15", "epoch": 25, "loss": 0.3437231, "lr": 0.000111}


HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


{"timestamp": "2020-03-17 21:27:34", "epoch": 26, "loss": 0.3422212, "lr": 0.000111}


HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


{"timestamp": "2020-03-17 21:27:53", "epoch": 27, "loss": 0.3406938, "lr": 0.000111}


HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


{"timestamp": "2020-03-17 21:28:12", "epoch": 28, "loss": 0.3401629, "lr": 0.000111}


HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


{"timestamp": "2020-03-17 21:28:31", "epoch": 29, "loss": 0.3397181, "lr": 0.000111}


HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


{"timestamp": "2020-03-17 21:28:50", "epoch": 30, "loss": 0.3385941, "lr": 3.7e-05}
Finished Training


## Test the model

In [9]:
print ('Evaluating...')
print ('Training Acc.: {:.4f} %'.format(trainer.score(X_train, Y_train) * 100))
print ('Test Acc.    : {:.4f} %'.format(trainer.score(X_test, Y_test) * 100))

Evaluating...
Training Acc.: 86.0500 %
Test Acc.    : 83.9500 %
