In [1]:
import torch
from torch import nn

import torchvision
from torchvision.datasets import ImageFolder

from torchvision import transforms

from torch.utils.data import DataLoader
from pathlib import Path
from torchvision.models import vgg16

In [2]:
import sys
sys.path.append("..")

In [3]:
from video_classification.datasets import FolderOfFrameFoldersDataset, FrameWindowDataset

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
ROOT = Path("/home/ubuntu/SupervisedVideoClassification")
DATA_ROOT = Path(ROOT/"data")

In [6]:
train_transforms = transforms.Compose([
    torchvision.transforms.ColorJitter(),
    transforms.RandomHorizontalFlip(p=0.25),
    transforms.RandomVerticalFlip(p=0.25),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
])

valid_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
])

In [7]:
train_ds = FolderOfFrameFoldersDataset(DATA_ROOT/'train', 
                                       transform=train_transforms, 
                                       base_class=FrameWindowDataset,
                                       window_size=3,
                                       overlapping=True,)
valid_ds = FolderOfFrameFoldersDataset(DATA_ROOT/'validation', 
                                       transform=valid_transforms, 
                                       base_class=FrameWindowDataset,
                                       window_size=3,
                                       overlapping=True,)

In [8]:
import torch
from torch import nn
from video_classification.models.single_image import SingleImageModel
from video_classification.models.mlp import MLP


class MultiImageModel(nn.Module):
    def __init__(self,
                 window_size=3,
                 single_mlp_sizes=[768, 128],
                 joint_mlp_sizes=[64, 2]):
        super().__init__()
        self.window_size = window_size
        self.single_mlp_sizes = single_mlp_sizes
        self.joint_mlp_sizes = joint_mlp_sizes
        
        self.single_image_model = SingleImageModel(self.single_mlp_sizes)
        self.in_features = self.single_mlp_sizes[-1] * self.window_size
        self.clf = MLP(self.in_features, joint_mlp_sizes)

    def forward(self, x):
        # x is of size [B, T, C, H, W]. In other words, a batch of windows.
        # each img for the same window goes through SingleImageModel
        x = x.transpose(0, 1)  # -> [T, B, C, H, W]
        x = torch.cat([self.single_image_model(window) for window in x], 1)
        # x is now of size [B, T * single_mlp_sizes[-1]]
        
        x = self.clf(x)
        # Now size is [B, joint_mlp_sizes[-1]] which should always be 2

        return x
    
    def freeze_single_image_model(self):
        # Freeze the VGG classifier
        for p in self.single_image_model.parameters():
            p.requires_grad = False
            
    def unfreeze_single_image_model(self):
        # Unfreeze the VGG classifier. Training the whole VGG is a no-go, so we only train the classifier part.
        for p in self.single_image_model.clf.parameters():
            p.requires_grad = True 

In [9]:
model = MultiImageModel(
                 window_size=3,
                 single_mlp_sizes=[1024, 256],
                 joint_mlp_sizes=[128, 2])

model = model.to(device)

In [10]:
x = torch.stack([train_ds[0][0], train_ds[1][0], train_ds[2][0], train_ds[3][0]]).to(device)

In [11]:
model(x)

tensor([[ 0.1265, -0.1215],
        [-0.2513, -0.5642],
        [ 0.0864,  0.3089],
        [-0.2790,  0.4778]], device='cuda:0', grad_fn=<AddmmBackward>)

In [12]:
from video_classification.trainer import Trainer

classes_weights = torch.Tensor([0.3, 1.0]).to(device)
criterion = nn.CrossEntropyLoss(weight=classes_weights)

In [13]:
trainer = Trainer(train_ds, 
                  valid_ds, 
                  model, 
                  criterion,
                  "multi_frame_vgg_from_scratch",
                  str(ROOT/'checkpoints'),
                  device=device,
                  amp_opt_level="O1",
                 )

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


In [14]:
# First 3 epochs: only joint MLP unfrozen, high learning rate

trainer.train(lr=1e-3, 
              batch_size=48, 
              n_epochs=20,
              gradient_accumulation_steps=8,
              num_workers=8,
              max_gradient_norm=2.0,
             )



HBox(children=(IntProgress(value=0, max=556), HTML(value='')))

Training Results - Epoch: 1: Avg accuracy: 0.94 |Precision: 0.96, 0.59 |Recall: 0.97, 0.51 | F1: 0.76 | Avg loss: 0.41
Validation Results - Epoch: 1: Avg accuracy: 0.68 |Precision: 0.93, 0.12 |Recall: 0.70, 0.43 | F1: 0.50 | Avg loss: 0.68


HBox(children=(IntProgress(value=0, max=556), HTML(value='')))

Training Results - Epoch: 2: Avg accuracy: 0.95 |Precision: 0.96, 0.87 |Recall: 0.99, 0.44 | F1: 0.78 | Avg loss: 0.31
Validation Results - Epoch: 2: Avg accuracy: 0.91 |Precision: 0.94, 0.49 |Recall: 0.97, 0.31 | F1: 0.66 | Avg loss: 0.42


HBox(children=(IntProgress(value=0, max=556), HTML(value='')))

Training Results - Epoch: 3: Avg accuracy: 0.96 |Precision: 0.96, 0.96 |Recall: 1.00, 0.43 | F1: 0.79 | Avg loss: 0.29
Validation Results - Epoch: 3: Avg accuracy: 0.87 |Precision: 0.93, 0.27 |Recall: 0.92, 0.31 | F1: 0.61 | Avg loss: 0.50


HBox(children=(IntProgress(value=0, max=556), HTML(value='')))

Training Results - Epoch: 4: Avg accuracy: 0.96 |Precision: 0.96, 0.94 |Recall: 1.00, 0.46 | F1: 0.80 | Avg loss: 0.28
Validation Results - Epoch: 4: Avg accuracy: 0.92 |Precision: 0.93, 0.73 |Recall: 0.99, 0.20 | F1: 0.64 | Avg loss: 0.43


HBox(children=(IntProgress(value=0, max=556), HTML(value='')))

Training Results - Epoch: 5: Avg accuracy: 0.95 |Precision: 0.96, 0.85 |Recall: 0.99, 0.47 | F1: 0.79 | Avg loss: 0.27
Validation Results - Epoch: 5: Avg accuracy: 0.84 |Precision: 0.93, 0.21 |Recall: 0.89, 0.31 | F1: 0.58 | Avg loss: 0.54


HBox(children=(IntProgress(value=0, max=556), HTML(value='')))

Training Results - Epoch: 6: Avg accuracy: 0.95 |Precision: 0.96, 0.82 |Recall: 0.99, 0.49 | F1: 0.79 | Avg loss: 0.26
Validation Results - Epoch: 6: Avg accuracy: 0.91 |Precision: 0.93, 0.48 |Recall: 0.97, 0.27 | F1: 0.65 | Avg loss: 0.41


HBox(children=(IntProgress(value=0, max=556), HTML(value='')))

Training Results - Epoch: 7: Avg accuracy: 0.95 |Precision: 0.95, 0.98 |Recall: 1.00, 0.39 | F1: 0.77 | Avg loss: 0.32
Validation Results - Epoch: 7: Avg accuracy: 0.90 |Precision: 0.93, 0.40 |Recall: 0.97, 0.23 | F1: 0.62 | Avg loss: 0.48


HBox(children=(IntProgress(value=0, max=556), HTML(value='')))

Training Results - Epoch: 8: Avg accuracy: 0.96 |Precision: 0.96, 0.84 |Recall: 0.99, 0.53 | F1: 0.81 | Avg loss: 0.25
Validation Results - Epoch: 8: Avg accuracy: 0.77 |Precision: 0.93, 0.16 |Recall: 0.81, 0.36 | F1: 0.54 | Avg loss: 0.68


HBox(children=(IntProgress(value=0, max=556), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 131072.0
Training Results - Epoch: 9: Avg accuracy: 0.95 |Precision: 0.96, 0.73 |Recall: 0.98, 0.55 | F1: 0.80 | Avg loss: 0.25
Validation Results - Epoch: 9: Avg accuracy: 0.90 |Precision: 0.93, 0.41 |Recall: 0.96, 0.27 | F1: 0.64 | Avg loss: 0.47


HBox(children=(IntProgress(value=0, max=556), HTML(value='')))

Training Results - Epoch: 10: Avg accuracy: 0.96 |Precision: 0.96, 0.90 |Recall: 1.00, 0.46 | F1: 0.79 | Avg loss: 0.28
Validation Results - Epoch: 10: Avg accuracy: 0.90 |Precision: 0.93, 0.37 |Recall: 0.97, 0.21 | F1: 0.61 | Avg loss: 0.56


HBox(children=(IntProgress(value=0, max=556), HTML(value='')))

Training Results - Epoch: 11: Avg accuracy: 0.95 |Precision: 0.96, 0.73 |Recall: 0.98, 0.53 | F1: 0.80 | Avg loss: 0.24
Validation Results - Epoch: 11: Avg accuracy: 0.92 |Precision: 0.93, 0.56 |Recall: 0.98, 0.28 | F1: 0.66 | Avg loss: 0.45


HBox(children=(IntProgress(value=0, max=556), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 131072.0
Training Results - Epoch: 12: Avg accuracy: 0.96 |Precision: 0.96, 0.96 |Recall: 1.00, 0.47 | F1: 0.80 | Avg loss: 0.25
Validation Results - Epoch: 12: Avg accuracy: 0.91 |Precision: 0.93, 0.47 |Recall: 0.97, 0.25 | F1: 0.64 | Avg loss: 0.47


HBox(children=(IntProgress(value=0, max=556), HTML(value='')))

Training Results - Epoch: 13: Avg accuracy: 0.96 |Precision: 0.96, 0.96 |Recall: 1.00, 0.49 | F1: 0.81 | Avg loss: 0.25
Validation Results - Epoch: 13: Avg accuracy: 0.90 |Precision: 0.93, 0.43 |Recall: 0.96, 0.30 | F1: 0.65 | Avg loss: 0.51


HBox(children=(IntProgress(value=0, max=556), HTML(value='')))

Training Results - Epoch: 14: Avg accuracy: 0.96 |Precision: 0.96, 0.97 |Recall: 1.00, 0.48 | F1: 0.81 | Avg loss: 0.25
Validation Results - Epoch: 14: Avg accuracy: 0.92 |Precision: 0.93, 0.58 |Recall: 0.98, 0.27 | F1: 0.66 | Avg loss: 0.45


HBox(children=(IntProgress(value=0, max=556), HTML(value='')))

Training Results - Epoch: 15: Avg accuracy: 0.96 |Precision: 0.96, 0.95 |Recall: 1.00, 0.53 | F1: 0.83 | Avg loss: 0.23
Validation Results - Epoch: 15: Avg accuracy: 0.92 |Precision: 0.94, 0.54 |Recall: 0.97, 0.31 | F1: 0.67 | Avg loss: 0.41


HBox(children=(IntProgress(value=0, max=556), HTML(value='')))

Training Results - Epoch: 16: Avg accuracy: 0.96 |Precision: 0.96, 0.96 |Recall: 1.00, 0.45 | F1: 0.80 | Avg loss: 0.25
Validation Results - Epoch: 16: Avg accuracy: 0.92 |Precision: 0.93, 0.64 |Recall: 0.99, 0.24 | F1: 0.65 | Avg loss: 0.42


HBox(children=(IntProgress(value=0, max=556), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 131072.0
Training Results - Epoch: 17: Avg accuracy: 0.96 |Precision: 0.96, 0.95 |Recall: 1.00, 0.51 | F1: 0.82 | Avg loss: 0.24
Validation Results - Epoch: 17: Avg accuracy: 0.92 |Precision: 0.93, 0.59 |Recall: 0.98, 0.29 | F1: 0.67 | Avg loss: 0.44


HBox(children=(IntProgress(value=0, max=556), HTML(value='')))

Training Results - Epoch: 18: Avg accuracy: 0.96 |Precision: 0.96, 0.96 |Recall: 1.00, 0.50 | F1: 0.82 | Avg loss: 0.24
Validation Results - Epoch: 18: Avg accuracy: 0.93 |Precision: 0.93, 0.79 |Recall: 0.99, 0.26 | F1: 0.68 | Avg loss: 0.48


HBox(children=(IntProgress(value=0, max=556), HTML(value='')))

Training Results - Epoch: 19: Avg accuracy: 0.96 |Precision: 0.96, 0.96 |Recall: 1.00, 0.50 | F1: 0.82 | Avg loss: 0.23
Validation Results - Epoch: 19: Avg accuracy: 0.92 |Precision: 0.93, 0.61 |Recall: 0.98, 0.26 | F1: 0.66 | Avg loss: 0.47


HBox(children=(IntProgress(value=0, max=556), HTML(value='')))

KeyboardInterrupt: 

In [16]:
import pandas as pd
reform = {(outerKey, innerKey): values for outerKey, innerDict in trainer.epoch_state.items() for innerKey, values in innerDict.items()}
df = pd.DataFrame(reform).T
df

Unnamed: 0,Unnamed: 1,accuracy,f1,nll,precision,recall
1,train,0.937559,0.757145,0.413565,"[0.9616572414069707, 0.5874125874125874]","[0.9713186145432449, 0.5132382892057027]"
1,test,0.679182,0.496446,0.678223,"[0.9276335877862596, 0.12414733969986358]","[0.7029153169828783, 0.4343675417661098]"
2,train,0.954032,0.781115,0.305601,"[0.9573099415204679, 0.8698698698698699]","[0.9947336439133077, 0.4424643584521385]"
2,test,0.910989,0.664824,0.417051,"[0.9350736278447122, 0.4942084942084942]","[0.9696899583526145, 0.3054892601431981]"
3,train,0.956959,0.787474,0.285249,"[0.956835649406102, 0.9605411499436303]","[0.9985821348997367, 0.43380855397148677]"
3,test,0.865007,0.607937,0.503948,"[0.9323626115547206, 0.2712215320910973]","[0.9185562239703841, 0.3126491646778043]"
4,train,0.95816,0.798594,0.275739,"[0.9588459741473291, 0.9398963730569948]","[0.9976503949767065, 0.46181262729124234]"
4,test,0.92259,0.635217,0.431444,"[0.9273827534039334, 0.7280701754385965]","[0.9928273947246645, 0.19809069212410502]"
5,train,0.95467,0.790076,0.27285,"[0.9592347717225461, 0.8474264705882353]","[0.9932752683816083, 0.4694501018329939]"
5,test,0.838009,0.579599,0.544394,"[0.929642166344294, 0.2115702479338843]","[0.8896344285053216, 0.3054892601431981]"
