# ***Legal BERT - Model Enhancement***

### Global variables

In [None]:
# Global variables

BATCH_SIZE = 32
MODEL_NAME = 'nlpaueb/legal-bert-small-uncased' #'bert-base-uncased'
EPOCHS = 3
EMBEDDING_SIZE = 512
NUM_CLASSES = 2
VOCABULARY_SIZE = 30522
NUM_TOKENS = 10
DEFAULT_TOKEN = 207 # What in BERT represents the word "the"
LIST_ID_SPECIAL_TOKENS = [0, 101, 102, 103]
LIST_SPECIAL_TOKENS = ['[PAD]', '[CLS]', '[SEP]', '[MASK]']

### Installation of packages

In [None]:
!pip install transformers
!pip install torch-lr-finder

### Imports

In [None]:
import torch
import os
from transformers import BertTokenizer
from google.colab import drive
from torch.utils.data import TensorDataset, random_split
from transformers import BertForSequenceClassification, AdamW, BertConfig
from transformers import get_linear_schedule_with_warmup
import numpy as np
import time
import datetime
import random
import gc
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
from sklearn.model_selection import train_test_split
from copy import deepcopy
from datetime import datetime
import datetime

### Mount on GoogleDrive

In [None]:
drive.mount('/content/drive')

### Device

In [None]:
if torch.cuda.is_available():    
    device = torch.device("cuda")

    print('There are %d GPU(s) available.' % torch.cuda.device_count())

    print('We will use the GPU:', torch.cuda.get_device_name(0))
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

### Reading original dataset

In [None]:
def get_sentences(path):
    sentences= []
    for filename in sorted(os.listdir(path)):
        with open(path+filename, 'r') as f:
            for sentence in f :
                if sentence.strip() != '':
                    sentences.append(sentence)
    return sentences

def get_labels(path):
    all_labels = []
    for filename in sorted(os.listdir(path)):
        file_labels = []
        with open(path+filename, 'r') as f:
            for label in f :
                if label.strip() != '':
                    all_labels.append(int(label))
    return all_labels

In [None]:
# Reading sentences and labels of the Claudette Dataset
all_sentences = get_sentences("ToS/TrainValSet/Sentences/")
all_labels = get_labels("ToS/TrainValSet/Labels/")

# Conversion of unfair sentence's flag|tag. Since unfair sentences are marked as "-1", we change them to "0" 
all_labels =  [0 if label ==-1 else label for label in all_labels]

## **Trigger Generation**

### Bert (Model, Tokenizer and Load)

In [None]:
print('Loading BERT tokenizer...')
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME, do_lower_case=True) # the model 'bert-base-uncased' only contains lower case sentences

### Model BertForSequenceClassification (Load model)

In [None]:
model = BertForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels = NUM_CLASSES,
    output_attentions = False,
    output_hidden_states = False,
)

model.cuda()

In [None]:
model.load_state_dict(torch.load('Bert4SeqClassif_20220804_1134_wPersistency.pt'))

### Trigger generation

##### General functions

In [None]:
# hook used in add_hooks()
extracted_grads = []
def extract_grad_hook(module, grad_in, grad_out):
    extracted_grads.append(grad_out[0])

# returns the wordpiece embedding weight matrix
def get_embedding_weight(language_model):
    for module in language_model.modules():
        if isinstance(module, torch.nn.Embedding):
            if module.weight.shape[0] == 30522: 
                return module.weight.detach()

# add hooks for embeddings
def add_hooks(language_model):
    for module in language_model.modules():
        if isinstance(module, torch.nn.Embedding):
            if module.weight.shape[0] == 30522: 
                module.weight.requires_grad = True
                module.register_full_backward_hook(extract_grad_hook)


# creates the batch of target texts with -1 placed at the end of the sequences for padding (for masking out the loss).
def make_target_batch(tokenizer, device, target_texts):
    # encode items and get the max length
    encoded_texts = []
    max_len = 0
    for target_text in target_texts:
        encoded_target_text = tokenizer.encode_plus(
            target_text,
            add_special_tokens = True,
            max_length = EMBEDDING_SIZE - NUM_TOKENS,
            pad_to_max_length = True,
            return_attention_mask = True
        )
        encoded_texts.append(encoded_target_text.input_ids)
        if len(encoded_target_text.input_ids) > max_len:
            max_len = len(encoded_target_text)

    # pad tokens, i.e., append -1 to the end of the non-longest ones
    for indx, encoded_text in enumerate(encoded_texts):
        if len(encoded_text) < max_len:
            encoded_texts[indx].extend([-1] * (max_len - len(encoded_text)))

    # convert to tensors and batch them up
    target_tokens_batch = None
    for encoded_text in encoded_texts:
        target_tokens = torch.tensor(encoded_text, device=device, dtype=torch.long).unsqueeze(0)
        if target_tokens_batch is None:
            target_tokens_batch = target_tokens
        else:
            target_tokens_batch = torch.cat((target_tokens, target_tokens_batch), dim=0)
    return target_tokens_batch

# Got from https://github.com/Eric-Wallace/universal-triggers/blob/master/attacks.py
def hotflip_attack(averaged_grad, embedding_matrix, trigger_token_ids,
                   increase_loss=False, num_candidates=1):
    """
    The "Hotflip" attack described in Equation (2) of the paper. This code is heavily inspired by
    the nice code of Paul Michel here https://github.com/pmichel31415/translate/blob/paul/
    pytorch_translate/research/adversarial/adversaries/brute_force_adversary.py
    This function takes in the model's average_grad over a batch of examples, the model's
    token embedding matrix, and the current trigger token IDs. It returns the top token
    candidates for each position.
    If increase_loss=True, then the attack reverses the sign of the gradient and tries to increase
    the loss (decrease the model's probability of the true class). For targeted attacks, you want
    to decrease the loss of the target class (increase_loss=False).
    """
    averaged_grad = averaged_grad.cpu()
    embedding_matrix = embedding_matrix.cpu()
    trigger_token_embeds = torch.nn.functional.embedding(torch.LongTensor(trigger_token_ids),
                                                         embedding_matrix).detach().unsqueeze(0)
    averaged_grad = averaged_grad.unsqueeze(0)
    gradient_dot_embedding_matrix = torch.einsum("bij,kj->bik",
                                                 (averaged_grad, embedding_matrix))        
    if not increase_loss:
        gradient_dot_embedding_matrix *= -1    # lower versus increase the class probability.
    if num_candidates > 1: # get top k options
        _, best_k_ids = torch.topk(gradient_dot_embedding_matrix, num_candidates, dim=2)
        return best_k_ids.detach().cpu().numpy()[0]
    _, best_at_each_step = gradient_dot_embedding_matrix.max(2)
    return best_at_each_step[0].detach().cpu().numpy()

def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

In [None]:
def get_input_masks_and_labels_with_tokens(sentences, labels, tokens, position='B'):
    input_ids = []
    attention_masks = []
    number_of_tokens = []

    for sent in sentences:

        if position == 'B':
            sent_with_tokens = tokens + " " + sent
        elif position == 'E':
            sent_with_tokens = sent + " " + tokens
        else:
            print('Wrong position command, please enter "E" or "B"')
            return

        encoded_dict = tokenizer.encode_plus(
                        sent_with_tokens,
                        add_special_tokens = True,
                        max_length = 512,
                        pad_to_max_length = True,
                        return_attention_mask = True,
                        return_tensors = 'pt',
                   )


        input_ids.append(encoded_dict['input_ids'])

        attention_masks.append(encoded_dict['attention_mask'])

    input_ids = torch.cat(input_ids, dim=0)
    attention_masks = torch.cat(attention_masks, dim=0)
    labels = torch.tensor(labels)

    # count number of tokens of each sentence
    for idx in range(len(input_ids)):
      sent_ids = input_ids[idx, :]

      cnt = 0
      for id in sent_ids:
          if id != 0:
              cnt += 1

      number_of_tokens.append(cnt)  

    return input_ids, attention_masks, labels, number_of_tokens

In [None]:
def get_loss_and_metrics(model, dataloader, device):
    # get initial loss for the trigger
    model.zero_grad()

    test_preds = []
    test_targets = []

    # Tracking variables 
    total_test_accuracy = 0
    total_test_loss = 0
    io_total_test_acc = 0
    io_total_test_prec = 0
    io_total_test_recall = 0
    io_total_test_f1 = 0

    for batch in dataloader:
        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)

        model.zero_grad()

        result = model(b_input_ids, 
                    token_type_ids=None, 
                    attention_mask=b_input_mask, 
                    labels=b_labels,
                    return_dict=True)

        loss = result.loss
        logits = result.logits

        test_preds.extend(logits.argmax(dim=1).cpu().numpy())
        test_targets.extend(batch[2].numpy())

        # Accumulate the validation loss.
        total_test_loss += loss.item()

        test_preds.extend(logits.argmax(dim=1).cpu().numpy())
        test_targets.extend(batch[2].numpy())

        # Move logits and labels to CPU
        logits = logits.detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()

        loss.backward()        

        # Calculate the accuracy for this batch of test sentences, and
        # accumulate it over all batches.        
        test_acc = accuracy_score(test_targets, test_preds)
        test_precision = precision_score(test_targets, test_preds)
        test_recall = recall_score(test_targets, test_preds)
        test_f1 = f1_score(test_targets, test_preds)

        io_total_test_acc += test_acc
        io_total_test_prec += test_precision
        io_total_test_recall += test_recall
        io_total_test_f1 += test_f1

    io_avg_test_loss = total_test_loss/len(dataloader)
    io_avg_test_acc = io_total_test_acc / len(dataloader)
    io_avg_test_prec = io_total_test_prec / len(dataloader)
    io_avg_test_recall = io_total_test_recall / len(dataloader)
    io_avg_test_f1 = io_total_test_f1 / len(dataloader)

    return io_avg_test_loss, io_avg_test_acc, io_avg_test_prec, io_avg_test_recall, io_avg_test_f1

In [None]:
def change_input_ids_with_candidate_token(input_ids, position, candidate, number_of_tokens, trigger_position='B'):
    if trigger_position == 'B':
        input_ids[:, position] = candidate
    elif trigger_position == 'E':
        for idx in range(len(input_ids)):
            if number_of_tokens[idx] > EMBEDDING_SIZE:
                input_ids[idx, EMBEDDING_SIZE-NUM_TOKENS-2+position] = candidate
            else:
                input_ids[idx, number_of_tokens[idx]-NUM_TOKENS-2+position] = candidate
    else:
        print('Wrong position command, please enter "E" or "B"')
        return
    return input_ids

##### Execution of the trigger generation

###### Get unfair sentences

In [None]:
positions_unfair = np.where(np.array(all_labels) == 1)[0]

target_unfair_sentences = []
labels_unfair_sentences = []
for index in range(len(positions_unfair)):
    target_unfair_sentences.append(all_sentences[positions_unfair[index]])
    labels_unfair_sentences.append(all_labels[positions_unfair[index]])

# Initialization of tokens with the particle "the" whose id number in BERT model is 207
trigger_tokens = np.array([DEFAULT_TOKEN]*NUM_TOKENS)
print(tokenizer.decode(trigger_tokens))

###### Addition of hooks and get word embeddings

In [None]:
model.eval()
model.to(device)

add_hooks(model) # add gradient hooks to embeddings
embedding_weight = get_embedding_weight(model) # save the word embedding matrix

###### Get tensors and creation of dataset and dataloader

In [None]:
## Define at which position we want to have the triggers
position = 'B' # Possible values {B|E} for beginning and end, respectively

input_ids, attention_masks, labels, number_of_tokens = get_input_masks_and_labels_with_tokens(target_unfair_sentences, labels_unfair_sentences, tokenizer.decode(trigger_tokens), position=position)

dataset = TensorDataset(input_ids, attention_masks, labels)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE)

###### Execution of trigger generation

In [None]:
timestamp = datetime.datetime.now().strftime("%y%m%d_%H_%M_%S_%p")
f = open(f"Execution_Pos{position}_NumTokens_{NUM_TOKENS}_{timestamp}.txt", "w")

In [None]:
candidates_dict = {}
candidates_dict["combinations_ids"] = []
candidates_dict["combinations_labels"] = []
foundX = -1
foundY = -1
candidates_combination = [DEFAULT_TOKEN]*NUM_TOKENS
goalTag = 'Loss'

extracted_grads = []

loss_obtained, acc_obtained, prec_obtained, recall_obtained, f1_obtained = get_loss_and_metrics(model, dataloader, device)
#print(f'loss_obtained {loss_obtained}')

candidates_selected = [DEFAULT_TOKEN]*NUM_TOKENS
# try all the candidates and pick the best
curr_best_loss = loss_obtained
curr_best_trigger_tokens = None

print(f"{position}[{foundX:3},{foundY:3}] TokenID[{candidates_combination}] TokensDesc[{tokenizer.decode(candidates_combination)}] Loss[{loss_obtained:3.5}] Acc[{acc_obtained:3.5}] Prec[{prec_obtained:3.5}] Recall[{recall_obtained:3.5}] F1[{f1_obtained:3.5}] => Worst<<{goalTag}>>[{curr_best_loss:3.5}] Found at [{foundX:3},{foundY:3}]")
f.write(f'{position}[{foundX:3},{foundY:3}] TokenID[{candidates_combination}] TokensDesc[{tokenizer.decode(candidates_combination)}] Loss[{loss_obtained:3.5}] Acc[{acc_obtained:3.5}] Prec[{prec_obtained:3.5}] Recall[{recall_obtained:3.5}] F1[{f1_obtained:3.5}] => Worst<<{goalTag}>>[{curr_best_loss:3.5}] Found at [{foundX:3},{foundY:3}]')

for id_token_to_flip in range(0, NUM_TOKENS):

    averaged_grad = torch.sum(extracted_grads[0], dim=0)
    averaged_grad = averaged_grad[id_token_to_flip].unsqueeze(0)

    # Use hotflip (linear approximation) attack to get the top num_candidates
    candidates = hotflip_attack(averaged_grad, embedding_weight,
                                        [trigger_tokens[id_token_to_flip]], 
                                        increase_loss=False, num_candidates=100)[0]
    print(f'candidates {candidates}')
    f.write(f'{position} candidates {candidates}')
    candidates_dict[id_token_to_flip] = candidates
    
    for index, cand in enumerate(candidates):
        extracted_grads = []

        if cand in LIST_ID_SPECIAL_TOKENS:
          continue

        input_ids_with_candidate_trigger = change_input_ids_with_candidate_token(deepcopy(input_ids), id_token_to_flip+1, cand, number_of_tokens, trigger_position=position)
        dataset_with_candidate_trigger = TensorDataset(input_ids_with_candidate_trigger, attention_masks, labels)
        dataloader_with_candidate_trigger = torch.utils.data.DataLoader(dataset_with_candidate_trigger, batch_size=BATCH_SIZE)

        current_loss, current_acc, current_prec, current_recall, current_f1 = get_loss_and_metrics(model, dataloader_with_candidate_trigger, device)

        if curr_best_loss < current_loss:
            curr_best_loss = current_loss
            candidates_selected[id_token_to_flip] = cand

            foundX = id_token_to_flip
            foundY = index
        candidates_combination[id_token_to_flip] = cand
        candidates_dict["combinations_ids"].append(candidates_combination)
        candidates_dict["combinations_labels"].append(tokenizer.decode(candidates_combination))
        print(f'[{id_token_to_flip:3},{index:3}] TokenID[{candidates_combination}] TokensDesc[{tokenizer.decode(candidates_combination)}] Loss[{current_loss:3.5}] Acc[{current_acc:3.5}] Prec[{current_prec:3.5}] Recall[{current_recall:3.5}] F1[{current_f1:3.5}] => Worst<<{goalTag}>>[{curr_best_loss:3.5}] Found at [{foundX:3},{foundY:3}]')
        f.write(f'{position}[{id_token_to_flip:3},{index:3}] TokenID[{candidates_combination}] TokensDesc[{tokenizer.decode(candidates_combination)}] Loss[{current_loss:3.5}] Acc[{current_acc:3.5}] Prec[{current_prec:3.5}] Recall[{current_recall:3.5}] F1[{current_f1:3.5}] => Worst<<{goalTag}>>[{curr_best_loss:3.5}] Found at [{foundX:3},{foundY:3}]\n')
        
        del input_ids_with_candidate_trigger
        del dataset_with_candidate_trigger
        del dataloader_with_candidate_trigger

        gc.collect()
        torch.cuda.empty_cache()

        #print(f'Candidates selected {candidates_selected} VS candidates_combination {candidates_combination}')
        #print(f'[{id_token_to_flip}][{index}] loss[{index}] {current_loss} ({curr_best_loss})')

    candidates_combination = deepcopy(candidates_selected)
    #extracted_grads = []
    input_ids = change_input_ids_with_candidate_token(deepcopy(input_ids), id_token_to_flip+1, candidates_selected[id_token_to_flip], number_of_tokens, trigger_position=position)
    print(f'Worst loss {curr_best_loss} with candidates {candidates_selected} in the {id_token_to_flip}-iteration with tokens [{tokenizer.decode(candidates_selected)}]')
    f.write(f'{position}Worst loss {curr_best_loss} with candidates {candidates_selected}\n')
f.close()

In [None]:
tokenizer.decode(candidates_selected)

##### Generate sentences with tokens

In [None]:
len(target_unfair_sentences)

In [None]:
# Store new sentences with the tokens at the beginning
position="B"
list_tokens_deal_worst_loss = [
    tokenizer.decode([621, 207, 3523]), #unless the everyone
    tokenizer.decode([621, 207, 207]) #unless the the
]

dict_distribution = {
    0: 0,
    1: 0
}

timestamp = datetime.datetime.now().strftime("%y%m%d_%H_%M_%S_%p")
f = open(f"ToS/DataAugmentation/Pos{position}_NumTokens{NUM_TOKENS}_{timestamp}.txt", "w")

for unf_sent in target_unfair_sentences:
    index = 1 if random.uniform(0,1) > 0.5 else 0
    f.write(f'{list_tokens_deal_worst_loss[index]} {unf_sent}\n')
    dict_distribution[index] += 1
f.close()

print(dict_distribution) # {0: 508, 1: 524} => {0: 471, 1: 520} (96(95-05)-04))

In [None]:
# Store new sentences with the tokens at the beginning
position="E"
list_tokens_deal_worst_loss = [
    tokenizer.decode([1297, 17560, 1004]), #unless the everyone
    tokenizer.decode([1297, 17560, 431]) #unless the the
]

dict_distribution = {
    0: 0,
    1: 0
}

timestamp = datetime.datetime.now().strftime("%y%m%d_%H_%M_%S_%p")
f = open(f"ToS/DataAugmentation/Pos{position}_NumTokens{NUM_TOKENS}_{timestamp}.txt", "w")

for unf_sent in target_unfair_sentences:
    index = 1 if random.uniform(0,1) > 0.5 else 0

    sent_tokenized = tokenizer.encode(unf_sent)

    if len(sent_tokenized) > 507:
      str_aux = (tokenizer.decode(sent_tokenized[0:507])).replace("\n", "")
      f.write(f'{str_aux} {list_tokens_deal_worst_loss[index]}\n')
    else:
      str_aux = unf_sent.replace("\n", "")
      f.write(f'{str_aux} {list_tokens_deal_worst_loss[index]}\n')
    dict_distribution[index] += 1
f.close()

print(dict_distribution) 

## Training of a new model

### Data split

In [None]:
all_sentences_and_augmented = get_sentences("ToS/TrainValSetWithAugmentedData/Sentences/")
all_labels_and_augmented = get_labels("ToS/TrainValSetWithAugmentedData/Labels/")

In [None]:
input_ids = []
attention_masks = []

for sent in all_sentences_and_augmented:
    encoded_dict = tokenizer.encode_plus(
                        sent,
                        add_special_tokens = True,
                        max_length = 512,         
                        pad_to_max_length = True, 
                        return_attention_mask = True,
                        return_tensors = 'pt',    
                   )
    
    input_ids.append(encoded_dict['input_ids'])
    
    attention_masks.append(encoded_dict['attention_mask'])

input_ids = torch.cat(input_ids, dim=0)
attention_masks = torch.cat(attention_masks, dim=0)
labels = torch.tensor(all_labels_and_augmented)

print('Original: ', all_sentences_and_augmented[0])
print('Token IDs:', input_ids[0])

dataset = TensorDataset(input_ids, attention_masks, labels)


train_idx, valid_idx = train_test_split(np.arange(len(labels)), test_size=0.05, shuffle=True, stratify=labels)

train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_idx)

train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, sampler=train_sampler)
validation_dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, sampler=valid_sampler)

### Training classification model

In [None]:
model = BertForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels = NUM_CLASSES,
    output_attentions = False,
    output_hidden_states = False,
)
model.cuda()

### Optimizer & Learning Rate Scheduler

In [None]:

optimizer = AdamW(model.parameters(),
                  lr = 2e-5,
                  eps = 1e-8,
                )

In [None]:
epochs = EPOCHS

total_steps = len(train_dataloader) * epochs

from torch.optim.lr_scheduler import CosineAnnealingLR
MIN_LR = 1e-5
scheduler = CosineAnnealingLR(optimizer, 600, eta_min = MIN_LR)

### Training

In [None]:
def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

def format_time(elapsed):
    elapsed_rounded = int(round((elapsed)))
    return str(datetime.timedelta(seconds=elapsed_rounded))

In [None]:
tr_metrics = []
va_metrics = []
tmp_print_flag = True


# This training code is based on the `run_glue.py` script here:
# https://github.com/huggingface/transformers/blob/5bfcd0485ece086ebcbed2d008813037968a9e58/examples/run_glue.py#L128

# Set the seed value all over the place to make this reproducible.
seed_val = 42

random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

# We'll store a number of quantities such as training and validation loss, 
# validation accuracy, and timings.
training_stats = []

# Measure the total training time for the whole run.
total_t0 = time.time()

# For each epoch...
for epoch_i in range(0, epochs):
    train_loss = 0.0
    train_preds = []
    train_targets = []
    
    # ========================================
    #               Training
    # ========================================
    
    # Perform one full pass over the training set.

    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
    print('Training...')

    # Measure how long the training epoch takes.
    t0 = time.time()

    # Reset the total loss for this epoch.
    total_train_loss = 0

    io_total_train_acc = 0
    io_total_train_prec = 0
    io_total_train_recall = 0
    io_total_train_f1 = 0
    io_total_valid_acc = 0
    io_total_valid_prec = 0
    io_total_valid_recall = 0
    io_total_valid_f1 = 0

    # Put the model into training mode. Don't be mislead--the call to 
    # `train` just changes the *mode*, it doesn't *perform* the training.
    # `dropout` and `batchnorm` layers behave differently during training
    # vs. test (source: https://stackoverflow.com/questions/51433378/what-does-model-train-do-in-pytorch)
    model.train()

    # For each batch of training data...
    for step, batch in enumerate(train_dataloader):

        # Progress update every 100 batches.
        if step % 100 == 0 and not step == 0:
            # Calculate elapsed time in minutes.
            elapsed = format_time(time.time() - t0)
            
            # Report progress.
            print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(train_dataloader), elapsed))

        # Unpack this training batch from our dataloader. 
        #
        # As we unpack the batch, we'll also copy each tensor to the GPU using the 
        # `to` method.
        #
        # `batch` contains three pytorch tensors:
        #   [0]: input ids 
        #   [1]: attention masks
        #   [2]: labels 
        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)

        # Always clear any previously calculated gradients before performing a
        # backward pass. PyTorch doesn't do this automatically because 
        # accumulating the gradients is "convenient while training RNNs". 
        # (source: https://stackoverflow.com/questions/48001598/why-do-we-need-to-call-zero-grad-in-pytorch)
        model.zero_grad()        

        # Perform a forward pass (evaluate the model on this training batch).
        # In PyTorch, calling `model` will in turn call the model's `forward` 
        # function and pass down the arguments. The `forward` function is 
        # documented here: 
        # https://huggingface.co/transformers/model_doc/bert.html#bertforsequenceclassification
        # The results are returned in a results object, documented here:
        # https://huggingface.co/transformers/main_classes/output.html#transformers.modeling_outputs.SequenceClassifierOutput
        # Specifically, we'll get the loss (because we provided labels) and the
        # "logits"--the model outputs prior to activation.
        result = model(b_input_ids, 
                       token_type_ids=None, 
                       attention_mask=b_input_mask, 
                       labels=b_labels,
                       return_dict=True)
        """
        if tmp_print_flag:
          tmp_print_flag = False
          print(f'result.keys() = {result.keys()}')
        """

        loss = result.loss
        logits = result.logits

        """
        print(f'loss {loss}')
        print(f'logits {logits}')
        """
        train_preds.extend(logits.argmax(dim=1).cpu().numpy())
        train_targets.extend(batch[2].numpy())

        # Accumulate the training loss over all of the batches so that we can
        # calculate the average loss at the end. `loss` is a Tensor containing a
        # single value; the `.item()` function just returns the Python value 
        # from the tensor.
        total_train_loss += loss.item()

        # Perform a backward pass to calculate the gradients.
        loss.backward()

        # Clip the norm of the gradients to 1.0.
        # This is to help prevent the "exploding gradients" problem.
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        # Update parameters and take a step using the computed gradient.
        # The optimizer dictates the "update rule"--how the parameters are
        # modified based on their gradients, the learning rate, etc.
        optimizer.step()

        # Update the learning rate.
        scheduler.step()

        train_acc = accuracy_score(train_targets, train_preds)
        train_precision = precision_score(train_targets, train_preds)
        train_recall = recall_score(train_targets, train_preds)
        train_f1 = f1_score(train_targets, train_preds)

        io_total_train_acc += train_acc
        io_total_train_prec += train_precision
        io_total_train_recall += train_recall
        io_total_train_f1 += train_f1

    io_avg_train_acc = io_total_train_acc / len(train_dataloader)
    io_avg_train_prec = io_total_train_prec / len(train_dataloader)
    io_avg_train_recall = io_total_train_recall / len(train_dataloader)
    io_avg_train_f1 = io_total_train_f1 / len(train_dataloader)
    print(
        f'Epoch {epoch_i+1} : \n\
        Train_acc : {io_avg_train_acc}\n\
        Train_F1 : {io_avg_train_f1}\n\
        Train_precision : {io_avg_train_prec}\n\
        Train_recall : {io_avg_train_recall}'
    )

    # Calculate the average loss over all of the batches.
    avg_train_loss = total_train_loss / len(train_dataloader)            
    
    # Measure how long this epoch took.
    training_time = format_time(time.time() - t0)

    print("")
    print("  Average training loss: {0:.2f}".format(avg_train_loss))
    print("  Training epoch took: {:}".format(training_time))
        
    # ========================================
    #               Validation
    # ========================================
    # After the completion of each training epoch, measure our performance on
    # our validation set.

    print("")
    print("Running Validation...")

    t0 = time.time()

    # Put the model in evaluation mode--the dropout layers behave differently
    # during evaluation.
    model.eval()

    valid_preds = []
    valid_targets = []

    # Tracking variables 
    total_eval_accuracy = 0
    total_eval_loss = 0
    nb_eval_steps = 0

    # Evaluate data for one epoch
    for batch in validation_dataloader:
        
        # Unpack this training batch from our dataloader. 
        #
        # As we unpack the batch, we'll also copy each tensor to the GPU using 
        # the `to` method.
        #
        # `batch` contains three pytorch tensors:
        #   [0]: input ids 
        #   [1]: attention masks
        #   [2]: labels 
        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)
        
        # Tell pytorch not to bother with constructing the compute graph during
        # the forward pass, since this is only needed for backprop (training).
        with torch.no_grad():        

            # Forward pass, calculate logit predictions.
            # token_type_ids is the same as the "segment ids", which 
            # differentiates sentence 1 and 2 in 2-sentence tasks.
            result = model(b_input_ids, 
                           token_type_ids=None, 
                           attention_mask=b_input_mask,
                           labels=b_labels,
                           return_dict=True)

        # Get the loss and "logits" output by the model. The "logits" are the 
        # output values prior to applying an activation function like the 
        # softmax.
        loss = result.loss
        logits = result.logits

        valid_preds.extend(logits.argmax(dim=1).cpu().numpy())
        valid_targets.extend(batch[2].numpy())

        # Accumulate the validation loss.
        total_eval_loss += loss.item()

        # Move logits and labels to CPU
        logits = logits.detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()

        # Calculate the accuracy for this batch of test sentences, and
        # accumulate it over all batches.
        total_eval_accuracy += flat_accuracy(logits, label_ids)
        
        valid_acc = accuracy_score(valid_targets, valid_preds)
        valid_precision = precision_score(valid_targets, valid_preds)
        valid_recall = recall_score(valid_targets, valid_preds)
        valid_f1 = f1_score(valid_targets, valid_preds)

        io_total_valid_acc += valid_acc
        io_total_valid_prec += valid_precision
        io_total_valid_recall += valid_recall
        io_total_valid_f1 += valid_f1

    io_avg_valid_acc = io_total_valid_acc / len(validation_dataloader)
    io_avg_valid_prec = io_total_valid_prec / len(validation_dataloader)
    io_avg_valid_recall = io_total_valid_recall / len(validation_dataloader)
    io_avg_valid_f1 = io_total_valid_f1 / len(validation_dataloader)
    print(
            f'Epoch {epoch_i+1} : \n\
            Valid_acc : {io_avg_valid_acc}\n\
            Valid_F1 : {io_avg_valid_f1}\n\
            Valid_precision : {io_avg_valid_prec}\n\
            Valid_recall : {io_avg_valid_recall}'
          )

    # Report the final accuracy for this validation run.
    avg_val_accuracy = total_eval_accuracy / len(validation_dataloader)
    print("  Accuracy: {0:.2f}".format(avg_val_accuracy))

    # Calculate the average loss over all of the batches.
    avg_val_loss = total_eval_loss / len(validation_dataloader)
    
    # Measure how long the validation run took.
    validation_time = format_time(time.time() - t0)
    
    print("  Validation Loss: {0:.2f}".format(avg_val_loss))
    print("  Validation took: {:}".format(validation_time))

    # Record all statistics from this epoch.
    training_stats.append(
        {
            'epoch': epoch_i + 1,
            'Training Loss': avg_train_loss,
            'Training Accur.': io_avg_train_acc,
            'Training F1': io_avg_train_f1,
            'Training Precision': io_avg_train_prec, 
            'Training Recall': io_avg_train_recall,
            'Valid. Loss': avg_val_loss,
            'Valid. Accur.': avg_val_accuracy,
            'Valid. F1': io_avg_valid_f1,
            'Valid. Precision': io_avg_valid_prec, 
            'Valid. Recall': io_avg_valid_recall,
            'Training Time': training_time,
            'Validation Time': validation_time
        }
    )

print("")
print("Training complete!")

print("Total training took {:} (h:mm:ss)".format(format_time(time.time()-total_t0)))

### Analysis

In [None]:
import pandas as pd

# Display floats with two decimal places.
pd.set_option('precision', 2)

# Create a DataFrame from our training statistics.
df_stats = pd.DataFrame(data=training_stats)

# Use the 'epoch' as the row index.
df_stats = df_stats.set_index('epoch')

# A hack to force the column headers to wrap.
#df = df.style.set_table_styles([dict(selector="th",props=[('max-width', '70px')])])

# Display the table.
df_stats

##### Loss per epoch - Training VS Validation

In [None]:
import matplotlib.pyplot as plt
% matplotlib inline

import seaborn as sns

# Use plot styling from seaborn.
sns.set(style='darkgrid')

# Increase the plot size and font size.
sns.set(font_scale=1.5)
plt.rcParams["figure.figsize"] = (12,6)

# Plot the learning curve.
plt.plot(df_stats['Training Loss'], 'b-o', label="Training")
plt.plot(df_stats['Valid. Loss'], 'g-o', label="Validation")

# Label the plot.
plt.title("Training & Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.xticks([1, 2, 3, 4])

plt.show()

##### Accuracy per epoch - Training VS Validation

In [None]:
# Plot the learning curve.
plt.plot(df_stats['Training Accur.'], 'b-o', label="Training")
plt.plot(df_stats['Valid. Accur.'], 'g-o', label="Validation")

# Label the plot.
plt.title("Training & Validation Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.xticks([1, 2, 3, 4])

plt.show()

##### F1 per epoch - Training VS Validation

In [None]:
# Plot the learning curve.
plt.plot(df_stats['Training F1'], 'b-o', label="Training")
plt.plot(df_stats['Valid. F1'], 'g-o', label="Validation")

# Label the plot.
plt.title("Training & Validation F1")
plt.xlabel("Epoch")
plt.ylabel("F1")
plt.legend()
plt.xticks([1, 2, 3, 4])

plt.show()

##### Recall per epoch - Training VS Validation

In [None]:
# Plot the learning curve.
plt.plot(df_stats['Training Recall'], 'b-o', label="Training")
plt.plot(df_stats['Valid. Recall'], 'g-o', label="Validation")

# Label the plot.
plt.title("Training & Validation Recall")
plt.xlabel("Epoch")
plt.ylabel("Recall")
plt.legend()
plt.xticks([1, 2, 3, 4])

plt.show()

### Save model

In [None]:
#model_name = "Bert4SeqClassif_augm_20220804_1249.pt"
#torch.save(model.state_dict(), model_name)

## Testing the model

In [None]:
model = BertForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels = NUM_CLASSES,
    output_attentions = False,
    output_hidden_states = False,
)
model.cuda()

model.load_state_dict(torch.load('Bert4SeqClassif_augm_20220804_1249.pt'))

Some weights of the model checkpoint at nlpaueb/legal-bert-small-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification we

<All keys matched successfully>

In [None]:
all_sentences = get_sentences("ToS/TestSetWithAugmentedData/SentencesBeginning/")
all_labels = get_labels("ToS/TestSetWithAugmentedData/LabelsBeginning/")

input_ids = []
attention_masks = []
tmp_bool = True

for sent in all_sentences:
    encoded_dict = tokenizer.encode_plus(
                        sent,            
                        add_special_tokens = True,
                        max_length = 512,         
                        pad_to_max_length = True, 
                        return_attention_mask = True,
                        return_tensors = 'pt',     
                   )
    if tmp_bool:
      tmp_bool = False
      print(f'Keys of encoded_dict: {encoded_dict.keys()}')
       
    input_ids.append(encoded_dict['input_ids'])
    
    attention_masks.append(encoded_dict['attention_mask'])

input_ids = torch.cat(input_ids, dim=0)
attention_masks = torch.cat(attention_masks, dim=0)
labels = torch.tensor(all_labels)

test_dataset = TensorDataset(input_ids, attention_masks, labels)

test_dataloader = DataLoader(
            test_dataset, 
            sampler = RandomSampler(test_dataset),
            batch_size = BATCH_SIZE 
        )

In [None]:
# ========================================
#               Test
# ========================================
# After the completion of each test epoch, measure our performance on
# our test set.

print("Running Testing...")

t0 = time.time()

# Put the model in evaluation mode--the dropout layers behave differently
# during evaluation.
model.eval()

test_preds = []
test_targets = []

# Tracking variables 
total_test_accuracy = 0
total_test_loss = 0
nb_test_steps = 0

io_total_test_acc = 0
io_total_test_prec = 0
io_total_test_recall = 0
io_total_test_f1 = 0

# Evaluate data for one epoch
for batch in test_dataloader:
    
    # Unpack this training batch from our dataloader. 
    #
    # As we unpack the batch, we'll also copy each tensor to the GPU using 
    # the `to` method.
    #
    # `batch` contains three pytorch tensors:
    #   [0]: input ids 
    #   [1]: attention masks
    #   [2]: labels 
    b_input_ids = batch[0].to(device)
    b_input_mask = batch[1].to(device)
    b_labels = batch[2].to(device)
    
    # Tell pytorch not to bother with constructing the compute graph during
    # the forward pass, since this is only needed for backprop (training).
    with torch.no_grad():        

        # Forward pass, calculate logit predictions.
        # token_type_ids is the same as the "segment ids", which 
        # differentiates sentence 1 and 2 in 2-sentence tasks.
        result = model(b_input_ids, 
                        token_type_ids=None, 
                        attention_mask=b_input_mask,
                        labels=b_labels,
                        return_dict=True)

    # Get the loss and "logits" output by the model. The "logits" are the 
    # output values prior to applying an activation function like the 
    # softmax.
    loss = result.loss
    logits = result.logits

    test_preds.extend(logits.argmax(dim=1).cpu().numpy())
    test_targets.extend(batch[2].numpy())

    # Accumulate the test loss.
    total_test_loss += loss.item()

    # Move logits and labels to CPU
    logits = logits.detach().cpu().numpy()
    label_ids = b_labels.to('cpu').numpy()

    # Calculate the accuracy for this batch of test sentences, and
    # accumulate it over all batches.
    total_test_accuracy += flat_accuracy(logits, label_ids)
    
    test_acc = accuracy_score(test_targets, test_preds)
    test_precision = precision_score(test_targets, test_preds)
    test_recall = recall_score(test_targets, test_preds)
    test_f1 = f1_score(test_targets, test_preds)

    io_total_test_acc += test_acc
    io_total_test_prec += test_precision
    io_total_test_recall += test_recall
    io_total_test_f1 += test_f1

# Report the final accuracy for this test run.
avg_test_accuracy = total_test_accuracy / len(test_dataloader)
avg_test_acc = io_total_test_acc / len(test_dataloader)
avg_test_prec = io_total_test_prec / len(test_dataloader)
avg_test_recall = io_total_test_recall / len(test_dataloader)
avg_test_f1 = io_total_test_f1 / len(test_dataloader)
print("  =>Accuracy:  {0:.2f}".format(avg_test_acc))
print("  =>Precision: {0:.2f}".format(avg_test_prec))
print("  =>Recall:    {0:.2f}".format(avg_test_recall))
print("  =>F1:        {0:.2f}".format(avg_test_f1))

# Calculate the average loss over all of the batches.
avg_test_loss = total_test_loss / len(test_dataloader)

# Measure how long the test run took.
test_time = format_time(time.time() - t0)

print("  Test Loss: {0:.2f}".format(avg_test_loss))
print("  Test took: {:}".format(test_time))
