In [32]:
import pandas as pd
import numpy as np
import ast

import random
from copy import deepcopy

import torch
from transformers import AutoTokenizer

In [12]:
judgement = pd.read_csv('train_judgement_bio.csv')
judgement.head()

Unnamed: 0,tokens,BIO_tags
0,"['\n\n', '(', '7', ')', 'On', 'specific', 'que...","['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', ..."
1,"['He', 'was', 'also', 'asked', 'whether', 'Agy...","['O', 'O', 'O', 'O', 'O', 'B-OTHER_PERSON', 'O..."
2,"[' \n', '5.2', 'CW3', 'Mr', 'Vijay', 'Mishra',...","['O', 'O', 'O', 'O', 'B-WITNESS', 'I-WITNESS',..."
3,"['You', 'are', 'hereby', 'asked', 'not', 'to',...","['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', ..."
4,"['The', 'pillion', 'rider', 'T.V.', 'Satyanara...","['O', 'O', 'O', 'B-OTHER_PERSON', 'I-OTHER_PER..."


In [13]:
df = judgement.sample(frac=1, random_state=42).copy()
df = df.rename({'BIO_tags': 'tags'}, axis=1)
df.head()

Unnamed: 0,tokens,tags
4613,"['Clause', '18(1', ')', ',', '(', '2', ')', 'a...","['B-PROVISION', 'I-PROVISION', 'I-PROVISION', ..."
1103,"['The', 'order', 'can', 'not', 'be', 'said', '...","['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', ..."
5214,"['Ajit', 'Kumar', 'Guha', '(', 'D.', 'W.', '1'...","['B-WITNESS', 'I-WITNESS', 'I-WITNESS', 'O', '..."
3315,"['The', 'purpose', 'of', 'entering', 'into', '...","['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', ..."
5363,"['It', 'is', 'admitted', 'that', 'the', 'vehic...","['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', ..."


In [14]:
def convert_to_list(s):
    return ast.literal_eval(s)

df['tokens'] = df['tokens'].apply(convert_to_list)
df['tags'] = df['tags'].apply(convert_to_list)
df.head()

Unnamed: 0,tokens,tags
4613,"[Clause, 18(1, ), ,, (, 2, ), and, (, 3, ), \n...","[B-PROVISION, I-PROVISION, I-PROVISION, I-PROV..."
1103,"[The, order, can, not, be, said, to, be, wrong...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ..."
5214,"[Ajit, Kumar, Guha, (, D., W., 1, ), ,, who, w...","[B-WITNESS, I-WITNESS, I-WITNESS, O, O, O, O, ..."
3315,"[The, purpose, of, entering, into, a, contract...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ..."
5363,"[It, is, admitted, that, the, vehicle, bearing...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ..."


In [15]:
def remove_BIO(tags):
  new_tags = []
  for tag in tags:
    new_tags.append(tag.replace('I-', '').replace('B-', ''))

  return new_tags

df['tags'] = df['tags'].apply(remove_BIO)
df['tags'].head()

4613    [PROVISION, PROVISION, PROVISION, PROVISION, P...
1103    [O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...
5214    [WITNESS, WITNESS, WITNESS, O, O, O, O, O, O, ...
3315    [O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...
5363    [O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...
Name: tags, dtype: object

In [16]:
tag_list = df['tags'].values
labels = [label for tags in tag_list  for label in tags]
labels = list(set(labels))
print(labels)
print(len(labels))

['OTHER_PERSON', 'GPE', 'PROVISION', 'WITNESS', 'O', 'RESPONDENT', 'DATE', 'COURT', 'CASE_NUMBER', 'JUDGE', 'STATUTE', 'PETITIONER', 'ORG', 'PRECEDENT']
14


In [17]:
data_5_way = df[df['tags'].apply(lambda x: len(set(x))==5)].sample(frac=1, random_state=42).reset_index().drop(columns='index', axis=1).copy()
data_5_way

Unnamed: 0,tokens,tags
0,"[Even, if, the, decision, of, the, Bombay, Hig...","[O, O, O, O, O, O, COURT, COURT, COURT, O, O, ..."
1,"[Under, the, aforesaid, deed, ,, which, has, c...","[O, O, O, O, O, O, O, O, O, O, O, DATE, O, O, ..."
2,"[We, have, a, Judgment, of, learned, Single, J...","[O, O, O, O, O, O, O, O, O, O, O, O, O, JUDGE,..."
3,"[By, further, order, dated, November, 1, ,, 19...","[O, O, O, O, DATE, DATE, DATE, DATE, O, O, O, ..."
4,"[Time, must, have, been, taken, by, both, PW14...","[O, O, O, O, O, O, O, O, WITNESS, WITNESS, O, ..."
...,...,...
396,"[The, present, case, has, been, instituted, on...","[O, O, O, O, O, O, O, O, O, O, O, O, PETITIONE..."
397,"[Various, complaints, were, filed, against, th...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, PRO..."
398,"[Application, of, Section, 29, of, the, Contra...","[O, O, PROVISION, PROVISION, O, O, STATUTE, ST..."
399,"[(, ALOK, VERMA, ), Judge, manju, M.Cr, ., C.N...","[O, JUDGE, JUDGE, O, O, O, CASE_NUMBER, CASE_N..."


In [18]:
dataset = deepcopy(df.values.tolist())
for data in dataset:
    print(data[0])
    print(data[1])
    break

['Clause', '18(1', ')', ',', '(', '2', ')', 'and', '(', '3', ')', '\n', '(', 'a', ')', '&', '(', 'b', ')', 'were', 'transposed', 'in', 'Article', '23', 'of', 'the', 'Draft', 'Constitution', 'of', 'India', '.']
['PROVISION', 'PROVISION', 'PROVISION', 'PROVISION', 'PROVISION', 'PROVISION', 'PROVISION', 'PROVISION', 'PROVISION', 'PROVISION', 'PROVISION', 'PROVISION', 'PROVISION', 'PROVISION', 'PROVISION', 'PROVISION', 'PROVISION', 'PROVISION', 'PROVISION', 'O', 'O', 'O', 'PROVISION', 'PROVISION', 'O', 'O', 'STATUTE', 'STATUTE', 'STATUTE', 'STATUTE', 'O']


In [19]:
class FewShotSampler:
    '''
    Samples the data from csv to episodic data with query and support set in N-way K-shot
    n_wa: numeber of classes
    k_shot: sample for each class
    '''
    def __init__(self, n_way):
        self.n_way = n_way

    def extract_n_way_data(self, data):
        n_way_data = []
        for items in data:
            if len(set(items[1]))==self.n_way:
                n_way_data.append(items)
        
        return n_way_data
    
    def sample_data(self, n_way_data, full_data, n_episodes):
        episodes = []
        random.seed(42)
        for query in random.sample(n_way_data, min(n_episodes, len(n_way_data))):

            query_text = query[0]
            query_tags = query[1]

            query_tag_set = set(query_tags)

            final_query_set = [query_text, query_tags]

            support_data = []
            class_counts = {k: 0 for k in query_tag_set}
            
            for items in random.sample(full_data, len(full_data)):
                if len(support_data)>=8:
                        break
                text, labels = items
                new_labels = []
                for i in range(len(labels)):
                    if labels[i] not in query_tag_set:
                        new_labels.append('O')
                    else:
                        new_labels.append(labels[i])
                

                for tag in query_tag_set:
                    if len(support_data)>=8:
                        break
                    if [text, new_labels] in support_data or ''.join(text) == ''.join(query_text):
                        break
                    if len(set(new_labels))<2:
                        break
                    if tag in new_labels and class_counts[tag]<2:
                        support_data.append([text, new_labels])
                        for ent in set(new_labels):
                                class_counts[ent]+=1


            
            if len(support_data)<8:
                for items in random.sample(full_data, 8):
                    if len(support_data)>=8:
                        break
                    text, labels = items
                    new_labels = []
                    for i in range(len(labels)):
                        if labels[i] not in query_tag_set:
                            new_labels.append('O')
                        else:
                            new_labels.append(labels[i])
                    
                    if [text, new_labels] in support_data or ''.join(text) == ''.join(query_text):
                        break

                    support_data.append([text, new_labels])    


            episodes.append({
                'query_set': final_query_set,
                'support_set': support_data
            })
            
            n_way_data.remove(query)
        
        return episodes

    def generate_episodes(self, reference_data, num_episodes):
        extracted_data = self.extract_n_way_data(reference_data)
        episodes = self.sample_data(extracted_data, reference_data, num_episodes)

        return episodes

In [20]:
sampler = FewShotSampler(5)

In [21]:
episodes = sampler.generate_episodes(deepcopy(dataset), 500)

In [22]:
def display_episode(episodes, idx):
    print('Query set')
    print(episodes[idx]['query_set'][0])
    print(episodes[idx]['query_set'][1])
    print(set(episodes[idx]['query_set'][1]))
    print('-'*300)
    print('-'*300)
    print('Support Set')
    for support in episodes[idx]['support_set']:
        text, support_entity = support
        print(f"Text - {text}")
        print(f"Labels - {support_entity}")
        print('SET: ')
        print(set(support_entity))
        print('-'*300)
        print()

In [23]:
for i, episode in enumerate(episodes):
    if len(episode['support_set'])!=8:
        print('Not equal to 8')

In [24]:
for i, episode in enumerate(episodes):
    query_set_tags = set(episode['query_set'][1])
    support_set_tags = set()
    for s_item in episode['support_set']:
        s_tags = s_item[1]
        support_set_tags.update(s_tags)
    
    if sorted(query_set_tags) != sorted(support_set_tags):
        print('Found tags in query set that do not belong in support set')

In [41]:
display_episode(episodes, 9)

Query set
['Earlier', 'the', 'accused', 'was', 'convicted', 'vide', 'judgment', 'dated', '23rd', 'August', ',', '1999', 'delivered', 'in', 'Sessions', 'Trial', 'No.325/1994', 'by', 'Sessions', 'Judge', 'Morena', 'for', 'commission', 'of', 'the', 'offence', 'under', 'section', '302', 'IPC', 'and', 'awarded', 'life', 'imprisonment', '.']
['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'DATE', 'DATE', 'DATE', 'DATE', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'JUDGE', 'O', 'O', 'O', 'O', 'O', 'O', 'PROVISION', 'PROVISION', 'STATUTE', 'O', 'O', 'O', 'O', 'O']
{'PROVISION', 'DATE', 'O', 'JUDGE', 'STATUTE'}
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------

In [28]:
xlmr_model_name = "xlm-roberta-base"
xlmr_tokenizer = AutoTokenizer.from_pretrained(xlmr_model_name)

In [29]:
def tokenize_and_align_labels(support_set, q_tokens, q_tags, eval=False):

  # For support

  tokenized_support_set = []
  label2idx = {}
  idx = 0
  for items in support_set:
    s_tokens, s_tags = items
    s_tokenized = xlmr_tokenizer(s_tokens, truncation=True, is_split_into_words=True)
    
    for tag in s_tags:
      if tag not in label2idx:
        label2idx[tag] = idx
        idx+=1

    labels = [label2idx[tag] for tag in s_tags]
    word_ids = s_tokenized.word_ids()

    label_ids = []
    for word_idx in word_ids:
      if word_idx==None:
        label_ids.append(label2idx['O'])
      else:
        label_ids.append(labels[word_idx])

    s_tokenized['labels'] = label_ids

    tokenized_support_set.append(s_tokenized)


  # For query
  q_tokenized_inputs = xlmr_tokenizer(q_tokens, truncation=True, is_split_into_words=True)

  if not eval:
    q_word_ids = q_tokenized_inputs.word_ids()
    q_labels = [label2idx[tag] for tag in q_tags]
    q_label_ids = []
    for word_idx in q_word_ids:
      if word_idx==None:
        q_label_ids.append(label2idx['O'])
      else:
        q_label_ids.append(q_labels[word_idx])

    q_tokenized_inputs['labels'] = q_label_ids
  else:
    q_tokenized_inputs['labels'] = []

  return q_tokenized_inputs, tokenized_support_set, len(label2idx)

In [30]:
def tokenize_episodes(episodes):
    
    tokenized_episodes = []
        
    for episode in episodes:

        final_support_set = []
        final_query_set = []
        query = episode['query_set']
        support = episode['support_set']

        q_tokens, q_tags = query

        tokenized_query_set, tokenized_support_set, num_classes = tokenize_and_align_labels(support, q_tokens, q_tags)

        if len(final_query_set)<1:
            q_ii = torch.tensor(tokenized_query_set['input_ids']).unsqueeze(0)
            q_am = torch.tensor(tokenized_query_set['attention_mask']).unsqueeze(0)
            q_l = torch.tensor(tokenized_query_set['labels'])
            final_query_set.extend([q_ii, q_am, q_l])

        for support_set in tokenized_support_set:
            s_ii = torch.tensor(support_set['input_ids']).unsqueeze(0)
            s_am = torch.tensor(support_set['attention_mask']).unsqueeze(0)
            s_l = torch.tensor(support_set['labels'], dtype=torch.int)
            final_support_set.append([s_ii, s_am, s_l])

        tokenized_episodes.append({
            'query_set': final_query_set,
            'support_set': final_support_set,
            'num_classes': num_classes
        })
    
    return tokenized_episodes 

In [33]:
tokenized_episodes = tokenize_episodes(episodes)

In [34]:
len(tokenized_episodes)

401

In [38]:
for episode in tokenized_episodes:
    query_set = episode['query_set'] 
    support_set = episode['support_set']
    num_classes = episode['num_classes']
    q_ii, q_am, q_l = query_set
    for item in support_set:
        s_ii, s_am, s_l = item
        print(s_ii)
        print(s_am)
        print(s_l)
        break
    break

tensor([[     0,   5443,      6,      5,    436,      5,   4235,     83,    563,
          30508,  13416,      6,      4,   7440,     73,     70, 181595,  31068,
          13483,   1919,  31330,    607,    297,    953,  28705,  34498, 113771,
            450,     98, 165045,  42276,   3378,   1363,      6,      4,     70,
         121413,  70541,    297,     23,     70,  43824,    126,  94419,    335,
          26038,  24491, 109921,    100,     70, 169424,    111, 185256, 209716,
              7,    111,   5798, 173857,    297, 130090,    111,    233,   6664,
          19441,    615,   5039,   1745,     15,   1632,   6275,    300,     86,
          45234,  33297,   1388,      6, 178851,   1212,      6,      5,      2]])
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1,

In [50]:
len(tokenized_episodes[0]['support_set'])

8