In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
import matplotlib
import seaborn as sns
sns.reset_orig()
import os
import pandas as pd
import json
import numpy as np

import warnings
warnings.filterwarnings('ignore')

pd.set_option('display.max_columns', None)

import torch
import gc
gc.collect()
torch.cuda.empty_cache()

!pip install transformers
!pip install datasets
!pip install sentencepiece
!pip install transformers

from datasets import load_dataset
from tokenizers import Tokenizer

import torch
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, PackedSequence
from transformers import BertModel

In [None]:
import torch
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, PackedSequence
from transformers import BertModel


class WordAttentionPre(torch.nn.Module):
    def __init__(self, device):
        super().__init__()
        self._device = device

    def forward(self, docs, doc_lengths, sent_lengths, attention_masks, token_type_ids):
        """
        :param docs: encoded document-level data; LongTensor (num_docs, padded_doc_length, padded_sent_length)
        :param doc_lengths: unpadded document lengths; LongTensor (num_docs)
        :param sent_lengths: unpadded sentence lengths; LongTensor (num_docs, max_sent_len)
        :param attention_masks: BERT attention masks; LongTensor (num_docs, padded_doc_length, padded_sent_length)
        :param token_type_ids: BERT token type IDs; LongTensor (num_docs, padded_doc_length, padded_sent_length)
        :return: sentences embeddings, docs permutation indices, docs batch sizes, word attention weights
        """

        # Sort documents by decreasing order in length
        doc_lengths, doc_perm_idx = doc_lengths.sort(dim=0, descending=True)
        docs = docs[doc_perm_idx]
        sent_lengths = sent_lengths[doc_perm_idx]
        attention_masks = attention_masks[doc_perm_idx]
        token_type_ids = token_type_ids[doc_perm_idx]
        #print(doc_lengths, doc_perm_idx, docs, sent_lengths,attention_masks, token_type_ids)
        
        # Make a long batch of sentences by removing pad-sentences
        # i.e. `docs` was of size (num_docs, padded_doc_length, padded_sent_length)
        # -> `packed_sents.data` is now of size (num_sents, padded_sent_length)
        packed_sents = pack_padded_sequence(docs, lengths=doc_lengths.tolist(), batch_first=True)

        # effective batch size at each timestep
        docs_valid_bsz = packed_sents.batch_sizes

        #print(packed_sents, docs_valid_bsz)


        # Make a long batch of sentence lengths by removing pad-sentences
        # i.e. `sent_lengths` was of size (num_docs, padded_doc_length)
        # -> `packed_sent_lengths.data` is now of size (num_sents)
        packed_sent_lengths = pack_padded_sequence(sent_lengths, lengths=doc_lengths.tolist(), batch_first=True)
        
        #print(packed_sent_lengths)
        
        # Make a long batch of attention masks by removing pad-sentences
        # i.e. `docs` was of size (num_docs, padded_doc_length, padded_sent_length)
        # -> `packed_attention_masks.data` is now of size (num_sents, padded_sent_length)
        packed_attention_masks = pack_padded_sequence(attention_masks, lengths=doc_lengths.tolist(), batch_first=True)

        # Make a long batch of token_type_ids by removing pad-sentences
        # i.e. `docs` was of size (num_docs, padded_doc_length, padded_sent_length)
        # -> `token_type_ids.data` is now of size (num_sents, padded_sent_length)
        packed_token_type_ids = pack_padded_sequence(token_type_ids, lengths=doc_lengths.tolist(), batch_first=True)

        sents, sent_lengths, attn_masks, token_types = (
            packed_sents.data, packed_sent_lengths.data, packed_attention_masks.data, packed_token_type_ids.data
        )

        #print(sents, sent_lengths, attn_masks, token_types)
        
        # Sort sents by decreasing order in sentence lengths
        sent_lengths, sent_perm_idx = sent_lengths.sort(dim=0, descending=True)
        sents = sents[sent_perm_idx]
        
        #print(sents.shape, sent_lengths.shape, sent_perm_idx.shape)
        return sents, sent_lengths, attn_masks, token_types, docs_valid_bsz, doc_perm_idx, sent_perm_idx

In [None]:
import torch
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, PackedSequence
from transformers import BertModel


class WordAttentionPost(torch.nn.Module):
    def __init__(
            self,
            device: str,
            recurrent_size: int,
            attention_dim: int,
            bert_version: str
    ):
        super().__init__()
        self.attention_dim = attention_dim
        self.recurrent_size = recurrent_size
        self._device = device
        self.bert_model = BertModel.from_pretrained(bert_version)

        # Maps BERT output to `attention_dim` sized tensor
        self.word_weight = nn.Linear(self.recurrent_size, self.attention_dim)

        # Word context vector (u_w) to take dot-product with
        self.context_weight = nn.Linear(self.attention_dim, 1)

    def recurrent_size(self):
        return self.recurrent_size

    def forward(self, sents, sent_lengths, attn_masks, token_types, docs_valid_bsz,  doc_perm_idx, sent_perm_idx):
        
        output_bert = self.bert_model(sents,attention_mask=attn_masks, token_type_ids=token_types)
        pooled_out = output_bert.pooler_output
        embeddings = output_bert.last_hidden_state
        #print(sents.shape)
        #print(pooled_out.shape, embeddings.shape)


        packed_words = pack_padded_sequence(embeddings, lengths=sent_lengths.tolist(), batch_first=True)

        # effective batch size at each timestep
        sentences_valid_bsz = packed_words.batch_sizes
        #print(sentences_valid_bsz)
        #print(packed_words.data.shape)

        
        u_i = torch.tanh(self.word_weight(packed_words.data))
        #print(u_i.shape)
        u_w = self.context_weight(u_i).squeeze(1)
        #print(u_w.shape)
        val = u_w.max()
        att = torch.exp(u_w - val)
        #print(val, att.shape)
        
        # Restore as sentences by repadding
        att, _ = pad_packed_sequence(PackedSequence(att, sentences_valid_bsz), batch_first=True)

        #print(att.shape)

        att_weights = att / torch.sum(att, dim=1, keepdim=True)
        #print(att,att_weights,att_weights.shape)


        # Restore as sentences by repadding
        sents, _ = pad_packed_sequence(packed_words, batch_first=True)
        #print(sents.shape)

        sents = sents * att_weights.unsqueeze(2)
        #print(sents.shape)
        sents = sents.sum(dim=1)
        #print(sents.shape)
        
        # Restore the original order of sentences (undo the first sorting)
        _, sent_unperm_idx = sent_perm_idx.sort(dim=0, descending=False)
        sents = sents[sent_unperm_idx]

        #print(sents.shape, sent_perm_idx, sent_unperm_idx)
        
        att_weights = att_weights[sent_unperm_idx]

        #print(f'att_wt {att_weights}')
        #print(sents.shape)
        #print(doc_perm_idx)
        #print(docs_valid_bsz)
        #return 

        return sents, doc_perm_idx, docs_valid_bsz, att_weights

In [None]:
from re import A
import torch
from torch import nn
from torch.nn.utils.rnn import pad_packed_sequence, PackedSequence

class SentenceAttention(torch.nn.Module):
    def __init__(
            self,
            device: str,
            dropout: float,
            word_recurrent_size: int,
            recurrent_size: int,
            attention_dim: int,
    ):
        super().__init__()
        self._device = device
        self.word_recurrent_size = word_recurrent_size
        self.recurrent_size = recurrent_size
        self.dropout = dropout
        self.attention_dim = attention_dim

        assert self.recurrent_size % 2 == 0

        self.encoder = nn.GRU(
            input_size=self.word_recurrent_size,
            hidden_size=self.recurrent_size // 2,
            dropout=self.dropout,
            bidirectional=True,
            batch_first=True,
        )
        self.dropout = nn.Dropout(dropout)

        # Maps GRU output to `attention_dim` sized tensor
        self.sentence_weight = nn.Linear(self.recurrent_size, self.attention_dim)

        # Word context vector (u_w) to take dot-product with
        self.sentence_context_weight = nn.Linear(self.attention_dim, 1)

    def recurrent_size(self):
        return self.recurrent_size

    def forward(self, sent_embeddings, doc_perm_idx, doc_valid_bsz, word_att_weights):
        """
        :param sent_embeddings: LongTensor (batch_size * padded_doc_length, sentence recurrent dim)
        :param doc_perm_idx: LongTensor (batch_size)
        :param doc_valid_bsz: LongTensor (max_doc_len)
        :param word_att_weights: LongTensor (batch_size * padded_doc_length, max_sent_len)
        :return: docs embeddings, word attention weights, sentence attention weights
        """

        sent_embeddings = self.dropout(sent_embeddings)

        # Sentence-level LSTM over sentence embeddings
        packed_sentences, _ = self.encoder(PackedSequence(sent_embeddings, doc_valid_bsz))
        
        #print(packed_sentences.data.size())

        u_i = torch.tanh(self.sentence_weight(packed_sentences.data))
        u_w = self.sentence_context_weight(u_i).squeeze(1)
        val = u_w.max()
        att = torch.exp(u_w - val)

        #print(u_i.size(), att.size(),att)

        
        # Restore as sentences by repadding
        att, _ = pad_packed_sequence(PackedSequence(att, doc_valid_bsz), batch_first=True)

        #print(att)
        
        sent_att_weights = att / torch.sum(att, dim=1, keepdim=True)

        #print(sent_att_weights)

        # Restore as documents by repadding
        docs, _ = pad_packed_sequence(packed_sentences, batch_first=True)

        #print(docs.shape)
        
        # Compute document vectors
        docs = docs * sent_att_weights.unsqueeze(2)
        docs = docs.sum(dim=1)
        #print(docs.shape)

        # Restore as documents by repadding
        word_att_weights, _ = pad_packed_sequence(PackedSequence(word_att_weights, doc_valid_bsz), batch_first=True)
        #print(word_att_weights.shape, word_att_weights)
        
        # Restore the original order of documents (undo the first sorting)
        _, doc_unperm_idx = doc_perm_idx.sort(dim=0, descending=False)
        docs = docs[doc_unperm_idx]

        #print(docs.shape, doc_perm_idx, doc_unperm_idx)

        word_att_weights = word_att_weights[doc_unperm_idx]
        sent_att_weights = sent_att_weights[doc_unperm_idx]
        #print(word_att_weights, sent_att_weights)
        
        return docs, word_att_weights, sent_att_weights

In [None]:
import torch
from torch import nn


class HANModel(torch.nn.Module):
  def __init__(
            self,
            device: str,
            dropout: float,
            word_embed_dim: int,
            word_att_dim: int,
            sentence_embed_dim: int, 
            sentence_att_dim: int,
            bert_version : str
            ):
        super(HANModel, self).__init__()
        self.word_attention_pre = WordAttentionPre(device)
        self.word_attention_post = WordAttentionPost(device, word_embed_dim, word_att_dim, bert_version)
        #self.word_attention = WordAttention(device, word_embed_dim, word_att_dim, bert_version)
        self.sentence_attention = SentenceAttention(device, dropout, word_embed_dim, sentence_embed_dim, sentence_att_dim)
        
  
  def forward(self, docs, doc_lengths, sent_lengths, attention_masks, token_type_ids):
        """
        :param docs: encoded document-level data; LongTensor (num_docs, padded_doc_length, padded_sent_length)
        :param doc_lengths: unpadded document lengths; LongTensor (num_docs)
        :param sent_lengths: unpadded sentence lengths; LongTensor (num_docs, max_sent_len)
        :param labels: labels; LongTensor (num_docs)
        :param attention_masks: BERT attention masks; LongTensor (num_docs, padded_doc_length, padded_sent_length)
        :param token_type_ids: BERT token type IDs; LongTensor (num_docs, padded_doc_length, padded_sent_length)
        :return: class scores, attention weights of words, attention weights of sentences
        """
        
        # get sentence embedding for each sentence by passing the input in the world attention model
        sents, sent_lengths, attn_masks, token_types, docs_valid_bsz,  doc_perm_idx, sent_perm_idx = self.word_attention_pre(
                docs, doc_lengths, sent_lengths, attention_masks, token_type_ids
            )
        
        sent_embeddings, doc_perm_idx, docs_valid_bsz, word_att_weights = self.word_attention_post(
                 sents, sent_lengths, attn_masks, token_types, docs_valid_bsz,  doc_perm_idx, sent_perm_idx
            )

        '''
        sent_embeddings, doc_perm_idx, docs_valid_bsz, word_att_weights = self.word_attention(
                docs, doc_lengths, sent_lengths, attention_masks, token_type_ids
            )
        '''
        # get document embedding for each document by passing the sentence embeddings in the sentence attention model
        doc_embeds, word_att_weights, sentence_att_weights = self.sentence_attention(
            sent_embeddings, doc_perm_idx, docs_valid_bsz, word_att_weights
        )
        

        return (doc_embeds,word_att_weights, sentence_att_weights)

In [None]:
import torch
from torch import nn


class HANClassifierModel(torch.nn.Module):
  def __init__(
            self,
            device: str,
            dropout: float,
            word_embed_dim: int,
            word_att_dim: int,
            sentence_embed_dim: int, 
            sentence_att_dim: int,
            classifier_hidden_dim: int, 
            num_classes :int,
            bert_version : str,
            task = None : str
            ):
        super(HANClassifierModel, self).__init__()

        self.HANmodel = HANModel(device, dropout, word_embed_dim, word_att_dim,sentence_embed_dim,sentence_att_dim,bert_version) 

        if task == 'ab':
          sentence_embed_dim += num_classes
          
        self.classifier = nn.Sequential(
            torch.nn.Linear(sentence_embed_dim, classifier_hidden_dim), 
            nn.ReLU(), 
            nn.Dropout(dropout),
            torch.nn.Linear(classifier_hidden_dim, num_classes)
        )

  
  
  def forward(self, docs, doc_lengths, sent_lengths, attention_masks, token_type_ids):
        """
        :param docs: encoded document-level data; LongTensor (num_docs, padded_doc_length, padded_sent_length)
        :param doc_lengths: unpadded document lengths; LongTensor (num_docs)
        :param sent_lengths: unpadded sentence lengths; LongTensor (num_docs, max_sent_len)
        :param labels: labels; LongTensor (num_docs)
        :param attention_masks: BERT attention masks; LongTensor (num_docs, padded_doc_length, padded_sent_length)
        :param token_type_ids: BERT token type IDs; LongTensor (num_docs, padded_doc_length, padded_sent_length)
        :return: class scores, attention weights of words, attention weights of sentences
        """
        
        # get document embedding for each document by passing the sentence embeddings in the sentence attention model
        doc_embeds, word_att_weights, sentence_att_weights = self.HANmodel(docs, doc_lengths, sent_lengths, attention_masks, token_type_ids)

        if task == 'ab':
          doc_embeds = torch.cat((doc_embeds, b_labels), 1)
        
        scores = self.classifier(doc_embeds)

        outputs = (scores, word_att_weights, sentence_att_weights)

        return outputs

In [None]:
import numpy as np
import torch
from torch.autograd import Function

class GradientReversalFunction(Function):
    """
    Forward pass is the identity function. In the backward pass,
    the upstream gradients are multiplied by -lambda (i.e. gradient is reversed)
    """

    @staticmethod
    def forward(ctx, x, lambda_):
        ctx.lambda_ = lambda_
        return x.clone()

    @staticmethod
    def backward(ctx, grads):
        lambda_ = ctx.lambda_
        lambda_ = grads.new_tensor(lambda_)
        dx = -lambda_ * grads
        return dx, None


class GradientReversal(torch.nn.Module):
    def __init__(self, lambda_=1):
        super(GradientReversal, self).__init__()
        self.lambda_ = lambda_

    def forward(self, x):
        return GradientReversalFunction.apply(x, self.lambda_)

# Task J Data Loading

In [None]:
import os
import re
import torchtext
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from sklearn.preprocessing import MultiLabelBinarizer, StandardScaler, KBinsDiscretizer

def len_condition(each_text):
  k = sum([1 if re.match(r'^\d+\. ',x) is not None else 0 for x in each_text])
  return k if k!=0 else len(each_text)


class ECHRDataset(Dataset):
  def __init__(self, path, tokenizer_path, max_sent_len, max_doc_len,classes=None,scaler=None,paragraph_num_removal = False, vocab_confounder = None):
    self.max_sent_len = max_sent_len
    self.max_doc_len = max_doc_len
    self.encoding = []
    self.labels = []
    self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
    self.sep_token_id = self.tokenizer.sep_token_id
    self.pad_token_id = self.tokenizer.pad_token_id

    df = pd.read_pickle(path)
    df = df.loc[df.RESPONDENT.str.len()==1]
    df = df.reset_index()

    self.labels = df["label"].to_numpy()

    df['len'] = df["TEXT"].apply(lambda row: len_condition(row))

    if paragraph_num_removal :
      df["TEXT"] = df["TEXT"].apply(lambda row: [re.sub(r'^\d+\. ', '',sentence)  for sentence in row])   
      

    text_list = df["TEXT"].apply(lambda row: [self.tokenizer.encode(sentence) for sentence in row])
    spl_tokens = [self.tokenizer.sep_token_id, self.tokenizer.cls_token_id]
    text_list = text_list.apply(lambda row:  [self.filter(sentence, spl_tokens) for sentence in row])
    self.encoding = text_list.apply(lambda row: self.pack_sentences(row))

    if(vocab_confounder):
       cf_tok = list(vocab_confounder)
       self.vocab_cf = df["TEXT"].apply(lambda x: 1 if len(set((" ".join(x)).split()).intersection(set(cf_tok)))!=0 else 0)
    
    mlb = MultiLabelBinarizer()
    if classes is None:
      self.respondents = mlb.fit_transform(df.pop('RESPONDENT'))
    else:
      mlb.fit([classes])
      self.respondents = mlb.transform(df.pop('RESPONDENT'))
      
    
    if scaler is None:
      self.scaler = KBinsDiscretizer(n_bins=5, encode='onehot-dense', strategy='kmeans')
      self.length = self.scaler.fit_transform(df[['len']])
    else:
      self.length = scaler.transform(df[['len']])

    self.classes = mlb.classes_
    # Number of examples.
    self.n_examples = len(self.labels)

    return


  def __len__(self):
    return self.n_examples


  def __getitem__(self, item):
    return (self.encoding[item], self.labels[item], self.respondents[item], list(map(int,self.length[item])), self.vocab_cf[item])
  
  def filter(self, sentence, spl_tokens):
    return [y for y in sentence if y  not in spl_tokens]

  
  def pack_sentences(self,token_encodings):
    doc = []
    sentence_len = [len(x) for x in token_encodings]
    i = 0
    sum = 0
    each_sentence_group = []
    
    while i <len(sentence_len):
      while(sentence_len[i]> self.max_sent_len):
        if(len(each_sentence_group)!=0):
            doc.append(each_sentence_group)
        doc.append(token_encodings[i][:self.max_sent_len])
        sentence_len[i] = sentence_len[i] - self.max_sent_len
        token_encodings[i] = token_encodings[i][self.max_sent_len:]
        each_sentence_group = []
        sum = 0
      sum += sentence_len[i] 
      if(sum<= self.max_sent_len):
        each_sentence_group.extend(token_encodings[i])
      else:
        doc.append(each_sentence_group)
        each_sentence_group = token_encodings[i]
        sum = sentence_len[i] 
      i +=1
    
    if(len(each_sentence_group)!=0):
      doc.append(each_sentence_group)
    
    if(len(doc)>self.max_doc_len):
      doc = doc[:self.max_doc_len]
    return doc

In [None]:
class MyCollate:
    def __init__(self, pad_idx, sep_idx):
        self.pad_token_id = pad_idx
        self.sep_token_id = sep_idx

    def pad_sentence_for_batch(self, tokens_lists, max_len: int):
        pad_id = self.pad_token_id
        toks_ids = []
        att_masks = []
        tok_type_lists = []
        for item_toks in tokens_lists:
            padded_item_toks = item_toks + [pad_id] * (max_len - len(item_toks))
            toks_ids.append(padded_item_toks)

            att_mask = [1] * len(item_toks) + [0] * (max_len - len(item_toks))
            att_masks.append(att_mask)

            tok_type_list = [0] * (max_len)
            tok_type_lists.append(tok_type_list)
            
        return toks_ids, att_masks, tok_type_lists


    def __call__(self, batch):
        batch = filter(lambda x: x is not None, batch)
        docs, labels, country_labels, length, vocab_cf = list(zip(*batch))
        doc_lengths = [len(x) for x in docs]
        sent_lengths = []
        for element in docs:
          sent_lengths.append([len(i) for i in element])
        
        batch_sz = len(labels)
        batch_max_doc_length = max(doc_lengths)
        batch_max_sent_length = max([max(sl) for sl in sent_lengths])

        docs_tensor = torch.zeros((batch_sz, batch_max_doc_length, batch_max_sent_length), dtype=torch.long)
        batch_att_mask_tensor = torch.zeros((batch_sz, batch_max_doc_length, batch_max_sent_length), dtype=torch.long)
        token_type_ids_tensor = torch.zeros((batch_sz, batch_max_doc_length, batch_max_sent_length), dtype=torch.long)
        sent_lengths_tensor = torch.zeros((batch_sz, batch_max_doc_length))

        for doc_idx, doc in enumerate(docs):
            padded_token_lists, att_mask_lists, tok_type_lists = self.pad_sentence_for_batch(doc, batch_max_sent_length)

            doc_length = doc_lengths[doc_idx]
            sent_lengths_tensor[doc_idx, :doc_length] = torch.tensor(sent_lengths[doc_idx], dtype=torch.long)

            for sent_idx, (padded_tokens, att_masks, tok_types) in enumerate(
                    zip(padded_token_lists, att_mask_lists, tok_type_lists)):
                docs_tensor[doc_idx, sent_idx, :] = torch.tensor(padded_tokens, dtype=torch.long)
                batch_att_mask_tensor[doc_idx, sent_idx, :] = torch.tensor(att_masks, dtype=torch.long)
                token_type_ids_tensor[doc_idx, sent_idx, :] = torch.tensor(tok_types, dtype=torch.long)

        return (
            docs_tensor,
            torch.tensor(labels, dtype=torch.long),
            torch.tensor(country_labels, dtype=torch.long),
            torch.tensor(length,dtype=torch.long),
            torch.tensor(vocab_cf, dtype=torch.long),
            torch.tensor(doc_lengths, dtype=torch.long),
            sent_lengths_tensor,
            batch_att_mask_tensor, 
            token_type_ids_tensor,
        )

In [None]:
paragraph_num_removal = True
vocab_confounder = ['represented', 'national', 'mr', 'summarised', 'practising', 'lawyer', 'agent', 'paragraph']

train_dataset = ECHRDataset('train.pkl',tokenizer_path = "nlpaueb/legal-bert-base-uncased", max_sent_len = 500, max_doc_len = 40,paragraph_num_removal = paragraph_num_removal, vocab_confounder = vocab_confounder,vocab_depth =vocab_depth )
dev_dataset = ECHRDataset('dev.pkl',tokenizer_path = "nlpaueb/legal-bert-base-uncased", max_sent_len = 500, max_doc_len = 40, classes= train_dataset.classes,scaler=train_dataset.scaler,paragraph_num_removal = paragraph_num_removal,vocab_confounder = vocab_confounder)
test_dataset = ECHRDataset('test.pkl',tokenizer_path = "nlpaueb/legal-bert-base-uncased", max_sent_len = 500, max_doc_len = 40, classes= train_dataset.classes,scaler=train_dataset.scaler,paragraph_num_removal = paragraph_num_removal,vocab_confounder = vocab_confounder)

In [None]:
train_dataloader =  DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn = MyCollate(pad_idx = train_dataset.pad_token_id, sep_idx = train_dataset.pad_token_id))
dev_dataloader =  DataLoader(dev_dataset, batch_size=16, shuffle=True, collate_fn = MyCollate(pad_idx = dev_dataset.pad_token_id, sep_idx = dev_dataset.pad_token_id))
test_dataloader =  DataLoader(test_dataset, batch_size=16, shuffle=True, collate_fn = MyCollate(pad_idx = test_dataset.pad_token_id, sep_idx = test_dataset.pad_token_id))

## Training only violation prediction Model

In [None]:
import torch.optim as optim
from tqdm import tqdm
from sklearn.metrics import accuracy_score,precision_score,recall_score,f1_score
import itertools


num_epochs = 20
learning_rate = 5e-4
file_prefix = "model_par_rem_"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = HANClassifierModel(device, 0.1, 768, 300, 400, 200, 100, 1,"nlpaueb/legal-bert-base-uncased").to(device)
feature_extractor = model.HANmodel
classifier = model.classifier


for name, param in model.named_parameters():
  if("bert" in name):
    param.requires_grad = False

optimizer = optim.Adam(list(filter(lambda p: p.requires_grad, model.parameters())), lr=learning_rate)
cls_criterion= nn.BCEWithLogitsLoss()

column_names = ["epoch", "mean epoch loss", "accuracy","precision","recall","f1"]
train_df = pd.DataFrame(columns = column_names)
dev_df = pd.DataFrame(columns = column_names)


for epoch in range(num_epochs):
    print(f"[Epoch {epoch} / {num_epochs}]")

    model.train()
    cls_predicted, cls_gold,loss_batch = [],[],[]
    for batch_idx, (docs,labels, country_labels, length, vocab_cf, doc_lengths, sent_lengths, att_mask, token_type_id) in tqdm(enumerate(train_dataloader),total=len(train_dataloader), leave=False):
        # Get input and targets and get to cuda
        docs, labels, country_labels, length,vocab_cf, doc_lengths, sent_lengths, att_mask, token_type_id = docs.to(device), labels.to(device),country_labels.to(device), length.to(device),vocab_cf.to(length), doc_lengths.to(device), sent_lengths.to(device), att_mask.to(device), token_type_id.to(device)

        # Forward prop
        doc_embeds, word_att_weights, sentence_att_weights = feature_extractor(docs, doc_lengths, sent_lengths, att_mask, token_type_id)

        cls_scores = classifier(doc_embeds)
        cls_scores = cls_scores.squeeze(dim=1)        
        
        cls_loss = cls_criterion(cls_scores, labels.float())
        
        loss = cls_loss
        loss_batch.append(loss)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        cls_preds = torch.sigmoid(cls_scores) # Sigmoid to map predictions between 0 and 1
        cls_pred_labels = (cls_preds >= 0.5).long() # Binarize predictions to 0 and 1
        cls_gold.extend(labels.cpu().detach().numpy())
        cls_predicted.extend(cls_pred_labels.cpu().detach().numpy())

    mean_epoch_loss = (torch.stack(loss_batch , dim=0).sum(dim=0)).item()

    cls_accuracy, cls_precision, cls_recall, cls_f1 = accuracy_score(cls_gold,cls_predicted),precision_score(cls_gold, cls_predicted),recall_score(cls_gold, cls_predicted),f1_score(cls_gold, cls_predicted)
    #print(f"Train Cls Statistics : Accuracy : {100.0*cls_accuracy:4.2f}%, {100.0*cls_precision:4.2f}%,{100.0*cls_recall:4.2f}%,{100.0*cls_f1:4.2f}%")   
    train_df.loc[len(train_df)] = [epoch,round(mean_epoch_loss,2),round(100*cls_accuracy,2), round(100*cls_precision,2), round(100*cls_recall,2), round(100*cls_f1,2)]

    model_checkpoint = {'model_state_dict':model.state_dict(),'optimizer': optimizer.state_dict()}
    torch.save(model_checkpoint,file_prefix+str(epoch)+".pth.tar")

    model.eval() 
    cls_predicted, cls_gold,loss_batch = [],[],[]

    with torch.no_grad(): 
      for batch_idx, (docs,labels, country_labels, length, vocab_cf,doc_lengths, sent_lengths, att_mask, token_type_id) in tqdm(enumerate(dev_dataloader),total=len(dev_dataloader), leave=False):
        # Get input and targets and get to cuda
        docs, labels, country_labels, length,vocab_cf, doc_lengths, sent_lengths, att_mask, token_type_id = docs.to(device), labels.to(device),country_labels.to(device),length.to(device),vocab_cf.to(device), doc_lengths.to(device), sent_lengths.to(device), att_mask.to(device), token_type_id.to(device)

        # Forward prop
        doc_embeds, word_att_weights, sentence_att_weights = feature_extractor(docs, doc_lengths, sent_lengths, att_mask, token_type_id)

        #disc_scores = country_discriminator(doc_embeds)
        cls_scores = classifier(doc_embeds)

        cls_scores = cls_scores.squeeze(dim=1)
        cls_loss = cls_criterion(cls_scores, labels.float())

        loss = cls_loss
        loss_batch.append(loss)
        
        cls_preds = torch.sigmoid(cls_scores) # Sigmoid to map predictions between 0 and 1
        cls_pred_labels = (cls_preds >= 0.5).long() # Binarize predictions to 0 and 1
        cls_gold.extend(labels.cpu().detach().numpy())
        cls_predicted.extend(cls_pred_labels.cpu().detach().numpy())


    mean_epoch_loss = (torch.stack(loss_batch , dim=0).sum(dim=0)).item()

    cls_accuracy, cls_precision, cls_recall, cls_f1 = accuracy_score(cls_gold,cls_predicted),precision_score(cls_gold, cls_predicted),recall_score(cls_gold, cls_predicted),f1_score(cls_gold, cls_predicted)
    #print(f"Dev Cls Statistics : Accuracy : {100.0*cls_accuracy:4.2f}%, {100.0*cls_precision:4.2f}%,{100.0*cls_recall:4.2f}%,{100.0*cls_f1:4.2f}%")   
    dev_df.loc[len(dev_df)] = [epoch,round(mean_epoch_loss,2),round(100*cls_accuracy,2), round(100*cls_precision,2), round(100*cls_recall,2), round(100*cls_f1,2)]

    print(train_df)
    print(dev_df)    

## Deconfounding Country/len/vocab/all 

In [None]:
import torch.optim as optim
from tqdm import tqdm
from sklearn.metrics import accuracy_score,precision_score,recall_score,f1_score
import itertools

num_epochs = 20

violation_learning_rate = 5e-5
country_learning_rate = 5e-4
length_learning_rate = 5e-4
vocab_learning_rate = 5e-4

gamma_country = 0.1
gamma_length = 0.1
gamma_vocab = 0.1
country_only = False
length_only = False
vocab_only = False
country_and_length_and_vocab = True
file_prefix = 'grad_vocab_all_'

assert country_only ^ length_only ^ vocab_only ^ country_and_length_and_vocab

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = HANClassifierModel(device, 0.1, 768, 300, 400, 200, 100, 1,"nlpaueb/legal-bert-base-uncased").to(device)


feature_extractor = model.HANmodel
classifier = model.classifier

for name, param in model.named_parameters():
  if("bert" in name):
    param.requires_grad = False

params_dict = []
params_dict.append({'params': list(filter(lambda p: p.requires_grad, model.parameters())), 'lr': violation_learning_rate})

country_modules  = []
len_modules = []
vocab_modules = []


if  country_only or country_and_length_and_vocab:
  country_modules = country_modules + [
              torch.nn.Linear(400, 100), 
              nn.ReLU(), 
              nn.Dropout(0.1),
              torch.nn.Linear(100,len(train_dataset.classes))
            ]
  country_discriminator = nn.Sequential(*country_modules).to(device)
  params_dict.append({'params': list(country_discriminator.parameters()), 'lr': country_learning_rate})

if length_only or country_and_length_and_vocab:
  len_modules =  len_modules + [
              torch.nn.Linear(400, 100), 
              nn.ReLU(), 
              nn.Dropout(0.1),
              torch.nn.Linear(100,5)
            ]
  len_discriminator = nn.Sequential(*len_modules).to(device)
  params_dict.append({'params': list(len_discriminator.parameters()), 'lr': length_learning_rate})

if vocab_only or country_and_length_and_vocab:
  vocab_modules =  vocab_modules + [
              torch.nn.Linear(400, 100), 
              nn.ReLU(), 
              nn.Dropout(0.1),
              torch.nn.Linear(100,8)
            ]
  vocab_discriminator = nn.Sequential(*vocab_modules).to(device)
  params_dict.append({'params': list(vocab_discriminator.parameters()), 'lr': vocab_learning_rate})



optimizer = optim.Adam(params_dict)


cls_criterion= nn.BCEWithLogitsLoss()
disc_criterion = nn.CrossEntropyLoss()
len_criterion = nn.CrossEntropyLoss()
vocab_criterion= nn.BCEWithLogitsLoss()

column_names = ["cls_loss", "cls_accuracy","cls_precision","cls_recall","cls_f1","disc_loss", "disc_accuracy","disc_precision","disc_recall","disc_f1","len_loss", "len_accuracy","len_precision","len_recall","len_f1","vocab_loss", "vocab_accuracy","vocab_precision","vocab_recall","vocab_f1"]
train_df = pd.DataFrame(columns = column_names)
dev_df = pd.DataFrame(columns = column_names)

def get_lambda(gamma, p):
   return 2. / (1. + np.exp(-gamma * p)) - 1

for epoch in range(num_epochs):
    print(f"[Epoch {epoch} / {num_epochs}]")

    model.train()
    if country_only or country_and_length_and_vocab:
        country_discriminator.train() 
    if length_only or country_and_length_and_vocab: 
        len_discriminator.train() 
    if vocab_only or country_and_length_and_vocab: 
        vocab_discriminator.train() 

    cls_predicted, cls_gold,cls_loss_batch = [],[],[]
    disc_predicted, disc_gold, disc_loss_batch = [],[],[]
    vocab_predicted, vocab_gold,vocab_loss_batch = [],[],[]
    len_predicted, len_gold, len_loss_batch = [],[],[]

    current_iteration = 0
    for batch_idx, (docs,labels, country_labels,length,vocab, doc_lengths, sent_lengths, att_mask, token_type_id) in tqdm(enumerate(train_dataloader),total=len(train_dataloader), leave=False):
        
        current_iteration += 1
        p = float(current_iteration + epoch * len(src_train_dataloader)) / num_epochs / len(src_train_dataloader)
        lambda_country, lambda_length, lambda_vocab = get_lambda(gamma_country,p), get_lambda(gamma_length,p),get_lambda(gamma_vocab,p)
        
        # Get input and targets and get to cuda
        docs, labels, country_labels,length,vocab, doc_lengths, sent_lengths, att_mask, token_type_id = docs.to(device), labels.to(device),country_labels.to(device),length.to(device),vocab.to(device), doc_lengths.to(device), sent_lengths.to(device), att_mask.to(device), token_type_id.to(device)

        # Forward prop
        doc_embeds, word_att_weights, sentence_att_weights = feature_extractor(docs, doc_lengths, sent_lengths, att_mask, token_type_id)

        cls_scores = classifier(doc_embeds)
        cls_scores = cls_scores.squeeze(dim=1)
        cls_loss = cls_criterion(cls_scores, labels.float()) 
        cls_loss_batch.append(cls_loss)       
        loss = cls_loss 
        
        
        if country_only or country_and_length_and_vocab:  
            doc_embeds_country = GradientReversalFunction.apply(doc_embeds, lambda_country)
            disc_scores = country_discriminator(doc_embeds_country)
            disc_loss = disc_criterion(disc_scores, country_labels.argmax(dim=1))  
            disc_loss_batch.append(disc_loss)
            loss = loss + disc_loss
        if length_only or country_and_length_and_vocab: 
            doc_embeds_length = GradientReversalFunction.apply(doc_embeds, lambda_length)   
            len_scores = len_discriminator(doc_embeds_length)
            len_loss = len_criterion(len_scores, length.argmax(dim=1)) 
            len_loss_batch.append(len_loss)
            loss = loss + len_loss  
        if vocab_only or country_and_length_and_vocab:  
            doc_embeds_vocab = GradientReversalFunction.apply(doc_embeds, lambda_vocab)
            vocab_scores = vocab_discriminator(doc_embeds_vocab)
            vocab_scores = vocab_scores.flatten()
            vocab_loss = vocab_criterion(vocab_scores, vocab.float().flatten()) 
            vocab_loss_batch.append(vocab_loss)  
            loss = loss + vocab_loss  


        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        cls_preds = torch.sigmoid(cls_scores) # Sigmoid to map predictions between 0 and 1
        cls_pred_labels = (cls_preds >= 0.5).long() # Binarize predictions to 0 and 1
        cls_gold.extend(labels.cpu().detach().numpy())
        cls_predicted.extend(cls_pred_labels.cpu().detach().numpy())

        if country_only or country_and_length_and_vocab: 
            disc_preds = torch.softmax(disc_scores, dim=1) 
            disc_gold.extend(country_labels.argmax(dim=1).cpu().detach().numpy())
            disc_predicted.extend(disc_preds.argmax(dim=1).cpu().detach().numpy())
        
        if length_only or country_and_length_and_vocab: 
            len_preds = torch.softmax(len_scores, dim=1) 
            len_gold.extend(length.argmax(dim=1).cpu().detach().numpy())
            len_predicted.extend(len_preds.argmax(dim=1).cpu().detach().numpy())

        if vocab_only or country_and_length_and_vocab: 
            vocab_preds = torch.sigmoid(vocab_scores) # Sigmoid to map predictions between 0 and 1
            vocab_pred_labels = (vocab_preds >= 0.5).long() # Binarize predictions to 0 and 1
            #vocab_gold.extend(vocab.cpu().detach().numpy())
            vocab_gold.extend(vocab.flatten().cpu().detach().numpy())
            vocab_predicted.extend(vocab_pred_labels.cpu().detach().numpy())

    cls_epoch_loss = (torch.stack(cls_loss_batch , dim=0).sum(dim=0)).item()
    if country_only or country_and_length_and_vocab: 
        disc_epoch_loss = (torch.stack(disc_loss_batch , dim=0).sum(dim=0)).item()
    if length_only or country_and_length_and_vocab:   
        len_epoch_loss = (torch.stack(len_loss_batch , dim=0).sum(dim=0)).item()
    if vocab_only or country_and_length_and_vocab:   
        vocab_epoch_loss = (torch.stack(vocab_loss_batch , dim=0).sum(dim=0)).item()

    cls_accuracy, cls_precision, cls_recall, cls_f1 = accuracy_score(cls_gold,cls_predicted),precision_score(cls_gold, cls_predicted),recall_score(cls_gold, cls_predicted),f1_score(cls_gold, cls_predicted)
    #print(f"Train Cls Statistics : Accuracy : {100.0*cls_accuracy:4.2f}%, {100.0*cls_precision:4.2f}%,{100.0*cls_recall:4.2f}%,{100.0*cls_f1:4.2f}%")   
    
    if country_only or country_and_length_and_vocab: 
        disc_accuracy, disc_precision, disc_recall, disc_f1 = accuracy_score(disc_gold,disc_predicted),precision_score(disc_gold, disc_predicted,average='macro'),recall_score(disc_gold, disc_predicted,average='macro'),f1_score(disc_gold,disc_predicted,average='macro')
        #print(f"Train Disc Statistics : Accuracy : {100.0*disc_accuracy:4.2f}%, {100.0*disc_precision:4.2f}%,{100.0*disc_recall:4.2f}%,{100.0*disc_f1:4.2f}%")   
    
    if vocab_only or country_and_length_and_vocab: 
        vocab_accuracy, vocab_precision, vocab_recall, vocab_f1 = accuracy_score(vocab_gold,vocab_predicted),precision_score(vocab_gold, vocab_predicted),recall_score(vocab_gold, vocab_predicted),f1_score(vocab_gold, vocab_predicted)

    if length_only or country_and_length_and_vocab: 
        len_accuracy, len_precision, len_recall, len_f1 = accuracy_score(len_gold,len_predicted),precision_score(len_gold, len_predicted,average='macro'),recall_score(len_gold, len_predicted,average='macro'),f1_score(len_gold, len_predicted,average='macro')

    if country_only:
        train_df.loc[len(train_df)] = [round(cls_epoch_loss,2),round(100*cls_accuracy,2), round(100*cls_precision,2), round(100*cls_recall,2), round(100*cls_f1,2),round(disc_epoch_loss,2),round(100*disc_accuracy,2), round(100*disc_precision,2), round(100*disc_recall,2), round(100*disc_f1,2),"","","","","","","","","",""]
        model_checkpoint = {'model_state_dict':model.state_dict(),'disc_state_dict': country_discriminator.state_dict(),'optimizer': optimizer.state_dict()}
    
    if length_only:
        train_df.loc[len(train_df)] = [round(cls_epoch_loss,2),round(100*cls_accuracy,2), round(100*cls_precision,2), round(100*cls_recall,2), round(100*cls_f1,2),"","","","","",round(len_epoch_loss,2),round(100*len_accuracy,2), round(100*len_precision,2), round(100*len_recall,2), round(100*len_f1,2),"","","","",""]
        model_checkpoint = {'model_state_dict':model.state_dict(),'len_state_dict': len_discriminator.state_dict(),'optimizer': optimizer.state_dict()}
    
    if vocab_only:
        train_df.loc[len(train_df)] = [round(cls_epoch_loss,2),round(100*cls_accuracy,2), round(100*cls_precision,2), round(100*cls_recall,2), round(100*cls_f1,2),"","","","","","","","","","",round(vocab_epoch_loss,2),round(100*vocab_accuracy,2), round(100*vocab_precision,2), round(100*vocab_recall,2), round(100*vocab_f1,2)]
        model_checkpoint = {'model_state_dict':model.state_dict(),'vocab_state_dict': vocab_discriminator.state_dict(),'optimizer': optimizer.state_dict()}
    

    if country_and_length_and_vocab:
        train_df.loc[len(train_df)] = [round(cls_epoch_loss,2),round(100*cls_accuracy,2), round(100*cls_precision,2), round(100*cls_recall,2), round(100*cls_f1,2),round(disc_epoch_loss,2),round(100*disc_accuracy,2), round(100*disc_precision,2), round(100*disc_recall,2), round(100*disc_f1,2),round(len_epoch_loss,2),round(100*len_accuracy,2), round(100*len_precision,2), round(100*len_recall,2), round(100*len_f1,2),round(vocab_epoch_loss,2),round(100*vocab_accuracy,2), round(100*vocab_precision,2), round(100*vocab_recall,2), round(100*vocab_f1,2)]
        model_checkpoint = {'model_state_dict':model.state_dict(),'disc_state_dict': country_discriminator.state_dict(),'len_state_dict': len_discriminator.state_dict(),'vocab_state_dict': vocab_discriminator.state_dict(),'optimizer': optimizer.state_dict()}
    
    
    torch.save(model_checkpoint,file_prefix+str(epoch)+".pth.tar")

    model.eval()
    if country_only or country_and_length_and_vocab:
        country_discriminator.eval() 
    if length_only or country_and_length_and_vocab: 
        len_discriminator.eval() 
    if vocab_only or country_and_length_and_vocab: 
        vocab_discriminator.eval() 
    cls_predicted, cls_gold,cls_loss_batch = [],[],[]
    disc_predicted, disc_gold,disc_loss_batch = [],[],[]
    vocab_predicted, vocab_gold, vocab_loss_batch = [],[],[]
    len_predicted, len_gold, len_loss_batch = [],[],[]
    with torch.no_grad(): 
      for batch_idx, (docs,labels, country_labels,length, vocab, doc_lengths, sent_lengths, att_mask, token_type_id) in tqdm(enumerate(dev_dataloader),total=len(dev_dataloader), leave=False):
        # Get input and targets and get to cuda
        docs, labels, country_labels,length,vocab, doc_lengths, sent_lengths, att_mask, token_type_id = docs.to(device), labels.to(device),country_labels.to(device),length.to(device),vocab.to(device), doc_lengths.to(device), sent_lengths.to(device), att_mask.to(device), token_type_id.to(device)

        # Forward prop
        doc_embeds, word_att_weights, sentence_att_weights = feature_extractor(docs, doc_lengths, sent_lengths, att_mask, token_type_id)

        cls_scores = classifier(doc_embeds)
        cls_scores = cls_scores.squeeze(dim=1)
        cls_loss = cls_criterion(cls_scores, labels.float()) 
        cls_loss_batch.append(cls_loss)       
        loss = cls_loss 
        if country_only or country_and_length_and_vocab:  
            doc_embeds_country = GradientReversalFunction.apply(doc_embeds, lambda_country) 
            disc_scores = country_discriminator(doc_embeds_country)
            disc_loss = disc_criterion(disc_scores, country_labels.argmax(dim=1))  
            disc_loss_batch.append(disc_loss)
            loss += disc_loss
        if length_only or country_and_length_and_vocab: 
            doc_embeds_length = GradientReversalFunction.apply(doc_embeds, lambda_length)   
            len_scores = len_discriminator(doc_embeds_length)
            len_loss = len_criterion(len_scores, length.argmax(dim=1)) 
            len_loss_batch.append(len_loss)
            loss = loss + len_loss  
        if vocab_only or country_and_length_and_vocab: 
            doc_embeds_vocab = GradientReversalFunction.apply(doc_embeds, lambda_vocab) 
            vocab_scores = vocab_discriminator(doc_embeds_vocab)
            #vocab_scores = vocab_scores.squeeze(dim=1)
            vocab_scores = vocab_scores.flatten()
            #vocab_loss = vocab_criterion(vocab_scores, vocab.float()) 
            vocab_loss = vocab_criterion(vocab_scores, vocab.float().flatten()) 
            vocab_loss_batch.append(vocab_loss)  
            loss = loss + vocab_loss   


        cls_preds = torch.sigmoid(cls_scores) # Sigmoid to map predictions between 0 and 1
        cls_pred_labels = (cls_preds >= 0.5).long() # Binarize predictions to 0 and 1
        cls_gold.extend(labels.cpu().detach().numpy())
        cls_predicted.extend(cls_pred_labels.cpu().detach().numpy())

        if country_only or country_and_length_and_vocab: 
            disc_preds = torch.softmax(disc_scores, dim=1) 
            disc_gold.extend(country_labels.argmax(dim=1).cpu().detach().numpy())
            disc_predicted.extend(disc_preds.argmax(dim=1).cpu().detach().numpy())
        
        if length_only or country_and_length_and_vocab: 
            len_preds = torch.softmax(len_scores, dim=1) 
            len_gold.extend(length.argmax(dim=1).cpu().detach().numpy())
            len_predicted.extend(len_preds.argmax(dim=1).cpu().detach().numpy())

        if vocab_only or country_and_length_and_vocab: 
            vocab_preds = torch.sigmoid(vocab_scores) # Sigmoid to map predictions between 0 and 1
            vocab_pred_labels = (vocab_preds >= 0.5).long() # Binarize predictions to 0 and 1
            #vocab_gold.extend(vocab.cpu().detach().numpy())
            vocab_gold.extend(vocab.flatten().cpu().detach().numpy())
            vocab_predicted.extend(vocab_pred_labels.cpu().detach().numpy())

    cls_epoch_loss = (torch.stack(cls_loss_batch , dim=0).sum(dim=0)).item()
    if country_only or country_and_length_and_vocab: 
        disc_epoch_loss = (torch.stack(disc_loss_batch , dim=0).sum(dim=0)).item()
    if length_only or country_and_length_and_vocab:   
        len_epoch_loss = (torch.stack(len_loss_batch , dim=0).sum(dim=0)).item()
    if vocab_only or country_and_length_and_vocab:   
        vocab_epoch_loss = (torch.stack(vocab_loss_batch , dim=0).sum(dim=0)).item()


    cls_accuracy, cls_precision, cls_recall, cls_f1 = accuracy_score(cls_gold,cls_predicted),precision_score(cls_gold, cls_predicted),recall_score(cls_gold, cls_predicted),f1_score(cls_gold, cls_predicted)
    #print(f"Dev Cls Statistics : Accuracy : {100.0*cls_accuracy:4.2f}%, {100.0*cls_precision:4.2f}%,{100.0*cls_recall:4.2f}%,{100.0*cls_f1:4.2f}%")   

    if country_only or country_and_length_and_vocab: 
        disc_accuracy, disc_precision, disc_recall, disc_f1 = accuracy_score(disc_gold,disc_predicted),precision_score(disc_gold, disc_predicted,average='macro'),recall_score(disc_gold, disc_predicted,average='macro'),f1_score(disc_gold,disc_predicted,average='macro')
        #print(f"Train Disc Statistics : Accuracy : {100.0*disc_accuracy:4.2f}%, {100.0*disc_precision:4.2f}%,{100.0*disc_recall:4.2f}%,{100.0*disc_f1:4.2f}%")   
    
    if vocab_only or country_and_length_and_vocab: 
        vocab_accuracy, vocab_precision, vocab_recall, vocab_f1 = accuracy_score(vocab_gold,vocab_predicted),precision_score(vocab_gold, vocab_predicted),recall_score(vocab_gold, vocab_predicted),f1_score(vocab_gold, vocab_predicted)

    if length_only or country_and_length_and_vocab: 
        len_accuracy, len_precision, len_recall, len_f1 = accuracy_score(len_gold,len_predicted),precision_score(len_gold, len_predicted,average='macro'),recall_score(len_gold, len_predicted,average='macro'),f1_score(len_gold, len_predicted,average='macro')

    if country_only:
        dev_df.loc[len(dev_df)] = [round(cls_epoch_loss,2),round(100*cls_accuracy,2), round(100*cls_precision,2), round(100*cls_recall,2), round(100*cls_f1,2),round(disc_epoch_loss,2),round(100*disc_accuracy,2), round(100*disc_precision,2), round(100*disc_recall,2), round(100*disc_f1,2),"","","","","","","","","",""]

    if length_only:
        dev_df.loc[len(dev_df)] = [round(cls_epoch_loss,2),round(100*cls_accuracy,2), round(100*cls_precision,2), round(100*cls_recall,2), round(100*cls_f1,2),"","","","","",round(len_epoch_loss,2),round(100*len_accuracy,2), round(100*len_precision,2), round(100*len_recall,2), round(100*len_f1,2),"","","","",""]

    if vocab_only:
        dev_df.loc[len(dev_df)] = [round(cls_epoch_loss,2),round(100*cls_accuracy,2), round(100*cls_precision,2), round(100*cls_recall,2), round(100*cls_f1,2),"","","","","","","","","","",round(vocab_epoch_loss,2),round(100*vocab_accuracy,2), round(100*vocab_precision,2), round(100*vocab_recall,2), round(100*vocab_f1,2)]

    if country_and_length_and_vocab:
        dev_df.loc[len(dev_df)] = [round(cls_epoch_loss,2),round(100*cls_accuracy,2), round(100*cls_precision,2), round(100*cls_recall,2), round(100*cls_f1,2),round(disc_epoch_loss,2),round(100*disc_accuracy,2), round(100*disc_precision,2), round(100*disc_recall,2), round(100*disc_f1,2),round(len_epoch_loss,2),round(100*len_accuracy,2), round(100*len_precision,2), round(100*len_recall,2), round(100*len_f1,2),round(vocab_epoch_loss,2),round(100*vocab_accuracy,2), round(100*vocab_precision,2), round(100*vocab_recall,2), round(100*vocab_f1,2)]

   
    print(train_df)
    print(dev_df)    

    gc.collect()
    torch.cuda.empty_cache()

# Test Set

In [None]:
checkpoint_file = 'bert_a_s_t_grl_src9.pth.tar'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = HANClassifierModel(device, 0.1, 768, 300, 400, 200, 100, 1,"nlpaueb/legal-bert-base-uncased").to(device)
feature_extractor = model.HANmodel
classifier = model.classifier


checkpoint = torch.load(checkpoint_file)
model.load_state_dict(checkpoint['model_state_dict'])


cls_predicted, cls_gold = [],[]
    
with torch.no_grad(): 
      for batch_idx, (docs,labels, country_labels,length,vocab_cf, doc_lengths, sent_lengths, att_mask, token_type_id) in tqdm(enumerate(test_dataloader),total=len(test_dataloader), leave=False):
        # Get input and targets and get to cuda
        docs, labels, country_labels,length, vocab_cf,doc_lengths, sent_lengths, att_mask, token_type_id = docs.to(device), labels.to(device),country_labels.to(device), length.to(device),vocab_cf.to(length), doc_lengths.to(device), sent_lengths.to(device), att_mask.to(device), token_type_id.to(device)

        # Forward prop
        doc_embeds, word_att_weights, sentence_att_weights = feature_extractor(docs, doc_lengths, sent_lengths, att_mask, token_type_id)

        #disc_scores = country_discriminator(doc_embeds)
        cls_scores = classifier(doc_embeds)

        cls_scores = cls_scores.squeeze(dim=1)
        
        cls_preds = torch.sigmoid(cls_scores) # Sigmoid to map predictions between 0 and 1
        cls_pred_labels = (cls_preds >= 0.5).long() # Binarize predictions to 0 and 1
        cls_gold.extend(labels.cpu().detach().numpy())
        cls_predicted.extend(cls_pred_labels.cpu().detach().numpy())

    cls_accuracy, cls_precision, cls_recall, cls_f1 = accuracy_score(cls_gold,cls_predicted),precision_score(cls_gold, cls_predicted),recall_score(cls_gold, cls_predicted),f1_score(cls_gold, cls_predicted)
    print(f"Test Cls Statistics : Accuracy : {100.0*cls_accuracy:4.2f}%, {100.0*cls_precision:4.2f}%,{100.0*cls_recall:4.2f}%,{100.0*cls_f1:4.2f}%")    

# IG

In [None]:
import pandas as pd
import re
from transformers import AutoTokenizer

paragraph_num_removal = True
tokenizer_path = "nlpaueb/legal-bert-base-uncased"
max_sent_len = 500
max_doc_len = 40

tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

def spl_filter(sentence, spl_tokens):
    return [y for y in sentence if y  not in spl_tokens]

def pack_sentences(token_encodings, max_sent_len, max_doc_len):
    doc = []
    sentence_len = [len(x) for x in token_encodings]
    i = 0
    sum = 0
    each_sentence_group = []
    mapping = []

    while i <len(sentence_len):
      start = sum
      sum += sentence_len[i] 
      if(sentence_len[i]>max_sent_len):
        print("greater than 512")
      if(sum<= max_sent_len):
        each_sentence_group.extend(token_encodings[i])
        mapping.append({"start": start , "end": sum-1 , "pack_num":len(doc)})
      else:
        doc.append(each_sentence_group)
        each_sentence_group = token_encodings[i]
        sum = sentence_len[i] 
        mapping.append({"start": 0 , "end": sum-1 , "pack_num":len(doc)})
      i +=1
    
    if(len(each_sentence_group)!=0):
      doc.append(each_sentence_group)

    if(len(doc)>max_doc_len):
      doc = doc[:max_doc_len]
      mapping = [x for x in mapping if x["pack_num"] < max_doc_len]

    return doc, mapping

def encode(text):
    if paragraph_num_removal :
      text = [re.sub(r'^\d+\. ', '',paragraph)  for paragraph in text]
      
    text_tokens = [tokenizer.encode(paragraph) for paragraph in text]
    
    spl_tokens = [tokenizer.sep_token_id, tokenizer.cls_token_id]
    text_tokenized = [spl_filter(paragraph, spl_tokens) for paragraph in text_tokens]
    
    encoding,mapping = pack_sentences(text_tokenized, max_sent_len, max_doc_len)
    return encoding, mapping


In [None]:
def pad_sentence_for_batch(tokens_lists, max_len: int, pad_token_id):
        pad_id = pad_token_id
        toks_ids = []
        att_masks = []
        tok_type_lists = []
        for item_toks in tokens_lists:
            padded_item_toks = item_toks + [pad_id] * (max_len - len(item_toks))
            toks_ids.append(padded_item_toks)

            att_mask = [1] * len(item_toks) + [0] * (max_len - len(item_toks))
            att_masks.append(att_mask)

            tok_type_list = [0] * (max_len)
            tok_type_lists.append(tok_type_list)
            
        return toks_ids, att_masks, tok_type_lists


def batchify(encoding, pad_token_id):
        doc_lengths = [len(encoding)]
        sent_lengths = []
        sent_lengths.append([len(i) for i in encoding])
        
        batch_sz = 1
        batch_max_doc_length = max(doc_lengths)
        batch_max_sent_length = max([max(sl) for sl in sent_lengths])

        docs_tensor = torch.zeros((batch_sz, batch_max_doc_length, batch_max_sent_length), dtype=torch.long)
        batch_att_mask_tensor = torch.zeros((batch_sz, batch_max_doc_length, batch_max_sent_length), dtype=torch.long)
        token_type_ids_tensor = torch.zeros((batch_sz, batch_max_doc_length, batch_max_sent_length), dtype=torch.long)
        sent_lengths_tensor = torch.zeros((batch_sz, batch_max_doc_length))

        
        doc_idx = 0
        padded_token_lists, att_mask_lists, tok_type_lists = pad_sentence_for_batch(encoding, batch_max_sent_length,pad_token_id)
        doc_length = doc_lengths[doc_idx]
        sent_lengths_tensor[doc_idx, :doc_length] = torch.tensor(sent_lengths[doc_idx], dtype=torch.long)

        for sent_idx, (padded_tokens, att_masks, tok_types) in enumerate(
                    zip(padded_token_lists, att_mask_lists, tok_type_lists)):
                docs_tensor[doc_idx, sent_idx, :] = torch.tensor(padded_tokens, dtype=torch.long)
                batch_att_mask_tensor[doc_idx, sent_idx, :] = torch.tensor(att_masks, dtype=torch.long)
                token_type_ids_tensor[doc_idx, sent_idx, :] = torch.tensor(tok_types, dtype=torch.long)

        return (
            docs_tensor,
            torch.tensor(doc_lengths, dtype=torch.long),
            sent_lengths_tensor,
            batch_att_mask_tensor, 
            token_type_ids_tensor,
        )

## Dev 

In [None]:
#sampling an input 

import pandas as pd

df = pd.read_pickle('dev.pkl')
#text = (df.iloc[[1]]['TEXT'].to_numpy())[0]
#print(text)

In [None]:
ids = ['001-4553',
'001-57904',
'001-58115',
'001-58227',
'001-61317',
'001-67816',
'001-69131',
'001-73291',
'001-76677',
'001-77181',
'001-80083',
'001-83482',
'001-90165',
'001-88523',
'001-97402',
'001-91415',
'001-90276',
'001-102788',
'001-102634',
'001-102256',
'001-101853',
'001-107174',
'001-108189',
'001-108225',
'001-106159']

### Test

In [None]:
#sampling an input 

import pandas as pd

df = pd.read_pickle('test.pkl')
#text = (df.iloc[[1]]['TEXT'].to_numpy())[0]
#print(text)

In [None]:
ids = ['001-140004',
       '001-140021',
       '001-140754',
       '001-142196',
       '001-144107',
       '001-145014',
       '001-154400',
       '001-155374',
       '001-158470',
       '001-158484',
       '001-161061',
       '001-162023',
       '001-164315',
       '001-166774',
       '001-167112',
       '001-170052',
       '001-170360',
       '001-170858',
       '001-171971',
       '001-178750',
       '001-180548',
       '001-183127',
       '001-184659',
       '001-184674',
       '001-185320',
       ]

### IG generation

In [None]:
def IG(text):
      encoding, mapping = encode(text)
      docs, doc_lengths, sent_lengths, att_mask, token_type_id = batchify (encoding, tokenizer.pad_token_id)
      docs, doc_lengths, sent_lengths, att_mask, token_type_id = docs.to(device), doc_lengths.to(device), sent_lengths.to(device), att_mask.to(device), token_type_id.to(device)
      #print(docs.shape)
      sents, sent_lengths_pre, attn_masks, token_types, docs_valid_bsz,  doc_perm_idx, sent_perm_idx = model.HANmodel.word_attention_pre(
                      docs, doc_lengths, sent_lengths, att_mask, token_type_id
                  )

      reference_indices = torch.full(sents.shape, tokenizer.pad_token_id).to(device)

      attributions_ig = lig.attribute(sents, baselines = reference_indices,  additional_forward_args=(sent_lengths_pre, attn_masks, token_types, docs_valid_bsz,  doc_perm_idx, sent_perm_idx), n_steps=1)
      attributions = attributions_ig.sum(dim=-1)
      attributions = attributions / torch.norm(attributions)
      return attributions, encoding, mapping
      

In [None]:
import pickle


import numpy

!pip install captum
from captum.attr import visualization as viz
from captum.attr import LayerConductance, LayerIntegratedGradients, IntegratedGradients
from captum.attr import configure_interpretable_embedding_layer

import gc

gc.collect()
torch.cuda.empty_cache()
torch.backends.cudnn.enabled=False

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = HANClassifierModel(device, 0.1, 768, 300, 400, 200, 100, 1,"nlpaueb/legal-bert-base-uncased").to(device)

def my_forward_func(sents, sent_lengths, attn_masks, token_types, docs_valid_bsz,  doc_perm_idx, sent_perm_idx):
  sent_embeddings, doc_perm_idx, docs_valid_bsz, word_att_weights = model.HANmodel.word_attention_post(
                 sents, sent_lengths, attn_masks, token_types, docs_valid_bsz,  doc_perm_idx, sent_perm_idx
            )
  doc_embeds, word_att_weights, sentence_att_weights = model.HANmodel.sentence_attention(
            sent_embeddings, doc_perm_idx, docs_valid_bsz, word_att_weights
        )
  cls_scores = model.classifier(doc_embeds)
  cls_scores = cls_scores.squeeze(dim=1)        
  cls_preds = torch.sigmoid(cls_scores) 
  return cls_preds


scores_dict = {}

checkpoint_file = 'grad_cou_2.pth.tar'
model_arch = 'grad_cou'
checkpoint = torch.load(checkpoint_file)
model.load_state_dict(checkpoint['model_state_dict'])


lig = LayerIntegratedGradients(my_forward_func,model.HANmodel.word_attention_post.bert_model.embeddings)

for id in ids:
          model.eval()
          model.zero_grad()
          text = ((df.loc[df['ITEMID'] == id]['TEXT']).to_numpy())[0]
          attributions, encoding, mapping = IG(text)
          scores_dict[id] = []
          scores_dict[id].append({model_arch: {"mapping":mapping, "attributions" :attributions} })


          gc.collect()
          torch.cuda.empty_cache()

with open('final_lex/grad_cou_test.pkl', 'wb') as f:
   pickle.dump(scores_dict, f)

# Lexglue

In [None]:
lex_b = {}
lex_b['train'] = load_dataset("lex_glue", "ecthr_b", split='train')
lex_b['val'] = load_dataset("lex_glue", "ecthr_b", split='validation')
lex_b['test'] = load_dataset("lex_glue", "ecthr_b", split='test')

In [None]:
import os
import re
import torchtext
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from transformers import BertModel
from sklearn.preprocessing import MultiLabelBinarizer, StandardScaler, KBinsDiscretizer
import torch
import json

def len_condition(each_text):
  k = sum([1 if re.match(r'^\d+\. ',x) is not None else 0 for x in each_text])
  return k if k!=0 else len(each_text)

class LexGlueDataset(Dataset):
  def __init__(self, dataset_path, bert_path, max_sent_len, max_doc_len,task_a_classes=None,task_b_classes=None,cou_classes=None,scaler=None,paragraph_num_removal = False, vocab_confounder_b =None, vocab_confounder_a = None ,vocab_confounder_ab = None):
    self.max_sent_len = max_sent_len
    self.max_doc_len = max_doc_len
    self.encoding = []
    self.labels = []
    self.tokenizer = AutoTokenizer.from_pretrained(bert_path)
    self.sep_token_id = self.tokenizer.sep_token_id
    self.pad_token_id = self.tokenizer.pad_token_id

    with open(dataset_path, 'r') as json_file:
      json_list = list(json_file)

    data = []
    for json_str in json_list:
        data.append(json.loads(json_str))

    df = pd.DataFrame(data)
    df = df[['facts', 'defendants','allegedly_violated_articles','violated_articles','silver_rationales', 'gold_rationales']]

    task_b = {'6','3','5','P1-1','8','2','10','11','14','9'}
    task_a = {'6','3','5','P1-1','8','2','10','11','14','9'}

    
    df['def_len'] = df['defendants'].apply(lambda x:len(x))
    
    df = df.loc[df['def_len'] != 0]
    df = df.reset_index()

    df['defendants'] = df['defendants'].apply(lambda x:[x[0]])
    
    df['len'] = df["facts"].apply(lambda row: len_condition(row))

    if paragraph_num_removal :
      df["facts"] = df["facts"].apply(lambda row: [re.sub(r'^\d+\. ', '',sentence)  for sentence in row])   
    
    text_list = df["facts"].apply(lambda row: [self.tokenizer.encode(sentence) for sentence in row])
    spl_tokens = [self.tokenizer.sep_token_id, self.tokenizer.cls_token_id]
    text_list = text_list.apply(lambda row:  [self.filter(sentence, spl_tokens) for sentence in row])
    self.encoding = text_list.apply(lambda row: self.pack_sentences(row))

    df['allegedly_violated_articles'] = df['allegedly_violated_articles'].apply(lambda row: [article  for article in row if article in task_b]) 
    df['violated_articles'] = df['violated_articles'].apply(lambda row: [article  for article in row if article in task_a]) 

    
    if(vocab_confounder_b ):
        final_tok_b = vocab_confounder_b
        self.vocab_cf_b = []
        for i in range(len(final_tok_b)):
          self.vocab_cf_b.append(df["facts"].apply(lambda x: 1 if len(set((" ".join(x)).split()).intersection(set(final_tok_b[i])))!=0 else 0))
        self.vocab_cf_b_zip = list(zip(*self.vocab_cf_b))
    else:
        self.vocab_cf_b_zip = [0]*len(self.encoding)
    
    if(vocab_confounder_a):
        final_tok_a = vocab_confounder_a
        self.vocab_cf_a = []
        for i in range(len(final_tok_a)):
          self.vocab_cf_a.append(df["facts"].apply(lambda x: 1 if len(set((" ".join(x)).split()).intersection(set(final_tok_a[i])))!=0 else 0))
        self.vocab_cf_a_zip = list(zip(*self.vocab_cf_a))
    else:
        self.vocab_cf_a_zip = [0]*len(self.encoding)

    if(vocab_confounder_ab ):
        final_tok_ab = vocab_confounder_ab
        self.vocab_cf_ab = []
        for i in range(len(final_tok_ab)):
          self.vocab_cf_ab.append(df["facts"].apply(lambda x: 1 if len(set((" ".join(x)).split()).intersection(set(final_tok_ab[i])))!=0 else 0))
        self.vocab_cf_ab_zip = list(zip(*self.vocab_cf_ab))
    else:
        self.vocab_cf_ab_zip = [0]*len(self.encoding)
    

    mlb_task_b = MultiLabelBinarizer()
    if task_b_classes is None:
      self.task_b_labels = mlb_task_b.fit_transform(df.pop('allegedly_violated_articles'))
    else:
      mlb_task_b.fit([task_b_classes])
      self.task_b_labels = mlb_task_b.transform(df.pop('allegedly_violated_articles'))

    mlb_task_a = MultiLabelBinarizer()
    if task_a_classes is None:
      self.task_a_labels = mlb_task_a.fit_transform(df.pop('violated_articles'))
    else:
      mlb_task_a.fit([task_a_classes])
      self.task_a_labels = mlb_task_a.transform(df.pop('violated_articles'))

    mlb_cou = MultiLabelBinarizer()
    if cou_classes is None:
      self.country = mlb_cou.fit_transform(df.pop('defendants'))
    else:
      mlb_cou.fit([cou_classes])
      self.country = mlb_cou.transform(df.pop('defendants'))
    
    self.length = df['len']
    
    if scaler is None:
      self.scaler = KBinsDiscretizer(n_bins=4, encode='onehot-dense', strategy='kmeans')
      self.length = self.scaler.fit_transform(df[['len']])
    else:
      self.length = scaler.transform(df[['len']])

    self.silver_rationales = df['silver_rationales'].to_numpy()

    self.gold_rationales = df['gold_rationales'].to_numpy()
    
    
    self.task_b_classes = mlb_task_b.classes_
    self.task_a_classes = mlb_task_a.classes_
    self.cou_classes = mlb_cou.classes_
    # Number of examples.
    self.n_examples = len(self.task_a_labels)

    return


  def __len__(self):
    return self.n_examples


  def __getitem__(self, item):
    return (self.encoding[item],  self.task_a_labels[item], self.task_b_labels[item], self.country[item], self.length[item], self.vocab_cf_a_zip[item], self.vocab_cf_b_zip[item],self.vocab_cf_ab_zip[item], self.silver_rationales[item], self.gold_rationales[item])
  
  def filter(self, sentence, spl_tokens):
    return [y for y in sentence if y  not in spl_tokens]

  
  def pack_sentences(self,token_encodings):
    doc = []
    sentence_len = [len(x) for x in token_encodings]
    i = 0
    sum = 0
    each_sentence_group = []
    
    while i <len(sentence_len):
      while(sentence_len[i]> self.max_sent_len):
        if(len(each_sentence_group)!=0):
            doc.append(each_sentence_group)
        doc.append(token_encodings[i][:self.max_sent_len])
        sentence_len[i] = sentence_len[i] - self.max_sent_len
        token_encodings[i] = token_encodings[i][self.max_sent_len:]
        each_sentence_group = []
        sum = 0
      sum += sentence_len[i] 
      if(sum<= self.max_sent_len):
        each_sentence_group.extend(token_encodings[i])
      else:
        doc.append(each_sentence_group)
        each_sentence_group = token_encodings[i]
        sum = sentence_len[i] 
      i +=1
    
    if(len(each_sentence_group)!=0):
      doc.append(each_sentence_group)
    
    if(len(doc)>self.max_doc_len):
      doc = doc[:self.max_doc_len]
    return doc

In [None]:
class MyCollate:
    def __init__(self, pad_idx, sep_idx):
        self.pad_token_id = pad_idx
        self.sep_token_id = sep_idx

    def pad_sentence_for_batch(self, tokens_lists, max_len: int):
        pad_id = self.pad_token_id
        toks_ids = []
        att_masks = []
        tok_type_lists = []
        for item_toks in tokens_lists:
            padded_item_toks = item_toks + [pad_id] * (max_len - len(item_toks))
            toks_ids.append(padded_item_toks)

            att_mask = [1] * len(item_toks) + [0] * (max_len - len(item_toks))
            att_masks.append(att_mask)

            tok_type_list = [0] * (max_len)
            tok_type_lists.append(tok_type_list)
            
        return toks_ids, att_masks, tok_type_lists


    def __call__(self, batch):
        batch = filter(lambda x: x is not None, batch)
        docs, task_a_labels, task_b_labels, country, length, vocab_cf_a,vocab_cf_b,vocab_cf_ab, silver, gold = list(zip(*batch))
        doc_lengths = [len(x) for x in docs]
        sent_lengths = []
        for element in docs:
          sent_lengths.append([len(i) for i in element])
        
        batch_sz = len(task_a_labels)
        batch_max_doc_length = max(doc_lengths)
        batch_max_sent_length = max([max(sl) for sl in sent_lengths])

        docs_tensor = torch.zeros((batch_sz, batch_max_doc_length, batch_max_sent_length), dtype=torch.long)
        batch_att_mask_tensor = torch.zeros((batch_sz, batch_max_doc_length, batch_max_sent_length), dtype=torch.long)
        token_type_ids_tensor = torch.zeros((batch_sz, batch_max_doc_length, batch_max_sent_length), dtype=torch.long)
        sent_lengths_tensor = torch.zeros((batch_sz, batch_max_doc_length))

        for doc_idx, doc in enumerate(docs):
            padded_token_lists, att_mask_lists, tok_type_lists = self.pad_sentence_for_batch(doc, batch_max_sent_length)

            doc_length = doc_lengths[doc_idx]
            sent_lengths_tensor[doc_idx, :doc_length] = torch.tensor(sent_lengths[doc_idx], dtype=torch.long)

            for sent_idx, (padded_tokens, att_masks, tok_types) in enumerate(
                    zip(padded_token_lists, att_mask_lists, tok_type_lists)):
                docs_tensor[doc_idx, sent_idx, :] = torch.tensor(padded_tokens, dtype=torch.long)
                batch_att_mask_tensor[doc_idx, sent_idx, :] = torch.tensor(att_masks, dtype=torch.long)
                token_type_ids_tensor[doc_idx, sent_idx, :] = torch.tensor(tok_types, dtype=torch.long)

        return (
            docs_tensor,
            torch.tensor(task_a_labels, dtype=torch.long),
            torch.tensor(task_b_labels, dtype=torch.long),
            torch.tensor(country,dtype=torch.long),
            torch.tensor(length,dtype=torch.long),
            torch.tensor(vocab_cf_a, dtype=torch.long),
            torch.tensor(vocab_cf_b, dtype=torch.long),
            torch.tensor(vocab_cf_ab, dtype=torch.long),
            torch.tensor(doc_lengths, dtype=torch.long),
            sent_lengths_tensor,
            batch_att_mask_tensor, 
            token_type_ids_tensor,
            silver,
            gold
        )

In [None]:
paragraph_num_removal = True
bert_path = "nlpaueb/legal-bert-base-uncased"
vocab_confounder_b = ['hearing', 'born', 'adjourned', 'detained', 'noted', 'hearing', 'alleged', 'investigation', 'place', 'question', 'members']
vocab_confoudner_a =[ 'stated','could','also','one','detained','investigation','within','due','second','hearing','certain']
vocab_confoudner_ab =[ 'february','published','november','march','religious','investigation','first','service','letter','carried','would','one','submitted','head','march','damage','group','provided','seen']
        
train_dataset = LexGlueDataset(lex_b['train'],bert_path = bert_path, max_sent_len = 500, max_doc_len = 40,paragraph_num_removal = paragraph_num_removal,vocab_confounder_b =vocab_confounder_b, vocab_confounder_a =vocab_confounder_a,vocab_confounder_ab =vocab_confounder_ab )
dev_dataset = LexGlueDataset(lex_b['val'],bert_path = bert_path, max_sent_len = 500, max_doc_len = 40, classes= train_dataset.classes,paragraph_num_removal = paragraph_num_removal,vocab_confounder_b =vocab_confounder_b, vocab_confounder_a =vocab_confounder_a,vocab_confounder_ab =vocab_confounder_ab)
test_dataset = LexGlueDataset(lex_b['test'],bert_path = bert_path, max_sent_len = 500, max_doc_len = 40, classes= train_dataset.classes,paragraph_num_removal = paragraph_num_removal,vocab_confounder_b =vocab_confounder_b, vocab_confounder_a =vocab_confounder_a,vocab_confounder_ab =vocab_confounder_ab)

In [None]:
train_dataloader =  DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn = MyCollate(pad_idx = train_dataset.pad_token_id, sep_idx = train_dataset.pad_token_id))
dev_dataloader =  DataLoader(dev_dataset, batch_size=16, shuffle=True, collate_fn = MyCollate(pad_idx = dev_dataset.pad_token_id, sep_idx = dev_dataset.pad_token_id))
test_dataloader =  DataLoader(test_dataset, batch_size=16, shuffle=True, collate_fn = MyCollate(pad_idx = test_dataset.pad_token_id, sep_idx = test_dataset.pad_token_id))

## Only prediction

In [None]:
import numpy as np
from sklearn.metrics import f1_score

def compute_metrics(gold, predicted):
    gold = np.array(gold)
    predicted = np.array(predicted)
    
    y_true = np.zeros((gold.shape[0], gold.shape[1] + 1), dtype=np.int32)
    y_true[:, :-1] = gold
    y_true[:, -1] = (np.sum(gold, axis=1) == 0).astype('int32')
        # Fix predictions
    y_pred = np.zeros((predicted.shape[0], predicted.shape[1] + 1), dtype=np.int32)
    y_pred[:, :-1] = predicted
    y_pred[:, -1] = (np.sum(predicted, axis=1) == 0).astype('int32')
    
    macro_f1 = f1_score(y_true=y_true, y_pred=y_pred, average='macro', zero_division=0)
    micro_f1 = f1_score(y_true=y_true, y_pred=y_pred, average='micro', zero_division=0)
    return {'macro-f1': macro_f1, 'micro-f1': micro_f1}

In [None]:
import torch.optim as optim
from tqdm import tqdm
import itertools

num_epochs = 30
learning_rate = 1e-4
file_prefix = "a_vanila_"
task = 'a' #(a or b or ab)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

classes_num = len(train_dataset.task_a_classes) if task=='a' else len(train_dataset.task_b_classes)

model = HANClassifierModel(device, 0.1, 768, 300, 400, 200, 100, classes_num,"nlpaueb/legal-bert-base-uncased", task).to(device)

feature_extractor = model.HANmodel
classifier = model.classifier


for name, param in model.named_parameters():
  if("bert" in name):
    param.requires_grad = False

optimizer = optim.Adam(list(filter(lambda p: p.requires_grad, model.parameters())), lr=learning_rate)
cls_criterion= nn.BCEWithLogitsLoss()

column_names = ["epoch", "mean epoch loss", "macro_f1","micro_f1"]
train_df = pd.DataFrame(columns = column_names)
dev_df = pd.DataFrame(columns = column_names)

for epoch in range(num_epochs):
    print(f"[Epoch {epoch} / {num_epochs}]")

    model.train()
    cls_predicted, cls_gold,loss_batch = [],[],[]
    for batch_idx, (docs,task_a_labels, task_b_labels, country,  length, vocab_cf_a,vocab_cf_b, vocab_cf_ab, doc_lengths, sent_lengths, att_mask, token_type_id, silver, gold) in tqdm(enumerate(train_dataloader),total=len(train_dataloader), leave=False):
        # Get input and targets and get to cuda
        labels = task_b_labels if task == 'b' else task_a_labels
        docs, labels, country, length,vocab_cf_a,vocab_cf_b, vocab_cf_ab,  doc_lengths, sent_lengths, att_mask, token_type_id = docs.to(device), labels.to(device), country.to(device),length.to(device),vocab_cf_a.to(device),vocab_cf_b.to(device), vocab_cf_ab.to(device),  doc_lengths.to(device), sent_lengths.to(device), att_mask.to(device), token_type_id.to(device)

        # Forward prop
        doc_embeds, word_att_weights, sentence_att_weights = feature_extractor(docs, doc_lengths, sent_lengths, att_mask, token_type_id)

        if task == 'ab':
          doc_embeds = torch.cat((doc_embeds, task_b_labels.to(device)), 1)

        cls_scores = classifier(doc_embeds)    
        
        cls_loss = cls_criterion(cls_scores, labels.float())
        
        loss = cls_loss
        loss_batch.append(loss)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        cls_preds = torch.sigmoid(cls_scores) # Sigmoid to map predictions between 0 and 1
        cls_pred_labels = (cls_preds >= 0.5).long() # Binarize predictions to 0 and 1
        cls_gold.extend(labels.cpu().detach().numpy())
        cls_predicted.extend(cls_pred_labels.cpu().detach().numpy())

    mean_epoch_loss = (torch.stack(loss_batch , dim=0).sum(dim=0)).item()

    metrics = compute_metrics(cls_gold, cls_predicted)
    
    train_df.loc[len(train_df)] = [epoch,round(mean_epoch_loss,2),round(100*metrics['macro-f1'],2),round(100*metrics['micro-f1'],2)]

    model_checkpoint = {'model_state_dict':model.state_dict(),'optimizer': optimizer.state_dict()}
    torch.save(model_checkpoint,file_prefix+str(epoch)+".pth.tar")

    
    model.eval() 
    cls_predicted, cls_gold,loss_batch = [],[],[]

    with torch.no_grad(): 
      for batch_idx, (docs,task_a_labels, task_b_labels, country,  length, vocab_cf_a,vocab_cf_b, vocab_cf_ab, doc_lengths, sent_lengths, att_mask, token_type_id, silver, gold) in tqdm(enumerate(dev_dataloader),total=len(dev_dataloader), leave=False):
        # Get input and targets and get to cuda
        labels = task_b_labels if task == 'b' else task_a_labels
        docs, labels, country, length,vocab_cf_a,vocab_cf_b, vocab_cf_ab, doc_lengths, sent_lengths, att_mask, token_type_id = docs.to(device), labels.to(device), country.to(device),length.to(device),vocab_cf_a.to(device),vocab_cf_b.to(device), vocab_cf_ab.to(device),doc_lengths.to(device), sent_lengths.to(device), att_mask.to(device), token_type_id.to(device)

        # Forward prop
        doc_embeds, word_att_weights, sentence_att_weights = feature_extractor(docs, doc_lengths, sent_lengths, att_mask, token_type_id)

        if task == 'ab':
          doc_embeds = torch.cat((doc_embeds, task_b_labels.to(device)), 1)

        cls_scores = classifier(doc_embeds)

        cls_loss = cls_criterion(cls_scores, labels.float())

        loss = cls_loss
        loss_batch.append(loss)
        
        cls_preds = torch.sigmoid(cls_scores) # Sigmoid to map predictions between 0 and 1
        cls_pred_labels = (cls_preds >= 0.5).long() # Binarize predictions to 0 and 1
        cls_gold.extend(labels.cpu().detach().numpy())
        cls_predicted.extend(cls_pred_labels.cpu().detach().numpy())


    mean_epoch_loss = (torch.stack(loss_batch , dim=0).sum(dim=0)).item()
    metrics = compute_metrics(cls_gold, cls_predicted)

    dev_df.loc[len(dev_df)] = [epoch,round(mean_epoch_loss,2),round(100*metrics['macro-f1'],2),round(100*metrics['micro-f1'],2)]

### Gradient Reversal for length /vocab /country / all 

In [None]:
import numpy as np
from sklearn.metrics import f1_score

def compute_metrics(gold, predicted):
    gold = np.array(gold)
    predicted = np.array(predicted)
    
    y_true = np.zeros((gold.shape[0], gold.shape[1] + 1), dtype=np.int32)
    y_true[:, :-1] = gold
    y_true[:, -1] = (np.sum(gold, axis=1) == 0).astype('int32')
        # Fix predictions
    y_pred = np.zeros((predicted.shape[0], predicted.shape[1] + 1), dtype=np.int32)
    y_pred[:, :-1] = predicted
    y_pred[:, -1] = (np.sum(predicted, axis=1) == 0).astype('int32')
    
    macro_f1 = f1_score(y_true=y_true, y_pred=y_pred, average='macro', zero_division=0)
    micro_f1 = f1_score(y_true=y_true, y_pred=y_pred, average='micro', zero_division=0)
    return {'macro-f1': macro_f1, 'micro-f1': micro_f1}

In [None]:
import torch.optim as optim
from tqdm import tqdm
from sklearn.metrics import accuracy_score,precision_score,recall_score,f1_score
import itertools

num_epochs = 30
violation_learning_rate = 1e-4
country_learning_rate = 5e-5
length_learning_rate = 5e-5
vocab_learning_rate = 5e-5
gamma_country = 0.1
gamma_length = 0.1
gamma_vocab = 0.1
country_only = True
length_only = False
vocab_only = False
country_and_length_and_vocab = False
file_prefix = 'a_grad_cou_'
task = 'a'


if task == 'b':
  vocab_classes = 11
elif task == 'a':
  vocab_classes = 14
else:
  vocab_classes = 19

assert country_only ^ length_only ^ vocab_only ^ country_and_length_and_vocab

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = HANClassifierModel(device, 0.1, 768, 300, 400, 200, 100, len(train_dataset.task_b_classes),"nlpaueb/legal-bert-base-uncased", task).to(device)

feature_extractor = model.HANmodel
classifier = model.classifier

for name, param in model.named_parameters():
  if("bert" in name):
    param.requires_grad = False

params_dict = []
params_dict.append({'params': list(filter(lambda p: p.requires_grad, model.parameters())), 'lr': violation_learning_rate})

country_modules = []
len_modules = []
vocab_modules = []

if country_only or country_and_length_and_vocab:
  country_modules =  country_modules + [
              torch.nn.Linear(400, 200), 
              nn.ReLU(), 
              nn.Dropout(0.1),
              torch.nn.Linear(200,len(train_dataset.cou_classes))
            ]
  country_discriminator = nn.Sequential(*country_modules).to(device)
  params_dict.append({'params': list(country_discriminator.parameters()), 'lr': country_learning_rate})


if length_only or country_and_length_and_vocab:
  len_modules =  len_modules + [
              torch.nn.Linear(400, 200), 
              nn.ReLU(), 
              nn.Dropout(0.1),
              torch.nn.Linear(200,4)
            ]
  len_discriminator = nn.Sequential(*len_modules).to(device)
  params_dict.append({'params': list(len_discriminator.parameters()), 'lr': length_learning_rate})

if vocab_only or country_and_length_and_vocab:
  vocab_modules =  vocab_modules + [
              torch.nn.Linear(400, 200), 
              nn.ReLU(), 
              nn.Dropout(0.1),
              torch.nn.Linear(200,vocab_classes)
            ]
  vocab_discriminator = nn.Sequential(*vocab_modules).to(device)
  params_dict.append({'params': list(vocab_discriminator.parameters()), 'lr': vocab_learning_rate})



optimizer = optim.Adam(params_dict)


cls_criterion= nn.BCEWithLogitsLoss()
len_criterion = nn.CrossEntropyLoss()
vocab_criterion= nn.BCEWithLogitsLoss()
country_criterion= nn.CrossEntropyLoss()

column_names = ["cls_loss", "cls_macro_f1","cls_micro_f1","cou_loss", "cou_accuracy","cou_precision","cou_recall","cou_f1","len_loss", "len_accuracy","len_precision","len_recall","len_f1","vocab_loss", "vocab_accuracy","vocab_precision","vocab_recall","vocab_f1"]
train_df = pd.DataFrame(columns = column_names)
dev_df = pd.DataFrame(columns = column_names)


for epoch in range(num_epochs):
    print(f"[Epoch {epoch} / {num_epochs}]")

    model.train()
    if country_only or country_and_length_and_vocab: 
        country_discriminator.train()
    if length_only or country_and_length_and_vocab: 
        len_discriminator.train() 
    if vocab_only or country_and_length_and_vocab: 
        vocab_discriminator.train() 

    cls_predicted, cls_gold,cls_loss_batch = [],[],[]
    country_predicted, country_gold,country_loss_batch = [],[],[]
    vocab_predicted, vocab_gold,vocab_loss_batch = [],[],[]
    len_predicted, len_gold, len_loss_batch = [],[],[]
    current_iteration = 0

    for batch_idx, (docs,task_a_labels, task_b_labels, country,  length, vocab_cf_a,vocab_cf_b,vocab_cf_ab, doc_lengths, sent_lengths, att_mask, token_type_id, silver, gold) in tqdm(enumerate(train_dataloader),total=len(train_dataloader), leave=False):
        
        current_iteration += 1
        p = float(current_iteration + epoch * len(src_train_dataloader)) / num_epochs / len(src_train_dataloader)
        lambda_country, lambda_length, lambda_vocab = get_lambda(gamma_country,p), get_lambda(gamma_length,p),get_lambda(gamma_vocab,p)     
        
        # Get input and targets and get to cuda
        labels = task_b_labels if task == 'b' else task_a_labels

        if task=='a':
          vocab_cf = vocab_cf_a 
        elif task =='b':
          vocab_cf_b = vocab_cf_b
        else:
          vocab_cf = vocab_cf_ab 

        docs, labels, country, length,vocab, doc_lengths, sent_lengths, att_mask, token_type_id = docs.to(device), labels.to(device), country.to(device),length.to(device),vocab_cf.to(device), doc_lengths.to(device), sent_lengths.to(device), att_mask.to(device), token_type_id.to(device)

        # Forward prop
        doc_embeds, word_att_weights, sentence_att_weights, sents = feature_extractor(docs, doc_lengths, sent_lengths, att_mask, token_type_id)
        
        if task == 'ab':
          doc_embeds = torch.cat((doc_embeds, task_b_labels.to(device)), 1)

        cls_scores = classifier(doc_embeds)
        cls_loss = cls_criterion(cls_scores, labels.float()) 
        cls_loss_batch.append(cls_loss)       
        loss = cls_loss 

        if country_only or country_and_length_and_vocab:  
            doc_embeds_country = GradientReversalFunction.apply(doc_embeds, lambda_country) 
            country_scores = country_discriminator(doc_embeds_country)
            country_loss = country_criterion(country_scores, country.argmax(dim=1)) 
            country_loss_batch.append(country_loss)  
            loss = loss + country_loss  
        if length_only or country_and_length_and_vocab:
            doc_embeds_length = GradientReversalFunction.apply(doc_embeds, lambda_length)     
            len_scores = len_discriminator(doc_embeds_length)
            len_loss = len_criterion(len_scores, length.argmax(dim=1)) 
            len_loss_batch.append(len_loss)
            loss = loss + len_loss  
        if vocab_only or country_and_length_and_vocab:  
            doc_embeds_vocab = GradientReversalFunction.apply(doc_embeds, lambda_vocab) 
            vocab_scores = vocab_discriminator(doc_embeds_vocab)
            vocab_loss = vocab_criterion(vocab_scores.flatten(), vocab.float().flatten()) 
            vocab_loss_batch.append(vocab_loss)  
            loss = loss + vocab_loss  


        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        cls_preds = torch.sigmoid(cls_scores) # Sigmoid to map predictions between 0 and 1
        cls_pred_labels = (cls_preds >= 0.5).long() # Binarize predictions to 0 and 1
        cls_gold.extend(labels.cpu().detach().numpy())
        cls_predicted.extend(cls_pred_labels.cpu().detach().numpy())

        if length_only or country_and_length_and_vocab: 
            len_preds = torch.softmax(len_scores, dim=0) 
            len_gold.extend(length.argmax(dim=1).cpu().detach().numpy())
            len_predicted.extend(len_preds.argmax(dim=1).cpu().detach().numpy())

        if vocab_only or country_and_length_and_vocab: 
            vocab_preds = torch.sigmoid(vocab_scores) # Sigmoid to map predictions between 0 and 1
            vocab_pred_labels = (vocab_preds >= 0.5).long() # Binarize predictions to 0 and 1
            vocab_gold.extend(vocab.cpu().detach().numpy())
            vocab_predicted.extend(vocab_pred_labels.cpu().detach().numpy())
        
        if country_only or country_and_length_and_vocab: 
            country_preds = torch.softmax(country_scores, dim=0) 
            country_gold.extend(country.argmax(dim=1).cpu().detach().numpy())
            country_predicted.extend(country_preds.argmax(dim=1).cpu().detach().numpy())

    cls_epoch_loss = (torch.stack(cls_loss_batch , dim=0).sum(dim=0)).item()
    if country_only or country_and_length_and_vocab:   
        country_epoch_loss = (torch.stack(country_loss_batch , dim=0).sum(dim=0)).item()
    if length_only or country_and_length_and_vocab:   
        len_epoch_loss = (torch.stack(len_loss_batch , dim=0).sum(dim=0)).item()
    if vocab_only or country_and_length_and_vocab:   
        vocab_epoch_loss = (torch.stack(vocab_loss_batch , dim=0).sum(dim=0)).item()

    metrics = compute_metrics(cls_gold, cls_predicted)

    if vocab_only or country_and_length_and_vocab: 
        vocab_accuracy, vocab_precision, vocab_recall, vocab_f1 = accuracy_score(vocab_gold,vocab_predicted),precision_score(vocab_gold, vocab_predicted,average='macro'),recall_score(vocab_gold, vocab_predicted,average='macro'),f1_score(vocab_gold, vocab_predicted,average='macro')

    if length_only or country_and_length_and_vocab: 
        len_accuracy, len_precision, len_recall, len_f1 = accuracy_score(len_gold,len_predicted),precision_score(len_gold, len_predicted,average='macro'),recall_score(len_gold, len_predicted,average='macro'),f1_score(len_gold, len_predicted,average='macro')
    
    if country_only or country_and_length_and_vocab: 
        country_accuracy, country_precision, country_recall, country_f1 = accuracy_score(country_gold,country_predicted),precision_score(country_gold, country_predicted,average='macro'),recall_score(country_gold, country_predicted,average='macro'),f1_score(country_gold, country_predicted,average='macro')
    
    if country_only:
        train_df.loc[len(train_df)] = [round(cls_epoch_loss,2),round(100*metrics['macro-f1'],2), round(100*metrics['micro-f1'],2),round(country_epoch_loss,2),round(100*country_accuracy,2), round(100*country_precision,2), round(100*country_recall,2), round(100*country_f1,2),"","","","","","","","","",""]
        model_checkpoint = {'model_state_dict':model.state_dict(),'cou_state_dict': country_discriminator.state_dict(),'optimizer': optimizer.state_dict()}
    
    if length_only:
        train_df.loc[len(train_df)] = [round(cls_epoch_loss,2),round(100*metrics['macro-f1'],2), round(100*metrics['micro-f1'],2),"","","","","",round(len_epoch_loss,2),round(100*len_accuracy,2), round(100*len_precision,2), round(100*len_recall,2), round(100*len_f1,2),"","","","",""]
        model_checkpoint = {'model_state_dict':model.state_dict(),'len_state_dict': len_discriminator.state_dict(),'optimizer': optimizer.state_dict()}
    
    if vocab_only:
        train_df.loc[len(train_df)] = [round(cls_epoch_loss,2), round(100*metrics['macro-f1'],2), round(100*metrics['micro-f1'],2),"","","","","","","","","","",round(vocab_epoch_loss,2),round(100*vocab_accuracy,2), round(100*vocab_precision,2), round(100*vocab_recall,2), round(100*vocab_f1,2)]
        model_checkpoint = {'model_state_dict':model.state_dict(),'vocab_state_dict': vocab_discriminator.state_dict(),'optimizer': optimizer.state_dict()}
    
    if country_and_length_and_vocab:
        train_df.loc[len(train_df)] = [round(cls_epoch_loss,2),round(100*metrics['macro-f1'],2), round(100*metrics['micro-f1'],2),round(country_epoch_loss,2),round(100*country_accuracy,2), round(100*country_precision,2), round(100*country_recall,2), round(100*country_f1,2),round(len_epoch_loss,2),round(100*len_accuracy,2), round(100*len_precision,2), round(100*len_recall,2), round(100*len_f1,2),round(vocab_epoch_loss,2),round(100*vocab_accuracy,2), round(100*vocab_precision,2), round(100*vocab_recall,2), round(100*vocab_f1,2)]
        model_checkpoint = {'model_state_dict':model.state_dict(),'country_state_dict': country_discriminator.state_dict(),'len_state_dict': len_discriminator.state_dict(),'vocab_state_dict': vocab_discriminator.state_dict(),'optimizer': optimizer.state_dict()}
    
    
    torch.save(model_checkpoint,file_prefix+str(epoch)+".pth.tar")

    model.eval()
    if country_only or country_and_length_and_vocab: 
        country_discriminator.eval() 
    if length_only or country_and_length_and_vocab: 
        len_discriminator.eval() 
    if vocab_only or country_and_length_and_vocab: 
        vocab_discriminator.eval() 
    cls_predicted, cls_gold,cls_loss_batch = [],[],[]
    country_predicted, country_gold, country_loss_batch = [],[],[]
    vocab_predicted, vocab_gold, vocab_loss_batch = [],[],[]
    len_predicted, len_gold, len_loss_batch = [],[],[]
    with torch.no_grad(): 
      for batch_idx, (docs,task_a_labels, task_b_labels, country,  length, vocab_cf_a,vocab_cf_b, vocab_cf_ab, doc_lengths, sent_lengths, att_mask, token_type_id, silver, gold) in tqdm(enumerate(dev_dataloader),total=len(dev_dataloader), leave=False):
        # Get input and targets and get to cuda
        labels = task_b_labels if task == 'b' else task_a_labels

        if task=='a':
          vocab_cf = vocab_cf_a 
        elif task =='b':
          vocab_cf_b = vocab_cf_b
        else:
          vocab_cf = vocab_cf_ab 

        docs, labels, country, length,vocab, doc_lengths, sent_lengths, att_mask, token_type_id = docs.to(device), labels.to(device), country.to(device),length.to(device),vocab_cf.to(device), doc_lengths.to(device), sent_lengths.to(device), att_mask.to(device), token_type_id.to(device)

        # Forward prop
        doc_embeds, word_att_weights, sentence_att_weights, sents  = feature_extractor(docs, doc_lengths, sent_lengths, att_mask, token_type_id)

        if task == 'ab':
          doc_embeds = torch.cat((doc_embeds, task_b_labels.to(device)), 1)
        
        cls_scores = classifier(doc_embeds)
        cls_loss = cls_criterion(cls_scores, labels.float()) 
        cls_loss_batch.append(cls_loss)       
        loss = cls_loss 

        if country_only or country_and_length_and_vocab: 
            doc_embeds_country = GradientReversalFunction.apply(doc_embeds, lambda_country) 
            country_scores = country_discriminator(doc_embeds_country)
            country_loss = country_criterion(country_scores, country.argmax(dim=1)) 
            country_loss_batch.append(country_loss)  
            loss = loss + country_loss  

        if length_only or country_and_length_and_vocab:   
            doc_embeds_length = GradientReversalFunction.apply(doc_embeds, lambda_length) 
            len_scores = len_discriminator(doc_embeds_length)
            len_loss = len_criterion(len_scores, length.argmax(dim=1)) 
            len_loss_batch.append(len_loss)
            loss = loss + len_loss  

        if vocab_only or country_and_length_and_vocab: 
            doc_embeds_vocab = GradientReversalFunction.apply(doc_embeds, lambda_vocab) 
            vocab_scores = vocab_discriminator(doc_embeds_vocab)
            vocab_loss = vocab_criterion(vocab_scores.flatten(), vocab.float().flatten()) 
            vocab_loss_batch.append(vocab_loss)  
            loss = loss + vocab_loss  
      
        cls_preds = torch.sigmoid(cls_scores) # Sigmoid to map predictions between 0 and 1
        cls_pred_labels = (cls_preds >= 0.5).long() # Binarize predictions to 0 and 1
        cls_gold.extend(labels.cpu().detach().numpy())
        cls_predicted.extend(cls_pred_labels.cpu().detach().numpy())

        if length_only or country_and_length_and_vocab: 
            len_preds = torch.softmax(len_scores, dim=0) 
            len_gold.extend(length.argmax(dim=1).cpu().detach().numpy())
            len_predicted.extend(len_preds.argmax(dim=1).cpu().detach().numpy())

        if vocab_only or country_and_length_and_vocab: 
            vocab_preds = torch.sigmoid(vocab_scores) # Sigmoid to map predictions between 0 and 1
            vocab_pred_labels = (vocab_preds >= 0.5).long() # Binarize predictions to 0 and 1
            vocab_gold.extend(vocab.cpu().detach().numpy())
            vocab_predicted.extend(vocab_pred_labels.cpu().detach().numpy())
        
        if country_only or country_and_length_and_vocab: 
            country_preds = torch.softmax(country_scores, dim=0) 
            country_gold.extend(country.argmax(dim=1).cpu().detach().numpy())
            country_predicted.extend(country_preds.argmax(dim=1).cpu().detach().numpy())


    cls_epoch_loss = (torch.stack(cls_loss_batch , dim=0).sum(dim=0)).item()

    if country_only or country_and_length_and_vocab:   
        country_epoch_loss = (torch.stack(country_loss_batch , dim=0).sum(dim=0)).item()
    if length_only or country_and_length_and_vocab:   
        len_epoch_loss = (torch.stack(len_loss_batch , dim=0).sum(dim=0)).item()
    if vocab_only or country_and_length_and_vocab:   
        vocab_epoch_loss = (torch.stack(vocab_loss_batch , dim=0).sum(dim=0)).item()

    metrics = compute_metrics(cls_gold, cls_predicted)

    if vocab_only or country_and_length_and_vocab: 
        vocab_accuracy, vocab_precision, vocab_recall, vocab_f1 = accuracy_score(vocab_gold,vocab_predicted),precision_score(vocab_gold, vocab_predicted,average='macro'),recall_score(vocab_gold, vocab_predicted,average='macro'),f1_score(vocab_gold, vocab_predicted,average='macro')

    if length_only or country_and_length_and_vocab: 
        len_accuracy, len_precision, len_recall, len_f1 = accuracy_score(len_gold,len_predicted),precision_score(len_gold, len_predicted,average='macro'),recall_score(len_gold, len_predicted,average='macro'),f1_score(len_gold, len_predicted,average='macro')
    
    if country_only or country_and_length_and_vocab: 
        country_accuracy, country_precision, country_recall, country_f1 = accuracy_score(country_gold,country_predicted),precision_score(country_gold, country_predicted,average='macro'),recall_score(country_gold, country_predicted,average='macro'),f1_score(country_gold, country_predicted,average='macro')
    
    if country_only:
        dev_df.loc[len(dev_df)] = [round(cls_epoch_loss,2),round(100*metrics['macro-f1'],2), round(100*metrics['micro-f1'],2),round(country_epoch_loss,2),round(100*country_accuracy,2), round(100*country_precision,2), round(100*country_recall,2), round(100*country_f1,2),"","","","","","","","","",""]

    if length_only:
        dev_df.loc[len(dev_df)] = [round(cls_epoch_loss,2),round(100*metrics['macro-f1'],2), round(100*metrics['micro-f1'],2),"","","","","",round(len_epoch_loss,2),round(100*len_accuracy,2), round(100*len_precision,2), round(100*len_recall,2), round(100*len_f1,2),"","","","",""]

    if vocab_only:
        dev_df.loc[len(dev_df)] = [round(cls_epoch_loss,2), round(100*metrics['macro-f1'],2), round(100*metrics['micro-f1'],2),"","","","","","","","","","",round(vocab_epoch_loss,2),round(100*vocab_accuracy,2), round(100*vocab_precision,2), round(100*vocab_recall,2), round(100*vocab_f1,2)]
    
    if country_and_length_and_vocab:
        dev_df.loc[len(dev_df)] = [round(cls_epoch_loss,2),round(100*metrics['macro-f1'],2), round(100*metrics['micro-f1'],2),round(country_epoch_loss,2),round(100*country_accuracy,2), round(100*country_precision,2), round(100*country_recall,2), round(100*country_f1,2),round(len_epoch_loss,2),round(100*len_accuracy,2), round(100*len_precision,2), round(100*len_recall,2), round(100*len_f1,2),round(vocab_epoch_loss,2),round(100*vocab_accuracy,2), round(100*vocab_precision,2), round(100*vocab_recall,2), round(100*vocab_f1,2)]

    print(train_df)
    print(dev_df)    


### Test

In [None]:
checkpoint_file = 'bert_a_s_t_grl_src9.pth.tar'
task = 'a'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = HANClassifierModel(device, 0.1, 768, 300, 400, 200, 100, 1,"nlpaueb/legal-bert-base-uncased", task).to(device)
feature_extractor = model.HANmodel
classifier = model.classifier

checkpoint = torch.load(checkpoint_file)
model.load_state_dict(checkpoint['model_state_dict'])


cls_predicted, cls_gold,cls_loss_batch = [],[],[]

with torch.no_grad(): 
      for batch_idx, (docs,task_a_labels, task_b_labels, country,  length, vocab_cf_a,vocab_cf_b, vocab_cf_ab, doc_lengths, sent_lengths, att_mask, token_type_id, silver, gold) in tqdm(enumerate(dev_dataloader),total=len(dev_dataloader), leave=False):
        
        labels = task_b_labels if task == 'b' else task_a_labels
        docs, labels, country, length,vocab, doc_lengths, sent_lengths, att_mask, token_type_id = docs.to(device), labels.to(device), country.to(device),length.to(device),vocab_cf.to(device), doc_lengths.to(device), sent_lengths.to(device), att_mask.to(device), token_type_id.to(device)

        # Forward prop
        doc_embeds, word_att_weights, sentence_att_weights, sents  = feature_extractor(docs, doc_lengths, sent_lengths, att_mask, token_type_id)

        if task == 'ab':
          doc_embeds = torch.cat((doc_embeds, task_b_labels.to(device)), 1)
        
        cls_scores = classifier(doc_embeds)      
        cls_preds = torch.sigmoid(cls_scores) # Sigmoid to map predictions between 0 and 1
        cls_pred_labels = (cls_preds >= 0.5).long() # Binarize predictions to 0 and 1
        cls_gold.extend(labels.cpu().detach().numpy())
        cls_predicted.extend(cls_pred_labels.cpu().detach().numpy())

    metrics = compute_metrics(cls_gold, cls_predicted)
    print(metrics)  

### IG

In [None]:
import pandas as pd
import re
from transformers import AutoTokenizer
from sklearn.preprocessing import MultiLabelBinarizer


paragraph_num_removal = True
tokenizer_path = "nlpaueb/legal-bert-base-uncased"
max_sent_len = 500
max_doc_len = 40
classes = ['10', '11', '14', '2', '3', '5', '6', '8', '9', 'P1-1']

tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

def spl_filter(sentence, spl_tokens):
    return [y for y in sentence if y  not in spl_tokens]

def pack_sentences(token_encodings, max_sent_len, max_doc_len):
    doc = []
    sentence_len = [len(x) for x in token_encodings]
    i = 0
    sum = 0
    each_sentence_group = []
    mapping = []
    flag = False

    while i <len(sentence_len):
      start_doc = len(doc)
      while(sentence_len[i]> max_sent_len):
        if(len(each_sentence_group)!=0):
            doc.append(each_sentence_group)
            if(~flag):
              start_doc = len(doc)
        flag = True
        doc.append(token_encodings[i][:max_sent_len])
        sentence_len[i] = sentence_len[i] - max_sent_len
        token_encodings[i] = token_encodings[i][max_sent_len:]
        each_sentence_group = []
        sum = 0
      start = sum
      sum += sentence_len[i] 
      if(sum<= max_sent_len):
        each_sentence_group.extend(token_encodings[i])
        if(~flag):
              mapping.append({"start": start , "end": sum-1 , "start_doc": start_doc,  "end_doc":len(doc)})
        else:
              mapping.append({"start": start , "end": sum-1 , "start_doc": start_doc+1,  "end_doc":len(doc)})
      else:
        doc.append(each_sentence_group)
        each_sentence_group = token_encodings[i]
        sum = sentence_len[i] 
        mapping.append({"start": 0 , "end": sum-1 ,"start_doc":start_doc +1, "end_doc":len(doc)})
      i +=1
    
    if(len(each_sentence_group)!=0):
      doc.append(each_sentence_group)

    if(len(doc)>max_doc_len):
      doc = doc[:max_doc_len]
      mapping = [x for x in mapping if x["end_doc"] < max_doc_len]

    return doc, mapping

def encode(text, art_id):
    if paragraph_num_removal :
      text = [re.sub(r'^\d+\. ', '',paragraph)  for paragraph in text]
      
    text_tokens = [tokenizer.encode(paragraph) for paragraph in text]
    
    spl_tokens = [tokenizer.sep_token_id, tokenizer.cls_token_id]
    text_tokenized = [spl_filter(paragraph, spl_tokens) for paragraph in text_tokens]
    one_hot = MultiLabelBinarizer()
    class_ids = [[x] for x in classes]
    one_hot.fit(class_ids)
    b_labels = one_hot.transform([art_id])
    
    encoding,mapping = pack_sentences(text_tokenized, max_sent_len, max_doc_len)
    return encoding, mapping, b_labels

In [None]:
import pandas as pd
import re
from transformers import AutoTokenizer
from sklearn.preprocessing import MultiLabelBinarizer


paragraph_num_removal = True
tokenizer_path = "nlpaueb/legal-bert-base-uncased"
max_sent_len = 500
max_doc_len = 40
classes = ['10', '11', '14', '2', '3', '5', '6', '8', '9', 'P1-1']

tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

def spl_filter(sentence, spl_tokens):
    return [y for y in sentence if y  not in spl_tokens]

def pack_sentences(token_encodings, max_sent_len, max_doc_len):
    doc = []
    sentence_len = [len(x) for x in token_encodings]
    i = 0
    sum = 0
    each_sentence_group = []
    mapping = []
    flag = False

    while i <len(sentence_len):
      start_doc = len(doc)
      while(sentence_len[i]> max_sent_len):
        if(len(each_sentence_group)!=0):
            doc.append(each_sentence_group)
            if(~flag):
              start_doc = len(doc)
        flag = True
        doc.append(token_encodings[i][:max_sent_len])
        sentence_len[i] = sentence_len[i] - max_sent_len
        token_encodings[i] = token_encodings[i][max_sent_len:]
        each_sentence_group = []
        sum = 0
      start = sum
      sum += sentence_len[i] 
      if(sum<= max_sent_len):
        each_sentence_group.extend(token_encodings[i])
        if(~flag):
              mapping.append({"start": start , "end": sum-1 , "start_doc": start_doc,  "end_doc":len(doc)})
        else:
              mapping.append({"start": start , "end": sum-1 , "start_doc": start_doc+1,  "end_doc":len(doc)})
      else:
        doc.append(each_sentence_group)
        each_sentence_group = token_encodings[i]
        sum = sentence_len[i] 
        mapping.append({"start": 0 , "end": sum-1 ,"start_doc":start_doc +1, "end_doc":len(doc)})
      i +=1
    
    if(len(each_sentence_group)!=0):
      doc.append(each_sentence_group)

    if(len(doc)>max_doc_len):
      doc = doc[:max_doc_len]
      mapping = [x for x in mapping if x["end_doc"] < max_doc_len]

    return doc, mapping

def encode(text, art_id):
    if paragraph_num_removal :
      text = [re.sub(r'^\d+\. ', '',paragraph)  for paragraph in text]
      
    text_tokens = [tokenizer.encode(paragraph) for paragraph in text]
    
    spl_tokens = [tokenizer.sep_token_id, tokenizer.cls_token_id]
    text_tokenized = [spl_filter(paragraph, spl_tokens) for paragraph in text_tokens]
    one_hot = MultiLabelBinarizer()
    class_ids = [[x] for x in classes]
    one_hot.fit(class_ids)
    b_labels = one_hot.transform([art_id])
    
    encoding,mapping = pack_sentences(text_tokenized, max_sent_len, max_doc_len)
    return encoding, mapping, b_labels

In [None]:
with open('test.jsonl', 'r') as json_file:
    json_list = list(json_file)

data = []
for json_str in json_list:
   data.append(json.loads(json_str))

df = pd.DataFrame(data)
df = df[['case_id','facts', 'defendants','allegedly_violated_articles','violated_articles','silver_rationales', 'gold_rationales']]

task_b = {'6','3','5','P1-1','8','2','10','11','14','9'}
task_a = {'6','3','5','P1-1','8','2','10','11','14','9'}

df['rationales_len'] = df['gold_rationales'].apply(lambda x:len(x))
df = df.loc[df['rationales_len'] != 0]


df['allegedly_violated_articles'] = df['allegedly_violated_articles'].apply(lambda row: [article  for article in row if article in task_b]) 
df['violated_articles'] = df['violated_articles'].apply(lambda row: [article  for article in row if article in task_a]) 

In [None]:
def IG(text, b_labels, task):
      encoding, mapping, b_labels = encode(text, b_labels)
      docs, doc_lengths, sent_lengths, att_mask, token_type_id = batchify (encoding, tokenizer.pad_token_id)
      docs, doc_lengths, sent_lengths, att_mask, token_type_id, b_labels = docs.to(device), doc_lengths.to(device), sent_lengths.to(device), att_mask.to(device), token_type_id.to(device), torch.tensor(b_labels).to(device)
      #print(docs.shape)
      sents, sent_lengths_pre, attn_masks, token_types, docs_valid_bsz,  doc_perm_idx, sent_perm_idx = model.HANmodel.word_attention_pre(
                      docs, doc_lengths, sent_lengths, att_mask, token_type_id
                  )

      reference_indices = torch.full(sents.shape, tokenizer.pad_token_id).to(device)

      attributions_ig = lig.attribute(sents, baselines = reference_indices,  additional_forward_args=(sent_lengths_pre, attn_masks, token_types, docs_valid_bsz,  doc_perm_idx, sent_perm_idx, b_labels, task), n_steps=1)
      attributions = attributions_ig.sum(dim=-1)
      attributions = attributions / torch.norm(attributions)
      return attributions, encoding, mapping
      

In [None]:
import pickle


import numpy

!pip install captum
from captum.attr import visualization as viz
from captum.attr import LayerConductance, LayerIntegratedGradients, IntegratedGradients
from captum.attr import configure_interpretable_embedding_layer

import gc

gc.collect()
torch.cuda.empty_cache()
torch.backends.cudnn.enabled=False

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")




def my_forward_func(sents, sent_lengths, attn_masks, token_types, docs_valid_bsz,  doc_perm_idx, sent_perm_idx, b_labels, task):
  sent_embeddings, doc_perm_idx, docs_valid_bsz, word_att_weights = model.HANmodel.word_attention_post(
                 sents, sent_lengths, attn_masks, token_types, docs_valid_bsz,  doc_perm_idx, sent_perm_idx
            )
  doc_embeds, word_att_weights, sentence_att_weights = model.HANmodel.sentence_attention(
            sent_embeddings, doc_perm_idx, docs_valid_bsz, word_att_weights
        )
  if task == 'ab':
    doc_features = torch.cat((doc_embeds, b_labels), 1)
  cls_scores = model.classifier(doc_features)  
  cls_scores = cls_scores.squeeze(dim=1)        
  cls_preds = torch.sigmoid(cls_scores) 
  return cls_preds.squeeze(0)

scores_dict = {}

scores_dict = {}
checkpoint_file = 'final_lex/a_b_grad_all_'+str(i)+'.pth.tar'
task = 'a'
model_arch = 'a_b_grad_all'
checkpoint = torch.load(checkpoint_file)


model.load_state_dict(checkpoint['model_state_dict'])
model = HANClassifierModel(device, 0.1, 768, 300, 400, 200, 100, 10,"nlpaueb/legal-bert-base-uncased", task).to(device)


lig = LayerIntegratedGradients(my_forward_func,model.HANmodel.word_attention_post.bert_model.embeddings)

for index, row in df.iterrows():
        model.eval()
        model.zero_grad()
        id = row['case_id']
        text = row['facts']       
        b_labels =  row['allegedly_violated_articles']

        attributions, encoding, mapping = IG(text, b_labels, task)
        if id not in list(scores_dict.keys()):
            scores_dict[id] = []
        scores_dict[id].append({model_arch: {"mapping":mapping, "attributions" :attributions} })
        gc.collect()
        torch.cuda.empty_cache()

with open('final_lex/a_b_grad_all_'+str(i)+'.pkl', 'wb') as f:
          pickle.dump(scores_dict, f)