In [1]:
import torch
from torch import nn

import torchvision
from torchvision.datasets import ImageFolder

from torchvision import transforms

from torch.utils.data import DataLoader
from pathlib import Path
from torchvision.models import resnet101

In [2]:
import sys
sys.path.append("..")

In [3]:
from video_classification.datasets import FolderOfFrameFoldersDataset, FrameWindowDataset

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
ROOT = Path("/home/ubuntu/SupervisedVideoClassification")
DATA_ROOT = Path(ROOT/"data")

In [6]:
train_transforms = transforms.Compose([
    torchvision.transforms.ColorJitter(),
    transforms.RandomHorizontalFlip(p=0.25),
    transforms.RandomVerticalFlip(p=0.25),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
])

valid_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
])

In [7]:
train_ds = FolderOfFrameFoldersDataset(DATA_ROOT/'train', 
                                       transform=train_transforms, 
                                       base_class=FrameWindowDataset,
                                       window_size=3,
                                       overlapping=True,)
valid_ds = FolderOfFrameFoldersDataset(DATA_ROOT/'validation', 
                                       transform=valid_transforms, 
                                       base_class=FrameWindowDataset,
                                       window_size=3,
                                       overlapping=True,)

In [8]:
from torch import nn
from torchvision.models import resnet101
from video_classification.models.mlp import MLP


class SingleImageResNetModel(nn.Module):
    def __init__(self, mlp_sizes=[768, 128, 2]):
        super().__init__()
        resnet = resnet101(pretrained=True)
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)

        self.clf = MLP(2048, mlp_sizes)
        self.freeze_resnet()

    def forward(self, x):
        x = self.resnet(x).squeeze()
        x = self.clf(x)
        return x

    def freeze_resnet(self):
        for p in self.resnet.parameters():
            p.requires_grad = False

    def unfreeze_resnet(self):
        for p in self.resnet.parameters():
            p.requires_grad = True


In [9]:
import torch
from torch import nn
from video_classification.models.mlp import MLP


class MultiImageModel(nn.Module):
    def __init__(self,
                 window_size=3,
                 single_mlp_sizes=[768, 128],
                 joint_mlp_sizes=[64, 2]):
        super().__init__()
        self.window_size = window_size
        self.single_mlp_sizes = single_mlp_sizes
        self.joint_mlp_sizes = joint_mlp_sizes
        
        self.single_image_model = SingleImageResNetModel(self.single_mlp_sizes)
        self.in_features = self.single_mlp_sizes[-1] * self.window_size
        self.clf = MLP(self.in_features, joint_mlp_sizes)

    def forward(self, x):
        # x is of size [B, T, C, H, W]. In other words, a batch of windows.
        # each img for the same window goes through SingleImageModel
        x = x.transpose(0, 1)  # -> [T, B, C, H, W]
        x = torch.cat([self.single_image_model(window) for window in x], 1)
        # x is now of size [B, T * single_mlp_sizes[-1]]
        
        x = self.clf(x)
        # Now size is [B, joint_mlp_sizes[-1]] which should always be 2

        return x
    
    def freeze_single_image_model(self):
        # Freeze the VGG classifier
        for p in self.single_image_model.parameters():
            p.requires_grad = False
            
    def unfreeze_single_image_model(self):
        # Unfreeze the VGG classifier. Training the whole VGG is a no-go, so we only train the classifier part.
        for p in self.single_image_model.clf.parameters():
            p.requires_grad = True 

In [10]:
model = MultiImageModel(
                 window_size=3,
                 single_mlp_sizes=[1024, 256],
                 joint_mlp_sizes=[128, 2])

model = model.to(device)

In [11]:
x = torch.stack([train_ds[0][0], train_ds[1][0], train_ds[2][0], train_ds[3][0]]).to(device)

In [12]:
model(x)

tensor([[-0.1784,  0.2266],
        [-0.7380, -0.4095],
        [ 0.2799, -0.2015],
        [ 0.2494, -0.8661]], device='cuda:0', grad_fn=<AddmmBackward>)

In [13]:
from video_classification.trainer import Trainer

classes_weights = torch.Tensor([0.3, 1.0]).to(device)
criterion = nn.CrossEntropyLoss(weight=classes_weights)

In [14]:
trainer = Trainer(train_ds, 
                  valid_ds, 
                  model, 
                  criterion,
                  "multi_frame_resnet101_from_scratch",
                  str(ROOT/'checkpoints'),
                  device=device,
                  amp_opt_level="O1",
                  cycle_mult=0.9,
                 )

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


In [15]:
trainer.train(lr=1e-3, 
              batch_size=128, 
              n_epochs=20,
              gradient_accumulation_steps=2,
              num_workers=8,
              max_gradient_norm=2.0,
             )



HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 1: Avg accuracy: 0.94 |Precision: 0.95, 0.61 |Recall: 0.98, 0.40 | F1: 0.72 | Avg loss: 0.38
Validation Results - Epoch: 1: Avg accuracy: 0.93 |Precision: 0.93, 0.90 |Recall: 1.00, 0.28 | F1: 0.70 | Avg loss: 0.46


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 2: Avg accuracy: 0.95 |Precision: 0.96, 0.84 |Recall: 0.99, 0.47 | F1: 0.79 | Avg loss: 0.29
Validation Results - Epoch: 2: Avg accuracy: 0.93 |Precision: 0.94, 0.74 |Recall: 0.99, 0.30 | F1: 0.70 | Avg loss: 0.39


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 3: Avg accuracy: 0.93 |Precision: 0.97, 0.55 |Recall: 0.96, 0.63 | F1: 0.78 | Avg loss: 0.27
Validation Results - Epoch: 3: Avg accuracy: 0.88 |Precision: 0.94, 0.34 |Recall: 0.93, 0.35 | F1: 0.64 | Avg loss: 0.47


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 4: Avg accuracy: 0.88 |Precision: 0.98, 0.35 |Recall: 0.89, 0.72 | F1: 0.70 | Avg loss: 0.35
Validation Results - Epoch: 4: Avg accuracy: 0.74 |Precision: 0.95, 0.18 |Recall: 0.76, 0.56 | F1: 0.56 | Avg loss: 0.70


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 5: Avg accuracy: 0.96 |Precision: 0.97, 0.78 |Recall: 0.99, 0.67 | F1: 0.85 | Avg loss: 0.21
Validation Results - Epoch: 5: Avg accuracy: 0.90 |Precision: 0.94, 0.44 |Recall: 0.95, 0.37 | F1: 0.67 | Avg loss: 0.46


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 6: Avg accuracy: 0.96 |Precision: 0.96, 0.86 |Recall: 0.99, 0.54 | F1: 0.82 | Avg loss: 0.26
Validation Results - Epoch: 6: Avg accuracy: 0.93 |Precision: 0.94, 0.67 |Recall: 0.98, 0.35 | F1: 0.71 | Avg loss: 0.44


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 7: Avg accuracy: 0.92 |Precision: 0.98, 0.49 |Recall: 0.94, 0.76 | F1: 0.78 | Avg loss: 0.23
Validation Results - Epoch: 7: Avg accuracy: 0.91 |Precision: 0.96, 0.49 |Recall: 0.94, 0.55 | F1: 0.73 | Avg loss: 0.35


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 8: Avg accuracy: 0.96 |Precision: 0.97, 0.88 |Recall: 0.99, 0.59 | F1: 0.84 | Avg loss: 0.22
Validation Results - Epoch: 8: Avg accuracy: 0.92 |Precision: 0.95, 0.59 |Recall: 0.97, 0.45 | F1: 0.73 | Avg loss: 0.42


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 9: Avg accuracy: 0.96 |Precision: 0.98, 0.69 |Recall: 0.97, 0.72 | F1: 0.84 | Avg loss: 0.19
Validation Results - Epoch: 9: Avg accuracy: 0.91 |Precision: 0.95, 0.51 |Recall: 0.96, 0.47 | F1: 0.72 | Avg loss: 0.43


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 10: Avg accuracy: 0.94 |Precision: 0.99, 0.58 |Recall: 0.95, 0.84 | F1: 0.83 | Avg loss: 0.18
Validation Results - Epoch: 10: Avg accuracy: 0.92 |Precision: 0.95, 0.52 |Recall: 0.96, 0.50 | F1: 0.73 | Avg loss: 0.38


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 11: Avg accuracy: 0.95 |Precision: 0.98, 0.63 |Recall: 0.96, 0.81 | F1: 0.84 | Avg loss: 0.17
Validation Results - Epoch: 11: Avg accuracy: 0.84 |Precision: 0.95, 0.28 |Recall: 0.87, 0.53 | F1: 0.64 | Avg loss: 0.52


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 12: Avg accuracy: 0.92 |Precision: 0.98, 0.49 |Recall: 0.94, 0.76 | F1: 0.78 | Avg loss: 0.24
Validation Results - Epoch: 12: Avg accuracy: 0.93 |Precision: 0.96, 0.62 |Recall: 0.96, 0.63 | F1: 0.79 | Avg loss: 0.33


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 13: Avg accuracy: 0.97 |Precision: 0.98, 0.78 |Recall: 0.98, 0.74 | F1: 0.87 | Avg loss: 0.17
Validation Results - Epoch: 13: Avg accuracy: 0.92 |Precision: 0.95, 0.55 |Recall: 0.96, 0.45 | F1: 0.72 | Avg loss: 0.44


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 14: Avg accuracy: 0.97 |Precision: 0.98, 0.83 |Recall: 0.99, 0.69 | F1: 0.87 | Avg loss: 0.18
Validation Results - Epoch: 14: Avg accuracy: 0.93 |Precision: 0.94, 0.72 |Recall: 0.98, 0.41 | F1: 0.74 | Avg loss: 0.50


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 15: Avg accuracy: 0.97 |Precision: 0.98, 0.77 |Recall: 0.98, 0.81 | F1: 0.89 | Avg loss: 0.14
Validation Results - Epoch: 15: Avg accuracy: 0.92 |Precision: 0.96, 0.58 |Recall: 0.96, 0.53 | F1: 0.76 | Avg loss: 0.38


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 16: Avg accuracy: 0.97 |Precision: 0.98, 0.83 |Recall: 0.99, 0.77 | F1: 0.89 | Avg loss: 0.14
Validation Results - Epoch: 16: Avg accuracy: 0.94 |Precision: 0.95, 0.70 |Recall: 0.98, 0.50 | F1: 0.77 | Avg loss: 0.38


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 17: Avg accuracy: 0.97 |Precision: 0.98, 0.86 |Recall: 0.99, 0.76 | F1: 0.89 | Avg loss: 0.14
Validation Results - Epoch: 17: Avg accuracy: 0.93 |Precision: 0.96, 0.65 |Recall: 0.97, 0.56 | F1: 0.78 | Avg loss: 0.38


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 18: Avg accuracy: 0.96 |Precision: 0.99, 0.71 |Recall: 0.97, 0.85 | F1: 0.88 | Avg loss: 0.13
Validation Results - Epoch: 18: Avg accuracy: 0.92 |Precision: 0.95, 0.55 |Recall: 0.96, 0.46 | F1: 0.73 | Avg loss: 0.48


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 19: Avg accuracy: 0.97 |Precision: 0.99, 0.78 |Recall: 0.98, 0.84 | F1: 0.90 | Avg loss: 0.12
Validation Results - Epoch: 19: Avg accuracy: 0.91 |Precision: 0.96, 0.50 |Recall: 0.95, 0.57 | F1: 0.74 | Avg loss: 0.40


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 20: Avg accuracy: 0.95 |Precision: 0.99, 0.64 |Recall: 0.96, 0.85 | F1: 0.85 | Avg loss: 0.14
Validation Results - Epoch: 20: Avg accuracy: 0.90 |Precision: 0.96, 0.45 |Recall: 0.93, 0.62 | F1: 0.73 | Avg loss: 0.42


In [16]:
import pandas as pd
reform = {(outerKey, innerKey): values for outerKey, innerDict in trainer.epoch_state.items() for innerKey, values in innerDict.items()}
df = pd.DataFrame(reform).T
df

Unnamed: 0,Unnamed: 1,accuracy,f1,nll,precision,recall
1,train,0.936958,0.723321,0.381039,"[0.953193333595997, 0.61198738170347]","[0.9800688677334414, 0.395112016293279]"
1,test,0.93398,0.698448,0.455625,"[0.9349099587763072, 0.9015151515151515]","[0.9969921332716335, 0.2840095465393795]"
2,train,0.954407,0.788925,0.288283,"[0.9591158059467919, 0.8438934802571166]","[0.9931132266558639, 0.4679226069246436]"
2,test,0.929129,0.696361,0.387364,"[0.9361050328227571, 0.7426900584795322]","[0.989819527996298, 0.3031026252983294]"
3,train,0.934519,0.776292,0.26769,"[0.9705447981621267, 0.5481742190937088]","[0.9583957869151306, 0.6344195519348269]"
3,test,0.881881,0.640437,0.470022,"[0.9370353159851301, 0.33867276887871856]","[0.9331328088847756, 0.3532219570405728]"
4,train,0.880296,0.700706,0.347542,"[0.9754058477462733, 0.3483424047501237]","[0.8932955235973263, 0.7169042769857433]"
4,test,0.739085,0.557308,0.702977,"[0.9461961238067689, 0.18146417445482865]","[0.7568255437297547, 0.5560859188544153]"
5,train,0.962062,0.851589,0.206908,"[0.974314794037506, 0.7814530419373893]","[0.9850111403686449, 0.6736252545824847]"
5,test,0.901919,0.672527,0.46011,"[0.9395942557556417, 0.4350282485875706]","[0.9537251272559001, 0.36754176610978523]"
