In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim

from transformers import AutoTokenizer
from transformers import AutoConfig
from transformers import XLMRobertaConfig
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.models.roberta.modeling_roberta import RobertaModel
from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel

from datasets import Dataset

from tqdm import tqdm
from copy import deepcopy
import pandas as pd
import numpy as np
import random
import joblib
import ast
import gc

In [2]:
device = 'mps'

In [3]:
class XLMRobertaForTokenClassification(RobertaPreTrainedModel):
  config_class = XLMRobertaConfig
  def __init__(self, config):
    super().__init__(config)

    # load model body
    self.roberta = RobertaModel(config, add_pooling_layer=False)

    # load and initialize weights
    self.init_weights()

  def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None, **kwargs):

    # use model body to get encoder representation
    outputs = self.roberta(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, **kwargs)

    return outputs

In [4]:
class PrototypicalNetworksNER(nn.Module):
    def __init__(self, encoder, hidden_size, prototype_size):
        super(PrototypicalNetworksNER, self).__init__()

        # self.num_classes = num_classes
        self.hidden_size = hidden_size
        self.prototype_size = prototype_size
        self.encoder = encoder
        self.fc1 = nn.Linear(hidden_size, prototype_size)


    def build_prototypes(self, support_set, num_classes, get_prototypes=False):

      prototypes = torch.zeros(num_classes, self.prototype_size).to(device)
      all_prototypes = [[] for clas in range(num_classes)]
      # count of each class
      class_counts = torch.zeros(num_classes).to(device)
      for hidden_states, labels in support_set:
        for hidden_state, label in zip(hidden_states.squeeze(), labels):

          prototype = self.fc1(hidden_state)
          prototypes[label] += prototype.to(device)
          class_counts[label]+=1

          if get_prototypes:
            all_prototypes[label].append(prototype)
      # mean
      prototypes = prototypes/class_counts.unsqueeze(1)
      prototypes = torch.nan_to_num(prototypes, 0.)
      if get_prototypes:
        return prototypes, all_prototypes
      return prototypes


    def predict_query_set(self, query_set, prototypes):
        distances = []
        for hidden_states in query_set:
          hidden_states = hidden_states.squeeze()
          q_prototype = self.fc1(hidden_states)
          distance = torch.cdist(q_prototype, prototypes)
          # To get the closest prototype
          # predicted_labels = F.softmax(distances, dim=1)
          distances.append(distance)
        return distances

    def forward(self, support_set, query_set, num_classes, get_prototypes=False):
        '''
        input:
            Tokenized Support Set
            Tokenized Query Set

        Output of encoder (Roberta Model)
            sequence output: (1, 256, 768)

        Final Outputs:
            Distances
        '''

        hidden_support_set, hidden_query_set = [], []
        for item in support_set:
          s_input_ids, s_attention_mask, s_labels = item
          s_hidden_states = self.encoder(input_ids=s_input_ids.to(device), attention_mask=s_attention_mask.to(device))['last_hidden_state']
          hidden_support_set.append([s_hidden_states, s_labels])
        if get_prototypes:
          prototypes, all_protos = self.build_prototypes(hidden_support_set, num_classes, get_prototypes)
        else:
          prototypes = self.build_prototypes(hidden_support_set, num_classes)

        q_input_ids, q_attention_mask, q_labels = query_set
        q_hidden_states = self.encoder(input_ids=q_input_ids.to(device), attention_mask=q_attention_mask.to(device))['last_hidden_state']
        hidden_query_set.append(q_hidden_states)

        predictions = self.predict_query_set(hidden_query_set, prototypes)

        del hidden_support_set
        del hidden_query_set
        del s_input_ids
        del s_attention_mask
        del s_labels
        del s_hidden_states
        del q_input_ids
        del q_attention_mask
        gc.collect()

        if get_prototypes:
          return predictions, all_protos

        return predictions

In [5]:
# device = 'cpu'
HIDDEN_SIZE = 768
PROTOTYPE_SIZE = 256

xlmr_model_name = "xlm-roberta-base"
xlmr_tokenizer = AutoTokenizer.from_pretrained(xlmr_model_name)
xlmr_config = AutoConfig.from_pretrained(xlmr_model_name)
xlmr_model = (XLMRobertaForTokenClassification.from_pretrained(xlmr_model_name, config=xlmr_config).to(device))
model = PrototypicalNetworksNER(xlmr_model, HIDDEN_SIZE, PROTOTYPE_SIZE).to(device)

In [6]:
model.load_state_dict(torch.load('10_epoch_state_dict', map_location=torch.device('mps')))
model.eval()

PrototypicalNetworksNER(
  (encoder): XLMRobertaForTokenClassification(
    (roberta): RobertaModel(
      (embeddings): RobertaEmbeddings(
        (word_embeddings): Embedding(250002, 768, padding_idx=1)
        (position_embeddings): Embedding(514, 768, padding_idx=1)
        (token_type_embeddings): Embedding(1, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): RobertaEncoder(
        (layer): ModuleList(
          (0-11): 12 x RobertaLayer(
            (attention): RobertaAttention(
              (self): RobertaSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): RobertaSelfOutput(
            

In [81]:
eval_episodes = joblib.load('eval_episodes')

In [29]:
class FewShotSampler:
    '''
    Samples the data from csv to episodic data with query and support set in N-way K-shot
    n_wa: numeber of classes
    k_shot: sample for each class
    '''
    def __init__(self, k_shot, n_way):
        self.k_shot = k_shot
        self.n_way = n_way

    def extract_n_way_data(self, data):
        n_way_data = []
        for items in data:
            if len(set(items[1]))==self.n_way:
                n_way_data.append(items)

        return n_way_data

    def sample_data(self, n_way_data, full_data, n_episodes):
        episodes = []
        random.seed(42)
        for query in random.sample(n_way_data, min(n_episodes, len(n_way_data))):

            query_text = query[0]
            query_tags = query[1]

            query_tag_set = set(query_tags)

            final_query_set = [query_text, query_tags]

            support_data = []
            class_counts = {k: 0 for k in query_tag_set}

            for items in random.sample(full_data, len(full_data)):
                text, labels = items
                new_labels = []
                for i in range(len(labels)):
                    if labels[i] not in query_tag_set:
                        new_labels.append('O')
                    else:
                        new_labels.append(labels[i])


                tag_count_greater_k = False
                for tag in query_tag_set:
                    if [text, new_labels] in support_data:
                        break
                    if len(set(new_labels))<=1:
                        break
                    if len(support_data) >= self.n_way*self.k_shot:
                        break
                    if tag in new_labels:
                        for ent in new_labels:
                            if class_counts[ent]>=self.k_shot and ent!='O':
                                tag_count_greater_k = True
                        if not tag_count_greater_k:
                            support_data.append([text, new_labels])
                            for ent in set(new_labels):
                                class_counts[ent]+=1

            episodes.append({
                'query_set': final_query_set,
                'support_set': support_data
            })

            n_way_data.remove(query)

        return episodes

    def generate_episodes(self, reference_data, num_episodes):
        extracted_data = self.extract_n_way_data(reference_data)
        episodes = self.sample_data(extracted_data, reference_data, num_episodes)

        return episodes

In [32]:
judgement = pd.read_csv('train_judgement_bio.csv')
df = judgement.sample(frac=1, random_state=42).copy()
df = df.rename({'BIO_tags': 'tags'}, axis=1)

def convert_to_list(s):
    return ast.literal_eval(s)

df['tokens'] = df['tokens'].apply(convert_to_list)
df['tags'] = df['tags'].apply(convert_to_list)

def remove_BIO(tags):
  new_tags = []
  for tag in tags:
    new_tags.append(tag.replace('I-', '').replace('B-', ''))

  return new_tags

df['tags'] = df['tags'].apply(remove_BIO)

tag_list = df['tags'].values
labels = [label for tags in tag_list  for label in tags]
labels = list(set(labels))
print(labels)
print(len(labels))

['PETITIONER', 'PROVISION', 'DATE', 'JUDGE', 'GPE', 'WITNESS', 'ORG', 'CASE_NUMBER', 'O', 'OTHER_PERSON', 'RESPONDENT', 'PRECEDENT', 'STATUTE', 'COURT']
14


In [33]:
dataset = df.values.tolist()
for data in dataset:
    print(data[0])
    print(data[1])
    break

['Clause', '18(1', ')', ',', '(', '2', ')', 'and', '(', '3', ')', '\n', '(', 'a', ')', '&', '(', 'b', ')', 'were', 'transposed', 'in', 'Article', '23', 'of', 'the', 'Draft', 'Constitution', 'of', 'India', '.']
['PROVISION', 'PROVISION', 'PROVISION', 'PROVISION', 'PROVISION', 'PROVISION', 'PROVISION', 'PROVISION', 'PROVISION', 'PROVISION', 'PROVISION', 'PROVISION', 'PROVISION', 'PROVISION', 'PROVISION', 'PROVISION', 'PROVISION', 'PROVISION', 'PROVISION', 'O', 'O', 'O', 'PROVISION', 'PROVISION', 'O', 'O', 'STATUTE', 'STATUTE', 'STATUTE', 'STATUTE', 'O']


In [35]:
def tokenize_and_align_labels(support_set, q_tokens, q_tags, eval=False):

  # For support

  tokenized_support_set = []
  label2idx = {}
  idx = 0
  for items in support_set:
    s_tokens, s_tags = items
    s_tokenized = xlmr_tokenizer(s_tokens, truncation=True, is_split_into_words=True)

    for tag in s_tags:
      if tag not in label2idx:
        label2idx[tag] = idx
        idx+=1

    labels = [label2idx[tag] for tag in s_tags]
    word_ids = s_tokenized.word_ids()

    label_ids = []
    for word_idx in word_ids:
      if word_idx==None:
        label_ids.append(label2idx['O'])
      else:
        label_ids.append(labels[word_idx])

    s_tokenized['labels'] = label_ids

    tokenized_support_set.append(s_tokenized)


  # For query
  q_tokenized_inputs = xlmr_tokenizer(q_tokens, truncation=True, is_split_into_words=True)

  if not eval:
    q_word_ids = q_tokenized_inputs.word_ids()
    q_labels = [label2idx[tag] for tag in q_tags]
    q_label_ids = []
    for word_idx in q_word_ids:
      if word_idx==None:
        q_label_ids.append(label2idx['O'])
      else:
        q_label_ids.append(q_labels[word_idx])

    q_tokenized_inputs['labels'] = q_label_ids
  else:
    q_tokenized_inputs['labels'] = []

  return q_tokenized_inputs, tokenized_support_set, label2idx

In [36]:
def tokenize_episodes(episodes, eval=False):

    tokenized_episodes = []
     
    for episode in episodes:

        final_support_set = []
        final_query_set = []
        query = episode['query_set']
        support = episode['support_set']

        q_tokens, q_tags = query

        tokenized_query_set, tokenized_support_set, label2idx = tokenize_and_align_labels(support, q_tokens, q_tags, eval=eval)

        if len(final_query_set)<1:
            q_ii = torch.tensor(tokenized_query_set['input_ids']).unsqueeze(0)
            q_am = torch.tensor(tokenized_query_set['attention_mask']).unsqueeze(0)
            q_l = torch.tensor(tokenized_query_set['labels'])
            final_query_set.extend([q_ii, q_am, q_l])

        for support_set in tokenized_support_set:
            s_ii = torch.tensor(support_set['input_ids']).unsqueeze(0)
            s_am = torch.tensor(support_set['attention_mask']).unsqueeze(0)
            s_l = torch.tensor(support_set['labels'], dtype=torch.int)
            final_support_set.append([s_ii, s_am, s_l])

        tokenized_episodes.append({
            'query_set': final_query_set,
            'support_set': final_support_set,
            'label2idx': label2idx
        })

    return tokenized_episodes

In [70]:
eval_sampler = FewShotSampler(1, 3)
eval_episodes = eval_sampler.generate_episodes(deepcopy(dataset), 401)

In [71]:
len(eval_episodes)

401

In [82]:
tokenized_eval_episodes = tokenize_episodes(eval_episodes)

In [83]:
import gc
import torch

In [84]:
def eval(eval_eps):
  y_true = []
  y_preds = []
  all_i2l = []
  model.eval()
  for i, episode in tqdm(enumerate(eval_eps), total=len(eval_eps)):
    query_set = episode['query_set']
    support_set = episode['support_set']
    num_classes = len(episode['label2idx'])
    i2l = {i: l for l, i in episode['label2idx'].items()}
    _, _, query_labels = query_set

    # Calculate predictions
    eval_distances = model(support_set, query_set, num_classes)[0]
    eval_preds = torch.argmin(eval_distances, dim=1).detach().cpu().numpy()

    y_true.append(query_labels.tolist())
    y_preds.append(eval_preds.tolist())
    all_i2l.append(i2l)

    del query_set
    del support_set
    del query_labels
    del num_classes
    gc.collect()

  return y_true, y_preds, all_i2l

In [85]:
y_test, y_preds, i2l = eval(tokenized_eval_episodes)
len(y_test), len(y_preds)

100%|██████████| 401/401 [05:42<00:00,  1.17it/s]


(401, 401)

In [87]:
joblib.dump(i2l, 'index2label')

['index2label']

In [86]:
joblib.dump(y_test, 'y_test')
joblib.dump(y_preds, 'y_pred')

['y_pred']

In [76]:
labels = ['OTHER_PERSON', 'GPE', 'PROVISION', 'WITNESS', 'O', 'RESPONDENT', 'DATE', 'COURT', 'CASE_NUMBER', 'JUDGE', 'STATUTE', 'PETITIONER', 'ORG', 'PRECEDENT']

In [77]:
from sklearn.metrics import f1_score, precision_score, recall_score

In [78]:
def find_scores_per_class(y_test, y_preds, i2l, labels=labels, metrics='f1'):
  fscores = {l: [] for l in labels}
  pscores = {l: [] for l in labels}
  rscores = {l: [] for l in labels}
  for true, pred, idx2label in zip(y_test, y_preds, i2l):
    f_scores = f1_score(true, pred, average=None)
    p_scores = precision_score(true, pred, average=None, zero_division=0.0)
    r_scores = recall_score(true, pred, average=None, zero_division=0.0)
    for i in range(len(f_scores)):
      lab = idx2label[i]
      fscores[lab].append(f_scores[i])
      pscores[lab].append(p_scores[i])
      rscores[lab].append(r_scores[i])

  final_scores = {l: [] for l in labels}
  for k in fscores:
    final_scores[k].append(np.mean(fscores[k]))
    final_scores[k].append(np.mean(pscores[k]))
    final_scores[k].append(np.mean(rscores[k]))

  return pd.DataFrame(final_scores, index=['Average f1 scores', 'Average precision scores', 'Average recall scores'], 
                      columns=['COURT', 'JUDGE', 'WITNESS', 'STATUTE', 'PETITIONER', 'DATE', 'OTHER_PERSON', 'PRECEDENT', 'O', 'RESPONDENT', 'GPE', 'CASE_NUMBER', 'PROVISION', 'ORG'])

In [79]:
all_scores = find_scores_per_class(y_test, y_preds, i2l, metrics='f1').T
all_scores

Unnamed: 0,Average f1 scores,Average precision scores,Average recall scores
COURT,0.880783,0.892522,0.923586
JUDGE,0.819643,0.802827,0.859375
WITNESS,0.746781,0.700647,0.892677
STATUTE,0.946223,0.936043,0.975976
PETITIONER,0.562844,0.516026,0.714286
DATE,0.97787,0.974856,0.988223
OTHER_PERSON,0.813244,0.826266,0.844412
PRECEDENT,0.909851,0.86326,0.985839
O,0.975036,0.986303,0.968496
RESPONDENT,0.599435,0.550476,0.7


In [80]:
import pandas as pd

 
exel = pd.ExcelWriter('outputs.xlsx')
all_scores.to_excel(exel)
 
exel.close()

In [69]:
torch.mps.current_allocated_memory(), torch.mps.driver_allocated_memory()

(1573310464, 4379525120)