<a href="https://colab.research.google.com/github/Jeevesh8/arg_mining/blob/main/experiments/BERT_am.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
#if running on colab, install below 4
!git clone https://github.com/Jeevesh8/arg_mining
!pip install transformers
!pip install seqeval datasets allennlp
!pip install flax
!pip install sentencepiece
#if connected to local runtime, run the next command too
#pip install bs4 tensorflow torch 

In [None]:
#Run to ignore warnings
import warnings
import numpy as np
warnings.filterwarnings('ignore')

### Load Metric

In [None]:
%%capture
from datasets import load_metric
metric = load_metric('seqeval')

### Define & Load Tokenizer, Model, Dataset

In [None]:
import torch
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
device

device(type='cpu')

In [None]:
model_version = 'bert-base-uncased'

In [None]:
%%capture
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained(model_version,
                                          bos_token = "[CLS]",
                                          eos_token = "[SEP]")
transformer_model = AutoModel.from_pretrained(model_version).to(device)

In [None]:
import torch.nn as nn

#### To add extra token type embeddings...

In [None]:
def resize_token_type_embeddings(transformer_model, new_size):
    old_embeddings = transformer_model.embeddings.token_type_embeddings.weight
    old_size, hidden_dim = old_embeddings.shape
    transformer_model.embeddings.token_type_embeddings = nn.Embedding(new_size, hidden_dim, device=transformer_model.device)
    with torch.no_grad():
        transformer_model.embeddings.token_type_embeddings.weight[:old_size] = old_embeddings

#resize_token_type_embeddings(transformer_model, 2)
#transformer_model.config.type_vocab_size = 2

#### Load in discourse markers(Provide ``Discourse_Markers.txt``)

In [None]:
with open('./Discourse_Markers.txt') as f:
    discourse_markers = [dm.strip() for dm in f.readlines()]

* Change the ``batch_size`` in ``arg_mining/datasets/cmv_modes/configs.py`` before running below cell, as needed. [By default: 8]

* Can also change ``max_len`` in the same file to suit the maximum length of your model. All threads will be truncated at ``max_len`` length. 

In [None]:
%%capture
from arg_mining.datasets.cmv_modes import load_dataset, data_config

In [None]:
tokenizer.add_tokens(data_config["special_tokens"], special_tokens=True)

transformer_model.resize_token_embeddings(len(tokenizer))

Embedding(30537, 768)

### Function to get datasets
* Change split sizes, if needed.

In [None]:
def get_datasets():
    train_dataset, valid_dataset, test_dataset = load_dataset(tokenizer=tokenizer,
                                                              train_sz=50,
                                                              test_sz=50,
                                                              mask_tokens=discourse_markers,)
    return train_dataset, valid_dataset, test_dataset

### Wrap dataset in ``get_comment_wise_dataset`` if you want to get comment wise dataset

In [None]:
from typing import List, Tuple

In [None]:
def split_encoding(tokenized_thread: List[int], 
                   split_on: List[int], 
                   eos_token_id: int) -> List[List[int]]:
    """Splits tokenized_thread into multiple lists at each occurance of 
    a token_id specified in split_on or the eos_token_id.
    
    1. The eos_token_id is retained in the last splitted component.
    2. Each matched token_id from split_on is retained in the component that 
       follows it.
    """
    splitted = [[]]
    for elem in tokenized_thread:
        if elem in split_on:
            splitted.append([])
        splitted[-1].append(elem)
        if elem == eos_token_id:
            break
    return splitted

def pad_batch(elems: List[List[int]], pad_token_id: int) -> List[List[int]]:
    """Pads all lists in elems to the maximum list length of any list in 
    elems. Pads with pad_token_id.
    """
    max_len = max([len(elem) for elem in elems])
    return [elem+[pad_token_id]*(max_len-len(elem)) for elem in elems]

def get_comment_wise_dataset(dataset,
                             max_len: int=512,
                             batch_size: int=8) -> Tuple[List[List[int]], 
                                                         List[List[int]], 
                                                         List[List[int]]]:
    """
    Args:
        dataset:     A numpy iterator dataset for threads, as returned from 
                     get_datasets() function above.
        max_len:     Maximum length at which to truncate any comment.
        batch_size:  Number of comments in a batch
    
    Returns:
        A tuple having batched & padded(to max. length in batch) tokenized threads, 
        masked threads, and component type labels; where each element corresponds
        to a comment in some thread.
    
    NOTE:
        This function removes the extra num_devices dimension from the elements 
        of dataset provided.
    """
    user_token_indices = tokenizer.encode("[UNU]"+"".join([f"[USER{i}]" for i in range(data_config["max_users"])]))[1:-1]
    comment_wise_tokenized_threads = []
    comment_wise_masked_threads = []
    comment_wise_comp_type_labels = []

    for (tokenized_threads, masked_threads, comp_type_labels, _ ) in dataset:
        tokenized_threads = np.squeeze(tokenized_threads, axis=0).tolist()
        masked_threads = np.squeeze(masked_threads, axis=0).tolist()
        comp_type_labels = np.squeeze(comp_type_labels, axis=0).tolist()

        for tokenized_thread, masked_thread, comp_type_label in zip(tokenized_threads, masked_threads, comp_type_labels):
            splitted_encodings = split_encoding(tokenized_thread, user_token_indices, tokenizer.eos_token_id)
            for elem in splitted_encodings:
                comment_wise_tokenized_threads.append(elem)
                comment_wise_masked_threads.append(masked_thread[:len(elem)])
                comment_wise_comp_type_labels.append(comp_type_label[:len(elem)])
                masked_thread, comp_type_label = masked_thread[len(elem):], comp_type_label[len(elem):]
                
    i = 0
    cw_tokenized_threads, cw_masked_threads, cw_comp_type_labels = [], [], []
    while i<len(comment_wise_tokenized_threads):
         cw_tokenized_threads.append(comment_wise_tokenized_threads[i][:max_len])
         cw_masked_threads.append(comment_wise_masked_threads[i][:max_len])
         cw_comp_type_labels.append(comment_wise_comp_type_labels[i][:max_len])
         i += 1
         
         if i%batch_size==0:
             yield (pad_batch(cw_tokenized_threads, tokenizer.pad_token_id), 
                    pad_batch(cw_masked_threads, tokenizer.pad_token_id),
                    pad_batch(cw_comp_type_labels, data_config["pad_for"]["comp_type_labels"]))
            
             cw_tokenized_threads, cw_masked_threads, cw_comp_type_labels = [], [], []

### Sample Run for dataset

In [None]:
"""
train_dataset, valid_dataset, test_dataset = get_datasets()
for (tokenized_threads, masked_threads, comp_type_labels) in get_comment_wise_dataset(train_dataset):
    print(len(tokenized_threads[0]))
    print(tokenizer.batch_decode(tokenized_threads))
    print(tokenizer.batch_decode(masked_threads))
    print(comp_type_labels)
    break
"""

[[101, 4642, 2615, 1024, 14048, 2005, 5841, 16582, 2015, 103, 1037, 5860, 20026, 28230, 7337, 2000, 2562, 2769, 1004, 2204, 5841, 2503, 1037, 3327, 2591, 4418, 1010, 1004, 2323, 2022, 22585, 2030, 5892, 103, 1037, 19229, 2005, 6107, 1012, 30527, 30525, 30525, 1005, 1015, 1012, 30525, 14048, 2003, 5860, 20026, 28230, 1012, 30525, 30525, 1005, 1015, 1012, 1015, 1012, 30525, 1037, 1000, 3115, 1000, 14763, 4118, 2052, 5676, 2008, 2296, 23761, 2052, 2022, 16330, 2006, 1996, 2168, 9181, 1011, 1996, 9967, 1997, 5918, 1997, 1996, 2597, 1998, 22617, 7882, 2000, 2008, 2597, 1012, 30525, 30525, 1005, 1015, 1012, 1016, 1012, 30525, 2040, 2019, 23761, 4282, 2003, 5681, 6179, 2005, 2216, 5841, 2029, 5478, 8498, 6550, 2007, 3056, 2111, 1999, 3145, 4460, 1006, 1041, 1012, 1043, 1012, 4341, 1007, 1010, 103, 2005, 1996, 6565, 3484, 1997, 5841, 1010, 2040, 2115, 2814, 2024, 2003, 8360, 2135, 22537, 1012, 30525, 30525, 1005, 1015, 1012, 1017, 1012, 30525, 4352, 2619, 2000, 25022, 11890, 2819, 15338, 1996,

### Define layers for a Linear-Chain-CRF

In [None]:
from allennlp.modules.conditional_random_field import ConditionalRandomField as crf

ac_dict = data_config["arg_components"]

allowed_transitions =([(ac_dict["B-C"], ac_dict["I-C"]), 
                       (ac_dict["B-P"], ac_dict["I-P"])] + 
                      [(ac_dict["I-C"], ac_dict[ct]) 
                        for ct in ["I-C", "B-C", "B-P", "O"]] +
                      [(ac_dict["I-P"], ac_dict[ct]) 
                        for ct in ["I-P", "B-C", "B-P", "O"]] +
                      [(ac_dict["O"], ac_dict[ct]) 
                        for ct in ["O", "B-C", "B-P"]])
                    
linear_layer = nn.Linear(transformer_model.config.hidden_size,
                         len(ac_dict)).to(device)

crf_layer = crf(num_tags=len(ac_dict),
                constraints=allowed_transitions,
                include_start_end_transitions=False).to(device)

cross_entropy_layer = nn.CrossEntropyLoss(weight=torch.log(torch.tensor([3.3102, 61.4809, 3.6832, 49.6827, 2.5639], 
                                                                        device=device)), reduction='none')

### Loss and Prediction Function

In [None]:
def compute(batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
            preds: bool=False, cross_entropy: bool=True):
    """
    Args:
        batch:  A tuple having tokenized thread of shape [batch_size, seq_len],
                component type labels of shape [batch_size, seq_len], and a global
                attention mask for Longformer, of the same shape.
        
        preds:  If True, returns a List(of batch_size size) of Tuples of form 
                (tag_sequence, viterbi_score) where the tag_sequence is the 
                viterbi-decoded sequence, for the corresponding sample in the batch.
        
        cross_entropy:  This argument will only be used if preds=False, i.e., if 
                        loss is being calculated. If True, then cross entropy loss
                        will also be added to the output loss.
    
    Returns:
        Either the predicted sequences with their scores for each element in the batch
        (if preds is True), or the loss value summed over all elements of the batch
        (if preds is False).
    """
    tokenized_threads, token_type_ids, comp_type_labels = batch
    
    pad_mask = torch.where(tokenized_threads!=tokenizer.pad_token_id, 1, 0)
    
    logits = linear_layer(transformer_model(input_ids=tokenized_threads,
                                            attention_mask=pad_mask,).last_hidden_state)
    
    if preds:
        return crf_layer.viterbi_tags(logits, pad_mask)
    
    log_likelihood = crf_layer(logits, comp_type_labels, pad_mask)
    
    if cross_entropy:
        logits = logits.reshape(-1, logits.shape[-1])
        
        pad_mask, comp_type_labels = pad_mask.reshape(-1), comp_type_labels.reshape(-1)
        
        ce_loss = torch.sum(pad_mask*cross_entropy_layer(logits, comp_type_labels))
        
        return ce_loss - log_likelihood

    return -log_likelihood

### Define optimizer

In [None]:
from itertools import chain

import torch.optim as optim

optimizer = optim.Adam(params = chain(transformer_model.parameters(),
                                      linear_layer.parameters(),
                                      crf_layer.parameters()),
                       lr = 2e-5,)

### Training And Evaluation Loops

In [None]:
def train(dataset):
    accumulate_over = 4
    
    optimizer.zero_grad()

    for i, (tokenized_threads, masked_threads, comp_type_labels) in enumerate(get_comment_wise_dataset(dataset)):
        
        #Cast to PyTorch tensor
        tokenized_threads = torch.tensor(tokenized_threads, device=device)
        masked_threads = torch.tensor(masked_threads, device=device)
        comp_type_labels = torch.tensor(comp_type_labels, device=device, dtype=torch.long)
        
        loss = compute((tokenized_threads,
                        torch.where(masked_threads==tokenizer.mask_token_id, 1, 0), 
                        comp_type_labels,))/data_config["batch_size"]
        
        print("Loss: ", loss)
        loss.backward()
        
        if i%accumulate_over==accumulate_over-1:
            optimizer.step()
            optimizer.zero_grad()
    
    optimizer.step()

In [None]:
def evaluate(dataset, metric):
    
    int_to_labels = {v:k for k, v in ac_dict.items()}
    with torch.no_grad():
        for tokenized_threads, masked_threads, comp_type_labels in dataset:
        
            #Cast to PyTorch tensor
            tokenized_threads = torch.tensor(tokenized_threads, device=device)
            masked_threads = torch.tensor(masked_threads, device=device)
            comp_type_labels = torch.tensor(comp_type_labels, device=device)
            
            preds = compute((tokenized_threads,
                            torch.where(masked_threads==tokenizer.mask_token_id, 1, 0), 
                            comp_type_labels,), preds=True)
            
            lengths = torch.sum(torch.where(tokenized_threads!=tokenizer.pad_token_id, 1, 0), 
                                axis=-1)
            
            preds = [ [int_to_labels[pred] for pred in pred[0][:lengths[i]]]
                    for i, pred in enumerate(preds)
                    ]
            
            refs = [ [int_to_labels[ref] for ref in labels[:lengths[i]]]
                    for i, labels in enumerate(comp_type_labels.cpu().tolist())
                ]
            
            metric.add_batch(predictions=preds, 
                            references=refs,)
                            #tokenized_threads=tokenized_threads.cpu().tolist())
        
    print(metric.compute())

### Final Training

In [None]:
n_epochs = 35

In [None]:
for epoch in range(n_epochs):
    print(f"------------EPOCH {epoch+1}---------------")
    train_dataset, _, test_dataset = get_datasets()
    train(train_dataset)
    evaluate(test_dataset, metric)