In [1]:
import torch
from torch import nn

import torchvision
from torchvision.datasets import ImageFolder

from torchvision import transforms

from torch.utils.data import DataLoader
from pathlib import Path
from torchvision.models import resnet101

In [2]:
import sys
sys.path.append("..")

In [3]:
from video_classification.datasets import FolderOfFrameFoldersDataset, FrameWindowDataset

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
ROOT = Path("/home/ubuntu/SupervisedVideoClassification")
DATA_ROOT = Path(ROOT/"data")

In [6]:
train_transforms = transforms.Compose([
    torchvision.transforms.ColorJitter(),
    transforms.RandomHorizontalFlip(p=0.25),
    transforms.RandomVerticalFlip(p=0.25),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
])

valid_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
])

In [7]:
train_ds = FolderOfFrameFoldersDataset(DATA_ROOT/'train', 
                                       transform=train_transforms, 
                                       base_class=FrameWindowDataset,
                                       window_size=5,
                                       overlapping=True,)
valid_ds = FolderOfFrameFoldersDataset(DATA_ROOT/'validation', 
                                       transform=valid_transforms, 
                                       base_class=FrameWindowDataset,
                                       window_size=5,
                                       overlapping=True,)

In [8]:
from torch import nn
from torchvision.models import resnet101
from video_classification.models.mlp import MLP


class SingleImageResNetModel(nn.Module):
    def __init__(self, mlp_sizes=[768, 128, 2]):
        super().__init__()
        resnet = resnet101(pretrained=True)
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)

        self.clf = MLP(2048, mlp_sizes)
        self.freeze_resnet()

    def forward(self, x):
        x = self.resnet(x).squeeze()
        x = self.clf(x)
        return x

    def freeze_resnet(self):
        for p in self.resnet.parameters():
            p.requires_grad = False

    def unfreeze_resnet(self):
        for p in self.resnet.parameters():
            p.requires_grad = True


In [9]:
import torch
from torch import nn
from video_classification.models.mlp import MLP


class MultiImageModel(nn.Module):
    def __init__(self,
                 window_size=3,
                 single_mlp_sizes=[768, 128],
                 joint_mlp_sizes=[64, 2]):
        super().__init__()
        self.window_size = window_size
        self.single_mlp_sizes = single_mlp_sizes
        self.joint_mlp_sizes = joint_mlp_sizes
        
        self.single_image_model = SingleImageResNetModel(self.single_mlp_sizes)
        self.in_features = self.single_mlp_sizes[-1] * self.window_size
        self.clf = MLP(self.in_features, joint_mlp_sizes)

    def forward(self, x):
        # x is of size [B, T, C, H, W]. In other words, a batch of windows.
        # each img for the same window goes through SingleImageModel
        x = x.transpose(0, 1)  # -> [T, B, C, H, W]
        x = torch.cat([self.single_image_model(window) for window in x], 1)
        # x is now of size [B, T * single_mlp_sizes[-1]]
        
        x = self.clf(x)
        # Now size is [B, joint_mlp_sizes[-1]] which should always be 2

        return x
    
    def freeze_single_image_model(self):
        # Freeze the VGG classifier
        for p in self.single_image_model.parameters():
            p.requires_grad = False
            
    def unfreeze_single_image_model(self):
        # Unfreeze the VGG classifier. Training the whole VGG is a no-go, so we only train the classifier part.
        for p in self.single_image_model.clf.parameters():
            p.requires_grad = True 

In [10]:
model = MultiImageModel(
                 window_size=5,
                 single_mlp_sizes=[1024, 256],
                 joint_mlp_sizes=[128, 2])

model = model.to(device)

In [11]:
x = torch.stack([train_ds[0][0], train_ds[1][0], train_ds[2][0], train_ds[3][0]]).to(device)

In [12]:
model(x)

tensor([[-0.2204, -0.2248],
        [ 0.1438,  0.3746],
        [ 0.2079, -0.2529],
        [ 0.1585,  0.2678]], device='cuda:0', grad_fn=<AddmmBackward>)

In [13]:
from video_classification.trainer import Trainer

classes_weights = torch.Tensor([0.3, 1.0]).to(device)
criterion = nn.CrossEntropyLoss(weight=classes_weights)

In [14]:
trainer = Trainer(train_ds, 
                  valid_ds, 
                  model, 
                  criterion,
                  "multi_frame_resnet101_from_scratch",
                  str(ROOT/'checkpoints'),
                  device=device,
                  amp_opt_level="O1",
                  cycle_mult=0.9,
                 )

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


In [15]:
trainer.train(lr=1e-3, 
              batch_size=64, 
              n_epochs=20,
              gradient_accumulation_steps=4,
              num_workers=8,
              max_gradient_norm=2.0,
             )



HBox(children=(IntProgress(value=0, max=416), HTML(value='')))

Validation Results - Epoch: 1: Avg accuracy: 0.68 |Precision: 0.16 |Recall: 0.64 | F1: 0.26 | Avg loss: 0.69


HBox(children=(IntProgress(value=0, max=416), HTML(value='')))

Validation Results - Epoch: 2: Avg accuracy: 0.89 |Precision: 0.41 |Recall: 0.55 | F1: 0.47 | Avg loss: 0.37


HBox(children=(IntProgress(value=0, max=416), HTML(value='')))

Validation Results - Epoch: 3: Avg accuracy: 0.70 |Precision: 0.15 |Recall: 0.51 | F1: 0.23 | Avg loss: 0.63


HBox(children=(IntProgress(value=0, max=416), HTML(value='')))

Validation Results - Epoch: 4: Avg accuracy: 0.90 |Precision: 0.43 |Recall: 0.44 | F1: 0.44 | Avg loss: 0.36


HBox(children=(IntProgress(value=0, max=416), HTML(value='')))

Validation Results - Epoch: 5: Avg accuracy: 0.94 |Precision: 0.70 |Recall: 0.58 | F1: 0.63 | Avg loss: 0.29


HBox(children=(IntProgress(value=0, max=416), HTML(value='')))

Validation Results - Epoch: 6: Avg accuracy: 0.85 |Precision: 0.28 |Recall: 0.45 | F1: 0.34 | Avg loss: 0.47


HBox(children=(IntProgress(value=0, max=416), HTML(value='')))

Validation Results - Epoch: 7: Avg accuracy: 0.86 |Precision: 0.29 |Recall: 0.41 | F1: 0.34 | Avg loss: 0.49


HBox(children=(IntProgress(value=0, max=416), HTML(value='')))

Validation Results - Epoch: 8: Avg accuracy: 0.93 |Precision: 0.74 |Recall: 0.41 | F1: 0.53 | Avg loss: 0.35


HBox(children=(IntProgress(value=0, max=416), HTML(value='')))

Validation Results - Epoch: 9: Avg accuracy: 0.83 |Precision: 0.28 |Recall: 0.58 | F1: 0.38 | Avg loss: 0.48


HBox(children=(IntProgress(value=0, max=416), HTML(value='')))

Validation Results - Epoch: 10: Avg accuracy: 0.89 |Precision: 0.43 |Recall: 0.57 | F1: 0.49 | Avg loss: 0.34


HBox(children=(IntProgress(value=0, max=416), HTML(value='')))

Validation Results - Epoch: 11: Avg accuracy: 0.92 |Precision: 0.57 |Recall: 0.40 | F1: 0.47 | Avg loss: 0.44


HBox(children=(IntProgress(value=0, max=416), HTML(value='')))

Validation Results - Epoch: 12: Avg accuracy: 0.94 |Precision: 0.82 |Recall: 0.38 | F1: 0.52 | Avg loss: 0.38


HBox(children=(IntProgress(value=0, max=416), HTML(value='')))

Validation Results - Epoch: 13: Avg accuracy: 0.94 |Precision: 0.72 |Recall: 0.46 | F1: 0.56 | Avg loss: 0.35


HBox(children=(IntProgress(value=0, max=416), HTML(value='')))

Validation Results - Epoch: 14: Avg accuracy: 0.92 |Precision: 0.56 |Recall: 0.50 | F1: 0.53 | Avg loss: 0.36


HBox(children=(IntProgress(value=0, max=416), HTML(value='')))

Validation Results - Epoch: 15: Avg accuracy: 0.90 |Precision: 0.44 |Recall: 0.42 | F1: 0.43 | Avg loss: 0.41


HBox(children=(IntProgress(value=0, max=416), HTML(value='')))

Validation Results - Epoch: 16: Avg accuracy: 0.91 |Precision: 0.46 |Recall: 0.35 | F1: 0.40 | Avg loss: 0.47


HBox(children=(IntProgress(value=0, max=416), HTML(value='')))

Validation Results - Epoch: 17: Avg accuracy: 0.93 |Precision: 0.64 |Recall: 0.47 | F1: 0.54 | Avg loss: 0.37


HBox(children=(IntProgress(value=0, max=416), HTML(value='')))

Validation Results - Epoch: 18: Avg accuracy: 0.93 |Precision: 0.64 |Recall: 0.39 | F1: 0.49 | Avg loss: 0.39


HBox(children=(IntProgress(value=0, max=416), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 262144.0
Validation Results - Epoch: 19: Avg accuracy: 0.93 |Precision: 0.68 |Recall: 0.43 | F1: 0.53 | Avg loss: 0.41


HBox(children=(IntProgress(value=0, max=416), HTML(value='')))

Validation Results - Epoch: 20: Avg accuracy: 0.92 |Precision: 0.53 |Recall: 0.35 | F1: 0.42 | Avg loss: 0.47


In [16]:
import pandas as pd
reform = {(outerKey, innerKey): values for outerKey, innerDict in trainer.epoch_state.items() for innerKey, values in innerDict.items()}
df = pd.DataFrame(reform).T
df

ValueError: If using all scalar values, you must pass an index