In [1]:
import torch
from torch import nn

import torchvision
from torchvision.datasets import ImageFolder

from torchvision import transforms

from torch.utils.data import DataLoader
from pathlib import Path
from torchvision.models import vgg16

In [2]:
import sys
sys.path.append("..")

In [3]:
from video_classification.datasets import FolderOfFrameFoldersDataset

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
ROOT = Path("/home/ubuntu/SupervisedVideoClassification")
DATA_ROOT = Path(ROOT/"data")

In [6]:
train_transforms = transforms.Compose([
    torchvision.transforms.ColorJitter(),
    transforms.RandomHorizontalFlip(p=0.25),
    transforms.RandomVerticalFlip(p=0.25),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
])

valid_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
])

In [7]:
train_ds = FolderOfFrameFoldersDataset(DATA_ROOT/'train', transform=train_transforms)
valid_ds = FolderOfFrameFoldersDataset(DATA_ROOT/'validation', transform=valid_transforms)

In [8]:
train_ds

FolderOfFrameFoldersDataset with 26711 samples.
	Overall data distribution: {'negative': 24747, 'positive': 1964}

In [9]:
valid_ds

FolderOfFrameFoldersDataset with 4751 samples.
	Overall data distribution: {'negative': 4332, 'positive': 419}

In [10]:
class SingleImageModel(nn.Module):
    def __init__(self, mlp_sizes=[768, 128, 2]):
        super().__init__()
        self.vgg = vgg16(pretrained=True)
        self.vgg.classifier = nn.Sequential(self.vgg.classifier[:-1])  # Remove imagenet output layer
        in_features = 4096  # vgg feats
        out_features = mlp_sizes[0]

        layers = []
        for i, size in enumerate(mlp_sizes):
            out_features = mlp_sizes[i]

            layers.append(nn.Linear(in_features, out_features))
            layers.append(nn.ReLU())
            layers.append(nn.BatchNorm1d(out_features)),
            layers.append(nn.Dropout(p=0.3))
            in_features = out_features

        layers.pop()  # Remove last dropout
        layers.pop()  # Remove last BN
        layers.pop()  # Remove last ReLU
        self.clf = nn.Sequential(*layers)
        self.freeze_vgg()
        
    def forward(self, x):
        x = self.vgg(x)
        x = self.clf(x)
        return x
    
    def freeze_vgg(self):
        # Freeze the VGG classifier
        for p in self.vgg.parameters():
            p.requires_grad = False
            
    def unfreeze_vgg(self):
        # Unfreeze the VGG classifier. Training the whole VGG is a no-go, so we only train the classifier part.
        for p in self.vgg.classifier[1:].parameters():
            p.requires_grad = True 

In [11]:
model = SingleImageModel(mlp_sizes=[1024, 256, 2])

model = model.to(device)

In [12]:
from video_classification.trainer import Trainer

classes_weights = torch.Tensor([0.3, 1.0]).to(device)
criterion = nn.CrossEntropyLoss(weight=classes_weights)

In [14]:
trainer = Trainer(train_ds, 
                  valid_ds, 
                  model, 
                  criterion,
                  "single_frame_vgg",
                  str(ROOT/'checkpoints'),
                  device=device,
                  amp_opt_level="O1",
                 )

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


In [15]:
trainer.train(lr=1e-3, 
              batch_size=48, 
              n_epochs=7,
              gradient_accumulation_steps=4,
              num_workers=8,
              max_gradient_norm=2.0,
             )



HBox(children=(IntProgress(value=0, max=557), HTML(value='')))

Training Results - Epoch: 1: Avg accuracy: 0.93 |Precision: 0.96, 0.50 |Recall: 0.96, 0.45 | F1: 0.72 | Avg loss: 0.35
Validation Results - Epoch: 1: Avg accuracy: 0.82 |Precision: 0.93, 0.20 |Recall: 0.86, 0.36 | F1: 0.58 | Avg loss: 0.56


HBox(children=(IntProgress(value=0, max=557), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
Training Results - Epoch: 2: Avg accuracy: 0.95 |Precision: 0.95, 0.92 |Recall: 1.00, 0.38 | F1: 0.76 | Avg loss: 0.31
Validation Results - Epoch: 2: Avg accuracy: 0.88 |Precision: 0.93, 0.32 |Recall: 0.93, 0.33 | F1: 0.63 | Avg loss: 0.48


HBox(children=(IntProgress(value=0, max=557), HTML(value='')))

Training Results - Epoch: 3: Avg accuracy: 0.95 |Precision: 0.95, 0.96 |Recall: 1.00, 0.36 | F1: 0.75 | Avg loss: 0.31
Validation Results - Epoch: 3: Avg accuracy: 0.92 |Precision: 0.93, 0.58 |Recall: 0.98, 0.24 | F1: 0.65 | Avg loss: 0.41


HBox(children=(IntProgress(value=0, max=557), HTML(value='')))

Training Results - Epoch: 4: Avg accuracy: 0.95 |Precision: 0.95, 0.97 |Recall: 1.00, 0.37 | F1: 0.75 | Avg loss: 0.32
Validation Results - Epoch: 4: Avg accuracy: 0.91 |Precision: 0.93, 0.49 |Recall: 0.98, 0.19 | F1: 0.61 | Avg loss: 0.48


HBox(children=(IntProgress(value=0, max=557), HTML(value='')))

Training Results - Epoch: 5: Avg accuracy: 0.95 |Precision: 0.96, 0.82 |Recall: 0.99, 0.44 | F1: 0.77 | Avg loss: 0.29
Validation Results - Epoch: 5: Avg accuracy: 0.91 |Precision: 0.93, 0.44 |Recall: 0.97, 0.27 | F1: 0.64 | Avg loss: 0.46


HBox(children=(IntProgress(value=0, max=557), HTML(value='')))

Training Results - Epoch: 6: Avg accuracy: 0.96 |Precision: 0.96, 0.96 |Recall: 1.00, 0.42 | F1: 0.78 | Avg loss: 0.28
Validation Results - Epoch: 6: Avg accuracy: 0.91 |Precision: 0.93, 0.48 |Recall: 0.97, 0.29 | F1: 0.66 | Avg loss: 0.42


HBox(children=(IntProgress(value=0, max=557), HTML(value='')))

Training Results - Epoch: 7: Avg accuracy: 0.95 |Precision: 0.96, 0.94 |Recall: 1.00, 0.41 | F1: 0.77 | Avg loss: 0.29
Validation Results - Epoch: 7: Avg accuracy: 0.91 |Precision: 0.93, 0.51 |Recall: 0.98, 0.24 | F1: 0.64 | Avg loss: 0.45


In [16]:
import pandas as pd
reform = {(outerKey, innerKey): values for outerKey, innerDict in trainer.epoch_state.items() for innerKey, values in innerDict.items()}
pd.DataFrame(reform).T

Unnamed: 0,Unnamed: 1,accuracy,f1,nll,precision,recall
1,train,0.926547,0.715988,0.349815,"[0.9563754356447542, 0.5005720823798627]","[0.964722996726876, 0.4455193482688391]"
1,test,0.817512,0.577148,0.557422,"[0.9330167458135467, 0.20133333333333334]","[0.8617266851338874, 0.360381861575179]"
2,train,0.952005,0.75577,0.310784,"[0.9528698807272166, 0.9241293532338308]","[0.9975350547541116, 0.37830957230142565]"
2,test,0.879815,0.629151,0.47519,"[0.9347976878612717, 0.3215962441314554]","[0.9332871652816251, 0.3269689737470167]"
3,train,0.951818,0.749244,0.311438,"[0.9516402279377791, 0.9580514208389715]","[0.9987473229078272, 0.3604887983706721]"
3,test,0.917912,0.647611,0.409342,"[0.9303493449781659, 0.5847953216374269]","[0.9836103416435826, 0.2386634844868735]"
4,train,0.952641,0.753875,0.321788,"[0.9521333949476278, 0.9703903095558546]","[0.9991110033539419, 0.36710794297352345]"
4,test,0.910966,0.614743,0.481181,"[0.9262813522355507, 0.4879518072289157]","[0.9803785780240074, 0.19331742243436753]"
5,train,0.95163,0.773135,0.29255,"[0.9570877343415053, 0.818785578747628]","[0.9922818927546774, 0.4394093686354379]"
5,test,0.905915,0.642591,0.463544,"[0.9319546364242829, 0.4448818897637795]","[0.9674515235457064, 0.26968973747016706]"


In [16]:
trainer.evaluator.state.metrics

{'accuracy': 0.7371079772679436,
 'nll': 0.5739234951709953,
 'precision': [0.9280755345737295, 0.1391304347826087],
 'recall': [0.7714681440443213, 0.3818615751789976],
 'f1': 0.5232539857186665}

In [17]:
trainer.evaluator.state.metrics

{'accuracy': 0.9122290044201221,
 'nll': 0.661812626883949,
 'precision': [0.9121920404295641, 1.0],
 'recall': [1.0, 0.00477326968973747],
 'f1': 0.48179056739542064}

In [1]:
!nvidia-smi

Sun Jul  7 18:21:20 2019       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.40.04    Driver Version: 418.40.04    CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla V100-SXM2...  On   | 00000000:00:1E.0 Off |                    0 |
| N/A   41C    P0    27W / 300W |      0MiB / 16130MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage    

# Error analysis

In [None]:
best_model_weights = torch.load(str(ROOT/'checkpoints/single_frame_vgg_SingleImageModel_6_f1=0.6585821.pth'))
model.load_state_dict(best_model_weights)

In [27]:
from tqdm import tqdm
import numpy as np

valid_loader = DataLoader(valid_ds, batch_size=48, shuffle=False)
y_probs = []
y_true = []
with torch.no_grad():
    for i, (x, y) in enumerate(tqdm(valid_loader)):
        x = x.to(device)
        x = model(x)
        batch_pred = torch.softmax(x, dim=-1).cpu().tolist()
        batch_true = y.tolist()
        y_probs.extend((y for y in batch_pred))
        y_true.extend((y for y in batch_true))
        
y_probs = np.array(y_probs)
y_true = np.array(y_true)
y_pred = np.argmax(y_probs, 1)

100%|██████████| 99/99 [01:23<00:00,  1.18it/s]


In [44]:
np.stack([y_pred, y_true], 1).shape

(4751, 2)

array([0, 0, 0, ..., 0, 0, 0])

In [68]:
false_positives = []
false_negatives = []
true_positives = []
true_negatives = []

for i, (pred, real) in enumerate(zip(y_pred, y_true)):
    if pred == 1 and real == 1:
        true_positives.append(i)
    elif pred == 0 and real == 0:
        true_negatives.append(i)
    elif pred == 1 and real == 0:
        false_positives.append(i)
    elif pred == 0 and real == 1:
        false_negatives.append(i)

In [53]:
valid_ds = FolderOfFrameFoldersDataset(DATA_ROOT/'validation', transform=None)

In [72]:
true_positives

[1444,
 1448,
 1596,
 1647,
 1648,
 1649,
 1650,
 1651,
 1652,
 1653,
 1654,
 1655,
 1656,
 1657,
 1658,
 1659,
 1660,
 1661,
 1662,
 1663,
 1664,
 1665,
 1666,
 1667,
 1668,
 1669,
 1670,
 1671,
 1672,
 1673,
 1674,
 1675,
 1676,
 1677,
 1678,
 1679,
 1680,
 1681,
 1682,
 1683,
 1684,
 1685,
 1686,
 1687,
 1688,
 1689,
 1690,
 1691,
 1692,
 1693,
 1694,
 1695,
 1696,
 1697,
 1698,
 1699,
 1700,
 1701,
 1702,
 1703,
 1704,
 1705,
 1706,
 1707,
 1708,
 1709,
 1710,
 1711,
 1712,
 1713,
 1714,
 1715,
 1716,
 1717,
 1718,
 1719,
 1720,
 1721,
 1722,
 1723,
 1724,
 1725,
 1726,
 1727,
 1728,
 1729,
 1730,
 1731,
 1732,
 1733,
 1734,
 1735,
 1736,
 1737,
 1738,
 1739,
 1740,
 1741,
 1742,
 1743,
 1744,
 1745,
 1746,
 1747,
 1748,
 1749,
 1750,
 1751,
 1752,
 1753,
 1754,
 1755,
 1756,
 1757,
 1758,
 1759,
 2001,
 2002,
 2003,
 2004,
 2005,
 2006,
 2007,
 2008,
 2174,
 2175,
 2176,
 2177,
 2178,
 2179,
 2180,
 2269,
 2270,
 2271,
 2272,
 2273,
 2274,
 2275,
 2276,
 2277,
 2278]

In [None]:
valid_ds[2274][0]

In [85]:
false_negatives

[157,
 158,
 159,
 160,
 161,
 162,
 163,
 164,
 165,
 166,
 167,
 168,
 169,
 170,
 171,
 382,
 383,
 384,
 385,
 386,
 387,
 388,
 389,
 390,
 391,
 392,
 393,
 394,
 395,
 998,
 999,
 1000,
 1001,
 1002,
 1003,
 1004,
 1005,
 1006,
 1007,
 1008,
 1009,
 1162,
 1163,
 1164,
 1165,
 1166,
 1167,
 1168,
 1169,
 1170,
 1171,
 1372,
 1373,
 1374,
 1375,
 1376,
 1377,
 1378,
 1379,
 1443,
 1445,
 1446,
 1447,
 1449,
 1450,
 1451,
 1591,
 1592,
 1593,
 1594,
 1595,
 1597,
 1598,
 2606,
 2607,
 2608,
 2609,
 2610,
 2611,
 2612,
 2613,
 2614,
 2615,
 2616,
 2617,
 2618,
 2619,
 2620,
 2621,
 2622,
 2623,
 2624,
 2625,
 2626,
 2627,
 2628,
 2629,
 2630,
 2631,
 2632,
 2633,
 2634,
 2635,
 2636,
 2637,
 2638,
 2639,
 2640,
 2641,
 2642,
 2643,
 2644,
 2645,
 2973,
 2974,
 2975,
 2976,
 2977,
 2978,
 2979,
 2980,
 2981,
 2982,
 2983,
 2984,
 2985,
 2986,
 2987,
 2988,
 2989,
 2990,
 2991,
 2992,
 2993,
 2994,
 2995,
 2996,
 2997,
 2998,
 2999,
 3000,
 3001,
 3002,
 3003,
 3004,
 3005,
 3006,
 3

In [None]:
valid_ds[163][0]

In [None]:
valid_ds[1000][0]

In [None]:
valid_ds[1163][0]

In [None]:
valid_ds[1448][0]