In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import xavier_uniform
from torch.autograd import Variable
import numpy as np
#from dataproc import extract_wvs

from math import floor
import random
import sys
import time
import os
from collections import defaultdict

from transformers import BertTokenizer, BertForSequenceClassification, BertConfig, BertModel
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from transformers.optimization import AdamW, get_linear_schedule_with_warmup


In [2]:
data_dir = "./mimic-iii"
files_dir = "./dump_files"
model_dir = "./models"

In [3]:
from gensim.models import Word2Vec
w2v_model = Word2Vec.load(os.path.join(model_dir, 'w2v-nurse_phys-100-5-3.bin'))



In [4]:
class BaseModel(nn.Module):
    def __init__(self, numClasses, embed_file=None,  dropout=0.2, embed_size=50, vocab_size=500):
        super(BaseModel, self).__init__()
        #torch.manual_seed(1337)
        self.numClasses=numClasses
        self.embed_size = embed_size
        self.dropout = nn.Dropout(p=dropout)
        #make embedding layer
        
        print(vocab_size,embed_size)
        #weights = w2v_model.wv
        #self.embed = nn.Embedding.from_pretrained(weights)
        self.embed = nn.Embedding(vocab_size, embed_size, padding_idx=0)
        nn.init.xavier_uniform_(self.embed.weight)
        self.embed_size=self.embed.weight.size()[1]

class BERTLATA(BaseModel):
    def __init__(self, numClasses, embed_file, num_filter_maps=256, embed_size=128, dropout=0.2,vocab_size=500):
        super(BERTLATA, self).__init__( numClasses=numClasses, embed_file=embed_file,dropout=dropout, embed_size=embed_size,vocab_size=vocab_size)
        numberofLayers=4
        self.config = BertConfig(vocab_size=self.embed.weight.size()[0], 
        hidden_size=self.embed.weight.size()[1],
        max_position_embeddings=MAX_LEN,
        num_hidden_layers=numberofLayers,
        num_attention_heads=numberofLayers,
        intermediate_size=num_filter_maps,
        output_hidden_states=True)
        
        self.encoder = BertModel(config=self.config)
        self.U = nn.ModuleList([nn.Linear(self.embed.weight.size()[1], self.numClasses) for i in range(numberofLayers+2)])
        self.final=nn.Linear(self.embed.weight.size()[1]*(numberofLayers+2),self.numClasses)

    def forward(self, x,labels=None):
#         print('x', x.shape)
        #x = self.embed(x)
        #print('embed x', x.shape)
        outputs = self.encoder(input_ids=x)
        print()
        attentionOutput=[]
        if(len(outputs)==3):
                attentionOutput=outputs[2]
        elif(len(outputs)==2 ):
                attentionOutput=outputs[1]
        m=[]
        Attentions=[]
        for i, output in enumerate(attentionOutput):
            alpha = F.softmax(self.U[i].weight.matmul(output.transpose(1,2)), dim=2)
            Attentions.append(alpha)
            m1 = alpha.matmul(output)
            m1 = self.dropout(m1)
            m.append(m1)
        alpha = F.softmax(self.U[-1].weight.matmul(outputs[0].transpose(1,2)), dim=2)
        Attentions.append(alpha)
        m1 = alpha.matmul(outputs[0])
        m1 = self.dropout(m1)
        m.append(m1)
        m=torch.cat(m,-1)
        y=self.final.weight.mul(m).sum(dim=2).add(self.final.bias)
        yhat = y
        return yhat, Attentions

class ENCAML(BaseModel):
    def __init__(self,numClasses, embed_file,num_filter_maps=256,embed_size=50, dropout=0.2,vocab_size=500):
        super(ENCAML, self).__init__( numClasses=numClasses, embed_file=embed_file,dropout=dropout, embed_size=embed_size,vocab_size=vocab_size)
        kernel_size=[3,5,7,9]
        self.conv=nn.ModuleList([nn.Conv1d(self.embed_size, num_filter_maps, kernel_size=i, padding=int(floor(i/2))) for i in kernel_size])
        self.U = nn.ModuleList([nn.Linear(num_filter_maps, self.numClasses) for i in range(len(kernel_size))])
        self.final=nn.Linear(num_filter_maps*len(kernel_size),self.numClasses)
        
    def forward(self, x):
        x = self.embed(x)
        x = self.dropout(x)
        x = x.transpose(1, 2)
        m=[]
        Attentions=[]
        for i, conv in enumerate(self.conv):
            x1 = torch.tanh(conv(x).transpose(1,2))
            alpha = F.softmax(self.U[i].weight.matmul(x1.transpose(1,2)), dim=2)
            Attentions.append(alpha)
            m1 = alpha.matmul(x1)
            m1 = self.dropout(m1)
            m.append(m1)
        m=torch.cat(m,-1)
        y=self.final.weight.mul(m).sum(dim=2).add(self.final.bias)
        yhat = y
        return yhat, Attentions

        
#Create a vocabulary file
#criterion=nn.MultiLabelSoftMarginLoss()
#ind2w = {i+2:w for i,w in enumerate(sorted(vocab))}
#ind2w[0]=PAD_TOKEN
#ind2w[1]=UNK_TOKEN
#w2ind = {w:i for i,w in ind2w.items()}
#model=BERTLATA(numClasses,None,64,128,0.2,vocabsize)
#generate batchwise data and target( multi label), in the loop of epochs train model
#outputs, attentions = model(data)
#loss = criterion(outputs, target)


In [5]:
df = pd.read_csv(os.path.join(files_dir, 'nurse_phys.csv'))

In [6]:
tz = BertTokenizer.from_pretrained("bert-base-uncased")
tz.vocab_size

30522

In [7]:
class NotesDataset(Dataset):

    def __init__(self, notes, targets, tokenizer, max_len):
        self.notes = notes
        self.targets = targets
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.notes)

    def __getitem__(self, item):
        review = str(self.notes[item])
        target = self.targets[item]

        encoding = self.tokenizer.encode_plus(
          review,
          add_special_tokens=True,
          max_length=self.max_len,
          return_token_type_ids=False,
          pad_to_max_length=True,
          return_attention_mask=True,
          return_tensors='pt',
        )
        
        print(encoding.keys())

        return {
          'notes_text': review,
          'input_ids': encoding['input_ids'].flatten(),
          'attention_mask': encoding['attention_mask'].flatten(),
          'targets': torch.tensor(target, dtype=torch.long)
        }

In [8]:
RANDOM_SEED = 42
df_train, df_test = train_test_split(df, test_size=0.1, random_state=RANDOM_SEED)
df_val, df_test = train_test_split(df_test, test_size=0.5, random_state=RANDOM_SEED)

In [9]:
def create_data_loader(df, tokenizer, max_len, batch_size):
    ds = NotesDataset(
    notes=df.NOTES.to_numpy(),
    targets=df.iloc[:,2:].to_numpy(),
    tokenizer=tokenizer,
    max_len=max_len
    )

    return DataLoader(
    ds,
    batch_size=batch_size,
    num_workers=4
    )

In [10]:
BATCH_SIZE = 10
MAX_LEN = 128

train_data_loader = create_data_loader(df_train, tz, MAX_LEN, BATCH_SIZE)
val_data_loader = create_data_loader(df_val, tz, MAX_LEN, BATCH_SIZE)
test_data_loader = create_data_loader(df_test, tz, MAX_LEN, BATCH_SIZE)

In [11]:
def train_epoch(model, data_loader, loss_fn, optimizer, device, scheduler, n_examples):
    model.to(device)
    model = model.train()
    losses = []
    tr_loss = 0

    for d in data_loader:
        input_ids = d["input_ids"].to(device)
        attention_mask = d["attention_mask"].to(device)
        targets = d["targets"].to(device)
        
#         print('input', input_ids.shape)
#         print(input_ids)
#         print("target", targets.shape)
#         print(targets)
        outputs, _ = model(input_ids)

        loss = loss_fn(outputs, targets)
        tr_loss += loss.item()

        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

    return tr_loss / n_examples

In [12]:
def eval_model(model, data_loader, loss_fn, device, n_examples):
    model.to(device)
    model = model.eval()
    tr_loss = 0


    with torch.no_grad():
        for d in data_loader:
            input_ids = d["input_ids"].to(device)
            attention_mask = d["attention_mask"].to(device)
            targets = d["targets"].to(device)

            outputs, _ = model(
            input_ids=input_ids)

            loss = loss_fn(outputs, targets)
            tr_loss += loss.item()

    return tr_loss / n_examples

In [13]:
def get_predictions(model, data_loader):
    model.to(device)
    model = model.eval()

    predictions = []
    real_values = []

    with torch.no_grad():
        for d in data_loader:
            texts = d["notes_text"]
            input_ids = d["input_ids"].to(device)
            attention_mask = d["attention_mask"].to(device)
            targets = d["targets"].to(device)

            outputs, _ = model(
            input_ids=input_ids)

            probs = F.sigmoid(outputs)

            predictions.extend(preds)
            real_values.extend(targets)

    predictions = torch.stack(predictions).cpu()
    real_values = torch.stack(real_values).cpu()
    return predictions, real_values

In [14]:
numClasses = 6984
EPOCHS = 10
device = 'cuda:1'
model=BERTLATA(numClasses, None, 64, 128, 0.2, 30522)

30522 128


In [15]:
optimizer = AdamW(model.parameters(), lr=2e-5, correct_bias=False)
total_steps = len(train_data_loader) * EPOCHS

scheduler = get_linear_schedule_with_warmup(
  optimizer,
  num_warmup_steps=0,
  num_training_steps=total_steps
)

loss_fn = nn.MultiLabelSoftMarginLoss()

In [None]:
%%time

history = defaultdict(list)
best_loss = float('inf')

for epoch in range(EPOCHS):

    print(f'Epoch {epoch + 1}/{EPOCHS}')
    print('-' * 10)

    train_loss = train_epoch(
    model,
    train_data_loader,    
    loss_fn, 
    optimizer, 
    device, 
    scheduler, 
    len(df_train)
    )

    print(f'Train loss {train_loss}')

    val_loss = eval_model(
    model,
    val_data_loader,
    loss_fn, 
    device, 
    len(df_val)
    )

    print(f'Val loss {val_loss}')
    print()

    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)

    if val_loss < best_loss:
        torch.save(model.state_dict(), 'best_model_state.bin')
        best_loss = val_loss

Epoch 1/1
----------


Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-s

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  3429,  2095,  2214,  3287,  2381,  9820,  3729,  2549,  3728,
         11889,  2349,  7473,  2361,  3591,  3968,  1054, 12205, 20228, 11236,
         18291,  3255,  3138,  2784,  3052, 19340,  1039,  2595,  2099,  3662,
          4800, 14876,  9289,  5443,  7473,  2361,  2363,  3158,  3597,  8292,
          7959,  8197,  4168,  8670,  6593, 20026, 23310, 20784, 20023,  3653,
          2094,  8977,  5643, 11937, 11714,  2361,  2638,  2594,  3054,  1051,
          2475,  2938,  2015,  3968,  2566,  3189,  2872,  1051,  2475,  6948,
          2213,  2588,  5508,  2938,  2015,  2282,  2250, 20228, 11236, 18291,
          3255, 21454,  4728,  3255,  8048,  1060,  8242, 10791, 11265,  4904,
         18981, 18595,  2278,  4530,  7485,  3653,  3540, 13700,  1054, 19857,
         16

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
tensor([[  101,  1052, 12170, 20051, 15321, 22684, 20389,  2100,  7163, 15079,
          7709,  3291,  6412,  7928,  7667, 24978,  2099,  2659,  2927,  5670,
         14925, 14399,  2100,  2358,  6678, 10791,  8242, 10124, 14262,  8820,
          3070, 14931,  2852,  3070,  2895,  3802,  2102,  4654, 28251,  4383,
         22822, 20738, 17153,  9365,  2140,  3255, 10290,  4921,  8840, 20110,
          2953,  1060, 13433, 29454, 10711,  4371,  2213, 17678, 10354, 16515,
          2638, 25606,  3433, 17850, 29454,  2102,  8840, 20110,  2953,  3255,
          2000,  3917,  3085, 13866,  2478,  9174, 28879, 19875,  2000,  3917,
          5844, 10268,  1044,  2475,  2080,  2933,  1051, 16429,  3242,  2572,
         28507,  2618,  2651,  9530,  2102, 17153,  9365,  2140,  3255, 13866,
          2163,  2092,  1059,  2566,  3597,  3401,  2102, 19029,  3046,  2178,
          8700, 20302,  8449,  2594,  9530,  

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  6522,  2072,  2720,  2095,  2214,  2158,  2381,  1044,  2102,
          2078, 25125,  3742,  7632,  8017, 16298, 29477, 16921,  7175, 22110,
          2278,  4914,  4015, 23025,  2226,  2729,  3322,  3591,  9808,  2232,
          4788,  5582,  1040,  7274,  2361, 22084,  1060,  2420,  1040,  7274,
          2361, 22084,  2195,  2706,  7552, 16007, 28207,  5666,  2147,  6279,
          5994,  8948,  7722,  2195,  3134,  5776,  2179, 24978, 18532,  2072,
          9808,  2232, 19817,  7361,  2733,  3283, 12098,  2546, 13675,  2445,
         17306,  2174,  2445,  2865, 16643, 15460,  3597,  7685,  4015, 15321,
          6305,  6558,  9345,  2140,  3188,  3225,  2002, 19362,  2378, 14181,
          2102,  3968,  3988, 17850, 17531,  10

dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101, 10930,  2931,  2146,  2627, 19960,  1044,  2595, 22822, 17062,
         24552, 28667, 29264,  6090,  8713,  2884, 15859,  7315, 21877,  9152,
         14141,  2213,  9808,  2050,  1044,  2102,  2078, 21183,  2072, 10089,
          2019, 17577, 28697,  2497,  8872,  2094,  2175,  4904,  8670, 21162,
         17577,  9052,  3669, 13250, 11265,  2290,  2358,  9331,  2232, 11888,
          3255,  9808,  2618, 10441, 15265, 14778,  2483,  2234,  1041,  2860,
          1039, 17540, 19372,  2512, 21572, 26638, 19340,  1060,  2549,  3134,
          2036,  3264,  3255,  1054,  2217,  6090,  8977,  9145,  3543, 13866,
          2318,  3158,  3597, 18856, 23938,  2651,  2061, 12274, 16136,  2938,
          7510,  1039,  2692,  2475,  6887,  1044, 22571, 11636, 17577,  7667,
         13866,  4654,  2361,  1059, 21030, 11254,  2802,  4530,  2938,  2015,
          1048,  3081, 12170,  4502,  2361, 23675,  2015, 234

dict_keys(['input_ids', 'attention_mask'])
tensor([[0, 0, 0,  ..., 0, 0, 0]], device='cuda:1')
x torch.Size([1, 128])

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  4780, 10930,  2213,  1059,  7610,  2232,  3278,  8909, 22117,
         11463,  6761,  2863, 18804, 16677,  5490,  6692, 27711,  3526,  2482,
         21081,  2863,  4914,  3110, 14849,  1044, 22571, 11636,  2594,  4909,
         18178,  5302,  3988,  2132, 14931,  3968,  7175,  1044, 22571, 10244,
          3619,  3012,  2306,  8292,  2890, 17327,  2819,  1039,  2595,  2099,
          3662,  2187, 20228, 11236,  2389,  1041,  4246, 14499,  3378,  8823,
          2571, 25572,  6190,  2174,  2004, 24335, 13876,  9626,  4588,  6380,
         18133, 17540, 14978,  4914,  2723,  5776,  2179,  4895,  6072, 26029,
         12742, 26034,  3642,  2630,  2170,  9353,  4877,  7531,  2363,  4958,
          3170

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101, 21887,  2854, 16749, 11826, 22160,  2102,  9298,  2290,  2595,
          2549,  7667, 13866,  2363, 20014, 19761,  3064,  7367, 13701,  2094,
         17678,  9253, 14181,  3215, 13373, 14573,  2121,  7712, 13823, 17917,
          2080,  2475, 25022, 19610,  7716, 18279, 22924,  6540,  3108, 10868,
         19689, 14262,  8820,  3070, 23894,  3697,  2079,  9397,  3917,  2556,
          3096, 10109,  1058, 14666,  4987,  1057,  7361, 11706,  2895,  3872,
         24501,  2271, 26243,  4383,  9253, 14841,  6494,  3064,  5441, 24829,
          2361, 22597, 14181,  2102,  2318,  2566,  1039,  8778,  1048, 17250,
          2015, 16360, 19738,  3064,  1058, 14666,  3168,  5425, 23894,  3697,
          2079,  9397,  3917, 24829, 10318, 13823, 17531, 25022,  3333, 16731,
          2102, 28667,  5369, 18141,  3816,  2310, 18674, 24501,  4765,  9253,
          3445,  2

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  2720,  2095,  2214,  3287,  4914,  2420,  3283, 10886,  6544,
          2046,  9048, 10719, 10792, 19583,  3905,  2991,  9601,  2030, 10128,
          4015,  4200,  6544, 10534,  5845, 11085,  3527,  2140,  2012, 11444,
          2078,  2815, 23025,  2226,  6544, 10534, 11585,  3564,  3242,  8329,
          2276,  6589,  2718,  2132,  3300,  2723,  2132,  1039,  8560, 14931,
          2015,  4663,  8669,  2828,  7939,  2015, 19583,  4942, 25148,  3370,
         27723, 29248, 16330,  2030,  2705,  2080,  8560,  6749,  2524,  9127,
         27937,  2080,  2825, 11707,  8830, 23025,  2226,  8089, 23760, 29048,
         28378,  7667, 23760, 25808,  3512,  4088,  5670, 13433,  3229,  4039,
          2202, 13433, 19960,  2015,  2895,  4921, 26018,  3669, 21254, 11460,
          1

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101, 18583, 17341,  2902,  3734,  2512, 12436,  2361,  7667, 21358,
         15878, 15928,  2063,  1048,  2015,  8579,  4244,  7888, 15911,  2187,
          2217,  2938,  2015, 18629,  2938,  2015,  4654,  8743,  3258,  2844,
         19340, 26478, 17944,  2498, 19055,  2358,  2890,  2361,  1052,  2532,
          2668,  9832, 27263,  2278,  9108,  1048, 12273, 13866,  2632,  8569,
         27833, 11265,  5910, 14163,  9006, 27268,  1999,  2232,  6487,  4183,
         17854,  2378, 19340,  2895,  1039,  2595,  2099, 11703, 12083, 13706,
          2589,  8627,  8627, 22975, 18939,  3108, 13866,  2589, 11113,  2595,
          2445,  2566,  4449,  2668,  1039,  2595,  2741, 27263,  2278,  3433,
         13866,  3464, 18629,  3464, 21358, 15878, 15928,  2063,  5410,  2588,
          2

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101, 26226,  2050,  6909, 18439,  1999, 14971,  7542,  7667,  7247,
          3461, 12170, 20051, 11530,  3671,  3997, 10862,  4076, 10954,  8048,
          1060,  2509,  1059,  6413,  4512, 23439,  3255, 15035,  2791,  4487,
         21748, 11638,  3370,  3048, 23263,  9174,  1059,  4740,  2893,  1051,
         16429,  5061, 17996,  3264, 18856, 12868,  2895,  9530,  2102, 10856,
          9956,  2140,  1053,  2575,  8093,  2015,  4921,  2002, 19362,  2378,
         14181,  2102,  2218,  1060,  8117, 13866,  2102, 23123,  8319,  2417,
          2527,  7962,  1050, 15088,  1053,  2487,  8093,  3433, 11265, 10976,
          3570, 10109, 13866,  2102,  2002, 19362,  2378, 14181,  2102,  9530,
          2102,  1057, 17850,  2933,  9530,  2102,  8080,  1050, 15088,  1053,
          2487,  8093,  2689,  1053,  2475,  8093,  2015,  3464,  6540,  9530,
          2102,  4

dict_keys(['input_ids', 'attention_mask'])

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  8776,  5177,  3570,  3972, 15735,  2819,  7667, 13866,  7367,
         13701,  2094, 17678, 11253,  4747, 13866,  2036, 21713,  5794,  8516,
         27304,  2927,  5670, 13866, 21568,  3048,  4654,  7913, 22930,  3111,
          2793,  8783,  2132, 10005, 23760, 25808,  3512,  7391,  5020, 28022,
          2135, 22643,  3048,  4654,  7913, 22930,  3111,  2793,  2895,  7367,
         20207,  2357, 11265, 10976, 13869,  7367, 20207,  3255,  4200, 14841,
          6494,  3064,  3020, 13004,  3905,  4852, 22356, 10975,  2078,  7893,
          2094,  3641, 23025,  2226,  6319,  3433, 13866,  2506, 21568,  3484,
          5670, 23025,  2226,  2136, 19488, 13866,  3048,  4654,  7913, 22930,
          3111,  2793,  2206, 10954, 23025,  2226, 19488,  2047,  4449,  2804,
          7893,  2094,  5776,  2625, 21568,  2625, 23

input torch.Size([1, 128])
tensor([[  101, 12603,  1052,  7667, 27011, 14931,  2132,  2589,  2197,  2305,
         14931,  3463,  3461,  2187,  7652,  5670,  5373,  3054,  4179,  5090,
          3621,  5301, 10124,  4942,  7011, 15472,  3170,  2014,  6200,  3508,
          4895,  9289,  9099,  6528, 29469,  2389,  2014,  6200,  3508,  4642,
         23760,  4181,  3366,  1042, 10085,  2072,  2157, 19124, 21833,  4193,
          1044, 22571, 10244,  3619,  3012,  3497,  3968, 14545,  3621,  3445,
          6540, 23760,  4181,  3366,  2157, 18309,  3012,  4942, 24979,  2389,
         19610, 10610,  2863,  2047,  2752, 19610,  2953, 25032,  4270, 27011,
         11325, 28828, 15321,  6305,  2594, 11320, 11201,  2099,  8560, 28929,
         24960,  3968, 14545, 11601,  3968, 14545, 26721, 10760,  9289, 19470,
          3012,  3671, 11320, 11201,  2099,  9113,  6147,  2895,  5776,  4654,
         28251,  4383,  5852,  2589, 25775,  2092, 10699, 10109,  1048,  4190,
         11251,  2583, 23

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  1044,  1044, 22571, 12184,  3619,  3258,  5213,  7667, 17531,
         25353, 16033, 10415,  2707,  5670,  6360,  2318,  9874,  2075,  2659,
         26226,  2361,  6348,  3464, 10882,  2497,  3446,  8138, 26189,  2278,
          2895,  8945,  7393,  2098,  2561,  1048, 20989,  4189,  3466, 13866,
         25194, 10514,  7542,  2075, 21713,  5794,  8516, 14181,  2102,  3030,
          3433,  4189,  3433,  8331,  8945,  7393, 17531,  2815,  2659,  3232,
          2847,  6360,  9874,  2098,  2747,  3054,  2659, 25353, 16033, 10415,
          1057,  2307,  7620,  8331,  4119,  8498,  2105, 17850,  2348, 13675,
          5243,  7629,  3170,  3464,  6706,  2933, 13866,  4150,  1044, 22571,
         12184,  3619,  3512,  8945,  7393, 20989, 27263,  2278, 12889,  2651,
          2933,  4826, 16464,  4945, 11325, 12098,  51

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  2720, 28353,  1052,  9298,  2290, 10381,  2546,  1041,  2546,
          3591,  7473,  2361,  2436,  2651,  3653,  6728,  9312,  3239,  5970,
          5776, 10865,  3445, 17540,  2627,  2733,  8030,  2030,  2705,  7361,
         22084, 16021,  5358,  6200,  3445,  2896,  4654,  7913, 16383,  3968,
         14545,  5508,  2436,  2179,  1044, 22571, 11636,  2594, 10958, 12170,
         22083, 11733,  2099,  8579,  4244,  2896,  4654,  7913, 16383,  3968,
         14545,  2741,  9413,  2147,  6279,  3968,  3988,  8995,  2015, 17212,
          2497,  2445, 17306, 11460,  9152, 13181, 25643, 17119,  2378,  5869,
          7646, 11460,  4921, 10507, 17996,  6434, 23969,  2290,  3662,  4030,
          2012, 14482, 10882, 23736, 20382, 17850,  3322,  5776,  7722,  3370,
          3570,  5301,  4487, 14900,  2483,  7722,  3370,  5301,  2659,  1048,
         13316,  3

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  3806, 13181, 18447, 19126, 19501,  3356, 11463,  8189, 21025,
         19501, 21025,  2497,  7667,  2363,  3197,  7473,  3197, 21461,  2361,
          1041,  2860,  1999,  2099, 16731,  2102, 16731,  2102,  3445,  2318,
          3131,  7473, 11463,  8189, 11585, 15488, 14644,  2304, 14708, 28667,
         11667,  2895,  2156,  3433, 16731,  2102,  5838, 23263,  2933,  2507,
          3197,  7473,  4638,  2695, 16731,  2102, 25125,  4945, 11888, 11888,
         25125,  4945, 13675,  2546, 11888, 14234,  4295,  7667,  1057, 10507,
         10507, 17850,  2895,  2363, 10507, 24978,  8945,  7393, 10507,  1048,
          2099,  8945,  7393,  2036,  2318,  4921,  2546, 10507, 17850,  3433,
          1057,  3445,  2067, 10507, 17850,  3254, 10548, 10507, 17850,  2933,
          3613,  3422,  1057,  4876, 28767,  2164, 21901,  4800,  1999, 14971,
          6593,  7

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  4868, 10930,  3287,  2381,  9820,  3729,  2549,  1058,  2140,
          1044, 17341,  2273,  2075, 13706,  1060,  2475,  2381, 18439,  2000,
          2595,  7361,  8523, 15530,  2483, 25750,  3591,  3968, 29031,  4895,
          6072, 26029, 12742, 29031,  3322,  2245, 23060,  2100,  8663,  7629,
         26641,  2588,  3228,  6583, 18992,  2078, 16464,  3446,  2253,  3968,
          5443,  1052, 17531,  1052,  1054,  8776,  4337,  3512,  1044, 22571,
         11636,  2594,  2659, 10958, 25478, 20014, 19761,  3064,  2179,  2440,
         16405,  6820, 16136,  3595,  8496,  2363,  1048,  4921, 24978,  3158,
          2278, 14931,  2595, 23310, 11253,  4135, 18684, 15459,  2132, 14931,
          3662,  1044, 22571, 10244,  3619,  30

dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  5776, 10930,  3287, 18804, 16677, 25125,  3526,  6187,  1039,
          1038,  2310, 18674, 16215, 21716, 15853,  2483, 10886,  2139,  8569,
         13687,  2075,  5970, 25125, 12818, 16215, 21716, 15853,  2483,  8208,
          5776, 15330,  7709,  3278,  3872,  3223,  3517,  2668,  3279,  5776,
         24582,  2226, 20014, 19761,  3064,  2729,  3291, 25125, 16007, 28207,
          5666,  7667,  5776,  2253, 11322,  3512, 11707,  8830,  4504,  2187,
         11265,  8458,  2890,  6593, 16940,  2895,  2234,  3131, 17324,  2072,
          5082,  2144,  1039,  3433,  2933, 23760, 29048, 28378,  7667,  2895,
          3433,  2933,  5776, 10930,  3287, 18804, 16677, 25125,  3526,  6187,
          1039,  1038,  2310, 18674, 16215, 21716, 15853,  2483, 10886,  2139,
          8569, 13687,  2075,  5970, 25125, 12818, 16215, 21716, 15853,  2483,
          8208,  5776, 15330,  7709,  3278,  3872,  3223,  35

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  5776, 10930,  7610,  2232,  2595,  9033,  2290,  5490,  6692,
         27711,  3526,  2482, 21081,  2863,  4416, 15253, 11192, 10335,  6700,
          1044,  4903,  2094,  3322,  4015,  1044, 22571, 11636,  2594, 16464,
          4945,  3322,  4914, 16464,  3570,  5301, 18798,  2483, 11265,  2290,
          1040,  2615,  2102,  1038,  2140, 14931,  2050, 11265,  2290, 21877,
          2723,  5845,  5292,  2361,  1059,  3158,  2278,  8292,  7959,  8197,
          4168,  2036,  5845,  4921, 10343, 14089,  3593, 26788,  2618, 23760,
          9289,  3401, 10092, 16007, 28207,  5666,  4921,  5210,  8516, 13433,
          3158,  2278,  1039,  4487,  4246,  2678, 10577,  9033,  2290, 20298,
    

dict_keys(['input_ids', 'attention_mask'])
x torch.Size([1, 128])

input torch.Size([1, 128])
tensor([[  101,  6640,  2308,  1052,  2991,  2217,  2793,  2188,  7483,  2851,
          4285,  8768, 23439,  2132, 12603,  2132, 14931,  3936, 11325, 11888,
          3816,  4304,  2187, 19124,  4942, 24979,  2389, 19610, 10610,  2863,
          3742,  3466, 11457, 18834,  7277,  2571,  3602, 13866, 24735,  2991,
         17371,  2232,  7842,  2232, 13866,  4914, 24529,  2594,  2226,  8822,
          2741,  2651, 13982, 19610, 10610,  2863,  2852,  4942, 24979,  2389,
         19610,  2953, 25032,  4270, 17371,  2232,  7667, 13866, 23130,  2135,
         10109, 23130, 15074,  2015,  3264, 17531,  6540,  2187,  1050,  2497,
          2361,  4862,  2638,  2872,  2019, 25344,  3653,  6728,  3936,  3461,
         25619,  3020, 25353, 16033, 10415, 15324,  2895, 11265, 10976, 14148,
          1053,  2475,  8093, 13866, 13675,  7088, 11439,  8029, 13982, 13866,
          2445, 26018,  3669, 21254, 

dict_keys(['input_ids', 'attention_mask'])
tensor([[  101,  4914,  3674,  4408,  2417,  2668,  6812,  2884,  5750, 16844,
          2891,  3597,  7685,  2589,  4760,  2312,  8310,  2668, 27345,  2594,
         15859,  3120,  9524,  4453,  4015,  3197, 10975,  9818,  3131,  5127,
         13461, 16731,  2102, 19067, 24529,  2594,  2226,  2729, 25930,  2850,
          7610,  2232,  1044,  2102,  2078,  1044, 22571, 14573, 12541,  9314,
          2152, 16480,  4244, 27833,  4013, 20528,  2618,  6187,  8249,  7242,
          2067,  6181,  5970,  3806, 13181, 18447, 19126, 19501,  2896, 19610,
         10610,  5403, 12871,  7987,  2497, 18098, 21025, 19501, 21025,  2497,
          7667,  9499,  8048,  2335,  5443,  2015, 21358, 15878, 15928,  2063,
         16731,  2102, 10821,  4487, 29212,  7987,  2497,  2566, 28667, 11667,
          2895, 23675, 24978,  8331,  8945,  7393,  3131, 10975,  9818, 16371,
          2278,  4200,  6415,  2417,  3526,  2817,  3433,  5443,  3961,  6540,
         

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101, 10184, 11522,  2401,  7667,  2895,  3433,  2933, 10184, 11522,
          2401,  7667, 17850,  5034,  7842,  1059,  1051,  9468, 26189,  6169,
          2659,  1047,  2504,  4178, 10184, 11522,  2401,  2895, 16360, 19738,
          3064,  1059,  2033,  4160, 21117,  2140,  3433, 10548, 26189,  6169,
          3944,  2933,  9377,  1048, 17250,  2015, 27937,  2080,  3054,  3490,
          2618,  4568,  6393,  8571, 21887,  2854, 16749,  4295, 28353,  2003,
          5403,  7712,  2540,  4295,  7667,  1039,  3108,  3255,  2651,  1044,
         17540, 22939,  8458, 16610,  2483,  1059, 23969,  2290,  3431,  2358,
          6245,  2015,  2895,  2445, 15050, 19960,  2015,  3641,  1039,  4178,
         17540, 22939,  8458, 16610,  2483,  3264, 14925, 11714, 15530,  2483,
         19387, 20092, 15704,  3433, 17531,  2253,  2933,  8080, 18133, 23969,
          2290,  1

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101, 13866,  2095,  2214, 10170,  1044, 28353,  1040,  2213,  2475,
          1044,  2102,  2078, 28389,  1039, 26180, 15965, 11265, 10976, 20166,
          6583, 29566,  4588, 18642, 11888,  2659,  2067,  3255,  3591,  2420,
          4788,  5582, 13318, 19340,  2460,  2791,  3052, 10720,  2015,  3108,
          3255, 21454, 22939, 12171, 20192,  3968,  1039,  2595,  2099,  3662,
          1052,  2532, 13866,  2445, 23310, 20784, 20023,  1060,  2487,  2318,
         17212,  2497,  4015, 23025,  2226, 11460,  2102,  5776,  2170, 13866,
          2179,  2723,  1052,  2004,  8197, 15172,  6097, 10464,  3490,  8808,
          3264,  4078, 19321,  2075,  28

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  6421,  1042,  1052, 19330,  2102, 19888, 24278, 25022, 12171,
         25229,  1039,  1038, 13866,  6392,  1052,  1054, 24494, 12997,  2546,
          1040,  2595,  2706,  3283,  2188,  2651,  4178, 24471,  3981,  2854,
          4297, 12162, 21820,  3401,  2036,  3264,  4450,  2098,  5505,  2348,
         10599,  3251,  2367, 26163,  3393,  3968, 14545, 12422,  2367,  2030,
          2705,  7361, 22084,  1052,  4859,  3591,  3968,  8030,  3968, 13114,
         15928,  2063,  9634, 17531, 17850, 10548,  1048,  4921,  2546,  1039,
          2595,  2099, 29543,  2363,  3158,  9006,  2100, 15459, 23310, 11253,
          4135, 18684, 15459,  4914, 23025,  2226,  5142, 19802,  6190,  1044,
         14671, 11463, 15909,  2271,  1040,  2213,  2828,  7667,  2895,  3433,
          2


input torch.Size([1, 128])
tensor([[  101, 13866,  2464,  3968,  2335,  2197,  2733,  4958, 11921,  9048,
          2015,  9229,  3322, 15099, 14743,  3591,  3968,  3892,  2714,  8030,
         15219, 19501, 15099, 14743, 11937, 11714, 16731,  2102,  4530,  2685,
          2144,  2420,  3283, 15219,  8966,  1048,  2491,  9524,  9524,  3264,
         18749, 20026,  2389, 23245,  2464,  4372,  2102,  2741, 23025,  2226,
          2250,  4576,  3860,  5776,  5373,  2004, 18098,  2378,  1039,  2051,
          4958, 11921,  9048,  2015,  4451,  3468,  2098,  7667,  5776,  2363,
          3903,  4451,  8966, 13866,  1039,  2235,  3815, 19501,  2067,  3759,
         11937, 11714, 11522,  2594, 23760, 25808,  3512,  2124, 23760, 25808,
          3512,  2363, 13433, 19960,  2015, 16731,  2102,  9634,  3968,  2250,
          4576,  3647,  2938,  2015,  2282,  2250,  2895,  5776,  2363,  9410,
         13433, 19960,  2015, 17850,  1038,  1052,  4372,  2102,  2234,  3319,
          3718,  5308, 1

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  3963, 10930,  2546,  7388,  6187,  2149, 11631,  2154,  9634,
          2179,  2684,  2188,  6037,  4895,  4381, 11937, 11714,  2361, 22084,
          3147,  3543,  2170, 29031,  3264,  1044, 22571, 12184,  3619,  3258,
         14412,  2361, 22939,  8458,  5686,  4588, 25269,  9687,  9413,  9312,
          3936, 11937, 11714,  2361, 22084,  4039,  6855,  7509,  2475,  2872,
          2512,  2128, 13578,  8988,  2121,  7308,  3264, 24829,  2361,  3461,
         25619,  4658,  3543,  8915,  8737,  4895,  2890, 27108, 20782, 18749,
          4588,  5648,  4921,  3229,  3081,  1041,  3501,  4039,  6855, 15965,
         10768, 22049,  2430,  2240,  2872,  2363,  4921, 20989,  1060,  1048,
         17537,  3424,  7712,  3217, 21102,  2015, 23310,  2080,  1062,  2891,
          6038,  3158,  2278,  5142,  2825,  2004, 167

input torch.Size([1, 128])
tensor([[  101,  1044, 16571, 22939, 20915,  2594,  2002, 24952,  2278, 19686,
         16571,  7667, 13866,  3464,  4895,  6072, 26029, 12742,  2053, 25171,
         22239,  2330,  2159, 11867, 12162,  2929,  3272,  2026, 10085,  7811,
          2271,  2929,  2132, 15190, 10632,  2053, 25171, 22239,  7247,  3461,
          2305,  5670, 12653,  7391,  2512, 22643,  7391, 22643,  5670,  3602,
         13866,  2363, 13938,  2015,  2561, 10856,  9956,  2140,  3497,  4374,
         20372, 14223, 20194,  9808,  5302,  2132, 14931,  2589,  2651, 15497,
         25212,  2290,  9808,  2232,  3464,  3929, 18834,  2098,  4642,  2615,
          2174,  2058, 13578, 22314, 18834, 17531,  2213, 13866,  9885,  7509,
          2475, 10514,  7542,  2098,  1053,  2487, 17850,  2015,  1048, 16523,
          2572,  3215, 16405, 13728,  3968, 14545,  6037, 11867,  4904,  2819,
          1054,  8747,  7474,  3560,  3052,  4165,  4249, 13675, 13699,  4183,
          2271,  2371,  3

input torch.Size([1, 128])
tensor([[  101,  2382,  1061,  2099,  2214,  1042,  1052,  2991,  5108,  2761,
          4914,  6441,  2421,  2312, 11290, 18749,  2235, 17371,  2232,  1054,
          6130, 18749,  1048, 10792, 23292,  3191, 22930,  3064, 11325,  1051,
          2475,  4078,  4017, 18924, 24501,  2361, 12893, 10507,  2575, 20228,
         11236,  2389,  1041,  4246, 14499,  2179,  2157, 11192, 20228, 11236,
          2389,  1041,  4246, 14499, 11325,  7667, 13866,  2282,  2250,  1051,
          2475,  2938,  2015,  3054,  2152,  2157, 15219,  3108,  7270,  3461,
         10514,  7542,  8777, 14262,  8820,  3070, 20023,  3560, 11987,  1048,
          2015,  3154, 15911,  2157, 11192,  2918,  2895,  3108,  7270,  2904,
          2300,  7744,  3729,  2497,  6628,  1051, 16429,  3242,  3433, 13866,
         24501,  2361,  3570, 15704,  2933,  4651,  2723,  3613,  8627, 21908,
         11848,  3436,  3255,  2491, 11325,  3255, 11888,  3255,  7667, 13866,
         17612,  2015,  3

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  5213, 17419,  2594,  7667,  2895,  3433,  2933, 23760, 29048,
         28378,  7667,  2895,  3433,  2933, 25125,  4945, 11325, 11325, 25125,
          4945, 12098,  2546,  7667,  2895,  3433,  2933, 10930,  3287,  4914,
         18699,  8915,  8737, 22260, 11937, 11714,  2361,  2638,  2594,  2938,
         10958,  1057, 17106,  2872, 10507,  6703, 17996,  1041,  2860, 17531,
         13675,  5243,  2102, 25610,  2278,  4996,  2445, 23675,  2015,  4921,
          2546, 24479, 18749, 12259, 18856,  2872,  5213, 17419,  2594,  7667,
         13866,  5478,  2811,  2953,  2295,  2342,  8331,  8945,  7393,  2229,
          2668, 18856,  2595,  2234,  2067,  4013,  2618,  2271, 18062, 27965,
          4607, 24163, 27631, 10768,  9289,  2483, 17996,  4013,  2618,  2271,
          4

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  2720,  2095,  2214,  2158,  2381, 23760, 29048, 26180, 11441,
          2095,  3283,  1043, 17298,  9006,  2050, 10886,  2420,  3147,  2066,
          8030, 19340,  5776,  4311,  5156,  2110,  2740,  2627,  5958,  2211,
          2514, 26478, 17944,  2318, 19340, 13318, 26916,  2317, 11867,  4904,
          2819,  5776,  4311,  4692,  2202,  1999, 15238,  2099,  2081,  2514,
          2422,  3753,  2295,  2052,  8143,  3030,  2635, 23439,  3108,  3255,
         14412, 23270, 10708, 19029, 24780, 22939, 12171, 20192, 26351, 29477,
          2140,  4178,  5776,  2381, 15050,  3471, 11888,  2966,  3785,  5057,
          2533,  5776,  2363, 13004, 14017, 17897, 22196,  2140, 11265,  8569,
         28863,  2015, 23310, 11253,  4135, 18684, 15459,  3108,  1060,  4097,
          3936,  3160,  2157,  2896, 21833, 29543,  29

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101, 19962, 22599,  9108,  3602,  5796,  2095,  2214,  2450, 28767,
         14609, 27937,  2232,  3322,  4914, 11265, 10976, 26210,  4590,  2100,
          2326,  5142, 10372, 21210, 18454,  3372,  3322,  3591,  2195,  2420,
         19029, 24780, 21419,  3255,  4487, 16173, 10992, 14931,  7645,  3350,
          7704,  2235,  6812,  2884, 27208,  2092,  8331,  3074, 21419,  5955,
         21210, 18454,  3372, 20581, 21210, 18454,  3372, 10410,  3652,  5796,
          3736,  6948,  7645, 25610,  6169,  3226,  4997,  2404,  3158,  9006,
          2100, 15459,  8292,  7011,  6844,  4115,  2101,  8061,  8292,  7011,
          6844,  4115,  4015, 14609, 18454,  3372,  8985,  2851,  2253,  8208,
         21210, 18454,  3372, 26721, 25918,  8082,  2135, 16405,  2015,  3264,
          4

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101, 26094,  3012,  5920,  3535,  7667,  2895,  3433,  2933,  1044,
         22571, 12184,  3619,  3258,  5213,  7667, 17531,  1057, 10507, 17850,
          1048, 13433,  2015,  2651,  2895,  8822, 17531,  2802,  5670, 25606,
         23310,  7361,  9072,  1999, 20523, 17531,  2318,  2659, 13004,  3433,
          2933,  3291,  6412,  7928,  7667,  4663,  6228, 18834, 10906,  4076,
          9353,  3052,  4165,  3154,  3356, 24423, 11737,  7888, 10514,  7542,
          2075,  4317,  2317,  3756,  3595,  8496,  2235,  8777,  2572,  3215,
          2566,  3802,  2102,  8700, 17790,  1051,  2475,  2938,  2895, 10514,
          7542,  2075, 17850,  2015,  2734, 10882,  2080,  2475,  2357,  3433,
          2933,  2591,  2814, 18196,  7907,  2156, 13866,  2651,  2566,  7907,
         13866,  2155, 10402, 13732,  3569,  2814,  77

dict_keys(['input_ids', 'attention_mask'])

input torch.Size([1, 128])
tensor([[  101, 10930,  2931, 28767,  6556,  1040,  2615,  2102, 26226,  2050,
          4914,  4718,  2157, 19395,  3255, 17540, 14931,  2050,  2864,  3968,
          3662, 17758, 21908,  7861, 14956,  2072,  2312, 12279,  7861, 14956,
          2271,  2157,  2364, 21908, 16749,  6903,  2389,  2222,  2140,  5628,
          2318,  2002, 19362,  2378, 14181,  2102,  4651,  2098,  4200,  2723,
         13330,  2061,  2213, 11231, 22717,  2572,  2015,  2179, 11113,  2290,
          1048, 13316,  1042,  2015, 28093,  2132, 14931,  2864,  1051,  2475,
         10548,  1048, 13316,  2462, 12273, 16416,  3366, 24501,  2361,  3298,
          4015, 23025,  2226, 23760, 10010, 13592, 24501,  2361,  4945,  2825,
          2342, 12170,  4502,  2361,  4487, 13102,  2080,  1040, 16118,  1040,
          3490,  3229, 14255,  2615, 25930,  2850, 23760,  9289,  3401, 10092,
          2152, 13853,  7667, 13371,  6187,  2895,  7531, 24

dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  6273,  2546,  2381,  3802, 11631, 25022, 12171, 25229, 21025,
          2497, 11290,  4945, 28697,  2497,  3591,  3968,  2601, 22939, 12171,
         20192,  7977,  2668, 19029, 24780,  2512,  6703,  2512,  4157,  2598,
          7861, 19009, 16731,  2102,  6449,  2174, 13866, 19610,  7716, 18279,
          7712,  3973,  6540, 12835,  2102, 13697,  3351,  3154,  2318,  4921,
          4903,  2072,  2363,  1048,  4921,  2546,  2464, 21025,  4081,  1041,
          2290,  2094,  3131,  2363,  2561,  3197, 10975,  9818,  3161,  5751,
          9524,  9531,  4760,  3806, 18886,  7315,  3161, 19501, 10507,  2581,
          6812,  2884,  5750, 16731,  2102,  2211,  9885, 24529,  2594,  2226,
          2896, 21025,  2817,  3201,  6845, 21025,  3507,  6380,  2342, 28155,
          5620,  3806, 13181, 18447, 19126, 19501, 21025, 19501, 21025,  2497,
          7667,  5034,  7275, 26189,  2278,  3264, 24829,  23

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  6070, 28353,  1052, 24978, 18532,  2072,  2033,  2818, 20704,
          2099,  3356, 21025, 19501,  1052,  1041,  2290,  2094,  2302,  3154,
          3120,  4015, 24582,  2226, 17890,  2002, 19362,  2378, 14181,  2102,
         28667,  3126, 24413,  6703,  7861, 19009, 21887,  2854, 16749,  4295,
         28353,  2003,  5403,  7712,  2540,  4295,  7667, 13866,  5157,  2003,
          5403, 10092,  4292, 21025, 19501,  9377, 18133,  2243, 19817,  7361,
         10698,  2078, 13866,  6348,  2904, 28697,  2497,  3446,  1050,  2497,
          2361,  3333,  2506, 10882,  2497,  2895, 10507, 24978,  8945,  7393,
          1060,  1050,  2497,  2361,  2659, 28697,  2497, 13463, 10441, 20793,
    

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  5764, 10930,  1059, 28353,  1040,  2213,  1052,  9298,  2290,
          4914,  1059, 24978, 18532,  2072, 17359, 19357,  3064, 26261, 27109,
         14804,  4937,  2232,  1059, 13866,  3540,  4319,  3449, 24539, 26261,
          3372,  2179,  4895,  6072, 26029, 12742,  2723,  1054,  3657, 24353,
          1048, 19610, 11514, 23115,  2401,  1048, 11536, 19046,  3642,  6909,
          2170, 14931,  2050,  2132,  1051,  9468, 24117,  1052, 16474, 17076,
          3695, 12170, 27942, 10719,  1051,  9468, 24117,  1059,  2146, 16215,
         21716,  8286,  2363, 24264,  1056,  4502,  1059, 21442,  6895,  6228,
         26384,  7279, 25438,  2527,  6228, 26384, 14516,  1051,  9468, 24117,
         28549,  3540, 21961, 28549,  3540, 10514,  2361,  4487,  2615, 12403,
          6491,  5845,  1059, 24264,  9152, 13181, 256

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101, 10930,  3728, 19857,  2066,  8030,  4914,  3968,  4359, 13433,
         13822,  9016, 28935,  5562, 13866,  2179,  1044, 22571, 12184,  3619,
          3512, 11303,  1042,  2497,  2561,  5507,  1038,  6392,  8578,  2741,
          2179, 12098,  2546, 10819,  5358,  7662,  2100,  4359, 13433, 13822,
          8319, 15868,  2278,  8319, 15050, 16285, 13866,  4015, 14841, 23025,
          2226,  2968,  1044, 22571, 12184,  3619,  3258,  5213,  7667, 13866,
         17531, 25353, 16033, 10415,  2895,  8331,  8945,  7393,  2561,  5507,
          2164,  3968, 13866,  2036,  2363, 12170, 10010,  2497,  3433, 17531,
         25353, 16033, 10415,  2933,  8080, 17531,  1042,  2497,  2734,  1044,
         22571, 14573,  2121, 10092,  7667,  89

input torch.Size([1, 128])
tensor([[  101,  4413,  2546, 26226,  3593,  4921,  8004,  2197, 13004,  2078,
          3134,  3283,  4895,  7913,  4383,  2002, 15042,  2828,  1040,  2213,
         17359, 19357,  6024, 27441,  7315,  1061,  2869,  3283,  2197, 17748,
          5845,  3653,  2094,  8977,  8462,  3522, 19888,  2891, 17822, 27184,
          8985,  4852,  5285, 12717, 18674,  2512, 26682,  2100, 22939, 12171,
         20192,  1038,  5244,  3679,  4852, 28105,  2512, 12173, 15370,  4629,
         21419,  3255,  2599,  2175,  2902,  2445, 20989,  2741,  2188,  5776,
          2170,  9349,  2851, 21419,  3255, 14412, 19312,  9285, 22939, 12171,
         20192,  2409,  2175,  3968,  8030,  2651,  2815,  6540, 10766,  8953,
          9016,  3968,  2179, 13114, 15928,  2063, 10958,  4647,  2741,  2318,
          6420, 11113,  2445,  4921,  2546,  2741, 10507,  2226, 23025,  2226,
          3675,  8902,  9232,  8822,  3891,  6812,  2884,  2566, 29278,  3370,
          5970,  2206, 21

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  2824, 18834,  3431, 11585, 25269,  2152, 13114, 15928,  2063,
          2047,  8578,  3497, 11113,  2595, 22701,  2075,  2651,  9016, 23760,
         12399, 10092,  1052, 16363, 14787,  9016,  4242,  4761,  7667,  4098,
          2895, 11460,  9078, 15464,  5740,  8458,  2368,  3433,  2783, 13433,
          2933,  9530,  2102,  2783, 18834, 10906,  2047,  8578,  2871,  2095,
          2214, 11573,  3322,  4015,  2147,  6279,  4397,  1040,  2595,  9353,
         21716, 29107,  2135,  4330, 14412,  6508, 24471,  3981,  2854, 20125,
         26721, 23576, 11937, 11714, 11522,  2401,  3393, 11251,  3525,  2179,
         13656,  1048, 20960,  8715, 10599,  3802, 20569,  2132, 14931,  3662,
          2312,  5923, 24960,  1038,  2595,  1039,  1059, 16007, 27881,  1048,
         24

dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  4583,  7677,  2158,  1044,  2595,  9686,  4103, 10751,  1053,
          2213,  2860,  2546,  1044,  2595, 28353,  1040,  2213,  2019, 17577,
          4015, 11707,  2326,  8822,  1052,  2566,  7442, 25572,  2140,  2566,
          7442, 25572,  2140,  3255, 14689,  9623,  2015,  2464, 11360,  1038,
          6392,  1039,  2595,  4567,  2445, 25022, 21572,  5210,  8516, 13866,
          2579, 11987,  2815, 21358, 15878, 17531,  2938,  3436,  3054,  1048,
          5142, 19802,  6190,  4015, 24582,  2226,  3255, 20302, 28667,  9080,
          7667,  2566,  7442, 25572,  2140,  3255, 14397,  2226, 13866,  2363,
         11460, 22822, 20738,  1060,  2475,  3255,  2491,  2566,  7442, 25572,
          2140, 14743, 23489,  5508,  2723,  2895, 24582,  2226, 13866,  5517,
          3255,  2488, 10256, 15013,  3637,  2187,  2894, 11013, 28667,  9080,
         14743,  2157,  2187,  2217, 28667, 11667, 10124,  10

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101, 16464,  4945, 11325, 12098,  5104,  7667,  2895,  3433,  2933,
          2012, 14482, 10882, 23736, 20382, 28697,  2497,  7667,  2895,  3433,
          2933, 16464,  4945, 11325, 12098,  5104,  7667,  1051,  2475,  2938,
          2015,  7510, 11192,  4165,  1059, 21030,  6774,  8579,  4244,  2802,
          1039,  3108,  3255, 20228, 11236, 18291,  2895, 11265,  2497, 19067,
          2445,  2872, 12170,  4502,  2361,  8827, 10790, 21392,  2361,  2629,
          2318,  1060, 26915, 10288, 19960, 22822, 20738, 11460, 15050, 16285,
          4663, 19960,  5869,  7646, 11460,  1060,  2487,  3433, 13866,  3544,
          6625,  1039,  1052, 10395,  1051,  2475,  2938,  2015, 10882,  2080,
    

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101, 13866, 10930,  2931,  7610,  2232,  8872,  2094,  1048,  2188,
          1051,  2475,  9808,  2050, 12170,  4502,  2361,  2305, 10381,  2546,
         28353,  1052,  6258,  9298,  2290,  1040,  4328,  2072,  8040,  4048,
          2480,  8458,  7389,  2401,  1048,  2509, 19583,  1044, 22571, 14573,
         12541,  9314, 25353, 27718,  9626,  4588, 28879, 24582,  2094, 27159,
          2566, 13866,  2684,  7534,  2188,  2420,  4852, 17540,  4078,  5844,
          2188,  2292,  8167, 12863,  3445,  2147,  5505, 11113,  2290,  3968,
         17212,  2497, 13866,  2872, 18133,  9331,  2741, 23025,  2226,  2968,
         11888, 27885,  3367,  6820, 15277, 21908,  4295,  8872,  2094, 22953,
         12680, 13706,  7861, 21281,  3366,  28

dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  3108,  3255,  7667, 23439, 18133, 19817,  7361, 10698,  2078,
         23616, 11265,  2290,  1060,  2475, 17850,  5034, 17531,  2002, 19362,
          2378, 14181,  2102,  9530,  7629,  2895,  2435,  8840, 20110,  2953,
         11460, 13433,  2847,  2197, 13004, 17306,  2156, 28286,  2290, 14181,
          2102,  2057,  7231,  2094,  3433, 13866,  8271, 19029,  2235,  2572,
          2102,  3154,  7861, 19009,  2435,  1062, 11253,  5521,  1060,  2487,
          2204,  3466, 23969,  2290,  2689, 28286,  2290, 25606,  8319, 17531,
          2057,  7231,  2094, 17531, 27937,  2080,  4937,  2232,  2651, 17261,
         13866,  2102, 18856,  4140,  2741, 22861,  9634,  2933,  4937,  2232,
          2651,  2707,  4921,  2546,  2566,  2344,  8080, 13866,  2102,  2002,
         19362,  2378,  9530,  7629,  8840, 20110,  2953, 14841,  2094, 17306,
          4078,  6132, 25090, 24872,  7667,  1044,  2595,  10

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  1061,  2099,  2214,  3287,  4015,  9808,  2232,  3680,  2050,
         19802,  6190, 12456,  5796, 11290,  4945,  9686,  4103, 17540, 18479,
         27844,  8528,  2004, 17847,  2015,  9617, 10286,  3597,  3593, 17053,
          4588,  2566,  6342,  2063, 11290, 14234, 22291,  1059,  1057, 13866,
          2440,  3642,  3680,  2050,  3967, 29361,  2173, 13866,  5305, 11290,
         22291,  3568, 22291,  2862,  6594, 13866,  2564, 13866,  2081,  1040,
         16118,  3613,  2783,  2966,  3949, 19802,  6190,  2302,  5812, 28466,
          7667,  2895,  3433,  2933, 26014, 14266,  7667,  2895,  3433,  2933,
         20694, 27520,  7667,  2155,  3116,  7531,  2564,  2365, 11290,  2136,
         23025,  2226,  2136,  2852, 29300,  2591,  2147,  2895,  3091,  2135,
          6594, 13866,  4013, 26745,  6190,  2933,  2729,  3522,  2824,  2092,
          3642,  3

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  4583,  2095,  2214,  3287,  2381, 28353,  9298,  2290,  5490,
          6692, 27711,  3526,  2482, 21081,  2863,  2474, 18143,  2595,  1052,
          1060,  5339,  6245,  1040,  4328,  2072, 16216,  4103,  4015,  9808,
          2232, 11265, 10976,  6483,  2147,  5573, 26287,  5776,  3591,  9808,
          2232,  2026,  2389, 10440, 19029, 24780, 11251,  7518,  5845, 17214,
         10128,  7630, 19857,  2066,  8030,  4997, 24442,  1038,  7483,  2851,
          8271,  4039,  2693,  4654,  7913, 22930,  3111,  4937, 13594, 27011,
          4997,  6948,  2589,  2349,  5776,  2004,  8197,  6657, 20228, 18891,
          2595,  2293,  3630,  2595, 21358, 15878, 15928,  2063, 19610,  7716,
         18279,  7712,  3973,  6540,  1051,  2475,  2938,  2015,  2282,  2250,
         14978,  3300,  3255, 24185, 18752, 14045,  1043, 27421,  5443,  2273,
          2075, 13

dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101, 11325,  3255,  7667, 19935,  3255,  7478,  3255, 13675, 16613,
          2100, 19935, 12170, 20051,  3903,  4039,  3602, 16792,  2504,  3905,
          4297,  5644, 27870,  9407,  5776, 11360,  2895,  9706,  2015,  3967,
          2802,  2154,  5776,  8945,  7393,  2098,  1060,  2475,  3446,  3445,
          2566,  9706,  2015,  3433,  5776,  2747, 10124,  3255,  5790,  5776,
         25775,  1051, 16429,  1060,  2475,  2235,  4084,  1060,  2509,  2519,
          5776,  1039,  3255,  1039, 16962,  2933,  9530,  2102,  4958,  3593,
         11137, 10507,  3178,  3967,  9706,  2015,  3255,  3314, 23760, 12902,
         17577,  2152, 18044, 23760, 11008, 27241, 10092,  7667,  1047, 23969,
          2290,  3264,  6601,  5975,  2895, 23969,  2290,  2589, 22597,  1057,
         19723,  2445, 20647, 13181,  3366, 13938, 23969,  2290,  2358,  6245,
          2695, 22597, 20647, 13181,  3366, 23969,  2290,  55

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  1044, 22571, 12184,  3619,  3258,  5213,  7667, 17531,  4634,
          4949,  2895,  2445,  2561, 10507, 24978,  8331,  8945,  7393,  4949,
          2145,  4949, 23310,  7361,  9072,  2933,  3613, 14841,  6494,  2618,
         23310,  7361,  9072,  5441,  4949,  3582,  1039,  2595,  3463, 24479,
          3641, 21419,  3255,  2164, 21419, 24605,  7667, 13866,  1039, 19935,
          3255,  4078, 29234,  2015,  3255, 10634,  4629,  3806,  6508, 23852,
          8616, 14412,  2361, 19935, 13866,  2036,  1044, 22571, 12184,  3619,
          3512, 19029,  2895,  2445,  1062, 11253,  5521, 19029,  2204,  3466,
         19935, 14931,  2864,  1054,  2003,  5403,  7712,  6812,  2884,  7642,
         18749, 12259,  2015, 17785,  3433, 18749, 12259,  9874,  2075, 19935,
         14931,  2302,  3278, 19314,  9108,  2933,  36


input torch.Size([1, 128])
tensor([[  101, 18234,  3096, 11109,  7667,  2522,  9468, 17275, 11225, 10109,
          2895, 16360, 19234,  2098,  1053,  2475,  8093,  2015,  3433,  8153,
         10109,  2933,  3613,  2783,  2933,  4942, 24979,  2389, 19610,  2953,
         25032,  4270, 17371,  2232,  7667, 13866, 12098, 15441,  2015,  2376,
          5681,  7480,  2159,  3094,  7391,  1054,  3461,  1048,  2509,  7382,
         23667, 17701,  2232,  3048, 12170, 20051,  1057,  2063, 27491,  1054,
          2571,  3048,  2793,  5681,  2222,  2063, 10047,  5302, 14454, 17629,
          1999, 24759, 10732,  2512,  2128, 13578,  8988,  2121,  3464,  2895,
         11265, 10976, 11360,  1053,  2487,  8093,  3433, 11265, 10976,  3621,
          5301,  2144,  2689,  5670,  2933,  3613,  2783,  2933,  6832,  2490,
          2155, 13866, 18234,  3096, 11109,  7667,  2522,  9468, 17275, 11225,
         10109,  2895, 16360, 19234,  2098,  1053,  2475,  8093,  2015,  3433,
          8153, 10109,  

dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  3680,  2095,  2214,  2931,  3268,  2188,  2894,  2035, 24395,
          2572, 11636, 28775, 21202,  2440,  3642,  2381,  6909,  1060,  3522,
          3204,  3283,  4613,  7669, 21961,  2157, 11536, 11251, 28697,  2497,
         29610,  4295,  2179,  2155,  2197,  4092,  2305,  3188, 29031, 19960,
         14931,  2099,  1999,  2099,  3525,  2363,  3197, 21461,  2361,  6819,
          2102,  1047, 14931,  2132,  3662,  2157,  4015,  3968,  1999,  2099,
          2445, 11268, 18622,  2638,  3197,  9377,  2132, 14931,  3662,  2157,
          4921,  2232, 16428, 18834,  7277,  2571, 14931,  2036,  3662,  1048,
         11968,  2666,  9080,  3742,  1048,  7412,  4649,  3258,  2273,  2075,
         18994,  2050, 27011,  2949, 24582,  2226,  9377,  2132, 14931, 24582,
          2226, 13483,  3431,  5776,  3402,  2253, 10184, 11522,  2594,  9885,
          2668,  3778, 11460,  2012, 18981,  3170,  2540,  34

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  2708, 12087, 10930,  1042,  1044,  2595,  8872,  2094, 28353,
         26189,  2094, 13866,  4914, 17890, 27144,  8040, 15472,  3949,  8552,
         11937, 11714, 11522,  2401,  1044, 22571, 11636,  2401,  8776,  5177,
          3570, 13114, 15928,  2063, 11265,  4904, 18981, 19825, 11498, 23585,
         24759, 20875,  8715, 13866,  4914, 24582,  2226,  2420,  3283,  6555,
          1044, 22571, 11636,  2401,  8776,  5177,  3570,  3178,  2824,  5776,
          3030,  5155, 17996,  7483,  1057,  7361,  5301,  2197,  2847,  6986,
          5776,  1044, 22571, 12184,  3619,  3512,  2172,  2154,  2684,  2228,
          2388,  2052,  2215,  2430,  2240,  2872,  2224,  2811,  5668,  5776,
          2081,  4642,  2080,  2155,  7483, 228

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  9253, 24759,  3022,  2213, 28378, 28378, 13656,  7667, 13866,
          1052, 24501, 18491,  6770, 14663,  5649,  3742,  2566, 12190,  2292,
          8167, 12863,  8048,  4076, 10954, 22816,  4249, 10109,  2895,  3433,
          2933,  3255,  2491, 11325,  3255, 11888,  3255,  7667, 13866,  1039,
          3255,  2105,  2227,  2159, 14978,  3255,  4094, 13866, 29454, 19513,
          3593,  2188, 11888, 14978,  2015, 11460, 13433,  2048,  2093,  2335,
          2154,  2895, 21497, 29454, 19513,  3593, 13433,  5939,  7770,  4747,
          4658,  8416,  2159,  3433, 13866,  2145, 13417, 10256, 17964,  2000,
          3917,  3085,  8345,  2154,  2933,  9530,  2102,  8080,  7438,  3255,
         10975,  2078,  3255,  2491, 11325,  3255, 11888,  3255,  7667, 17491,
          1052,  9099,  8458, 16515, 16975, 24501, 18491,  6770, 14663,  5649,
         13656, 13

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  5764,  3287,  1044, 28353, 19610, 11663, 21716, 10610,  6190,
         25022, 12171, 25229,  1044, 13075, 23522, 21908, 10882, 12618,  6190,
         11888, 26261, 29514, 23852,  2188,  7722,  4651,  2098,  9808,  2232,
         11463,  8189,  2019, 17577, 13866,  3264,  3110, 14849,  5410,  2195,
          2420,  2304, 14708, 22939, 12171, 20192,  2195,  2420,  3964,  2627,
          2048,  2420,  2478,  2788,  3594,  1048, 13316,  2305,  2478,  2154,
          2391,  4039,  3328,  5723,  2302,  3352, 17540, 16342,  2094,  5410,
          3591,  9808,  2232, 19739, 27131,  2278, 14708, 17531, 13866,  2155,
         18292, 13866, 17531,  2015,  5373,  2288,  1057, 10975,  9818,  2015,
          3197, 10975,  9818,  2015,  3968,  2893,  3197, 21461,  2361,  2092,
         16731,  2102,  3968,  1057, 10975,  9818,  2015,  9808,  2232,  2566,
          3189, 16

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  2381,  3680, 10930,  1042,  7610,  2232,  2754,  3523, 11192,
          2512,  2235,  3526, 11192,  4456, 14996, 27144,  2197,  3949,  3570,
          2695, 11192, 24501, 18491,  3384, 22953, 12680,  4818, 26261,  3372,
          2754, 24264,  7570,  2094,  2290,  4939,  1040,  2595,  1044, 22571,
         14573, 12541,  9314,  2964,  8872,  2094,  3680,  2050, 18404, 21781,
          2015,  1052,  2532,  4311,  2698,  2420,  4788,  5582,  1040,  7274,
          2361, 22084,  2179, 24501,  2361, 12893, 11429,  2851,  6449,  3322,
          2579, 19067,  3158,  9006,  2100, 15459,  9587,  9048, 10258, 11636,
         28775,  2078,  2695, 27885,  3367,  6820, 15277,  1052,  2532,  2092,
         14017, 17897, 22196,  2140,  2036,  21

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  3963,  2546,  1044,  1044,  2102,  2078,  7534,  2847,  2460,
          2791,  3052,  5776,  4311,  3155,  2733, 14699,  3759,  2026,  2389,
         10440,  2015, 10548, 18923, 13433, 13822,  2036,  3532, 21418, 22753,
          3264,  5776,  1038,  2213,  3134,  3968,  2938,  2015, 10958,  5301,
         17212,  2497,  1039,  2595,  2099,  3662, 21908,  3968, 14545,  5776,
          2445,  5869,  7646,  1060,  2487,  4921, 11113,  2595,  4015,  2968,
          5508,  3131, 17212,  2497,  1059,  2938,  2015, 11937, 11714,  2361,
          2638,  2594,  2152, 11480,  1039,  2067,  3255,  1048,  2015, 15911,
          1059, 12170, 22083, 11733,  2099,  8579,  4244,  5869,  7646,  1060,
         11460,  1060,  2487, 22822, 20738, 114

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  6522,  2072, 10930,  2556, 19935,  3255,  2179,  2566,  2546,
          6829,  4181,  2389, 17359, 17119,  1052,  8503,  8983,  6829,  4181,
          2389, 17359, 17119,  2708, 12087,  2566,  2546,  6829,  4181,  2389,
         17359, 17119, 11937, 11714, 11522,  2401,  7667,  3446,  4758,  2058,
          3630,  2278, 17850,  2358,  2139, 17603,  4757,  1050,  2094,  3014,
          2540,  3796,  9530,  2102,  8138, 14397, 26189,  2278,  2895, 13853,
          6887,  2891, 16360, 19738,  3064,  2927,  5670,  3433,  9530,  2102,
         12052, 17850, 14925, 14399,  2100,  4895,  6072, 16116,  1048, 17250,
          6110,  2933, 11937, 11714, 11522,  2401,  7667, 13866,  4247,  2358,
          3014,  2540,  3796, 17850,  8138, 14397,  2015, 26189,  6169,  4771,
         10299,  3264, 11937, 11714, 11522,  2594,  23

tensor([[0, 0, 0,  ..., 0, 0, 0]], device='cuda:1')
x torch.Size([1, 128])

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101, 13866,  4914, 14397,  2226,  1052, 11290, 24501, 18491,  2381,
         16844,  6187, 15253, 23760, 29048,  5099, 23936,  1044, 21833,  6593,
         16940, 17632, 24501, 18491,  7667, 13866,  5777, 23852,  2135,  2305,
          4089, 12098,  3560,  3085,  8048,  1060,  8948,  3154,  2174, 13866,
          5505,  8467,  2335,  2349,  3255,  3571, 21454, 17531, 16922,  1057,
          7361,  2036, 16922, 10507,  2048,  5486,  2847, 16545,  2015,  1060,
         29215,  9333,  2566, 22291,  2136,  3313,  7039, 24582,  2226,  6319,
          2895, 13866,  6628,  7197,  2296,  2847,  2036,  6628, 19340,  2784,
          7200,  2201,  2378,  2445,  1060,  2659,  1057,  7361, 17531,  3433,
          8948,  3961,  3154, 13866,  2583, 19340,  6274,  2505,  1051,  2475,
          2938,  6540

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  4293,  2095,  2214,  2450,  2381,  1054,  5369, 12248,  3406,
          3593, 27641, 11320, 12207, 23760, 29048, 10886, 17076,  3695, 14728,
          2863, 26261, 13027, 23971,  3593,  2224, 13866, 19077,  2135, 20014,
         19761,  3064,  6540,  4650,  2506,  7620,  3968, 14545,  2607,  8552,
         12436,  2361,  2358,  9331,  2232,  8740, 23446,  8670, 21162, 17577,
         14246,  2278,  7689, 12906,  4607, 24163, 16665,  2140, 21183,  2072,
          4654, 28251,  3370, 23760, 29048, 28378,  7667, 24829,  2361,  2651,
          4921,  6845, 12870,  4135,  2140, 26018,  2721, 21254,  2895,  2979,
          4613, 10577,  9312,  2583,  2707, 13433, 19960,  2015,  2777,  7361,
         13153,  4747, 10514, 14693, 12556, 11460, 13433,  1053,  2154,  9152,
         25

dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  6275,  2095,  2214,  2931,  1044,  2595,  1044,  2102,  2078,
         27345,  2594, 15859,  7315,  9808,  2483,  1052,  9033, 21693,  9314,
          5624,  6593, 16940,  1048,  5856,  2497,  2761,  3591,  2902,  4178,
          6065, 14708,  7987,  2497, 18098,  2318,  3110,  5410,  5458,  2170,
          3788,  6411,  4682,  8143,  2098,  7861,  3215,  3369,  2179, 14412,
          4502,  3468,  8187,  3697,  6855, 17531,  2132, 12603, 12603, 10751,
          6540,  3968,  5508,  9808,  2232,  4914,  2966,  2326,  3988, 20228,
          3215,  1999,  2099,  2363,  3197, 10975,  9818,  2015,  7620,  8030,
          9524,  3322,  3030, 16844,  2891,  3597,  7685,  2589,  3041,  2651,
          3662, 27345,  2594, 18845,  6190,  2302,  3161,  9524,  6366,  3461,
         26572,  2361, 18323, 16844,  3461, 26572,  2361,  8292, 24894,  2174,
          2211,  3413,  2312,  8310,  2668, 10975,  3225,  31

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  4185,  2095,  2214,  2931,  4914,  7610,  2232, 23852, 23666,
         18629,  3586,  2791,  1040,  7274, 21890, 10440,  2627,  2195,  2086,
          9298,  2290,  1060,  2302,  4335,  3264, 20118, 13320,  2157,  4012,
         24128,  2157,  2364, 13473,  2213, 22953, 12680,  2271,  2482,  4140,
          3593,  5970,  2187,  4942, 20464, 21654,  2706,  3283,  2695,  6728,
          6110,  2157, 11536, 15127, 15321,  6305,  2594, 20118, 13320,  3461,
         21500,  8545,  3726, 22160,  2102, 11320, 11201,  2099, 12475,  9253,
          4958,  2072, 17678, 11253,  4747, 22597,  7610,  2232,  9298,  2290,
          1060,  2487,  7532,  3512,  8153,  8761, 22889,  2063,  1055,  5558,
    

dict_keys(['input_ids', 'attention_mask'])
input torch.Size([1, 128])
tensor([[  101,  6522,  2072,  1042,  2179,  3129,  1054,  3239, 14925, 11714,
         15530,  2483, 13576,  5248,  9741, 18634,  3129,  2170, 29031, 13866,
          4895,  6072, 26029, 12742,  5508,  2579,  2902,  4997,  2132,  1039,
          8560, 14931, 20014, 19761,  3064,  7861, 19009,  8776,  5177,  3570,
          4015,  5729,  1044, 22571,  7856,  7913, 10092,  6583,  3322, 14931,
          2132, 15190,  4997,  1044, 22571,  7856,  7913, 10092,  2659, 13365,
          1044, 22571,  2891,  5302, 13837,  3012,  7667, 18634,  4023,  7391,
          3461, 28022,  2135, 22643,  4076, 10954,  2334, 10057,  2376, 11530,
          2844, 18201, 19340,  1048,  2290,  8310,  3154,  3595,  8496, 17678,
         11253,  4747, 14841,  6494,  3064, 22356, 17531,  2491,  3544,  7327,
          6767, 16930,  2594, 26226,  2361,  2312,  8310,  3154,  3756, 17996,
          4997,  1048,  2947,  2521,  3173, 16731,  5753,  69

In [None]:
y_pred, y_test = get_predictions(model, test_data_loader)