## A test for applying BERT to cloze task

According to the answer from github [How can I apply Bert to a cloze task](https://github.com/huggingface/transformers/issues/80#issuecomment-444445782)

In [42]:
import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
text = 'From Monday to Friday most people are busy working or studying, '\
       'but in the evenings and weekends they are free and _ themselves.'
tokenized_text = tokenizer.tokenize(text)

masked_index = tokenized_text.index('_')
tokenized_text[masked_index] = '[PAD]'

candidates = ['love', 'work', 'enjoy', 'play']
candidates_ids = tokenizer.convert_tokens_to_ids(candidates)

indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)

segments_ids = [0] * len(tokenized_text)

tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])

language_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
language_model.eval()

predictions = language_model(tokens_tensor, segments_tensors)
predictions_candidates = predictions[0, masked_index, candidates_ids]
answer_idx = torch.argmax(predictions_candidates).item()

print(f'The most likely word is "{candidates[answer_idx]}".')

The most likely word is "enjoy".


## Experiment of applying AUM to Bert based method

In [1]:
import torch
from pytorch_pretrained_bert import BertForSequenceClassification
from pytorch_pretrained_bert import BertConfig, BertAdam
# import torch.optim as optim

from torch.utils.tensorboard import SummaryWriter

In [2]:
import os
import numpy as np
import json

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

import pickle
import time

from pytorch_pretrained_bert import BertTokenizer


class MultiNLIDataset(Dataset):
    def __init__(self,root='/media/felicia/Data/multinli',matched=True,tokenized=True,max_length=12):
        super(MultiNLIDataset,self).__init__()
        self.root=root
        self.matched=matched
        self.tokenized=tokenized
        self.max_length=max_length
        self.jsonfile="multinli_1.0_train.jsonl" if self.matched else "multinli_1.0_dev_mismatched.jsonl"
        self.filename=os.path.join(self.root,self.jsonfile)


        self.num_labels = 2
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.LABEL_MAP = {
                "entailment": 0,
                "neutral": 1,
                "contradiction": 2,
                "hidden": 0
            }

        self.data=[]
        self.sentences=[]
        self.labels=[]
        self.load_data()
        self.random_flip = torch.randint(392702, (200, ))
#         self.random_flip = torch.randint(800, (20, ))

    def load_data(self):
        with open(self.filename) as f:
            for idx, line in enumerate(f):
#                 if idx > 800:
#                     continue
                example=json.loads(line) # dict
                self.data.append(example)
        if self.tokenized:
            self.tokenize()

    
    def tokenize(self):
        for i, text in enumerate(self.data):
            sent=text["sentence1"]
            label=text["gold_label"]
            if label not in self.LABEL_MAP:
                continue

            tokenized_text = self.tokenizer.tokenize(sent)
            indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text)
            if len(indexed_tokens)<self.max_length:
                indexed_tokens+=[0]*(self.max_length-len(indexed_tokens))
            else:
                indexed_tokens=indexed_tokens[:self.max_length]
            indexed_tokens=np.array(indexed_tokens)

            self.sentences.append(indexed_tokens)
            self.labels.append(self.LABEL_MAP[label])
    
    def __getitem__(self,index):
        sent = self.sentences[index]
        if index in self.random_flip:
            label = self.labels[index]+int(2*(torch.rand(1)>0.5)-1)
            if label>2:
                label = 0
            elif label<0:
                label = 2
        else:
            label = self.labels[index]

        return index,sent,label

    def __len__(self):
        return len(self.sentences)
#         return 800

In [3]:
# def Trainer(batch_size=4, epoch = 50, n_classes = 3, vocab_size = 30522):
batch_size=4
epoch = 50
n_classes = 3
vocab_size = 30522
writer = SummaryWriter('runs/NLP_AUM_First')

config = BertConfig(vocab_size)
model = BertForSequenceClassification(config, n_classes)
model.to("cuda")
dataset = MultiNLIDataset("./multinli_1.0", max_length=217)
dataloader = DataLoader(dataset, batch_size = batch_size, shuffle=True)
criterion = torch.nn.CrossEntropyLoss()
optimizer = BertAdam(model.parameters(), lr = 0.000001)
aum = torch.zeros([epoch, len(dataset), n_classes])

t_total value of -1 results in schedule not being applied


In [4]:
for e in range(epoch):
    logits_store = torch.zeros([len(dataset), n_classes]).cuda()
    running_loss = 0.0
    for idx, batch in enumerate(dataloader):
        optimizer.zero_grad()
        sample_idx, text_in, labels = batch
        logits = model(text_in.cuda())

        logits_store[sample_idx] = logits

        loss = criterion(logits, labels.cuda())
        loss.backward()
        optimizer.step()
        running_loss+=loss.item()
        if idx % 100 == 99:
            print(running_loss/100)
            running_loss = 0.0
    logits_topk, logits_topk_ind = torch.topk(logits_store, 2, 0)
    aum[e] = logits_store - logits_topk[0][None]
    aum[e, logits_topk_ind[0,0], 0] = logits_store[logits_topk_ind[0,0], 0] - logits_topk[1, 0]
    aum[e, logits_topk_ind[0,1], 1] = logits_store[logits_topk_ind[0,1], 1] - logits_topk[1, 1]

    attention_index = dataset.random_flip#Need the attention index to point out the flipped sample

#     for ind in attention_index:
#         writer.add_scalar("Sentence {}".format(ind), aum[e, ind, dataset[ind][2]], e)
    torch.save({aumaum, "AUM.pth")

1.1456522166728973
1.1209135419130325
1.0955419850349426
1.1125767797231674
1.096754828095436
1.0886997503042222
1.0857273465394974
1.0801356768608092
1.0894037210941314
1.0827822196483612
1.0733021980524062
1.0736334884166718
1.065121978521347
1.0574870485067367
1.0595620465278626
1.0459759545326233
1.0344018590450288
1.037968230843544


KeyboardInterrupt: 

In [11]:
(logits_store - logits_topk[0][None]).shape

torch.Size([800, 3])

In [5]:
labels = []
for i in dataset:
    labels.append(i[2])

In [6]:
np.unique(labels)

array([-1,  0,  1,  2])

In [13]:
aum = torch.zeros([epoch, len(dataset), n_classes])
aum[e]=logits_store - logits_topk[0][None]

In [16]:
logits[logits_topk_ind[0,0], 0].shape
# - logits_topk[1, 0]

IndexError: index 760 is out of bounds for dimension 0 with size 4

In [6]:
with open("./multinli_1.0/multinli_1.0_train.jsonl") as f:
    for idx, line in enumerate(f):
        if idx>10:
            continue
        print(f)

<_io.TextIOWrapper name='./multinli_1.0/multinli_1.0_train.jsonl' mode='r' encoding='UTF-8'>
<_io.TextIOWrapper name='./multinli_1.0/multinli_1.0_train.jsonl' mode='r' encoding='UTF-8'>
<_io.TextIOWrapper name='./multinli_1.0/multinli_1.0_train.jsonl' mode='r' encoding='UTF-8'>
<_io.TextIOWrapper name='./multinli_1.0/multinli_1.0_train.jsonl' mode='r' encoding='UTF-8'>
<_io.TextIOWrapper name='./multinli_1.0/multinli_1.0_train.jsonl' mode='r' encoding='UTF-8'>
<_io.TextIOWrapper name='./multinli_1.0/multinli_1.0_train.jsonl' mode='r' encoding='UTF-8'>
<_io.TextIOWrapper name='./multinli_1.0/multinli_1.0_train.jsonl' mode='r' encoding='UTF-8'>
<_io.TextIOWrapper name='./multinli_1.0/multinli_1.0_train.jsonl' mode='r' encoding='UTF-8'>
<_io.TextIOWrapper name='./multinli_1.0/multinli_1.0_train.jsonl' mode='r' encoding='UTF-8'>
<_io.TextIOWrapper name='./multinli_1.0/multinli_1.0_train.jsonl' mode='r' encoding='UTF-8'>
<_io.TextIOWrapper name='./multinli_1.0/multinli_1.0_train.jsonl' mode