In [None]:
import torch

# Confirm that the GPU is detected

assert torch.cuda.is_available()

# Get the GPU device name.

device_name = torch.cuda.get_device_name()
n_gpu = torch.cuda.device_count()
print(f"Found device: {device_name}, n_gpu: {n_gpu}")
device = torch.device("cuda")

Found device: Tesla T4, n_gpu: 1


In [None]:
!pip install torch
!pip install transformers
!pip install -U -q PyDrive
!pip install datasets
!pip install seqeval
!pip install ray[tune]
!pip install transformers accelerate
!pip install wandb

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import os
import itertools
import pandas as pd
import numpy as np
from datasets import Dataset
from datasets import load_metric
from transformers import AutoTokenizer
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer
from transformers import DataCollatorForTokenClassification
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("nlpaueb/legal-bert-base-uncased")
model = AutoModelForSequenceClassification.from_pretrained("nlpaueb/legal-bert-base-uncased", num_labels=15) # num_labels = 14 + 1


Some weights of the model checkpoint at nlpaueb/legal-bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification wer

# Importing Datasets

In [None]:
import pandas as pd
import json
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/My Drive/Colab Notebooks/

train_df_judgement = pd.read_json('Data_Augmentation_for_Low_Resource_Indian_Legal_NER/NER_TRAIN_JUDGEMENT_PREPROCESSED.json')
train_df_preamble = pd.read_json('Data_Augmentation_for_Low_Resource_Indian_Legal_NER/NER_TRAIN_PREAMBLE_PREPROCESSED.json')

test_df_judgement = pd.read_json('Data_Augmentation_for_Low_Resource_Indian_Legal_NER/NER_DEV_JUDGEMENT_PREPROCESSED.json')
test_df_preamble = pd.read_json('Data_Augmentation_for_Low_Resource_Indian_Legal_NER/NER_DEV_PREAMBLE_PREPROCESSED.json')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/My Drive/Colab Notebooks


In [None]:
print(len(train_df_judgement), len(train_df_preamble))
print(len(test_df_judgement), len(test_df_preamble))
total_data = len(train_df_judgement) + len(train_df_preamble) + len(test_df_judgement) + len(test_df_preamble)

8494 553
839 28


# Train : Validation : Test split (70:20:10)

In [None]:
# concatenating the judgement + preamble data
df = pd.concat([train_df_judgement, train_df_preamble])
df.reset_index(inplace=True, drop=True)
# print(df)
print(total_data)
# splitting train (70% of the total data)
split=int(total_data*0.7)
split_70=split /(len(train_df_judgement) + len(train_df_preamble))
train_df = df.sample(frac = split_70)
# splitting validation (appx 20% of the total data)
val_df = df.drop(train_df.index)
# test data (appx 10% of the total data)
test_df = pd.concat([test_df_judgement, test_df_preamble])

9914


In [None]:
len(train_df)

6939

In [None]:
len(val_df)

2108

In [None]:
train_dataset = Dataset.from_pandas(train_df)
val_dataset = Dataset.from_pandas(val_df)
test_dataset = Dataset.from_pandas(test_df)

# Tokenization

In [None]:
labels_list = [" OTHERS", " PETITIONER", " COURT", " RESPONDENT", " JUDGE", " OTHER_PERSON", " LAWYER", " DATE", " ORG", " GPE", " STATUTE", " PROVISION", " PRECEDENT", " CASE_NUMBER", " WITNESS"]
label_encoding_dict = {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9, 10: 10, 11: 11, 12: 12, 13: 13, 14: 14}
label_list_encoding_dict = {0: " OTHERS", 1: " PETITIONER", 2: " COURT", 3: " RESPONDENT", 4: " JUDGE", 5: " OTHER_PERSON", 6: " LAWYER", 7: " DATE", 8: " ORG", 9: " GPE", 10: " STATUTE", 11: " PROVISION", 12: " PRECEDENT", 13: " CASE_NUMBER", 14: " WITNESS"}

In [None]:
def tokenize_all_labels(rows):
    tokenized_inputs = tokenizer(list(rows["tokens"]), truncation = True, is_split_into_words = True)
    labels, label_all = [], True
    for index, label in enumerate(rows["ner_tags"]):
        # print(i, label)
        prior_idx = None
        word_ids = tokenized_inputs.word_ids(batch_index = index)
        
        label_ids = []
        for current_idx in word_ids:
            if current_idx is None: label_ids.append(-100)
            elif label[current_idx] == '0': label_ids.append(0)
            elif current_idx != prior_idx: label_ids.append(label_encoding_dict[label[current_idx]])
            else: label_ids.append(label_encoding_dict[label[current_idx]] if label_all else -100)
            prior_idx = current_idx
        labels.append(label_ids)
        
    tokenized_inputs["labels"] = labels
    return tokenized_inputs

In [None]:
train_dataset_tokenized = train_dataset.map(tokenize_all_labels, batched=True)
val_dataset_tokenized = val_dataset.map(tokenize_all_labels, batched=True)

Map:   0%|          | 0/6939 [00:00<?, ? examples/s]

Map:   0%|          | 0/2108 [00:00<?, ? examples/s]

# Training

In [None]:
import wandb
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mastha[0m ([33m685_data_augmentation_legal_ner[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
import warnings
warnings.filterwarnings('ignore')

# hyperparameters
batch_size = 5
epoch = 5
learning_rate = 2.910635913133073e-05

# tokenizing and loading legalBERT model 
tokenizer = AutoTokenizer.from_pretrained("nlpaueb/legal-bert-base-uncased")
model = AutoModelForTokenClassification.from_pretrained("nlpaueb/legal-bert-base-uncased", num_labels=len(labels_list))
metric = load_metric("seqeval")

training_arguments = TrainingArguments(
    "eval_indian_legal_ner",
    report_to='wandb',
    evaluation_strategy = "steps",
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=epoch,
    weight_decay=1e-5,
    load_best_model_at_end=True,
    metric_for_best_model='accuracy',
    run_name='gold'  
)

# training on train set and evaluating on validation set
data_collator = DataCollatorForTokenClassification(tokenizer)

best_accuracy = 0
best_results = 0

def evaluate_metrics(pred_tuple):
    global best_accuracy
    global best_results
    predictions, labels = pred_tuple
    predictions = np.argmax(predictions, axis=2)
    
    actual_predictions = [[labels_list[pred_tuple] for (pred_tuple, l) in zip(prediction, label) if l != -100] for prediction, label in zip(predictions, labels)]
    actual_labels = [[labels_list[l] for (pred_tuple, l) in zip(prediction, label) if l != -100] for prediction, label in zip(predictions, labels)]
    results = metric.compute(predictions=actual_predictions, references=actual_labels)
    if best_accuracy < results["overall_accuracy"]: 
      best_results = results
    return {"precision": results["overall_precision"], "recall": results["overall_recall"], "f1": results["overall_f1"], "accuracy": results["overall_accuracy"]}
    
trainer = Trainer(
    model,
    training_arguments,
    train_dataset = train_dataset_tokenized,
    eval_dataset = val_dataset_tokenized,
    data_collator = data_collator,
    tokenizer=tokenizer,
    compute_metrics=evaluate_metrics
)
trainer.train()
trainer.evaluate()
trainer.save_model('indian_legal_ner.model')
wandb.finish()

Some weights of the model checkpoint at nlpaueb/legal-bert-base-uncased were not used when initializing BertForTokenClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initia

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy
500,0.3847,0.19002,0.686918,0.786138,0.733186,0.939663
1000,0.173,0.145514,0.781769,0.831832,0.806024,0.955173
1500,0.1408,0.142305,0.792157,0.849846,0.819988,0.958632
2000,0.1,0.135533,0.811741,0.852043,0.831404,0.962142
2500,0.0943,0.124521,0.81908,0.864785,0.841312,0.963207
3000,0.0772,0.143217,0.832803,0.862808,0.84754,0.965013
3500,0.0575,0.137559,0.83636,0.865114,0.850494,0.965956
4000,0.0563,0.130354,0.825854,0.863137,0.844084,0.966524
4500,0.0383,0.146624,0.831652,0.869288,0.850054,0.966514
5000,0.0333,0.150912,0.835692,0.869288,0.852159,0.967224


0,1
eval/accuracy,▁▅▆▇▇▇▇███████
eval/f1,▁▅▆▆▇▇█▇██████
eval/loss,█▃▃▂▁▃▂▂▃▄▃▄▅▃
eval/precision,▁▅▆▇▇▇█▇▇███▇█
eval/recall,▁▅▆▆▇▇▇▇██████
eval/runtime,▅▁▁█▃▁▇▃▄▃▅▁▃▂
eval/samples_per_second,▃██▁▆█▂▆▅▆▄█▆▇
eval/steps_per_second,▃██▁▆█▂▆▅▆▄█▆▇
train/epoch,▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇████
train/global_step,▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇████

0,1
eval/accuracy,0.96822
eval/f1,0.85884
eval/loss,0.14412
eval/precision,0.84563
eval/recall,0.87247
eval/runtime,15.2576
eval/samples_per_second,138.161
eval/steps_per_second,27.658
train/epoch,5.0
train/global_step,6940.0


In [None]:
best_results

{'CASE_NUMBER': {'precision': 0.73,
  'recall': 0.7891891891891892,
  'f1': 0.7584415584415585,
  'number': 185},
 'COURT': {'precision': 0.8422818791946308,
  'recall': 0.8807017543859649,
  'f1': 0.8610634648370498,
  'number': 285},
 'DATE': {'precision': 0.9522546419098143,
  'recall': 0.967654986522911,
  'f1': 0.9598930481283422,
  'number': 371},
 'GPE': {'precision': 0.7611940298507462,
  'recall': 0.7183098591549296,
  'f1': 0.7391304347826085,
  'number': 284},
 'JUDGE': {'precision': 0.8165137614678899,
  'recall': 0.89,
  'f1': 0.8516746411483254,
  'number': 100},
 'LAWYER': {'precision': 0.8823529411764706,
  'recall': 0.8823529411764706,
  'f1': 0.8823529411764706,
  'number': 17},
 'ORG': {'precision': 0.631578947368421,
  'recall': 0.6832740213523132,
  'f1': 0.6564102564102564,
  'number': 281},
 'OTHERS': {'precision': 0.8685096579833422,
  'recall': 0.8880231926073564,
  'f1': 0.8781580361942304,
  'number': 5519},
 'OTHER_PERSON': {'precision': 0.7995444191343963,


# Evaluating Test Dataset

In [None]:
wandb.init(name = "gold-test")

In [None]:
test_dataset_tokenized = test_dataset.map(tokenize_all_labels, batched=True)
test_results = trainer.evaluate(test_dataset_tokenized)
test_results

Map:   0%|          | 0/867 [00:00<?, ? examples/s]

{'eval_loss': 0.13911709189414978,
 'eval_precision': 0.8537941034316094,
 'eval_recall': 0.882808595702149,
 'eval_f1': 0.8680589680589681,
 'eval_accuracy': 0.9711757155403428,
 'eval_runtime': 8.5001,
 'eval_samples_per_second': 101.999,
 'eval_steps_per_second': 20.47,
 'epoch': 5.0}

In [None]:
wandb.finish()

0,1
eval/accuracy,▁
eval/f1,▁
eval/loss,▁
eval/precision,▁
eval/recall,▁
eval/runtime,▁
eval/samples_per_second,▁
eval/steps_per_second,▁
train/epoch,▁
train/global_step,▁

0,1
eval/accuracy,0.97118
eval/f1,0.86806
eval/loss,0.13912
eval/precision,0.85379
eval/recall,0.88281
eval/runtime,8.5001
eval/samples_per_second,101.999
eval/steps_per_second,20.47
train/epoch,5.0
train/global_step,6940.0


# Evaluating legalBART Samples

In [None]:
wandb.init(name = "legalbart-test")

In [None]:
# loading legalBART samples
legalBARTSamples = pd.read_csv('Data_Augmentation_for_Low_Resource_Indian_Legal_NER/legal_BART_Samples.csv')
legalBARTSamples['tokens'] = legalBARTSamples['tokens'].map(lambda x: eval(x))
legalBARTSamples['ner_tags'] = legalBARTSamples['ner_tags'].map(lambda x: eval(x))

In [None]:
legalBARTSamples

Unnamed: 0,tokens,ner_tags
0,"[On, specific, query, by, the, Bench, ,, it, w...","[0, 0, 0, 0, 5, 1, 5, 7, 6, 0, 0, 0, 0, 5, 0, ..."
1,"[According, to, the, Agya, ,, span, class, '',...","[0, 5, 5, 3, 5, 0, 2, 0, 0, 0, 0, 0, 0, 0, 7, ..."
2,"[PW3, Vijay, Mishra, ,, Deputy, Manager, ,, HD...","[0, 7, 7, 5, 2, 0, 5, 0, 2, 5, 5, 5, 0, 0, 0, ..."
3,"[He, was, asked, to, come, and, carry, out, th...","[0, 6, 0, 5, 0, 5, 0, 0, 5, 1, 0, 0, 0, 0, 6, ..."
4,"[The, pillion, rider, ,, Satyanarayana, Murthy...","[5, 0, 0, 5, 3, 3, 6, 0, 0, 5, 5, 0, 5, 5, 0, ..."
...,...,...
8389,"[The, Honble, Sri, Justice, Nooty, Ramamohana,...","[2, 0, 0, 2, 2, 2, 2, 0, 0, 0, 1, 0, 0, 0, 1, ..."
8390,"[Petitioner, :, Indira, Nehru, Gandhi, Respond...","[0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, ..."
8391,"[In, The, Court, Of, Sh, ., R.K, ., Sharma, ,,...","[0, 2, 2, 1, 0, 0, 2, 2, 1, 0, 0, 0, 2, 0, 0, ..."
8392,"[Petitioner, :, Bhagwant, Singh, Vs, ., Respon...","[0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, ..."


In [None]:
# concatinating gold test samples + legalBART samples
legalBARTSamples = pd.concat([test_df, legalBARTSamples])

legalBARTSamplesDataset = Dataset.from_pandas(legalBARTSamples)
legalBARTSamples_tokenized = legalBARTSamplesDataset.map(tokenize_all_labels, batched=True)
legalBARTSamples_results = trainer.evaluate(legalBARTSamples_tokenized)
legalBARTSamples_results

Map:   0%|          | 0/9261 [00:00<?, ? examples/s]

{'eval_loss': 5.691403388977051,
 'eval_precision': 0.06644750821844181,
 'eval_recall': 0.014915615661957605,
 'eval_f1': 0.024362523143480412,
 'eval_accuracy': 0.4084948905561926,
 'eval_runtime': 90.5602,
 'eval_samples_per_second': 102.263,
 'eval_steps_per_second': 20.462,
 'epoch': 5.0}

In [None]:
wandb.finish()

0,1
eval/accuracy,▁
eval/f1,▁
eval/loss,▁
eval/precision,▁
eval/recall,▁
eval/runtime,▁
eval/samples_per_second,▁
eval/steps_per_second,▁
train/epoch,▁
train/global_step,▁

0,1
eval/accuracy,0.40849
eval/f1,0.02436
eval/loss,5.6914
eval/precision,0.06645
eval/recall,0.01492
eval/runtime,90.5602
eval/samples_per_second,102.263
eval/steps_per_second,20.462
train/epoch,5.0
train/global_step,6940.0


# Evaluating DAGA Samples

In [None]:
wandb.init(name = "daga-test")

In [None]:
# loading DAGA samples
DAGASamples = pd.read_csv('Data_Augmentation_for_Low_Resource_Indian_Legal_NER/daga_samples_postprocessed.csv')

DAGASamples['tokens'] = DAGASamples['tokens'].map(lambda x: eval(x))
DAGASamples['ner_tags'] = DAGASamples['ner_tags'].map(lambda x: eval(x))

In [None]:
# concatinating gold test samples + DAGA samples
DAGASamples = pd.concat([test_df, DAGASamples])

DAGASamplesDataset = Dataset.from_pandas(DAGASamples)
DAGASamples_tokenized = DAGASamplesDataset.map(tokenize_all_labels, batched=True)
DAGASamples_results = trainer.evaluate(DAGASamples_tokenized)
DAGASamples_results

Map:   0%|          | 0/13332 [00:00<?, ? examples/s]

{'eval_loss': 1.0615456104278564,
 'eval_precision': 0.6063009001285898,
 'eval_recall': 0.5157311052031903,
 'eval_f1': 0.5573606311106003,
 'eval_accuracy': 0.860454895662648,
 'eval_runtime': 95.4831,
 'eval_samples_per_second': 139.627,
 'eval_steps_per_second': 27.932,
 'epoch': 5.0}

In [None]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.002 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.536906…

0,1
eval/accuracy,▁
eval/f1,▁
eval/loss,▁
eval/precision,▁
eval/recall,▁
eval/runtime,▁
eval/samples_per_second,▁
eval/steps_per_second,▁
train/epoch,▁
train/global_step,▁

0,1
eval/accuracy,0.86045
eval/f1,0.55736
eval/loss,1.06155
eval/precision,0.6063
eval/recall,0.51573
eval/runtime,95.4831
eval/samples_per_second,139.627
eval/steps_per_second,27.932
train/epoch,5.0
train/global_step,6940.0


# Evaluating Synonym/Mention Replacement Samples

In [None]:
wandb.init(name = "synmen-test")

In [None]:
# loading synonym mention samples
synmenSamples_train = pd.read_csv('Data_Augmentation_for_Low_Resource_Indian_Legal_NER/train_postprocessed.csv')
synmenSamples_test = pd.read_csv('Data_Augmentation_for_Low_Resource_Indian_Legal_NER/test_postprocessed.csv')
synmenSamples_dev = pd.read_csv('Data_Augmentation_for_Low_Resource_Indian_Legal_NER/dev_postprocessed.csv')

synmenSamples_train_val = synmenSamples_train.append(synmenSamples_dev, ignore_index=True)
synmenSamples = synmenSamples_train_val.append(synmenSamples_test, ignore_index=True)

synmenSamples['tokens'] = synmenSamples['tokens'].map(lambda x: eval(x))
synmenSamples['ner_tags'] = synmenSamples['ner_tags'].map(lambda x: eval(x))

In [None]:
synmenSamples

Unnamed: 0,tokens,ner_tags
0,"[that, extent.]","[0, 0]"
1,"[The, jurisdiction, granted, to, the, Tribunal...","[0, 0, 0, 0, 0, 0, 0, 0, 10, 10, 10, 0, 0, 0, ..."
2,"[In, the, said, case,, Kailasm, ,, J.,, while,...","[0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,"[In, Bhinka, v., Charan, Singh,, (AIR, 1959, S...","[0, 12, 12, 12, 12, 12, 12, 12, 12, 0, 0, 0, 0..."
4,"[Thereafter,, he, shouted, and, his, aunt, Gau...","[0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 5, 0, 0, ..."
...,...,...
11111,"[Carbp-196, of, 2016.odt, rrpillai, In, The, H...","[0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 0, 0, 0, ..."
11112,"[dss, 1, Judgment-wp-10835-18-g.doc, In, The, ...","[0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, ..."
11113,"[In, The, High, Court, Of, Kerala, At, Ernakul...","[0, 0, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 4, 4, ..."
11114,"[High, Court, Of, Judicature, At, Allahabad,, ...","[2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, ..."


In [None]:
# concatinating gold test samples + synonym/mention samples
synmenSamples = pd.concat([test_df, synmenSamples])

synmenSamplesDataset = Dataset.from_pandas(synmenSamples)
synmenSamples_tokenized = synmenSamplesDataset.map(tokenize_all_labels, batched=True)
synmenSamples_results = trainer.evaluate(synmenSamples_tokenized)
synmenSamples_results

Map:   0%|          | 0/11983 [00:00<?, ? examples/s]

{'eval_loss': 0.16805464029312134,
 'eval_precision': 0.825912219967371,
 'eval_recall': 0.7981746631971595,
 'eval_f1': 0.8118065786214358,
 'eval_accuracy': 0.9494854443721775,
 'eval_runtime': 141.3006,
 'eval_samples_per_second': 84.805,
 'eval_steps_per_second': 16.964,
 'epoch': 5.0}

In [None]:
wandb.finish()

0,1
eval/accuracy,▁
eval/f1,▁
eval/loss,▁
eval/precision,▁
eval/recall,▁
eval/runtime,▁
eval/samples_per_second,▁
eval/steps_per_second,▁
train/epoch,▁
train/global_step,▁

0,1
eval/accuracy,0.94949
eval/f1,0.81181
eval/loss,0.16805
eval/precision,0.82591
eval/recall,0.79817
eval/runtime,141.3006
eval/samples_per_second,84.805
eval/steps_per_second,16.964
train/epoch,5.0
train/global_step,6940.0


# Extras: Demonstrating Results on an Indian Legal Example

In [None]:
# importing the trained model from the checkpoint
tokenizer = AutoTokenizer.from_pretrained('./indian_legal_ner.model/')
model = AutoModelForTokenClassification.from_pretrained('./indian_legal_ner.model/', num_labels=len(labels_list))
model

BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, el

In [None]:
sentence = '''Further, section 19(3) of the act specifies that unless a different intention appears, rules contained in section 20 to 24 rules for ascertaining the intention of the parties as to the time at which the property and goods is to passed to the buyer.'''
tokenizer = AutoTokenizer.from_pretrained('./indian_legal_ner.model/')
input_tensor = tokenizer(sentence)['input_ids']
attention_tensor = tokenizer(sentence)['attention_mask']

model = AutoModelForTokenClassification.from_pretrained('./indian_legal_ner.model/', num_labels=len(labels_list))

predictions = model.forward(input_ids=torch.tensor(input_tensor).unsqueeze(0), attention_mask=torch.tensor(attention_tensor).unsqueeze(0))
predictions = torch.argmax(predictions.logits.squeeze(), axis=1)
entities = [labels_list[i] for i in predictions]

tokens = tokenizer.batch_decode(input_tensor)
for tag, entity, token in zip(predictions, entities, tokens):
  print(f"{token:<12}{tag:<12}{entity}")

[CLS]       0            OTHERS
further     0            OTHERS
,           0            OTHERS
section     11           PROVISION
19          11           PROVISION
(           11           PROVISION
3           11           PROVISION
)           11           PROVISION
of          0            OTHERS
the         0            OTHERS
act         0            OTHERS
specifie    0            OTHERS
##s         0            OTHERS
that        0            OTHERS
unless      0            OTHERS
a           0            OTHERS
different   0            OTHERS
intention   0            OTHERS
appears     0            OTHERS
,           0            OTHERS
rules       0            OTHERS
contained   0            OTHERS
in          0            OTHERS
section     11           PROVISION
20          11           PROVISION
to          11           PROVISION
24          11           PROVISION
rules       0            OTHERS
for         0            OTHERS
ascertain   0            OTHERS
##ing       0