In [4]:
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler

import numpy as np
import pandas as pd
import math
import glob
import time
import pickle

import steering
import speeding

import logging

In [5]:
PROJECT_ROOT = '../../../..'
LEARNED_DRIVER = 'snakeoil_miner/data'
DIFFICULTY = ['easy', 'medium', 'hard']

# COMMAND = 'steering'
COMMAND = 'speeding'
logging.basicConfig(filename='logs/{}/training-{}.log'.format(COMMAND, time.time()),level=logging.DEBUG)

NUM_EPOCHS = 200

TRAINING_FILES = []
for d in DIFFICULTY:
    TRAINING_FILES.extend(glob.glob('/'.join([PROJECT_ROOT, LEARNED_DRIVER, d, '*.csv'])))

TRAINING_DATA = {}
for FILE in TRAINING_FILES:
    DF = pd.read_csv(FILE, index_col=False)
    TRAINING_DATA[FILE] = DF.values
    
CUDA = torch.cuda.is_available()
if CUDA:
    DTYPE = torch.cuda.FloatTensor
else:
    DTYPE = torch.FloatTensor

In [6]:
def save_checkpoint(state, is_best, filepath='latest_checkpoint.tar'):
    torch.save(state, 'split_checkpoints/' + COMMAND + '/' + filepath)
    if is_best:
        torch.save(state, 'split_checkpoints/' + COMMAND + '/' + 'best_checkpoint.tar')

In [4]:
def train(training_data, model, criterion, optimzier):
    loss = 0
    model.train(mode=True)
    for key in training_data:
        logging.info('--- Parsing track {}-{}'.format(key.split('/')[-2], key.split('/')[-1]))
        print('--- Parsing track {}-{}'.format(key.split('/')[-2], key.split('/')[-1]))
        
        model.hidden_state = model.init_hidden()
        optimizer.zero_grad()
        track_sequence = training_data[key]

        if COMMAND == 'steering':
            targets = track_sequence[:, 2:3]
        elif COMMAND == 'speeding':
            targets = track_sequence[:, 0:2]
        
        inputs = track_sequence[:, 3:]

        targets_variable = autograd.Variable(torch.Tensor(targets)).type(DTYPE)
        inputs_variable = autograd.Variable(torch.Tensor(inputs),  requires_grad=True).type(DTYPE)

        outputs_variable = model(inputs_variable)

        track_loss = criterion(outputs_variable, targets_variable)

        track_loss.backward()
        optimizer.step()

        loss += track_loss.data[0]
    return loss

In [5]:
def validate(validation_data, model, criterion):
    loss = 0
    model.train(mode=False)
    for key in validation_data:
        logging.info('--- Parsing track {}-{}'.format(key.split('/')[-2], key.split('/')[-1]))
        print('--- Parsing track {}-{}'.format(key.split('/')[-2], key.split('/')[-1]))
        
        model.hidden_state = model.init_hidden()
        track_sequence = validation_data[key]

        if COMMAND == 'steering':
            targets = track_sequence[:, 2:3]
        elif COMMAND == 'speeding':
            targets = track_sequence[:, 0:2]
            
        inputs = track_sequence[:, 3:]

        targets_variable = autograd.Variable(torch.Tensor(targets), volatile=True).type(DTYPE)
        inputs_variable = autograd.Variable(torch.Tensor(inputs), volatile=True).type(DTYPE)

        outputs_variable = model(inputs_variable)

        track_loss = criterion(outputs_variable, targets_variable)

        loss += track_loss.data[0]
    return loss

In [None]:
if COMMAND == 'steering':
    model = steering.Steering(steering.HYPERPARAMS.INPUT_SIZE,
                               steering.HYPERPARAMS.LSTM_HIDDEN_SIZE,
                               steering.HYPERPARAMS.HIDDEN_LAYER_SIZE,
                               steering.HYPERPARAMS.DROPOUT_PROB,
                               steering.HYPERPARAMS.NUM_LAYERS,
                               steering.HYPERPARAMS.TARGET_SIZE,
                               steering.HYPERPARAMS.BATCH_SIZE)
    optimizer = optim.Adam(model.parameters(), lr=steering.HYPERPARAMS.LEARNING_RATE)
elif COMMAND == 'speeding':
    model = speeding.Speeding(speeding.HYPERPARAMS.INPUT_SIZE,
                               speeding.HYPERPARAMS.LSTM_HIDDEN_SIZE,
                               speeding.HYPERPARAMS.HIDDEN_LAYER_SIZE,
                               speeding.HYPERPARAMS.DROPOUT_PROB,
                               speeding.HYPERPARAMS.NUM_LAYERS,
                               speeding.HYPERPARAMS.TARGET_SIZE,
                               speeding.HYPERPARAMS.BATCH_SIZE)
    optimizer = optim.Adam(model.parameters(), lr=speeding.HYPERPARAMS.LEARNING_RATE)
if CUDA:
    model.cuda()
        
criterion = nn.MSELoss().cuda()

scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', verbose=True)

min_loss = math.inf
losses = {
  'training': [],
  'validation': []
}

In [None]:
logging.info('Training %s...' % COMMAND)
print('Training %s...' % COMMAND)

for epoch in np.arange(NUM_EPOCHS):
    logging.info('Epoch [%d/%d]' %(epoch+1, NUM_EPOCHS))
    print('Epoch [%d/%d]' %(epoch+1, NUM_EPOCHS))
    
    is_best = False

    training_loss = train(TRAINING_DATA, model, criterion, optimizer)
    logging.info('--- TRAINING LOSS: %f' % training_loss)
    print('--- TRAINING LOSS: %f' % training_loss)

#     validation_loss = validate(VALIDATION_DATA, model, criterion)
#     logging.info('--- VALIDATION LOSS: %f' % validation_loss)
#     print('--- VALIDATION LOSS: %f' % validation_loss)

    if training_loss < min_loss:
        logging.info('--- --- best model found so far: %f' % training_loss)
        print('--- --- best model found so far: %f' % training_loss)
        min_loss = training_loss
        is_best = True

    losses['training'].append(training_loss)
#     losses['validation'].append(validation_loss)

    save_checkpoint({
          'epoch': epoch + 1,
          'state_dict': model.state_dict(),
          'min_loss': min_loss,
          'optimizer' : optimizer.state_dict(),
      }, is_best)

    scheduler.step(training_loss)
    logging.info('-------------------------------------------------------')
    print('-------------------------------------------------------')

Training speeding...
Epoch [1/200]
--- Parsing track medium-race_10.csv
--- Parsing track easy-race_03.csv
--- Parsing track hard-race_22.csv
--- Parsing track easy-race_00.csv
--- Parsing track easy-race_02.csv
--- Parsing track hard-race_21.csv
--- Parsing track medium-race_13.csv
--- Parsing track easy-race_01.csv
--- Parsing track hard-race_20.csv
--- Parsing track medium-race_12.csv
--- Parsing track hard-race_23.csv
--- Parsing track medium-race_11.csv
--- TRAINING LOSS: 1.540664
--- --- best model found so far: 1.540664
-------------------------------------------------------
Epoch [2/200]
--- Parsing track medium-race_10.csv
--- Parsing track easy-race_03.csv
--- Parsing track hard-race_22.csv
--- Parsing track easy-race_00.csv
--- Parsing track easy-race_02.csv
--- Parsing track hard-race_21.csv
--- Parsing track medium-race_13.csv
--- Parsing track easy-race_01.csv
--- Parsing track hard-race_20.csv
--- Parsing track medium-race_12.csv
--- Parsing track hard-race_23.csv
--- Pa

--- Parsing track easy-race_00.csv
--- Parsing track easy-race_02.csv
--- Parsing track hard-race_21.csv
--- Parsing track medium-race_13.csv
--- Parsing track easy-race_01.csv
--- Parsing track hard-race_20.csv
--- Parsing track medium-race_12.csv
--- Parsing track hard-race_23.csv
--- Parsing track medium-race_11.csv
--- TRAINING LOSS: 0.745814
-------------------------------------------------------
Epoch [24/200]
--- Parsing track medium-race_10.csv
--- Parsing track easy-race_03.csv
--- Parsing track hard-race_22.csv
--- Parsing track easy-race_00.csv
--- Parsing track easy-race_02.csv
--- Parsing track hard-race_21.csv
--- Parsing track medium-race_13.csv
--- Parsing track easy-race_01.csv
--- Parsing track hard-race_20.csv
--- Parsing track medium-race_12.csv
--- Parsing track hard-race_23.csv
--- Parsing track medium-race_11.csv
--- TRAINING LOSS: 0.729313
-------------------------------------------------------
Epoch [25/200]
--- Parsing track medium-race_10.csv
--- Parsing trac

--- Parsing track hard-race_20.csv
--- Parsing track medium-race_12.csv
--- Parsing track hard-race_23.csv
--- Parsing track medium-race_11.csv
--- TRAINING LOSS: 0.703903
-------------------------------------------------------
Epoch [39/200]
--- Parsing track medium-race_10.csv
--- Parsing track easy-race_03.csv
--- Parsing track hard-race_22.csv
--- Parsing track easy-race_00.csv
--- Parsing track easy-race_02.csv
--- Parsing track hard-race_21.csv
--- Parsing track medium-race_13.csv
--- Parsing track easy-race_01.csv
--- Parsing track hard-race_20.csv
--- Parsing track medium-race_12.csv
--- Parsing track hard-race_23.csv
--- Parsing track medium-race_11.csv
--- TRAINING LOSS: 0.690243
-------------------------------------------------------
Epoch [40/200]
--- Parsing track medium-race_10.csv
--- Parsing track easy-race_03.csv
--- Parsing track hard-race_22.csv
--- Parsing track easy-race_00.csv
--- Parsing track easy-race_02.csv
--- Parsing track hard-race_21.csv
--- Parsing track 

--- Parsing track easy-race_03.csv
--- Parsing track hard-race_22.csv
--- Parsing track easy-race_00.csv
--- Parsing track easy-race_02.csv
--- Parsing track hard-race_21.csv
--- Parsing track medium-race_13.csv
--- Parsing track easy-race_01.csv
--- Parsing track hard-race_20.csv
--- Parsing track medium-race_12.csv
--- Parsing track hard-race_23.csv
--- Parsing track medium-race_11.csv
--- TRAINING LOSS: 0.722613
-------------------------------------------------------
Epoch [55/200]
--- Parsing track medium-race_10.csv
--- Parsing track easy-race_03.csv
--- Parsing track hard-race_22.csv
--- Parsing track easy-race_00.csv
--- Parsing track easy-race_02.csv
--- Parsing track hard-race_21.csv
--- Parsing track medium-race_13.csv
--- Parsing track easy-race_01.csv
--- Parsing track hard-race_20.csv
--- Parsing track medium-race_12.csv
--- Parsing track hard-race_23.csv
--- Parsing track medium-race_11.csv
--- TRAINING LOSS: 0.698056
------------------------------------------------------

--- Parsing track medium-race_11.csv
--- TRAINING LOSS: 0.592681
--- --- best model found so far: 0.592681
-------------------------------------------------------
Epoch [69/200]
--- Parsing track medium-race_10.csv
--- Parsing track easy-race_03.csv
--- Parsing track hard-race_22.csv
--- Parsing track easy-race_00.csv
--- Parsing track easy-race_02.csv
--- Parsing track hard-race_21.csv
--- Parsing track medium-race_13.csv
--- Parsing track easy-race_01.csv
--- Parsing track hard-race_20.csv
--- Parsing track medium-race_12.csv
--- Parsing track hard-race_23.csv
--- Parsing track medium-race_11.csv
--- TRAINING LOSS: 0.589296
--- --- best model found so far: 0.589296
-------------------------------------------------------
Epoch [70/200]
--- Parsing track medium-race_10.csv
--- Parsing track easy-race_03.csv
--- Parsing track hard-race_22.csv
--- Parsing track easy-race_00.csv
--- Parsing track easy-race_02.csv
--- Parsing track hard-race_21.csv
--- Parsing track medium-race_13.csv
--- 

--- Parsing track easy-race_00.csv
--- Parsing track easy-race_02.csv
--- Parsing track hard-race_21.csv
--- Parsing track medium-race_13.csv
--- Parsing track easy-race_01.csv
--- Parsing track hard-race_20.csv
--- Parsing track medium-race_12.csv
--- Parsing track hard-race_23.csv
--- Parsing track medium-race_11.csv
--- TRAINING LOSS: 0.558566
--- --- best model found so far: 0.558566
-------------------------------------------------------
Epoch [84/200]
--- Parsing track medium-race_10.csv
--- Parsing track easy-race_03.csv
--- Parsing track hard-race_22.csv
--- Parsing track easy-race_00.csv
--- Parsing track easy-race_02.csv
--- Parsing track hard-race_21.csv
--- Parsing track medium-race_13.csv
--- Parsing track easy-race_01.csv
--- Parsing track hard-race_20.csv
--- Parsing track medium-race_12.csv
--- Parsing track hard-race_23.csv
--- Parsing track medium-race_11.csv
--- TRAINING LOSS: 0.556867
--- --- best model found so far: 0.556867
----------------------------------------

--- Parsing track easy-race_00.csv
--- Parsing track easy-race_02.csv
--- Parsing track hard-race_21.csv
--- Parsing track medium-race_13.csv
--- Parsing track easy-race_01.csv
--- Parsing track hard-race_20.csv
--- Parsing track medium-race_12.csv
--- Parsing track hard-race_23.csv
--- Parsing track medium-race_11.csv
--- TRAINING LOSS: 0.520258
-------------------------------------------------------
Epoch [113/200]
--- Parsing track medium-race_10.csv
--- Parsing track easy-race_03.csv
--- Parsing track hard-race_22.csv
--- Parsing track easy-race_00.csv
--- Parsing track easy-race_02.csv
--- Parsing track hard-race_21.csv
--- Parsing track medium-race_13.csv
--- Parsing track easy-race_01.csv
--- Parsing track hard-race_20.csv
--- Parsing track medium-race_12.csv
--- Parsing track hard-race_23.csv
--- Parsing track medium-race_11.csv
--- TRAINING LOSS: 0.527088
-------------------------------------------------------
Epoch [114/200]
--- Parsing track medium-race_10.csv
--- Parsing tr

--- Parsing track easy-race_02.csv
--- Parsing track hard-race_21.csv
--- Parsing track medium-race_13.csv
--- Parsing track easy-race_01.csv
--- Parsing track hard-race_20.csv
--- Parsing track medium-race_12.csv
--- Parsing track hard-race_23.csv
--- Parsing track medium-race_11.csv
--- TRAINING LOSS: 0.504801
-------------------------------------------------------
Epoch [128/200]
--- Parsing track medium-race_10.csv
--- Parsing track easy-race_03.csv
--- Parsing track hard-race_22.csv
--- Parsing track easy-race_00.csv
--- Parsing track easy-race_02.csv
--- Parsing track hard-race_21.csv
--- Parsing track medium-race_13.csv
--- Parsing track easy-race_01.csv
--- Parsing track hard-race_20.csv
--- Parsing track medium-race_12.csv
--- Parsing track hard-race_23.csv
--- Parsing track medium-race_11.csv
--- TRAINING LOSS: 0.497204
--- --- best model found so far: 0.497204
-------------------------------------------------------
Epoch [129/200]
--- Parsing track medium-race_10.csv
--- Par

--- Parsing track easy-race_03.csv
--- Parsing track hard-race_22.csv
--- Parsing track easy-race_00.csv
--- Parsing track easy-race_02.csv
--- Parsing track hard-race_21.csv
--- Parsing track medium-race_13.csv
--- Parsing track easy-race_01.csv
--- Parsing track hard-race_20.csv
--- Parsing track medium-race_12.csv
--- Parsing track hard-race_23.csv
--- Parsing track medium-race_11.csv
--- TRAINING LOSS: 0.467756
--- --- best model found so far: 0.467756
-------------------------------------------------------
Epoch [143/200]
--- Parsing track medium-race_10.csv
--- Parsing track easy-race_03.csv
--- Parsing track hard-race_22.csv
--- Parsing track easy-race_00.csv
--- Parsing track easy-race_02.csv
--- Parsing track hard-race_21.csv
--- Parsing track medium-race_13.csv
--- Parsing track easy-race_01.csv
--- Parsing track hard-race_20.csv
--- Parsing track medium-race_12.csv
--- Parsing track hard-race_23.csv
--- Parsing track medium-race_11.csv
--- TRAINING LOSS: 0.475183
-----------

--- Parsing track easy-race_00.csv
--- Parsing track easy-race_02.csv
--- Parsing track hard-race_21.csv
--- Parsing track medium-race_13.csv
--- Parsing track easy-race_01.csv
--- Parsing track hard-race_20.csv
--- Parsing track medium-race_12.csv
--- Parsing track hard-race_23.csv
--- Parsing track medium-race_11.csv
--- TRAINING LOSS: 0.475330
-------------------------------------------------------
Epoch [158/200]
--- Parsing track medium-race_10.csv
--- Parsing track easy-race_03.csv
--- Parsing track hard-race_22.csv
--- Parsing track easy-race_00.csv
--- Parsing track easy-race_02.csv
--- Parsing track hard-race_21.csv
--- Parsing track medium-race_13.csv
--- Parsing track easy-race_01.csv
--- Parsing track hard-race_20.csv
--- Parsing track medium-race_12.csv
--- Parsing track hard-race_23.csv
--- Parsing track medium-race_11.csv
--- TRAINING LOSS: 0.477480
-------------------------------------------------------
Epoch [159/200]
--- Parsing track medium-race_10.csv
--- Parsing tr

--- Parsing track medium-race_12.csv
--- Parsing track hard-race_23.csv
--- Parsing track medium-race_11.csv
--- TRAINING LOSS: 0.446817
--- --- best model found so far: 0.446817
-------------------------------------------------------
Epoch [173/200]
--- Parsing track medium-race_10.csv
--- Parsing track easy-race_03.csv
--- Parsing track hard-race_22.csv
--- Parsing track easy-race_00.csv
--- Parsing track easy-race_02.csv
--- Parsing track hard-race_21.csv
--- Parsing track medium-race_13.csv
--- Parsing track easy-race_01.csv
--- Parsing track hard-race_20.csv
--- Parsing track medium-race_12.csv
--- Parsing track hard-race_23.csv
--- Parsing track medium-race_11.csv
--- TRAINING LOSS: 0.448982
-------------------------------------------------------
Epoch [174/200]
--- Parsing track medium-race_10.csv
--- Parsing track easy-race_03.csv
--- Parsing track hard-race_22.csv
--- Parsing track easy-race_00.csv
--- Parsing track easy-race_02.csv
--- Parsing track hard-race_21.csv
--- Parsi

--- Parsing track hard-race_20.csv
--- Parsing track medium-race_12.csv
--- Parsing track hard-race_23.csv
--- Parsing track medium-race_11.csv
--- TRAINING LOSS: 0.421538
--- --- best model found so far: 0.421538
-------------------------------------------------------
Epoch [196/200]
--- Parsing track medium-race_10.csv
--- Parsing track easy-race_03.csv
--- Parsing track hard-race_22.csv
--- Parsing track easy-race_00.csv
--- Parsing track easy-race_02.csv
--- Parsing track hard-race_21.csv
--- Parsing track medium-race_13.csv
--- Parsing track easy-race_01.csv
--- Parsing track hard-race_20.csv
--- Parsing track medium-race_12.csv
--- Parsing track hard-race_23.csv
--- Parsing track medium-race_11.csv
--- TRAINING LOSS: 0.424913
-------------------------------------------------------
Epoch [197/200]
--- Parsing track medium-race_10.csv
--- Parsing track easy-race_03.csv
--- Parsing track hard-race_22.csv
--- Parsing track easy-race_00.csv
--- Parsing track easy-race_02.csv
--- Parsi

In [10]:
# dump losses to file
with open('split_checkpoints/{}/losses.pkl'.format(COMMAND), 'wb') as file:
    pickle.dump(losses, file)

In [12]:
# test dump worked
with open('split_checkpoints/{}/losses.pkl'.format(COMMAND), 'rb') as file:
    hello = pickle.load(file)
hello

{'training': [1.540663681924343,
  0.9666658975183964,
  0.9044640175998211,
  0.8651743708178401,
  0.8714125948026776,
  0.8324881340377033,
  0.8452917858958244,
  0.8028058307245374,
  0.7838735990226269,
  0.7680966397747397,
  0.8074919544160366,
  0.7939923061057925,
  0.7470022384077311,
  0.786412313580513,
  0.761294384021312,
  0.7697959402576089,
  0.74113236553967,
  0.7234220542013645,
  0.7413505865260959,
  0.7791558029130101,
  0.7567182113416493,
  0.7626897236332297,
  0.7458144295960665,
  0.7293128417804837,
  0.7302228212356567,
  0.7030009124428034,
  0.6947828317061067,
  0.7311645993031561,
  0.7146716946735978,
  0.7119283801876009,
  0.7436360297724605,
  0.7190381116233766,
  0.7080178428441286,
  0.7035957584157586,
  0.6820943844504654,
  0.6926703187637031,
  0.7108126860111952,
  0.7039025644771755,
  0.6902434686198831,
  0.6777229667641222,
  0.7046507289633155,
  0.7249881140887737,
  0.681992347817868,
  0.6813712902367115,
  0.6581161804497242,
  0.