In [None]:
!pip install transformers
!pip install allennlp
!pip install --upgrade google-cloud-storage

In [None]:
from google.colab import drive
drive.mount('/content/drive',force_remount=True)

# %cd /content/drive/My\ Drive/Legal\ DS/SCRF_RRL/rhetorical-role-baseline/
# Ekstep corpus:
# %cd /content/drive/My\ Drive/Legal\ DS/Paheli_new_corpus/semantic_segmentation/Corpus

import pandas as pd
from itertools import groupby
import warnings
warnings.filterwarnings('ignore')

pd.set_option('display.max_columns', None)

In [None]:
%cd /content/drive/My\ Drive/Legal\ DS/Paheli_new_corpus/semantic-segmentation/Corpus/

In [None]:
import os
import re
#import torchtext
import json
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from sklearn.preprocessing import MultiLabelBinarizer, StandardScaler, KBinsDiscretizer
import numpy as np
# Pytorch Dataset
class RRDataset(Dataset):
  def __init__(self, path, tokenizer_path, label_to_ind, max_len):
    self.encoding = []
    self.labels = []
    self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
    self.sep_token_id = self.tokenizer.sep_token_id
    self.pad_token_id = self.tokenizer.pad_token_id
    self.label_to_ind = label_to_ind
    self.max_len = max_len

    f = open(path)
    data = json.load(f)
    #ans?
    ans = map(self.parse_doc,data)
    text_labels = list(ans)
    #labels, text, id
    self.text = [x[1] for x in text_labels]
    self.labels = [x[0] for x in text_labels]
    self.id = [x[2] for x in text_labels]
    self.text, self.labels, self.id = zip(*list(filter(lambda x: (len(x[0]) > 0 and len(x[1]) > 0), zip(self.text, self.labels,self.id))))

    return


  def __len__(self):
    return len(self.labels)


  def __getitem__(self, item):
    return (self.text[item], self.labels[item], self.id[item])
  
  def parse_doc(self,x):
    doc_labels, text = [],[]
    id = x['id']
    sentences = x['annotations'][0]['result']
    for sentence in sentences:
        #doc_labels.append(sentence['value']['labels'][0])
        doc_labels.append(self.label_to_ind[sentence['value']['labels'][0]])
        #text.append(sentence['value']['text'])
        text.append(self.tokenizer.encode(sentence['value']['text'])[:self.max_len])
    return (doc_labels, text, id)


In [None]:
from allennlp.data.dataset_readers.dataset_utils import enumerate_spans
import torch

class MyCollate:
    def __init__(self, pad_idx, sep_idx, max_width, label_to_ind):
        self.pad_token_id = pad_idx
        #sep token ???
        self.sep_token_id = sep_idx
        self.max_width = max_width
        self.label_to_ind = label_to_ind

    def pad_sentence_for_batch(self, tokens_lists, max_len: int):
        pad_id = self.pad_token_id
        toks_ids = []
        att_masks = []
        #pad each token in token list
        #att mask = 1 * token len
        for item_toks in tokens_lists:
            padded_item_toks = item_toks + [pad_id] * (max_len - len(item_toks))
            toks_ids.append(padded_item_toks)

            att_mask = [1] * len(item_toks) + [0] * (max_len - len(item_toks))
            att_masks.append(att_mask)
            
        return toks_ids, att_masks
    
    def pad_doc_for_batch(self, doc_lengths, labels, segment_ids, max_len):
        lab_masks = []
        sent_masks = []
        seg_masks = []
        for i  in range(len(labels)):
          lab_item = labels[i] + [self.label_to_ind['MASK']]*(max_len - len(labels[i]))
          lab_masks.append(lab_item)

          seg_item = segment_ids[i] + [0]*(max_len - len(segment_ids[i]))
          seg_masks.append(seg_item)

          each_sent_mask = [1] * doc_lengths[i] + [0] * (max_len- doc_lengths[i])
          sent_masks.append(each_sent_mask)

        return sent_masks, lab_masks, seg_masks

    def span_enumeration(self, sent_masks, max_width):
        all_span_ids = []
        for each in range(len(sent_masks)):
            each_span_ids = enumerate_spans(sent_masks[each][(sent_masks[each].nonzero())], offset=0, max_span_width=max_width)
            #each_span_ids = enumerate_spans(x["sentence_mask"][each], offset=0, max_span_width=3)
            all_span_ids.append(each_span_ids)

        max_span_len = max([len(x) for x in all_span_ids])
        span_ids = [x+[[0,0]]*(max_span_len-len(x)) for x in all_span_ids]
        return span_ids
    
    
    def seg_mask_fix(self,seg_inds):
        max_path = self.max_width
        counter = np.zeros((len(seg_inds)), dtype=np.int32)
        seg_inds_fix = []
        for b,  sent_inds in enumerate(seg_inds):
            counter = 0
            new_inds = []
            for i , flag in enumerate(sent_inds):
                path_flag = (counter >= max_path-1)
                    
                mask_step = flag | path_flag
                new_inds.append(mask_step)
                counter = counter + 1
                counter = (1- mask_step)*counter*(counter < max_path)
                
            seg_inds_fix.append(new_inds)     
        return seg_inds_fix

    
    def convert_labels_to_segments(self,labels):
      seg_ids = []
      for i  in range(len(labels)):
          each_seg_id = []
          prev = labels[i][0]
          each_seg_id.append(0)
          for j in range(1,len(labels[i])):
              if(prev != labels[i][j]):
                  each_seg_id[len(each_seg_id)-1] = 1
                  each_seg_id.append(0)
                  prev = labels[i][j]
              else:
                  each_seg_id.append(0)
          each_seg_id[len(each_seg_id)-1] = 1
          seg_ids.append(each_seg_id)
      segments = self.seg_mask_fix(seg_ids)
      return segments

    def __call__(self, batch):
        batch = filter(lambda x: x is not None, batch)
        docs , labels, id = list(zip(*batch))
        doc_lengths = [len(x) for x in docs]
        sent_lengths = []
        for element in docs:
          sent_lengths.append([len(i) for i in element])
        
        batch_sz = len(id)
        batch_max_doc_length = max(doc_lengths)
        batch_max_sent_length = max([max(sl) for sl in sent_lengths])

        docs_tensor = torch.zeros((batch_sz, batch_max_doc_length, batch_max_sent_length), dtype=torch.long)
        att_mask = torch.zeros((batch_sz, batch_max_doc_length, batch_max_sent_length), dtype=torch.long)

        segments = self.convert_labels_to_segments(labels)
        padded_sent_mask, padded_label_id, padded_segments = self.pad_doc_for_batch(doc_lengths, labels, segments, batch_max_doc_length)

        label_ids = torch.tensor(padded_label_id , dtype=torch.long)
        segment_ids = torch.tensor(padded_segments , dtype=torch.long)
        sent_mask = torch.tensor(padded_sent_mask, dtype=torch.long)

        spans = torch.tensor(self.span_enumeration(sent_mask, self.max_width), dtype = torch.long)

        for doc_idx, doc in enumerate(docs):
            padded_token_lists, att_mask_lists = self.pad_sentence_for_batch(doc, batch_max_sent_length)

            for sent_idx, (padded_tokens, att_masks) in enumerate(
                    zip(padded_token_lists, att_mask_lists)):
                docs_tensor[doc_idx, sent_idx, :] = torch.tensor(padded_tokens, dtype=torch.long)
                att_mask[doc_idx, sent_idx, :] = torch.tensor(att_masks, dtype=torch.long)
        

        output = {
            "sentence_mask": torch.tensor(sent_mask),
            "input_ids": torch.tensor(docs_tensor),
            "attention_mask": torch.tensor(att_mask),
            "label_ids": torch.tensor(label_ids),
            "segment_mask": torch.tensor(segment_ids),
            "span_indices": torch.tensor(spans),
            "doc_name": id
        }
        return output
        

In [None]:
# LABELS = ["MASK","PREAMBLE", "NONE", "FAC", "ISSUE", "ARG_RESPONDENT", "ARG_PETITIONER", "ANALYSIS", "PRE_RELIED",
#               "PRE_NOT_RELIED", "STA", "RLC", "RPC", "RATIO"]

# LABELS = ["DEFAULT", 'MASK', "NONE", "Facts", "Argument", "Ratio of the decision", "Statute", "Precedent", "Ruling by Present Court", "Ruling by Lower Court"]

LABELS = ['MASK', 'OBJECTIVE', 'BACKGROUND', 'METHODS', 'RESULTS', 'CONCLUSIONS']

#%cd /content/drive/My\ Drive/Legal\ DS\ backup/Corpus/
%cd /content/drive/My\ Drive/Legal\ DS\ backup/Pubmed/

labels_int = range(len(LABELS)) 
label_to_ind = dict( zip(LABELS,labels_int))
#label_to_ind['MASK'] = len(label_to_ind)
label_to_ind['START'] = len(label_to_ind)
label_to_ind['STOP'] = len(label_to_ind)

bert_model = "bert-base-uncased"
# bert_model = "zlucia/custom-legalbert"

train_dataset = RRDataset('pubmed_train_to_ekstep.json',tokenizer_path = bert_model, label_to_ind = label_to_ind, max_len = 128)
dev_dataset = RRDataset('pubmed_dev_to_ekstep.json',tokenizer_path = bert_model, label_to_ind = label_to_ind, max_len = 128)
#dev_dataset = RRDataset('dev.json',tokenizer_path = bert_model, label_to_ind = label_to_ind, max_len = 128)

In [None]:
len(train_dataset), len(dev_dataset)

In [None]:
#max_width = max span length
#batch_size default 1

train_dataloader =  DataLoader(train_dataset, batch_size=30, shuffle=True, collate_fn = MyCollate(pad_idx = train_dataset.pad_token_id, sep_idx = train_dataset.pad_token_id, max_width = 2, label_to_ind=label_to_ind))

dev_dataloader =  DataLoader(dev_dataset, batch_size=30, shuffle=True, collate_fn = MyCollate(pad_idx = train_dataset.pad_token_id, sep_idx = train_dataset.pad_token_id, max_width = 2, label_to_ind=label_to_ind))

In [None]:
for batch_idx, x in tqdm(enumerate(train_dataloader),total=len(train_dataloader), leave=False):
      #print(x["sentence_mask"], x["input_ids"],x["attention_mask"], x["label_ids"])
      print(x["sentence_mask"].shape, x["input_ids"].shape,x["attention_mask"].shape, x["label_ids"].shape, x["doc_name"],x["span_indices"].shape,x["segment_mask"].shape)
      #print(x['span_indices'])
      print(x['label_ids'])
      print(x["segment_mask"])
      break

In [None]:
import torch 
import torch.nn as nn
import numpy as np
from torch.autograd import Variable


class SpanCRF(nn.Module):
    def __init__(self,  label_to_ind, max_path):
        super(SpanCRF, self).__init__()

        self.tag_to_ix = label_to_ind
        self.tagset_size = len(self.tag_to_ix)
        self.max_path = max_path
        
        self.transitions = nn.Parameter(torch.randn(self.tagset_size, self.tagset_size))


        
    def _forward_alg(self, logits, len_list, is_volatile=False):
        """
        Computes the (batch_size,) denominator term (FloatTensor list) for the log-likelihood, which is the
        sum of the likelihoods across all possible state sequences.
        
        Arguments:
            logits: [batch_size, seq_len, max_path, n_labels] FloatTensor
            lens: [batch_size] LongTensor
        """
        batch_size, seq_len, max_path, n_labels = logits.size()
        
        alpha = logits.data.new(batch_size, seq_len+1, self.tagset_size).fill_(-10000).to(device)
        alpha[:, 0, self.tag_to_ix['START']] = 0
        alpha = Variable(alpha, volatile=is_volatile)
        
        # Transpose batch size and time dimensions:
        logits_t = logits.permute(1,0,2,3).to(device)
        c_lens = len_list.clone()
        
        alpha_out_sum = Variable(logits.data.new(batch_size,max_path, self.tagset_size).fill_(0)).to(device)
        mat = Variable(logits.data.new(batch_size,self.tagset_size,self.tagset_size).fill_(0)).to(device)
        
        for j, logit in enumerate(logits_t):
            for i in range(0,max_path):
                if i<=j:
                    alpha_exp = alpha[:,j-i, :].clone().unsqueeze(1).expand(batch_size,self.tagset_size, self.tagset_size)
                    logit_exp = logit[:, i].unsqueeze(-1).expand(batch_size, self.tagset_size, self.tagset_size).to(device)
                    trans_exp = self.transitions.unsqueeze(0).expand_as(alpha_exp)
                    mat = alpha_exp + logit_exp + trans_exp
                    alpha_out_sum[:,i,:] =  self.log_sum_exp(mat , 2, keepdim=True).squeeze(2)

            alpha_nxt = self.log_sum_exp(alpha_out_sum , dim=1, keepdim=True).squeeze(1)
            
            mask = Variable((c_lens > 0).float().unsqueeze(-1).expand(batch_size,self.tagset_size)).to(device)
            alpha_nxt = mask * alpha_nxt + (1 - mask) *alpha[:, j, :].clone() 
            
            c_lens = c_lens - 1      

            alpha[:,j+1, :] = alpha_nxt

        alpha[:,-1,:] = alpha[:,-1,:] + self.transitions[self.tag_to_ix['STOP']].unsqueeze(0).expand_as(alpha[:,-1,:])
        norm = self.log_sum_exp(alpha[:,-1,:], 1).squeeze(-1)

        return norm

        
    def viterbi_decode(self, logits, lens):
        """
        Use viterbi algorithm to compute the most probable path of segments
        
        Arguments:
            logits: [batch_size, seq_len, max_path, n_labels] FloatTensor
            lens: [batch_size] LongTensor
        """
        batch_size, seq_len, max_path, n_labels = logits.size()
        logits = logits.to(device)
        # Transpose to batch size and time dimensions
        logits_t = logits.permute(1,0,2,3)
        
        vit = Variable(logits.data.new(batch_size,seq_len+1, self.tagset_size).fill_(-10000),
                                       volatile = not self.training).to(device)
        
        vit_tag_max = Variable(logits.data.new(batch_size,max_path, self.tagset_size).fill_(-10000),
                                   volatile = not self.training).to(device)
        
        vit_tag_argmax = Variable(logits.data.new(batch_size,max_path, self.tagset_size).fill_(-100),
                                   volatile = not self.training).to(device)
        vit[:,0, self.tag_to_ix['START']] = 0
        c_lens = Variable(lens.clone(), volatile= not self.training).to(device)
        
        pointers = Variable(logits.data.new(batch_size, seq_len, self.tagset_size, 2 ).fill_(-100))
        for j, logit in enumerate(logits_t):
            for i in range(0,max_path):
                if i<=j:
                    vit_exp = vit[:,j-i, :].clone().unsqueeze(1).expand(batch_size,self.tagset_size, self.tagset_size)
                    trn_exp = self.transitions.unsqueeze(0).expand_as(vit_exp)
                    vit_trn_sum = vit_exp + trn_exp
                    vt_max, vt_argmax = vit_trn_sum.max(2)
                    vit_nxt = vt_max + logit[:, i]
                    vit_tag_max[:,i,:] = vit_nxt
                    vit_tag_argmax[:,i,:] = vt_argmax
           
            seg_vt_max, seg_vt_argmax = vit_tag_max.max(1)
            
            mask = (c_lens > 0).float().unsqueeze(-1).expand_as(seg_vt_max)
            vit[:, j+1, :] = mask*seg_vt_max + (1-mask)*vit[:, j, :].clone()
            
            mask = (c_lens == 1).float().unsqueeze(-1).expand_as(  vit[:, j+1, :])
            vit[:, j+1, :] = vit[:, j+1, :] +  mask * self.transitions[ self.tag_to_ix['STOP'] ].unsqueeze(0).expand_as( vit[:, j+1, :] )
            
            idx_exp = seg_vt_argmax.unsqueeze(1)
            pointers[:,j,:,0] =  torch.gather(vit_tag_argmax, 1,idx_exp ).squeeze(1)
            pointers[:,j,:,1] = seg_vt_argmax 
            
            c_lens = c_lens - 1  
        
        #Get the argmax from the last viterbi scores and follow the reverse pointers for the best path 
        end_max , end_max_idx = vit[:,-1,:].max(1)
        end_max_idx = end_max_idx.data.cpu().numpy()
        
        pointers = pointers.data.long().cpu().numpy()
        pointers_rev = np.flip(pointers,1)
        paths = []
        segments = []
        
        for b in range(batch_size):
            #Different lengths each sentence, so get the starting index on the reverse list
            start_index = seq_len-lens[b] 
            path = [end_max_idx[b]]
            segment = [lens[b]]
            
            if (start_index >= seq_len -1):
                paths.append(path)
                continue
            
            max_tuple = pointers_rev[b,start_index,end_max_idx[b]]
            start_index += 1
            prev_tag = end_max_idx[b]
            next_tag = max_tuple[0]
            next_jump = max_tuple[1]
            
            for j, argmax in enumerate(pointers_rev[b,start_index:,:]):
                #Append same tag as many times as indicated by the best segment length we stored
                if next_jump > 0:
                    next_jump -= 1
                    path.insert(0, prev_tag)
                    continue
                #Switch to next tag when we hit zero
                else:
                    segment.insert(0, lens[b]- j-1)
                    path.insert(0, next_tag)
                
                #Get the next tag, and the number of times we have to append the previous one
                prev_tag = next_tag
                max_tuple = argmax[next_tag]
                next_tag = max_tuple[0]
                next_jump = max_tuple[1]
                
            segments.append(segment)     
            paths.append(path)
            
        return paths, segments
        
        
    def _bilstm_score(self, logits, labels, seg_inds, lens):
        
        """
        Computes the (batch_size,) numerator (FloatTensor list) for the log-likelihood, which is the
        
        Arguments:
            logits: [batch_size, seq_len, max_path, n_labels] FloatTensor
            labels: [batch_size, seq_len] LongTensor
            seg_inds: [batch_size, seq_len] LongTensor
            lens: [batch_size] LongTensor
        """
        lens = Variable( lens, volatile = not self.training)
        
        batch_size, max_len, _, _ = logits.size()
        
        # Transpose to batch size and time dimensions
        labels = labels.transpose(1,0)
        
        seg_inds = seg_inds.transpose(1,0).data.cpu().numpy()
        labels_exp = labels.unsqueeze(-1)

        #Construct the mask the will sellect the corrects segments from all possible segments for each timstep
        mask_seg = np.zeros(( batch_size, max_len, self.max_path))
        
        mask_step =  np.zeros(( batch_size), dtype=np.int32)
        counter = np.zeros((batch_size), dtype=np.int32)
        
        #For each timstep accross all sentences
        for i in range(0,max_len):
            #0 or 1 depending if we are on the end of a segment
            mask_step =  seg_inds[:, i] 
            mask_seg[np.arange(batch_size), i, counter] = mask_step 
            counter = counter + 1
            counter = (1- mask_step)*counter*(counter < self.max_path)
           
        mask_seg = torch.from_numpy(mask_seg).float()
        if next(self.parameters()).is_cuda == True:
            mask_seg = mask_seg.cuda()
            
        mask_seg = mask_seg.unsqueeze(-1).expand_as(logits)
        mask_seg = Variable(mask_seg,  volatile = not self.training).to(device)
        
        logit_mask = logits*mask_seg
        sum_cols = torch.sum(logit_mask, dim=2).squeeze(2)
        
        all_scores = torch.gather(sum_cols, 2, labels_exp).squeeze(-1)
        
        mask_time = self.sequence_mask(lens).float()
        all_scores = all_scores*mask_time
        
        sum_seg_scores = torch.sum(all_scores, dim=1).squeeze(-1)

        return  sum_seg_scores
        
    def score(self, logits, y, seg_inds, lens):
        logits = logits.to(device)
        bilstm_score = self._bilstm_score(logits, y, seg_inds, lens)
        transition_score = self.transition_score(y, lens, seg_inds )
        
        score = transition_score + bilstm_score

        return score
    
    def transition_score(self, labels, lens, mask_seg_idx):
        """
        Computes the (batch_size,) scores (FloatTensor list) that will be added to the emission scores
        
        Arguments:
            logits: [batch_size, seq_len, max_path, n_labels] FloatTensor
            labels: [batch_size, seq_len] LongTensor
            seg_inds: [batch_size, seq_len] LongTensor
            lens: [batch_size] LongTensor
        """
        lens = Variable( lens, volatile = not self.training)
        labels = labels.transpose(1,0)
        mask_seg_idx = mask_seg_idx.transpose(1,0)
        batch_size, seq_len = labels.size()
        # pad labels with <start> and <stop> indices
        labels_ext = Variable(labels.data.new(batch_size, seq_len + 2))
        labels_ext[:, 0] = self.tag_to_ix['START']
        labels_ext[:, 1:-1] = labels
        mask = self.sequence_mask(lens + 1, max_len=seq_len + 2).long()
        pad_stop = Variable(labels.data.new(1).fill_(self.tag_to_ix['STOP']))
        
        pad_stop = pad_stop.unsqueeze(-1).expand(batch_size, seq_len + 2)
        labels_ext = (1 + (-1)*mask) * pad_stop + mask * labels_ext
        trn = self.transitions
        
        trn_exp = trn.unsqueeze(0).expand(batch_size, *trn.size())
        lbl_r = labels_ext[:, 1:]
        lbl_rexp = lbl_r.unsqueeze(-1).expand(*lbl_r.size(), trn.size(0))
        trn_row = torch.gather(trn_exp, 1, lbl_rexp)
        
        lbl_lexp = labels_ext[:, :-1].unsqueeze(-1)
        trn_scr = torch.gather(trn_row, 2, lbl_lexp)
        trn_scr = trn_scr.squeeze(-1)
        
        # Mask sentences in time dim
        mask = self.sequence_mask(lens + 1).float()
        trn_scr = trn_scr * mask
        
        trn_scr[:, 1:] = trn_scr[:, 1:].clone()*mask_seg_idx.float() 
        
        score = trn_scr.sum(1).squeeze(-1)
        
        return score

    def loglik(self, logits, y, lens):
        norm_score = self._forward_alg(logits, lens)
        sequence_score = self.score(logits, y, lens, logits=logits)
        loglik = sequence_score - norm_score

        return loglik   


    def log_sum_exp(self,vec, dim=0, keepdim=True):
        max_val, idx = torch.max(vec, dim, keepdim=True)
        max_exp = max_val.expand_as(vec)
    
        return max_val + torch.log(torch.sum(torch.exp(vec - max_exp), dim, keepdim=keepdim))

    
    def sequence_mask(self,lens, max_len=None):
        batch_size = lens.size(0)
        if max_len is None:
        
            max_len = lens.max().data
            
        ranges = torch.arange(0, max_len).long()
        ranges = ranges.unsqueeze(0).expand(batch_size, max_len)
        ranges = Variable(ranges)
        if lens.data.is_cuda:
            ranges = ranges.cuda()

        lens_exp = lens.unsqueeze(1).expand_as(ranges)
        mask = ranges < lens_exp
        return mask

In [None]:
from allennlp.common.util import pad_sequence_to_length
from allennlp.modules.seq2seq_encoders import PytorchSeq2SeqWrapper
from allennlp.nn.util import masked_mean, masked_softmax
from allennlp.modules.span_extractors import EndpointSpanExtractor,SelfAttentiveSpanExtractor
import copy

from transformers import BertModel

from allennlp.modules import ConditionalRandomField

import torch
import torch.nn as nn


class CRFOutputLayer(torch.nn.Module):
    ''' CRF output layer consisting of a linear layer and a CRF. '''
    def __init__(self, in_dim, num_labels):
        super(CRFOutputLayer, self).__init__()
        self.num_labels = num_labels
        self.classifier = torch.nn.Linear(in_dim, self.num_labels)
        self.crf = ConditionalRandomField(self.num_labels)

    def forward(self, x, mask, labels=None):
        ''' x: shape: batch, max_sequence, in_dim
            mask: shape: batch, max_sequence
            labels: shape: batch, max_sequence
        '''

        batch_size, max_sequence, in_dim = x.shape

        logits = self.classifier(x)
        outputs = {}
        if labels is not None:
            log_likelihood = self.crf(logits, labels, mask)
            loss = -log_likelihood
            outputs["loss"] = loss
        else:
            best_paths = self.crf.viterbi_tags(logits, mask)
            predicted_label = [x for x, y in best_paths]
            predicted_label = [pad_sequence_to_length(x, desired_length=max_sequence) for x in predicted_label]
            predicted_label = torch.tensor(predicted_label)
            outputs["predicted_label"] = predicted_label

            #log_denominator = self.crf._input_likelihood(logits, mask)
            #log_numerator = self.crf._joint_likelihood(logits, predicted_label, mask)
            #log_likelihood = log_numerator - log_denominator
            #outputs["log_likelihood"] = log_likelihood

        return outputs
        


class AttentionPooling(torch.nn.Module):
    def __init__(self, in_features, dimension_context_vector_u=200, number_context_vectors=5):
        super(AttentionPooling, self).__init__()
        self.dimension_context_vector_u = dimension_context_vector_u
        self.number_context_vectors = number_context_vectors
        self.linear1 = torch.nn.Linear(in_features=in_features, out_features=self.dimension_context_vector_u, bias=True)
        self.linear2 = torch.nn.Linear(in_features=self.dimension_context_vector_u,
                                       out_features=self.number_context_vectors, bias=False)

        self.output_dim = self.number_context_vectors * in_features

    def forward(self, tokens, mask):
        #shape tokens: (batch_size, tokens, in_features)

        # compute the weights
        # shape tokens: (batch_size, tokens, dimension_context_vector_u)
        a = self.linear1(tokens)
        a = torch.tanh(a)
        # shape (batch_size, tokens, number_context_vectors)
        a = self.linear2(a)
        # shape (batch_size, number_context_vectors, tokens)
        a = a.transpose(1, 2)
        a = masked_softmax(a, mask)

        # calculate weighted sum
        s = torch.bmm(a, tokens)
        s = s.view(tokens.shape[0], -1)
        return s



class BertTokenEmbedder(torch.nn.Module):
    def __init__(self, config):
        super(BertTokenEmbedder, self).__init__()
        self.bert = BertModel.from_pretrained(config["bert_model"])
        self.bert_trainable = config["bert_trainable"]
        self.bert_hidden_size = self.bert.config.hidden_size
        for param in self.bert.parameters():
            param.requires_grad = self.bert_trainable

    def forward(self, batch):
        if "bert_embeddings" in batch:
            return batch["bert_embeddings"]

        documents, sentences, tokens = batch["input_ids"].shape
        attention_mask = batch["attention_mask"].view(-1, tokens)
        input_ids = batch["input_ids"].view(-1, tokens)

        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        # shape (documents*sentences, tokens, 768)
        bert_embeddings = outputs[0]

        if not self.bert_trainable:
            batch["bert_embeddings"] = bert_embeddings.to("cpu")
        return bert_embeddings

class BertHSLN(torch.nn.Module):
    def __init__(self, config):
        super(BertHSLN, self).__init__()

        self.bert = BertTokenEmbedder(config)
        self.dropout = torch.nn.Dropout(config["dropout"])
        self.word_lstm_hidden_size = config["word_lstm_hs"]
        self.word_lstm = PytorchSeq2SeqWrapper(torch.nn.LSTM(input_size=self.bert.bert_hidden_size,
                                  hidden_size=self.word_lstm_hidden_size,
                                  num_layers=1, batch_first=True, bidirectional=True))

        self.attention_pooling = AttentionPooling(2 * self.word_lstm_hidden_size,
                                                  dimension_context_vector_u=config["att_pooling_dim_ctx"],
                                                  number_context_vectors=config["att_pooling_num_ctx"])
        
        input_dim = self.attention_pooling.output_dim
        self.sentence_lstm_hidden_size = config["sentence_lstm_hs"]
        self.sentence_lstm = PytorchSeq2SeqWrapper(torch.nn.LSTM(input_size=input_dim,
                                  hidden_size=self.sentence_lstm_hidden_size,
                                  num_layers=1, batch_first=True, bidirectional=True))
        


        self.input_dim = self.sentence_lstm_hidden_size * 2
        self.max_path = config["max_path"]
        self.num_labels = len(config['label_to_ind'])
        
        self.span_crf = config["span_crf"]
        self.crf = config["crf"]

        if self.crf:
          self.crf_fc = nn.Linear(self.input_dim, self.num_labels)
          self.crf = SpanCRF(config["label_to_ind"],1)
          #self.crf = CRFOutputLayer(in_dim=self.input_dim, num_labels=self.num_labels)

        if self.span_crf:
          self.endpoint_span_extractor = EndpointSpanExtractor(self.sentence_lstm_hidden_size * 2,
                                                             combination="x,y,x*y,x-y",
                                                             num_width_embeddings=config["max_path"],
                                                             span_width_embedding_dim=config["span_width_embedding_dim"],
                                                             bucket_widths=True)
          # self.endpoint_span_extractor = SelfAttentiveSpanExtractor(self.sentence_lstm_hidden_size * 2,
          #                                                     num_width_embeddings=config["max_path"],
          #                                                     span_width_embedding_dim=config["span_width_embedding_dim"],
          #                                                     bucket_widths=True)
          self.span_input_dim = self.sentence_lstm_hidden_size * 2 * 4 + config["span_width_embedding_dim"]
          # self.span_input_dim = self.sentence_lstm_hidden_size * 2  + config["span_width_embedding_dim"]
          
          self.crf_spanfc = nn.Linear(self.span_input_dim, self.num_labels)
          self.spancrf = SpanCRF(config["label_to_ind"],self.max_path)




    def forward(self, batch, labels=None, eval=False):

        documents, sentences, tokens = batch["input_ids"].shape
  
        # shape (documents*sentences, tokens, 768)
        bert_embeddings = self.bert(batch)
        bert_embeddings = self.dropout(bert_embeddings)

        tokens_mask = batch["attention_mask"].view(-1, tokens)

        # shape (documents*sentences, tokens, 2*lstm_hidden_size)
        bert_embeddings_encoded = self.word_lstm(bert_embeddings, tokens_mask)

        #shape (documents*sentences, pooling_out)
        sentence_embeddings = self.attention_pooling(bert_embeddings_encoded, tokens_mask)

        # shape: (documents, sentences, pooling_out)
        sentence_embeddings = sentence_embeddings.view(documents, sentences, -1)
        sentence_embeddings = self.dropout(sentence_embeddings)


        sentence_mask = batch["sentence_mask"]

        # shape: (documents, sentence, 2*lstm_hidden_size)
        sentence_embeddings_encoded = self.sentence_lstm(sentence_embeddings, sentence_mask)
        sentence_embeddings_encoded = self.dropout(sentence_embeddings_encoded)

        sentence_len = torch.sum(sentence_mask,dim=-1)
        output = {}

        if self.span_crf:
            span_embeddings = self.endpoint_span_extractor(sentence_embeddings_encoded,batch["span_indices"], sentence_mask)
            segment_rep = self.crf_spanfc(span_embeddings)
            _,max_span_len,_ = segment_rep.shape
        
            segment_span_feat = torch.zeros(documents, sentences, self.max_path, self.num_labels)

        
            batch_size, max_span_len,_ = batch["span_indices"].shape
            _, max_seq_len, max_path_len, _ = segment_span_feat.shape

            for i in range(batch_size):
              for j in range(max_span_len):
                start_idx = batch["span_indices"][i][j][0]
                len_idx = batch["span_indices"][i][j][1] - batch["span_indices"][i][j][0]
                segment_span_feat[i,start_idx,len_idx,:] = segment_rep[i][j]
            
            segment_mask = batch["segment_mask"]
            
            if not eval:
                span_forward_var_batch = self.spancrf._forward_alg(segment_span_feat,sentence_len )
                span_gold_score_batch = self.spancrf.score(segment_span_feat, labels.transpose(0,1) , segment_mask.transpose(0,1),sentence_len)
                loss = (span_forward_var_batch-span_gold_score_batch).mean()
                #output['span_crf'] = {"forward_var_batch":span_forward_var_batch , "gold_score_batch" : span_gold_score_batch}
                output['loss2'] = loss

        if self.crf:
            #output = self.crf(sentence_embeddings_encoded, sentence_mask, labels)
            #return output 

            segment_feat = sentence_embeddings_encoded.unsqueeze(2)
            segment_feat = self.crf_fc(segment_feat)
            segment_feat = segment_feat.view(documents, sentences, 1, self.num_labels)
            
            if not eval:
                forward_var_batch = self.crf._forward_alg(segment_feat,sentence_len )
                gold_score_batch = self.crf.score(segment_feat, labels.transpose(0,1) , sentence_mask.transpose(0,1),sentence_len)
                loss = (forward_var_batch-gold_score_batch  ).mean()
                output['loss1'] = loss
            

        if eval:
            if self.crf:
              crf_tag_seqs, crf_segments = self.crf.viterbi_decode(segment_feat,sentence_len )
              #output['crf'] = {"tag_seqs":crf_tag_seqs, "segments":  crf_segments }
              output["predicted_label1"] = crf_tag_seqs
            if self.span_crf:
              span_crf_tag_seqs, span_crf_segments = self.spancrf.viterbi_decode(segment_span_feat,sentence_len )
              #output['span_crf'] = {"tag_seqs":span_crf_tag_seqs, "segments":  span_crf_segments }
              output["predicted_label2"] = span_crf_tag_seqs

        return output

In [None]:
#MAX PATH CHANGED TO MAX WIDTH 

config = {
    "dropout":0.5,
    "word_lstm_hs":758,
    "att_pooling_dim_ctx":200,
    "att_pooling_num_ctx": 15,
    "sentence_lstm_hs":758,
    "bert_model": bert_model,
    "bert_trainable": False,
    "label_to_ind" : label_to_ind,
    "max_path": 10,
    "span_width_embedding_dim" :100,
    "lr_epoch_decay":0.9,
    "crf": False,
    "span_crf" : True
}

In [None]:
from sklearn.metrics import precision_recall_fscore_support, classification_report, confusion_matrix, accuracy_score
import numpy as np

def eval_model(model, eval_batches, device,label_to_ind, id):
    model.eval()
    true_labels = []
    labels_dict={}
    predicted_labels = []
    docwise_predicted_labels=[]
    docwise_true_labels = []
    doc_name_list = []
    with torch.no_grad():
        for batch in eval_batches:
            # move tensor to gpu
            shift_to_device(batch,device)

            output = model(batch=batch, eval=True)
            
            for each in range(len(batch['doc_name'])):
                true_labels_batch, predicted_labels_batch = \
                    clear_and_map_padded_values(batch["label_ids"][each], output[id][each],label_to_ind)
                assert len(batch['sentence_mask'][each].nonzero()) == len(predicted_labels_batch)

                #print(batch["label_ids"][each].shape, output['predicted_label'][each].shape)
                #print(len(true_labels_batch), len(predicted_labels_batch))
                #print(batch['doc_name'][each])
                docwise_true_labels.append(true_labels_batch)
                docwise_predicted_labels.append(predicted_labels_batch)
                doc_name_list.append(batch['doc_name'][each])
                true_labels.extend(true_labels_batch)
                predicted_labels.extend(predicted_labels_batch)
            
            shift_to_device(batch,torch.device("cpu"))
    
    labels_dict['y_true']=true_labels
    labels_dict['y_predicted'] = predicted_labels
    labels_dict['docwise_y_true'] = docwise_true_labels
    labels_dict['docwise_y_predicted'] = docwise_predicted_labels
    labels_dict['doc_names'] = doc_name_list
    metrics, confusion, class_report = \
        calc_classification_metrics(y_true=true_labels, y_predicted=predicted_labels,labels = list(label_to_ind.keys()))
    return metrics, confusion,labels_dict, class_report
    
    
def clear_and_map_padded_values(true_labels, predicted_labels,label_to_ind):
    cleared_predicted = []
    cleared_true = []
    ind_to_label = {v: k for k, v in label_to_ind.items()}
    for true_label, predicted_label in zip(true_labels, predicted_labels):
        if true_label.item() != label_to_ind['MASK']:
            cleared_true.append(ind_to_label[true_label.item()])
            cleared_predicted.append(ind_to_label[predicted_label.item()])
    return cleared_true, cleared_predicted

def calc_span_idx(labels):
    span_idx = []
    span_labels = []
    i = 0
    while(i < len(labels)):
      start = 0
      end = 0
      if(i == len(labels) - 1):
        if(labels[i] != labels[i-1]):
          start = i
          end = i
          i = i + 1
      else:
        if(labels[i] != labels[i+1]):
          start = i
          end = i
          i = i + 1
        else:
          start = i
          while(labels[i] == labels[i+1]):
            i = i + 1
            if(i == len(labels) - 1):
              end = i
              break
          end = i
          i = i + 1

      span_idx.append((start, end, labels[start]))
      span_labels.append(labels[start])

    return span_idx, span_labels


def calc_classification_metrics(y_true, y_predicted, labels):
    macro_precision, macro_recall, macro_f1, _ = precision_recall_fscore_support(y_true, y_predicted, average='macro')
    micro_precision, micro_recall, micro_f1, _ = precision_recall_fscore_support(y_true, y_predicted, average='micro')
    weighted_precision, weighted_recall, weighted_f1, _ = precision_recall_fscore_support(y_true, y_predicted, average='weighted')
    per_label_precision, per_label_recall, per_label_f1, _ = precision_recall_fscore_support(y_true, y_predicted, average=None, labels=labels)


    true_spans, t_labels = calc_span_idx(y_true)
    pred_spans, p_labels = calc_span_idx(y_predicted)

    correct_spans = []
    for item in pred_spans:
      if item in true_spans:
        correct_spans.append(item)
        


    if(len(true_spans) < len(pred_spans)):
      span_F1 = len(correct_spans) / len(true_spans)
    else:
      span_F1 = len(correct_spans) / len(pred_spans)


    acc = accuracy_score(y_true, y_predicted)

    class_report = classification_report(y_true, y_predicted, digits=4)
    confusion_abs = confusion_matrix(y_true, y_predicted, labels=labels)
    # normalize confusion matrix
    confusion = np.around(confusion_abs.astype('float') / confusion_abs.sum(axis=1)[:, np.newaxis] * 100, 2)
    return {"acc": acc,
            "macro-f1": macro_f1,
            "Span-F1": span_F1,
            "macro-precision": macro_precision,
            "macro-recall": macro_recall,
            "micro-f1": micro_f1,
            "micro-precision": micro_precision,
            "micro-recall": micro_recall,
            "weighted-f1": weighted_f1,
            "weighted-precision": weighted_precision,
            "weighted-recall": weighted_recall,
            "labels": labels,
            "per-label-f1": per_label_f1.tolist(),
            "Correct-Spans": correct_spans, 
            "per-label-precision": per_label_precision.tolist(),
            "per-label-recall": per_label_recall.tolist(),
            #"confusion_abs": confusion_abs.tolist()
            }, \
           confusion.tolist(), \
           class_report

In [None]:
def shift_to_device(batch,device):
  for key in batch.keys():
    if(torch.is_tensor(batch[key])):
      batch[key] = batch[key].to(device)
    else:
      batch[key] = batch[key]

In [None]:
import torch.optim as optim
from tqdm import tqdm
from sklearn.metrics import accuracy_score,precision_score,recall_score,f1_score
import itertools
import gc
from torch.optim.lr_scheduler import StepLR

gc.collect()
torch.cuda.empty_cache()

num_epochs = 20
learning_rate = 3e-4

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = BertHSLN(config).to(device)


for name, param in model.named_parameters():
  if("bert" in name):
    param.requires_grad = False

optimizer = optim.Adam(list(filter(lambda p: p.requires_grad, model.parameters())), lr=learning_rate)
max_grad_norm = 1.0
epoch_scheduler = StepLR(optimizer, step_size=1, gamma=config["lr_epoch_decay"])

accs = []
epochs = []
train_losses = []


for epoch in range(num_epochs):
    print(f"[Epoch {epoch} / {num_epochs}]")
    model.train()
    for batch_idx, batch in tqdm(enumerate(train_dataloader),total=len(train_dataloader), leave=False):
        shift_to_device(batch,device)
        output = model(batch, batch["label_ids"])
        loss = output["loss2"]
        #loss = output["loss1"] + output["loss2"]
        loss = loss.sum()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        optimizer.zero_grad()
        shift_to_device(batch,torch.device("cpu"))
    
    epochs.append(epoch)
    train_losses.append(loss.item())
    print(f"Loss in epoch {epoch} : {loss.item()}")
    epoch_scheduler.step()

    #test_metrics, test_confusion,labels_dict,_ = eval_model(model, dev_dataloader , device,label_to_ind,"predicted_label1")
    #print(test_metrics)
    test_metrics, test_confusion,labels_dict,_ = eval_model(model, dev_dataloader , device,label_to_ind,"predicted_label2")
    accs.append(test_metrics['acc'])
    print(test_metrics)

In [None]:
#pubmed ep sp2 sF1

In [None]:
#PahPre macroF1 0.32 acc: 0.58 Epoch: 25 span len: 2

import matplotlib.pyplot as plt

plt.plot(epochs, train_losses)
plt.ylabel('Train loss')
plt.xlabel('Epochs')
plt.show()

In [None]:
plt.plot(epochs, accs)
plt.ylabel('Accuracy')
plt.xlabel('Epochs')
plt.show()

In [None]:
#Paheli macroF1 0.19777 acc: 0.42 Epoch: 31 span len: 10
#PahPre macroF1 0.19371 acc: 0.42 Epoch: 50 span len: 10

#PahPre macroF1 0.32 acc: 0.58 Epoch: 25 span len: 2

In [None]:
#Ekstep Span 2 atten

In [None]:
import matplotlib.pyplot as plt

plt.plot(epochs, train_losses)
plt.ylabel('Train loss')
plt.xlabel('Epochs')
plt.show()

In [None]:
plt.plot(epochs, accs)
plt.ylabel('Accuracy')
plt.xlabel('Epochs')
plt.show()

In [None]:
#EKstep legal Span 2 atten

In [None]:
import matplotlib.pyplot as plt

plt.plot(epochs, train_losses)
plt.ylabel('Train loss')
plt.xlabel('Epochs')
plt.show()

In [None]:
plt.plot(epochs, accs)
plt.ylabel('Accuracy')
plt.xlabel('Epochs')
plt.show()

In [None]:
#Ekstep bert sp 2 atten

In [None]:
import matplotlib.pyplot as plt

plt.plot(epochs, train_losses)
plt.ylabel('Train loss')
plt.xlabel('Epochs')
plt.show()

In [None]:
plt.plot(epochs, accs)
plt.ylabel('Accuracy')
plt.xlabel('Epochs')
plt.show()