In [None]:
!git clone https://github.com/Senyu-T/unifiedqa

In [None]:
!nvidia-smi -L

In [None]:
cd unifiedqa/bart

In [None]:
!chmod +x download_data.sh; ./download_data.sh

In [8]:
cd data/natural_questions_with_dpr_para/

/content/unifiedqa/bart/data/natural_questions_with_dpr_para


In [11]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [12]:
import os
os.chdir("/content/drive/MyDrive/NQ")

In [16]:
import string
import re

def normalize_answer(s):
  def remove_articles(text):
    return re.sub(r'\b(a|an|the)\b', ' ', text)
  def white_space_fix(text):
    return ' '.join(text.split())
  def remove_punc(text):
    exclude = set(string.punctuation)
    return ''.join(ch for ch in text if ch not in exclude)
  def lower(text):
    return text.lower()
  return white_space_fix(remove_articles(remove_punc(lower(s))))

def remove_spc_token(s):
  s = s.replace(' \\\'\\\'', ' \'\'')   # double quotation
  s = s.replace('\\\'', '\'')
  s = s.replace(' \'s', '\'s')    # 's
  s = s.replace(' ,', ',')
  return s

In [17]:
# read file, parse into context / question / answer for further data analysis
def read_files(file_name):
  answers = []
  questions = []
  contexts = []
  with open(file_name, 'rb') as inference_in:
    lines = inference_in.readlines()
    for i in range(len(lines)):
      sep = str(lines[i]).split('\\n') 
      questions.append(sep[0][2:-1])
      ans = (sep[1].split('\\t')[-1]).lower()
      ans = normalize_answer(remove_spc_token(ans))  # normalize answers
      answers.append(ans)
      contexts.append(sep[1].split('\\t')[0])
  return answers, questions, contexts

In [18]:
tr_answers, tr_questions, tr_contexts = read_files("/content/unifiedqa/bart/data/natural_questions_with_dpr_para/train.tsv")
val_answers, val_questions, val_contexts = read_files("/content/unifiedqa/bart/data/natural_questions_with_dpr_para/dev.tsv")

In [None]:
!pip install transformers

In [22]:
import torch
import torch.optim
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch import Tensor
import torch.nn.functional as F
from transformers import AutoModel, BertTokenizerFast, BertTokenizer, AutoTokenizer, BertModel
from sklearn.metrics import accuracy_score
from tqdm import tqdm
import csv
import numpy as np

In [23]:
class BiLSTM(nn.Module):
    def __init__(self,embedding_dim, hidden_dim, num_layers, num_classes):
        super(BiLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size=embedding_dim,
                    hidden_size=hidden_dim,
                    num_layers=num_layers,
                    batch_first=False,
                    bidirectional=True)
        self.fc1 = torch.nn.Linear(2*hidden_dim, hidden_dim)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(hidden_dim, num_classes)
    
    
    def forward(self, embeddings):
        self.lstm.flatten_parameters()
        lstm_output, _ = self.lstm(embeddings)
        output = lstm_output[:,-1,:]
        output = self.fc1(output)
        output = self.relu(output)
        output = self.fc2(output)
        return output

class BertBiLSTM(nn.Module):
    def __init__(self, num_layers, num_classes, embedding_dim = 768, hidden_dim=128):
        super(BertBiLSTM, self).__init__()
        
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        for param in self.bert.parameters():
          param.requires_grad = False
        
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        
        self.classifier = BiLSTM(self.embedding_dim, hidden_dim, num_layers, num_classes)
       
    def forward(self, input_ids, attention_mask):
        text_embeddings = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        text_embeddings = text_embeddings[0]
        
        output = self.classifier(text_embeddings)
        return output

In [28]:
MAX_LEN = 64

def preprocessing_for_bert(tokenizer, sentences):
    input_ids = []
    attention_masks = []
    num_sentences = len(sentences)

    for i, sent in enumerate(sentences):
      encoded_sent = tokenizer.encode_plus(text=sent, 
                          add_special_tokens=True,        
                          max_length=MAX_LEN,               
                          padding='max_length',         
                          return_attention_mask=True, 
                          truncation=True)     
      input_ids.append(encoded_sent.get('input_ids'))
      attention_masks.append(encoded_sent.get('attention_mask'))

    return input_ids, attention_masks


In [43]:
def get_data(file_path, tsv_file, tokenizer, save_location, label2index=None, index2label=None):
    answers, questions, contexts = read_files(file_path)

    tsv_file = open(tsv_file)
    read_tsv = csv.reader(tsv_file)
    labels = [row[0] for row in read_tsv]

    if label2index is None:
      label_set = set(labels)
      label2index = {label:index for index, label in enumerate(label_set)}
      index2label = {index:label for index, label in enumerate(label_set)}

    labels = [label2index[label] for label in labels]

    input_ids, attention_masks = preprocessing_for_bert(tokenizer, questions)
    np.savez(save_location, input_ids=input_ids, attention_masks=attention_masks, labels=labels, label2index=label2index, index2label=index2label)
    return len(label2index)

In [44]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

In [45]:
train_path = "/content/unifiedqa/bart/data/natural_questions_with_dpr_para/train.tsv"
val_path = "/content/unifiedqa/bart/data/natural_questions_with_dpr_para/dev.tsv"
tr_tags = "/content/drive/MyDrive/NQ/data/answer_tags/train_soft_tag.tsv"
val_tags = "/content/drive/MyDrive/NQ/data/answer_tags/dev_soft_tag.tsv"

In [46]:
os.chdir("/content/drive/MyDrive/NQ")

In [47]:
num_classes = get_data(train_path, tr_tags, tokenizer, "data/tr_tokenized.npz")

In [None]:
num_classes = get_data(val_path, val_tags, tokenizer, "data/val_tokenized.npz")

In [40]:
def load_dataset(location):
    data = dict(np.load(location,allow_pickle=True))
    for key, elem in data.items():
      if key=='label2index' or key=='index2label':
        continue
      data[key] = torch.tensor(elem)
    dataset = TensorDataset(data['input_ids'].squeeze(), data['attention_masks'].squeeze(), data['labels'])
    return dataset 

In [59]:
train_path = "data/tr_tokenized.npz"
train_dataset = load_dataset(train_path)
val_path = "data/val_tokenized.npz"
val_dataset = load_dataset(val_path)

In [63]:
def evaluate(network, loader, loss_fn, data_size):
    loss = 0.0
    acc = 0.0
    with torch.no_grad():
      for i, (input_ids, masks, labels) in tqdm(enumerate(loader),total=len(loader),position=0, leave=True):
          input_ids = input_ids.cuda()
          masks = masks.cuda()
          labels = labels.cuda()
            
          output = network(input_ids, masks)
          loss += loss_fn(output,labels)
          preds = torch.argmax(output,dim=1)
          #output = network(input_ids, masks, labels=labels)
          #preds = torch.argmax(output.logits, dim=1)
          #loss = output.loss

          acc += torch.eq(preds, labels).sum().item()
    return loss.item() / data_size, acc / data_size

In [61]:
def train(directory, network, loss_fn, train_dataset, test_dataset, optimizer, scheduler, batch_size, num_epochs, verbose=True, val_freq=5):
    train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
    val_dataloader = DataLoader(test_dataset, shuffle=True, batch_size=batch_size)
    
    train_loss, train_acc = torch.zeros(num_epochs), torch.zeros(num_epochs)
    test_loss, test_acc = torch.zeros(num_epochs//val_freq + 1), torch.zeros(num_epochs//val_freq + 1)
    val_best_acc, train_best_acc = 0.0, 0.0
    os.makedirs(directory, exist_ok=True)
    for epoch in range(num_epochs):
        if epoch % val_freq == 0:
          network.eval()
          idx = epoch//val_freq
          test_loss[idx], test_acc[idx] = evaluate(network, val_dataloader, loss_fn, len(test_dataset))
          if test_acc[idx] > val_best_acc:
            val_best_acc = test_acc[idx]
            torch.save(network.state_dict(), f"{directory}/snapshot_val_best")
          if verbose:
            print(f"epoch:{epoch:3d}, test_loss: {test_loss[idx]:3.6f}, test_acc: {test_acc[idx]:3.5f}")

        network.train()
        train_epoch_loss, train_epoch_acc = 0.0, 0.0
        for i, (input_ids, masks, labels) in tqdm(enumerate(train_dataloader), total=len(train_dataloader), position=0, leave=True):
            input_ids = input_ids.cuda()
            masks = masks.cuda()
            labels = labels.cuda()              
            optimizer.zero_grad() 
            
            output = network(input_ids, masks)
            pred = torch.argmax(output, dim=1)
            train_epoch_acc += torch.eq(pred, labels).sum().item()
            loss = loss_fn(output,labels)/batch_size

            # commented lines are codes for BertSeqeucen
            #output = network(input_ids, masks, labels=labels)       
            #pred = torch.argmax(output.logits, dim=1)
            #loss = output.loss
            train_epoch_loss += loss.item()
            loss.backward()
            optimizer.step()

        train_acc[epoch] = train_epoch_acc / len(train_dataset)
        train_loss[epoch] = train_epoch_loss / len(train_dataloader)
        print(f"train_loss:{train_loss[epoch]:.6f}, train_acc:{train_acc[epoch]:.6f}")
        if (train_acc[epoch] > train_best_acc):
          train_best_acc = train_acc[epoch]
          torch.save(network.state_dict(), f"{directory}/snapshot_train_best")

        scheduler.step(test_acc[epoch//val_freq])

    network.eval()   
    test_loss[-1], test_acc[-1] = evaluate(network, val_dataloader, loss_fn, len(test_dataset))    
    torch.save(train_loss, f"{directory}/train_loss")
    torch.save(test_loss, f"{directory}/test_loss")
    torch.save(train_acc, f"{directory}/train_acc")
    torch.save(test_acc, f"{directory}/test_acc")
    torch.save(network.state_dict(), f"{directory}/snapshot_final")


In [54]:
def get_path(loss, opt, lr, batch_size, epoch):
    return f"{PATH}/lr_{lr}_bs_{batch_size}_loss_{loss}_epoch_{epoch}_opt_{opt}"

In [49]:
loss = 'ce'
opt = 'adam'
batch_size = 32
lr = 1e-5
num_epochs = 50

In [55]:
PATH = "bert_bilstm_classifier"
directory = f"{get_path(loss, opt, lr, batch_size, num_epochs)}"
os.makedirs(directory, exist_ok=True)

torch.manual_seed(11747)
num_layers = 2

In [57]:
network = BertBiLSTM(num_layers, num_classes)
network = network.cuda()

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=570.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…




In [58]:
optimizer = torch.optim.Adam(network.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.7, patience=2, verbose=True, mode='max')

In [64]:
train(directory, network, loss_fn, train_dataset, val_dataset, optimizer, scheduler, batch_size, num_epochs)

100%|██████████| 335/335 [00:13<00:00, 24.48it/s]
  0%|          | 1/3022 [00:00<08:59,  5.60it/s]

epoch:  0, test_loss: 0.093335, test_acc: 0.00047


100%|██████████| 3022/3022 [06:41<00:00,  7.52it/s]


train_loss:0.031963, train_acc:0.746566


100%|██████████| 3022/3022 [06:42<00:00,  7.52it/s]


train_loss:0.014974, train_acc:0.879298


100%|██████████| 3022/3022 [06:41<00:00,  7.52it/s]


train_loss:0.012851, train_acc:0.888928


100%|██████████| 3022/3022 [06:42<00:00,  7.51it/s]


train_loss:0.011471, train_acc:0.895920


  0%|          | 1/3022 [00:00<06:48,  7.39it/s]

Epoch     4: reducing learning rate of group 0 to 7.0000e-06.


100%|██████████| 3022/3022 [06:42<00:00,  7.51it/s]


train_loss:0.010109, train_acc:0.905809


100%|██████████| 335/335 [00:13<00:00, 24.60it/s]
  0%|          | 1/3022 [00:00<06:58,  7.23it/s]

epoch:  5, test_loss: 0.012866, test_acc: 0.88067


100%|██████████| 3022/3022 [06:41<00:00,  7.53it/s]


train_loss:0.009207, train_acc:0.912957


100%|██████████| 3022/3022 [06:39<00:00,  7.57it/s]


train_loss:0.008374, train_acc:0.919060


100%|██████████| 3022/3022 [06:38<00:00,  7.58it/s]


train_loss:0.007648, train_acc:0.925204


100%|██████████| 3022/3022 [06:40<00:00,  7.54it/s]


train_loss:0.007006, train_acc:0.931038


  0%|          | 1/3022 [00:00<06:48,  7.39it/s]

Epoch     9: reducing learning rate of group 0 to 4.9000e-06.


100%|██████████| 3022/3022 [06:40<00:00,  7.55it/s]


train_loss:0.006229, train_acc:0.939189


100%|██████████| 335/335 [00:13<00:00, 24.59it/s]
  0%|          | 1/3022 [00:00<07:00,  7.18it/s]

epoch: 10, test_loss: 0.014961, test_acc: 0.87721


 14%|█▎        | 415/3022 [00:54<05:44,  7.58it/s]

KeyboardInterrupt: ignored

Get the tags

In [65]:
network = BertBiLSTM(num_layers, num_classes)
network.load_state_dict(torch.load(f"{directory}/snapshot_val_best"))
network.cuda()

BertBiLSTM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
 

Inference and save preds / logits

In [71]:

def get_labels(tsv_file):
  index2label, label2index = None, None
  tsv_file = open(tsv_file)
  read_tsv = csv.reader(tsv_file)
  labels = [row[0] for row in read_tsv]
  if label2index is None:
    label_set = set(labels)
    label2index = {label:index for index, label in enumerate(label_set)}
    index2label = {index:label for index, label in enumerate(label_set)}
  return label2index, index2label
  

In [72]:
tr_tags = "/content/drive/MyDrive/NQ/data/answer_tags/train_soft_tag.tsv"
val_tags = "/content/drive/MyDrive/NQ/data/answer_tags/dev_soft_tag.tsv"
tr_i2l, tr_l2i = get_labels(tr_tags)
print(tr_i2l)

{'LANGUAGE': 0, 'WORK_OF_ART': 1, 'EVENT': 2, 'GPE': 3, 'PERCENT': 4, 'FAC': 5, 'ORG': 6, 'ORDINAL': 7, 'TIME': 8, 'NORP': 9, 'MONEY': 10, 'LOC': 11, 'LAW': 12, 'PRODUCT': 13, 'PERSON': 14, 'DATE': 15, 'OTHERS': 16, 'QUANTITY': 17, 'CARDINAL': 18}


In [114]:
def inference(network, dataset, batch_size):
  logits = []
  preds = []
  loader = DataLoader(dataset, shuffle=True, batch_size=batch_size)
  network.eval()
  with torch.no_grad():
    for i, (input_ids, masks, labels) in tqdm(enumerate(loader),total=len(loader),position=0, leave=True):
      input_ids = input_ids.cuda()
      masks = masks.cuda()
      labels = labels.cuda()
            
      output = network(input_ids, masks)
      logits.append(output)
      pred = torch.argmax(output,dim=1)
      preds.extend([tr_l2i[pred[i].item()] for i in range(len(pred))])
      #output = network(input_ids, masks, labels=labels)
      #preds = torch.argmax(output.logits, dim=1)
      #loss = output.loss

    
  return torch.cat(logits), preds

In [115]:
val_logits, val_preds = inference(network, val_dataset, 32)

100%|██████████| 335/335 [00:13<00:00, 24.46it/s]


In [117]:
tr_logits, tr_preds = inference(network, train_dataset, 32)

100%|██████████| 3022/3022 [02:03<00:00, 24.40it/s]


In [116]:
torch.save(val_logits, f"{directory}/val_logits")

In [118]:
torch.save(tr_logits, f"{directory}/tr_logits")

In [121]:
with open(f"{directory}/val_preds.tsv", 'w') as v_f:
  v_f.write('\n'.join(val_preds))

In [124]:
with open(f"{directory}/tr_preds.tsv", 'w') as t_f:
  t_f.write('\n'.join(tr_preds))