# NLP with RoBERTa.

In [1]:
import torch
import json
import torch.nn.functional as F
import pandas as pd
from fairseq.data.data_utils import collate_tokens
from tqdm.auto import tqdm
import numpy as np

In [2]:
omdb = json.load(open("../../../../data/parsed/omdb.json", "r") )
tmdb = json.load(open("../../../../data/parsed/tmdb.json", "r") )

In [3]:
batch_size = 4
cuda = torch.device('cuda')

In [4]:
plots = []
for i in tmdb.keys():
    omdb_plot = omdb[i]['omdb'].get('Plot', '')
    tmdb_plot = tmdb[i]['tmdb'].get('overview', '')
    plot = tmdb_plot + ' ' + omdb_plot
    plots.append((i, plot, len(plot)))
    
plots = list(sorted(plots, key=lambda x: x[2]))
plots = list(filter(lambda x: x[2] > 4, plots))

def chunks(l, n):
    for i in range(0, len(l), n):
        yield l[i:i + n]

ids = [i[0] for i in plots]
plots = [i[1] for i in plots]
plots = list(chunks(plots, batch_size))
ids = list(chunks(ids, batch_size))

In [5]:
roberta = torch.hub.load('pytorch/fairseq', 'roberta.base').to(cuda)
roberta.eval()
print()

Using cache found in /home/dev/.cache/torch/hub/pytorch_fairseq_master


loading archive file http://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz from cache at /home/dev/.cache/torch/pytorch_fairseq/37d2bc14cf6332d61ed5abeb579948e6054e46cc724c7d23426382d11a31b2d6.ae5852b4abc6bf762e0b6b30f19e741aa05562471e9eb8f4a6ae261f04f9b350
| dictionary: 50264 types



In [6]:
fs = {}

def extract_features(batch, ids):
    batch = collate_tokens([roberta.encode(sent) for sent in batch], pad_idx=1).to(cuda)
    batch = batch[:, :512]
    features = roberta.extract_features(batch)
    pooled_features = F.avg_pool2d(features, (features.size(1), 1)).squeeze()
    for i in range(pooled_features.size(0)):
        fs[ids[i]] = pooled_features[i].detach().cpu().numpy()

In [7]:
for batch, ids in tqdm(zip(plots[::-1], ids[::-1]), total=len(plots)):
    extract_features(batch, ids)

HBox(children=(IntProgress(value=0, max=6779), HTML(value='')))




In [10]:
transformed = pd.DataFrame(fs).T

In [11]:
transformed.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,758,759,760,761,762,763,764,765,766,767
94314,-0.003021,0.042953,0.041073,-0.070172,0.041283,-0.022052,0.042337,0.061337,0.05694,-0.042828,...,-0.019826,-0.057223,-0.013783,-0.021975,0.013032,0.071649,-0.133055,-0.137458,0.038325,-0.021989
108947,0.005153,0.085032,0.004828,-0.130353,0.182427,0.118248,-0.002422,0.0791,-0.001717,-0.008349,...,-0.071296,-0.025807,0.000577,0.027177,0.099686,0.111841,-0.128115,-0.121794,0.03955,-0.009374
116851,-0.026172,-0.006734,0.027957,-0.177829,0.177694,-0.001653,-0.037998,0.140078,0.003422,0.011289,...,-0.067595,-0.061443,-0.047832,-0.00123,0.085797,0.107863,-0.048223,-0.033383,0.032022,-0.0061
3373,-0.042934,0.04532,0.030206,-0.081032,0.100015,0.094492,-0.050629,0.014997,0.050944,-0.083253,...,-0.039333,-0.119122,-0.069987,0.0243,-0.018089,0.065861,-0.136187,-0.23549,0.026261,-0.034703
102154,-0.065819,0.067415,0.064005,-0.045871,0.119665,0.039735,-0.02474,0.00926,0.051017,-0.023419,...,-0.030069,-0.032542,-0.062198,-0.025295,0.028602,0.085711,-0.192206,-0.226985,0.032895,-0.035906


In [12]:
transformed.to_csv('../../../../data/engineering/roberta.csv', index=False)