### 1 - Imports

In [43]:
import torch
import os
import time
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from data_loader import VideoFolder
from torchvision.transforms import Compose
from DemoModel import FullModel
import transforms as t
import utils
from tensorboardX import SummaryWriter
from IPython.core.display import HTML
import json
import math

%matplotlib inline

In [44]:
writer = SummaryWriter()

with open('./configs.json') as data_file:
    config = json.load(data_file)

In [45]:
curr_folder = 'full_net_10'
if not os.path.exists(curr_folder):
    os.makedirs(curr_folder)

In [46]:
batch_size = 2
steps_before_print = 2
num_workers = 0
step_size = 2
num_frames = 32 // step_size

### 2 - Seting up Data Loaders

In [47]:
std, mean = [0.2674, 0.2676, 0.2648], [0.4377, 0.4047, 0.3925]

In [48]:
transform = Compose([
    t.GroupResize((100, 160)),
    t.GroupRandomCrop((140, 100)),
    t.GroupRandomRotation(18),
    t.GroupCenterCrop((96, 96)),
    t.GroupToTensor(),
    t.GroupNormalize(std=std, mean=mean),
])

In [49]:
transform_validation = Compose([
    t.GroupResize((100, 160)),
    t.GroupRandomCrop((140, 100)),
    t.GroupCenterCrop((96, 96)),
    t.GroupToTensor(),
    t.GroupNormalize(std=std, mean=mean),
])

In [50]:
train_data = VideoFolder(
    root=config['train_data_folder'],
    csv_file_input=config['full_train_data_csv'],
    csv_file_labels=config['full_labels_csv'],
    clip_size=num_frames,
    nclips=1,
    step_size=step_size,
    is_val=False,
    transform=transform,
)

In [51]:
train_loader = torch.utils.data.DataLoader(
    train_data,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=False,
    drop_last=True)

In [52]:
validation_data = VideoFolder(
    root=config['validation_data_folder'],
    csv_file_input=config['full_validation_data_csv'],
    csv_file_labels=config['full_labels_csv'],
    clip_size=num_frames,
    nclips=1,
    step_size=step_size,
    is_val=False,
    transform=transform_validation,
)

In [53]:
validation_loader = torch.utils.data.DataLoader(
    validation_data,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=False,  # changed
    drop_last=True
)

In [54]:
def save_model(model, use_ts=False):
    if use_ts:
        time_stamp = time.strftime("%d_%b_%Y_%Hh%Mm", time.gmtime())
        torch.save(model.state_dict(), curr_folder + '/{}.ckp'.format(time_stamp))
    else:
        torch.save(model.state_dict(), curr_folder + '/{}.ckp'.format('best_model'))

### 3 - Model definition

In [55]:
class MyModel(nn.Module):
    def __init__(self, batch_size, seq_lenght=8):
        super().__init__()
        self.fm = FullModel(batch_size, seq_lenght)
        self.fc = nn.Linear(27, 27)
        self.batch_size = batch_size
        self.seq_lenght = seq_lenght
        self.steps = 0
        self.steps = 0
        self.epochs = 0
        self.best_valdiation_loss = math.inf

    def forward(self, x):
        x = self.fm.forward(x)
        x = self.fc(x)
        return x

In [56]:
model = MyModel(batch_size=batch_size)

In [57]:
print(model)

MyModel(
  (fm): FullModel(
    (rgb2d): CNN2D(
      (conv1): Sequential(
        (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv2): Sequential(
        (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv3): Sequential(
        (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
      )
      (conv4): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True,

In [58]:
most_recent_file = ''
for file in os.listdir(curr_folder):
    if file.endswith(".ckp"):
        file = os.path.join(".", file)
        if file > most_recent_file:
            most_recent_file = file
if most_recent_file != '':
    print('Model LOADED: ', curr_folder + '/' + most_recent_file)
    loaded_dict = torch.load(curr_folder + '/' + most_recent_file)
    model.fm.load_state_dict(loaded_dict)
else:
    print('No model loaded.')

Model LOADED:  full_net_10/./demo.ckp


In [59]:
criterion = nn.CrossEntropyLoss()

In [60]:
for param in model.fm.parameters():
    param.requires_grad = False

In [61]:
def train(epochs):
    print("Training is about to start...")
    best_valdiation_loss = model.best_valdiation_loss

    for epoch in range(epochs):
        step = 0
        epoch_loss = 0
        epoch_acc = 0
        times_calculated = 0
        total_size = len(train_loader)
        for i, (images, labels) in enumerate(train_loader):
            model.train()

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            writer.add_scalar('training_loss', loss.item(), model.steps)
            loss.backward()
            optimizer.step()

            step += 1
            epoch_loss += loss.item()
            if step % 10 == 0:
                print(f'step {step} of {total_size}')

            if step % steps_before_print == 0:
                # Calculate Accuracy
                model.eval()
                validation_loss, accuracy = utils.calculate_loss_and_accuracy(validation_loader, model, criterion,
                                                                              stop_at=1200)
                writer.add_scalar('validation_loss', validation_loss, model.steps)
                writer.add_scalar('accuracy', accuracy, model.steps)
                epoch_acc += accuracy
                times_calculated += 1
                # Print Loss
                print('Iteration: {}/{} - ({:.2f}%). Loss: {}. Accuracy: {}'.format(step, total_size,
                                                                                    step * 100 / total_size,
                                                                                    loss.item(), accuracy))
                if validation_loss < model.best_valdiation_loss:
                    model.best_valdiation_loss = validation_loss
                    print('Saving best model')
                    save_model(model)
                del validation_loss
            del loss, outputs, images, labels

        model.epochs += 1

        print('Epoch({}) avg loss: {} avg acc: {}'.format(epoch, epoch_loss/step, epoch_acc/times_calculated))
        print('Epoch ', epoch)

In [62]:
learning_rate = 0.001
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate, momentum=0.9)

In [63]:
train(10)
save_model(model, use_ts=True)
learning_rate = learning_rate / 10
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
train(5)
save_model(model, use_ts=True)
learning_rate = learning_rate / 10
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
train(5)
save_model(model, use_ts=True)
learning_rate = learning_rate / 10
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
train(5)
save_model(model, use_ts=True)

Training is about to start...
Iteration: 2/25210 - (0.01%). Loss: 8.35986328125. Accuracy: 5.666666666666667
Saving best model


KeyboardInterrupt: 

In [None]:
#Saves model with a timestamp (prevents overwritting)
save_model(model, use_ts=True)

In [None]:
#Check accuracy for all saved checkpoints

for file in os.listdir(curr_folder):
    if file.endswith(".ckp"):
        print(file)
        print('Model LOADED: ', curr_folder + '/' + file)
        loaded_dict = torch.load(curr_folder + '/' + file)
        #loaded_dict = {k: v for k, v in loaded_dict.items() if not k.startswith('combiner') }
        #model.load_state_dict(loaded_dict, strict=False)
        model.load_state_dict(loaded_dict)
        model.eval()
        validation_loss, accuracy = utils.calculate_loss_and_accuracy(validation_loader, model, criterion, 1500)
        validation_loss, train_accuracy = utils.calculate_loss_and_accuracy(train_loader, model, criterion, 1500)
        print('Validation Acc: {} \t Train Acc: {}'.format(accuracy, train_accuracy))
