In [1]:
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler

import numpy as np
import pandas as pd
import math
import glob

from model import LSTMDriver
from model import HYPERPARAMS

import logging
logging.basicConfig(filename='training.log',level=logging.DEBUG)

In [8]:
PROJECT_ROOT = '../../../..'
LEARNED_DRIVER = 'snakeoil_miner/data'
DIFFICULTY = 'easy'


TRAINING_FILES = glob.glob('/'.join([PROJECT_ROOT, LEARNED_DRIVER, DIFFICULTY, '*.csv']))
TRAINING_DATA = {}
for FILE in TRAINING_FILES:
    DF = pd.read_csv(FILE, index_col=False)
    TRAINING_DATA[FILE] = DF.values

VALIDATION_FILES = glob.glob('/'.join([PROJECT_ROOT, LEARNED_DRIVER, 'validation', '*.csv']))
VALIDATION_DATA = {}
for FILE in VALIDATION_FILES:
    DF = pd.read_csv(FILE, index_col=False)
    VALIDATION_DATA[FILE] = DF.values
    
CUDA = torch.cuda.is_available()
if CUDA:
    DTYPE = torch.cuda.FloatTensor
else:
    DTYPE = torch.FloatTensor

   0.0    1  -0.0025721498952690908  0.0.1  0.333331  -1.74846e-07  7.07109  \
0  1.0  0.0               -0.002572    0.0  0.333331 -1.748460e-07  7.07109   
1  1.0  0.0               -0.002572    0.0  0.333331 -1.748460e-07  7.07109   
2  1.0  0.0               -0.002572    0.0  0.333331 -1.748460e-07  7.07109   
3  1.0  0.0               -0.002572    0.0  0.333331 -1.748460e-07  7.07109   
4  1.0  0.0               -0.002572    0.0  0.333331 -1.748460e-07  7.07109   

   15.3578  24.0487  41.0277   ...     200.0.2  200.0.3  200.0.4  200.0.5  \
0  15.3578  24.0487  41.0277   ...       200.0    200.0    200.0    200.0   
1  15.3578  24.0487  41.0277   ...       200.0    200.0    200.0    200.0   
2  15.3578  24.0487  41.0277   ...       200.0    200.0    200.0    200.0   
3  15.3578  24.0487  41.0277   ...       200.0    200.0    200.0    200.0   
4  15.3578  24.0487  41.0277   ...       200.0    200.0    200.0    200.0   

   200.0.6  143.356  82.0552  48.0972  30.7155  14.1421  
0   

In [3]:
def save_checkpoint(state, is_best, filepath='latest_checkpoint.pth.tar'):
    torch.save(state, 'checkpoints/' + filepath)
    if is_best:
        torch.save(state, 'checkpoints/best_checkpoint.pth.tar')

In [4]:
def train(training_data, model, criterion):
    loss = 0
    for key in training_data:
        logging.info('--- Parsing track {}-{}'.format(key.split('/')[-2], key.split('/')[-1]))
        print('--- Parsing track {}-{}'.format(key.split('/')[-2], key.split('/')[-1]))
        
        model.hidden = model.init_hidden()
        track_sequence = training_data[key]

        targets = track_sequence[:, 0:3]
        inputs = track_sequence[:, 3:]

        targets_variable = autograd.Variable(torch.Tensor(targets)).type(DTYPE)
        inputs_variable = autograd.Variable(torch.Tensor(inputs),  requires_grad=True).type(DTYPE)

        outputs_variable = model(inputs_variable)

        track_loss = criterion(outputs_variable, targets_variable)

        track_loss.backward()
        optimizer.step()

        loss += track_loss.data[0]
    return loss

In [5]:
def validate(validation_data, model, criterion):
    loss = 0
    for key in validation_data:
        logging.info('--- Parsing track {}-{}'.format(key.split('/')[-2], key.split('/')[-1]))
        print('--- Parsing track {}-{}'.format(key.split('/')[-2], key.split('/')[-1]))
        
        model.hidden = model.init_hidden()
        track_sequence = validation_data[key]

        targets = track_sequence[:, 0:3]
        inputs = track_sequence[:, 3:]

        targets_variable = autograd.Variable(torch.Tensor(targets), volatile=True).type(DTYPE)
        inputs_variable = autograd.Variable(torch.Tensor(inputs), volatile=True).type(DTYPE)

        outputs_variable = model(inputs_variable)

        track_loss = criterion(outputs_variable, targets_variable)

        loss += track_loss.data[0]
    return loss

In [6]:
if CUDA:
    model = LSTMDriver(HYPERPARAMS.INPUT_SIZE,
                       HYPERPARAMS.HIDDEN_SIZE,
                       HYPERPARAMS.NUM_LAYERS,
                       HYPERPARAMS.TARGET_SIZE,
                       HYPERPARAMS.BATCH_SIZE).cuda()
    criterion = nn.MSELoss().cuda()
else:
    model = LSTMDriver(HYPERPARAMS.INPUT_SIZE,
                       HYPERPARAMS.HIDDEN_SIZE,
                       HYPERPARAMS.NUM_LAYERS,
                       HYPERPARAMS.TARGET_SIZE,
                       HYPERPARAMS.BATCH_SIZE)
    criterion = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=HYPERPARAMS.LEARNING_RATE)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', verbose=True)

min_loss = math.inf
losses = {
  'training': [],
  'validation': []
}

for epoch in np.arange(HYPERPARAMS.NUM_EPOCHS):
    if CUDA:
        model.cuda()
        criterion.cuda()
    
    logging.info('Epoch [%d/%d]' %(epoch+1, HYPERPARAMS.NUM_EPOCHS))
    print('Epoch [%d/%d]' %(epoch+1, HYPERPARAMS.NUM_EPOCHS))
    
    is_best = False

    training_loss = train(TRAINING_DATA, model, criterion)
    logging.info('--- TRAINING LOSS: %f' % training_loss)
    print('--- TRAINING LOSS: %f' % training_loss)

    validation_loss = validate(VALIDATION_DATA, model, criterion)
    logging.info('--- VALIDATION LOSS: %f' % validation_loss)
    print('--- VALIDATION LOSS: %f' % validation_loss)

    if validation_loss < min_loss:
        logging.info('--- --- best model found so far: %f' % validation_loss)
        print('--- --- best model found so far: %f' % validation_loss)
        min_loss = validation_loss
        is_best = True

    losses['training'].append(training_loss)
    losses['validation'].append(validation_loss)

    model.cpu()
    save_checkpoint({
          'epoch': epoch + 1,
          'state_dict': model.state_dict(),
          'min_loss': min_loss,
          'optimizer' : optimizer.state_dict(),
      }, is_best)

    scheduler.step(validation_loss)
    logging.info('-------------------------------------------------------')
    print('-------------------------------------------------------')

Epoch [1/100]
--- Parsing track easy-race_01.csv
--- Parsing track easy-race_02.csv
--- Parsing track easy-race_00.csv
--- Parsing track easy-race_03.csv
--- TRAINING LOSS: 0.800636
--- Parsing track validation-race_31.csv
--- Parsing track validation-race_30.csv
--- Parsing track validation-race_32.csv
--- Parsing track validation-race_33.csv
--- VALIDATION LOSS: 0.486004
--- --- best model found so far: 0.486004
-------------------------------------------------------
Epoch [2/100]
--- Parsing track easy-race_01.csv
--- Parsing track easy-race_02.csv
--- Parsing track easy-race_00.csv
--- Parsing track easy-race_03.csv
--- TRAINING LOSS: 0.350198
--- Parsing track validation-race_31.csv
--- Parsing track validation-race_30.csv
--- Parsing track validation-race_32.csv
--- Parsing track validation-race_33.csv
--- VALIDATION LOSS: 0.244157
--- --- best model found so far: 0.244157
-------------------------------------------------------
Epoch [3/100]
--- Parsing track easy-race_01.csv
---

--- Parsing track validation-race_30.csv
--- Parsing track validation-race_32.csv
--- Parsing track validation-race_33.csv
--- VALIDATION LOSS: 0.233894
--- --- best model found so far: 0.233894
-------------------------------------------------------
Epoch [20/100]
--- Parsing track easy-race_01.csv
--- Parsing track easy-race_02.csv
--- Parsing track easy-race_00.csv
--- Parsing track easy-race_03.csv
--- TRAINING LOSS: 0.201153
--- Parsing track validation-race_31.csv
--- Parsing track validation-race_30.csv
--- Parsing track validation-race_32.csv
--- Parsing track validation-race_33.csv
--- VALIDATION LOSS: 0.229652
--- --- best model found so far: 0.229652
-------------------------------------------------------
Epoch [21/100]
--- Parsing track easy-race_01.csv
--- Parsing track easy-race_02.csv
--- Parsing track easy-race_00.csv
--- Parsing track easy-race_03.csv
--- TRAINING LOSS: 0.197181
--- Parsing track validation-race_31.csv
--- Parsing track validation-race_30.csv
--- Parsi

--- Parsing track easy-race_02.csv
--- Parsing track easy-race_00.csv
--- Parsing track easy-race_03.csv
--- TRAINING LOSS: 0.233377
--- Parsing track validation-race_31.csv
--- Parsing track validation-race_30.csv
--- Parsing track validation-race_32.csv
--- Parsing track validation-race_33.csv
--- VALIDATION LOSS: 0.269284
-------------------------------------------------------
Epoch [39/100]
--- Parsing track easy-race_01.csv
--- Parsing track easy-race_02.csv
--- Parsing track easy-race_00.csv
--- Parsing track easy-race_03.csv
--- TRAINING LOSS: 0.232600
--- Parsing track validation-race_31.csv
--- Parsing track validation-race_30.csv
--- Parsing track validation-race_32.csv
--- Parsing track validation-race_33.csv
--- VALIDATION LOSS: 0.268494
-------------------------------------------------------
Epoch [40/100]
--- Parsing track easy-race_01.csv
--- Parsing track easy-race_02.csv
--- Parsing track easy-race_00.csv
--- Parsing track easy-race_03.csv
--- TRAINING LOSS: 0.231698
-

--- VALIDATION LOSS: 0.261652
-------------------------------------------------------
Epoch [57/100]
--- Parsing track easy-race_01.csv
--- Parsing track easy-race_02.csv
--- Parsing track easy-race_00.csv
--- Parsing track easy-race_03.csv
--- TRAINING LOSS: 0.224767
--- Parsing track validation-race_31.csv
--- Parsing track validation-race_30.csv
--- Parsing track validation-race_32.csv
--- Parsing track validation-race_33.csv
--- VALIDATION LOSS: 0.261637
-------------------------------------------------------
Epoch [58/100]
--- Parsing track easy-race_01.csv
--- Parsing track easy-race_02.csv
--- Parsing track easy-race_00.csv
--- Parsing track easy-race_03.csv
--- TRAINING LOSS: 0.224751
--- Parsing track validation-race_31.csv
--- Parsing track validation-race_30.csv
--- Parsing track validation-race_32.csv
--- Parsing track validation-race_33.csv
--- VALIDATION LOSS: 0.261621
-------------------------------------------------------
Epoch [59/100]
--- Parsing track easy-race_01.cs

--- Parsing track validation-race_32.csv
--- Parsing track validation-race_33.csv
--- VALIDATION LOSS: 0.261474
-------------------------------------------------------
Epoch [76/100]
--- Parsing track easy-race_01.csv
--- Parsing track easy-race_02.csv
--- Parsing track easy-race_00.csv
--- Parsing track easy-race_03.csv
--- TRAINING LOSS: 0.224590
--- Parsing track validation-race_31.csv
--- Parsing track validation-race_30.csv
--- Parsing track validation-race_32.csv
--- Parsing track validation-race_33.csv
--- VALIDATION LOSS: 0.261472
-------------------------------------------------------
Epoch [77/100]
--- Parsing track easy-race_01.csv
--- Parsing track easy-race_02.csv
--- Parsing track easy-race_00.csv
--- Parsing track easy-race_03.csv
--- TRAINING LOSS: 0.224588
--- Parsing track validation-race_31.csv
--- Parsing track validation-race_30.csv
--- Parsing track validation-race_32.csv
--- Parsing track validation-race_33.csv
--- VALIDATION LOSS: 0.261470
----------------------

--- Parsing track validation-race_32.csv
--- Parsing track validation-race_33.csv
--- VALIDATION LOSS: 0.261440
-------------------------------------------------------
Epoch [95/100]
--- Parsing track easy-race_01.csv
--- Parsing track easy-race_02.csv
--- Parsing track easy-race_00.csv
--- Parsing track easy-race_03.csv
--- TRAINING LOSS: 0.224557
--- Parsing track validation-race_31.csv
--- Parsing track validation-race_30.csv
--- Parsing track validation-race_32.csv
--- Parsing track validation-race_33.csv
--- VALIDATION LOSS: 0.261438
-------------------------------------------------------
Epoch [96/100]
--- Parsing track easy-race_01.csv
--- Parsing track easy-race_02.csv
--- Parsing track easy-race_00.csv
--- Parsing track easy-race_03.csv
--- TRAINING LOSS: 0.224555
--- Parsing track validation-race_31.csv
--- Parsing track validation-race_30.csv
--- Parsing track validation-race_32.csv
--- Parsing track validation-race_33.csv
--- VALIDATION LOSS: 0.261437
----------------------

In [11]:
model2 = LSTMDriver(HYPERPARAMS.INPUT_SIZE,
                       HYPERPARAMS.HIDDEN_SIZE,
                       HYPERPARAMS.NUM_LAYERS,
                       HYPERPARAMS.TARGET_SIZE,
                       HYPERPARAMS.BATCH_SIZE)
checkpoint = torch.load('checkpoints/best_checkpoint.pth.tar')
model2.load_state_dict(checkpoint['state_dict'])
model2.cpu()
save_checkpoint({
          'epoch': checkpoint['epoch'],
          'state_dict': model2,
          'min_loss': checkpoint['min_loss'],
          'optimizer' : checkpoint['optimizer'],
      }, is_best)