# TEST

In [1]:
from barbar import Bar
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import torch
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import utils

## 1. Split train / validation

In [2]:
split = 0.8
all_train_videos = np.array(utils.get_train_test_video_names()[0])

In [3]:
np.random.seed(0)
train_videos = np.array(all_train_videos)[np.random.choice(len(all_train_videos), int(0.8 * len(all_train_videos)), replace=False)]
validation_videos = np.setdiff1d(all_train_videos, train_videos, assume_unique=False)
train_videos.sort()
validation_videos.sort()

In [4]:
train_data = utils.HernitiaDataset(videonames = train_videos, resize = (224,224), mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
validation_data = utils.HernitiaDataset(videonames = validation_videos, resize = (224,224), mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])

## 2. Model construction

In [5]:
class MyFC(torch.nn.Module):
    def __init__(self, num_classes):
        super(MyFC, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(512,256),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(256,num_classes),
        )

    def forward(self, x):
        return self.classifier(x)

resnet18 = models.resnet18(pretrained=True)
resnet18.fc = MyFC(utils.num_classes)
resnet18 = resnet18.to(utils.device)

for name, param in resnet18.named_parameters():
    if 'fc' not in name:
        param.requires_grad=False

In [6]:
# training parameters
LEARNING_RATE = 0.001
EPOCHS = 5
BATCH_SIZE = 32

# instantiate data loaders
train_loader = utils.DataLoader(dataset=train_data, batch_size=1, shuffle=False)
validation_loader = utils.DataLoader(dataset=validation_data, batch_size=1, shuffle=False)

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(resnet18.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode = "min", patience = 1, verbose = True)

## 3. Useful functions

In [7]:
def one_hot_acc(y_pred, y_true):
    """
    Description
    -------------
    Computes accuracy of prediction vs. one hot encoded multi-class label for one batch

    Parameters
    -------------
    y_pred         : list of names of the video to include in the dataset
    y_true         : tuple, shape of the resized frame (default = (224,224))

    Returns
    -------------
    Accuracy over the batch
    """
    y_pred_tag = torch.argmax(y_pred, 1)
    y_true_tag = torch.argmax(y_true, 1)

    correct_results_sum = (y_pred_tag == y_true_tag).sum().float()
    acc = correct_results_sum/y_true.shape[0]
    acc = torch.round(acc * 100)
    
    return acc

In [8]:
# Training
def train(model, epoch):
    print('\nEpoch: %d' % epoch)
    print('Training ...')
    model.train()
    epoch_loss = 0
    epoch_acc = 0
    sum_num_frames = 0

    for video_idx, (frames, labels) in enumerate(Bar(train_loader)):
        frames, labels = frames[0].to(utils.device), labels[0].to(utils.device)
        sum_num_frames += frames.shape[0]
        num_batches = frames.shape[0] // BATCH_SIZE

        for batch_idx in range(num_batches):

            # last batch might not be of size BATCH_SIZE
            frames_batch = frames[int(batch_idx*BATCH_SIZE):int((batch_idx+1)*BATCH_SIZE)]
            labels_batch = labels[int(batch_idx*BATCH_SIZE):int((batch_idx+1)*BATCH_SIZE)]
            actual_batch_size = frames_batch.shape[0]

            optimizer.zero_grad()
            outputs = model(frames_batch)

            loss = criterion(outputs, labels_batch.float())
            acc = one_hot_acc(outputs, labels_batch.float())
            
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item() * actual_batch_size
            epoch_acc += acc.item() * actual_batch_size
             
    epoch_loss /= sum_num_frames
    epoch_acc /= sum_num_frames
    print(f'Training loss: {epoch_loss:.5f} | Training acc: {epoch_acc:.3f}')
    return epoch_loss, epoch_acc

In [9]:
# Validation
def validate(model, epoch, model_name = 'ResNet18'):
    print('Validation ...')
    global best_acc
    model.eval()
    validation_loss = 0
    validation_acc = 0
    sum_num_frames = 0

    with torch.no_grad():
        for video_idx, (frames, labels) in enumerate(Bar(validation_loader)):
            frames, labels = frames[0].to(utils.device), labels[0].to(utils.device)
            sum_num_frames += frames.shape[0]
            num_batches = frames.shape[0] // BATCH_SIZE

            for batch_idx in range(num_batches):

                # last batch might not be of size BATCH_SIZE
                frames_batch = frames[int(batch_idx*BATCH_SIZE):int((batch_idx+1)*BATCH_SIZE)]
                labels_batch = labels[int(batch_idx*BATCH_SIZE):int((batch_idx+1)*BATCH_SIZE)]
                actual_batch_size = frames_batch.shape[0]

                outputs = model(frames_batch)

                loss = criterion(outputs, labels_batch.float())
                acc = one_hot_acc(outputs, labels_batch.float())

                validation_loss += loss.item()
                validation_acc += acc.item() * actual_batch_size

        validation_loss /= sum_num_frames
        validation_acc /= sum_num_frames
        print(f'Validation loss: {validation_loss:.5f} | Validation acc: {validation_acc:.3f}')

    # Save checkpoint.
    acc = validation_acc
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': model.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/' + model_name + '.pth')
        best_acc = acc
    
    return validation_loss, validation_acc

In [10]:
def plot_performance(model_name, training_loss, training_acc, validation_loss, validation_acc):

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18,6))

    # subplot for losses
    ax1.grid()
    ax1.set_xlabel('epoch')
    ax1.set_ylabel('loss')
    ax1.plot(training_loss, color = 'blue', label = "training loss")
    ax1.plot(validation_loss, color = 'green', label = "validation loss")
    ax1.legend()

    # subplot for losses
    ax2.grid()
    ax2.set_xlabel('epoch')
    ax2.set_ylabel('accuracy')
    ax2.plot(training_acc, color = 'blue', label = "training accuracy")
    ax2.plot(validation_acc, color = 'green', label = "validation accuracy")
    ax2.legend()

    plt.subplots_adjust(wspace=0.5)
    plt.suptitle(model_name)
    plt.savefig('figures/' + model_name + '.png')
    plt.show()

## 4. Train

In [11]:
train_loss, validation_loss = [], []
train_acc, validation_acc = [], []
best_acc = 0

for epoch in range(EPOCHS):
    trnl, trnacc = train(resnet18, epoch)
    tstl, tstacc = validate(resnet18, epoch, 'ResNet18')
    train_loss.append(trnl)
    train_acc.append(trnacc)
    validation_loss.append(tstl)
    validation_acc.append(tstacc)
    scheduler.step(tstl)


Epoch: 0
Training ...
Training loss: 0.24851 | Training acc: 24.771
Validation ...
Validation loss: 0.00861 | Validation acc: 24.483
Saving..

Epoch: 1
Training ...
Training loss: 0.22028 | Training acc: 26.234
Validation ...
Validation loss: 0.00649 | Validation acc: 24.483

Epoch: 2
Training ...
Training loss: 0.20452 | Training acc: 26.237
Validation ...
Validation loss: 0.00644 | Validation acc: 24.483

Epoch: 3
Training ...
 3/56: [=>..............................] - ETA 809.9s

KeyboardInterrupt: 