# Test increased input size

So far we've been stuck with input size 1024 because of memory problems.

Let's try to increase this by shrinking other parts of the model, using multiple GPUs, or using a longer version of the model.

### Longformer
Let's use the `Longformer` model defined in prior work, which uses local self-attention rather than global attention to reduce memory/training time.

Details [here](https://huggingface.co/transformers/model_doc/longformer.html).

In [2]:
from transformers import LongformerConfig, LongformerModel
config = LongformerConfig()
cache_dir = '../../data/longformer_cache/'
model = LongformerModel.from_pretrained('allenai/longformer-base-4096')

In [3]:
from transformers import LongformerTokenizer
tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096', cache_dir=cache_dir)
test_input_str = ['this is an input sentence', 'this is another input sentence']
test_input_ids = tokenizer.encode_plus(test_input_str[0])
print(test_input_ids)

{'input_ids': [0, 9226, 16, 41, 8135, 3645, 2], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}


In [7]:
## test max document length
import numpy as np
np.random.seed(123)
test_vocab = ['testing', 'words', 'sentence', 'language']
max_doc_length = 4096
test_input_str = ' '.join(np.random.choice(test_vocab, max_doc_length, replace=True))
test_input_tokens = tokenizer.encode_plus(test_input_str, truncation=True)
import torch
device_name = 'cuda:2'
model.to(device_name)
with torch.no_grad():
    test_input_ids = torch.LongTensor(test_input_tokens['input_ids']).reshape(1,-1)
    test_attention = torch.LongTensor(test_input_tokens['attention_mask']).reshape(1,-1)
    test_input_ids = test_input_ids.to(device_name)
    test_attention = test_attention.to(device_name)
    test_output = model(test_input_ids, test_attention)
    test_input_ids = test_input_ids.to('cpu')
    test_attention = test_attention.to('cpu')

In [8]:
print(test_output)

(tensor([[[-0.0524, -0.0313,  0.0290,  ..., -0.0564, -0.0237, -0.0572],
         [ 0.0336, -0.0289,  0.1894,  ..., -0.5832,  0.0433, -0.0138],
         [ 0.1830, -0.1202,  0.2508,  ..., -0.6626,  0.1989,  0.0652],
         ...,
         [-0.2211, -0.0821,  0.0413,  ..., -0.3656,  0.0050, -0.0762],
         [-0.2310, -0.2436,  0.0702,  ..., -0.0624, -0.0860,  0.0031],
         [-0.1774, -0.1190,  0.0915,  ..., -0.2293,  0.1176, -0.0809]]],
       device='cuda:2'), tensor([[ 0.1938, -0.2920,  0.0522, -0.0186,  0.3591, -0.1527, -0.5196, -0.3524,
         -0.0942, -0.3459, -0.4731, -0.1311,  0.2271, -0.3579,  0.0459, -0.2032,
         -0.3294, -0.1121, -0.1357, -0.0892, -0.0480,  0.0020, -0.2277,  0.0582,
         -0.5146, -0.0252, -0.1562, -0.3015,  0.2656,  0.2639, -0.0479, -0.0453,
         -0.2961,  0.0293,  0.0439,  0.2718, -0.2025,  0.1744, -0.4228,  0.2750,
         -0.0817,  0.1218,  0.2675, -0.0756,  0.1055,  0.1678, -0.2062, -0.2978,
         -0.1758,  0.4407,  0.3613,  0.3274,  

In [5]:
test_output[0].shape

torch.Size([1, 4096, 768])

It seems like this worked! Can this handle the news articles in our data?

In [9]:
## load sample data
import pandas as pd
NYT_article_data = pd.read_csv('../../data/nyt_comments/NYT_question_data_train_data.csv', sep=',', index_col=False)
display(NYT_article_data.head())

Unnamed: 0,source_text,target_text,article_id
0,WASHINGTON -- President Trump's advisers have ...,Where is our supine U. S. Congress?,5ad09d04068401528a2a8848
1,A federal judge in Manhattan indicated on Mond...,would that apply to me if I were in those circ...,5ad49614068401528a2a8e81
2,"President Trump recently tweeted, “ The United...","@ Sarah : I ask, in all sincerity, what good t...",5add197f068401528a2aa147
3,I came to America from India at age 23. That w...,Why aren't Indian citizens and other immigrant...,5ad75687068401528a2a95e6
4,No president in my lifetime has made me think ...,But what about at home?,5ac4059c068401528a2a1c89


In [None]:
NYT_article_data = NYT_article_data.assign(**{
    'source_text_len' : NYT_article_data.loc[:, 'source_text'].apply(tokenizer.tokenize)
})
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
plt.hist(NYT_article_data.loc[:, 'source_text_len'])
plt.yscale('log')
plt.show()