In [None]:
from esper.prelude import *
from rekall.video_interval_collection import VideoIntervalCollection
from rekall.temporal_predicates import overlaps
from esper.rekall import *
import matplotlib.pyplot as plt
import cv2

# Load up Ground Truth

In [None]:
# Load up small ground truth set for training
shots_gt_training_qs = Shot.objects.filter(
    Q(video_id=123, labeler__name__contains='manual', max_frame__lte=16560) | # easy
    Q(video_id=172, labeler__name__contains='manual') | # hard
    Q(video_id=179, labeler__name__contains='manual') | # easy
    Q(video_id=104, labeler__name__contains='manual') |
    Q(video_id=148, labeler__name__contains='manual')
)

In [None]:
shots_gt_test_qs = Shot.objects.filter(labeler__name__contains='manual')

In [None]:
shots_gt_training = VideoIntervalCollection.from_django_qs(shots_gt_training_qs)

In [None]:
shots_gt_test = VideoIntervalCollection.from_django_qs(shots_gt_test_qs).minus(shots_gt_training)

In [None]:
# Visualize the ground truth.
esper_widget(intrvllists_to_result(shots_gt_training), jupyter_keybindings=True, disable_captions=True)

# Evaluate Baselines

## Load up Shots from Heuristics

In [None]:
# Figure out temporal extents of the clips that were labeled
clips_training = shots_gt_training.dilate(1).coalesce().dilate(-1)
clips_test = shots_gt_test.dilate(1).coalesce().dilate(-1)

In [None]:
cinematic_shots_qs = Shot.objects.filter(cinematic=True).all()
cinematic_shots = VideoIntervalCollection.from_django_qs(
    cinematic_shots_qs,
    progress = True
)

In [None]:
cinematic_shots_training = cinematic_shots.filter_against(
    clips_training,
    predicate=overlaps()
)
cinematic_shots_test = cinematic_shots.filter_against(
    clips_test,
    predicate=overlaps()
)

In [None]:
cinematic_shot_boundaries_training = cinematic_shots_training.map(lambda i: (i.start, i.start, i.payload))
cinematic_shot_boundaries_test = cinematic_shots_test.map(lambda i: (i.start, i.start, i.payload))
gt_shot_boundaries_training = shots_gt_training.map(lambda i: (i.start, i.start, i.payload))
gt_shot_boundaries_test = shots_gt_test.map(lambda i: (i.start, i.start, i.payload))

In [None]:
def size(interval_collection):
    count = 0
    for video_id in interval_collection.get_allintervals():
        count += interval_collection.get_intervallist(video_id).size()
        
    return count

In [None]:
def print_per_video_precision_recall(gt_shot_boundaries, eval_shot_boundaries):
    for video_id in gt_shot_boundaries.get_allintervals():
        print("Video {}: ".format(video_id))
        cine_sb = VideoIntervalCollection({
            video_id: eval_shot_boundaries.get_intervallist(video_id)
        })
        gt_sb = VideoIntervalCollection({
            video_id: gt_shot_boundaries.get_intervallist(video_id)
        })
        accurate_sb = cine_sb.filter_against(gt_sb, predicate=overlaps())
        inaccurate_sb = cine_sb.minus(accurate_sb)

        found_human_sb = gt_sb.filter_against(cine_sb, predicate=overlaps())
        missed_human_sb = gt_sb.minus(found_human_sb)
        
        print("Precision: {}, {} out of {}".format(
            size(accurate_sb) / size(cine_sb), 
            size(accurate_sb), 
            size(cine_sb)))
        print("Recall: {}, {} out of {}".format(
            size(accurate_sb) / size(gt_sb), 
            size(accurate_sb), 
            size(gt_sb)))

In [None]:
print_per_video_precision_recall(gt_shot_boundaries_training, cinematic_shot_boundaries_training)

In [None]:
print_per_video_precision_recall(gt_shot_boundaries_test, cinematic_shot_boundaries_test)

In [None]:
# Visualize the discrepancies. Ground truth is in red, heuristic results are in blue.
result = intrvllists_to_result(shots_gt_training, color='red')
add_intrvllists_to_result(result, cinematic_shots_training, color='blue')
esper_widget(result, jupyter_keybindings=True, disable_captions=True)

## Machine Learning

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import time
import datetime
from tqdm import tqdm
import copy
import scannertools as st
import random

In [None]:
st.init_storage(os.environ['BUCKET'])

In [None]:
class ShotDetectionDataset(Dataset):
    def __init__(self, shots, window_size=1, height=224):
        """Constrcutor for ShotDetectionDataset.
        
        Args:
            shots: VideoIntervalCollection of all the intervals to get frames from. If the payload is -1,
            then the interval is not an actual shot and just needs to be included in the dataset.
        """
        self.window_size = window_size
        items = set()
        frame_nums = {}
        
        for video_id in shots.get_allintervals():
            frame_nums[video_id] = set()
            for intrvl in shots.get_intervallist(video_id).get_intervals():
                for f in range(intrvl.start, intrvl.end + 1):
                    items.add((
                        video_id,
                        f,
                        1 if f == intrvl.start and intrvl.payload != -1 else 0
                    ))
                    for i in range(intrvl.start - window_size, intrvl.end + window_size + 1):
                        frame_nums[video_id].add(i)
        self.items = sorted(list(items))
        
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((100, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        
        # Load frames into memory
        self.frames = {
            video_id: {
                'frame_nums': sorted(list(frame_nums[video_id])),
                'frames': [
                    self.transform(f)
                    for f in Video.objects.get(id=video_id).for_scannertools().frames(
                        sorted(list(frame_nums[video_id]))
                    )
                ]
            }
            for video_id in tqdm(frame_nums)
        }
    
    def __len__(self):
        return len(self.items)
    
    def __getitem__(self, idx):
        """
        Indexed by video ID, then frame number
        Returns self.window_size frames before the indexed frame to self.window_size
            frames after the indexed frame
        """
        video_id, frame_num, label = self.items[idx]
        
        start_index = self.frames[video_id]['frame_nums'].index(frame_num - self.window_size)
        img_tensors = self.frames[video_id]['frames'][start_index:start_index + 2*self.window_size + 1]
        
#         img_tensors = [
#             self.transform(f)
#             for f in Video.objects.get(id=video_id).for_scannertools().frames(
#                 list(range(frame_num - self.window_size, frame_num + self.window_size + 1))
#             )
#         ]
        
        return img_tensors, label

In [None]:
# construct a training set with good class balance
shot_boundaries = shots_gt_training.map(
    lambda intrvl: (intrvl.start, intrvl.start, intrvl.payload)
)
shots_without_boundaries = shots_gt_training.map(
    lambda intrvl: (intrvl.start + 1, intrvl.end, intrvl.payload)
).get_allintervals()
non_boundary_frames = [
    (video_id, f)
    for video_id in shots_without_boundaries
    for intrvl in shots_without_boundaries[video_id].get_intervals()
    for f in range(intrvl.start, intrvl.end + 1)
]
random.seed(0)
random.shuffle(non_boundary_frames) # seed of 0 for reproducibility
chosen_frames = collect(non_boundary_frames[:size(shot_boundaries)], lambda tup: tup[0])

training_frames = shot_boundaries.set_union(
    VideoIntervalCollection({
        video_id: [
            (frame, frame, -1)
            for vid, frame in chosen_frames[video_id]
        ]
        for video_id in chosen_frames
    })
).set_union(
    shots_gt_training.map(
        lambda intrvl: (intrvl.end, intrvl.end, -1)
    )
).set_union(
    shots_gt_training.map(
        lambda intrvl: (intrvl.start+1, intrvl.start+1, -1)
    )
)

In [None]:
dataset_training = ShotDetectionDataset(training_frames)

In [None]:
dataloader_training = DataLoader(dataset_training, batch_size=8, shuffle=True, num_workers=0)

In [None]:
dataset_training_test = ShotDetectionDataset(shots_gt_training)

In [None]:
dataloader_training_test = DataLoader(dataset_training_test, batch_size=8, shuffle=False, num_workers=0)

In [None]:
class VideoNet(nn.Module):
    def __init__(self, window_size=1):
        super(VideoNet, self).__init__()
#         self.resnet1 = models.ResNet(models.resnet.BasicBlock, [1, 1, 1, 1], num_classes=128)
#         self.resnet2 = models.ResNet(models.resnet.BasicBlock, [1, 1, 1, 1], num_classes=128)
#         self.resnet3 = models.ResNet(models.resnet.BasicBlock, [1, 1, 1, 1], num_classes=128)
        self.resnet1 = models.resnet18(pretrained=True)
        self.resnet2 = models.resnet18(pretrained=True)
        self.resnet3 = models.resnet18(pretrained=True)
    
        # Replace pooling layer with Adaptive Pooling
        self.resnet1.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.resnet2.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.resnet3.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        
#         self.embeddingpool = nn.MaxPool1d(5, stride=3)
        
#         self.rnfc1 = nn.Linear(1000, 128)
#         self.rnfc2 = nn.Linear(1000, 128)
#         self.rnfc3 = nn.Linear(1000, 128)
        
        self.embeddingconv = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64, 1)
        self.sigmoid = nn.Sigmoid()
        
    def init_weights(self):
        nn.init.kaiming_normal_(self.embeddingconv.weight, mode='fan_out', nonlinearity='relu')
        nn.init.xavier_uniform_(self.fc.weight)
        
    def forward(self, image1, image2, image3):
        image1embedding = self.resnet1(image1).unsqueeze(1)
        image2embedding = self.resnet2(image2).unsqueeze(1)
        image3embedding = self.resnet3(image3).unsqueeze(1)
        
#         print(image1embedding.size())
        
#         embedding_image = torch.cat(
#             (self.embeddingpool(image1embedding),
#              self.embeddingpool(image2embedding),
#              self.embeddingpool(image3embedding)),
#             dim=1
#         )
        
        embedding_image = torch.cat(
            (image1embedding,
             image2embedding,
             image3embedding),
            dim=1
        )
        
#         print(embedding_image.size())
        
        embedding_image = embedding_image.unsqueeze(1)
        
#         print(embedding_image.size())
        out = self.embeddingconv(embedding_image)
#         print(out.size())
        out = self.relu(out)
#         print(out.size())
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
#         out = nn.LogSoftmax(1)(out)
#         out = self.sigmoid(out)
#         out = F.softmax(out, dim=1)
        
        return out
    
#     def parameters(self):
#         return [self.embeddingconv.parameters(), self.fc.parameters()]

In [None]:
vnet = VideoNet()

In [None]:
vnet.init_weights()

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
vnet = vnet.to(device)

In [None]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train']:
            if phase == 'train':
                scheduler.step()
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            total_inputs = 0.0
            
            true_positive = 0.
            false_positive = 0.
            true_negative = 0.
            false_negative = 0.

            # Iterate over data.
            for idx, (inputs, labels) in tqdm(enumerate(dataloader_training)):
#                 if idx > 100:
#                     break
#                 print("Start loop {}".format(datetime.datetime.now()))
#                 crit = nn.BCELoss(
#                     weight = torch.tensor([
#                         1.0 if l.item() == 1 else .25
#                         for l in labels
#                     ]).to(device)
#                 )
                inputs = [i.to(device) for i in inputs]
                labels = labels.to(device)
#                 print("Moved inputs {}".format(datetime.datetime.now()))

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs[0], inputs[1], inputs[2])
                    batch_size = labels.size(0)
#                     _, preds = torch.max(outputs, 1)
#                     loss = criterion(outputs.view(1, 4), labels.view(1, 4))
#                     loss=torch.tensor([[0]]).to(device)
#                     print(outputs.view(1, batch_size), labels.view(1, batch_size))
#                     loss=criterion(outputs, labels)
                    loss=criterion(outputs.view(1, batch_size), labels.float().view(1, batch_size))
#                     print(outputs.view(1, 4), labels.float().view(1, 4), loss)
#                     if False:
                    if idx == 0:
                        print(outputs, labels, loss)
#                     print(labels)
#                     print(loss)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                
                
#                     for p, l in zip(preds, labels):
#                         if p.item() == l.item():
#                             if l.item() == 1:
#                                 true_positive += 1.
#                             else:
#                                 true_negative += 1.
#                         else:
#                             if p.item() == 1:
#                                 false_positive += 1.
#                             else:
#                                 false_negative += 1.
#                         total_inputs += 1

                    for o, l in zip(outputs, labels):
                        if o.item() > 0.:
                            if l.item() == 1:
                                true_positive += 1.
                            else:
                                false_positive += 1.
                        else:
                            if l.item() == 1:
                                false_negative += 1.
                            else:
                                true_negative += 1.
                        total_inputs += 1
                            
                    # statistics
                    running_loss += loss.item() * inputs[0].size(0)
                    running_corrects = true_positive + true_negative
#                     print(running_corrects, true_positive, true_negative, total_inputs)

#                 print("End loop {}".format(datetime.datetime.now()))

            epoch_loss = running_loss / total_inputs #/ len(dataset)
            epoch_acc = running_corrects / total_inputs #/ len(dataset)
            if true_positive + false_positive != 0:
                precision = true_positive / (true_positive + false_positive)
            else:
                precision = 0.
            if true_positive + false_negative != 0:
                recall = true_positive / (true_positive + false_negative)
            else:
                recall = 0.

            print('{} Loss: {:.4f} Acc: {:.4f} Precision: {:.4f} Recall: {:.4f}'.format(
                phase, epoch_loss, epoch_acc, precision, recall))
            print('TP: {} TN: {} FP: {} FN: {}'.format(
                true_positive, true_negative, false_positive, false_negative
            ))

            # deep copy the model
            if epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [None]:
# criterion = nn.CrossEntropyLoss(weight=torch.tensor([.1, 1.]).to(device))
# criterion = nn.CrossEntropyLoss()
# criterion = nn.NLLLoss(weight=torch.tensor([.01, .99]).to(device))
# criterion = nn.NLLLoss()
# criterion = nn.BCEWithLogitsLoss()
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([3.]).to(device))

In [None]:
optimizer = optim.SGD(vnet.parameters(), lr=0.01, momentum=0.9)

In [None]:
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

In [None]:
model = train_model(vnet, criterion, optimizer, exp_lr_scheduler, num_epochs=100)

In [None]:
def test_model(model, criterion, dataloader=dataloader_test):
    since = time.time()

    model.eval()   # Set model to evaluate mode

    running_loss = 0.0
    running_corrects = 0
    total_inputs = 0.0

    true_positive = 0.
    false_positive = 0.
    true_negative = 0.
    false_negative = 0.
    
    results = []

    # Iterate over data.
    for idx, (inputs, labels) in tqdm(enumerate(dataloader)):
        inputs = [i.to(device) for i in inputs]
        labels = labels.to(device)
                
        with torch.set_grad_enabled(False):
            outputs = model(inputs[0], inputs[1], inputs[2])
            batch_size = labels.size(0)
#             _, preds = torch.max(outputs, 1)
#             loss=criterion(outputs, labels)
            loss=criterion(outputs.view(1, batch_size), labels.float().view(1, batch_size))
#             if False:
            if idx == 0:
                print(outputs, labels, loss)
#             print(labels)
#             print(loss)
                
                
#             for p, l in zip(preds, labels):
#                 if p.item() == l.item():
#                     if l.item() == 1:
#                         true_positive += 1.
#                     else:
#                         true_negative += 1.
#                 else:
#                     if p.item() == 1:
#                         false_positive += 1.
#                     else:
#                         false_negative += 1.
#                 total_inputs += 1

            for o, l in zip(outputs, labels):
                if o.item() > 0.:
                    if l.item() == 1:
                        true_positive += 1.
                    else:
                        false_positive += 1.
                else:
                    if l.item() == 1:
                        false_negative += 1.
                    else:
                        true_negative += 1.
                total_inputs += 1
                results.append((o.item(), l.item()))

        # statistics
        running_loss += loss.item() * inputs[0].size(0)
        running_corrects = true_positive + true_negative
#     print(running_corrects, true_positive, true_negative, total_inputs)

    epoch_loss = running_loss / total_inputs #/ len(dataset)
    epoch_acc = running_corrects / total_inputs #/ len(dataset)
    if true_positive + false_positive != 0:
        precision = true_positive / (true_positive + false_positive)
    else:
        precision = 0.
    if true_positive + false_negative != 0:
        recall = true_positive / (true_positive + false_negative)
    else:
        recall = 0.

    print('Loss: {:.4f} Acc: {:.4f} Precision: {:.4f} Recall: {:.4f}'.format(
        epoch_loss, epoch_acc, precision, recall))
    print('TP: {} TN: {} FP: {} FN: {}'.format(
        true_positive, true_negative, false_positive, false_negative
    ))
    
    return results

In [None]:
dataset_test = ShotDetectionDataset(shots_gt_test)

In [None]:
dataloader_test = DataLoader(dataset_test, batch_size=8, shuffle=False, num_workers=0)

In [None]:
test_results = test_model(model, criterion)

In [None]:
training_test_results = test_model(model, criterion, dataloader_training_test)

In [None]:
true_positives = []
false_positives = []
for (output, label), item in zip(training_test_results, dataset_training_test.items):
    if output >= 0 and label == 1:
        true_positives.append((output, label, item))        
    if output > 0 and label == 0:
        false_positives.append((output, label, item))

In [None]:
tp_collected = collect(true_positives, lambda tup: tup[2][0])
true_positive_intrvls = VideoIntervalCollection({
    video_id: [
        (item[1], item[1], 0)
        for output, label, item in tp_collected[video_id]
    ]
    for video_id in tp_collected
})

In [None]:
fp_collected = collect(false_positives, lambda tup: tup[2][0])
false_positive_intrvls = VideoIntervalCollection({
    video_id: [
        (item[1], item[1], 0)
        for output, label, item in fp_collected[video_id]
    ]
    for video_id in fp_collected
})

In [None]:
esper_widget(
    intrvllists_to_result_with_objects(true_positive_intrvls, lambda a, b: []),
    jupyter_keybindings=True
)

In [None]:
esper_widget(
    intrvllists_to_result_with_objects(false_positive_intrvls, lambda a, b: []),
    jupyter_keybindings=True,
    display_captions=False
)

In [None]:
torch.save(model, '2-5-19_529pm_videonet_1to1classbalance_bcewithlogitsloss.pth')

In [None]:
torch.save(model, '2-6-19_948am_videonet_10to1classbalance_bcewithlogitsloss.pth')

In [None]:
torch.save(model, '2-6-19_1016am_videonet_2to1classbalance_bcewithlogitsloss.pth')

In [None]:
torch.save(model, '2-6-19_5pm_videonet_3to1classbalance_bcewithlogitsloss.pth')

# Notes

## Model/loss: raw output of last FC layer to BCEWithLogitsLoss

Training with perfectly balanced classes - selected 58 positive examples from training dataset and randomly selected 58 negative examples:
* Achieved 100% accuracy on train.
* On test, precision/recall at 26.7%/24.4%. Confusion matrix `TP: 139.0 TN: 52399.0 FP: 381.0 FN: 430.0`. Output of model had absolute value < 0.5.
* Saved in `2-5-19_529pm_videonet_1to1classbalance_bcewithlogitsloss.pth`.

Training with 10:1 class imbalance - 58 positive examples, 580 randomly selected negative examples:
* Achieved 100% accuracy on train.
* On test, precision/recall at 57.9%/1.9%. Confusion matrix `TP: 11.0 TN: 52772.0 FP: 8.0 FN: 558.0`. Output of model had absolute value around 5-10.
* Saved in `2-6-19_948am_videonet_10to1classbalance_bcewithlogitsloss.pth`.

Training with 2:1 class imbalance - 58 positive examples, 58 randomly selected negative examples, 58 examples from the end of shots:
* Achieved 100% accuracy on train.
* On test, precision/recall at 18.2%/3.9%. Confusion matrix `TP: 22.0 TN: 52681.0 FP: 99.0 FN: 547.0`. Output of model had absolute value < 2.
* Saved in `2-6-19_1016am_videonet_2to1classbalance_bcewithlogitsloss.pth`.

Issue: if you train on a subset of frames from the training clips, you won't do great on the full range of frames from the training clips. I.e. if you train on all the shot transitions, along with some random selected non-transition frames, you'll be able to identify all the shot transitions in your training clips, but you'll also get a bunch of false positives.

Training with a 3:1 class imbalance and 97 positive examples - plus 97 randomly selected negative examples. 97 examples from the end of shots, and 97 examples from the frame right after each shot transition.
* 100% accuracy on the training set.
* On the full set of training clips, 100% recall with 66% precision.
* On the set of training clips, hallucinating that many frames in a row are shot boundaries. Confusion matrix `TP: 97.0 TN: 7155.0 FP: 49.0 FN: 0.0`.
* Saved in `2-6-19_5pm_videonet_3to1classbalance_bcewithlogitsloss.pth`.

### Scratchpad

In [None]:
for inputs, labels in dataloader:
    inputs = [i.to(device) for i in inputs]
    labels = labels.to(device)
    outputs = vnet(inputs[0], inputs[1], inputs[2])
    print(outputs, labels)
    break

In [None]:
criterion(
    torch.tensor([
        [0.3, 0.7],
        [0.7, 0.3],
        [0.7, 0.3],
        [0.7, 0.3]
    ]),
    torch.tensor([
        1, 0, 0, 0
    ])
)

In [None]:
criterion = nn.CrossEntropyLoss(weight=torch.tensor([.01, 1.]))

In [None]:
nn.CrossEntropyLoss(weight=torch.tensor([1., .5]).to(device))(
    torch.tensor(
        [[-0.9855, 1.1573]]
    ).to(device),
    torch.tensor([1]).to(device)
)

In [None]:
criterion(
    torch.tensor([
        [0.8, 0.2],
        [0.8, 0.2],
        [0.8, 0.2],
        [0.8, 0.2]
    ]),
    torch.tensor([
        1, 0, 0, 0
    ])
)

In [None]:
#tenlayer_resnet = models.ResNet(models.resnet.BasicBlock, [1, 1, 1, 1], num_classes=128)

In [None]:
# Replace the avgpool layer with an AdaptiveAvgPool so we don't have to worry about input size
#tenlayer_resnet.avgpool = nn.AdaptiveAvgPool2d((1, 1))

In [None]:
#print(tenlayer_resnet)

In [None]:
#params = list(tenlayer_resnet.parameters())

In [None]:
#len(params)

In [None]:
#params[-1].size()

In [None]:
# Load up an image and run it through the network
vid_id = list(shots_gt.get_allintervals().keys())[0]
frame = shots_gt.get_intervallist(vid_id).get_intervals()[0].start
img = cv2.cvtColor(load_frame(Video.objects.get(id=vid_id), frame, []), cv2.COLOR_BGR2RGB)
plt.imshow(img)

In [None]:
#img_tensor = transforms.ToTensor()(img)

In [None]:
#tenlayer_resnet(img_tensor.unsqueeze(0))

In [None]:
imgs = [
   cv2.cvtColor(load_frame(Video.objects.get(id=123), f, []), cv2.COLOR_BGR2RGB)
   for f in range(14455-1, 14455+2)
]

In [None]:
plt.imshow(imgs[0])

In [None]:
plt.imshow(imgs[1])

In [None]:
plt.imshow(imgs[2])

In [None]:
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [None]:
img_tensors = [
    transform(npimg).unsqueeze(0).to(device)
    for npimg in imgs
]

In [None]:
img_tensors[0]

In [None]:
o = vnet(img_tensors[0], img_tensors[1], img_tensors[2])

In [None]:
o

In [None]:
model(img_tensors[0], img_tensors[1], img_tensors[2])

In [None]:
torch.max(o, 1)

In [None]:
class ShotDetectionDataset(Dataset):
    def __init__(self, shots, window_size=1, height=224):
        """Constrcutor for ShotDetectionDataset.
        
        Args:
            shots: VideoIntervalCollection of all the intervals to get frames from.
        """
        self.window_size = window_size
        frames = set()
        
        for video_id in shots.get_allintervals():
            for intrvl in shots.get_intervallist(video_id).get_intervals():
                for f in range(intrvl.start, intrvl.end + 1):
                    frames.add((video_id, f, 1 if f == intrvl.start else 0))
        self.frames = sorted(list(frames))
        
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    
    def __len__(self):
        return len(self.frames)
    
    def __getitem__(self, idx):
        """
        Indexed by video ID, then frame number
        Returns self.window_size frames before the indexed frame to self.window_size
            frames after the indexed frame
        """
        video_id, frame_num, label = self.frames[idx]
        npimgs = [
            cv2.cvtColor(load_frame(Video.objects.get(id=video_id), f, []), cv2.COLOR_BGR2RGB)
            for f in range(frame_num-self.window_size, frame_num+self.window_size + 1)
        ]
        img_tensors = [
            self.transform(npimg)
            for npimg in imgs
        ]
        
        return img_tensors, label

In [None]:
dataset = ShotDetectionDataset(shots_gt)

In [None]:
len(dataset)

In [None]:
for i in range(len(dataset)):
    sample = dataset[i]
    
    print(i, sample)
    
    if i == 3:
        break

In [None]:
dataloader = DataLoader(dataset, batch_size=4, shuffle=False, num_workers=0)

In [None]:
for i_batch, sample_batched in enumerate(dataloader):
    sample, label = sample_batched
    print(i_batch, len(sample))
    print(sample[0].size())
    print(label)
    if i_batch == 3:
        break

In [None]:
vnet(sample_batched[0], sample_batched[1], sample_batched[2])

In [None]:
vnet.train()

In [None]:
outs = vnet(sample_batched[0], sample_batched[1], sample_batched[2])

In [None]:
torch.max(outs, 1)

In [None]:
len(list(vnet.modules()))