# BERT-based Classifier with BERT Fine-tuning

This notebook is just for a BERT model feeding into a classifier, and fine-tuning the whole stack collectively.

In [34]:
import torch
import random
import os
import nltk
import time
import datetime
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from nltk import tokenize
from transformers import BertTokenizer, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, random_split, DataLoader, RandomSampler, SequentialSampler
from transformers import BertModel, AdamW, BertConfig
from sklearn.model_selection import cross_val_score
from sklearn.metrics import confusion_matrix
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

****

# Data Preparation and Cleaning

This section focuses on taking in data and cleaning it. Notable functions are those that load in raw data, the option to delete sentences that don't have any non-outside tokens, and the option to delete all sentences above a certain sentence length. 

In [2]:
def extract_xy(df):
    """
    This method extracts and correctly aranges the NER training x-values (tokens)
    and y-values (BESIO labels) from a pandas dataframe containing labeled NER
    data

    Parameters:
        df (pandas DataFrame, required): Dataframe loaded via pd.read_excel() on
            a labeled NER dataset

    Returns:
        two lists of identical shape. One contains all the tokens for labeling
        and the other contains all the labels of those tokenized words. Note that these tokens
        are not BERT tokens as yet. They still need to be fed into a BERT tokenizer.
    """
    labeled = []
    columns = df.columns
    new_df = pd.DataFrame()
    all_tokens = []
    besio = []
    mol = []
    IorO = []
        
    for idx, column in enumerate(columns):
        # find every column that starts with 'name'
        if column.startswith('name'):

            # check if the entry in 'name' cell is a str
            if isinstance(df[column][0], str):
                tokens = df[columns[idx + 1]].values
                #find the index where the tokens become NaNs, and chop the token length down to that size. 
                l = 0
                for entries in tokens: 
                    if type(entries) == str:
                        l += 1
                all_tokens.append(tokens[:l])
                df[columns[idx+2]].replace(np.nan, 'O', inplace = True)
                besio.append(df[columns[idx+2]][:l].values)
                df[columns[idx+3]].replace(np.nan, '', inplace = True)
                mol.append(df[columns[idx+3]][:l].values)
                df[columns[idx+4]].replace(np.nan, '', inplace = True)
                IorO.append(df[columns[idx+4]][:l].values)

    i = 0
    label_values = []
    while i < len(besio):
        label_values.append([])
        for j in range(len(besio[i])):
            #Strip unintentional whitespace from all entries:
            besio[i][j] = besio[i][j].replace(" ", "")
            mol[i][j] = mol[i][j].replace(" ", "")
            IorO[i][j] = IorO[i][j].replace(" ", "")
            
            if besio[i][j].upper() == 'O':
                label_values[i].append(besio[i][j].upper())
            if mol[i][j].upper() == 'PRO':
                label_values[i].append(besio[i][j].upper()+'-'+mol[i][j].upper())
            if IorO[i][j].upper() == 'I' or IorO[i][j].upper() == 'O':
                #The below does not handle cases where BESIO or MOL has errors though...
                if mol[i][j].upper() != 'PRO':
                    label_values[i].append(besio[i][j].upper()+'-'+mol[i][j].upper()+'-'+IorO[i][j].upper())
                else: 
                    print("Weird. This Property is organic or inorganic? LOL")
            #PRIME OPPORTUNITY FOR ERROR HANDLING - IF ANYTHING NOT IN THE ABOVE CATS, SOMETHING IS WRONG
            
        i += 1   
    return all_tokens, label_values

In [3]:
def tokenized_to_string(token_list):
    """This function is a helper function that takes the data from the labelled sheets, and turns them 
    from a list format back into a sentence format. Sort of an 'unsplit' method."""
    token_stringlist = []
    for paper_tokens in token_list:
        paper_string = ""
        for i in paper_tokens:
            #This is basically an 'unsplit' method lol
            paper_string += (str(i) + " ")
        token_stringlist.append(paper_string)
    return token_stringlist

In [4]:
def labeled_sheets_to_listed_tokens(directory_url):
    """This function opens a directory of labeled excel sheets from NER labeled excel sheets and returns the tokens as a list 
    of strings fully combined on a document level. It returns a list of strings, with each string being a full document."""
    files = os.listdir(directory_url)
    token_list = []
    label_list = []
    sent_labels = []
    for file in files:
        df = pd.read_excel(directory_url+file)
        token, label = extract_xy(df)
        token_list += (tokenized_to_string(token))
        label_list += (label)
    #Now we tokenize each paper by sentences using NLTK:
    #we will also restructure labels to be ordered by sentences. 
    for i in range(len(token_list)):
        sentences = tokenize.sent_tokenize(token_list[i])
        token_list[i] = sentences
        short_term_labels = []
        for j in range(len(token_list[i])):    
            length = len(token_list[i][j].split())
            short_list = label_list[i][:length]
            short_term_labels.append(short_list)
            del(label_list[i][:length])
        sent_labels.append(short_term_labels)
    return token_list, sent_labels

In [12]:
def drop_empty_sentences(token_list, label_list, label_dict):
    """This function is a form of undersampling, as it is dropping all examples of sentences where only O labels exist. We need to
    drop full sentences for undersampling due to the nature of BERT needing full sentences for its context understanding."""
    labels_list = list(label_dict.keys())
    flip = 0
    list_of_deletions = []
    del_counter = 0
    for i in range(len(labels_list)):
        if labels_list[i] == 'O':
            del(labels_list[i])
            flip = 1
        if flip == 1:
            break
    for i in range(len(token_list)):
        for j in range(len(token_list[i])):
            if any(x in label_list[i][j] for x in labels_list):
                pass
            else:
                list_of_deletions.append([i,j])
    print(list_of_deletions)
    while len(list_of_deletions) != 0:
        i,j = list_of_deletions[-1]
        del(label_list[i][j])
        del(token_list[i][j])
        del(list_of_deletions[-1])
        del_counter += 1
    print("Total deleted sentences = " + str(del_counter))
    return token_list, label_list

In [13]:
def max_encoded_length(token_list):
    """This function takes in a full list of all tokens, and determines the max BERT-encoded length
    for any sentence in the corpus, so we can set an appropriate maximum length for our BERT model."""
    max_len = 0
    len_list = []
    #Instantiate a tokenizer from the BERT Tokenizer class
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case = True)
    
    for papers in token_list:
        for sentences in papers:
            input_ids = tokenizer.encode(sentences, add_special_characters = True)
            max_len = max(max_len, len(input_ids))
            len_list.append(len(input_ids))
            
    return max_len, len_list

In [14]:
def delete_sentences(token_list, label_list, print_pop = False, max_length = 150):
    """This function takes in the list of tokens and labels, and deletes any sentence that has an 
    encoded tokenized length of greater than max_length. There is an option to print out sentences
    that have been deleted by using print_pop = True. """
    i = 0
    #Instantiate a tokenizer from the BERT Tokenizer class
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case = True)
    
    #lets use pop to remove these long sentences
    while i < len(token_list):
        list_of_j = []
        j = 0
        while j < len(token_list[i]):
            input_ids = tokenizer.encode(token_list[i][j], add_special_tokens = True)
            if len(input_ids) > 150:
                print("Found item length: " + str(len(input_ids)))
                list_of_j.append(j)
            j += 1
        k = len(list_of_j)-1
        #Gotta count backwards so we don't disturb the list structure
        while k > -1:
            if print_pop:
                print(token_list[i].pop(list_of_j[k]))
                print(label_list[i].pop(list_of_j[k]))
            else:
                token_list[i].pop(list_of_j[k])
                label_list[i].pop(list_of_j[k])
            k = k - 1
        i += 1 
    return token_list, label_list

In [15]:
def tokenize_and_align_lables(list_of_tokens, list_of_labels, encoding_dict, max_sent_length = 155):
    """This function takes the token list, BERT-tokenizes it, all while maintaining the match between 
    labels and words. This is important as BERT breaks down full tokens into subwords, so sometimes
    a label will need to span multiple subwords to correctly label the situation. This function
    also builds the padding tokens needed, padding to a length set by max_sent_length."""
    
    input_ids = []
    attention_masks = []
    label_shapes = []
    no_pad_labels = []
    count = 0
    
    #Instantiate a tokenizer from the BERT Tokenizer class
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case = True)

    for abstracts, abst_labels in zip(list_of_tokens, list_of_labels):
        for sentences, sent_labels in zip(abstracts, abst_labels):
            encoded_dict = tokenizer.encode_plus(
                                        sentences,
                                        add_special_tokens = True,
                                        max_length = max_sent_length,
                                        pad_to_max_length = True,
                                        return_attention_mask = True,
                                        return_tensors = 'pt'
            )
            #Ok, now we get our labels based on encoded sizes. 
            #Make this a standalone function later instead of nesting
            #Need to start the CLS token to every label. 
            #This CLS token should be a int 0, to keep label
            #length matching consistent with the tokenized sentence
            extend_sent_labels = [0]


            #I Bet the problem with things being read in is the difference in length of 
            #however this chunks sentences versus how the labels were originally split.

            for word, label in zip(sentences, sent_labels):
                tokenized_word = tokenizer.tokenize(word)
                #Find out how many chunks each word gets broken into
                n_subwords = len(tokenized_word)
                #Extend the length of the labels to match new word length
                #Put label in brackets so it knows you want n_subwords entries
                #of label, not label times n_subwords
                extend_sent_labels.extend([label]*n_subwords)

            #In order to know just how much to bias the dataset for each label, we need to know how many of each we have. 
            no_pad_labels.extend(extend_sent_labels)

            #This handles increasing the length for padding and sep tokens
            #Go all the way to 155. Padding and  SEP should both be PAD tokens in label form
            #Because the key map will switch them to a 0. 
            extend_sent_labels.extend(['PAD']*(155-len(extend_sent_labels)))    
    
            #Next step, we need to use the dictionary lookup
            #to replace all the values from this list to become 
            #numbers. for loops to go through the whole list. 
            for i in range(len(extend_sent_labels)):
                if extend_sent_labels[i] in encoding_dict.keys():
                    #Replace the label in extend_set_labels[i] from dict
                    extend_sent_labels[i] = encoding_dict[extend_sent_labels[i]]
            #print(extend_sent_labels)
            #Then, we make the labels list into a tensor.
            #extend_sent_labels = torch.tensor(extend_sent_labels)
            test_list = []
            test_list.append(extend_sent_labels)
            test_list = torch.tensor(test_list)
            #Build our attention mask, labels, and input ids of each item.
            label_shapes.append(test_list)
            input_ids.append(encoded_dict['input_ids'])
            attention_masks.append(encoded_dict['attention_mask'])

    #Make lists we just built into tensors
    input_ids = torch.cat(input_ids, dim = 0)
    attention_masks = torch.cat(attention_masks, dim = 0)
    labels = torch.cat(label_shapes, dim = 0)

    print("Original Sentence: ", list_of_tokens[0][0])
    print("Tokenized IDs: ", input_ids[0])
    print("Extended Labels: ", labels[0])
    
    return input_ids, attention_masks, labels

## Calling the above functions:

This next set shows calling all the above functions, and then getting to the point where data is ready to input into a BERT model for training.

In [10]:
#Load all of the data into two lists, one for tokens, one for labels.
dir_url = '/Users/Jonathan/Desktop/LabeledChemEData/Labeled_Sheets/'
list_o_tokens, list_o_labels = labeled_sheets_to_listed_tokens(dir_url)

Weird. This Property is organic or inorganic? LOL
Weird. This Property is organic or inorganic? LOL


In [16]:
#Figure out the max length, and all lengths. All lengths helps to figure out where 
#to set your max length for the delete sentences function. 
max_len_sent, list_of_all_lengths = max_encoded_length(list_o_tokens)

(509,
 [51,
  27,
  25,
  6,
  30,
  19,
  32,
  136,
  28,
  34,
  52,
  10,
  47,
  27,
  22,
  48,
  10,
  46,
  30,
  38,
  32,
  33,
  45,
  32,
  65,
  19,
  18,
  40,
  49,
  26,
  37,
  15,
  27,
  3,
  43,
  181,
  18,
  28,
  34,
  39,
  26,
  30,
  9,
  39,
  28,
  50,
  29,
  40,
  26,
  78,
  32,
  56,
  46,
  35,
  35,
  11,
  51,
  20,
  24,
  22,
  21,
  30,
  33,
  30,
  17,
  12,
  20,
  30,
  34,
  40,
  45,
  21,
  34,
  14,
  13,
  17,
  30,
  41,
  48,
  11,
  18,
  28,
  31,
  31,
  10,
  36,
  14,
  26,
  46,
  23,
  31,
  33,
  33,
  35,
  4,
  35,
  35,
  38,
  26,
  47,
  43,
  10,
  34,
  27,
  41,
  41,
  10,
  53,
  35,
  33,
  40,
  13,
  18,
  26,
  16,
  3,
  49,
  38,
  53,
  69,
  52,
  28,
  37,
  31,
  33,
  56,
  20,
  29,
  42,
  23,
  83,
  41,
  29,
  23,
  10,
  29,
  34,
  25,
  34,
  30,
  33,
  87,
  22,
  35,
  48,
  40,
  57,
  37,
  38,
  33,
  51,
  46,
  49,
  49,
  7,
  83,
  26,
  29,
  28,
  57,
  60,
  28,
  3,
  45,
  30,
  58,
  4

In [None]:
#Quick plot to show approx dist of sentence lengths in our corpus
fig, axs = plt.subplots()
axs.hist(list_of_all_lengths, 20)

In [17]:
#Delete sentecnes that are too long.
short_tokens, short_labels = delete_sentences(list_o_tokens, list_o_labels)

Found item length: 181
Found item length: 189
Found item length: 172
Found item length: 270
Found item length: 178
Found item length: 191
Found item length: 171
Found item length: 151
Found item length: 235
Found item length: 161
Found item length: 509
Found item length: 174
Found item length: 207
Found item length: 248
Found item length: 180
Found item length: 155
Found item length: 165
Found item length: 173


In [18]:
#Long term we should absolutely consider going back to the individual labels. Might need
#to have more data than we currently do. Definitely eventually want to separate organic 
#and inorganic chemicals once we get a high-performing model. 
small_label_mapping = {'O': 3, "B-MOL-O": 1, "I-MOL-O": 1, "E-MOL-O": 1,
                "S-MOL-O": 1, "B-MOL-I": 1, "I-MOL-I": 1, "E-MOL-I": 1,
                "S-MOL-I": 1, "B-PRO": 2, "I-PRO": 2, "E-PRO": 2, "S-PRO": 2,
                'PAD': 0}

In [20]:
input_tokens, attention_mask, input_labels = tokenize_and_align_lables(short_tokens, short_labels, small_label_mapping)

Original Sentence:  In the interaction between gas molecules with single-walled carbon nanotube (SWCNT) we show that as a result of collisions the gas scattering contributes with an important background signal and should be considered in SWCNT-based gas sensors.
Tokenized IDs:  tensor([  101,  1999,  1996,  8290,  2090,  3806, 10737,  2007,  2309,  1011,
        17692,  6351, 28991, 28251,  2063,  1006, 25430,  2278,  3372,  1007,
         2057,  2265,  2008,  2004,  1037,  2765,  1997, 28820,  1996,  3806,
        17501, 16605,  2007,  2019,  2590,  4281,  4742,  1998,  2323,  2022,
         2641,  1999, 25430,  2278,  3372,  1011,  2241,  3806, 13907,  1012,
          102,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     

In [22]:
#Once we have our tokenized values, we'll build a dataloader to save on memory.
#If you find your computer unable to handle this next bit, lower batch size to 16.
#Likely you will see a performance hit, but it's better than it not working!
dataset = TensorDataset(input_tokens, attention_mask, input_labels)
batch_size = 32
dataloader = DataLoader(dataset,
                        sampler = RandomSampler(dataset),
                        batch_size = batch_size
                        )

In [28]:
dataset = TensorDataset(input_tokens, attention_mask, input_labels)

#Do our 90/10 training/validation split
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size

#Now do a train_val split, randomly
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

print('{:>5,} training samples'.format(train_size))
print('{:>5,} validation samples'.format(val_size))

1,144 training samples
  128 validation samples


In [31]:
#Set a dataloader batch size. Paper says do 16 or 32, but do bigger if comp can
batch_size = 32

#Create training dataloaders using a random sequence pull
train_dataloader = DataLoader(train_dataset,
                             sampler = RandomSampler(train_dataset),
                             batch_size = batch_size
                             )

#Validation can go any order, so we'll do sequentially
validation_dataloader = DataLoader(val_dataset,
                                  sampler = SequentialSampler(val_dataset),
                                  batch_size = batch_size
                                  )


****

# BERT and Classifier Model Construction

Here will build up a BERT model from the huggingface BERT base case, basically stacking a BERT model together with a NN classifier of some sort. - We'll start with a 1-layer just to get going. 3-layer seems a likely space for us to end up. 

In [38]:
model = BertModel.from_pretrained(
        "bert-base-uncased",
        output_attentions = False, #Whether model returns attention weights
        output_hidden_states = False, #Whether model outputs all hidden states                                    
        )

In [39]:
model

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

In [77]:
class BERTplus1Layer(nn.Module):
    #There's a more general version of this on the huggingface website.
    #https://huggingface.co/transformers/_modules/transformers/modeling_bert.html#BertForTokenClassification
    def __init__(self, p = 0.1):
        super(BERTplus1Layer, self).__init__()
        self.bert_segment = BertModel.from_pretrained(
        'bert-base-uncased',
            output_attentions = False,
            output_hidden_states = False
        )
        
        self.classifier_dropout = nn.Dropout(p = p)
        #We need a call to instantiate the weights here I think. 
        #Alternatively, we could have weights for our loss function I guess
        self.classifier1 = nn.Linear(768, 4)
    #This model setup actually loooks correct...? It instantiates pretty well.
    #It's really basically what they have on the huggingface website. 
    
    
    
    #Think we'll need to handle having the tokens, labels, and att. mask.
    #Not sure....
    #You can pass multiple inputs, and return multiple outputs. 
    #Maybe I could have the loss calculated here too, so I could return the
    #Loss and predictions at the same time. 
    def forward(self, input_ids = None, attention_mask = None, labels = None):
        bert_outputs = self.bert_segment(input_ids, attention_mask = attention_mask)
        #Once we have the outputs, do the thing we did last time, where we take the 
        #output and get the embedding, and pump the embedding into the classifier.
        sequence_outputs = bert_outputs[0]
        
        sequence_output = self.classifier_dropout(sequence_outputs)
        logits = self.classifier(sequence_output)
        
        #Here, include the weight of the function
        #Give 0 weight to 0,
        #give inverse class count to everything else.
        weight_list = np.empty(4)
        key_count = np.zeros(len(set(labels)))        
        #Pretty sure this needs nested i and j. 
        #THIS WILL BE SLOW. WE SHOULD INSTEAD PASS
        #OUR LOSS FUNCTION WEIGHTS IN EVERY TIME
        for i in range(len(labels)):
            for j in range(len(labels[i])):
                key_count[labels[i][j]] += 1
        weight_list = (sum(key_count)/len(set(labels))*key_count)
        weight_list[0] = 0
                
        loss_function = nn.CrossEntropyLoss(weight = weight_list)
            
        #If we have labels, we can calculate losses.
        #Focus on building a loss that doesn't consider padding tokens
        #Can't drop the padding or outside tokens just yet. 
        if labels is not None:
            #Nested for loops to go through all outputs
            for i in range(len(logits)):
                for j in range(len(logits[i])):
                    if attention_mask[i][j] == 0:
                        pass
                    #Need losses for all non-pad tokens
                    elif attention_mask[i][j] == 1:
                        pass
        
        
        
        
        return logits, loss

In [78]:
model = BERTplus1Layer()

In [79]:
model

BERTplus1Layer(
  (bert_segment): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_af

In [53]:
model.modules

<bound method Module.modules of BERTplus1Layer(
  (bert_segment): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((

It doesn't matter if the classifier never learns about padding tokens. The only important thing is that the classifier never guesses that anything is a padding token. If we set a weight of 0 to the padding token classifier output, then we can turn that off permanently. And, since we'll always know what is and isn't a padding token via the attention mask, then we're golden. It doesn't matter what it predicts them as, just that it never predicts anything that isn't a padding token as being a pad. 

As such, we can get away with only including non-pads in the loss, and never letting the model learn about pads. 

In [76]:
len(set([1,3,5,6,7,8,1]))

6

In [None]:
class Net(nn.Module):
    def __init__(self, p = 0.1):
        super(Net, self).__init__()
        self.classifier_dropout = nn.Dropout(p = p)
        self.classifier = nn.Linear(768,4)
    def forward(self, x):
        x = self.classifier_dropout(x)
        x = self.classifier(x)
        return x

In [48]:
# Note: AdamW is a class from the huggingface library (as opposed to pytorch) 
# I believe the 'W' stands for 'Weight Decay fix"
optimizer = AdamW(model.parameters(),
                  lr = 2e-5, # args.learning_rate - default is 5e-5, our notebook had 2e-5
                  eps = 1e-8 # args.adam_epsilon  - default is 1e-8.
                )

In [49]:
# Number of training epochs. The BERT authors recommend between 2 and 4. 
# We chose to run for 1. Largely cause I'm not sure my computer will survive
#This could also be a major source of error. I'm thinking I'll try 2 next and pray.
#When I ran for 20 epochs, the model saturated around 5 epochs with 2e-5 lr though.
epochs = 5

# Total number of training steps is [number of batches] x [number of epochs]. 
# (Note that this is not the same as the number of training samples).
total_steps = len(train_dataloader) * epochs

# Create the learning rate scheduler.
scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps = 0, # Default value in run_glue.py
                                            num_training_steps = total_steps)

In [50]:
device = torch.device("cpu")