In [None]:
import torch
import torch.optim as optim
from torch import nn
from transformers import AutoTokenizer, AutoModelForMaskedLM
import json
import os
import re
from matplotlib import pyplot as plt
import numpy as np
from tqdm import tqdm
import pickle
import pdb
import os
from torch.utils.data import DataLoader
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
#os.environ["CUDA_VISIBLE_DEVICES"]="0"
#os.environ["CUDA_VISIBLE_DEVICES"]="2,3,4,5,6,7"

# Loading data

Data format: json {article:, options:[[]], answers:[]}, the blanks are marked with "_"

In [None]:
def loadData(folder, suffix=None):
    """loads data as nested dicts/lists"""
    lst = []
    for root, dirs, files in os.walk(folder, topdown=False):
        for name in files:
            if 'ipynb' in root:
                continue # jupyter tmp file
            if suffix is None or suffix in root:
                name = os.path.join(root, name)
                with open(name) as f:
                    tmp = json.load(f)
                    lst.append(tmp)
                    if not tmp['options']:
                        raise
    print(folder, suffix, len(lst))
    return lst

train_lst = loadData('ELE', 'train') 
val_lst = loadData('ELE', 'dev')
test_lst = loadData('ELE', 'test')
cloth_lst = loadData('CLOTH')
clean_lst = [] 
i = 0
""" remove duplicates"""
for idx, item in enumerate(cloth_lst):
    dup = False
    for j in train_lst+val_lst: # no test from cloth, as expected
        if item['options'] == j['options']:
            dup = True
            break
    if not dup:
        clean_lst.append(item)

train_lst = train_lst + clean_lst
tmp = train_lst[0]
print("%d from cloth"%len(clean_lst))
print(tmp)

# Preprocessing
 [cls] and [sep] is added automatically (for BERT, 101 and 102)

In [None]:
BLANK_ID = 1035 # bert_uncased for "_"
MASK_ID = 103 # for BERT
MAX_LEN = 512
SEP_TOKEN = 102
MODEL_NAME = 'bert-base-uncased'


from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

def foo(x):
    x = x.encode('ascii')
    return int.from_bytes(x, byteorder='little')

class ClozeDataset(Dataset):
    """
    the simplest data format: {article, options, answers}
    article and answers zero padded, options -1 padded
    options that contains multiple tokens are truncated
     94 articles longer than 512, articles that are much too long are not discarded here, but will be truncated by my BERT model. The ignored options are filled with A.
     5677 answers contains more than 1 BERT tokens, but only 2 of them cannot be disinguished using the initial token
     for BERT, BLANK_ID should be changed into [MASK]
    """
    def __init__(self, data_list):
        super().__init__()
        self.data = []
        self.meta = []
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        # how many answers contain multiple bert tokens?
        cnt = 0
        cnt1 = 0
        # how many cannot be distinguished by the initial token?
        cnt2 = 0
        for item in tqdm(data_list):
            # article
            article = item['article'].lower()
            article = tokenizer.encode(article)
            length = len(article)
            article = torch.tensor(article)
            n_blanks_before = sum(article==BLANK_ID)
            if length > MAX_LEN:
                cnt1 += 1
                article = article[:MAX_LEN]
                article[-1] = SEP_TOKEN
            n_blanks = sum(article==BLANK_ID)
            article = (article * (article!=BLANK_ID).long())+(MASK_ID*(article==BLANK_ID).long())
            
            # answers
            answers = [foo(i) - foo('A') for i in item['answers']][:n_blanks]
            answers = torch.tensor(answers)
            
            # options
            options = [[tokenizer.encode(word)[1:-1] for word in line] for line in item['options']][:n_blanks]
            for i, option in enumerate(options):
                if answers.shape[0]> 0:
                    if len(option[answers[i]])>1:
                        cnt += 1
                        if option[answers[i]] in option[0:answers[i]]+option[answers[i]+1:]:
                            cnt2 += 1
                options[i] = [item[0] for item in option]
            # [0] is [CLS], [-1] is sep
            options = torch.tensor(options)
            self.data.append({"article":article, "options":options, "answers":answers})
            self.meta.append({"n_blanks_before":n_blanks_before, "n_blanks_truncated":n_blanks, "article_length":length})
            
        print("%d answers contains multiple tokens"%(cnt))
        print("%d articles exceeds max length"%(cnt1))
        print("%d answers cannot be decided using the initial token"%(cnt2))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
    
def collate_fn(data_list):
    batch = {}
    max_len = {}
    for key in data_list[0]:
        max_len[key] = 0
        for item in data_list:
            max_len[key] = max(max_len[key], item[key].shape[0])
        lst = [item[key] for item in data_list]
        padding_value = 0
        if key == 'answers':
            padding_value = -1
        batch[key] = pad_sequence(lst, batch_first = True, padding_value = padding_value)
    return batch

In [None]:
# Load pre-trained model tokenizer (vocabulary)
train_set = ClozeDataset(train_lst)
print(len(train_set))
with open("train_set", 'wb') as f:
    pickle.dump(train_set, f)

## Or Loading the Preprocessed Dataset

In [None]:
with open("train_set", 'rb') as f:
    train_set = pickle.load(f)

# Model

In [None]:
class Model(nn.Module):
    def __init__(self,):
        super().__init__()
        self.bert = AutoModelForMaskedLM.from_pretrained('bert-large-uncased')

    def forward(self, article, options, answers=None):
        attention_mask = (article > 99)
        result = self.bert(article, attention_mask = attention_mask, labels=article)
        # we compute our custom loss, so there is no need to set the labels
        _, logit = result[0], result[1]
        
        b, l, dim = logit.shape
        blank_mask = article == MASK_ID
        blank_mask = blank_mask.unsqueeze(-1).expand(*logit.shape)
        logit = torch.masked_select(logit, blank_mask).view(-1, dim)
        
        options = options.view(-1)
        mask = options>0
        options = torch.masked_select(options.view(-1), mask).view(-1, 4)
        # removes the padding options

        if not answers is None:
            answers = answers.view(-1)
            answers = torch.masked_select(answers, answers>=0)
            # removes the padding answers
            index = answers.long().unsqueeze(1)
            answer_token = torch.gather(options, 1, index).view(-1)
            # shape: (n_blanks)
            CE = nn.CrossEntropyLoss(reduction='none')
            loss = CE(input = logit, target = answer_token)
            return loss
        
        else:
            option_score = torch.gather(logit, 1, options)
            prediction = torch.argmax(option_score, dim = 1).view(-1)
            return prediction

# Traininig

In [None]:
with open("train_set", 'rb') as f:
    train_set = pickle.load(f)
train_loader = DataLoader(train_set, batch_size = 24, shuffle = True, collate_fn = collate_fn)

In [None]:
model = Model()
model = model.cuda()
model = nn.DataParallel(model)
optimizer = optim.Adam(model.parameters(), lr=1e-5)

In [None]:
losses = []
model.train()
for epoch in range(10):
    torch.save({'model_dict': model.module.state_dict(),\
           'optimizer_dict': optimizer.state_dict()},
           "./checkpoint_"+str(epoch))
    for i, data in enumerate(tqdm(train_loader)):
        article, options, answers = data['article'].cuda(), data['options'].cuda(), data['answers'].cuda()
        loss = model(article, options, answers)
        loss = loss.mean()
        loss.backward()
        losses.append(loss.item())
        if i % 10 is 0:
            optimizer.step()
            optimizer.zero_grad()
        if i%100 is 0:
            plt.plot(losses)
            plt.show()

In [None]:
with open("losses", 'wb') as f:
    pickle.dump(losses, f)

# Eval

In [None]:
val_lst = loadData('ELE', 'dev')
val_set = ClozeDataset(val_lst)
val_loader = DataLoader(val_set, batch_size = 1, shuffle = False, collate_fn=collate_fn)

In [None]:
model.eval()
correct = 0.
total = 0.
with torch.no_grad():
    for data in tqdm(val_loader):
        article, options, answers = data['article'].cuda(), data['options'].cuda(), data['answers'].cuda()
        pred = model.module(article, options)
        answers = answers.view(-1)
        answers = torch.masked_select(answers, answers>=0)
        correct += (pred == answers).sum()
        total += pred.shape[0]
print("acc:", correct/total)

# Test

In [None]:
test_lst = loadData('ELE','test')
test_set = ClozeDataset(test_lst)
test_loader = DataLoader(test_set, batch_size = 1, shuffle = False)

In [None]:
tokenizer = AutoTokenizer.from_pretrained('bert-large-uncased')
model.cuda()
model.eval()
results = {}
for i, data in enumerate(tqdm(test_loader)):
    result = []
    article, options = data['article'], data['options']
    article, options = article.cuda(), options.cuda()
    prediction = model(article, options)
    for j in range(test_set.meta[i]['n_blanks_before']):
        if j < prediction.shape[0]:
            result.append(chr(ord('A')+prediction[j]))
        else:
            result.append('A')
    results["test%04d"%i] = result
    
with open("results.json", "w") as f:
    json.dump(results, f)

# SL

In [None]:
torch.save({'model_dict': model.module.state_dict(),\
           'optimizer_dict': optimizer.state_dict()},
           "./checkpoint_0")

In [None]:
state_dict = torch.load("./checkpoint_7")
model.module.load_state_dict(state_dict['model_dict'])