In [None]:
from __future__ import print_function, division
import sys
sys.path.append("../")

In [None]:
import numpy as np
import os
import pathlib
import pickle
import torch
import torchvision
import tqdm

import dsbtorch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torchvision.set_image_backend('accimage')

In [None]:
data_dir = "/home/ubuntu/data/dataset/"
dataset_names =  ['train', 'dev', 'test']

In [None]:
scanned_datasets = {}
for x in dataset_names:
    sddir = pathlib.Path("scans/" + os.path.join(data_dir, x))
    sdfile = sddir / "ws1-psr1-nsr1.pkl"
    print(sdfile)
    if not sdfile.exists():
        raise ValueError("You need to use the ScanDatasets notebook first to scan & pickle the dataset.")
    with open(sdfile, 'rb') as f:
        scanned_datasets[x] = pickle.load(f)

In [None]:
datasets = {x: dsbtorch.VideoSlidingWindowDataset(scanned_datasets[x], dsbtorch.DEFAULT_TRANSFORM) for x in dataset_names}

In [None]:
output_dir = "/home/ubuntu/data/encoded_dataset/"

In [None]:
model = dsbtorch.ResCNN("/home/ubuntu/data/DeepSponsorBlock/results/rescnn.weights").to(device)
for param in model.parameters():
    param.requires_grad = False
model.eval()

In [None]:
batch_size = 1024
dataloaders = {x: torch.utils.data.DataLoader(datasets[x], batch_size=batch_size, num_workers=6, pin_memory=True) for x in dataset_names}
dataset_sizes = {x: len(datasets[x]) for x in dataset_names}

In [None]:
with torch.set_grad_enabled(False):
    for x in dataset_names:
        sd = scanned_datasets[x]

        out_files = [output_dir / dir_list[0].parent.relative_to(data_dir) for dir_list in sd.cumulative_dirs]
        lengths = list(np.diff(np.array(sd.cumulative_indices + [sd.n_indices])))

        # Reverse them to use as a stack.
        out_files.reverse()
        lengths.reverse()

        encoder_outputs = []
        acc_labels = []
        for imgs, lbls in tqdm.tqdm(dataloaders[x]):
            imgs = torch.reshape(imgs, (-1, 3, 144, 256)).to(device)
            encoder_outputs.append(model(imgs).cpu())
            acc_labels.append(torch.reshape(lbls, (-1, )))

            while lengths and sum(x.shape[0] for x in encoder_outputs) >= lengths[-1]:
                out_path = out_files.pop()
                out_path.parent.mkdir(parents=True, exist_ok=True)

                combined_encoder_outputs = torch.cat(encoder_outputs)
                combined_labels = torch.cat(acc_labels)

                length = lengths.pop()

                encoder_outputs = [combined_encoder_outputs[length:]]
                acc_labels = [combined_labels[length:]]

                cnn_outputs, labels = (combined_encoder_outputs[:length], combined_labels[:length])
                np.save(out_path.with_suffix('.emb'), cnn_outputs.numpy())
                np.save(out_path.with_suffix('.lbl'), labels.numpy())