<a href="https://colab.research.google.com/github/Jeevesh8/arg_mining/blob/main/experiments/long_context_am.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Install Dependencies

In [None]:
%%capture
#if running on colab, install below 4
#!git clone https://github.com/Jeevesh8/arg_mining
#!pip install transformers
#!pip install seqeval datasets allennlp
#!pip install flax

#if connected to local runtime, run the next command too
#pip install bs4 tensorflow torch 



*   Update ``arg_mining/datasets/cmv_modes/configs.py`` as per your requirements, all experiments considered till now, set ``batch_size`` to 2, and all other variables with their default value.



In [1]:
#Run to ignore warnings
import warnings
warnings.filterwarnings('ignore')

### Load Metric

In [2]:
%%capture
from datasets import load_metric
metric = load_metric('seqeval')

### Krippendorff's Alpha Metric

In [3]:
from typing import List, Tuple
import re

class krip_alpha():
    """A module for computing sentence level Krippendorff's Alpha,
    for argumentative components  annotated at the token level. Must use
    labels ["B-C", "B-P"].
    """
    def __init__(self) -> None:
        """See self.compute_metric() for what each of these data actually mean.
        """
        self.pred_has_claim = 0
        self.ref_has_claim = 0
        self.pred_has_premise = 0
        self.ref_has_premise = 0
        
        self.claim_wise_agreement = 0
        self.premise_wise_agreement = 0
        
        self.claim_wise_disagreement = 0
        self.premise_wise_disagreement = 0
    
        self.total_sentences = 0
        
        self.has_both_ref = 0
        self.has_both_pred = 0
        self.has_none_ref = 0
        self.has_none_pred = 0

    def preprocess(self, threads: List[List[int]]) -> List[List[List[int]]]:
        """
        Args:
            threads:    A list of all threads in a batch. A thread is a list of 
                        integers corresponding to token_ids of the tokens in the 
                        thread.
        Returns:
            A List with all the threads, where each thread now consists of 
            sentence lists. Where, a sentence list in a thread list is the list 
            of token_ids corresponding to a sentence in a thread. 
        """
        threads_lis = []

        for i, thread in enumerate(threads):
            sentence = []
            threads_lis.append([])
            for j, token_id in enumerate(thread):
                if token_id==tokenizer.pad_token_id:
                    break
                
                sentence.append(token_id)
                token = tokenizer.convert_ids_to_tokens(token_id)
                #print("appended token:", token)

                next_token = 'None' if j==len(thread) else tokenizer.convert_ids_to_tokens(thread[j+1])

                if (token.count('.')+token.count('?')+token.count('!')>=1 and 
                    next_token.count('.')+next_token.count('?')+next_token.count('!')==0):

                    threads_lis[i].append(sentence)
                    #print("Sample sentence: ", tokenizer.decode(sentence))
                    sentence = []
                
                elif re.findall(r"\[USER\d+\]|\[UNU\]", token)!=[]:
                    prev_part = tokenizer.decode(sentence[:-1])[1:-1]
                    if re.search(r'[a-zA-Z]', prev_part) is not None:
                        threads_lis[i].append(sentence[:-1])
                        #print("Sample sentence just befor user token:", tokenizer.decode(sentence[:-1]))
                        sentence = [sentence[-1]]
                    else:
                        k=len(sentence)-2
                        while k>=0 and sentence[k]==tokenizer.convert_tokens_to_ids('Ġ'):
                            k-=1
                        sentence = sentence[k+1:]
                        threads_lis[i][-1] += sentence[:k]
                        #print("Sample sentence just befor user token:", tokenizer.decode(threads_lis[i][-1]))
                
            has_rem_token = False
            for elem in sentence:
                if (elem!=tokenizer.convert_tokens_to_ids('Ġ') and
                    elem!=tokenizer.eos_token_id):
                    has_rem_token = True
                    break
            
            if has_rem_token:
                threads_lis[i].append(sentence)
                #print("Sample sentence at end of thread: ", tokenizer.decode(sentence))
                sentence = []

        return threads_lis

    def get_sentence_wise_preds(self, threads: List[List[List[int]]], 
                                      predictions: List[List[str]]) -> List[List[List[str]]]:
        """Splits the prediction corresponding to each thread, into predictions
        for each sentence in the corresponding thread in "threads" list.
        Args:
            threads:      A list of threads, where each thread consists of further 
                          lists corresponding to the various sentences in the
                          thread. [As output by self.preprocess()]
            predictions:  A list of predictions for each thread, in the threads
                          list. Each prediciton consists of a list of componenet 
                          types corresponding to each token in a thread.
        Returns:
            The predictions list, with each prediction split into predictions 
            corresponding to the sentences in the corresponding thread specified
            in the threads list. 
        """
        sentence_wise_preds = []
        for i, thread in enumerate(threads):
            next_sentence_beg = 0
            sentence_wise_preds.append([])
            for sentence in thread:
                sentence_wise_preds[i].append(
                    predictions[i][next_sentence_beg:next_sentence_beg+len(sentence)])
                next_sentence_beg += len(sentence)
        return sentence_wise_preds
    
    def update_state(self, pred_sentence: List[str], ref_sentence: List[str]) -> None:
        """Updates the various information maintained for the computation of
        Krippendorff's alpha, based on the predictions(pred_sentence) and 
        references(ref_sentence) provided for a particular sentence, in some 
        thread.
        """
        self.total_sentences += 1
        
        if 'B-C' in pred_sentence:
            self.pred_has_claim += 1
            if 'B-C' in ref_sentence:
                self.ref_has_claim += 1
                self.claim_wise_agreement += 1
            else:
                self.claim_wise_disagreement += 1
            
        elif 'B-C' in ref_sentence:
            self.ref_has_claim += 1
            self.claim_wise_disagreement += 1
        
        else:
            self.claim_wise_agreement += 1
        
        if 'B-P' in pred_sentence:
            self.pred_has_premise += 1
            if 'B-P' in ref_sentence:
                self.ref_has_premise += 1
                self.premise_wise_agreement += 1
            else:
                self.premise_wise_disagreement += 1

        elif 'B-P' in ref_sentence:
            self.ref_has_premise += 1
            self.premise_wise_disagreement += 1
        
        else:
            self.premise_wise_agreement += 1
        
        if 'B-C' in pred_sentence and 'B-P' in pred_sentence:
            self.has_both_pred += 1
        
        if 'B-C' in ref_sentence and 'B-P' in ref_sentence:
            self.has_both_ref += 1
        
        if 'B-C' not in pred_sentence and 'B-P' not in pred_sentence:
            self.has_none_pred += 1
        
        if 'B-C' not in ref_sentence and 'B-P' not in ref_sentence:
            self.has_none_ref += 1
        return

    def add_batch(self, predictions: List[List[str]], 
                  references: List[List[str]], 
                  tokenized_threads: List[List[int]]) -> None:
        """Add a batch of predictions and references for the computation of 
        Krippendorff's alpha.
        Args:
            predictions:      A list of predictions for each thread, in the 
                              threads list. Each prediciton consists of a list 
                              of component types corresponding to each token in 
                              a thread.
            references:       Same structure as predictions, but consisting of 
                              acutal gold labels, instead of predicted ones.
            tokenized_thread: A list of all threads in a batch. A thread is a 
                              list of integers corresponding to token_ids of the
                              tokens in the thread.
        """
        threads = self.preprocess(tokenized_threads)
        
        sentence_wise_preds = self.get_sentence_wise_preds(threads, predictions)
        sentence_wise_refs = self.get_sentence_wise_preds(threads, references)

        for pred_thread, ref_thread in zip(sentence_wise_preds, sentence_wise_refs):
            for pred_sentence, ref_sentence in zip(pred_thread, ref_thread):
                self.update_state(pred_sentence, ref_sentence)

    def compute(self, print_additional: bool=True) -> None:
        """Prints out the metric, for the batched added till now. And then 
        resets all data being maintained by the metric. 
        Args:
            print_additional:   If True, will print all the data being 
                                maintained instead of just the Krippendorff's 
                                alphas for claims and premises.
        """
        print("Sentence level Krippendorff's alpha for Claims: ", 1-(self.claim_wise_disagreement/(self.claim_wise_agreement+self.claim_wise_disagreement))/0.5)
        print("Sentence level Krippendorff's alpha for Premises: ", 1-(self.premise_wise_disagreement/(self.premise_wise_agreement+self.premise_wise_disagreement))/0.5)
        
        if print_additional:
            print("Additional attributes: ")
            print("\tTotal Sentences:", self.total_sentences)
            print("\tPrediction setences having claims:", self.pred_has_claim)
            print("\tPrediction sentences having premises:", self.pred_has_premise)
            print("\tReference setences having claims:", self.ref_has_claim)
            print("\tReference sentences having premises:", self.ref_has_premise)
            print("\n")
            print("\tPrediction Sentence having both claim and premise:", self.has_both_pred)
            print("\tPrediction Sentence having neither claim nor premise:", self.has_none_pred)
            print("\tReference Sentence having both claim and premise:", self.has_both_ref)
            print("\tReference Sentence having neither claim nor premise:", self.has_none_ref)
            print("\n")
            print("\tSentences having claim in both reference and prediction:", self.claim_wise_agreement)
            print("\tSentences having claim in only one of reference or prediction:", self.claim_wise_disagreement)
            print("\tSentences having premise in both reference and prediction:", self.premise_wise_agreement)
            print("\tSentences having premise in only one of reference or prediction:", self.premise_wise_disagreement)
        self.__init__()

In [None]:
metric = krip_alpha()

### Define & Load Tokenizer, Model, Dataset

In [3]:
import torch
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [4]:
device

device(type='cuda', index=0)

#### Load Model/Tokenizer from HF

In [5]:
model_version = 'allenai/longformer-base-4096'

In [None]:
%%capture
from transformers import LongformerTokenizer, AutoModel
tokenizer = LongformerTokenizer.from_pretrained(model_version)
transformer_model = AutoModel.from_pretrained(model_version).to(device)

#### Or load them from pretrained files...

In [6]:
from transformers import LongformerTokenizer, LongformerModel

tokenizer = LongformerTokenizer.from_pretrained('./4epoch_complete/tokenizer/')
transformer_model = LongformerModel.from_pretrained('./4epoch_complete/model/').to(device)

Some weights of the model checkpoint at ./4epoch_complete/model/ were not used when initializing LongformerModel: ['lm_head.decoder.weight', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.bias', 'lm_head.dense.bias', 'lm_head.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing LongformerModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LongformerModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LongformerModel were not initialized from the model checkpoint at ./4epoch_complete/model/ and are newly initialized: ['longformer.pooler.dense.bias', 'longformer.pooler.dense.weight']
You should probably TRAIN this model on a down-stre

In [7]:
import torch.nn as nn

#### To add extra token type embeddings...

In [9]:
def resize_token_type_embeddings(transformer_model, new_size):
    old_embeddings = transformer_model.embeddings.token_type_embeddings.weight
    old_size, hidden_dim = old_embeddings.shape
    transformer_model.embeddings.token_type_embeddings = nn.Embedding(new_size, hidden_dim, device=transformer_model.device)
    with torch.no_grad():
        transformer_model.embeddings.token_type_embeddings.weight[:old_size] = old_embeddings

resize_token_type_embeddings(transformer_model, 2)

In [None]:
transformer_model.config.type_vocab_size = 2

#### Load in discourse markers

In [8]:
with open('./Discourse_Markers.txt') as f:
    discourse_markers = [dm.strip() for dm in f.readlines()]

In [9]:
%%capture
from arg_mining.datasets.cmv_modes import load_dataset, data_config

#### Add special tokens to tokenizer and model vocab, if not already there

In [None]:
tokenizer.add_tokens(data_config["special_tokens"])

transformer_model.resize_token_embeddings(len(tokenizer))

#### Function to get train, test data (80/20 split currently)

In [10]:
def get_datasets():
    train_dataset, valid_dataset, test_dataset = load_dataset(tokenizer=tokenizer,
                                                              train_sz=50,
                                                              test_sz=50,
                                                              mask_tokens=discourse_markers)
    return train_dataset, valid_dataset, test_dataset

### Define layers for a Linear-Chain-CRF

In [11]:
from allennlp.modules.conditional_random_field import ConditionalRandomField as crf

ac_dict = data_config["arg_components"]

allowed_transitions =([(ac_dict["B-C"], ac_dict["I-C"]), 
                       (ac_dict["B-P"], ac_dict["I-P"])] + 
                      [(ac_dict["I-C"], ac_dict[ct]) 
                        for ct in ["I-C", "B-C", "B-P", "O"]] +
                      [(ac_dict["I-P"], ac_dict[ct]) 
                        for ct in ["I-P", "B-C", "B-P", "O"]] +
                      [(ac_dict["O"], ac_dict[ct]) 
                        for ct in ["O", "B-C", "B-P"]])
                    
linear_layer = nn.Linear(transformer_model.config.hidden_size,
                         len(ac_dict)).to(device)

crf_layer = crf(num_tags=len(ac_dict),
                constraints=allowed_transitions,
                include_start_end_transitions=False).to(device)

cross_entropy_layer = nn.CrossEntropyLoss(weight=torch.log(torch.tensor([3.3102, 61.4809, 3.6832, 49.6827, 2.5639], 
                                                                        device=device)), reduction='none')

### Global Attention Mask Utility for Longformer

In [12]:
import numpy as np

def get_global_attention_mask(tokenized_threads: np.ndarray) -> np.ndarray:
    """Returns an attention mask, with 1 where there are [USER{i}] tokens and 
    0 elsewhere.
    """
    mask = np.zeros_like(tokenized_threads)
    for user_token in ["UNU"]+[f"[USER{i}]" for i in range(data_config["max_users"])]:
        user_token_id = tokenizer.encode(user_token)[1:-1]
        mask = np.where(tokenized_threads==user_token_id, 1, mask)
    return np.array(mask, dtype=bool)

### Loss and Prediction Function

In [13]:
from typing import Tuple

In [14]:
def compute(batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
            preds: bool=False, cross_entropy: bool=True):
    """
    Args:
        batch:  A tuple having tokenized thread of shape [batch_size, seq_len],
                component type labels of shape [batch_size, seq_len], and a global
                attention mask for Longformer, of the same shape.
        
        preds:  If True, returns a List(of batch_size size) of Tuples of form 
                (tag_sequence, viterbi_score) where the tag_sequence is the 
                viterbi-decoded sequence, for the corresponding sample in the batch.
        
        cross_entropy:  This argument will only be used if preds=False, i.e., if 
                        loss is being calculated. If True, then cross entropy loss
                        will also be added to the output loss.
    
    Returns:
        Either the predicted sequences with their scores for each element in the batch
        (if preds is True), or the loss value summed over all elements of the batch
        (if preds is False).
    """
    tokenized_threads, token_type_ids, comp_type_labels, global_attention_mask = batch
    
    pad_mask = torch.where(tokenized_threads!=tokenizer.pad_token_id, 1, 0)
    
    logits = linear_layer(transformer_model(input_ids=tokenized_threads,
                                            attention_mask=pad_mask,
                                            global_attention_mask=global_attention_mask).last_hidden_state)
    
    if preds:
        return crf_layer.viterbi_tags(logits, pad_mask)
    
    log_likelihood = crf_layer(logits, comp_type_labels, pad_mask)
    
    if cross_entropy:
        logits = logits.reshape(-1, logits.shape[-1])
        
        pad_mask, comp_type_labels = pad_mask.reshape(-1), comp_type_labels.reshape(-1)
        
        ce_loss = torch.sum(pad_mask*cross_entropy_layer(logits, comp_type_labels))
        
        return ce_loss - log_likelihood

    return -log_likelihood

### Define optimizer

In [15]:
from itertools import chain

import torch.optim as optim

optimizer = optim.Adam(params = chain(transformer_model.parameters(),
                                      linear_layer.parameters(),
                                      crf_layer.parameters()),
                       lr = 2e-5,)

### Training And Evaluation Loops

In [16]:
def train(dataset):
    accumulate_over = 4
    
    optimizer.zero_grad()

    for i, (tokenized_threads, masked_threads, comp_type_labels, _ ) in enumerate(dataset):
        global_attention_mask = torch.tensor(get_global_attention_mask(tokenized_threads),
                                             device=device, dtype=torch.int32)
        
        #Remove Device Axis and cast to PyTorch tensor
        tokenized_threads = torch.tensor(np.squeeze(tokenized_threads, axis=0), 
                                         device=device)
        masked_threads = torch.tensor(np.squeeze(masked_threads, axis=0), 
                                      device=device)
        comp_type_labels = torch.tensor(np.squeeze(comp_type_labels, axis=0), 
                                        device=device, dtype=torch.long)
        
        global_attention_mask = torch.squeeze(global_attention_mask, dim=0)
        
        loss = compute((tokenized_threads,
                        torch.where(masked_threads==tokenizer.mask_token_id, 1, 0), 
                        comp_type_labels, 
                        global_attention_mask))/data_config["batch_size"]
        
        print("Loss: ", loss)
        loss.backward()
        
        if i%accumulate_over==accumulate_over-1:
            optimizer.step()
            optimizer.zero_grad()
    
    optimizer.step()

In [17]:
def evaluate(dataset, metric):
    
    int_to_labels = {v:k for k, v in ac_dict.items()}
    
    for tokenized_threads, masked_threads, comp_type_labels, _ in dataset:
        
        global_attention_mask = torch.tensor(get_global_attention_mask(tokenized_threads), 
                                             device=device)
        
        #Remove Device Axis and cast to PyTorch tensor
        tokenized_threads = torch.tensor(np.squeeze(tokenized_threads, axis=0),
                                        device=device)
        masked_threads = torch.tensor(np.squeeze(masked_threads, axis=0),
                                     device=device)
        comp_type_labels = torch.tensor(np.squeeze(comp_type_labels, axis=0),
                                        device=device)
        global_attention_mask = torch.squeeze(global_attention_mask, dim=0)
        
        preds = compute((tokenized_threads,
                         torch.where(masked_threads==tokenizer.mask_token_id, 1, 0), 
                         comp_type_labels,
                         global_attention_mask),
                        preds=True)
        
        lengths = torch.sum(torch.where(tokenized_threads!=tokenizer.pad_token_id, 1, 0), 
                            axis=-1)
        
        preds = [ [int_to_labels[pred] for pred in pred[0][:lengths[i]]]
                  for i, pred in enumerate(preds)
                ]
        
        refs = [ [int_to_labels[ref] for ref in labels[:lengths[i]]]
                 for i, labels in enumerate(comp_type_labels.cpu().tolist())
               ]
        
        metric.add_batch(predictions=preds, 
                         references=refs,)
                         #tokenized_threads=tokenized_threads.cpu().tolist())
    
    print(metric.compute())

### Final Training

In [18]:
n_epochs = 35

In [19]:
for epoch in range(n_epochs):
    print(f"------------EPOCH {epoch+1}---------------")
    train_dataset, _, test_dataset = get_datasets()
    train(train_dataset)
    evaluate(test_dataset, metric)

------------EPOCH 1---------------
Loss:  tensor(2662.7983, device='cuda:0', grad_fn=<DivBackward0>)
Loss:  tensor(1993.1859, device='cuda:0', grad_fn=<DivBackward0>)
Loss:  tensor(1798.4768, device='cuda:0', grad_fn=<DivBackward0>)
Loss:  tensor(2255.1416, device='cuda:0', grad_fn=<DivBackward0>)
Loss:  tensor(3383.7368, device='cuda:0', grad_fn=<DivBackward0>)
Loss:  tensor(2618.2593, device='cuda:0', grad_fn=<DivBackward0>)
Loss:  tensor(3064.5889, device='cuda:0', grad_fn=<DivBackward0>)
Loss:  tensor(3202.7256, device='cuda:0', grad_fn=<DivBackward0>)
Loss:  tensor(2894.2588, device='cuda:0', grad_fn=<DivBackward0>)
Loss:  tensor(2760.7642, device='cuda:0', grad_fn=<DivBackward0>)
Loss:  tensor(2424.2334, device='cuda:0', grad_fn=<DivBackward0>)
Loss:  tensor(2671.1470, device='cuda:0', grad_fn=<DivBackward0>)
Loss:  tensor(3606.2119, device='cuda:0', grad_fn=<DivBackward0>)
Loss:  tensor(2277.6997, device='cuda:0', grad_fn=<DivBackward0>)
Loss:  tensor(2838.8088, device='cuda:0',

### Rough -- Checking dataset

In [None]:
" ".join(" mY name is ".split())

'mY name is'

In [None]:
def get_datasets():
    train_dataset, valid_dataset, test_dataset = load_dataset(tokenizer=tokenizer,
                                                              train_sz=80,
                                                              test_sz=20,
                                                              mask_tokens=discourse_markers)
    return train_dataset, valid_dataset, test_dataset

In [None]:
train_dataset, _, test_dataset = get_datasets()

In [None]:
for tokenized_threads, masked_threads, comp_type_labels, _ in test_dataset:
    tokenized_threads, masked_threads, comp_type_labels = tokenized_threads[0], masked_threads[0], comp_type_labels[0]
    for tokenized_thread, masked_thread, comp_type_label in zip(tokenized_threads, masked_threads, comp_type_labels):
        print(comp_type_label[:100])
        print(tokenized_thread[:100])
        print(tokenizer.decode(tokenized_thread[:500]))
        start, end = 0, 0
        prev_type = "other"
        i = 0
        while i<tokenized_thread.shape[0]:
            if comp_type_label[i]==ac_dict["O"]:
                if prev_type=="other":
                    end += 1
                else:
                    print("Component: ", tokenizer.decode(tokenized_thread[start:end+1]), " of type: ", prev_type, tokenized_thread[start:end+1])
                    print("Masked Component: ", tokenizer.decode(masked_thread[start:end+1]), " of type: ", prev_type, masked_thread[start:end+1])
                    start = i
                    end = i
                    prev_type="other"
                
            if comp_type_label[i] in [ac_dict["B-C"], ac_dict["B-P"]]:
                print("Component: ", tokenizer.decode(tokenized_thread[start:end+1]), " of type: ", prev_type, tokenized_thread[start:end+1])
                print("Masked Component: ", tokenizer.decode(masked_thread[start:end+1]), " of type: ", prev_type, masked_thread[start:end+1])
                start = i
                end = i
                prev_type = "Claim" if comp_type_label[i]==ac_dict["B-C"] else "Premise"
            
            if comp_type_label[i] in [ac_dict["I-C"], ac_dict["I-P"]]:
                end += 1
            
            i+=1
        break
    break

[0 1 2 2 2 2 2 2 2 2 2 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 2
 2 2 2 2 0 0 0 0 1 2 2 2 2 2 2 2 2 2 2 2 2 3 4 4 4 4]
[    0 18814   846    35  7978     9  1901    16   145   551   350   444
 50270 50268 50268  1121     5    94   367   688    52   348    56    80
  1307  1061  1369    11     5   232     6   258     9    61    58  1726
    30  3510  8941     7    22  3519     9  1901     4    22    20    78
   145     5 11597     9  6366    81    20 21902     6     8   452     5
  1094    23     5  4088     9    10 33937  4320    11  2201     4 50268
   100  1819   923    84   481  1901    53     7   162  1437  8585    16
    10   699   516   227 20203   110    78  8322   235    36  1437    22
   270  1284 29384   328]
<s>CMV: Freedom of speech is being taken too far [USER0] [NEWLINE] [NEWLINE] In the last few weeks we've had two huge events happen in the world, both of which were caused by matters rel

In [None]:
import re
re.sub(r"\s*</claim>([^\s])", r"</claim> \1", "<claim>my name is </claim>jeevesh.")

'<claim>my name is</claim> jeevesh.'