In [1]:
import torch
from torch import nn

import torchvision
from torchvision.datasets import ImageFolder

from torchvision import transforms

from torch.utils.data import DataLoader
from pathlib import Path

In [2]:
import sys
sys.path.append("..")

In [3]:
from video_classification.datasets import FolderOfFrameFoldersDataset

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
ROOT = Path("/home/ubuntu/SupervisedVideoClassification")
DATA_ROOT = Path(ROOT/"data")

In [6]:
train_transforms = transforms.Compose([
    torchvision.transforms.ColorJitter(),
    transforms.RandomHorizontalFlip(p=0.25),
    transforms.RandomVerticalFlip(p=0.25),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
])

valid_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
])

In [7]:
train_ds = FolderOfFrameFoldersDataset(DATA_ROOT/'train', transform=train_transforms)
valid_ds = FolderOfFrameFoldersDataset(DATA_ROOT/'validation', transform=valid_transforms)

In [8]:
train_ds

FolderOfFrameFoldersDataset with 26711 samples.
	Overall data distribution: {'negative': 24747, 'positive': 1964}

In [9]:
valid_ds

FolderOfFrameFoldersDataset with 4751 samples.
	Overall data distribution: {'negative': 4332, 'positive': 419}

In [10]:
from torch import nn
from torchvision.models import resnet101
from video_classification.models.mlp import MLP


class SingleImageResNetModel(nn.Module):
    def __init__(self, mlp_sizes=[768, 128, 2]):
        super().__init__()
        resnet = resnet101(pretrained=True)
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)

        self.clf = MLP(2048, mlp_sizes)
        self.freeze_resnet()

    def forward(self, x):
        x = self.resnet(x).squeeze()
        x = self.clf(x)
        return x

    def freeze_resnet(self):
        for p in self.resnet.parameters():
            p.requires_grad = False

    def unfreeze_resnet(self):
        for p in self.resnet.parameters():
            p.requires_grad = True


In [11]:
model = SingleImageResNetModel(mlp_sizes=[1024, 256, 2])

model = model.to(device)

In [12]:
from video_classification.trainer import Trainer

classes_weights = torch.Tensor([0.3, 1.0]).to(device)
criterion = nn.CrossEntropyLoss(weight=classes_weights)

In [13]:
trainer = Trainer(train_ds, 
                  valid_ds, 
                  model, 
                  criterion,
                  "single_frame_resnet",
                  str(ROOT/'checkpoints'),
                  device=device,
                  amp_opt_level="O1",
                  cycle_mult=0.9,
                 )

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


In [14]:
trainer.train(lr=1e-3, 
              batch_size=128, 
              n_epochs=20,
              gradient_accumulation_steps=2,
              num_workers=8,
              max_gradient_norm=2.0,
             )



HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 1: Avg accuracy: 0.95 |Precision: 0.95, 0.88 |Recall: 1.00, 0.35 | F1: 0.74 | Avg loss: 0.34
Validation Results - Epoch: 1: Avg accuracy: 0.89 |Precision: 0.93, 0.37 |Recall: 0.95, 0.30 | F1: 0.64 | Avg loss: 0.46


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 2: Avg accuracy: 0.88 |Precision: 0.98, 0.34 |Recall: 0.89, 0.74 | F1: 0.70 | Avg loss: 0.34
Validation Results - Epoch: 2: Avg accuracy: 0.79 |Precision: 0.94, 0.21 |Recall: 0.82, 0.50 | F1: 0.59 | Avg loss: 0.52


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 3: Avg accuracy: 0.85 |Precision: 0.98, 0.29 |Recall: 0.86, 0.73 | F1: 0.67 | Avg loss: 0.36
Validation Results - Epoch: 3: Avg accuracy: 0.75 |Precision: 0.94, 0.17 |Recall: 0.78, 0.47 | F1: 0.55 | Avg loss: 0.55


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 4: Avg accuracy: 0.95 |Precision: 0.97, 0.74 |Recall: 0.98, 0.57 | F1: 0.81 | Avg loss: 0.24
Validation Results - Epoch: 4: Avg accuracy: 0.89 |Precision: 0.94, 0.37 |Recall: 0.94, 0.33 | F1: 0.64 | Avg loss: 0.44


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 5: Avg accuracy: 0.94 |Precision: 0.98, 0.58 |Recall: 0.96, 0.70 | F1: 0.80 | Avg loss: 0.23
Validation Results - Epoch: 5: Avg accuracy: 0.84 |Precision: 0.94, 0.27 |Recall: 0.88, 0.47 | F1: 0.63 | Avg loss: 0.47


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 6: Avg accuracy: 0.95 |Precision: 0.97, 0.64 |Recall: 0.97, 0.65 | F1: 0.81 | Avg loss: 0.24
Validation Results - Epoch: 6: Avg accuracy: 0.89 |Precision: 0.94, 0.39 |Recall: 0.94, 0.37 | F1: 0.66 | Avg loss: 0.45


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 7: Avg accuracy: 0.96 |Precision: 0.96, 0.93 |Recall: 1.00, 0.49 | F1: 0.81 | Avg loss: 0.25
Validation Results - Epoch: 7: Avg accuracy: 0.94 |Precision: 0.94, 0.95 |Recall: 1.00, 0.29 | F1: 0.71 | Avg loss: 0.48


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 8: Avg accuracy: 0.93 |Precision: 0.98, 0.54 |Recall: 0.95, 0.78 | F1: 0.80 | Avg loss: 0.22
Validation Results - Epoch: 8: Avg accuracy: 0.85 |Precision: 0.95, 0.30 |Recall: 0.89, 0.49 | F1: 0.65 | Avg loss: 0.45


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 9: Avg accuracy: 0.95 |Precision: 0.98, 0.68 |Recall: 0.97, 0.72 | F1: 0.84 | Avg loss: 0.20
Validation Results - Epoch: 9: Avg accuracy: 0.91 |Precision: 0.94, 0.47 |Recall: 0.96, 0.34 | F1: 0.67 | Avg loss: 0.43


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 10: Avg accuracy: 0.95 |Precision: 0.98, 0.62 |Recall: 0.96, 0.77 | F1: 0.83 | Avg loss: 0.19
Validation Results - Epoch: 10: Avg accuracy: 0.90 |Precision: 0.94, 0.41 |Recall: 0.95, 0.39 | F1: 0.67 | Avg loss: 0.43


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 11: Avg accuracy: 0.96 |Precision: 0.97, 0.83 |Recall: 0.99, 0.61 | F1: 0.84 | Avg loss: 0.21
Validation Results - Epoch: 11: Avg accuracy: 0.93 |Precision: 0.94, 0.69 |Recall: 0.98, 0.36 | F1: 0.72 | Avg loss: 0.43


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 12: Avg accuracy: 0.96 |Precision: 0.98, 0.77 |Recall: 0.98, 0.69 | F1: 0.85 | Avg loss: 0.18
Validation Results - Epoch: 12: Avg accuracy: 0.93 |Precision: 0.95, 0.66 |Recall: 0.98, 0.45 | F1: 0.75 | Avg loss: 0.37


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 13: Avg accuracy: 0.96 |Precision: 0.98, 0.76 |Recall: 0.98, 0.74 | F1: 0.86 | Avg loss: 0.18
Validation Results - Epoch: 13: Avg accuracy: 0.91 |Precision: 0.94, 0.51 |Recall: 0.96, 0.39 | F1: 0.70 | Avg loss: 0.43


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 14: Avg accuracy: 0.95 |Precision: 0.98, 0.63 |Recall: 0.96, 0.78 | F1: 0.83 | Avg loss: 0.19
Validation Results - Epoch: 14: Avg accuracy: 0.88 |Precision: 0.95, 0.39 |Recall: 0.92, 0.55 | F1: 0.70 | Avg loss: 0.39


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 15: Avg accuracy: 0.96 |Precision: 0.98, 0.68 |Recall: 0.97, 0.79 | F1: 0.85 | Avg loss: 0.17
Validation Results - Epoch: 15: Avg accuracy: 0.91 |Precision: 0.94, 0.51 |Recall: 0.97, 0.37 | F1: 0.69 | Avg loss: 0.45


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 16: Avg accuracy: 0.96 |Precision: 0.98, 0.72 |Recall: 0.98, 0.77 | F1: 0.86 | Avg loss: 0.16
Validation Results - Epoch: 16: Avg accuracy: 0.91 |Precision: 0.95, 0.49 |Recall: 0.95, 0.47 | F1: 0.71 | Avg loss: 0.40


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 17: Avg accuracy: 0.96 |Precision: 0.98, 0.76 |Recall: 0.98, 0.76 | F1: 0.87 | Avg loss: 0.17
Validation Results - Epoch: 17: Avg accuracy: 0.91 |Precision: 0.94, 0.52 |Recall: 0.97, 0.36 | F1: 0.69 | Avg loss: 0.47


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 18: Avg accuracy: 0.97 |Precision: 0.98, 0.80 |Recall: 0.99, 0.71 | F1: 0.87 | Avg loss: 0.17
Validation Results - Epoch: 18: Avg accuracy: 0.92 |Precision: 0.94, 0.59 |Recall: 0.97, 0.38 | F1: 0.71 | Avg loss: 0.45


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 19: Avg accuracy: 0.96 |Precision: 0.98, 0.72 |Recall: 0.98, 0.80 | F1: 0.87 | Avg loss: 0.15
Validation Results - Epoch: 19: Avg accuracy: 0.91 |Precision: 0.94, 0.51 |Recall: 0.96, 0.41 | F1: 0.70 | Avg loss: 0.45


HBox(children=(IntProgress(value=0, max=209), HTML(value='')))

Training Results - Epoch: 20: Avg accuracy: 0.97 |Precision: 0.98, 0.79 |Recall: 0.98, 0.76 | F1: 0.88 | Avg loss: 0.16
Validation Results - Epoch: 20: Avg accuracy: 0.93 |Precision: 0.95, 0.72 |Recall: 0.98, 0.42 | F1: 0.75 | Avg loss: 0.44


In [15]:
import pandas as pd
reform = {(outerKey, innerKey): values for outerKey, innerDict in trainer.epoch_state.items() for innerKey, values in innerDict.items()}
pd.DataFrame(reform).T

Unnamed: 0,Unnamed: 1,accuracy,f1,nll,precision,recall
1,train,0.948935,0.7389,0.3364,"[0.9510435554183866, 0.879746835443038]","[0.9961611508465673, 0.35386965376782076]"
1,test,0.894338,0.636654,0.460383,"[0.9332579185520362, 0.37462235649546827]","[0.9522160664819944, 0.29594272076372313]"
2,train,0.87698,0.699279,0.336004,"[0.9769747077388096, 0.3431419079259611]","[0.8881480583505071, 0.7362525458248472]"
2,test,0.790149,0.586061,0.521678,"[0.9440745672436751, 0.20983935742971888]","[0.8183287165281625, 0.4988066825775656]"
3,train,0.849163,0.665361,0.361623,"[0.9760569852941177, 0.29145627146031106]","[0.8582454438921889, 0.734725050916497]"
3,test,0.751842,0.551804,0.553447,"[0.9387698302254384, 0.17184801381692574]","[0.778624192059095, 0.47494033412887826]"
4,train,0.95354,0.808503,0.244142,"[0.9662038873462911, 0.7408394403730846]","[0.9842809229401543, 0.5661914460285132]"
4,test,0.890339,0.644035,0.435605,"[0.9359414321665522, 0.36578947368421055]","[0.9443674976915974, 0.3317422434367542]"
5,train,0.940998,0.802461,0.233504,"[0.9761229605885012, 0.5815811606391926]","[0.9597931062350992, 0.7041751527494908]"
5,test,0.840244,0.625963,0.466015,"[0.9449564134495642, 0.26902173913043476]","[0.8758079409048938, 0.47255369928400953]"
