In [None]:
from __future__ import print_function, division
import sys
sys.path.append("../")

import numpy as np
import matplotlib.pyplot as plt
import pathlib
import pickle
import random
import scipy.stats
import torch
from torch import nn
import torchvision
import tqdm

import dsbfetch
import dsbtorch

plt.ion()   # interactive mode

device = torch.device("cuda")

In [None]:
dataset_dir = pathlib.Path("/home/ubuntu/data/encoded_dataset_511/test")
dataset = dsbtorch.PreEmbeddedDataset(dataset_dir)

In [None]:
def collate_fn(batches):
    return list(batch[0] for batch in batches), list(batch[1] for batch in batches)

batch_size = 64
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=6, pin_memory=True, collate_fn=collate_fn)
dataset_sizes = len(dataset)

In [None]:
def eval_model(rnn_decoder):
    rnn_decoder.eval()
    
    ranges = []
    ious = []
    
    with torch.set_grad_enabled(False):
        for batch_idx, (cnn_outputs, labels) in enumerate(tqdm.tqdm(dataloader)):
            cnn_outputs = nn.utils.rnn.pack_sequence(cnn_outputs, enforce_sorted=False).to(device)

            # Get the RNN labels
            padded_labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)
            padded_start_labels, padded_end_labels = dsbtorch.get_start_and_end_labels(padded_labels)

            start_idxs = torch.argmax(padded_start_labels, dim=-1)
            end_idxs = torch.argmax(padded_end_labels, dim=-1)

            start_probs_ps, end_probs_ps = rnn_decoder(cnn_outputs)

            start_probs_all, start_lengths = torch.nn.utils.rnn.pad_packed_sequence(start_probs_ps, batch_first=True)
            end_probs_all, end_lengths = torch.nn.utils.rnn.pad_packed_sequence(end_probs_ps, batch_first=True)

            start_probs_all = torch.squeeze(start_probs_all, dim=-1)
            end_probs_all = torch.squeeze(end_probs_all, dim=-1)

            assert torch.all(start_lengths == end_lengths)

            start_preds_idx = torch.argmax(start_probs_all, dim=-1)
            end_preds_idx = torch.argmax(end_probs_all, dim=-1)
            ranges.extend(zip(start_preds_idx.tolist(), end_preds_idx.tolist()))
            ious.extend(dsbtorch.compute_IOU_from_indices(start_preds_idx, end_preds_idx, start_idxs, end_idxs))
                        
    return ranges, ious

In [None]:
decoder = dsbtorch.PreprocessedEncoderDecoder(2048, weights_path="/home/ubuntu/data/DeepSponsorBlock/results/preprocessed_encoder_decoder.weights").to(device)

In [None]:
ranges, ious = eval_model(decoder)

In [None]:
# Load labels from SponsorBlock segments file.
labeled_videos = dsbfetch.load_segments("../segments.csv")
labels = {x.video_id: x.segments for x in labeled_videos}

In [None]:
# Compile the videos we evaluated
videos = [emb_file.stem.split('.')[0] for emb_file, _ in dataset.videos]
pairs = sorted(random.sample(list(zip(videos, ranges, ious)), 50), key=lambda x: -x[2])

In [None]:
# Get YouTube links
links = [("https://www.youtube.com/embed/%s?start=%d&end=%d" % (vid, start, end + 1), iou, labels[vid]) for vid, (start, end), iou in pairs]
print("\n".join(str(x) for x in links))

In [None]:
print("Mean:", np.mean(ious))
print("Median:", np.median(ious))

In [None]:
# Plot the histogram.
plt.figure(dpi=300)
plt.style.use('grayscale')

plt.hist(ious, bins=25)

title = "Number of Videos by IOU"
plt.title(title)

plt.show()