<a href="https://colab.research.google.com/github/SahilDhull/emphasis_selection/blob/master/model/final_largebert_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install transformers
!pip install config



In [2]:
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split

from transformers import BertTokenizer, BertConfig , BertForMaskedLM , BertModel
# from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
# from transformers import RobertaConfig, RobertaForMaskedLM, RobertaTokenizer
# from transformers import XLNetConfig, XLNetModel , XLNetTokenizer
from transformers import WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup
from transformers import PreTrainedModel, PreTrainedTokenizer , BertPreTrainedModel

from tqdm import tqdm, trange
import pandas as pd
import io
import numpy as np
import matplotlib.pyplot as plt
import codecs
from torch.nn.utils.rnn import pack_padded_sequence
import os

Using TensorFlow backend.


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
torch.cuda.get_device_name(0)

'Tesla T4'

In [4]:
from google.colab import drive
drive.mount('/content/drive')

train_file = 'drive/My Drive/datasets/train.txt'
dev_file = 'drive/My Drive/datasets/dev.txt'
test_file = 'drive/My Drive/datasets/test.txt'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:
tokenizer = BertTokenizer.from_pretrained('bert-large-cased', do_lower_case = False)

In [0]:
def read_token_map(file,word_index = 1,prob_index = 4, caseless = False):
  
  with codecs.open(file, 'r', 'utf-8') as f:
      lines = f.readlines()

  tokenized_texts = []
  token_map = []
  token_labels = []
  sent_length = []

  bert_tokens = []
  orig_to_tok_map = []
  labels = []

  bert_tokens.append("[CLS]")
  
  for line in lines:
    if not (line.isspace()):
      feats = line.strip().split()
      word = feats[word_index].lower() if caseless else feats[word_index]
      label = feats[prob_index].lower() if caseless else feats[prob_index]
      labels.append((float)(label))
      orig_to_tok_map.append(len(bert_tokens))
      
      if(word == "n't"):
        word = "'t"
        if(bert_tokens[-1] != "won"):
          bert_tokens[-1] = bert_tokens[-1] +"n"
      if(word == "wo"):
        word = "won"

      bert_tokens.extend(tokenizer.tokenize(word))
     
    elif len(orig_to_tok_map) > 0:
      bert_tokens.append("[SEP]")
      tokenized_texts.append(bert_tokens)
      token_map.append(orig_to_tok_map)
      token_labels.append(labels)
      sent_length.append(len(labels))
      bert_tokens = []
      orig_to_tok_map = []
      labels = []
      length = 0
      bert_tokens.append("[CLS]")
          
  if len(orig_to_tok_map) > 0:
    bert_tokens.append("[SEP]")
    tokenized_texts.append(bert_tokens)
    token_map.append(orig_to_tok_map)
    token_labels.append(labels)
    sent_length.append(len(labels))
  
  return tokenized_texts, token_map, token_labels, sent_length

In [0]:
def read_test_token_map(file,word_index = 1, caseless = False):
  
  with codecs.open(file, 'r', 'utf-8') as f:
      lines = f.readlines()

  tokenized_texts = []
  token_map = []
  sent_length = []

  bert_tokens = []
  orig_to_tok_map = []
  
  bert_tokens.append("[CLS]")
  
  for line in lines:
    if not (line.isspace()):
      feats = line.strip().split()
      word = feats[word_index].lower() if caseless else feats[word_index]
      orig_to_tok_map.append(len(bert_tokens))
      
      if(word == "n't"):
        word = "'t"
        if(bert_tokens[-1] != "won"):
          bert_tokens[-1] = bert_tokens[-1] +"n"
      if(word == "wo"):
        word = "won"

      bert_tokens.extend(tokenizer.tokenize(word))
     
    elif len(orig_to_tok_map) > 0:
      bert_tokens.append("[SEP]")
      tokenized_texts.append(bert_tokens)
      token_map.append(orig_to_tok_map)
      sent_length.append(len(orig_to_tok_map))
      bert_tokens = []
      orig_to_tok_map = []
      length = 0
      bert_tokens.append("[CLS]")
          
  if len(orig_to_tok_map) > 0:
    bert_tokens.append("[SEP]")
    tokenized_texts.append(bert_tokens)
    token_map.append(orig_to_tok_map)
    sent_length.append(len(orig_to_tok_map))
  
  return tokenized_texts, token_map, sent_length

In [8]:
t_tokenized_texts, t_token_map, t_token_label, t_sent_length = read_token_map(train_file)
print(t_tokenized_texts[100])
print(t_token_map[100])
print(t_token_label[100])
print(t_sent_length[100])

d_tokenized_texts, d_token_map, d_token_label, d_sent_length = read_token_map(dev_file)
print(d_tokenized_texts[0])
print(d_token_map[0])
print(d_token_label[0])
print(d_sent_length[0])

f_tokenized_texts, f_token_map, f_sent_length = read_test_token_map(test_file)
print(f_tokenized_texts[50])
print(f_token_map[50])
print(f_sent_length[50])

['[CLS]', 'Happiness', 'consists', 'in', 'realizing', 'it', 'is', 'all', 'a', 'great', 'strange', 'dream', '.', '[SEP]']
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
[0.6666666666666666, 0.1111111111111111, 0.0, 0.2222222222222222, 0.0, 0.1111111111111111, 0.1111111111111111, 0.0, 0.2222222222222222, 0.3333333333333333, 0.3333333333333333, 0.1111111111111111]
12
['[CLS]', 'Life', 'is', 'defined', 'more', 'by', 'its', 'risks', 'than', 'by', 'its', 'same', '##ness', '##es', '.', '[SEP]']
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 14]
[0.4444444444444444, 0.1111111111111111, 0.2222222222222222, 0.1111111111111111, 0.1111111111111111, 0.1111111111111111, 1.0, 0.1111111111111111, 0.1111111111111111, 0.1111111111111111, 0.7777777777777778, 0.1111111111111111]
12
['[CLS]', 'In', 'the', 'practice', 'of', 'tolerance', ',', 'one', "'", 's', 'enemy', 'is', 'the', 'best', 'teacher', '.', '[SEP]']
[1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15]
14


In [9]:
MAX_LEN = 72

# Use the BERT tokenizer to convert the tokens to their index numbers in the BERT vocabulary
t_input_ids = [tokenizer.convert_tokens_to_ids(x) for x in t_tokenized_texts]

# Pad our input tokens
t_input_ids = pad_sequences(t_input_ids, maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")
t_token_map = pad_sequences(t_token_map, maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")
t_token_label = pad_sequences(t_token_label, maxlen=MAX_LEN, dtype="float", truncating="post", padding="post")

print(t_input_ids[100])
print(t_token_map[100])
print(t_token_label[100])

[  101 25410  2923  1107 10459  1122  1110  1155   170  1632  4020  4185
   119   102     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0]
[ 1  2  3  4  5  6  7  8  9 10 11 12  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
[0.66666667 0.11111111 0.         0.22222222 0.         0.11111111
 0.11111111 0.         0.22222222 0.33333333 0.33333333 0.11111111
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.     

In [10]:
d_input_ids = [tokenizer.convert_tokens_to_ids(x) for x in d_tokenized_texts]

# Pad our input tokens
d_input_ids = pad_sequences(d_input_ids, maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")
d_token_map = pad_sequences(d_token_map, maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")
d_token_label = pad_sequences(d_token_label, maxlen=MAX_LEN, dtype="float", truncating="post", padding="post")

print(d_input_ids[0])
print(d_token_map[0])
print(d_token_label[0])

[  101  2583  1110  3393  1167  1118  1157 11040  1190  1118  1157  1269
  1757  1279   119   102     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0]
[ 1  2  3  4  5  6  7  8  9 10 11 14  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
[0.44444444 0.11111111 0.22222222 0.11111111 0.11111111 0.11111111
 1.         0.11111111 0.11111111 0.11111111 0.77777778 0.11111111
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.     

In [11]:
f_input_ids = [tokenizer.convert_tokens_to_ids(x) for x in f_tokenized_texts]

# Pad our input tokens
f_input_ids = pad_sequences(f_input_ids, maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")
f_token_map = pad_sequences(f_token_map, maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")

print(f_input_ids[50])
print(f_token_map[50])

[  101  1130  1103  2415  1104 15745   117  1141   112   188  3437  1110
  1103  1436  3218   119   102     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0]
[ 1  2  3  4  5  6  7  8 10 11 12 13 14 15  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]


In [12]:
t_attention_masks = []
# Create a mask of 1s for each token followed by 0s for padding
for seq in t_input_ids:
  seq_mask = [float(i>0) for i in seq]
  t_attention_masks.append(seq_mask)
print(t_attention_masks[100])

d_attention_masks = []
# Create a mask of 1s for each token followed by 0s for padding
for seq in d_input_ids:
  seq_mask = [float(i>0) for i in seq]
  d_attention_masks.append(seq_mask)
print(d_attention_masks[0])

f_attention_masks = []
# Create a mask of 1s for each token followed by 0s for padding
for seq in f_input_ids:
  seq_mask = [float(i>0) for i in seq]
  f_attention_masks.append(seq_mask)
print(f_attention_masks[50])

[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.

In [0]:
t_input_ids = torch.tensor(t_input_ids)
t_token_map = torch.tensor(t_token_map )
t_token_label = torch.tensor(t_token_label)
t_attention_masks = torch.tensor(t_attention_masks)
t_sent_length = torch.tensor(t_sent_length)

d_input_ids = torch.tensor(d_input_ids)
d_token_map = torch.tensor(d_token_map )
d_token_label = torch.tensor(d_token_label)
d_attention_masks = torch.tensor(d_attention_masks)
d_sent_length = torch.tensor(d_sent_length)

f_input_ids = torch.tensor(f_input_ids)
f_token_map = torch.tensor(f_token_map )
f_attention_masks = torch.tensor(f_attention_masks)
f_sent_length = torch.tensor(f_sent_length)

# Select a batch size for training. 
batch_size = 32
# print(t_token_labels)
# Create an iterator of our data with torch DataLoader 
train_data = TensorDataset(t_input_ids, t_token_map, t_token_label, t_attention_masks, t_sent_length)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)
validation_data = TensorDataset(d_input_ids, d_token_map, d_token_label, d_attention_masks, d_sent_length)
validation_sampler = SequentialSampler(validation_data)
validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=batch_size, shuffle = False)
test_data = TensorDataset(f_input_ids, f_token_map, f_attention_masks, f_sent_length)
test_sampler = SequentialSampler(test_data)
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size,shuffle = False)

In [0]:
def read_for_output(file,word_index = 1):
  
  with codecs.open(file, 'r', 'utf-8') as f:
      lines = f.readlines()

  words_lsts = []
  word_ids_lsts = []
  words = []
  ids = []
  
  for line in lines:
    if not (line.isspace()):
      feats = line.strip().split()
      words.append(feats[word_index])
      ids.append(feats[0])
     
    elif len(words) > 0:
      words_lsts.append(words)
      word_ids_lsts.append(ids)
      words = []
      ids = []
          
  if len(words) > 0:
    words_lsts.append(words)
    word_ids_lsts.append(ids)
    words = []
    ids = []
  
  return words_lsts , word_ids_lsts

In [15]:
dev_words, dev_word_ids = read_for_output(dev_file)
test_words, test_word_ids = read_for_output(test_file)

print(dev_words[0])
print(dev_word_ids[0])
print(test_words[50])
print(test_word_ids[50])

['Life', 'is', 'defined', 'more', 'by', 'its', 'risks', 'than', 'by', 'its', 'samenesses', '.']
['Q_0_0', 'Q_0_1', 'Q_0_2', 'Q_0_3', 'Q_0_4', 'Q_0_5', 'Q_0_6', 'Q_0_7', 'Q_0_8', 'Q_0_9', 'Q_0_10', 'Q_0_11']
['In', 'the', 'practice', 'of', 'tolerance', ',', 'one', "'s", 'enemy', 'is', 'the', 'best', 'teacher', '.']
['Q_50_0', 'Q_50_1', 'Q_50_2', 'Q_50_3', 'Q_50_4', 'Q_50_5', 'Q_50_6', 'Q_50_7', 'Q_50_8', 'Q_50_9', 'Q_50_10', 'Q_50_11', 'Q_50_12', 'Q_50_13']


In [0]:
def intersection(lst1, lst2):
    lst3 = [value for value in lst1 if value in lst2]
    return lst3

def fix_padding(scores_numpy, label_probs,  mask_numpy):
    #if len(scores_numpy) != len(mask_numpy):
    #    print("Error: len(scores_numpy) != len(mask_numpy)")
    #assert len(scores_numpy) == len(mask_numpy)
    #if len(label_probs) != len(mask_numpy):
    #    print("len(label_probs) != len(mask_numpy)")
    #assert len(label_probs) == len(mask_numpy)

    all_scores_no_padd = []
    all_labels_no_pad = []
    for i in range(len(mask_numpy)):
        all_scores_no_padd.append(scores_numpy[i][:int(mask_numpy[i])])
        all_labels_no_pad.append(label_probs[i][:int(mask_numpy[i])])

    assert len(all_scores_no_padd) == len(all_labels_no_pad)
    return all_scores_no_padd, all_labels_no_pad

def match_M(batch_scores_no_padd, batch_labels_no_pad):

    top_m = [1, 2, 3, 4]
    batch_num_m=[]
    batch_score_m=[]
    for m in top_m:
        intersects_lst = []
        # exact_lst = []
        score_lst = []
        ############################################### computing scores:
        for s in batch_scores_no_padd:
            if len(s) <=m:
                continue
            h = m
            # if len(s) > h:
            #     while (s[np.argsort(s)[-h]] == s[np.argsort(s)[-(h + 1)]] and h < (len(s) - 1)):
            #         h += 1

            # s = np.asarray(s.cpu())
            s = np.asarray(s)
            #ind_score = np.argsort(s)[-h:]
            ind_score = sorted(range(len(s)), key = lambda sub: s[sub])[-h:]
            score_lst.append(ind_score)

        ############################################### computing labels:
        label_lst = []
        for l in batch_labels_no_pad:
            if len(l) <=m:
                continue
            # if it contains several top values with the same amount
            h = m
            # l = l.cpu()
            if len(l) > h:
                while (l[np.argsort(l)[-h]] == l[np.argsort(l)[-(h + 1)]] and h < (len(l) - 1)):
                    h += 1
            l = np.asarray(l)
            ind_label = np.argsort(l)[-h:]
            label_lst.append(ind_label)

        ############################################### :

        for i in range(len(score_lst)):
            intersect = intersection(score_lst[i], label_lst[i])
            intersects_lst.append((len(intersect))/(min(m, len(score_lst[i]))))
            # sorted_score_lst = sorted(score_lst[i])
            # sorted_label_lst =  sorted(label_lst[i])
            # if sorted_score_lst==sorted_label_lst:
            #     exact_lst.append(1)
            # else:
            #     exact_lst.append(0)
        batch_num_m.append(len(score_lst))
        batch_score_m.append(sum(intersects_lst))
    return batch_num_m, batch_score_m

In [0]:
def test(model):
  print("")
  print("Running test...")

  model.eval()
  eval_loss, eval_accuracy = 0, 0
  nb_eval_steps, nb_eval_examples = 0, 0

  iii = 0

  s = ""
  sentence_id = ""

  for batch in test_dataloader:
      
      # Add batch to GPU
      batch = tuple(t.to(device) for t in batch)
      
      # Unpack the inputs from our dataloader
      v_input_ids = batch[0].to(device)
      v_input_mask = batch[2].to(device)
      v_token_starts = batch[1].to(device)
      v_sent_length = batch[3]
            
      # Telling the model not to compute or store gradients, saving memory and
      # speeding up validation
      with torch.no_grad():        
          output = model(v_input_ids, v_input_mask, v_token_starts, v_sent_length)
      
      pred_labels = output[1]

      pred_labels = pred_labels.detach().cpu().numpy()

      for i in range(v_input_ids.size()[0]):
        for j in range(len(test_words[iii])):
          if sentence_id == iii:
            s = s + "{}\t{}\t{}\t".format(test_word_ids[iii][j], test_words[iii][j], pred_labels[i][j]) + "\n"
          else:
            s = s + "\n" + "{}\t{}\t{}\t".format(test_word_ids[iii][j], test_words[iii][j], pred_labels[i][j]) + "\n"
            sentence_id = iii
        iii = iii + 1
      s = s +"\n"
      
  print("testing complete\n")
  # print(s)
  return s

In [0]:
def validation(model):
  print("")
  print("Running Validation...")

  model.eval()
  eval_loss, eval_accuracy = 0, 0
  nb_eval_steps, nb_eval_examples = 0, 0

  num_m = [0, 0, 0, 0]
  score_m = [0, 0, 0, 0]

  iii = 0

  s = ""
  sentence_id = ""

  for batch in validation_dataloader:
      
      # Add batch to GPU
      batch = tuple(t.to(device) for t in batch)
      
      # Unpack the inputs from our dataloader
      v_input_ids = batch[0].to(device)
      v_input_mask = batch[3].to(device)
      v_token_starts = batch[1].to(device)
      v_labels = batch[2].to(device)
      v_sent_length = batch[4]
            
      # Telling the model not to compute or store gradients, saving memory and
      # speeding up validation
      with torch.no_grad():        
          output = model(v_input_ids, v_input_mask, v_token_starts, v_sent_length, v_labels)
      
      pred_labels = output[1]

      pred_labels = pred_labels.detach().cpu().numpy()
      v_labels = v_labels.to('cpu').numpy()

      for i in range(v_input_ids.size()[0]):
        for j in range(len(dev_words[iii])):
          if sentence_id == iii:
            s = s + "{}\t{}\t{}\t{}".format(dev_word_ids[iii][j], dev_words[iii][j], pred_labels[i][j],v_labels[i][j]) + "\n"
          else:
            s = s + "\n" + "{}\t{}\t{}\t{}".format(dev_word_ids[iii][j], dev_words[iii][j], pred_labels[i][j],v_labels[i][j]) + "\n"
            sentence_id = iii
        iii = iii + 1
      s = s +"\n"
      
      pred_labels, v_labels = fix_padding(pred_labels, v_labels, v_sent_length)

      batch_num_m, batch_score_m = match_M(pred_labels, v_labels)
      num_m = [sum(i) for i in zip(num_m, batch_num_m)]
      score_m = [sum(i) for i in zip(score_m, batch_score_m)]
  
  m_score = [i/j for i,j in zip(score_m, num_m)]
  
  print("Validation Accuracy: ")
  print(m_score)
  v_score = np.mean(m_score)
  print(v_score)
  # print(s)

  return v_score, s

In [0]:
max_accuracy = 0
val_out = ""
test_out = ""

def train(model,  optimizer, scheduler, tokenizer, max_epochs, save_path, device, val_freq = 10):
  
  bestpoint_dir = os.path.join(save_path)
  os.makedirs(bestpoint_dir, exist_ok=True)

  global max_accuracy 
  global val_out 
  global test_out 
  
  for epoch_i in range(0, max_epochs):
    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, max_epochs))
    print('Training...')

    # Reset the total loss for this epoch.
    total_loss = 0
    model.train()

    # For each batch of training data...
    for step, batch in enumerate(train_dataloader):    

        print("batch",step,"out of",len(train_dataloader))
        b_input_ids = batch[0].to(device)
        b_input_mask = batch[3].to(device)
        b_token_starts = batch[1].to(device)
        b_labels = batch[2].to(device)
        b_sent_length = batch[4]

        model.zero_grad()   
        model.train()     

        output = model(b_input_ids, b_input_mask, b_token_starts,b_sent_length,b_labels)
        loss = output[0]

        total_loss += loss.item()

        # Perform a backward pass to calculate the gradients.
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        # Update the learning rate.
        scheduler.step()

        if step % 10 == 0:
          accuracy, outs = validation(model)
          if(accuracy > max_accuracy):
            max_accuracy = accuracy
            val_out = outs
            test_out = test(model)

            # model.save_pretrained(bestpoint_dir)  
            # print("Saving model bestpoint to ", bestpoint_dir)

    # Calculate the average loss over the training data.
    avg_train_loss = total_loss / len(train_dataloader)            
    
    print("")
    print("  Average training loss: {0:.2f}".format(avg_train_loss))
  
  print("")
  print("Training complete!")

In [0]:
class transformer_model(nn.Module):
  def __init__(self, model_name, drop_prob = 0.3):
    super(transformer_model, self).__init__()

    config = BertConfig.from_pretrained(model_name, output_hidden_states=True)
    self.bert = BertForMaskedLM.from_pretrained(model_name, config = config)
    
    # the commented lines freezes layers of the model
    # cnt=0
    # for child in bert.bert.children():
    #   cnt = cnt + 1
    #   if cnt<=23:
    #     for param in child.parameters():
    #       param.requires_grad = False

    bert_dim = 25*1024
    hidden_dim1 = 950
    hidden_dim2 = 40
    final_size = 1

    self.fc1 = nn.Linear(bert_dim, hidden_dim1)
    self.fc2 = nn.Linear(hidden_dim1, hidden_dim2)
    self.fc3 = nn.Linear(hidden_dim2, final_size)
    self.dropout = nn.Dropout(p=drop_prob)

  def avg(self, a, st, end):
    k = a
    lis = []
    for i in range(st,end):
      lis.append(a[i])
    x = torch.mean(torch.stack(lis),dim=0)
    return x

  def save_pretrained(self, output_dir):
    self.bert.save_pretrained(output_dir)
    #please save the fc layers
           
  def forward(self, bert_ids, bert_mask, bert_token_starts, lm_lengths = None, labels = None):
    
    batch_size = bert_ids.size()[0]
    pad_size = bert_ids.size()[1]
    # print("batch size",batch_size,"\t\tpad_size",pad_size)

    output = self.bert(bert_ids, attention_mask = bert_mask)

    bert_out = output[1][0]
    for layers in range(1,25,1):
      bert_out = torch.cat((bert_out, output[1][layers]), dim=2)
    
    pred_logits = torch.relu(self.fc1(self.dropout(bert_out)))
    pred_logits = torch.relu(self.fc2(self.dropout(pred_logits)))
    pred_logits = torch.sigmoid(self.fc3(self.dropout(pred_logits)))
    pred_logits = torch.squeeze(pred_logits,2)

    pred_labels = torch.tensor(np.zeros(bert_token_starts.size()),dtype = torch.float64).to(device)

    # print(pred_logits[0])
    # print(pred_labels[0])
    # print(labels[0])
    # print(bert_token_starts[0])

    for b in range(batch_size):
      for w in range(pad_size):
        if(bert_token_starts[b][w]!=0):
          if(bert_token_starts[b][w]>=pad_size):
            print(bert_token_starts[b])
          else:
            st = bert_token_starts[b][w]
            end = bert_token_starts[b][w+1]
            if(end==0):
              end = st+1
              while(bert_mask[b][end]!=0):
                end = end+1
            # pred_labels[b][w] = self.avg(pred_logits[b],st,end)
            pred_labels[b][w] = pred_logits[b][bert_token_starts[b][w]]

    # print(pred_labels[0])

    if(labels != None):
      lm_lengths, lm_sort_ind = lm_lengths.sort(dim=0, descending=True)
      scores = labels[lm_sort_ind]
      targets = pred_labels[lm_sort_ind]
      scores = pack_padded_sequence(scores, lm_lengths, batch_first=True).data
      targets = pack_padded_sequence(targets, lm_lengths, batch_first=True).data

      # print(targets,scores)

      loss_fn = nn.BCELoss().to(device) 
      loss = loss_fn(targets,scores)

      return loss, pred_labels 

    else:
      return 0.0, pred_labels

In [0]:
model = transformer_model('bert-large-cased').to(device)

In [0]:
optimizer = AdamW(model.parameters(), lr=2e-5, eps = 1e-8)

epochs = 30
total_steps = len(train_dataloader) * epochs

# Create the learning rate scheduler.
scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps = 0, # Default value in run_glue.py
                                            num_training_steps = total_steps)

In [23]:
save_path = 'drive/My Drive/datasets/results/bert_large/'
train(model,  optimizer, scheduler, tokenizer, epochs, save_path, device)


Training...
batch 0 out of 86

Running Validation...
Validation Accuracy: 
[0.4005102040816326, 0.5464190981432361, 0.6268939393939393, 0.6727272727272727]
0.5616376285865202

Running test...
testing complete
batch 1 out of 86
batch 2 out of 86
batch 3 out of 86
batch 4 out of 86
batch 5 out of 86
batch 6 out of 86
batch 7 out of 86
batch 8 out of 86
batch 9 out of 86
batch 10 out of 86

Running Validation...
Validation Accuracy: 
[0.4897959183673469, 0.7082228116710876, 0.7803030303030304, 0.8083333333333333]
0.6966637734186996

Running test...
testing complete
batch 11 out of 86
batch 12 out of 86
batch 13 out of 86
batch 14 out of 86
batch 15 out of 86
batch 16 out of 86
batch 17 out of 86
batch 18 out of 86
batch 19 out of 86
batch 20 out of 86

Running Validation...
Validation Accuracy: 
[0.5229591836734694, 0.7148541114058355, 0.7888257575757577, 0.8196969696969697]
0.7115840055880082

Running test...
testing complete
batch 21 out of 86
batch 22 out of 86
batch 23 out of 86
batc

KeyboardInterrupt: ignored

In [24]:
print(max_accuracy ,val_out, test_out)

0.7896076009376503 
Q_0_0	Life	0.3863700330257416	0.4444444444444444
Q_0_1	is	0.05237560346722603	0.1111111111111111
Q_0_2	defined	0.21343442797660828	0.2222222222222222
Q_0_3	more	0.06619858741760254	0.1111111111111111
Q_0_4	by	0.0400894470512867	0.1111111111111111
Q_0_5	its	0.10197929292917252	0.1111111111111111
Q_0_6	risks	0.7404477000236511	1.0
Q_0_7	than	0.045847706496715546	0.1111111111111111
Q_0_8	by	0.027160512283444405	0.1111111111111111
Q_0_9	its	0.087177574634552	0.1111111111111111
Q_0_10	samenesses	0.3688516318798065	0.7777777777777778
Q_0_11	.	0.17192846536636353	0.1111111111111111

S_1_0	There	0.055736299604177475	0.2222222222222222
S_1_1	is	0.10482457280158997	0.2222222222222222
S_1_2	magic	0.7936208248138428	0.8888888888888888
S_1_3	in	0.249949112534523	0.3333333333333333
S_1_4	the	0.2007422298192978	0.3333333333333333
S_1_5	night	0.40127649903297424	0.4444444444444444
S_1_6	when	0.05057312548160553	0.1111111111111111
S_1_7	pumpkins	0.5406239628791809	0.3333333333333333