In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
%ls

[0m[01;34mdrive[0m/  [01;34msample_data[0m/


In [3]:
%cd drive/MyDrive/Bert_Lime

/content/drive/MyDrive/Bert_Lime


In [3]:
!pip install --upgrade google-cloud-storage



In [4]:
%%capture
#if running on colab, install below 4
!git clone https://github.com/Jeevesh8/arg_mining
!git clone https://github.com/chridey/change-my-view-modes
!pip install transformers
!pip install seqeval datasets allennlp
!pip install flax
!pip install sentencepiece
#if connected to local runtime, run the next command too
# !pip install bs4 tensorflow torch 

In [5]:
!pip install bertviz



In [6]:
#Run to ignore warnings
import warnings
import numpy as np
warnings.filterwarnings('ignore')
import pickle

### Load Metric

In [7]:
%%capture
from datasets import load_metric
metric = load_metric('seqeval')

In [None]:
!pip install wandb



In [None]:
!wandb login

[34m[1mwandb[0m: Currently logged in as: [33manunay[0m (use `wandb login --relogin` to force relogin)


### Define & Load Tokenizer, Model, Dataset

In [8]:
import torch
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [9]:
device

device(type='cuda', index=0)

In [10]:
model_version = 'bert-base-cased'

In [11]:
%%capture
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased",
                                          bos_token = "[CLS]",
                                          eos_token = "[SEP]")
transformer_model = AutoModel.from_pretrained(model_version,output_hidden_states=True, output_attentions = True)

In [12]:
transformer_model = transformer_model.to(device)

In [13]:
import torch.nn as nn

In [None]:
import wandb
wandb.init(project="bert_explain", entity="anunay")

<IPython.core.display.Javascript object>

In [None]:
wandb.config = {
  "learning_rate": 2e-5,
  "epochs": 35,
  "batch_size": 2,
  "num_devices": 1,
  "max_len": 4096,
  "max_comps": 128,
  "omit_filenames": True 
}

In [None]:
wandb.watch(transformer_model, log='all', log_freq=50)

[]

#### To add extra token type embeddings...

In [14]:
def resize_token_type_embeddings(transformer_model, new_size):
    old_embeddings = transformer_model.embeddings.token_type_embeddings.weight
    old_size, hidden_dim = old_embeddings.shape
    transformer_model.embeddings.token_type_embeddings = nn.Embedding(new_size, hidden_dim, device=transformer_model.device)
    with torch.no_grad():
        transformer_model.embeddings.token_type_embeddings.weight[:old_size] = old_embeddings

#resize_token_type_embeddings(transformer_model, 2)
#transformer_model.config.type_vocab_size = 2

In [15]:
%cd Argument\ Mining\ BTP
%ls

[Errno 2] No such file or directory: 'Argument Mining BTP'
/content/drive/MyDrive/Bert_Lime
[0m[01;34marg_mining[0m/                        Drinventor_linear_layer.pt
[01;34mchange-my-view-modes[0m/              Drinventor_tokenizer_pre.pkl
[01;34mcompiled_corpus[0m/                   Drinventor_transformer_layer.pt
compiled_corpus.zip                layer_wise_analysis.pkl
crf_layer.pkl                      linear_layer.pt
cross_entropy_layer.pt             [01;34mModel[0m/
data_dict.pkl                      [01;34mnaacl18-multitask_argument_mining[0m/
data_runner.pkl                    [01;34mtemp[0m/
Discourse_Markers.txt              tokenizer_pre.pkl
Drinventor_crf_layer.pkl           transformer_layer.pt
Drinventor_cross_entropy_layer.pt  [01;34mwandb[0m/


#### Load in discourse markers(Provide ``Discourse_Markers.txt``)

In [16]:
with open('./Discourse_Markers.txt') as f:
    discourse_markers = [dm.strip() for dm in f.readlines()]

* Change the ``batch_size`` in ``arg_mining/datasets/cmv_modes/configs.py`` before running below cell, as needed. [By default: 8]

* Can also change ``max_len`` in the same file to suit the maximum length of your model. All threads will be truncated at ``max_len`` length. 

In [17]:
%%capture
from arg_mining.datasets.cmv_modes import load_dataset, data_config

In [18]:
tokenizer.add_tokens(data_config["special_tokens"], special_tokens=True)

transformer_model.resize_token_embeddings(len(tokenizer))

Embedding(29011, 768)

### Function to get datasets
* Change split sizes, if needed.

In [19]:
def get_datasets():
    train_dataset, valid_dataset, test_dataset = load_dataset(tokenizer=tokenizer,
                                                              train_sz=50,
                                                              test_sz=50,
                                                              mask_tokens=discourse_markers,)
    return train_dataset, valid_dataset, test_dataset

### Wrap dataset in ``get_comment_wise_dataset`` if you want to get comment wise dataset

In [20]:
from typing import List, Tuple

In [21]:
def split_encoding(tokenized_thread: List[int], 
                   split_on: List[int], 
                   eos_token_id: int) -> List[List[int]]:
    """Splits tokenized_thread into multiple lists at each occurance of 
    a token_id specified in split_on or the eos_token_id.
    
    1. The eos_token_id is retained in the last splitted component.
    2. Each matched token_id from split_on is retained in the component that 
       follows it.
    """
    splitted = [[]]
    for elem in tokenized_thread:
        if elem in split_on:
            splitted.append([])
        splitted[-1].append(elem)
        if elem == eos_token_id:
            break
    return splitted

def pad_batch(elems: List[List[int]], pad_token_id: int) -> List[List[int]]:
    """Pads all lists in elems to the maximum list length of any list in 
    elems. Pads with pad_token_id.
    """
    max_len = max([len(elem) for elem in elems])
    return [elem+[pad_token_id]*(max_len-len(elem)) for elem in elems]

def get_comment_wise_dataset(dataset,
                             max_len: int=512,
                             batch_size: int=8) -> Tuple[List[List[int]], 
                                                         List[List[int]], 
                                                         List[List[int]]]:
    """
    Args:
        dataset:     A numpy iterator dataset for threads, as returned from 
                     get_datasets() function above.
        max_len:     Maximum length at which to truncate any comment.
        batch_size:  Number of comments in a batch
    
    Returns:
        A tuple having batched & padded(to max. length in batch) tokenized threads, 
        masked threads, and component type labels; where each element corresponds
        to a comment in some thread.
    
    NOTE:
        This function removes the extra num_devices dimension from the elements 
        of dataset provided.
    """
    user_token_indices = tokenizer.encode("[UNU]"+"".join([f"[USER{i}]" for i in range(data_config["max_users"])]))[1:-1]
    comment_wise_tokenized_threads = []
    comment_wise_masked_threads = []
    comment_wise_comp_type_labels = []

    for (tokenized_threads, masked_threads, comp_type_labels, _ ) in dataset:
        tokenized_threads = np.squeeze(tokenized_threads, axis=0).tolist()
        masked_threads = np.squeeze(masked_threads, axis=0).tolist()
        comp_type_labels = np.squeeze(comp_type_labels, axis=0).tolist()

        for tokenized_thread, masked_thread, comp_type_label in zip(tokenized_threads, masked_threads, comp_type_labels):
            splitted_encodings = split_encoding(tokenized_thread, user_token_indices, tokenizer.eos_token_id)
            for elem in splitted_encodings:
                comment_wise_tokenized_threads.append(elem)
                comment_wise_masked_threads.append(masked_thread[:len(elem)])
                comment_wise_comp_type_labels.append(comp_type_label[:len(elem)])
                masked_thread, comp_type_label = masked_thread[len(elem):], comp_type_label[len(elem):]
    i = 0
    cw_tokenized_threads, cw_masked_threads, cw_comp_type_labels = [], [], []
    while i<len(comment_wise_tokenized_threads):
         cw_tokenized_threads.append(comment_wise_tokenized_threads[i][:max_len])
         cw_masked_threads.append(comment_wise_masked_threads[i][:max_len])
         cw_comp_type_labels.append(comment_wise_comp_type_labels[i][:max_len])
         i += 1
         
         if i%batch_size==0:
             yield (pad_batch(cw_tokenized_threads, tokenizer.pad_token_id), 
                    pad_batch(cw_masked_threads, tokenizer.pad_token_id),
                    pad_batch(cw_comp_type_labels, data_config["pad_for"]["comp_type_labels"]))
            
             cw_tokenized_threads, cw_masked_threads, cw_comp_type_labels = [], [], []

In [22]:
from typing import Generator

In [23]:
import random

In [24]:
l = [1,2,3]
print(random.shuffle(l))
print(l)

None
[2, 1, 3]


In [25]:
def get_masked_data_lists(dataset,
                          left: bool=True) -> Tuple[List[List[int]], List[List[int]]]:
    """
    Args:
        dataset:    A python generator that yields tuples of np.array's
                    consisting of tokenized_threads, masked_threads and component
                    type labels
        left:       If true, left side of components are masked. Otherwise right
                    side is masked.
    Returns:
        A tuple of two lists consisting of samples from entire dataset:
            final_threads:  A list of lists of int. Where each internal list corresponds
                            to a thread masked on one side.
            final_labels:   A list of lists of int. Where each internal list corresponds
                            to a masked component type labels.
    NOTE:
        A left masked sample consists of a tokenized thread whose all tokens before
        the beginning of some argumentative component are [MASK] and the corresponding
        component type labels are "other".
    """
    final_threads, final_labels = [], []
    if not left:
        for (tokenized_threads, _masked_threads, comp_type_labels) in dataset:
            for (tokenized_thread, comp_type_label) in zip(tokenized_threads, comp_type_labels):
                # tokenized_thread = tokenized_thread.tolist()
                # comp_type_label = comp_type_label.tolist()
                left_masked_thread = []
                comp_types_for_left_masked_thread = []
                for i, (_token, label) in enumerate(zip(tokenized_thread, comp_type_label)):
                    if (label == data_config["arg_components"]["B-C"] or 
                        label == data_config["arg_components"]["B-P"]):
                        final_threads.append(left_masked_thread+tokenized_thread[i:])
                        final_labels.append(comp_types_for_left_masked_thread+comp_type_label[i:])
                    left_masked_thread.append(tokenizer.mask_token_id)
                    comp_types_for_left_masked_thread.append(0)
    else:
        for (tokenized_threads, _masked_threads, comp_type_labels) in dataset:
            for (tokenized_thread, comp_type_label) in zip(tokenized_threads, comp_type_labels):
                tokenized_thread = tokenized_thread[::-1]
                comp_type_label = comp_type_label[::-1]
                right_masked_thread = []
                comp_types_for_right_masked_thread = []
                flag = 0
                for i, (_token, label) in enumerate(zip(tokenized_thread, comp_type_label)):
                    if ((label == data_config["arg_components"]["I-C"] or 
                        label == data_config["arg_components"]["I-P"])) and flag:
                        final_threads.append(right_masked_thread+tokenized_thread[i:])
                        final_labels.append(comp_types_for_right_masked_thread+comp_type_label[i:])
                        final_threads[-1] = final_threads[-1][::-1]
                        final_labels[-1] = final_labels[-1][::-1]
                    if(label != data_config["arg_components"]["I-C"] and
                        label != data_config["arg_components"]["I-P"]):
                        flag = 1
                    else:
                        flag = 0
                    right_masked_thread.append(tokenizer.mask_token_id)
                    comp_types_for_right_masked_thread.append(0)
#         print((final_threads, final_labels))
        
    return final_threads, final_labels

def get_masked_dataset(dataset,
                       left:bool = True,
                       shuffle:bool = True,
                       batch_size: int = 10) -> Generator[Tuple[np.ndarray, np.ndarray], None, None]:
    """
    Args:
        dataset:    Same as in get_masked_data_lists()
        left:       Same as in get_masked_data_lists()
        shuffle:    Whether to shuffle around the elements corresponding to masking
                    of various threads. If True, batch will consist of random left/right
                    masked samples from different tokenized_threads, rather than same one.
        batch_size: Number of elements to put in a batch. 
    
    Yields:
        A batch consisting of a tuple of np.array's corresponding to left/right masked
        tokenized_threads, and comp_type_labels.
    """
    masked_threads, labels_for_masked_threads = get_masked_data_lists(dataset, left)
    samples =[(elem1, elem2) for (elem1, elem2) in zip(masked_threads, labels_for_masked_threads)]
    if shuffle:
        random.shuffle(samples)
    
    batch_threads = []
    batch_labels = []
    lengths = []
    for sample in samples:
#         print(sample)
        batch_threads.append(sample[0])
        batch_labels.append(sample[1])
        lengths.append(len(sample[0]))
        if len(batch_threads)==batch_size:
            max_len = max(lengths)
            for thread, label in zip(batch_threads, batch_labels):
                thread += [tokenizer.pad_token_id]*(max_len-len(thread))
                label += [data_config["pad_for"]["comp_type_labels"]]*(max_len-len(thread))
            yield np.array(batch_threads), np.array(batch_labels)
            batch_threads, batch_labels, lengths = [], [], []

### Sample Run for dataset

In [None]:
"""
train_dataset, valid_dataset, test_dataset = get_datasets()
for (tokenized_threads, masked_threads, comp_type_labels) in get_comment_wise_dataset(train_dataset):
    print(len(tokenized_threads[0]))
    print(tokenizer.batch_decode(tokenized_threads))
    print(tokenizer.batch_decode(masked_threads))
    print(comp_type_labels)
    break
"""

'\ntrain_dataset, valid_dataset, test_dataset = get_datasets()\nfor (tokenized_threads, masked_threads, comp_type_labels) in get_comment_wise_dataset(train_dataset):\n    print(len(tokenized_threads[0]))\n    print(tokenizer.batch_decode(tokenized_threads))\n    print(tokenizer.batch_decode(masked_threads))\n    print(comp_type_labels)\n    break\n'

In [26]:
from allennlp.modules.conditional_random_field import ConditionalRandomField as crf

### Define layers for a Linear-Chain-CRF

In [27]:

from allennlp.modules.conditional_random_field import ConditionalRandomField as crf

ac_dict = data_config["arg_components"]

allowed_transitions =([(ac_dict["B-C"], ac_dict["I-C"]), 
                       (ac_dict["B-P"], ac_dict["I-P"])] + 
                      [(ac_dict["I-C"], ac_dict[ct]) 
                        for ct in ["I-C", "B-C", "B-P", "O"]] +
                      [(ac_dict["I-P"], ac_dict[ct]) 
                        for ct in ["I-P", "B-C", "B-P", "O"]] +
                      [(ac_dict["O"], ac_dict[ct]) 
                        for ct in ["O", "B-C", "B-P"]])
                    
linear_layer = nn.Linear(transformer_model.config.hidden_size,
                         len(ac_dict)).to(device)

crf_layer = crf(num_tags=len(ac_dict),
                constraints=allowed_transitions,
                include_start_end_transitions=False).to(device)

cross_entropy_layer = nn.CrossEntropyLoss(weight=torch.log(torch.tensor([3.3102, 61.4809, 3.6832, 49.6827, 2.5639], 
                                                                        device=device)), reduction='none')

In [28]:
linear_path = "linear_layer.pt"
cross_path = "cross_entropy_layer.pt"
crf_path = "crf_layer.pkl"
tokenizer_path = "tokenizer_pre.pkl"
transformer_path = "transformer_layer.pt"

In [29]:
import pickle

### Loss and Prediction Function

In [30]:
def compute(batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
            preds: bool=False, cross_entropy: bool=True):
    """
    Args:
        batch:  A tuple having tokenized thread of shape [batch_size, seq_len],
                component type labels of shape [batch_size, seq_len], and a global
                attention mask for Longformer, of the same shape.
        
        preds:  If True, returns a List(of batch_size size) of Tuples of form 
                (tag_sequence, viterbi_score) where the tag_sequence is the 
                viterbi-decoded sequence, for the corresponding sample in the batch.
        
        cross_entropy:  This argument will only be used if preds=False, i.e., if 
                        loss is being calculated. If True, then cross entropy loss
                        will also be added to the output loss.
    
    Returns:
        Either the predicted sequences with their scores for each element in the batch
        (if preds is True), or the loss value summed over all elements of the batch
        (if preds is False).
    """
    tokenized_threads, token_type_ids, comp_type_labels = batch
    
    pad_mask = torch.where(tokenized_threads!=tokenizer.pad_token_id, 1, 0)
    
    logits = linear_layer(transformer_model(input_ids=tokenized_threads,
                                            attention_mask=pad_mask,).last_hidden_state)
    # print(logits.shape)
    if preds:
        return crf_layer.viterbi_tags(logits, pad_mask)
    
    log_likelihood = crf_layer(logits, comp_type_labels, pad_mask)
    
    if cross_entropy:
        logits = logits.reshape(-1, logits.shape[-1])
        
        pad_mask, comp_type_labels = pad_mask.reshape(-1), comp_type_labels.reshape(-1)
        
        ce_loss = torch.sum(pad_mask*cross_entropy_layer(logits, comp_type_labels))
        
        return ce_loss - log_likelihood

    return -log_likelihood

### Define optimizer

In [31]:
from itertools import chain

import torch.optim as optim

optimizer = optim.Adam(params = chain(transformer_model.parameters(),
                                      linear_layer.parameters(),
                                      crf_layer.parameters()),
                       lr = 2e-5,)

### Training And Evaluation Loops

In [32]:
values_weight = [[] for i in range(5)]
values_bias = [[] for i in range(5)]
def train_left_right(dataset, left: bool=True):
    global values_weight, values_bias;
    accumulate_over = 4
    
    optimizer.zero_grad()
    
    for i, (tokenized_threads, comp_type_labels) in enumerate(get_comment_wise_dataset(dataset)):
        
        #Cast to PyTorch tensor
        tokenized_threads = torch.tensor(tokenized_threads, device=device)
        # masked_threads = torch.tensor(masked_threads, device=device)
        max_size = max([len(i1) for i1 in comp_type_labels])
        # print(max_size)
        new_comp = []
        for l in range(comp_type_labels.shape[0]):
          new_comp.append(list(comp_type_labels[l]) + [0]*(max_size - len(comp_type_labels[l])))
        comp_type_labels = np.array(new_comp)
        # print(type(comp_type_labels))
        comp_type_labels = torch.tensor(comp_type_labels, device=device, dtype=torch.int64)
        
        loss = compute((tokenized_threads,
                        torch.where(tokenized_threads==tokenizer.mask_token_id, 1, 0), 
                        comp_type_labels,))/data_config["batch_size"]
        
        print("Loss: ", loss)

        loss.backward()
        
        if i%accumulate_over==accumulate_over-1:
            optimizer.step()
            optimizer.zero_grad()
    
    optimizer.step()

In [33]:
def evaluate(dataset, metric):
    
    int_to_labels = {v:k for k, v in ac_dict.items()}
    print('ENTER')
    
    with torch.no_grad():
        for tokenized_threads, masked_threads, comp_type_labels in get_comment_wise_dataset(dataset):
            # print(comp_type_labels)
            #Cast to PyTorch tensor
            tokenized_threads = torch.tensor(tokenized_threads, device=device)
            masked_threads = torch.tensor(masked_threads, device=device)
            comp_type_labels = torch.tensor(comp_type_labels, device=device)
            # print(tokenized_threads)
            # print(comp_type_labels.shape)
            # print(comp_type_labels)
            # print(torch.where(tokenized_threads==tokenizer.mask_token_id, 1, 0))
            preds = compute((tokenized_threads,
                            torch.where(masked_threads==tokenizer.mask_token_id, 1, 0), 
                            comp_type_labels,), preds=True)
            
            lengths = torch.sum(torch.where(tokenized_threads!=tokenizer.pad_token_id, 1, 0), 
                                axis=-1)
            # print(preds.shape)
            # print(preds)
            preds = [ [int_to_labels[pred] for pred in pred[0][:lengths[i]]]
                    for i, pred in enumerate(preds)
                    ]
            
            refs = [ [int_to_labels[ref] for ref in labels[:lengths[i]]]
                    for i, labels in enumerate(comp_type_labels.cpu().tolist())
                ]
            
            metric.add_batch(predictions=preds, 
                            references=refs,)
                            #tokenized_threads=tokenized_threads.cpu().tolist())
        
    print(metric.compute())

In [34]:

def train(dataset):
    accumulate_over = 4
    
    optimizer.zero_grad()

    for i, (tokenized_threads, masked_threads, comp_type_labels) in enumerate(get_comment_wise_dataset(dataset)):
        
        #Cast to PyTorch tensor
        tokenized_threads = torch.tensor(tokenized_threads, device=device)
        masked_threads = torch.tensor(masked_threads, device=device)
        comp_type_labels = torch.tensor(comp_type_labels, device=device, dtype=torch.long)
        
        loss = compute((tokenized_threads,
                        torch.where(masked_threads==tokenizer.mask_token_id, 1, 0), 
                        comp_type_labels,))/data_config["batch_size"]
        
        print("Loss: ", loss)

        wandb.log({'train_loss': loss.item() })
        loss.backward()
        if i%accumulate_over==accumulate_over-1:
            optimizer.step()
            optimizer.zero_grad()
    
    optimizer.step()

### Final Training

In [None]:

for name, param in transformer_model.named_parameters():
  # if("encoder.layer.11" in name):  
  print(name, param.shape)

In [None]:
for name, param in linear_layer.named_parameters():
  # if("encoder.layer.11" in name):  
  print(name, param.shape)

weight torch.Size([5, 768])
bias torch.Size([5])


In [None]:
print(tokenizer.mask_token_id)

103


In [None]:
n_epochs = 35
for epoch in range(n_epochs):
    print(f"------------EPOCH {epoch+1}---------------")
    train_dataset, _, test_dataset = get_datasets()
    train(train_dataset)
    evaluate(test_dataset, metric)

------------EPOCH 1---------------
Loss:  tensor(1034.2817, device='cuda:0', grad_fn=<DivBackward0>)
Loss:  tensor(887.2109, device='cuda:0', grad_fn=<DivBackward0>)
Loss:  tensor(2408.0747, device='cuda:0', grad_fn=<DivBackward0>)
Loss:  tensor(1494.1326, device='cuda:0', grad_fn=<DivBackward0>)
Loss:  tensor(1873.7246, device='cuda:0', grad_fn=<DivBackward0>)
Loss:  tensor(2329.7798, device='cuda:0', grad_fn=<DivBackward0>)
Loss:  tensor(1234.6335, device='cuda:0', grad_fn=<DivBackward0>)
Loss:  tensor(3179.7261, device='cuda:0', grad_fn=<DivBackward0>)
Loss:  tensor(2652.5676, device='cuda:0', grad_fn=<DivBackward0>)
Loss:  tensor(2226.3198, device='cuda:0', grad_fn=<DivBackward0>)
Loss:  tensor(2417.5398, device='cuda:0', grad_fn=<DivBackward0>)
Loss:  tensor(2614.9922, device='cuda:0', grad_fn=<DivBackward0>)
Loss:  tensor(1398.1238, device='cuda:0', grad_fn=<DivBackward0>)
Loss:  tensor(1732.7836, device='cuda:0', grad_fn=<DivBackward0>)
Loss:  tensor(1246.1193, device='cuda:0', 

In [None]:
# linear_path = "Roberta_linear_layer.pt"
# cross_path = "Roberta_cross_entropy_layer.pt"
# crf_path = "Roberta_crf_layer.pkl"
# tokenizer_path = "Roberta_tokenizer_pre.pkl"
# transformer_path = "Roberta_transformer_layer.pt"
# torch.save(linear_layer.state_dict(), linear_path)
# torch.save(cross_entropy_layer.state_dict(), cross_path)
# torch.save(transformer_model.state_dict(), transformer_path)
# with open(crf_path, "wb") as f:
#   pickle.dump(crf_layer, f)
# with open(tokenizer_path, "wb") as f:
#   pickle.dump(tokenizer, f)

In [35]:
linear_path = "Model/linear_layer.pt"
cross_path = "Model/cross_entropy_layer.pt"
crf_path = "Model/crf_layer.pkl"
tokenizer_path = "Model/tokenizer_pre.pkl"
transformer_path = "Model/transformer_layer.pt"

In [36]:
%ls

[0m[01;34marg_mining[0m/                        Drinventor_linear_layer.pt
[01;34mchange-my-view-modes[0m/              Drinventor_tokenizer_pre.pkl
[01;34mcompiled_corpus[0m/                   Drinventor_transformer_layer.pt
compiled_corpus.zip                layer_wise_analysis.pkl
crf_layer.pkl                      linear_layer.pt
cross_entropy_layer.pt             [01;34mModel[0m/
data_dict.pkl                      [01;34mnaacl18-multitask_argument_mining[0m/
data_runner.pkl                    [01;34mtemp[0m/
Discourse_Markers.txt              tokenizer_pre.pkl
Drinventor_crf_layer.pkl           transformer_layer.pt
Drinventor_cross_entropy_layer.pt  [01;34mwandb[0m/


In [37]:
linear_layer.load_state_dict(torch.load(linear_path, map_location=device))
cross_entropy_layer.load_state_dict(torch.load(cross_path, map_location=device))
transformer_model.load_state_dict(torch.load(transformer_path , map_location=device))
with open(crf_path, "rb") as f:
  crf_layer = pickle.load(f)
with open(tokenizer_path, "rb") as f:
  tokenizer = pickle.load(f)

In [38]:
from itertools import chain

import torch.optim as optim

optimizer = optim.Adam(params = chain(transformer_model.parameters(),
                                      linear_layer.parameters(),
                                      crf_layer.parameters()),
                       lr = 2e-5,)

In [39]:
graph_attention = []
for i in range(12):
  graph_attention.append(dict())
mapping_ind = {}

In [40]:
cnt = 0
threshold = 0.01

In [41]:

def attention_graph(dataset):
    accumulate_over = 4
    global cnt;
    optimizer.zero_grad()

    for i, (tokenized_threads, masked_threads, comp_type_labels) in enumerate(get_comment_wise_dataset(dataset)):
        
        #Cast to PyTorch tensor
        tokenized_threads = torch.tensor(tokenized_threads, device=device)
        masked_threads = torch.tensor(masked_threads, device=device)
        comp_type_labels = torch.tensor(comp_type_labels, device=device, dtype=torch.long)
        
        pad_mask = torch.where(tokenized_threads!=tokenizer.pad_token_id, 1, 0)
    
        attention = transformer_model(input_ids=tokenized_threads,
                                            attention_mask=pad_mask,)[-1][0]
        maximum_attention = np.max(attention.cpu().detach().numpy())
        for j, tokenized_thread in enumerate(tokenized_threads):
          tokens = tokenizer.convert_ids_to_tokens(tokenized_thread) 
          for tok in tokens:
            if(tok not in mapping_ind):
              mapping_ind[tok] = cnt;
              cnt += 1
          for lay in range(12):
            weights = attention[j, lay, :, :]
            for ind_1 in range(len(tokens)):
              for ind_2 in range(len(tokens)):
                if(tokens[ind_1] in ["[NEWLINE]", '[PAD]'] or tokens[ind_2] in ["[NEWLINE]", '[PAD]']):
                  continue
                
                node_1 = mapping_ind[tokens[ind_1]]
                node_2 = mapping_ind[tokens[ind_2]]
                weight = weights[ind_1, ind_2]
                if(weight < threshold*maximum_attention):
                  continue
                if(node_1 not in graph_attention[lay]):
                  graph_attention[lay][node_1] = {}
                if(node_2  not in graph_attention[lay][node_1]):
                  graph_attention[lay][node_1][node_2] = weight
                else:
                  graph_attention[lay][node_1][node_2] += weight
        print(cnt, cnt**2)              


In [None]:
train_dataset, _, test_dataset = get_datasets()
attention_graph(train_dataset)

265 70225
347 120409
784 614656
877 769129
997 994009
1260 1587600
1350 1822500
1619 2621161
1836 3370896
1979 3916441
2110 4452100
2161 4669921
2293 5257849
2393 5726449
2444 5973136


In [None]:
graph_path = "attention_graph.pkl"
with open(graph_path, "wb") as f:
  pickle.dump(graph_attention, f)

In [46]:
k = 0
for i in graph_attention[0]:
  k += len(graph_attention[0][i])

91
91
91
91
1383
113
6686
2091
1684
2523
14973
1243
171
4785
16063
1316
1351
548
8282
7723
620
548
10465
548
13969
548
548
548
561
548
1048
2301
6924
1644
1096
626
1371
13796
871
880
548
20045
548
548
1194
1242
1096
1048
1048
3416
6967
548
1548
1048
548
899
3170
1463
548
548
548
548
548
1048
10368
548
548
548
1542
1048
665
1133
548
548
3136
5532
777
1378
655
2963
871
10622
1093
1839
2663
3226
1478
548
693
1670
2551
1526
548
1169
2578
1243
1096
548
548
2227
655
655
548
990
795
41
255
1752
148
148
41
41
541
41
41
41
41
41
41
887
539
41
1648
1041
41
144
2520
107
107
336
107
107
107
107
277
1296
832
723
4000
454
454
434
607
351
1811
1946
2262
107
107
454
107
107
107
107
332
1430
107
229
107
107
430
107
1107
257
2163
107
107
836
769
122
445
567
244
244
154
490
1517
122
707
347
347
538
538
538
592
122
563
563
848
122
554
1623
122
725
347
847
122
122
244
122
1122
122
122
122
122
122
122
328
725
310
1372
225
578
78
78
168
181
78
225
225
225
225
1467
225
548
225
225
225
225
519
309
147
147
147


In [38]:
from bertviz import model_view


In [53]:
tokenized_threads = tokenizer.encode("I point out that many Christians follow the bible which has numerous examples of sexism, but in application, there are numerous branches of Christianity that are no more sexist than secular groups. For example, Congregationalists and Universaliists.", return_tensors='pt')
d_tokenized_threads = tokenized_threads.to(device)
# print(tokenized_threads)3,60,000/yr
pad_mask = torch.where(tokenized_threads!=tokenizer.pad_token_id, 1, 0).to(device)
# print(pad_mask)
output = transformer_model(input_ids=d_tokenized_threads,
                                        attention_mask=pad_mask,)[-1]
print(np.max(output[0].cpu().detach().numpy()))
print(np.min(output[0].cpu().detach().numpy()))

0.99528
4.3162334e-08


In [None]:
tokenized_threads = tokenizer.encode("I point out that many Christians follow the bible which has numerous examples of sexism, but in application, there are numerous branches of Christianity that are no more sexist than secular groups. For example, Congregationalists and Universaliists.", return_tensors='pt')
d_tokenized_threads = tokenized_threads.to(device)
# print(tokenized_threads)3,60,000/yr
pad_mask = torch.where(tokenized_threads!=tokenizer.pad_token_id, 1, 0).to(device)
# print(pad_mask)
output = transformer_model(input_ids=d_tokenized_threads,
                                        attention_mask=pad_mask,)
# output = transformer_model(tokenized_threads)
attention = output[-1]
print(len(attention), attention[0].shape)
tokens = tokenizer.convert_ids_to_tokens(tokenized_threads[0]) 
print(len(tokens), attention[0].shape)
print(attentio)


Output hidden; open in https://colab.research.google.com to view.

In [None]:
model_view(attention, tokens)

In [None]:
tokenized_threads = tokenizer.encode("For example, Congregationalists and Universaliists. I point out that many Christians follow the bible which has numerous examples of sexism, but in application, there are numerous branches of Christianity that are no more sexist than secular groups. ", return_tensors='pt')
d_tokenized_threads = tokenized_threads.to(device)
# print(tokenized_threads)3,60,000/yr
pad_mask = torch.where(tokenized_threads!=tokenizer.pad_token_id, 1, 0).to(device)
# print(pad_mask)
output = transformer_model(input_ids=d_tokenized_threads,
                                        attention_mask=pad_mask,)
# output = transformer_model(tokenized_threads)
attention = output[-1]
print(len(attention), attention[0].shape)
tokens = tokenizer.convert_ids_to_tokens(tokenized_threads[0]) 
print(len(tokens), attention[0].shape)
model_view(attention, tokens)

In [None]:
class wrapper():
  model = -1
  tokenizer = -1

  def __init__(model, tokenizer):
    self.model = model
    self.tokenizer = tokenizer

  def predict_proba(text_sample):
    batch = preprocess(text_sample)
    tokenized_threads, token_type_ids = batch
    
    pad_mask = torch.where(tokenized_threads!=tokenizer.pad_token_id, 1, 0)
    
    logits = linear_layer(transformer_model(input_ids=tokenized_threads,
                                            attention_mask=pad_mask,).last_hidden_state)
    
    
    return logits

In [None]:
import seaborn as sns
%matplotlib inline
from collections import OrderedDict
from lime.lime_text import LimeTextExplainer

explainer = LimeTextExplainer(class_names=class_names)
explanation = explainer.explain_instance(text_sample, pipeline.predict_proba, num_features=10)

weights = OrderedDict(explanation.as_list())
lime_weights = pd.DataFrame({'words': list(weights.keys()), 'weights': list(weights.values())})

sns.barplot(x="words", y="weights", data=lime_weights);
plt.xticks(rotation=45)
plt.title('Sample {} features weights given by LIME'.format(idx));