In [41]:
import spacy
from spacy.tokens import DocBin
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import json
# from datasets import load_dataset
from collections import defaultdict
from typing import List, Dict, Tuple
from datasets import Dataset, DatasetDict
import ast

In [2]:
categories = ['CASE_NUMBER', 'COURT', 'DATE', 'GPE', 'JUDGE', 'LAWYER', 'ORG', 'OTHER_PERSON', 'PETITIONER', 'PRECEDENT', 'PROVISION', 'RESPONDENT', 'STATUTE', 'WITNESS']

# Create label2id and id2label dictionaries
B_PREFIX = 'B-'
I_PREFIX = 'I-'
O_TAG = 'O'
label2id = {O_TAG: 0}
id2label = {0: O_TAG}
idx = 1
for category in categories:
    label2id[B_PREFIX + category] = idx
    id2label[idx] = B_PREFIX + category
    idx += 1
    label2id[I_PREFIX + category] = idx
    id2label[idx] = I_PREFIX + category
    idx += 1

### Creating data [don't run if already created]

In [3]:
train = pd.read_csv("./data/finetuning/train.csv")
dev = pd.read_csv("./data/finetuning/dev.csv")
test = pd.read_csv("./data/finetuning/test.csv")

In [4]:
test

Unnamed: 0,sentence,raw_entities,entities_dict,text
0,$~40 * In The High Court Of Delhi At New Delhi...,"{""CASE_NUMBER"": ""[]"", ""COURT"": ""['High Court O...","{'CASE_NUMBER': '[]', 'COURT': ""['High Court O...",<s> [INST] You are solving the NER problem in ...
1,1 Reportable In The Supreme Court Of India Civ...,"{""CASE_NUMBER"": ""[]"", ""COURT"": ""['Supreme Cour...","{'CASE_NUMBER': '[]', 'COURT': ""['Supreme Cour...",<s> [INST] You are solving the NER problem in ...
2,R/Scr.A/9089/2017 Judgment In The High Court O...,"{""CASE_NUMBER"": ""[]"", ""COURT"": ""['High Court O...","{'CASE_NUMBER': '[]', 'COURT': ""['High Court O...",<s> [INST] You are solving the NER problem in ...
3,High Court Of Judicature For Rajasthan Bench A...,"{""CASE_NUMBER"": ""[]"", ""COURT"": ""['High Court O...","{'CASE_NUMBER': '[]', 'COURT': ""['High Court O...",<s> [INST] You are solving the NER problem in ...
4,1 In The High Court Of Judicature At Madras Da...,"{""CASE_NUMBER"": ""[]"", ""COURT"": ""['High Court O...","{'CASE_NUMBER': '[]', 'COURT': ""['High Court O...",<s> [INST] You are solving the NER problem in ...
...,...,...,...,...
1069,"Apparently, Channaraddi set up his daughters G...","{""CASE_NUMBER"": ""['O.S.No.31/2009']"", ""COURT"":...","{'CASE_NUMBER': ""['O.S.No.31/2009']"", 'COURT':...",<s> [INST] You are solving the NER problem in ...
1070,After the dismissal of the petition for annulm...,"{""CASE_NUMBER"": ""['F.C.O.P.No.41 of 2012']"", ""...","{'CASE_NUMBER': ""['F.C.O.P.No.41 of 2012']"", '...",<s> [INST] You are solving the NER problem in ...
1071,"On 12.07.2018, a letter was received from the ...","{""CASE_NUMBER"": ""['Special Case (NDPS) No.17 o...","{'CASE_NUMBER': ""['Special Case (NDPS) No.17 o...",<s> [INST] You are solving the NER problem in ...
1072,The date on which the measurements were record...,"{""CASE_NUMBER"": ""[]"", ""COURT"": ""[]"", ""DATE"": ""...","{'CASE_NUMBER': '[]', 'COURT': '[]', 'DATE': '...",<s> [INST] You are solving the NER problem in ...


In [5]:
nlp = spacy.load('en_core_web_sm')
def tokenize_and_tag(df: pd.DataFrame, categories: List[str]) -> pd.DataFrame:
    # Define tag prefixes
    B_PREFIX = 'B-'
    I_PREFIX = 'I-'
    O_TAG = 'O'

    # Prepare output data
    output_data = {'tokens': [], 'ner_tags': []}

    for _, row in df.iterrows():
        sentence = row['sentence']
        entities = row['entities_dict']
        # print(entities)

        # Tokenize the sentence
        # tokens = sentence.split()  # Simple tokenization, can be replaced with a more robust tokenizer
        doc = nlp(sentence)
        tokens = [token.text for token in doc]

        # Initialize tags as 'Outside' for each token
        tags = [O_TAG for _ in tokens]

        entities = ast.literal_eval(entities)
        # print(type(entities))

        # Update tags based on entities
        for category, entity_list in entities.items():
            entity_lista = ast.literal_eval(entity_list)
            for entity in entity_lista:
                entity_tokens = entity.split()
                # Find all occurrences of the entity in the tokens
                for i in range(len(tokens)):
                    # print(entity_tokens, tokens[i:i+len(entity_tokens)])
                    if tokens[i:i+len(entity_tokens)] == entity_tokens:
                        # Update the tags for this occurrence of the entity
                        tags[i] = B_PREFIX + category
                        for j in range(i + 1, i + len(entity_tokens)):
                            tags[j] = I_PREFIX + category

        output_data['tokens'].append(tokens)
        output_data['ner_tags'].append(tags)
        data = pd.DataFrame(output_data) 
        data['ner_tags_str'] = data['ner_tags']
        data['ner_tags'] = data['ner_tags'].apply(lambda x: list(map(label2id.get, x)))

    return data


In [6]:
train_data = tokenize_and_tag(train, categories)
dev_data = tokenize_and_tag(dev, categories)
test_data = tokenize_and_tag(test, categories)

In [8]:
train_data.to_csv("./data/roberta/train.csv", index=False)
dev_data.to_csv("./data/roberta/dev.csv", index=False)
test_data.to_csv("./data/roberta/test.csv", index=False)

In [39]:
train_data

Unnamed: 0,tokens,ner_tags,ner_tags_str
0,"[(, 7, ), On, specific, query, by, the, Bench,...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ..."
1,"[He, was, also, asked, whether, Agya, <, span,...","[0, 0, 0, 0, 0, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[O, O, O, O, O, B-OTHER_PERSON, O, O, O, O, O,..."
2,"[5.2, CW3, Mr, Vijay, Mishra, ,, Deputy, Manag...","[0, 0, 0, 27, 28, 0, 0, 0, 0, 13, 14, 0, 0, 0,...","[O, O, O, B-WITNESS, I-WITNESS, O, O, O, O, B-..."
3,"[The, pillion, rider, T.V., Satyanarayana, Mur...","[0, 0, 0, 15, 16, 16, 0, 0, 0, 0]","[O, O, O, B-OTHER_PERSON, I-OTHER_PERSON, I-OT..."
4,"[,, if, the, argument, of, the, learned, couns...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ..."
...,...,...,...
9890,"[1, ®, In, The, High, Court, Of, Karnataka, At...","[0, 0, 0, 0, 3, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, ...","[O, O, O, O, B-COURT, I-COURT, I-COURT, I-COUR..."
9891,"[They, had, admittedly, left, India, after, th...","[0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[O, O, O, O, B-GPE, O, O, O, O, O, O, O, O, O,..."
9892,"[Non, -, applicant, produced, witnesses, NAW, ...","[0, 0, 0, 0, 0, 0, 0, 27, 28, 0, 0, 0, 0, 0, 2...","[O, O, O, O, O, O, O, B-WITNESS, I-WITNESS, O,..."
9893,"[No, doubt, ,, civil, and, criminal, jurisdict...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ..."


### Model Building

In [26]:
def transform_columns_to_list(df):
    for categ in df.columns:
        df[categ] = df[categ].apply(ast.literal_eval)
    return df

In [29]:
train_data = pd.read_csv("./data/roberta/train.csv")
dev_data = pd.read_csv("./data/roberta/dev.csv")
test_data = pd.read_csv("./data/roberta/test.csv")

In [30]:
train_data = transform_columns_to_list(train_data)
dev_data = transform_columns_to_list(dev_data)
test_data = transform_columns_to_list(test_data)

In [42]:
# Convert pandas DataFrames to Hugging Face's Dataset objects
train_dataset = Dataset.from_pandas(train_data)
dev_dataset = Dataset.from_pandas(dev_data)
test_dataset = Dataset.from_pandas(test_data)

# Create a DatasetDict
data = DatasetDict({
    'train': train_dataset,
    'validation': dev_dataset,
    'test': test_dataset
})


In [43]:
data

DatasetDict({
    train: Dataset({
        features: ['tokens', 'ner_tags', 'ner_tags_str'],
        num_rows: 9895
    })
    validation: Dataset({
        features: ['tokens', 'ner_tags', 'ner_tags_str'],
        num_rows: 1100
    })
    test: Dataset({
        features: ['tokens', 'ner_tags', 'ner_tags_str'],
        num_rows: 1074
    })
})

#### Tokenization

In [65]:
from transformers import AutoTokenizer

model_checkpoint = "xlm-roberta-base"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

config.json:   0%|          | 0.00/615 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.10M [00:00<?, ?B/s]

In [66]:
tokenizer.is_fast

True

In [67]:
inputs = data['train'][0]['tokens']
inputs = tokenizer(inputs, is_split_into_words=True)
print(inputs.tokens())

['<s>', '▁(', '▁7', '▁)', '▁On', '▁specific', '▁que', 'ry', '▁by', '▁the', '▁Ben', 'ch', '▁about', '▁an', '▁entry', '▁of', '▁Rs', '▁', '.', '▁1,3', '1', ',', '37', ',', '500', '▁on', '▁deposit', '▁side', '▁of', '▁Hongkong', '▁Bank', '▁account', '▁of', '▁which', '▁a', '▁photo', '▁copy', '▁is', '▁appear', 'ing', '▁at', '▁p', '.', '▁40', '▁of', '▁assess', 'ee', "▁'", 's', '▁paper', '▁book', '▁', ',', '▁learned', '▁author', 'ised', '▁representativ', 'e', '▁submitted', '▁that', '▁it', '▁was', '▁related', '▁to', '▁loan', '▁from', '▁broker', '▁', ',', '▁Rahul', '▁&', '▁Co', '.', '▁on', '▁the', '▁basis', '▁of', '▁his', '▁sub', 'mission', '▁a', '▁necessary', '▁mark', '▁is', '▁put', '▁by', '▁us', '▁on', '▁that', '▁photo', '▁copy', '▁', '.', '</s>']


In [68]:
print(inputs.word_ids())

[None, 0, 1, 2, 3, 4, 5, 5, 6, 7, 8, 8, 9, 10, 11, 12, 13, 14, 14, 15, 15, 15, 15, 15, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 29, 30, 31, 31, 32, 33, 34, 34, 35, 35, 36, 37, 38, 38, 39, 40, 40, 41, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 51, 52, 53, 54, 54, 55, 56, 57, 58, 59, 60, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 72, None]


In [69]:
def align_labels_with_tokens(labels, word_ids):
  new_labels = []
  current_word=None
  for word_id in word_ids:
    if word_id != current_word:
      current_word = word_id
      label = -100 if word_id is None else labels[word_id]
      new_labels.append(label)

    elif word_id is None:
      new_labels.append(-100)

    else:
      label = labels[word_id]

      if label%2==1:
        label = label + 1
      new_labels.append(label)

  return new_labels

In [70]:
labels = data['train'][0]['ner_tags']
word_ids = inputs.word_ids()
print(labels, word_ids)

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, 14, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] [None, 0, 1, 2, 3, 4, 5, 5, 6, 7, 8, 8, 9, 10, 11, 12, 13, 14, 14, 15, 15, 15, 15, 15, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 29, 30, 31, 31, 32, 33, 34, 34, 35, 35, 36, 37, 38, 38, 39, 40, 40, 41, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 51, 52, 53, 54, 54, 55, 56, 57, 58, 59, 60, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 72, None]


In [71]:
print(align_labels_with_tokens(labels, word_ids))

[-100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, 14, 14, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -100]


In [72]:
def tokenize_and_align_labels(examples):
  tokenized_inputs = tokenizer(examples['tokens'], truncation=True, is_split_into_words=True)

  all_labels = examples['ner_tags']

  new_labels = []
  for i, labels in enumerate(all_labels):
    word_ids = tokenized_inputs.word_ids(i)
    new_labels.append(align_labels_with_tokens(labels, word_ids))

  tokenized_inputs['labels'] = new_labels

  return tokenized_inputs

In [73]:
tokenized_datasets = data.map(tokenize_and_align_labels, batched=True, remove_columns=data['train'].column_names)

Map:   0%|          | 0/9895 [00:00<?, ? examples/s]

Map:   0%|          | 0/1100 [00:00<?, ? examples/s]

Map:   0%|          | 0/1074 [00:00<?, ? examples/s]

In [74]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 9895
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 1100
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 1074
    })
})

#### Data collation and Metrics

In [75]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

In [76]:
batch = data_collator([tokenized_datasets['train'][i] for i in range(2)])
print(batch)

You're using a XLMRobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'input_ids': tensor([[     0,     15,    361,   1388,   2161,  29458,     41,   1294,    390,
             70,   3419,    206,   1672,    142,  42805,    111, 115034,      6,
              5,  46963,    418,      4,  10945,      4,   4283,     98,  40370,
           5609,    111, 185934,   4932,  15426,    111,   3129,     10,  16186,
          43658,     83, 108975,    214,     99,    915,      5,   1112,    111,
         202120,   7039,    242,      7,  15122,  12877,      6,      4,  97384,
          42179,  52021,  99638,     13, 230121,    450,    442,    509,  62548,
             47, 111628,   1295, 115835,      6,      4, 191367,    619,   1311,
              5,     98,     70,  18231,    111,   1919,   1614,  21150,     10,
          63559,  16188,     83,   3884,    390,   1821,     98,    450,  16186,
          43658,      6,      5,      2],
        [     0,   1529,    509,   2843,  37170,  36766,  12342,    395,   4426,
          27734,  18507,  22422,  15080,    555,    4

#### Metrics

In [96]:
import evaluate
from seqeval.scheme import IOB2

metric = evaluate.load('seqeval')

In [99]:
import numpy as np

def compute_metrics(eval_preds):
  logits, labels = eval_preds

  predictions = np.argmax(logits, axis=-1)

  true_labels = [[id2label[l] for l in label if l!=-100] for label in labels]

  true_predictions = [[id2label[p] for p,l in zip(prediction, label) if l!=-100]
                      for prediction, label in zip(predictions, labels)]

  all_metrics = metric.compute(predictions=true_predictions, references=true_labels, scheme="IOB2", mode="strict", zero_division=0)

  return {"precision": all_metrics['overall_precision'],
          "recall": all_metrics['overall_recall'],
          "f1": all_metrics['overall_f1'],
          "accuracy": all_metrics['overall_accuracy']}

### Model training

In [100]:
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer

In [101]:
# model = AutoModelForTokenClassification.from_pretrained(
#                                                     model_checkpoint,
#                                                     id2label=id2label,
#                                                     label2id=label2id)
model = AutoModelForTokenClassification.from_pretrained(
                                                    "./distilbert-finetuned-ner/checkpoint-3711",
                                                    id2label=id2label,
                                                    label2id=label2id)

In [102]:
model.config.num_labels

29

In [103]:
args = TrainingArguments("distilbert-finetuned-ner",
                         evaluation_strategy = "epoch",
                         save_strategy="epoch",
                         learning_rate = 2e-5,
                         num_train_epochs=3,
                         weight_decay=0.01)

In [104]:
trainer = Trainer(model=model,
                  args=args,
                  train_dataset = tokenized_datasets['train'],
                  eval_dataset = tokenized_datasets['validation'],
                  data_collator=data_collator,
                  compute_metrics=compute_metrics,
                  tokenizer=tokenizer)

In [105]:
trainer.train()

  0%|          | 0/3711 [00:00<?, ?it/s]

{'loss': 0.0248, 'learning_rate': 1.7305308542171922e-05, 'epoch': 0.4}
{'loss': 0.0233, 'learning_rate': 1.4610617084343843e-05, 'epoch': 0.81}


  0%|          | 0/138 [00:00<?, ?it/s]

{'eval_loss': 0.10630408674478531, 'eval_precision': 0.8756989247311828, 'eval_recall': 0.8806228373702422, 'eval_f1': 0.878153978865646, 'eval_accuracy': 0.9812913556408506, 'eval_runtime': 72.6488, 'eval_samples_per_second': 15.141, 'eval_steps_per_second': 1.9, 'epoch': 1.0}
{'loss': 0.0186, 'learning_rate': 1.1915925626515765e-05, 'epoch': 1.21}
{'loss': 0.0154, 'learning_rate': 9.221234168687686e-06, 'epoch': 1.62}


  0%|          | 0/138 [00:00<?, ?it/s]

{'eval_loss': 0.1097690686583519, 'eval_precision': 0.8632193494578816, 'eval_recall': 0.8953287197231834, 'eval_f1': 0.8789808917197451, 'eval_accuracy': 0.9801139708479525, 'eval_runtime': 71.9323, 'eval_samples_per_second': 15.292, 'eval_steps_per_second': 1.918, 'epoch': 2.0}
