In [None]:
!wget https://storage.googleapis.com/indianlegalbert/OPEN_SOURCED_FILES/NER/NER_TRAIN.zip
!wget https://storage.googleapis.com/indianlegalbert/OPEN_SOURCED_FILES/NER/NER_DEV.zip

In [None]:
!unzip NER_DEV.zip
!unzip NER_TRAIN.zip

In [None]:
import numpy as np
import pandas as pd
import json
from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import pipeline
from transformers.tokenization_utils_base import TokenSpan
from tqdm import tqdm
from datasets import DatasetDict, Dataset, load_dataset
import pandas as pd

In [None]:
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
LABELS = ["COURT",  "PETITIONER", "RESPONDENT", "JUDGE", "LAWYER", "DATE", "ORG", "GPE", "STATUTE", "PROVISION", "PRECEDENT", "CASE_NUMBER", "WITNESS", "OTHER_PERSON"]
LABELS = ["O"] + ["B-"+label for label in LABELS] + ["I-"+label for label in LABELS]
id2label = {i:label for i,label in enumerate(LABELS)}
label2id = {label:id for id, label in id2label.items()}

In [None]:
def tokenize_and_ner(text, named_entities):
  #! Please leave this function as an example why not to use COLAB! TEXT variable is not used here.
  tokens = tokenizer.encode_plus(text, return_offsets_mapping=True)
  offsets = tokens['offset_mapping']
  # Initialize the label list
  labels = [label2id["O"]] * len(tokens['input_ids'])

  # Iterate over each named entity
  for named_entity in named_entities:
      start_char = named_entity["start"]
      end_char = named_entity["end"]
      # Find the nearest token boundaries to the named entity's start and end positions
      token_start = None
      token_end = None
      for i, (start_offset, end_offset) in enumerate(offsets):
          if start_offset <= start_char < end_offset:
              token_start = i
          if start_offset < end_char <= end_offset:
              token_end = i
              break
      if token_start is not None and token_end is not None:
          for i in range(token_start, token_end + 1):
            if i == token_start:
                labels[i] = label2id["B-"+named_entity['labels'][0]]
            else:
                labels[i] = label2id["I-"+named_entity['labels'][0]]

  tokens['labels'] = labels
  return tokens

In [None]:
def transform_dataset_entry(entry):
  text = entry['data']['text']
  named_entities = [r['value'] for r in entry['annotations'][0]['result']]
  return tokenize_and_ner(text, named_entities)

In [None]:
with open("./NER_TRAIN_JUDGEMENT.json", 'r') as f:
  train_data = json.load(f)

with open("./NER_DEV/NER_DEV_JUDGEMENT.json", 'r') as f:
  valid_data = json.load(f)

In [None]:
dataset_dict = {
    'train': train_data,
    'valid': valid_data
}

In [None]:
dataset_dict = {part:pd.DataFrame([transform_dataset_entry(entry) for entry in dataset]) for part, dataset in dataset_dict.items()}

In [None]:
dataset_dict = DatasetDict({
    'train': Dataset.from_pandas(dataset_dict['train']),
    'valid': Dataset.from_pandas(dataset_dict['valid'])
})


In [None]:
from transformers import AutoTokenizer, AutoModelForTokenClassification, Trainer, TrainingArguments
from datasets import load_metric

import evaluate

seqeval = evaluate.load("seqeval")

In [None]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
model = AutoModelForTokenClassification.from_pretrained(
    model_name, num_labels=29, id2label=id2label, label2id=label2id
)

In [None]:
def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    true_predictions = [
        [id2label[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [id2label[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = seqeval.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

In [None]:
import numpy as np
import torch
torch.cuda.empty_cache()
training_args = TrainingArguments(
    output_dir="my_awesome_law_ner_model",
    learning_rate=1e-5,
    per_device_train_batch_size=48,
    per_device_eval_batch_size=48,
    num_train_epochs=10,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    logging_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
)


In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset_dict["train"],
    eval_dataset=dataset_dict["valid"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)


In [None]:
trainer.train()
