# Extract features

In [1]:
cd ..

/home/beckmann/fairseq


In [2]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2"

In [3]:
import torch
from glob import glob
import pickle
import numpy as np
from IPython.display import Image, display
import mlflow
from fairseq.data import dictionary
from sklearn.metrics import pairwise_distances_argmin

In [4]:
from fairseq.data import (
    data_utils,
    Dictionary,
    PadDataset,
    PrependTokenDataset,
    TokenBlockDataset,
    EmbeddingDataset
)
from fairseq.data import iterators

In [5]:
from fairseq.models.roberta import RobertaModel
roberta = RobertaModel.from_pretrained('checkpoints/', 'checkpoint_best.pt', '/mnt/tamedia/video_concierge/new_imnet_10k')
roberta.cuda()
assert isinstance(roberta.model, torch.nn.Module)

In [6]:
def load_datasets(split_path):
    # TOKEN DATASET
    # dictionary
    dictionary = Dictionary.load(os.path.join(split_path.rsplit('/',1)[0], 'dict.txt'))
    token_dataset = data_utils.load_indexed_dataset(
            split_path,
            dictionary,
            None,
            combine=False,
    )
    # create continuous blocks of tokens
    token_dataset = TokenBlockDataset(
        token_dataset,
        token_dataset.sizes,
        512 - 1,  # one less for <s>
        pad=dictionary.pad(),
        eos=dictionary.eos(),
        break_mode='eos',
    )
    # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
    token_dataset = PrependTokenDataset(token_dataset, dictionary.bos())
    token_dataset = PadDataset(token_dataset, pad_idx=dictionary.pad(), left_pad=False)
        
    # EMBEDDING DATASET
    embs = torch.load(split_path + '.features')
    embedding_dataset = EmbeddingDataset(embs, pad_idx=0, left_pad=False)
    
    # COUNT DATASET
    # load counts
    thresh = 100
    with open(split_path + '.counts') as count_file:
        lines = [line.rstrip() for line in count_file]
        counts = [line.split(' ') for line in lines]
        for i, count in enumerate(counts):
            count = [int(el) for el in count]
            counts[i] = [el if el < thresh else thresh for el in count]
            counts[i] = torch.LongTensor(np.concatenate([[0],counts[i],[0]]))
    count_dataset = PadDataset(counts, pad_idx=0, left_pad=False)
    
    return token_dataset, embedding_dataset, count_dataset

In [7]:
# datasets
token_dataset, embedding_dataset, count_dataset = load_datasets('/mnt/tamedia/video_concierge/new_imnet_10k/valid')

In [27]:
epoch_size = len(token_dataset)

# batch sampler
batch_sampler = []
batch_size = 1
for i in range(0, epoch_size-batch_size, batch_size):
    batch_sampler.append(list(range(i, i+batch_size)))
batch_sampler.append(list(range(i, epoch_size)))

# iterators
token_iterator = iterators.EpochBatchIterator(
        dataset=token_dataset,
        collate_fn=token_dataset.collater,
        batch_sampler=batch_sampler
).next_epoch_itr(shuffle=False)

embedding_iterator = iterators.EpochBatchIterator(
        dataset=embedding_dataset,
        collate_fn=embedding_dataset.collater,
        batch_sampler=batch_sampler
).next_epoch_itr(shuffle=False)

count_iterator = iterators.EpochBatchIterator(
        dataset=count_dataset,
        collate_fn=count_dataset.collater,
        batch_sampler=batch_sampler
).next_epoch_itr(shuffle=False)

In [15]:
bert_features = []
for token, embedding, count in zip(token_iterator, embedding_iterator, count_iterator):
    bert_features.append(torch.mean(roberta.extract_features(token, count, embedding), axis=0)[0].cpu().data.numpy())

# Visualize close videos

In [16]:
def show_keyframes_for_vid(video_path):
    frames = sorted(glob(video_path + '/*'))
    for frame in frames:
        display(Image(filename=frame))

In [17]:
def pairwise_distances_argmin_gpu(x, y, bsz=10000, cuda_device='cuda:0'):
    argmins = np.zeros(len(x))
    device = torch.device(cuda_device)
    yy = torch.from_numpy(y).float().to(device)
    for i in range(0, len(x), bsz):
        xx = torch.from_numpy(x[i:i+bsz,:]).float().to(device)
        out = torch.cdist(xx,yy)
        a = torch.argmin(out, 1).cpu().numpy()
        argmins[i:i+len(a)] = a
        del xx
    return argmins

In [18]:
path = '/mnt/tamedia/video_concierge/bert_data/kf_data/valid'
videos = sorted(glob('{}/*'.format(path)))

In [None]:
vidnum=26
show_keyframes_for_vid(videos[vidnum])
argmins = pairwise_distances_argmin_gpu(np.expand_dims(bert_features[vidnum], axis=0), np.delete(bert_features, vidnum, axis=0), bsz=1)
print('_____________________________________________________________________________________________________________')
print(int(argmins[0]))
show_keyframes_for_vid(np.delete(videos, vidnum)[int(argmins[0])])

In [28]:
# compare with averaged features
averaged_features = []
for embs in embedding_iterator:
    averaged_features.append(torch.mean(embs, axis=1)[0].cpu().data.numpy())

In [None]:
vidnum=26
show_keyframes_for_vid(videos[vidnum])
argmins = pairwise_distances_argmin_gpu(np.expand_dims(averaged_features[vidnum], axis=0), np.delete(averaged_features, vidnum, axis=0), bsz=1)
print('_____________________________________________________________________________________________________________')
print(int(argmins[0]))
show_keyframes_for_vid(np.delete(videos, vidnum)[int(argmins[0])])