In [1]:
from random import shuffle

import os
import numpy as np
import torch
import torchvision
from matplotlib import pyplot as plt
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from tqdm import tqdm, tqdm_notebook
from toolz import functoolz

from data_utils import ImageClipDataset, split_clips_dataset
from models import residual_attention_network

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [17]:
NUM_THREADS = 4
DATA_PATH = './data/cifar/'
num_epochs = 50
batch_size = 16
val_split = 0.2

In [18]:
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomAffine(15, (0.2, 0.2)),
        transforms.ToTensor(),
    ]),
    'test': transforms.Compose([
        transforms.ToTensor(),
    ])
}

In [23]:
full_train_set = torchvision.datasets.CIFAR10(root=DATA_PATH, train=True, transform=data_transforms['train'])

val_len = int(len(full_train_set) * val_split)
train_len = len(full_train_set) - val_len
train_set, val_set = torch.utils.data.random_split(full_train_set, [train_len, val_len])

test_set = torchvision.datasets.CIFAR10(root=DATA_PATH, train=False, transform=data_transforms['test'])

In [24]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE,
                                           shuffle=True, num_workers=NUM_THREADS,
                                           pin_memory=torch.cuda.is_available())
val_loader = torch.utils.data.DataLoader(val_set, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=NUM_THREADS,
                                         pin_memory=torch.cuda.is_available())
test_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE,
                                          shuffle=False, num_workers=NUM_THREADS,
                                          pin_memory=torch.cuda.is_available())

In [10]:
class AttentionNetwork56(nn.Module):
    # Well, look... Something is off. Sizes should be divided by 2 after the residual layers.
    def __init__(self):
        super().__init__()
        self.maxpool = nn.MaxPool2d(3, 2, 1)
        self.avgpool = nn.AvgPool2d(7, 1)
        
        self.conv1 = nn.Conv2d(3, 64, 7, 2, 3)
        
        self.residual1 = residual_attention_network.make_residual_layer(
            torchvision.models.resnet.Bottleneck,
            64,
            64//2,
            1
        )
        self.attention1 = residual_attention_network.AttentionModule(256//2, 128//2, 1, 2, 1)
        self.residual2 = residual_attention_network.make_residual_layer(
            torchvision.models.resnet.Bottleneck,
            512//2,
            128//2,
            1
        )
        self.attention2 = residual_attention_network.AttentionModule(512//2, 256//2, 1, 2, 1)
        self.residual3 = residual_attention_network.make_residual_layer(
            torchvision.models.resnet.Bottleneck,
            1024//2,
            256//2,
            1
        )
        self.attention3 = residual_attention_network.AttentionModule(1024//2, 512//2, 1, 2, 1)
        self.residual4 = residual_attention_network.make_residual_layer(
            torchvision.models.resnet.Bottleneck,
            2048//2,
            512//2,
            3
        )
        self.clf = nn.Linear(in_features=2048//2, out_features=11)

    def forward(self, input):
        x = self.conv1(input)
        x = self.maxpool(x)
        
        x = self.residual1(x)
        x = self.attention1(x)
        
        x = self.residual2(x)
        x = self.maxpool(x)
        x = self.attention2(x)

        x = self.residual3(x)
        x = self.maxpool(x)
        x = self.attention3(x)
        
        x = self.residual4(x)
        x = self.maxpool(x)
        x = self.avgpool(x)
        
        x = x.view(x.shape[0], -1)
        x = self.clf(x)
        
        return x

In [11]:
att_net = AttentionNetwork56()
att_net = att_net.to(device)

In [12]:
optimizer = torch.optim.Adam(att_net.parameters())
criterion = nn.CrossEntropyLoss()

In [None]:
loss_history = []
hit_history = []
val_loss_history = []
val_hit_history = []
for epoch in tqdm_notebook(range(0, num_epochs), desc='Epochs'):
    hits = 0
    epoch_loss = 0
    for (images, targets, video, clip) in tqdm_notebook(train_loader, leave=False, desc='Training Batches'):
        images, targets = images.to(device), targets.to(device)

        optimizer.zero_grad()
        output = att_net(images)
        loss = criterion(output, targets)
        loss.backward()
        optimizer.step()
        
        predictions = output.max(dim=1)[1]
        hits += (predictions == targets).sum()
        epoch_loss += loss.item()

    optimizer.zero_grad()
    
    loss_history.append(epoch_loss)
    hit_history.append(hits)
    
    with torch.no_grad():
        val_hits = 0
        val_loss = 0
        for (images, targets, _, _) in tqdm_notebook(validation_loader, leave=False, desc='Validation Batches'):
            images, targets = images.to(device), targets.to(device)

            output = att_net(images)
            loss = criterion(output, targets)
            
            predictions = output.max(dim=1)[1]
            val_hits += (predictions == targets).sum()
            val_loss += loss.item()
            
        val_loss_history.append(val_loss)
        val_hit_history.append(val_hits)

HBox(children=(IntProgress(value=0, description='Epochs', max=50), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Training Batches', max=1021), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Validation Batches', max=254), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Training Batches', max=1021), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Validation Batches', max=254), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Training Batches', max=1021), HTML(value='')))

In [None]:
plt.plot(loss_history)

In [None]:
acc = torch.Tensor(hit_history) / len(train_set.samples)

In [None]:
plt.plot(acc.tolist())

In [None]:
val_acc = torch.Tensor(val_hit_history) / len(validation_set.samples)

In [None]:
plt.plot(val_acc.tolist())

In [None]:
acc

In [None]:
val_acc

## TODO
1. Regularizar (L! pra esparsificar)
2. Mais videos e mais frames por video
3. Classificação por video (maioria no vídeo)
4. Estado de atenção