In [1]:
import torch
from torch import nn

import torchvision
from torchvision.datasets import ImageFolder

from torchvision import transforms

from torch.utils.data import DataLoader
from pathlib import Path
from torchvision.models import resnet101

In [2]:
import sys
sys.path.append("..")

In [3]:
from video_classification.datasets import FolderOfFrameFoldersDataset, FrameWindowDataset

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
ROOT = Path("/home/ubuntu/SupervisedVideoClassification")
DATA_ROOT = Path(ROOT/"data")

In [6]:
train_transforms = transforms.Compose([
    torchvision.transforms.ColorJitter(),
    transforms.RandomHorizontalFlip(p=0.25),
    transforms.RandomVerticalFlip(p=0.25),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
])

valid_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
])

In [7]:
train_ds = FolderOfFrameFoldersDataset(DATA_ROOT/'train', 
                                       transform=train_transforms, 
                                       base_class=FrameWindowDataset,
                                       window_size=3,
                                       overlapping=True,)
valid_ds = FolderOfFrameFoldersDataset(DATA_ROOT/'validation', 
                                       transform=valid_transforms, 
                                       base_class=FrameWindowDataset,
                                       window_size=3,
                                       overlapping=True,)

In [8]:
from torch import nn
from torchvision.models import resnet101
from video_classification.models.mlp import MLP


class SingleImageResNetModel(nn.Module):
    def __init__(self, mlp_sizes=[768, 128, 2]):
        super().__init__()
        resnet = resnet101(pretrained=True)
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)

        self.clf = MLP(2048, mlp_sizes)
        self.freeze_resnet()

    def forward(self, x):
        x = self.resnet(x).squeeze()
        x = self.clf(x)
        return x

    def freeze_resnet(self):
        for p in self.resnet.parameters():
            p.requires_grad = False

    def unfreeze_resnet(self):
        for p in self.resnet.parameters():
            p.requires_grad = True


In [9]:
from torch import nn


class AverageImagesModel(nn.Module):
    def __init__(self, mlp_sizes=[768, 128, 2]):
        super().__init__()
        self.single_image_model = SingleImageResNetModel(mlp_sizes)

    def forward(self, x):
        # x is of size (B, T, C, H, W)
        x = x.mean(1)  # We average all images in axis T
        x = self.single_image_model(x)  # and then it's business as usual
        return x

    def freeze_vgg(self):
        # Freeze the VGG classifier
        self.single_image_model.freeze_vgg()

    def unfreeze_vgg(self):
        self.single_image_model.unfreeze_vgg()


In [10]:
model = AverageImagesModel(
                 mlp_sizes=[1024, 256, 2]
)

model = model.to(device)

In [11]:
x = torch.stack([train_ds[0][0], train_ds[1][0], train_ds[2][0], train_ds[3][0]]).to(device)

In [12]:
model(x)

tensor([[-0.0064,  0.2145],
        [ 0.3605,  0.3722],
        [-0.1629,  0.2921],
        [-0.4763,  0.1611]], device='cuda:0', grad_fn=<AddmmBackward>)

In [13]:
from video_classification.trainer import Trainer

classes_weights = torch.Tensor([0.3, 1.0]).to(device)
criterion = nn.CrossEntropyLoss(weight=classes_weights)

In [14]:
trainer = Trainer(train_ds, 
                  valid_ds, 
                  model, 
                  criterion,
                  "multi_frame_resnet101_from_scratch",
                  str(ROOT/'checkpoints'),
                  device=device,
                  amp_opt_level="O1",
                  cycle_mult=0.9,
                 )

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


In [15]:
trainer.train(lr=1e-3, 
              batch_size=64, 
              n_epochs=20,
              gradient_accumulation_steps=4,
              num_workers=8,
              max_gradient_norm=2.0,
             )



HBox(children=(IntProgress(value=0, max=417), HTML(value='')))

Training Results - Epoch: 1: Avg accuracy: 0.91 |Precision: 0.96, 0.42 |Recall: 0.94, 0.56 | F1: 0.72 | Avg loss: 0.38
Validation Results - Epoch: 1: Avg accuracy: 0.73 |Precision: 0.93, 0.15 |Recall: 0.77, 0.41 | F1: 0.53 | Avg loss: 0.55


HBox(children=(IntProgress(value=0, max=417), HTML(value='')))

Training Results - Epoch: 2: Avg accuracy: 0.86 |Precision: 0.97, 0.30 |Recall: 0.87, 0.71 | F1: 0.67 | Avg loss: 0.38
Validation Results - Epoch: 2: Avg accuracy: 0.52 |Precision: 0.95, 0.12 |Recall: 0.50, 0.73 | F1: 0.43 | Avg loss: 0.78


HBox(children=(IntProgress(value=0, max=417), HTML(value='')))

Training Results - Epoch: 3: Avg accuracy: 0.95 |Precision: 0.96, 0.82 |Recall: 0.99, 0.42 | F1: 0.77 | Avg loss: 0.29
Validation Results - Epoch: 3: Avg accuracy: 0.93 |Precision: 0.93, 0.79 |Recall: 0.99, 0.27 | F1: 0.69 | Avg loss: 0.43


HBox(children=(IntProgress(value=0, max=417), HTML(value='')))

Training Results - Epoch: 4: Avg accuracy: 0.94 |Precision: 0.96, 0.59 |Recall: 0.97, 0.51 | F1: 0.76 | Avg loss: 0.29
Validation Results - Epoch: 4: Avg accuracy: 0.87 |Precision: 0.93, 0.28 |Recall: 0.92, 0.32 | F1: 0.61 | Avg loss: 0.50


HBox(children=(IntProgress(value=0, max=417), HTML(value='')))

Training Results - Epoch: 5: Avg accuracy: 0.94 |Precision: 0.97, 0.62 |Recall: 0.97, 0.57 | F1: 0.78 | Avg loss: 0.26
Validation Results - Epoch: 5: Avg accuracy: 0.88 |Precision: 0.94, 0.32 |Recall: 0.93, 0.33 | F1: 0.63 | Avg loss: 0.47


HBox(children=(IntProgress(value=0, max=417), HTML(value='')))

Training Results - Epoch: 6: Avg accuracy: 0.95 |Precision: 0.97, 0.70 |Recall: 0.98, 0.58 | F1: 0.80 | Avg loss: 0.25
Validation Results - Epoch: 6: Avg accuracy: 0.93 |Precision: 0.93, 0.72 |Recall: 0.99, 0.28 | F1: 0.68 | Avg loss: 0.45


HBox(children=(IntProgress(value=0, max=417), HTML(value='')))

Training Results - Epoch: 7: Avg accuracy: 0.94 |Precision: 0.97, 0.58 |Recall: 0.96, 0.63 | F1: 0.78 | Avg loss: 0.25
Validation Results - Epoch: 7: Avg accuracy: 0.91 |Precision: 0.93, 0.48 |Recall: 0.97, 0.29 | F1: 0.66 | Avg loss: 0.54


HBox(children=(IntProgress(value=0, max=417), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 65536.0
Training Results - Epoch: 8: Avg accuracy: 0.93 |Precision: 0.97, 0.54 |Recall: 0.95, 0.66 | F1: 0.78 | Avg loss: 0.25
Validation Results - Epoch: 8: Avg accuracy: 0.84 |Precision: 0.94, 0.23 |Recall: 0.88, 0.37 | F1: 0.60 | Avg loss: 0.65


HBox(children=(IntProgress(value=0, max=417), HTML(value='')))

Training Results - Epoch: 9: Avg accuracy: 0.95 |Precision: 0.97, 0.63 |Recall: 0.97, 0.65 | F1: 0.81 | Avg loss: 0.24
Validation Results - Epoch: 9: Avg accuracy: 0.91 |Precision: 0.94, 0.51 |Recall: 0.97, 0.32 | F1: 0.67 | Avg loss: 0.42


HBox(children=(IntProgress(value=0, max=417), HTML(value='')))

Training Results - Epoch: 10: Avg accuracy: 0.94 |Precision: 0.97, 0.60 |Recall: 0.96, 0.67 | F1: 0.80 | Avg loss: 0.23
Validation Results - Epoch: 10: Avg accuracy: 0.90 |Precision: 0.93, 0.38 |Recall: 0.95, 0.30 | F1: 0.64 | Avg loss: 0.52


HBox(children=(IntProgress(value=0, max=417), HTML(value='')))

Training Results - Epoch: 11: Avg accuracy: 0.94 |Precision: 0.97, 0.59 |Recall: 0.96, 0.68 | F1: 0.80 | Avg loss: 0.23
Validation Results - Epoch: 11: Avg accuracy: 0.92 |Precision: 0.94, 0.54 |Recall: 0.97, 0.32 | F1: 0.68 | Avg loss: 0.43


HBox(children=(IntProgress(value=0, max=417), HTML(value='')))

Training Results - Epoch: 12: Avg accuracy: 0.96 |Precision: 0.97, 0.75 |Recall: 0.98, 0.60 | F1: 0.82 | Avg loss: 0.22
Validation Results - Epoch: 12: Avg accuracy: 0.92 |Precision: 0.94, 0.54 |Recall: 0.97, 0.32 | F1: 0.68 | Avg loss: 0.46


HBox(children=(IntProgress(value=0, max=417), HTML(value='')))

Training Results - Epoch: 13: Avg accuracy: 0.95 |Precision: 0.97, 0.74 |Recall: 0.98, 0.59 | F1: 0.82 | Avg loss: 0.23
Validation Results - Epoch: 13: Avg accuracy: 0.92 |Precision: 0.93, 0.68 |Recall: 0.99, 0.27 | F1: 0.67 | Avg loss: 0.57


HBox(children=(IntProgress(value=0, max=417), HTML(value='')))

Training Results - Epoch: 14: Avg accuracy: 0.95 |Precision: 0.98, 0.67 |Recall: 0.97, 0.70 | F1: 0.83 | Avg loss: 0.21
Validation Results - Epoch: 14: Avg accuracy: 0.91 |Precision: 0.94, 0.52 |Recall: 0.97, 0.31 | F1: 0.67 | Avg loss: 0.52


HBox(children=(IntProgress(value=0, max=417), HTML(value='')))

Training Results - Epoch: 15: Avg accuracy: 0.95 |Precision: 0.97, 0.66 |Recall: 0.97, 0.67 | F1: 0.82 | Avg loss: 0.21
Validation Results - Epoch: 15: Avg accuracy: 0.88 |Precision: 0.93, 0.30 |Recall: 0.93, 0.31 | F1: 0.62 | Avg loss: 0.58


HBox(children=(IntProgress(value=0, max=417), HTML(value='')))

Training Results - Epoch: 16: Avg accuracy: 0.95 |Precision: 0.98, 0.63 |Recall: 0.97, 0.69 | F1: 0.81 | Avg loss: 0.21
Validation Results - Epoch: 16: Avg accuracy: 0.89 |Precision: 0.94, 0.37 |Recall: 0.95, 0.33 | F1: 0.64 | Avg loss: 0.52


HBox(children=(IntProgress(value=0, max=417), HTML(value='')))

Training Results - Epoch: 17: Avg accuracy: 0.96 |Precision: 0.97, 0.74 |Recall: 0.98, 0.65 | F1: 0.84 | Avg loss: 0.20
Validation Results - Epoch: 17: Avg accuracy: 0.92 |Precision: 0.93, 0.56 |Recall: 0.98, 0.30 | F1: 0.67 | Avg loss: 0.51


HBox(children=(IntProgress(value=0, max=417), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 131072.0
Training Results - Epoch: 18: Avg accuracy: 0.96 |Precision: 0.97, 0.71 |Recall: 0.98, 0.68 | F1: 0.84 | Avg loss: 0.20
Validation Results - Epoch: 18: Avg accuracy: 0.92 |Precision: 0.94, 0.62 |Recall: 0.98, 0.30 | F1: 0.68 | Avg loss: 0.49


HBox(children=(IntProgress(value=0, max=417), HTML(value='')))

Training Results - Epoch: 19: Avg accuracy: 0.96 |Precision: 0.97, 0.73 |Recall: 0.98, 0.68 | F1: 0.84 | Avg loss: 0.20
Validation Results - Epoch: 19: Avg accuracy: 0.91 |Precision: 0.94, 0.51 |Recall: 0.97, 0.33 | F1: 0.68 | Avg loss: 0.45


HBox(children=(IntProgress(value=0, max=417), HTML(value='')))

Training Results - Epoch: 20: Avg accuracy: 0.95 |Precision: 0.97, 0.70 |Recall: 0.98, 0.67 | F1: 0.83 | Avg loss: 0.20
Validation Results - Epoch: 20: Avg accuracy: 0.91 |Precision: 0.93, 0.51 |Recall: 0.97, 0.30 | F1: 0.67 | Avg loss: 0.56


In [16]:
import pandas as pd
reform = {(outerKey, innerKey): values for outerKey, innerDict in trainer.epoch_state.items() for innerKey, values in innerDict.items()}
df = pd.DataFrame(reform).T
df

Unnamed: 0,Unnamed: 1,accuracy,f1,nll,precision,recall
1,train,0.910991,0.716367,0.37695,"[0.9641000041599068, 0.4218390804597701]","[0.938869758962933, 0.5605906313645621]"
1,test,0.734655,0.528013,0.549095,"[0.9308211473565804, 0.1459915611814346]","[0.7658491439148543, 0.4128878281622912]"
2,train,0.857893,0.672003,0.379128,"[0.9743938981204031, 0.30283365779796667]","[0.8694348794814665, 0.7128309572301426]"
2,test,0.521831,0.434401,0.778723,"[0.9500657030223391, 0.1240846216436127]","[0.501850994909764, 0.7279236276849642]"
3,train,0.950805,0.766638,0.289522,"[0.9558467899212107, 0.8229475766567754]","[0.992748632772939, 0.42362525458248473]"
3,test,0.929551,0.685175,0.430452,"[0.9338555265448216, 0.7931034482758621]","[0.993058769088385, 0.2744630071599045]"
4,train,0.937709,0.757485,0.288317,"[0.961663391747203, 0.5887850467289719]","[0.9714806562689893, 0.5132382892057027]"
4,test,0.868171,0.614446,0.502802,"[0.9334114888628371, 0.28361344537815125]","[0.9211013419713096, 0.3221957040572792]"
5,train,0.942212,0.781067,0.261921,"[0.9661255890764088, 0.6163556531284303]","[0.9716832084261697, 0.5717922606924644]"
5,test,0.878296,0.629911,0.469749,"[0.9351615152219381, 0.319634703196347]","[0.9310504396112911, 0.3341288782816229]"
