<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc" style="margin-top: 1em;"><ul class="toc-item"><li><span><a href="#DSM-Shot-Detection" data-toc-modified-id="DSM-Shot-Detection-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>DSM Shot Detection</a></span></li><li><span><a href="#Part-1:-Adaptive-Filtering" data-toc-modified-id="Part-1:-Adaptive-Filtering-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Part 1: Adaptive Filtering</a></span><ul class="toc-item"><li><span><a href="#SqueezeNet" data-toc-modified-id="SqueezeNet-2.1"><span class="toc-item-num">2.1&nbsp;&nbsp;</span>SqueezeNet</a></span></li><li><span><a href="#Compute-SqueezeNet-embeddings" data-toc-modified-id="Compute-SqueezeNet-embeddings-2.2"><span class="toc-item-num">2.2&nbsp;&nbsp;</span>Compute SqueezeNet embeddings</a></span></li><li><span><a href="#Shot-Candidates-from-Differences-in-SqueezeNet-Embeddings" data-toc-modified-id="Shot-Candidates-from-Differences-in-SqueezeNet-Embeddings-2.3"><span class="toc-item-num">2.3&nbsp;&nbsp;</span>Shot Candidates from Differences in SqueezeNet Embeddings</a></span></li><li><span><a href="#Experiment:-Find-Best-Window-Size" data-toc-modified-id="Experiment:-Find-Best-Window-Size-2.4"><span class="toc-item-num">2.4&nbsp;&nbsp;</span>Experiment: Find Best Window Size</a></span></li></ul></li><li><span><a href="#Part-2:-Hard-Cut-Detection" data-toc-modified-id="Part-2:-Hard-Cut-Detection-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Part 2: Hard Cut Detection</a></span><ul class="toc-item"><li><span><a href="#Adaptive-filtering-for-candidates" data-toc-modified-id="Adaptive-filtering-for-candidates-3.1"><span class="toc-item-num">3.1&nbsp;&nbsp;</span>Adaptive filtering for candidates</a></span></li><li><span><a href="#Load-Hard-Cut-Prediction-Model" data-toc-modified-id="Load-Hard-Cut-Prediction-Model-3.2"><span class="toc-item-num">3.2&nbsp;&nbsp;</span>Load Hard Cut Prediction Model</a></span></li><li><span><a href="#ClipShots-dataloader" data-toc-modified-id="ClipShots-dataloader-3.3"><span class="toc-item-num">3.3&nbsp;&nbsp;</span>ClipShots dataloader</a></span></li><li><span><a href="#Test-Model" data-toc-modified-id="Test-Model-3.4"><span class="toc-item-num">3.4&nbsp;&nbsp;</span>Test Model</a></span></li></ul></li><li><span><a href="#Part-4:-Gradual-Cut-Detection" data-toc-modified-id="Part-4:-Gradual-Cut-Detection-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Part 4: Gradual Cut Detection</a></span></li></ul></div>

In [None]:
import os
import torch
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
import json
from PIL import Image
from tqdm import tqdm
from scipy.spatial.distance import cosine
from collections import OrderedDict

# DSM Shot Detection
In this notebook we'll try to recreate DSM shot detection as outlined in [this paper](https://arxiv.org/pdf/1808.04234.pdf).

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Part 1: Adaptive Filtering

The first part of the DSM pipeline is adaptive filtering using SqueezeNet activations trained on ImageNet.

In [None]:
# First, load up the results of adaptive filtering from when they did it
items = []
with open('/app/data/dsm_cut_detection/data/data_list/cut_train_data_video.txt', 'r') as f:
    for line in f.readlines():
        boundary = line.strip().split(' ')
        if len(boundary) == 5:
            [_, path, start_frame, end_frame, label] = boundary
        else:
            [path, start_frame, end_frame, label] = boundary
        path = os.path.basename(path)
        start_frame = int(start_frame)
        end_frame = int(end_frame)
        label = int(label)
        items.append((path, start_frame, end_frame, label))

In [None]:
vids = {}
for i in items:
    key = i[0]
    if key not in vids:
        vids[key] = []
    vids[key].append((i[1], i[2], i[3]))

In [None]:
test_vid = items[0][0]

## SqueezeNet

In [None]:
squeezenet = models.squeezenet1_1(pretrained=True).to(device).eval()

## Compute SqueezeNet embeddings

In [None]:
import random
import math
import numbers
import collections
import numpy as np
import torch
from PIL import Image, ImageOps
try:
    import accimage
except ImportError:
    accimage = None


class Compose(object):
    """Composes several transforms together.
    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.
    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
    """

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img

    def randomize_parameters(self):
        for t in self.transforms:
            t.randomize_parameters()


class ToTensor(object):
    """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.
    Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    """

    def __init__(self, norm_value=255):
        self.norm_value = norm_value

    def __call__(self, pic):
        """
        Args:
            pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.
        Returns:
            Tensor: Converted image.
        """
        if isinstance(pic, np.ndarray):
            # handle numpy array
            img = torch.from_numpy(pic.transpose((2, 0, 1)))
            # backward compatibility
            return img.float().div(self.norm_value)

        if accimage is not None and isinstance(pic, accimage.Image):
            nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
            pic.copyto(nppic)
            return torch.from_numpy(nppic)

        # handle PIL Image
        if pic.mode == 'I':
            img = torch.from_numpy(np.array(pic, np.int32, copy=False))
        elif pic.mode == 'I;16':
            img = torch.from_numpy(np.array(pic, np.int16, copy=False))
        else:
            img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
        # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
        if pic.mode == 'YCbCr':
            nchannel = 3
        elif pic.mode == 'I;16':
            nchannel = 1
        else:
            nchannel = len(pic.mode)
        img = img.view(pic.size[1], pic.size[0], nchannel)
        # put it from HWC to CHW format
        # yikes, this transpose takes 80% of the loading time/CPU
        img = img.transpose(0, 1).transpose(0, 2).contiguous()
        if isinstance(img, torch.ByteTensor):
            return img.float().div(self.norm_value)
        else:
            return img

    def randomize_parameters(self):
        pass


class Normalize(object):
    """Normalize an tensor image with mean and standard deviation.
    Given mean: (R, G, B) and std: (R, G, B),
    will normalize each channel of the torch.*Tensor, i.e.
    channel = (channel - mean) / std
    Args:
        mean (sequence): Sequence of means for R, G, B channels respecitvely.
        std (sequence): Sequence of standard deviations for R, G, B channels
            respecitvely.
    """

    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """
        # TODO: make efficient
        for t, m, s in zip(tensor, self.mean, self.std):
            t.sub_(m).div_(s)
        return tensor

    def randomize_parameters(self):
        pass


class Scale(object):
    """Rescale the input PIL.Image to the given size.
    Args:
        size (sequence or int): Desired output size. If size is a sequence like
            (w, h), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
    """

    def __init__(self, size, interpolation=Image.BILINEAR):
        assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
        """
        Args:
            img (PIL.Image): Image to be scaled.
        Returns:
            PIL.Image: Rescaled image.
        """
        if isinstance(self.size, int):
            w, h = img.size
            if (w <= h and w == self.size) or (h <= w and h == self.size):
                return img
            if w < h:
                ow = self.size
                oh = int(self.size * h / w)
                return img.resize((ow, oh), self.interpolation)
            else:
                oh = self.size
                ow = int(self.size * w / h)
                return img.resize((ow, oh), self.interpolation)
        else:
            return img.resize(self.size, self.interpolation)

    def randomize_parameters(self):
        pass

def get_mean(norm_value=255):
    return [114.7748 / norm_value, 107.7354 / norm_value, 99.4750 / norm_value]

def get_test_spatial_transform(opt):
    return Compose([Scale((opt.spatial_size,opt.spatial_size)),
                    ToTensor(opt.norm_value),
                    Normalize(get_mean(opt.norm_value), [1, 1, 1])])

In [None]:
def pil_loader(path):
    if not os.path.exists(path):
        return None
    try:
        with open(path, 'rb') as f:
            with Image.open(f) as img:
                return img.convert('RGB')
    except:
        return None

In [None]:
class Opt:
    def __init__(self):
        self.norm_value = 1
        self.spatial_size = 224

In [None]:
class VidDataset(Dataset):
    def __init__(self, path, frame_num):
        self.path = path
        self.frame_num = frame_num
        self.transform = get_test_spatial_transform(Opt())
    
    def __len__(self):
        return self.frame_num
    
    def __getitem__(self, idx):
        return self.transform(pil_loader(os.path.join(self.path, 'image_{}.jpg'.format(idx + 1))))

In [None]:
annotations_path = '/app/data/ClipShots/annotations/train.json'
frame_path = os.path.join('/app/data/ClipShots/frames/train', test_vid)

In [None]:
with open(annotations_path, 'r') as f:
    train_annotations = json.load(f)

In [None]:
test_vid_dataset = VidDataset(frame_path, int(train_annotations[test_vid]['frame_num']))

In [None]:
test_vid_dataloader = DataLoader(test_vid_dataset, shuffle=False, batch_size=8, num_workers=0)

In [None]:
squeezenet_features = []
squeezenet_classifications = []

In [None]:
for images in tqdm(test_vid_dataloader):
    images = images.to(device)
    squeezenet_classifications += squeezenet(images).detach().cpu().numpy().tolist()
    squeezenet_features += squeezenet.features(images).detach().cpu().numpy().tolist()

In [None]:
squeezenet_features_flat = [
    torch.tensor(f).view(1, -1)
    for f in squeezenet_features
]

## Shot Candidates from Differences in SqueezeNet Embeddings

In [None]:
print(vids[test_vid])

In [None]:
# print out the ground truth again
print(sorted([t1 for t1, t2, gt in vids[test_vid]]))

In [None]:
true_cuts = set([t1 for t1, t2, gt in vids[test_vid] if gt == 1])

In [None]:
print(len(vids[test_vid]))

In [None]:
scales = [1, 2, 4, 8, 16, 32]
window_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]

In [None]:
SIGMA = 0.05
T = 0.5
def get_candidates_from_scale(vectors, scale, window_size):
    vec_scaled = vectors[::scale]
    diffs = [
        cosine(vec_scaled[i-1], vec_scaled[i])
        for i in range(1, len(vec_scaled))
    ]
    
    n = len(diffs)
    candidate_boundaries = []
    for i in range(window_size, n-window_size):
        window = diffs[max(i - window_size, 0):min(i+window_size, n)]
        threshold = T + (SIGMA / len(window)) * np.sum(window)
        if diffs[i] > threshold:
            candidate_boundaries.append(i * scale)
    
    return candidate_boundaries

In [None]:
def get_all_candidates(vectors, window_size):
    candidates = [
        get_candidates_from_scale(vectors, scale, window_size)
        for scale in scales
    ]
    
    final_candidates = []
    for scale, candidates_for_scale in zip(scales, candidates):
        for c in candidates_for_scale:
            duplicate_candidate = False
            for existing_candidate in final_candidates:
                if abs(c - existing_candidate) <= scale:
                    duplicate_candidate = True
                    break
            if not duplicate_candidate:
                final_candidates.append(c)
    
    return sorted(final_candidates)

In [None]:
for window_size in window_sizes:
    candidate_list = set(get_all_candidates(squeezenet_features_flat, window_size))
    print('Window size {} has {} candidates; {} precision and {} recall'.format(
        window_size,
        len(candidate_list),
        len(candidate_list.intersection(true_cuts)) / len(candidate_list),
        len(candidate_list.intersection(true_cuts)) / len(true_cuts)
    ))

## Experiment: Find Best Window Size

In [None]:
import random

In [None]:
random.seed(0)

In [None]:
all_vids = sorted(list(vids.keys()))
random.shuffle(all_vids)
test_vids = all_vids[:10]

In [None]:
test_vids

In [None]:
for vid in test_vids:
    print(vid)
    true_cuts = set([t1 for t1, t2, gt in vids[vid] if gt == 1])
    print('{} has {} true transitions'.format(vid, len(true_cuts)))
    frame_path = os.path.join('/app/data/ClipShots/frames/train', vid)
    vid_dataset = VidDataset(frame_path, int(train_annotations[vid]['frame_num']))
    vid_dataloader = DataLoader(vid_dataset, shuffle=False, batch_size=8, num_workers=0)
    
    squeezenet_features = []
    squeezenet_classifications = []
    for images in vid_dataloader:
        images = images.to(device)
        squeezenet_classifications += squeezenet(images).detach().cpu().numpy().tolist()
        squeezenet_features += squeezenet.features(images).detach().cpu().numpy().tolist()
        
    squeezenet_features_flat = [
        torch.tensor(f).view(1, -1)
        for f in squeezenet_features
    ]
    
    for window_size in window_sizes:
        candidate_list = set(get_all_candidates(squeezenet_features_flat, window_size))
        print('Window size {} has {} candidates; {} precision and {} recall'.format(
            window_size,
            len(candidate_list),
            0 if len(candidate_list) == 0 else len(candidate_list.intersection(true_cuts)) / len(candidate_list),
            0 if len(true_cuts) == 0 else len(candidate_list.intersection(true_cuts)) / len(true_cuts)
        ))

# Part 2: Hard Cut Detection
Use adaptive filtering with a window size of 6 to get candidates, and then run hard cut detection on the candidates.

## Adaptive filtering for candidates

In [None]:
WINDOW_SIZE=6

In [None]:
shot_candidates = {}
for test_vid in tqdm(test_vids):
    frame_path = os.path.join('/app/data/ClipShots/frames/train', test_vid)
    vid_dataset = VidDataset(frame_path, int(train_annotations[test_vid]['frame_num']))
    vid_dataloader = DataLoader(vid_dataset, shuffle=False, batch_size=8, num_workers=0)
    
    squeezenet_features = []
    for images in vid_dataloader:
        images = images.to(device)
        squeezenet_features += squeezenet.features(images).detach().cpu().numpy().tolist()
        
    squeezenet_features_flat = [
        torch.tensor(f).view(1, -1)
        for f in squeezenet_features
    ]    
    
    candidate_list = set(get_all_candidates(squeezenet_features_flat, WINDOW_SIZE))
    
    shot_candidates[test_vid] = candidate_list

In [None]:
# Store candidates in a local file for the dataloader
with open('/app/data/cut_candidates.txt', 'w') as f:
    for video in shot_candidates:
        gt = sorted([t1 for t1, t2, gt in vids[video]])
        for candidate in sorted(shot_candidates[video]):
            f.write('{} {} {} {} {}\n'.format(
                video, video, candidate, candidate + 1, 1 if candidate in gt else 0
            ))

## Load Hard Cut Prediction Model

In [None]:
# Model for 3D Resnet
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
import torch

__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
           'resnet152']


model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}


def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

bn_momentum=0.1
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes,momentum=bn_momentum)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes,momentum=bn_momentum)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4,momentum=bn_momentum)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers,input_channel, num_classes=1000,no_fc=False):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(input_channel, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64,momentum=bn_momentum)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(4, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)
        self.no_fc=no_fc

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        if not self.no_fc:
            x = self.fc(x)

        return x


def resnet18(pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    return model


def resnet34(pretrained=False, **kwargs):
    """Constructs a ResNet-34 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
    return model


def resnet50(opt, **kwargs):
    """Constructs a ResNet-50 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    if opt.img_concat:
        input_channel=opt.image_num*(6 if opt.center_crop else 3)
    else:
        input_channel=3

    model = ResNet(Bottleneck, [3, 4, 6, 3],input_channel ,2, no_fc=opt.no_fc)
    if opt.pretrain_path:
        pretrain_weights=torch.load(opt.pretrain_path, map_location=lambda storage, loc: storage)
        current_param=model.state_dict()
        pretrained={k:v for k,v in pretrain_weights.items() if k in current_param if k in current_param and k.split('.')[0]!='conv1' and k.split('.')[0]!='fc'}
        print(pretrained.keys())
        current_param.update(pretrained)
        model.load_state_dict(current_param)
    return model


def resnet101(pretrained=False, **kwargs):
    """Constructs a ResNet-101 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
    return model


def resnet152(pretrained=False, **kwargs):
    """Constructs a ResNet-152 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
    return model

In [None]:
opt = Object()
opt.pretrain_path = None
opt.img_concat = True
opt.center_crop = False
opt.no_fc = False
opt.image_num = 6

In [None]:
cut_model = resnet50(opt)

In [None]:
load_path = os.path.join(
    "/app/data/img_concat_6_frames_resnet",
    'resnet50_epoch_{}.pth'.format(19)
)
states = torch.load(load_path)['state_dict']
new_resnet_state_dict = OrderedDict()
for k, v in states.items():
    name = k[7:]
    new_resnet_state_dict[name] = v

cut_model.load_state_dict(new_resnet_state_dict)

In [None]:
cut_model_5_epochs = resnet50(opt)

In [None]:
load_path = os.path.join(
    "/app/data/img_concat_6_frames_resnet",
    'resnet50_epoch_{}.pth'.format(5)
)
states = torch.load(load_path)['state_dict']
new_resnet_state_dict = OrderedDict()
for k, v in states.items():
    name = k[7:]
    new_resnet_state_dict[name] = v

cut_model_5_epochs.load_state_dict(new_resnet_state_dict)

## ClipShots dataloader

In [None]:
import torch
import torch.utils.data as data
from PIL import Image
import os
import math
import functools
import json
import copy
import numpy as np
import random
import cv2

def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    if not os.path.exists(path):
        return None
    try:
        with open(path, 'rb') as f:
            with Image.open(f) as img:
                return img.convert('RGB')
        
    except:
        return None

def video_loader(video_dir_path, frame_indices,sample_duration):
    video = []
    for i in range(frame_indices,frame_indices+sample_duration):
        image_path = os.path.join(video_dir_path, 'image_{}.jpg'.format(i))
        if os.path.exists(image_path):
            video.append(pil_loader(image_path))
        else:
            return video
    assert(len(video)==sample_duration)
    return video

def numpy_loader(numpy_path):
    if os.path.exists(numpy_path):
        try:
            video=np.load(numpy_path)
            length=video.shape[0]
            video=np.split(video,length)
            video=[Image.fromarray(cv2.cvtColor(np.squeeze(frame), cv2.COLOR_BGR2RGB)).convert('RGB') for frame in video]
            return video
        except:
            return None
    else:
        return None

def get_numpy_path(root_path,video_path,begin):
    if video_path is None:
        return
    videoname=video_path.split("/")[-1]
    return os.path.join(root_path,'numpy_img_cut',videoname[0],videoname,"{}.npy".format(str(begin)))

def get_default_video_loader():
    return pil_loader


def make_dataset(root_path, image_list_path,label_in):
    image_list=[]
    with open(image_list_path, 'r') as f:
        for line in f.readlines():
            words=line.split(' ')
            if len(words)==4:
                path,idx1,idx2,label=words
                path=os.path.basename(path)
                video_path=None
            elif len(words)==5:
                video_path,path,idx1,idx2,label=words
                path=os.path.basename(path)
            idx1,idx2=int(idx1),int(idx2)
            if idx1>idx2:
                idx1,idx2=idx2,idx1
            if idx1<0:
                continue
            if int(label)!=label_in:
                continue
            info={"root_path":root_path,"idx1":idx1,"idx2":idx2,"path":path,'label':int(label),'video_path':video_path,'numpy_path':get_numpy_path(root_path,video_path,idx1-2)}
            image_list.append(info)

    return image_list


class DataSet(data.Dataset):

    def __init__(self, root_path, image_list_path,balance_pos_neg,label,center_crop_transform=None,image_num=4,
                 spatial_transform=None,get_loader=get_default_video_loader):
        #print(label)
        self.image_list = make_dataset(root_path, image_list_path,label)
        assert(len(self.image_list)>0)
        self.spatial_transform = spatial_transform
        self.loader = get_loader()
        self.index_shuffled=list(range(len(self.image_list)))

        self.balance_pos_neg=balance_pos_neg
        self.pos_cnt=0
        self.neg_cnt=0
        self.image_num=image_num
        self.center_crop_transform=center_crop_transform

    def get_pos_index(self):
        if len(self.pos_index)==0:
            return self.get_neg_index()

        index=self.pos_index[self.pos_cnt]
        self.pos_cnt+=1
        self.pos_cnt%=len(self.pos_index)
        return index
    
    def get_neg_index(self):
        if len(self.neg_index)==0:
            return self.get_pos_index()

        index=self.neg_index[self.neg_cnt]
        self.neg_cnt+=1
        self.neg_cnt%=len(self.neg_index)
        return index
    
    def get_index(self,index):
        if not self.balance_pos_neg:
            index=self.index_shuffled[index]
        else:
            if index%2==0:
                #index=self.pos_index[int(index/2)%len(self.pos_index)]
                index=self.get_pos_index()
            else:
                #index=self.neg_index[int(index/2)%len(self.neg_index)]
                index=self.get_neg_index()
        return index

    def get_imgs(self,root_path,idx1,idx2,path):
        imgs=[]
        half_len=int(self.image_num/2)
        for i in range(idx1,idx1-half_len,-1):
            img=self.loader(os.path.join(root_path,path,"image_{}.jpg".format(i)))
            if img is None:
                if len(imgs)>0:
                    if self.center_crop_transform is not None:
                        imgs+=imgs[-2:]
                    else:
                        imgs+=imgs[-1:]
                else:
                    return
                continue
            else:
                if self.spatial_transform is not None:
                    img_t=self.spatial_transform(img)
                    imgs.append(img_t)
                if self.center_crop_transform is not None:
                    img_c=self.center_crop_transform(img)
                    imgs.append(img_c)

        imgs=list(reversed(imgs))

        for i in range(idx2,idx2+half_len):
            img=self.loader(os.path.join(root_path,path,"image_{}.jpg".format(i)))
            
            if img is None:
                if len(imgs)>0:
                    if self.center_crop_transform is not None:
                        imgs+=imgs[-2:]
                    else:
                        imgs+=imgs[-1:]
                else:
                    return
                continue
            else:
                if self.spatial_transform is not None:
                    img_t=self.spatial_transform(img)
                    imgs.append(img_t)
                if self.center_crop_transform is not None:
                    img_c=self.center_crop_transform(img)
                    imgs.append(img_c)
                

        return imgs
        
    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is class_index of the target class.
        """
        imgs=None
        raw_index=index
        if self.spatial_transform is not None:
            self.spatial_transform.randomize_parameters()
        while imgs is None:
            
            index=self.get_index(index)
            info = self.image_list[index]

            root_path=info['root_path']
            idx1=info['idx1']
            idx2=info['idx2']
            path=info['path']
            label=info['label']
            video_path=info['video_path']
            numpy_path=info['numpy_path']
            imgs=None
            #if self.image_num==6 and numpy_path is not None:
            #    imgs=numpy_loader(numpy_path)
            #    if imgs is not None:
            #        imgs=[self.spatial_transform(img) for img in imgs]
            if imgs is None:
                imgs=self.get_imgs(root_path,idx1,idx2,path)
            if imgs is None:
                index=random.randint(0,len(self.image_list)-1)

        #torch.from_numpy(np.array(label,dtype=np.float32))
        return imgs,torch.from_numpy(np.array([label],dtype=np.float32))

    def __len__(self):
        return len(self.image_list)

In [None]:
import random
import math
import numbers
import collections
import numpy as np
import torch
from PIL import Image, ImageOps
try:
    import accimage
except ImportError:
    accimage = None


class Compose(object):
    """Composes several transforms together.
    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.
    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
    """

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img

    def randomize_parameters(self):
        for t in self.transforms:
            t.randomize_parameters()


class ToTensor(object):
    """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.
    Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    """

    def __init__(self, norm_value=255):
        self.norm_value = norm_value

    def __call__(self, pic):
        """
        Args:
            pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.
        Returns:
            Tensor: Converted image.
        """
        if isinstance(pic, np.ndarray):
            # handle numpy array
            img = torch.from_numpy(pic.transpose((2, 0, 1)))
            # backward compatibility
            return img.float().div(self.norm_value)

        if accimage is not None and isinstance(pic, accimage.Image):
            nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
            pic.copyto(nppic)
            return torch.from_numpy(nppic)

        # handle PIL Image
        if pic.mode == 'I':
            img = torch.from_numpy(np.array(pic, np.int32, copy=False))
        elif pic.mode == 'I;16':
            img = torch.from_numpy(np.array(pic, np.int16, copy=False))
        else:
            img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
        # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
        if pic.mode == 'YCbCr':
            nchannel = 3
        elif pic.mode == 'I;16':
            nchannel = 1
        else:
            nchannel = len(pic.mode)
        img = img.view(pic.size[1], pic.size[0], nchannel)
        # put it from HWC to CHW format
        # yikes, this transpose takes 80% of the loading time/CPU
        img = img.transpose(0, 1).transpose(0, 2).contiguous()
        if isinstance(img, torch.ByteTensor):
            return img.float().div(self.norm_value)
        else:
            return img

    def randomize_parameters(self):
        pass


class Normalize(object):
    """Normalize an tensor image with mean and standard deviation.
    Given mean: (R, G, B) and std: (R, G, B),
    will normalize each channel of the torch.*Tensor, i.e.
    channel = (channel - mean) / std
    Args:
        mean (sequence): Sequence of means for R, G, B channels respecitvely.
        std (sequence): Sequence of standard deviations for R, G, B channels
            respecitvely.
    """

    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """
        # TODO: make efficient
        for t, m, s in zip(tensor, self.mean, self.std):
            t.sub_(m).div_(s)
        return tensor

    def randomize_parameters(self):
        pass


class Scale(object):
    """Rescale the input PIL.Image to the given size.
    Args:
        size (sequence or int): Desired output size. If size is a sequence like
            (w, h), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
    """

    def __init__(self, size, interpolation=Image.BILINEAR):
        assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
        """
        Args:
            img (PIL.Image): Image to be scaled.
        Returns:
            PIL.Image: Rescaled image.
        """
        if isinstance(self.size, int):
            w, h = img.size
            if (w <= h and w == self.size) or (h <= w and h == self.size):
                return img
            if w < h:
                ow = self.size
                oh = int(self.size * h / w)
                return img.resize((ow, oh), self.interpolation)
            else:
                oh = self.size
                ow = int(self.size * w / h)
                return img.resize((ow, oh), self.interpolation)
        else:
            return img.resize(self.size, self.interpolation)

    def randomize_parameters(self):
        pass


class CenterCrop(object):
    """Crops the given PIL.Image at the center.
    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
            made.
    """

    def __init__(self, size):
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size

    def __call__(self, img):
        """
        Args:
            img (PIL.Image): Image to be cropped.
        Returns:
            PIL.Image: Cropped image.
        """
        w, h = img.size
        th, tw = self.size
        x1 = int(round((w - tw) / 2.))
        y1 = int(round((h - th) / 2.))
        return img.crop((x1, y1, x1 + tw, y1 + th))

    def randomize_parameters(self):
        pass


class CornerCrop(object):
    def __init__(self, size, crop_position=None):
        self.size = size
        if crop_position is None:
            self.randomize = True
        else:
            self.randomize = False
        self.crop_position = crop_position
        self.crop_positions = ['c', 'tl', 'tr', 'bl', 'br']

    def __call__(self, img):
        image_width = img.size[0]
        image_height = img.size[1]

        if self.crop_position == 'c':
            th, tw = (self.size, self.size)
            x1 = int(round((image_width - tw) / 2.))
            y1 = int(round((image_height - th) / 2.))
            x2 = x1 + tw
            y2 = y1 + th
        elif self.crop_position == 'tl':
            x1 = 0
            y1 = 0
            x2 = self.size
            y2 = self.size
        elif self.crop_position == 'tr':
            x1 = image_width - self.size
            y1 = 0
            x2 = image_width
            y2 = self.size
        elif self.crop_position == 'bl':
            x1 = 0
            y1 = image_height - self.size
            x2 = self.size
            y2 = image_height
        elif self.crop_position == 'br':
            x1 = image_width - self.size
            y1 = image_height - self.size
            x2 = image_width
            y2 = image_height

        img = img.crop((x1, y1, x2, y2))

        return img

    def randomize_parameters(self):
        if self.randomize:
            self.crop_position = self.crop_positions[
                random.randint(0, len(self.crop_positions) - 1)]


class RandomHorizontalFlip(object):
    """Horizontally flip the given PIL.Image randomly with a probability of 0.5."""

    def __call__(self, img):
        """
        Args:
            img (PIL.Image): Image to be flipped.
        Returns:
            PIL.Image: Randomly flipped image.
        """
        if self.p < 0.5:
            return img.transpose(Image.FLIP_LEFT_RIGHT)
        return img

    def randomize_parameters(self):
        self.p = random.random()


class MultiScaleCornerCrop(object):
    """Crop the given PIL.Image to randomly selected size.
    A crop of size is selected from scales of the original size.
    A position of cropping is randomly selected from 4 corners and 1 center.
    This crop is finally resized to given size.
    Args:
        scales: cropping scales of the original size
        size: size of the smaller edge
        interpolation: Default: PIL.Image.BILINEAR
    """

    def __init__(self, scales, size, interpolation=Image.BILINEAR):
        self.scales = scales
        self.size = size
        self.interpolation = interpolation

        self.crop_positions = ['c', 'tl', 'tr', 'bl', 'br']

    def __call__(self, img):
        min_length = min(img.size[0], img.size[1])
        crop_size = int(min_length * self.scale)

        image_width = img.size[0]
        image_height = img.size[1]

        if self.crop_position == 'c':
            center_x = image_width // 2
            center_y = image_height // 2
            box_half = crop_size // 2
            x1 = center_x - box_half
            y1 = center_y - box_half
            x2 = center_x + box_half
            y2 = center_y + box_half
        elif self.crop_position == 'tl':
            x1 = 0
            y1 = 0
            x2 = crop_size
            y2 = crop_size
        elif self.crop_position == 'tr':
            x1 = image_width - crop_size
            y1 = 0
            x2 = image_width
            y2 = crop_size
        elif self.crop_position == 'bl':
            x1 = 0
            y1 = image_height - crop_size
            x2 = crop_size
            y2 = image_height
        elif self.crop_position == 'br':
            x1 = image_width - crop_size
            y1 = image_height - crop_size
            x2 = image_width
            y2 = image_height

        img = img.crop((x1, y1, x2, y2))

        return img.resize((self.size, self.size), self.interpolation)

    def randomize_parameters(self):
        self.scale = self.scales[random.randint(0, len(self.scales) - 1)]
        self.crop_position = self.crop_positions[random.randint(0, len(self.scales) - 1)]

class RandomGrayscale(object):
    """Randomly convert image to grayscale with a probability of p (default 0.1).
    Args:
        p (float): probability that image should be converted to grayscale.
        num_output_channels (int): (1 or 3) number of channels desired for output image
    Returns:
        PIL Image: grayscale version of the input image with probability p
                   if num_output_channels == 1 : returned image is single channel
                   if num_output_channels == 3 : returned image is 3 channel with r == g == b
    """

    def __init__(self, p=0.1, num_output_channels=1):
        self.p = p
        self.num_output_channels = num_output_channels
        self.v=random.random()

    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be converted to grayscale.
        Returns:
            PIL Image: Randomly grayscaled image.
        """
        if self.v < self.p:
            if self.num_output_channels == 1:
                img = img.convert('L')
            elif self.num_output_channels == 3:
                img = img.convert('L')
                np_img = np.array(img, dtype=np.uint8)
                np_img = np.dstack([np_img, np_img, np_img])
                img = Image.fromarray(np_img, 'RGB')
            else:
                raise ValueError('num_output_channels should be either 1 or 3')

        return img

    def randomize_parameters(self):
        self.v=random.random()

In [None]:
transform=Compose([
    Scale((128, 128)),
    ToTensor(255)
])
pos_data=DataSet(
    '/app/data/ClipShots/frames/train',
    '/app/data/cut_candidates.txt',
    image_num=6,
    balance_pos_neg=False,
    label=1,
    center_crop_transform=None,
    spatial_transform=transform)
neg_data=DataSet(
    '/app/data/ClipShots/frames/train',
    '/app/data/cut_candidates.txt',
    image_num=6,
    balance_pos_neg=False,
    label=0,
    center_crop_transform=None,
    spatial_transform=transform)
pos_test_dataloader=DataLoader(pos_data,
    shuffle=False,
    num_workers=0,
    batch_size=128)
neg_test_dataloader=DataLoader(neg_data,
    shuffle=False,
    num_workers=0,
    batch_size=128)

## Test Model

In [None]:
criterion=nn.CrossEntropyLoss()

In [None]:
def calculate_accuracy(outputs, targets):
    batch_size = targets.size(0)

    _, pred = outputs.topk(1, 1, True)
    pred = pred.t()
    correct = pred.eq(targets.view(1, -1))
    n_correct_elems = correct.float().sum().item()

    return n_correct_elems / batch_size

def prf1_confusion_matrix(outputs, targets):
    tp = 0.
    fp = 0.
    tn = 0.
    fn = 0.

    _, preds = outputs.topk(1, 1, True)
    preds = preds.view(1, -1).squeeze()

    preds = preds.data.cpu().numpy().tolist()
    gt = targets.data.cpu().numpy().tolist()

    for truth, pred in zip(gt, preds):
        if truth == pred:
            if pred == 1:
                tp += 1.
            else:
                tn += 1.
        else:
            if pred == 1:
                fp += 1.
            else:
                fn += 1.
                    
    precision = tp / (tp + fp) if tp + fp != 0 else 0
    recall = tp / (tp + fn) if tp + fn != 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if precision + recall != 0 else 0
            
    return (precision, recall, f1, tp, tn, fp, fn)

In [None]:
def test_model(model, pos_data_loader, neg_data_loader,criterion):
    print('testing model')
    model = model.to(device)
    model = model.eval()

    num_batches = 0
    total_batches = len(pos_data_loader) + len(neg_data_loader)
    total_loss = 0
    total_acc = 0
    total_tp = 0
    total_tn = 0
    total_fp = 0
    total_fn = 0

    loaders = [iter(pos_data_loader), iter(neg_data_loader)]

    all_outputs = []
    all_targets = []
    
    test = True
    
    for l in loaders:
        for imgs, targets in l:
            imgs = [img for img in imgs]
            img_concat = torch.cat(imgs,1).to(device)
            targets = targets.long().view(-1).to(device)

            outputs = model(img_concat)
            
            if test:
                all_outputs.append(outputs)
                all_targets.append(targets)
                test = False

            loss = criterion(outputs, targets)
            acc = calculate_accuracy(outputs, targets)
            precision, recall, f1, tp, tn, fp, fn = prf1_confusion_matrix(outputs, targets)

            print('Batch: [{0}/{1}]\t'
                  'Loss_conf {loss_c:.4f}\t'
                  'acc {acc:.4f}\t'
                  'pre {pre:.4f}\t'
                  'rec {rec:.4f}\t'
                  'f1 {f1: .4f}\t'
                  'TP {tp} '
                  'TN {tn} '
                  'FP {fp} '
                  'FN {fn} '
                  .format(
                      num_batches + 1, total_batches, loss_c=loss.item(),acc=acc,
                      pre=precision, rec=recall, f1=f1,
                      tp=tp, tn=tn, fp=fp, fn=fn))

            total_loss += loss.item()
            total_acc += acc
            total_tp += tp
            total_tn += tn
            total_fp += fp
            total_fn += fn
            num_batches += 1

    avg_loss = total_loss / total_batches
    avg_acc = total_acc / total_batches

    final_precision = total_tp / (total_tp + total_fp)
    final_recall = total_tp / (total_tp + total_fn)
    final_f1 = 2 * (final_precision * final_recall) / (final_precision + final_recall)

    print('Final stats\t'
          'Loss {loss:.4f}\t'
          'acc {acc:.4f}\t'
          'pre {pre:.4f}\t'
          'rec {rec:.4f}\t'
          'f1 {f1: .4f}\t'
          'TP {tp} '
          'TN {tn} '
          'FP {fp} '
          'FN {fn} '
          .format(loss=avg_loss, acc=avg_acc, pre=final_precision,
              rec=final_recall, f1=final_f1, tp=total_tp, tn=total_tn, fp=total_fp,
              fn=total_fn))
    
    return all_outputs, all_targets

In [None]:
outputs, targets = test_model(cut_model, pos_test_dataloader, neg_test_dataloader, criterion)

In [None]:
outputs[0].detach().cpu().numpy().tolist()

In [None]:
targets[0].detach().cpu().numpy().tolist()

In [None]:
test_model(cut_model_5_epochs, pos_test_dataloader, neg_test_dataloader, criterion)

# Part 4: Gradual Cut Detection

TBD