In [1]:
import torch
from torch import nn

import torchvision
from torchvision.datasets import ImageFolder

from torchvision import transforms

from torch.utils.data import DataLoader
from pathlib import Path
from torchvision.models import resnet101

In [2]:
import sys
sys.path.append("..")

In [3]:
from video_classification.datasets import FolderOfFrameFoldersDataset, FrameWindowDataset

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
ROOT = Path("/home/ubuntu/SupervisedVideoClassification")
DATA_ROOT = Path(ROOT/"data")

In [6]:
train_transforms = transforms.Compose([
    transforms.ColorJitter(),
    transforms.RandomHorizontalFlip(p=0.25),
    transforms.RandomVerticalFlip(p=0.25),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
])

valid_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
])

In [7]:
train_ds = FolderOfFrameFoldersDataset(DATA_ROOT/'train', 
                                       transform=train_transforms, 
                                       base_class=FrameWindowDataset,
                                       window_size=3,
                                       overlapping=True,)
valid_ds = FolderOfFrameFoldersDataset(DATA_ROOT/'validation', 
                                       transform=valid_transforms, 
                                       base_class=FrameWindowDataset,
                                       window_size=3,
                                       overlapping=True,)

In [8]:
from torch import nn
from torchvision.models import resnet101
from video_classification.models.mlp import MLP


class SingleImageResNetModel(nn.Module):
    def __init__(self, mlp_sizes=[768, 128, 2]):
        super().__init__()
        resnet = resnet101(pretrained=True)
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)

        self.clf = MLP(2048, mlp_sizes)
        self.freeze_resnet()

    def forward(self, x):
        x = self.resnet(x).squeeze()
        x = self.clf(x)
        return x

    def freeze_resnet(self):
        for p in self.resnet.parameters():
            p.requires_grad = False

    def unfreeze_resnet(self):
        for p in self.resnet.parameters():
            p.requires_grad = True


In [9]:
import torch
from torch import nn
from video_classification.models.mlp import MLP


class MultiImageModel(nn.Module):
    def __init__(self,
                 window_size=3,
                 single_mlp_sizes=[768, 128],
                 joint_mlp_sizes=[64, 2]):
        super().__init__()
        self.window_size = window_size
        self.single_mlp_sizes = single_mlp_sizes
        self.joint_mlp_sizes = joint_mlp_sizes
        
        self.single_image_models = nn.ModuleList([SingleImageResNetModel(self.single_mlp_sizes) for _ in range(window_size)])
        self.in_features = self.single_mlp_sizes[-1] * self.window_size
        self.clf = MLP(self.in_features, joint_mlp_sizes)

    def forward(self, x):
        # x is of size [B, T, C, H, W]. In other words, a batch of windows.
        # each img for the same window goes through SingleImageModel
        x = x.transpose(0, 1)  # -> [T, B, C, H, W]
        encoded_windows = [m(window) for m, window in zip(self.single_image_models, x)]  # List of len T, each elem of size [B, single_mlp_sizes[-1]]
        x = torch.cat(encoded_windows, dim=1)
        # x is now of size [B, T * single_mlp_sizes[-1]]
        
        x = self.clf(x)
        # Now size is [B, joint_mlp_sizes[-1]] which should always be 2

        return x
    
    def freeze_single_image_model(self):
        # Freeze the VGG classifier
        for p in self.single_image_model.parameters():
            p.requires_grad = False
            
    def unfreeze_single_image_model(self):
        # Unfreeze the VGG classifier. Training the whole VGG is a no-go, so we only train the classifier part.
        for p in self.single_image_model.clf.parameters():
            p.requires_grad = True 

In [10]:
model = MultiImageModel(
                 window_size=3,
                 single_mlp_sizes=[1024, 256],
                 joint_mlp_sizes=[128, 2])

model = model.to(device)

In [11]:
x = torch.stack([train_ds[0][0], train_ds[1][0], train_ds[2][0], train_ds[3][0]]).to(device)

In [12]:
model(x)

tensor([[ 1.0122,  0.4045],
        [ 0.1436, -0.0671],
        [ 0.1872, -0.1722],
        [-1.1248,  0.1971]], device='cuda:0', grad_fn=<AddmmBackward>)

In [13]:
from video_classification.trainer import Trainer

classes_weights = torch.Tensor([0.3, 1.0]).to(device)
criterion = nn.CrossEntropyLoss(weight=classes_weights)

In [14]:
trainer = Trainer(train_ds, 
                  valid_ds, 
                  model, 
                  criterion,
                  "multi_frame_resnet101_differentMLP",
                  str(ROOT/'checkpoints'),
                  device=device,
                  amp_opt_level="O1",
                  cycle_mult=0.9,
                 )

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


In [15]:
from tqdm.autonotebook import tqdm



In [16]:
trainer.train(lr=1e-3, 
              batch_size=128, 
              n_epochs=40,
              gradient_accumulation_steps=2,
              num_workers=8,
              max_gradient_norm=2.0,
             )

HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 1: Avg accuracy: 0.62 |Precision: 0.15 |Recall: 0.74 | F1: 0.25 | Avg loss: 0.77


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 2: Avg accuracy: 0.76 |Precision: 0.19 |Recall: 0.53 | F1: 0.28 | Avg loss: 0.55


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 3: Avg accuracy: 0.89 |Precision: 0.36 |Recall: 0.32 | F1: 0.34 | Avg loss: 0.43


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 4: Avg accuracy: 0.93 |Precision: 0.67 |Recall: 0.34 | F1: 0.45 | Avg loss: 0.39


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 5: Avg accuracy: 0.92 |Precision: 0.55 |Recall: 0.37 | F1: 0.44 | Avg loss: 0.42


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 6: Avg accuracy: 0.67 |Precision: 0.15 |Recall: 0.58 | F1: 0.24 | Avg loss: 0.67


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 7: Avg accuracy: 0.90 |Precision: 0.47 |Recall: 0.59 | F1: 0.52 | Avg loss: 0.36


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 8: Avg accuracy: 0.90 |Precision: 0.44 |Recall: 0.38 | F1: 0.41 | Avg loss: 0.51


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 9: Avg accuracy: 0.92 |Precision: 0.54 |Recall: 0.40 | F1: 0.46 | Avg loss: 0.45


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 10: Avg accuracy: 0.94 |Precision: 0.87 |Recall: 0.31 | F1: 0.46 | Avg loss: 0.54


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 11: Avg accuracy: 0.83 |Precision: 0.28 |Recall: 0.59 | F1: 0.38 | Avg loss: 0.48


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 12: Avg accuracy: 0.93 |Precision: 0.67 |Recall: 0.40 | F1: 0.50 | Avg loss: 0.49


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 13: Avg accuracy: 0.93 |Precision: 0.65 |Recall: 0.41 | F1: 0.50 | Avg loss: 0.46


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 14: Avg accuracy: 0.93 |Precision: 0.63 |Recall: 0.53 | F1: 0.58 | Avg loss: 0.39


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 15: Avg accuracy: 0.87 |Precision: 0.33 |Recall: 0.47 | F1: 0.39 | Avg loss: 0.55


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 16: Avg accuracy: 0.94 |Precision: 0.74 |Recall: 0.47 | F1: 0.57 | Avg loss: 0.42


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 17: Avg accuracy: 0.93 |Precision: 0.61 |Recall: 0.52 | F1: 0.56 | Avg loss: 0.39


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 18: Avg accuracy: 0.94 |Precision: 0.67 |Recall: 0.52 | F1: 0.58 | Avg loss: 0.41


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 19: Avg accuracy: 0.91 |Precision: 0.49 |Recall: 0.47 | F1: 0.48 | Avg loss: 0.49


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 20: Avg accuracy: 0.93 |Precision: 0.69 |Recall: 0.42 | F1: 0.52 | Avg loss: 0.48


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 21: Avg accuracy: 0.92 |Precision: 0.60 |Recall: 0.46 | F1: 0.52 | Avg loss: 0.50


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 22: Avg accuracy: 0.94 |Precision: 0.72 |Recall: 0.50 | F1: 0.59 | Avg loss: 0.41


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 23: Avg accuracy: 0.93 |Precision: 0.70 |Recall: 0.42 | F1: 0.53 | Avg loss: 0.51


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 24: Avg accuracy: 0.93 |Precision: 0.71 |Recall: 0.40 | F1: 0.51 | Avg loss: 0.53


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 25: Avg accuracy: 0.93 |Precision: 0.71 |Recall: 0.39 | F1: 0.51 | Avg loss: 0.53


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 26: Avg accuracy: 0.92 |Precision: 0.59 |Recall: 0.46 | F1: 0.52 | Avg loss: 0.50


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 27: Avg accuracy: 0.93 |Precision: 0.66 |Recall: 0.43 | F1: 0.52 | Avg loss: 0.51


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 28: Avg accuracy: 0.93 |Precision: 0.62 |Recall: 0.45 | F1: 0.52 | Avg loss: 0.48


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 29: Avg accuracy: 0.94 |Precision: 0.71 |Recall: 0.46 | F1: 0.56 | Avg loss: 0.49


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 262144.0
Validation Results - Epoch: 30: Avg accuracy: 0.93 |Precision: 0.69 |Recall: 0.40 | F1: 0.51 | Avg loss: 0.57


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 31: Avg accuracy: 0.93 |Precision: 0.65 |Recall: 0.44 | F1: 0.52 | Avg loss: 0.51


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 32: Avg accuracy: 0.93 |Precision: 0.69 |Recall: 0.42 | F1: 0.53 | Avg loss: 0.53


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 33: Avg accuracy: 0.93 |Precision: 0.64 |Recall: 0.46 | F1: 0.53 | Avg loss: 0.49


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 34: Avg accuracy: 0.93 |Precision: 0.68 |Recall: 0.42 | F1: 0.52 | Avg loss: 0.51


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 35: Avg accuracy: 0.94 |Precision: 0.72 |Recall: 0.44 | F1: 0.55 | Avg loss: 0.51


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 36: Avg accuracy: 0.94 |Precision: 0.73 |Recall: 0.44 | F1: 0.55 | Avg loss: 0.51


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 37: Avg accuracy: 0.93 |Precision: 0.68 |Recall: 0.43 | F1: 0.53 | Avg loss: 0.52


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 38: Avg accuracy: 0.93 |Precision: 0.71 |Recall: 0.42 | F1: 0.53 | Avg loss: 0.55


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 39: Avg accuracy: 0.93 |Precision: 0.67 |Recall: 0.42 | F1: 0.51 | Avg loss: 0.55


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Validation Results - Epoch: 40: Avg accuracy: 0.93 |Precision: 0.67 |Recall: 0.46 | F1: 0.54 | Avg loss: 0.50


In [17]:
import pandas as pd
df = pd.DataFrame(trainer.epoch_state).T
df

Unnamed: 0,accuracy,f1,nll,precision,recall
1,0.61738,0.254725,0.773548,0.153846,0.739857
2,0.759755,0.27957,0.546239,0.190189,0.527446
3,0.889475,0.34005,0.433296,0.36,0.322196
4,0.92702,0.454259,0.388901,0.669767,0.343675
5,0.917106,0.439372,0.419865,0.546099,0.367542
6,0.672643,0.236971,0.673203,0.149226,0.575179
7,0.90424,0.520085,0.357694,0.466793,0.587112
8,0.903185,0.409266,0.505089,0.444134,0.379475
9,0.916895,0.461749,0.446874,0.539936,0.403341
10,0.935035,0.459649,0.540156,0.86755,0.312649


In [18]:
df['f1'].argmax()

The current behaviour of 'Series.argmax' is deprecated, use 'idxmax'
instead.
The behavior of 'argmax' will be corrected to return the positional
maximum in the future. For now, use 'series.values.argmax' or
'np.argmax(np.array(values))' to get the position of the maximum
row.
  if __name__ == '__main__':


22