In [None]:
import random
from esper.prelude import *
from rekall.video_interval_collection import VideoIntervalCollection
from rekall.temporal_predicates import *
from esper.rekall import *
import matplotlib.pyplot as plt
import cv2
import pickle

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch import optim
from torch.optim import lr_scheduler

from collections import OrderedDict
import scannertools as st

import esper.shot_detection_torch.models.deepsbd_resnet as deepsbd_resnet
import esper.shot_detection_torch.models.deepsbd_alexnet as deepsbd_alexnet
import esper.shot_detection_torch.dataloaders.movies_deepsbd as movies_deepsbd_data

In [None]:
st.init_storage(os.environ['BUCKET'])

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Construct five folds

In [None]:
# Load up all manually annotated shots
shots_qs = Shot.objects.filter(labeler__name__contains='manual')

In [None]:
shots = VideoIntervalCollection.from_django_qs(shots_qs)

In [None]:
video_ids = sorted(list(shots.get_allintervals().keys()))

In [None]:
random.seed(0)

In [None]:
# randomly shuffle video IDs
random.shuffle(video_ids)

In [None]:
# construct five folds
total_shots = shots_qs.count()
folds = []
num_shots_in_folds = 0
cur_fold = []
for video_id in video_ids:
    if num_shots_in_folds + shots.get_intervallist(video_id).size() > (len(folds) + 1) * total_shots / 5:
        folds.append(cur_fold)
        cur_fold = []
    num_shots_in_folds += shots.get_intervallist(video_id).size()
    cur_fold.append(video_id)
folds.append(cur_fold)

In [None]:
# store folds
with open('/app/data/shot_detection_folds.pkl', 'wb') as f:
    pickle.dump(folds, f)

In [None]:
# or load folds from disk
with open('/app/data/shot_detection_folds.pkl', 'rb') as f:
    folds = pickle.load(f)

In [None]:
# store shot intervals in pickle file
with open('/app/data/manually_annotated_shots.pkl', 'wb') as f:
    pickle.dump({
        video_id: [
            (interval.start, interval.end, interval.payload)
            for interval in shots.get_intervallist(video_id).get_intervals()
        ]
        for video_id in shots.get_allintervals()
    }, f)

# Heuristic Evaluation

In [None]:
clips = shots.dilate(1).coalesce().dilate(-1)

In [None]:
cinematic_shots_qs = Shot.objects.filter(cinematic=True, video_id__in=video_ids).all()
cinematic_shots = VideoIntervalCollection.from_django_qs(
    cinematic_shots_qs,
    progress = True
).filter_against(clips, predicate=overlaps())

In [None]:
cinematic_shot_boundaries = cinematic_shots.map(lambda i: (i.start, i.start, i.payload)).set_union(
    cinematic_shots.map(lambda i: (i.end + 1, i.end + 1, i.payload))
).coalesce()
gt_shot_boundaries = shots.map(lambda i: (i.start, i.start, i.payload)).set_union(
    shots.map(lambda i: (i.end + 1, i.end + 1, i.payload))
).coalesce()

In [None]:
for fold in folds:
    tp = 0
    fp = 0
    fn = 0
    
    for video_id in fold:
        cine_sb = cinematic_shot_boundaries.get_intervallist(video_id)
        gt_sb = gt_shot_boundaries.get_intervallist(video_id)
        
        accurate_sb = cine_sb.filter_against(gt_sb, predicate=overlaps())
        inaccurate_sb = cine_sb.minus(accurate_sb)

        found_human_sb = gt_sb.filter_against(cine_sb, predicate=overlaps())
        missed_human_sb = gt_sb.minus(found_human_sb)
        
        tp += accurate_sb.size()
        fp += inaccurate_sb.size()
        fn += missed_human_sb.size()
    
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    print('Precision: {}, {} out of {}'.format(
        precision,
        tp,
        tp + fp
    ))
    print('Recall: {}, {} out of {}'.format(
        recall,
        tp,
        tp + fn
    ))
    print('F1: {}'.format(2 * precision * recall / (precision + recall)))
    print()

Heuristics:
* Fold 1:
  * Precision: 0.8512396694214877, 103 out of 121
  * Recall: 0.8373983739837398, 103 out of 123
  * F1: 0.8442622950819672
* Fold 2:
  * Precision: 0.948051948051948, 73 out of 77
  * Recall: 0.7448979591836735, 73 out of 98
  * F1: 0.8342857142857143
* Fold 3:
  * Precision: 0.8829787234042553, 166 out of 188
  * Recall: 0.9431818181818182, 166 out of 176
  * F1: 0.9120879120879122
* Fold 4:
  * Precision: 0.8571428571428571, 78 out of 91
  * Recall: 0.7878787878787878, 78 out of 99
  * F1: 0.8210526315789474
* Fold 5:
  * Precision: 0.9090909090909091, 110 out of 121
  * Recall: 0.8396946564885496, 110 out of 131
  * F1: 0.873015873015873

Average F1: .857

In [None]:
# Heuristic, window version
stride = 8
window_size = 16
clips_window = shots.dilate(1).coalesce().dilate(-1).map(
    lambda intrvl: (
        intrvl.start - stride - ((intrvl.start - stride) % stride),
        intrvl.end + stride - ((intrvl.end - stride) % stride),
        intrvl.payload
    )
).dilate(1).coalesce().dilate(-1)

In [None]:
items_intrvls = {}
for video_id in clips_window.get_allintervals():
    items_intrvls[video_id] = []
    for intrvl in clips_window.get_intervallist(video_id).get_intervals():
        items_intrvls[video_id] += [
            (f, f + window_size, 0)
            for f in range(intrvl.start, intrvl.end - stride, stride)
        ]
items_col = VideoIntervalCollection(items_intrvls)

items_w_gt_boundaries = items_col.filter_against(
    gt_shot_boundaries,
    predicate=during_inv()
).map(
    lambda intrvl: (intrvl.start, intrvl.end, 2)
)

items_w_gt_labels = items_col.minus(
    items_w_gt_boundaries, predicate=equal()
).set_union(items_w_gt_boundaries)

items_w_cinematic_boundaries = items_col.filter_against(
    cinematic_shot_boundaries,
    predicate=during_inv()
).map(
    lambda intrvl: (intrvl.start, intrvl.end, 2)
)

items_w_cinematic_labels = items_col.minus(
    items_w_cinematic_boundaries, predicate=equal()
).set_union(items_w_cinematic_boundaries)

In [None]:
for fold in folds:
    tp = 0
    tn = 0
    fp = 0
    fn = 0
    
    for video_id in fold:
        cine_items = items_w_cinematic_labels.get_intervallist(video_id)
        gt_items = items_w_gt_labels.get_intervallist(video_id)
        
        for cine_item, gt_item in zip(cine_items.get_intervals(), gt_items.get_intervals()):
            if cine_item.payload == gt_item.payload:
                if cine_item.payload == 2:
                    tp += 1
                else:
                    tn += 1
            else:
                if cine_item.payload == 2:
                    fp += 1
                else:
                    fn += 1
    
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    print('Precision: {}, {} out of {}'.format(
        precision,
        tp,
        tp + fp
    ))
    print('Recall: {}, {} out of {}'.format(
        recall,
        tp,
        tp + fn
    ))
    print('F1: {}'.format(2 * precision * recall / (precision + recall)))
    print('TP: {} TN: {} FP: {} FN: {}'.format(tp, tn, fp, fn))
    print()

```
Precision: 0.8916666666666667, 214 out of 240
Recall: 0.8629032258064516, 214 out of 248
F1: 0.8770491803278689
TP: 214 TN: 1321 FP: 26 FN: 34

Precision: 0.972027972027972, 139 out of 143
Recall: 0.7473118279569892, 139 out of 186
F1: 0.844984802431611
TP: 139 TN: 328 FP: 4 FN: 47

Precision: 0.8919667590027701, 322 out of 361
Recall: 0.9817073170731707, 322 out of 328
F1: 0.9346879535558781
TP: 322 TN: 2297 FP: 39 FN: 6

Precision: 0.8802395209580839, 147 out of 167
Recall: 0.8032786885245902, 147 out of 183
F1: 0.84
TP: 147 TN: 900 FP: 20 FN: 36

Precision: 0.9184549356223176, 214 out of 233
Recall: 0.852589641434263, 214 out of 251
F1: 0.8842975206611571
TP: 214 TN: 1148 FP: 19 FN: 37

Average F1: 0.876
```

# DeepSBD Evaluation

In [None]:
# helper functions for deepsbd testing
def calculate_accuracy(outputs, targets):
    batch_size = targets.size(0)

    _, pred = outputs.topk(1, 1, True)
    pred = pred.t()
    correct = pred.eq(targets.view(1, -1))
    n_correct_elems = correct.float().sum().item()

    return n_correct_elems / batch_size

def prf1_array(pos_label, neg_label, gt, preds):
    tp = 0.
    fp = 0.
    tn = 0.
    fn = 0.
    
    for truth, pred in zip(gt, preds):
        if truth == pred:
            if pred == pos_label:
                tp += 1.
            else:
                tn += 1.
        else:
            if pred == pos_label:
                fp += 1.
            else:
                fn += 1.
    
    precision = tp / (tp + fp) if tp + fp != 0 else 0
    recall = tp / (tp + fn) if tp + fn != 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if precision + recall != 0 else 0
    
    return (precision, recall, f1, tp, tn, fp, fn)

def get_label(res_tensor):
    res_numpy=res_tensor.data.cpu().numpy()
    labels=[]
    for row in res_numpy:
        labels.append(np.argmax(row))
    return labels

def test_deepsbd(model, dataloader):
    preds = []
    labels = []
    outputs = []
    for clip_tensor, l, _ in tqdm(dataloader):
        o = model(clip_tensor.to(device))

        preds += get_label(o)
        labels += l.data.numpy().tolist()
        outputs += o.cpu().data.numpy().tolist()
    
    preds = [2 if p == 2 else 0 for p in preds]
        
    precision, recall, f1, tp, tn, fp, fn = prf1_array(2, 0, labels, preds)
    print("Precision: {}, Recall: {}, F1: {}".format(precision, recall, f1))
    print("TP: {}, TN: {}, FP: {}, FN: {}".format(tp, tn, fp, fn))
    
    return preds, labels, outputs

In [None]:
# Load DeepSBD datasets for each fold
deepsbd_datasets = []
for fold in folds:
    shots_in_fold_qs = Shot.objects.filter(
        labeler__name__contains='manual',
        video_id__in = fold
    )
    shots_in_fold = VideoIntervalCollection.from_django_qs(shots_in_fold_qs)
    
    data = movies_deepsbd_data.DeepSBDDataset(shots_in_fold, verbose=True)
    deepsbd_datasets.append(data)

In [None]:
# dataset to hold multiple folds
class DeepSBDTrainDataset(Dataset):
    def __init__(self, datasets):
        self.datasets = datasets
    
    def __len__(self):
        return sum(len(d) for d in self.datasets)
    
    def __getitem__(self, idx):
        for d in self.datasets:
            if idx < len(d):
                return d[idx]
            else:
                idx -= len(d)
        
        return None
    
    def weights_for_balanced_classes(self):
        labels = [
            item[3]
            for d in self.datasets
            for item in d.items
        ]
        
        class_counts = {}
        for l in labels:
            if l not in class_counts:
                class_counts[l] = 1
            else:
                class_counts[l] += 1
        
        weights_per_class = {
            l: len(labels) / class_counts[l]
            for l in class_counts
        }
        
        return [
            weights_per_class[l]
            for l in labels
        ]

In [None]:
# models
deepsbd_alexnet_model = deepsbd_alexnet.deepSBD()
deepsbd_resnet_model = deepsbd_resnet.resnet18(num_classes=3,
    sample_size=128,
    sample_duration=16)

In [None]:
# alexnet deepSBD pre-trained on ClipShots
alexnet_state_dict = torch.load('models/ClipShots-DeepSBD-Alexnet-final.pth')['state_dict']
new_state_dict = OrderedDict()
for k, v in alexnet_state_dict.items():
    name = k[7:]
    new_state_dict[name] = v
deepsbd_alexnet_model.load_state_dict(new_state_dict)
# deepsbd_alexnet_model = deepsbd_alexnet_model.to(device)
# deepsbd_alexnet_model = deepsbd_alexnet_model.eval()

In [None]:
# resnet deepSBD pre-trained on ClipShots
resnet_state_dict = torch.load('models/ClipShots-DeepSBD-Resnet-18-final.pth')['state_dict']
new_state_dict = OrderedDict()
for k, v in resnet_state_dict.items():
    name = k[7:]
    new_state_dict[name] = v
deepsbd_resnet_model.load_state_dict(new_state_dict)
# deepsbd_resnet_model = deepsbd_resnet_model.to(device)
deepsbd_resnet_model = deepsbd_resnet_model.train()

In [None]:
# resnet deepSBD pre-trained on Kinetics
deepsbd_resnet_model_no_clipshots = deepsbd_resnet.resnet18(
    num_classes=3,
    sample_size=128,
    sample_duration=16
)
deepsbd_resnet_model_no_clipshots.load_weights('models/resnet-18-kinetics.pth')

In [None]:
# alexnet deepSBD
deepsbd_alexnet_model_no_clipshots = deepsbd_alexnet.deepSBD()

In [None]:
deepsbd_resnet_model_no_clipshots = deepsbd_resnet_model_no_clipshots.to(device).train()

In [None]:
training_dataset_fold1 = DeepSBDTrainDataset(deepsbd_datasets[:4])

In [None]:
fold1_weights = torch.DoubleTensor(training_dataset_fold1.weights_for_balanced_classes())

In [None]:
fold1_sampler = torch.utils.data.sampler.WeightedRandomSampler(fold1_weights, len(fold1_weights))

In [None]:
training_dataloader_fold1 = DataLoader(
    training_dataset_fold1,
    num_workers=0,
    shuffle=False,
    batch_size=16,
    sampler=fold1_sampler
)

In [None]:
criterion = nn.CrossEntropyLoss()

In [None]:
optimizer = optim.SGD(deepsbd_resnet_model.parameters(), 
                      lr=.001, momentum=.9, weight_decay=1e-3)

In [None]:
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min',patience=60000)

In [None]:
def train_epoch(epoch, training_dataloader, model, criterion, optimizer, scheduler):
    iter_len = len(training_dataloader)
    training_iter = iter(training_dataloader)
    
    for i in range(iter_len):
        clip_tensor, targets, _ = next(training_iter)
        
        outputs = model(clip_tensor.to(device))
        targets = targets.to(device)
        
        loss = criterion(outputs, targets)
        acc = calculate_accuracy(outputs, targets)
        preds = get_label(outputs)
        preds = [2 if p == 2 else 0 for p in preds]
        precision, recall, f1, tp, tn, fp, fn = prf1_array(
            2, 0, targets.cpu().data.numpy().tolist(), preds)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        print('Epoch: [{0}][{1}/{2}]\t'
              'Loss_conf {loss_c:.4f}\t'
              'acc {acc:.4f}\t'
              'pre {pre:.4f}\t'
              'rec {rec:.4f}\t'
              'f1 {f1: .4f}\t'
              'TP {tp} '
              'TN {tn} '
              'FP {fp} '
              'FN {fn} '
              .format(
                  epoch, i + 1, iter_len, loss_c=loss.item(), acc=acc,
                  pre=precision, rec=recall, f1=f1,
                  tp=tp, tn=tn, fp=fp, fn=fn))
    
    save_file_path = os.path.join(
        '/app/notebooks/learning/models/deepsbd_resnet_clipshots_pretrain_train_on_folds',
        'fold5_{}_epoch.pth'.format(epoch)
    )
    states = {
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict()
    }
    torch.save(states, save_file_path)

In [None]:
state = torch.load('models/deepsbd_resnet_train_on_folds/fold4_4_epoch.pth')

In [None]:
deepsbd_resnet_model_no_clipshots.load_state_dict(state['state_dict'])

In [None]:
for i in range(5):
    train_epoch(i, training_dataloader_fold1, deepsbd_resnet_model, criterion, optimizer, scheduler)

In [None]:
# specialize pre-trained model

In [None]:
# test models on splits
model = deepsbd_resnet_model.to(device).eval()
per_fold_preds_labels_outputs = []
for fold_dataset in deepsbd_datasets:
    dataloader = DataLoader(fold_dataset, batch_size=8, shuffle=False, num_workers=0)
    preds, labels, outputs = test_deepsbd(model, dataloader)
    
    per_fold_preds_labels_outputs.append((preds, labels, outputs))

In [None]:
# test models on splits
model = deepsbd_alexnet_model.to(device).eval()
per_fold_preds_labels_outputs_alexnet = []
for fold_dataset in deepsbd_datasets:
    dataloader = DataLoader(fold_dataset, batch_size=8, shuffle=False, num_workers=0)
    preds, labels, outputs = test_deepsbd(model, dataloader)
    
    per_fold_preds_labels_outputs.append((preds, labels, outputs))

In [None]:
model = deepsbd_resnet_model.eval()
per_fold_preds_labels_outputs_fold_training_only = []
for fold_dataset in deepsbd_datasets[4:]:
    dataloader = DataLoader(fold_dataset, batch_size=8, shuffle=False, num_workers=0)
    preds, labels, outputs = test_deepsbd(model, dataloader)
    
    per_fold_preds_labels_outputs_fold_training_only.append((preds, labels, outputs))

In [None]:
model.load_weights('models/resnet-18-kinetics.pth')
per_fold_preds_labels_outputs_fold_training_only = []
for fold_dataset in deepsbd_datasets[:1]:
    dataloader = DataLoader(fold_dataset, batch_size=8, shuffle=False, num_workers=0)
    preds, labels, outputs = test_deepsbd(model, dataloader)
    
    per_fold_preds_labels_outputs_fold_training_only.append((preds, labels, outputs))

DeepSBD, ResNet18 backbone trained on ClipShots:
* Fold 1
  * Precision: 0.8636363636363636, Recall: 0.9620253164556962, F1: 0.9101796407185629
  * TP: 228.0, TN: 1322.0, FP: 36.0, FN: 9.0
* Fold 2
  * Precision: 0.8934010152284264, Recall: 0.9617486338797814, F1: 0.9263157894736842
  * TP: 176.0, TN: 314.0, FP: 21.0, FN: 7.0
* Fold 3
  * Precision: 0.7666666666666667, Recall: 0.8263473053892215, F1: 0.7953890489913544
  * TP: 276.0, TN: 2246.0, FP: 84.0, FN: 58.0
* Fold 4
  * Precision: 0.8960396039603961, Recall: 1.0, F1: 0.9451697127937337
  * TP: 181.0, TN: 901.0, FP: 21.0, FN: 0.0
* Fold 5
  * Precision: 0.8571428571428571, Recall: 0.9831932773109243, F1: 0.9158512720156555
  * TP: 234.0, TN: 1141.0, FP: 39.0, FN: 4.0

Average F1: .898

DeepSBD, AlexNet backbone trained on ClipShots:
* Fold 1
  * Precision: 0.8507462686567164, Recall: 0.9620253164556962, F1: 0.902970297029703
  * TP: 228.0, TN: 1318.0, FP: 40.0, FN: 9.0
* Fold 2
  * Precision: 0.912568306010929, Recall: 0.912568306010929, F1: 0.912568306010929
  * TP: 167.0, TN: 319.0, FP: 16.0, FN: 16.0
* Fold 3
  * Precision: 0.7818696883852692, Recall: 0.8263473053892215, F1: 0.8034934497816594
  * TP: 276.0, TN: 2253.0, FP: 77.0, FN: 58.0
* Fold 4
  * Precision: 0.9782608695652174, Recall: 0.994475138121547, F1: 0.9863013698630136
  * TP: 180.0, TN: 918.0, FP: 4.0, FN: 1.0
* Fold 5
  * Precision: 0.8669201520912547, Recall: 0.957983193277311, F1: 0.9101796407185628
  * TP: 228.0, TN: 1145.0, FP: 35.0, FN: 10.0
  
Average F1: .903
  
DeepSBD, ResNet18 backbone trained on folds only:
* Fold 1
  * Precision: 0.7737226277372263, Recall: 0.8945147679324894, F1: 0.8297455968688846
  * TP: 212.0, TN: 1296.0, FP: 62.0, FN: 25.0
* Fold 2
  * Precision: 0.8165680473372781, Recall: 0.7540983606557377, F1: 0.7840909090909091
  * TP: 138.0, TN: 304.0, FP: 31.0, FN: 45.0
* Fold 3
  * Precision: 0.7407407407407407, Recall: 0.718562874251497, F1: 0.7294832826747719
  * TP: 240.0, TN: 2246.0, FP: 84.0, FN: 94.0
* Fold 4
  * Precision: 0.7990196078431373, Recall: 0.9005524861878453, F1: 0.8467532467532468
  * TP: 163.0, TN: 881.0, FP: 41.0, FN: 18.0
* Fold 5
  * Precision: 0.8057851239669421, Recall: 0.819327731092437, F1: 0.8125
  * TP: 195.0, TN: 1133.0, FP: 47.0, FN: 43.0
  
Average F1: .801

DeepSBD, ResNet18 backbone pre-trained on ClipShots, and then trained on folds:
* Fold 1
  * Precision: 0.7482758620689656, Recall: 0.9156118143459916, F1: 0.823529411764706
  * TP: 217.0, TN: 1285.0, FP: 73.0, FN: 20.0
* Fold 2
  * Precision: 0.8685714285714285, Recall: 0.8306010928961749, F1: 0.8491620111731845
  * TP: 152.0, TN: 312.0, FP: 23.0, FN: 31.0
* Fold 3
  * Precision: 0.8092105263157895, Recall: 0.7365269461077845, F1: 0.7711598746081504
  * TP: 246.0, TN: 2272.0, FP: 58.0, FN: 88.0
* Fold 4
  * Precision: 0.9344262295081968, Recall: 0.9447513812154696, F1: 0.9395604395604397
  * TP: 171.0, TN: 910.0, FP: 12.0, FN: 10.0
* Fold 5
  * Precision: 0.8771186440677966, Recall: 0.8697478991596639, F1: 0.8734177215189872
  * TP: 207.0, TN: 1151.0, FP: 29.0, FN: 31.0
  
Average F1: .851

## Weak Labels

### K folds

In [None]:
# Load DeepSBD datasets for each fold
deepsbd_datasets_logits = []
for fold in folds:
    shots_in_fold_qs = Shot.objects.filter(
        labeler__name__contains='manual',
        video_id__in = fold
    )
    shots_in_fold = VideoIntervalCollection.from_django_qs(shots_in_fold_qs)
    
    data = movies_deepsbd_data.DeepSBDDataset(shots_in_fold, verbose=True, preload=True, logits=True)
    deepsbd_datasets_logits.append(data)

In [None]:
deepsbd_datasets_logits[0].items

In [None]:
# load weak labels
with open('/app/data/shot_detection_weak_labels/noisy_labels_all_windows.npy', 'rb') as f:
    weak_labels_windows = np.load(f)

In [None]:
weak_labels_windows[:10]

In [None]:
weak_labels_windows[0][0][0]

In [None]:
weak_labels_collected = collect(
    weak_labels_windows,
    lambda row: row[0][0]
)

In [None]:
weak_labels_col = VideoIntervalCollection({
    video_id: [
        (row[0][1] ,row[0][2], row[1])
        for row in weak_labels_collected[video_id]
    ]
    for video_id in tqdm(list(weak_labels_collected.keys()))
})

In [None]:
def weak_payload_to_logits(weak_payload):
    return (weak_payload[1], 0., weak_payload[0])

In [None]:
deepsbd_datasets_weak = []
for dataset in deepsbd_datasets_logits:
    items_collected = collect(
        dataset.items,
        lambda item: item[0]
    )
    items_col = VideoIntervalCollection({
        video_id: [
            (item[1], item[2], item[3])
            for item in items_collected[video_id]
        ]
        for video_id in items_collected
    })
    
    new_items = weak_labels_col.join(
        items_col,
        predicate=equal(),
        working_window=1,
        merge_op = lambda weak, item: [weak]
    )
    
    dataset.items = [
        (video_id, intrvl.start, intrvl.end, weak_payload_to_logits(intrvl.payload))
        for video_id in sorted(list(new_items.get_allintervals().keys()))
        for intrvl in new_items.get_intervallist(video_id).get_intervals()
    ]
    deepsbd_datasets_weak.append(dataset)

In [None]:
# dataset to hold multiple folds for weak data
class DeepSBDWeakTrainDataset(Dataset):
    def __init__(self, datasets):
        self.datasets = datasets
    
    def __len__(self):
        return sum(len(d) for d in self.datasets)
    
    def __getitem__(self, idx):
        for d in self.datasets:
            if idx < len(d):
                return d[idx]
            else:
                idx -= len(d)
        
        return None
    
    def weights_for_balanced_classes(self):
        labels = [
            np.argmax(item[3])
            for d in self.datasets
            for item in d.items
        ]
        
        class_counts = [
            0
            for i in range(len(self.datasets[0].items[0]))
        ]
        for l in labels:
            class_counts[l] += 1
        
        weights_per_class = {
            i: len(labels) / l if l != 0 else 0
            for i, l in enumerate(class_counts)
        }
        
        return [
            weights_per_class[l]
            for l in labels
        ]

In [None]:
# resnet deepSBD pre-trained on Kinetics
deepsbd_resnet_model_no_clipshots = deepsbd_resnet.resnet18(
    num_classes=3,
    sample_size=128,
    sample_duration=16
)

In [None]:
deepsbd_resnet_model_no_clipshots.load_weights('models/resnet-18-kinetics.pth')

In [None]:
deepsbd_resnet_model_no_clipshots = deepsbd_resnet_model_no_clipshots.to(device).train()

In [None]:
training_dataset_fold1 = DeepSBDWeakTrainDataset(deepsbd_datasets_weak[1:])

In [None]:
fold1_weights = torch.DoubleTensor(training_dataset_fold1.weights_for_balanced_classes())

In [None]:
fold1_sampler = torch.utils.data.sampler.WeightedRandomSampler(fold1_weights, len(fold1_weights))

In [None]:
training_dataloader_fold1 = DataLoader(
    training_dataset_fold1,
    num_workers=0,
    shuffle=False,
    batch_size=16,
    sampler=fold1_sampler
)

In [None]:
criterion = nn.BCEWithLogitsLoss()

In [None]:
optimizer = optim.SGD(deepsbd_resnet_model_no_clipshots.parameters(), 
                      lr=.001, momentum=.9, weight_decay=1e-3)

In [None]:
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min',patience=60000)

In [None]:
# helper functions for deepsbd testing
def calculate_accuracy_logits(outputs, targets):
    batch_size = targets.size(0)

    _, pred = outputs.topk(1, 1, True)
    pred = pred.t()
    _, target_preds = targets.topk(1, 1, True)
    correct = pred.eq(target_preds.view(1, -1))
    n_correct_elems = correct.float().sum().item()

    return n_correct_elems / batch_size

def prf1_array(pos_label, neg_label, gt, preds):
    tp = 0.
    fp = 0.
    tn = 0.
    fn = 0.
    
    for truth, pred in zip(gt, preds):
        if truth == pred:
            if pred == pos_label:
                tp += 1.
            else:
                tn += 1.
        else:
            if pred == pos_label:
                fp += 1.
            else:
                fn += 1.
    
    precision = tp / (tp + fp) if tp + fp != 0 else 0
    recall = tp / (tp + fn) if tp + fn != 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if precision + recall != 0 else 0
    
    return (precision, recall, f1, tp, tn, fp, fn)

def get_label(res_tensor):
    res_numpy=res_tensor.data.cpu().numpy()
    labels=[]
    for row in res_numpy:
        labels.append(np.argmax(row))
    return labels

def test_deepsbd(model, dataloader):
    preds = []
    labels = []
    outputs = []
    for clip_tensor, l, _ in tqdm(dataloader):
        o = model(clip_tensor.to(device))
        l = torch.transpose(torch.stack(l).to(device), 0, 1).float()

        preds += get_label(o)
        labels += get_label(l)
        outputs += o.cpu().data.numpy().tolist()
    
    preds = [2 if p == 2 else 0 for p in preds]
        
    precision, recall, f1, tp, tn, fp, fn = prf1_array(2, 0, labels, preds)
    print("Precision: {}, Recall: {}, F1: {}".format(precision, recall, f1))
    print("TP: {}, TN: {}, FP: {}, FN: {}".format(tp, tn, fp, fn))
    
    return preds, labels, outputs

In [None]:
def train_epoch(epoch, training_dataloader, model, criterion, optimizer, scheduler, fold_num=1):
    iter_len = len(training_dataloader)
    training_iter = iter(training_dataloader)
    
    for i in range(iter_len):
        clip_tensor, targets, _ = next(training_iter)
        
        outputs = model(clip_tensor.to(device))
        targets = torch.transpose(torch.stack(targets).to(device), 0, 1).float()
        
        loss = criterion(outputs, targets)
        acc = calculate_accuracy_logits(outputs, targets)
        preds = get_label(outputs)
        preds = [2 if p == 2 else 0 for p in preds]
        target_preds = get_label(targets)
        precision, recall, f1, tp, tn, fp, fn = prf1_array(
            2, 0, target_preds, preds)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        print('Epoch: [{0}][{1}/{2}]\t'
              'Loss_conf {loss_c:.4f}\t'
              'acc {acc:.4f}\t'
              'pre {pre:.4f}\t'
              'rec {rec:.4f}\t'
              'f1 {f1: .4f}\t'
              'TP {tp} '
              'TN {tn} '
              'FP {fp} '
              'FN {fn} '
              .format(
                  epoch, i + 1, iter_len, loss_c=loss.item(), acc=acc,
                  pre=precision, rec=recall, f1=f1,
                  tp=tp, tn=tn, fp=fp, fn=fn))
    
    save_file_path = os.path.join(
        '/app/notebooks/learning/models/deepsbd_resnet_clipshots_pretrain_train_on_folds_weak',
        'fold{}_{}_epoch.pth'.format(fold_num, epoch)
    )
    states = {
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict()
    }
    torch.save(states, save_file_path)

In [None]:
# train K folds
for i in range(5):
    training_datasets = DeepSBDWeakTrainDataset(
        deepsbd_datasets_weak[:i] + deepsbd_datasets_weak[i+1:])
    fold_weights = torch.DoubleTensor(training_datasets.weights_for_balanced_classes())
    fold_sampler = torch.utils.data.sampler.WeightedRandomSampler(fold_weights, len(fold_weights))
    
    training_dataloader = DataLoader(
        training_datasets,
        num_workers=0,
        shuffle=False,
        batch_size=16,
        sampler=fold_sampler
    )
    
    criterion = nn.BCEWithLogitsLoss()
    
    # reset model
    deepsbd_resnet_model_no_clipshots.load_weights('models/resnet-18-kinetics.pth')
    optimizer = optim.SGD(deepsbd_resnet_model_no_clipshots.parameters(), 
                          lr=.001, momentum=.9, weight_decay=1e-3)
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min',patience=60000)
    
    for epoch in range(5):
        train_epoch(
            epoch, training_dataloader, 
            deepsbd_resnet_model_no_clipshots, 
            criterion, optimizer, scheduler, fold_num = i + 1)

In [None]:
per_fold_preds_labels_outputs_fold_training_only = []
for i in range(0, 5):
    # load 
    weights = torch.load(os.path.join(
        'models/deepsbd_resnet_clipshots_pretrain_train_on_folds_weak',
        'fold{}_{}_epoch.pth'.format(i + 1, 4)))['state_dict']
    deepsbd_resnet_model_no_clipshots.load_state_dict(weights)
    deepsbd_resnet_model_no_clipshots = deepsbd_resnet_model_no_clipshots.eval()
    test_dataset = deepsbd_datasets_weak[i]
    dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=0)
    preds, labels, outputs = test_deepsbd(deepsbd_resnet_model_no_clipshots, dataloader)
    
    per_fold_preds_labels_outputs_fold_training_only.append((preds, labels, outputs))

```
Precision: 0.7669491525423728, Recall: 0.8190045248868778, F1: 0.7921225382932167
TP: 181.0, TN: 1319.0, FP: 55.0, FN: 40.0

Precision: 0.45294117647058824, Recall: 0.8369565217391305, F1: 0.5877862595419847
TP: 77.0, TN: 333.0, FP: 93.0, FN: 15.0

Precision: 0.7121771217712177, Recall: 0.6225806451612903, F1: 0.6643717728055077
TP: 193.0, TN: 2276.0, FP: 78.0, FN: 117.0

Precision: 0.7078651685393258, Recall: 0.7455621301775148, F1: 0.7262247838616714
TP: 126.0, TN: 882.0, FP: 52.0, FN: 43.0

Precision: 0.7053140096618358, Recall: 0.7564766839378239, F1: 0.73
TP: 146.0, TN: 1164.0, FP: 61.0, FN: 47.0

Average F1: 0.70
```

### Whole movies

In [None]:
# same as above, except train on whole movies

### 100 movies

In [None]:
# train on 100 movies

### All movies

In [None]:
# train on all movies

# DSM Evaluation

In [None]:
# adaptive filtering

In [None]:
# dataloaders

In [None]:
# model

In [None]:
# load pre-loaded model

In [None]:
# train from scratch

In [None]:
# specialize pre-trained model