In [0]:
import os
import logging

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Subset, DataLoader, random_split
from torch.backends import cudnn

import torchvision
from torchvision import transforms
from torchvision.models import resnet34

from PIL import Image
from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import wandb
from datetime import datetime
from utils import Config, test
from gtea_dataset import gtea61

In [0]:
# !wandb login ***

In [0]:
config_stage1 = Config({"stage": 1,
                        "num_classes": 61,
                        "batch_size": 32,
                        "lstm_mem_size": 512,
                        "lr": 1e-3,
                        "optimizer": "adam",
                        "epochs": 10,  # 200 TODO
                        "decay_steps": [25, 75, 150],
                        "decay_factor": 0.1,
                        "weight_decay": 5e-5,
                        "val_frequency": 3,
                        "models_dir": "models",
                        "seq_len": 7,
                        "training_user_split": [1, 3, 4],
                        "val_user_split": [2]})

config_stage2 = Config({"stage": 2,
                        "num_classes": 61,
                        "batch_size": 32,
                        "lstm_mem_size": 512,
                        "lr": 1e-4,
                        "optimizer": "adam",
                        "epochs": 10, # 150 TODO
                        "decay_steps": [25, 75],
                        "decay_factor": 0.1,
                        "weight_decay": 5e-5,
                        "val_frequency": 3,
                        "models_dir": "models",
                        "seq_len": 7,
                        "training_user_split": [1, 3, 4],
                        "val_user_split": [2]})

In [0]:
if not os.path.isdir('./GTEA61'):
  !git clone https://github.com/MauriVass/GTEA61

if not os.path.isdir('./'+config_stage1.mdoels_dir):
  os.mkdir('./'+config_stage1.mdoels_dir)

In [0]:
def PrepareTraining(config):
    train_params = []
    if config.stage == 1:

        model = attentionModel(num_classes=config.num_classes, mem_size=config.lstm_mem_size)
        model.train(False)
        for params in model.parameters():
            params.requires_grad = False

    else:

        model = attentionModel(num_classes=config.num_classes, mem_size=config.lstm_mem_size)
        stage1_dict = config.models_dir + '/best_model_rgb_state_dict.pth'
        model.load_state_dict(torch.load(stage1_dict))
        model.train(False)
        for params in model.parameters():
            params.requires_grad = False
        #
        for params in model.resNet.layer4[0].conv1.parameters():
            params.requires_grad = True
            train_params += [params]

        for params in model.resNet.layer4[0].conv2.parameters():
            params.requires_grad = True
            train_params += [params]

        for params in model.resNet.layer4[1].conv1.parameters():
            params.requires_grad = True
            train_params += [params]

        for params in model.resNet.layer4[1].conv2.parameters():
            params.requires_grad = True
            train_params += [params]

        for params in model.resNet.layer4[2].conv1.parameters():
            params.requires_grad = True
            train_params += [params]
        #
        for params in model.resNet.layer4[2].conv2.parameters():
            params.requires_grad = True
            train_params += [params]
        #
        for params in model.resNet.fc.parameters():
            params.requires_grad = True
            train_params += [params]

        model.resNet.layer4[0].conv1.train(True)
        model.resNet.layer4[0].conv2.train(True)
        model.resNet.layer4[1].conv1.train(True)
        model.resNet.layer4[1].conv2.train(True)
        model.resNet.layer4[2].conv1.train(True)
        model.resNet.layer4[2].conv2.train(True)
        model.resNet.fc.train(True)

    for params in model.lstm_cell.parameters():
        params.requires_grad = True
        train_params += [params]

    for params in model.classifier.parameters():
        params.requires_grad = True
        train_params += [params]

    return model, train_params

In [0]:
def TrainingRGB(model, config):
    wandb.watch(model, log="all")
    train_iter = 0
    best_accuracy = 0
    train = []
    val = []
    for epoch in range(config.epochs):
        epoch_loss = 0
        numCorrTrain = 0
        trainSamples = 0
        iterPerEpoch = 0
        model.lstm_cell.train(True)
        model.classifier.train(True)
        # writer.add_scalar('lr', optimizer_fn.param_groups[0]['lr'], epoch+1)
        if config.stage == 2:
            model.resNet.layer4[0].conv1.train(True)
            model.resNet.layer4[0].conv2.train(True)
            model.resNet.layer4[1].conv1.train(True)
            model.resNet.layer4[1].conv2.train(True)
            model.resNet.layer4[2].conv1.train(True)
            model.resNet.layer4[2].conv2.train(True)
            model.resNet.fc.train(True)
        for inputs, labels in train_loader:
            train_iter += 1
            iterPerEpoch += 1
            optimizer_fn.zero_grad()
            trainSamples += inputs.size(0)
            inputs = inputs.permute(1, 0, 2, 3, 4).to(config.device)  # but why?
            labels = labels.to(config.device)
            output_label, _ = model(inputs)
            loss = loss_fn(output_label, labels)
            loss.backward()
            optimizer_fn.step()
            _, predicted = torch.max(output_label.data, 1)

            predicted = predicted.to(config.device)
            numCorrTrain += torch.sum(predicted == labels).data.item()

            # numCorrTrain += (predicted == targets.cuda()).sum()
            epoch_loss += loss.item()
        optim_scheduler.step()
        avg_loss = epoch_loss / iterPerEpoch
        trainAccuracy = (numCorrTrain / trainSamples)

        print('Train: Epoch = {}/{} | Loss = {} | Accuracy = {}'.format(epoch + 1, config.epochs, avg_loss, trainAccuracy))

        max_loss = 6
        avg_loss_normalized = avg_loss if avg_loss < max_loss else max_loss
        train.append((trainAccuracy, avg_loss_normalized))
        wandb.log({"train_loss": avg_loss_normalized,
                   "train_accuracy": trainAccuracy,
                   "eopch": (epoch + 1)})

        if (epoch + 1) % config.val_frequency == 0:
            with torch.no_grad():
                model.eval()
                val_loss_epoch = 0
                val_iter = 0
                val_samples = 0
                numCorr = 0
                for inputs, labels in val_loader:
                    val_iter += 1
                    val_samples += inputs.size(0)
                    inputs = inputs.permute(1, 0, 2, 3, 4).to(config.device)
                    labels = labels.to(config.device)
                    output_label, _ = model(inputs)
                    val_loss = loss_fn(output_label, labels)
                    val_loss_epoch += val_loss.item()
                    _, predicted = torch.max(output_label.data, 1)
                    numCorr += torch.sum(predicted == labels).data.item()
            val_accuracy = (numCorr / val_samples)
            avg_val_loss = val_loss_epoch / val_iter
            print('*****  Val: Epoch = {} | Loss {} | Accuracy = {} *****'.format(epoch + 1, avg_val_loss, val_accuracy))

            avg_val_loss_normalized = avg_val_loss if avg_val_loss < max_loss else max_loss
            val.append((val_accuracy, avg_val_loss_normalized))
            wandb.log({"valid_loss": avg_val_loss_normalized,
                       "valid_accuracy": val_accuracy,
                       "eopch": (epoch + 1)})

            if val_accuracy > best_accuracy:
                save_path_model = (config.models_dir + '/best_model_rgb_state_dict.pth')
                torch.save(model.state_dict(), save_path_model)
                best_accuracy = val_accuracy
        else:
            if (epoch + 1) % 10 == 0:
                save_path_model = (config.models_dir + '/model_rgb_state_dict_epoch' + str(epoch + 1) + '.pth')
                # torch.save(model.state_dict(), save_path_model)
    wandb.run.summary["best_valid_accuracy"] = best_accuracy
    return train, val

Prepare Datasets

In [0]:
from gtea_dataset import gtea61
from spatial_transforms import *
from objectAttentionModelConvLSTM import *

normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
spatial_transform = Compose([Scale(256), RandomHorizontalFlip(), MultiScaleCornerCrop([1, 0.875, 0.75, 0.65625], 224),
                             ToTensor(), normalize])

gtea_root = "GTEA61"
training_user_split = [1, 3, 4]
val_user_split = [2]
config = config_stage1
train_dataset = gtea61("rgb", gtea_root, split="train", user_split=training_user_split, seq_len_rgb=config.seq_len, transform_rgb=spatial_transform, preload=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=0, pin_memory=True)

val_transform = Compose([Scale(256), CenterCrop(224), ToTensor(), normalize])
val_dataset = gtea61("rgb", gtea_root, split="test", user_split=val_user_split, seq_len_rgb=config.seq_len, transform_rgb=val_transform, preload=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config.batch_size, shuffle=True, num_workers=0, pin_memory=True)

Train Stage 1

In [0]:
config = config_stage1

model, train_params = PrepareTraining(config)
model.to(config.device)

loss_fn = nn.CrossEntropyLoss()
optimizer_fn = torch.optim.Adam(train_params, lr=config.lr, weight_decay=config.weight_decay, eps=1e-4)
optim_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer_fn, milestones=config.decay_steps, gamma=config.decay_factor)

training_time = datetime.now().strftime("%d-%b_%H-%M")
wandb.init(config=config, group=f"{config.seq_len}f", name=f"{training_time} Stage1, {config.seq_len}f, T{str(config.training_user_split).replace(' ', '')}", project="mldl-fpar")

train_rgb, val_rgb = TrainingRGB(model, config)

Train Stage 2

In [0]:
config = config_stage2

model, train_params2 = PrepareTraining(stage)
model.lstm_cell.train(True)
model.classifier.train(True)
model.to(config.device)

loss_fn = nn.CrossEntropyLoss()
optimizer_fn = torch.optim.Adam(train_params2, lr=config.lr, weight_decay=config.weight_decay, eps=1e-4)
optim_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer_fn, milestones=config.decay_steps, gamma=config.decay_factor)


training_time = datetime.now().strftime("%d-%b_%H-%M")
wandb.init(config=config, group=f"{config.seq_len}f", name=f"{training_time} Stage2, {config.seq_len}f, T{str(config.training_user_split).replace(' ', '')}", project="mldl-fpar")

train_rgb, val_rgb = TrainingRGB(model, config)