In [1]:
# default_exp train

# Train

> Model training

In [2]:
#hide
%reload_ext autoreload
%autoreload 2
from nbdev.showdoc import *
import warnings
warnings.filterwarnings("ignore")

In [3]:
# export
import os
import sys
import argparse
import time

import numpy as np
import json
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from ti.dataloader import DatasetTraj, zero_padding, getSTW, splitData, file_dir
from ti.prep import Transformer
from ti.model import Difference, TrajectoryDN, TrajectorySN, ContrastiveLoss

In [15]:
# export
torch.cuda.is_available()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 4
EPOCHS = 10
mode = 'sim'
run_name = f'run_simulatedsubtraj_Batch{BATCH_SIZE}_EPOCHS{EPOCHS}'
run_name

'run_simulatedsubtraj_Batch4_EPOCHS10'

In [16]:
# export
def get_data_and_model(params, model='DN', mode='sim'):
    stw = getSTW(mode)
    tr_range, val_range, ts_range = splitData(len(stw))
    # Partitions
    partition = {
        'train': tr_range,
        'validation': val_range
    }
    # Generators
    training_set = DatasetTraj(partition['train'], stw, mode=mode)
    train_g = DataLoader(training_set, **params)
    validation_set = DatasetTraj(partition['validation'], stw, mode=mode)
    val_g = DataLoader(validation_set, **params)
    transformer = Transformer()
    if model == 'DN':
        diff_net = Difference(mode='simple')
        net = TrajectoryDN(diff_net, n_features=(len(transformer.features_traj))*2) # 2x for org and dest 
    else:
        net = TrajectorySN(n_features=len(transformer.features_traj)) 
    net = nn.DataParallel(net)
    net.to(device)
    return train_g, val_g, net

In [17]:
# slow
params = {
    'batch_size': 16,
    'shuffle': True,
    'collate_fn': zero_padding
}
train_g, val_g, net = get_data_and_model(params, mode=mode)
count = 0
for i, (x1, x2, y, x_seq_lens, max_seq_len) in enumerate(train_g):
    print(f'Batch: {i}')
    print(x1.shape)
    print(x2[0].shape, x2[0].shape)
    print(y)
    print(x_seq_lens)
    print(max_seq_len)
    if count >=2:
        break
    count += 1

Batch: 0
torch.Size([9, 9, 4])
torch.Size([9, 9, 4]) torch.Size([9, 9, 4])
[1, 1, 1, 1, 0, 1, 0, 1, 0]
[3, 4, 2, 9, 4, 6, 5, 2, 5]
9


In [18]:
# export
def get_prf(y, y_p):
    p = round(precision_score(y, y_p, average='weighted'), 2)
    r = round(recall_score(y, y_p, average='weighted'), 2)
    f = round(f1_score(y, y_p, average='weighted'), 2)
    return p,r,f

def get_metric(metric, i, loss, acc, p, r, f, time):
    metric['iter'].append(i)
    metric['loss'].append(loss)
    metric['acc'].append(acc)
    metric['p'].append(p); metric['r'].append(r); metric['f'].append(f)
    metric['time_mins'].append(time)
    return metric

def write_metric(tm, vm, mtype, epoch, batch):
    file_tm = f'train_{mtype}_{epoch}_{batch}.json'
    runs_dir = os.path.join(file_dir, '../runs', run_name)
    if not os.path.exists(runs_dir):
        os.mkdir(runs_dir)
    with open (os.path.join(runs_dir, file_tm), 'w') as f:
        json.dump(tm, f)
    file_vm = f'val_{mtype}_{epoch}_{batch}.json'
    with open (os.path.join(runs_dir, file_vm), 'w') as f:
        json.dump(vm, f)

def train(model, train_g, val_g, optimizer, criterion, threshold, epoch, print_at, mtype='DN'):
    model.train()
    y_true = []
    y_pred = []
    epoch_running_loss = 0.0
    tic = time.time()
    iterations = 0
    train_metric = {
        'iter': [],
        'loss': [],
        'acc': [],
        'p': [],
        'r': [],
        'f': [],
        'time_mins': []
    }
    val_metric = {
        'iter': [],
        'loss': [],
        'acc': [],
        'p': [],
        'r': [],
        'f': [],
        'time_mins': []
    }
    for i in range(epoch):
        for x1, x2, y, x_seq_lens, max_seq_len in train_g:
            x1, y, x_seq_lens = torch.Tensor(x1).to(device), torch.Tensor(y).to(device), torch.Tensor(x_seq_lens).to(device)
            org = x2[0]
            dst = x2[1]
            org = torch.Tensor(org).to(device)
            dst = torch.Tensor(dst).to(device)
            x2 = [org, dst]
            # y = torch.Tensor(y)
            # Zero your gradients for every batch!
            optimizer.zero_grad()
            #FORWARD PASS
            if mtype == 'DN':
                output = model(x1, x2, x_seq_lens)
                output = torch.squeeze(output)
                if len(output.shape) == 0:
                    output = output.unsqueeze(0)
                loss = criterion(output, y) 
                predicted_vals = (output > threshold)*1
            else:
                out1, out2 = model(x1, x2, x_seq_lens)
                loss = criterion(out1, out2, y)
                predicted_vals = (torch.pairwise_distance(out1, out2) > threshold)*1
            # Compute the loss and its gradients
            loss.backward()
            # Adjust learning weights
            optimizer.step()
            epoch_running_loss += loss.item()
            # Store Predictions
            y_pred.extend(predicted_vals.tolist())
            y_true.extend(y.tolist())
            iterations += 1
        MODEL_SAVE_PATH = os.path.join(file_dir, '../runs', run_name, f'Epoch{i}.pth')
        if not os.path.exists(os.path.join(file_dir, '../runs', run_name)):
            os.mkdir(os.path.join(file_dir, '../runs', run_name))
        torch.save({
            'epoch': i,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            }, MODEL_SAVE_PATH)
        if i % print_at == print_at-1:
            train_time = round((time.time()-tic)/60.0, 2)
            print(f'Epoch Time (min): {train_time}')
            train_loss = round(epoch_running_loss / iterations, 4) # avg loss
            train_acc = round(accuracy_score(y_true, y_pred), 2)
            tp, tr, tf = get_prf(y_true, y_pred)
            train_metric = get_metric(train_metric, i+1, train_loss, train_acc, tp, tr, tf, train_time)
            tic = time.time()
            val_loss, val_acc, vp, vr, vf = test(model, val_g, criterion, threshold, mtype)
            val_time = round((time.time()-tic)/60.0, 2)
            val_metric = get_metric(val_metric, i+1, val_loss, val_acc, vp, vr, vf, val_time)
            print(f'Prediction Time (min): {val_time}')
            print(
                f'Epoch {i + 1}, Loss (Train, Val) : {train_loss}, {val_loss}, Accuracy (Train, Val): {train_acc}, {val_acc}, PRF (Val): {vp},{vr},{vf}'
            )
            print("***********************************************************")
            write_metric(train_metric, val_metric, mtype, epoch, train_g.batch_size)
            tic = time.time()
            model.train()
    return train_metric, val_metric
                          
def test(model, test_loader, criterion, threshold, mtype='DN'):
    #model in eval mode skips Dropout etc
    model.eval()
    y_true = []
    y_pred = []
    running_loss = 0.0
    # set the requires_grad flag to false as we are in the test mode
    with torch.no_grad():
        for x1, x2, y, x_seq_lens, max_seq_len in test_loader:
            x1, y, x_seq_lens = torch.Tensor(x1).to(device), torch.Tensor(y).to(device), torch.Tensor(x_seq_lens).to(device)
            org = x2[0]
            dst = x2[1]
            org = torch.Tensor(org).to(device)
            dst = torch.Tensor(dst).to(device)
            x2 = [org, dst]
            # y = torch.Tensor(y)
            # the model on the data
            if mtype == 'DN':
                output = model(x1, x2, x_seq_lens)
                output = torch.squeeze(output)
                if len(output.shape) == 0:
                    output = output.unsqueeze(0)
                loss = criterion(output, y) 
                pred = np.array((output.cpu() > threshold)*1)
            else:
                out1, out2 = model(x1, x2, x_seq_lens)
                loss = criterion(out1, out2, y)
                pred = np.array((torch.pairwise_distance(out1, out2) > threshold)*1)
            target = y.float()
            running_loss += loss.item()
            y_true.extend(target.tolist()) 
            y_pred.extend(pred.reshape(-1).tolist())
    avg_loss = round(running_loss / len(test_loader), 4)
    acc = round(accuracy_score(y_true, y_pred), 2)
    p,r,f = get_prf(y_true, y_pred)
    return avg_loss, acc, p,r,f

In [19]:
# export
def train_DN(train_g, val_g, net, lr, threshold, num_epochs, print_at):
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    train_metric, val_metric = train(
        net, train_g, val_g, optimizer, criterion, threshold, num_epochs, print_at, mtype='DN'
    )
    return train_metric, val_metric

In [20]:
# export
def train_SN(train_g, val_g, net, lr, threshold, num_epochs, print_at):
    criterion = ContrastiveLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    train_metric, val_metric = train(
        net, train_g, val_g, optimizer, criterion, threshold, num_epochs, print_at, mtype='SN'
    )
    return train_metric, val_metric

In [22]:
# slow
# usage: train_DN
params = {
    'batch_size': BATCH_SIZE,
    'shuffle': True,
    'collate_fn': zero_padding
}
lr = 1e-3
threshold = 0.5
num_epochs = EPOCHS
print_at = 1
train_g, val_g, net = get_data_and_model(params, model='DN', mode=mode)
MODEL_SAVE_PATH = os.path.join(file_dir, '../runs', run_name, f'Epoch{998}.pth')
if os.path.exists(MODEL_SAVE_PATH):
    print('Loading pre-trained model and optimizer weights from ', MODEL_SAVE_PATH)
    checkpoint = torch.load(MODEL_SAVE_PATH)
    net.load_state_dict(checkpoint['model_state_dict'])
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    net.train()
else:
    print('Did not find pre-trained weights at %s, starting training without them'%(MODEL_SAVE_PATH))
_, _ = train_DN(train_g, val_g, net, lr, threshold, num_epochs, print_at)

Did not find pre-trained weights at /Users/arunabha/fourkites/codebase/a3/paper/DifferenceNet/ti/ti/../runs/run_simulatedsubtraj_Batch4_EPOCHS10/Epoch998.pth, starting training without them
Epoch Time (min): 0.0
Prediction Time (min): 0.0
Epoch 1, Loss (Train, Val) : 0.6667, 0.7961, Accuracy (Train, Val): 0.56, 0.0, PRF (Val): 0.0,0.0,0.0
***********************************************************
Epoch Time (min): 0.0
Prediction Time (min): 0.0
Epoch 2, Loss (Train, Val) : 0.7069, 0.6046, Accuracy (Train, Val): 0.44, 1.0, PRF (Val): 1.0,1.0,1.0
***********************************************************
Epoch Time (min): 0.0
Prediction Time (min): 0.0
Epoch 3, Loss (Train, Val) : 0.6937, 0.6061, Accuracy (Train, Val): 0.48, 1.0, PRF (Val): 1.0,1.0,1.0
***********************************************************
Epoch Time (min): 0.0
Prediction Time (min): 0.0
Epoch 4, Loss (Train, Val) : 0.7062, 0.6083, Accuracy (Train, Val): 0.44, 1.0, PRF (Val): 1.0,1.0,1.0
**************************

In [23]:
# export
if __name__ == '__main__' and '__file__' in globals():
    args = sys.argv[1:]
    parser = argparse.ArgumentParser()
    models = ['SN', 'DN']
    parser.add_argument(
        '-m', '--model',
        help="""Choose model, SN : Siamese Net, DN : Difference Net""",
        choices=models
    )
    parser.add_argument(
        '-e', '--epoch',
        help="""Number of Epochs"""
    )
    parser.add_argument(
        '-p', '--print',
        help="""Print at every p step, p must not be greater than e"""
    )
    results = parser.parse_args(args)
    params = {
        'batch_size': 4,
        'shuffle': True,
        'collate_fn': zero_padding
    }
    lr = 1e-3
    threshold = 0.5
    num_epochs = int(results.epoch) if results.epoch else 10
    print_at = int(results.print) if results.print else 1
    if print_at > num_epochs:
        raise ValueError('p must not be greater than e')
    mtype = results.model if results.model else 'DN'
    print(f'Training starting with: model={mtype}, epoch={num_epochs}, print every={print_at}')
    train_g, val_g, net = get_data_and_model(params, mtype)
    if mtype == 'DN':
        train_DN(train_g, val_g, net, lr, threshold, num_epochs, print_at)
    else:
        train_SN(train_g, val_g, net, lr, threshold, num_epochs, print_at)

In [24]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_prep.ipynb.
Converted 01_dataloader.ipynb.
Converted 02_model.ipynb.
Converted 03_train.ipynb.
Converted index.ipynb.
