In [1]:
import torch

# Confirm that the GPU is detected

assert torch.cuda.is_available()

# Get the GPU device name.

device_name = torch.cuda.get_device_name()
n_gpu = torch.cuda.device_count()
print(f"Found device: {device_name}, n_gpu: {n_gpu}")
device = torch.device("cuda")

Found device: Tesla T4, n_gpu: 1


In [2]:
!pip install torch
!pip install transformers
!pip install -U -q PyDrive
!pip install datasets
!pip install seqeval
!pip install ray[tune]
!pip install transformers accelerate
!pip install wandb

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [3]:
import os
import itertools
import pandas as pd
import numpy as np
from datasets import Dataset
from datasets import load_metric
from transformers import AutoTokenizer
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer
from transformers import DataCollatorForTokenClassification
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("nlpaueb/legal-bert-base-uncased")
model = AutoModelForSequenceClassification.from_pretrained("nlpaueb/legal-bert-base-uncased", num_labels=15) # num_labels = 14 + 1


Some weights of the model checkpoint at nlpaueb/legal-bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification wer

# Importing Datasets

In [4]:
import pandas as pd
import json
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/My Drive/Colab Notebooks/

train_df_judgement = pd.read_json('Data_Augmentation_for_Low_Resource_Indian_Legal_NER/NER_TRAIN_JUDGEMENT_PREPROCESSED.json')
train_df_preamble = pd.read_json('Data_Augmentation_for_Low_Resource_Indian_Legal_NER/NER_TRAIN_PREAMBLE_PREPROCESSED.json')

test_df_judgement = pd.read_json('Data_Augmentation_for_Low_Resource_Indian_Legal_NER/NER_DEV_JUDGEMENT_PREPROCESSED.json')
test_df_preamble = pd.read_json('Data_Augmentation_for_Low_Resource_Indian_Legal_NER/NER_DEV_PREAMBLE_PREPROCESSED.json')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/My Drive/Colab Notebooks


In [5]:
print(len(train_df_judgement), len(train_df_preamble))
print(len(test_df_judgement), len(test_df_preamble))
total_data = len(train_df_judgement) + len(train_df_preamble) + len(test_df_judgement) + len(test_df_preamble)

8494 553
839 28


# Train : Validation : Test split (70:20:10)

In [6]:
# concatenating the judgement + preamble data
df = pd.concat([train_df_judgement, train_df_preamble])
df.reset_index(inplace=True, drop=True)
# print(df)
print(total_data)
# splitting train (70% of the total data)
split=int(total_data*0.7)
print(split)
train_df = df[:split]
# splitting validation (appx 20% of the total data)
val_df = df[split:]
# test data (appx 10% of the total data)
test_df = pd.concat([test_df_judgement, test_df_preamble])

9914
6939


In [7]:
len(train_df)

6939

In [8]:
len(val_df)

2108

In [9]:
train_dataset = Dataset.from_pandas(train_df)
val_dataset = Dataset.from_pandas(val_df)
test_dataset = Dataset.from_pandas(test_df)

# Tokenization

In [10]:
labels_list = ["OTHERS", "PETITIONER", "COURT", "RESPONDENT", "JUDGE", "OTHER_PERSON", "LAWYER", "DATE", "ORG", "GPE", "STATUTE", "PROVISION", "PRECEDENT", "CASE_NUMBER", "WITNESS"]
label_encoding_dict = {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9, 10: 10, 11: 11, 12: 12, 13: 13, 14: 14}
label_list_encoding_dict = {0: "OTHERS", 1: "PETITIONER", 2: "COURT", 3: "RESPONDENT", 4: "JUDGE", 5: "OTHER_PERSON", 6: "LAWYER", 7: "DATE", 8: "ORG", 9: "GPE", 10: "STATUTE", 11: "PROVISION", 12: "PRECEDENT", 13: "CASE_NUMBER", 14: "WITNESS"}

In [11]:
def tokenize_all_labels(rows):
    tokenized_inputs = tokenizer(list(rows["tokens"]), truncation = True, is_split_into_words = True)
    labels, label_all = [], True
    for index, label in enumerate(rows["ner_tags"]):
        # print(i, label)
        prior_idx = None
        word_ids = tokenized_inputs.word_ids(batch_index = index)
        
        label_ids = []
        for current_idx in word_ids:
            if current_idx is None: label_ids.append(-100)
            elif label[current_idx] == '0': label_ids.append(0)
            elif current_idx != prior_idx: label_ids.append(label_encoding_dict[label[current_idx]])
            else: label_ids.append(label_encoding_dict[label[current_idx]] if label_all else -100)
            prior_idx = current_idx
        labels.append(label_ids)
        
    tokenized_inputs["labels"] = labels
    return tokenized_inputs

In [12]:
train_dataset_tokenized = train_dataset.map(tokenize_all_labels, batched=True)
val_dataset_tokenized = val_dataset.map(tokenize_all_labels, batched=True)

Map:   0%|          | 0/6939 [00:00<?, ? examples/s]

Map:   0%|          | 0/2108 [00:00<?, ? examples/s]

# Training

In [13]:
import wandb
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [14]:
import warnings
warnings.filterwarnings('ignore')

# hyperparameters
batch_size = 5
epoch = 5
learning_rate = 2.910635913133073e-05

# tokenizing and loading legalBERT model 
tokenizer = AutoTokenizer.from_pretrained("nlpaueb/legal-bert-base-uncased")
model = AutoModelForTokenClassification.from_pretrained("nlpaueb/legal-bert-base-uncased", num_labels=len(labels_list))
metric = load_metric("seqeval")

training_arguments = TrainingArguments(
    "eval_indian_legal_ner",
    report_to='wandb',
    evaluation_strategy = "steps",
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=epoch,
    weight_decay=1e-5,
    load_best_model_at_end=True,
    metric_for_best_model='accuracy',
    run_name='gold'  
)

# training on train set and evaluating on validation set
data_collator = DataCollatorForTokenClassification(tokenizer)

def evaluate_metrics(pred_tuple):
    predictions, labels = pred_tuple
    predictions = np.argmax(predictions, axis=2)

    actual_predictions = [[labels_list[pred_tuple] for (pred_tuple, l) in zip(prediction, label) if l != -100] for prediction, label in zip(predictions, labels)]
    actual_labels = [[labels_list[l] for (pred_tuple, l) in zip(prediction, label) if l != -100] for prediction, label in zip(predictions, labels)]
    results = metric.compute(predictions=actual_predictions, references=actual_labels)
    return {"precision": results["overall_precision"], "recall": results["overall_recall"], "f1": results["overall_f1"], "accuracy": results["overall_accuracy"]}
    
trainer = Trainer(
    model,
    training_arguments,
    train_dataset = train_dataset_tokenized,
    eval_dataset = val_dataset_tokenized,
    data_collator = data_collator,
    tokenizer=tokenizer,
    compute_metrics=evaluate_metrics
)
trainer.train()
trainer.evaluate()
trainer.save_model('indian_legal_ner.model')
wandb.finish()

Some weights of the model checkpoint at nlpaueb/legal-bert-base-uncased were not used when initializing BertForTokenClassification: ['cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initia

Downloading builder script:   0%|          | 0.00/2.47k [00:00<?, ?B/s]

[34m[1mwandb[0m: Currently logged in as: [33mastha[0m ([33m685_data_augmentation_legal_ner[0m). Use [1m`wandb login --relogin`[0m to force relogin


You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy
500,0.344,0.569515,0.440463,0.533378,0.482488,0.796461
1000,0.1649,0.457239,0.461686,0.64219,0.53718,0.820823
1500,0.1268,0.702162,0.435174,0.61393,0.509323,0.768313
2000,0.0918,0.583336,0.482961,0.659101,0.557448,0.785091
2500,0.0942,0.545131,0.486081,0.707165,0.576142,0.793985
3000,0.0793,0.454535,0.538707,0.713841,0.61403,0.828136
3500,0.0592,0.512078,0.520792,0.70227,0.598067,0.817577
4000,0.0527,0.565799,0.507934,0.698042,0.588004,0.807566
4500,0.0358,0.678364,0.489626,0.666889,0.564673,0.787559
5000,0.0343,0.619108,0.508846,0.697597,0.588456,0.80138


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/accuracy,▄▇▁▃▄█▇▆▃▅▆▆▆█
eval/f1,▁▄▂▅▆█▇▇▅▇▇█▇█
eval/loss,▄▁█▅▄▁▃▄▇▆▅▆▆▁
eval/precision,▁▃▁▄▄█▇▆▅▆▇▇▇█
eval/recall,▁▅▄▆███▇▆▇████
eval/runtime,▃▁▇▂▃▂▂▂▅▃█▅▇▃
eval/samples_per_second,▅█▂▇▆▆▇▇▄▅▁▄▂▆
eval/steps_per_second,▅█▂▇▆▆▇▇▄▅▁▄▂▆
train/epoch,▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇████
train/global_step,▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇████

0,1
eval/accuracy,0.82814
eval/f1,0.61403
eval/loss,0.45453
eval/precision,0.53871
eval/recall,0.71384
eval/runtime,19.4791
eval/samples_per_second,108.219
eval/steps_per_second,21.664
train/epoch,5.0
train/global_step,6940.0


# Evaluating Test Dataset

In [15]:
wandb.init(name = "gold-test")

In [16]:
test_dataset_tokenized = test_dataset.map(tokenize_all_labels, batched=True)
test_results = trainer.evaluate(test_dataset_tokenized)
test_results

Map:   0%|          | 0/867 [00:00<?, ? examples/s]

{'eval_loss': 0.21242676675319672,
 'eval_precision': 0.7685279187817259,
 'eval_recall': 0.8701149425287357,
 'eval_f1': 0.8161725067385445,
 'eval_accuracy': 0.921025964374986,
 'eval_runtime': 6.509,
 'eval_samples_per_second': 133.2,
 'eval_steps_per_second': 26.732,
 'epoch': 5.0}

In [17]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.010 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.117229…

0,1
eval/accuracy,▁
eval/f1,▁
eval/loss,▁
eval/precision,▁
eval/recall,▁
eval/runtime,▁
eval/samples_per_second,▁
eval/steps_per_second,▁
train/epoch,▁
train/global_step,▁

0,1
eval/accuracy,0.92103
eval/f1,0.81617
eval/loss,0.21243
eval/precision,0.76853
eval/recall,0.87011
eval/runtime,6.509
eval/samples_per_second,133.2
eval/steps_per_second,26.732
train/epoch,5.0
train/global_step,6940.0


# Evaluating legalBART Samples

In [18]:
wandb.init(name = "legalbart-test")

In [19]:
# loading legalBART samples
legalBARTSamples = pd.read_csv('Data_Augmentation_for_Low_Resource_Indian_Legal_NER/legal_BART_Samples.csv')
legalBARTSamples['tokens'] = legalBARTSamples['tokens'].map(lambda x: eval(x))
legalBARTSamples['ner_tags'] = legalBARTSamples['ner_tags'].map(lambda x: eval(x))

In [20]:
legalBARTSamples

Unnamed: 0,tokens,ner_tags
0,"[On, specific, query, by, the, Bench, ,, it, w...","[0, 0, 0, 0, 5, 1, 5, 7, 6, 0, 0, 0, 0, 5, 0, ..."
1,"[According, to, the, Agya, ,, span, class, '',...","[0, 5, 5, 3, 5, 0, 2, 0, 0, 0, 0, 0, 0, 0, 7, ..."
2,"[PW3, Vijay, Mishra, ,, Deputy, Manager, ,, HD...","[0, 7, 7, 5, 2, 0, 5, 0, 2, 5, 5, 5, 0, 0, 0, ..."
3,"[He, was, asked, to, come, and, carry, out, th...","[0, 6, 0, 5, 0, 5, 0, 0, 5, 1, 0, 0, 0, 0, 6, ..."
4,"[The, pillion, rider, ,, Satyanarayana, Murthy...","[5, 0, 0, 5, 3, 3, 6, 0, 0, 5, 5, 0, 5, 5, 0, ..."
...,...,...
8389,"[The, Honble, Sri, Justice, Nooty, Ramamohana,...","[2, 0, 0, 2, 2, 2, 2, 0, 0, 0, 1, 0, 0, 0, 1, ..."
8390,"[Petitioner, :, Indira, Nehru, Gandhi, Respond...","[0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, ..."
8391,"[In, The, Court, Of, Sh, ., R.K, ., Sharma, ,,...","[0, 2, 2, 1, 0, 0, 2, 2, 1, 0, 0, 0, 2, 0, 0, ..."
8392,"[Petitioner, :, Bhagwant, Singh, Vs, ., Respon...","[0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, ..."


In [21]:
# concatinating gold test samples + legalBART samples
legalBARTSamples = pd.concat([test_df, legalBARTSamples])

legalBARTSamplesDataset = Dataset.from_pandas(legalBARTSamples)
legalBARTSamples_tokenized = legalBARTSamplesDataset.map(tokenize_all_labels, batched=True)
legalBARTSamples_results = trainer.evaluate(legalBARTSamples_tokenized)
legalBARTSamples_results

Map:   0%|          | 0/9261 [00:00<?, ? examples/s]

{'eval_loss': 5.416061878204346,
 'eval_precision': 0.04532750030380362,
 'eval_recall': 0.01564466068282862,
 'eval_f1': 0.02326088990053319,
 'eval_accuracy': 0.398821179327813,
 'eval_runtime': 93.3935,
 'eval_samples_per_second': 99.161,
 'eval_steps_per_second': 19.841,
 'epoch': 5.0}

In [22]:
wandb.finish()

VBox(children=(Label(value='0.010 MB of 0.010 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/accuracy,▁
eval/f1,▁
eval/loss,▁
eval/precision,▁
eval/recall,▁
eval/runtime,▁
eval/samples_per_second,▁
eval/steps_per_second,▁
train/epoch,▁
train/global_step,▁

0,1
eval/accuracy,0.39882
eval/f1,0.02326
eval/loss,5.41606
eval/precision,0.04533
eval/recall,0.01564
eval/runtime,93.3935
eval/samples_per_second,99.161
eval/steps_per_second,19.841
train/epoch,5.0
train/global_step,6940.0


# Evaluating DAGA Samples

In [23]:
wandb.init(name = "daga-test")

In [24]:
# loading DAGA samples
DAGASamples = pd.read_csv('Data_Augmentation_for_Low_Resource_Indian_Legal_NER/daga_samples_postprocessed.csv')

DAGASamples['tokens'] = DAGASamples['tokens'].map(lambda x: eval(x))
DAGASamples['ner_tags'] = DAGASamples['ner_tags'].map(lambda x: eval(x))

In [25]:
# concatinating gold test samples + DAGA samples
DAGASamples = pd.concat([test_df, DAGASamples])

DAGASamplesDataset = Dataset.from_pandas(DAGASamples)
DAGASamples_tokenized = DAGASamplesDataset.map(tokenize_all_labels, batched=True)
DAGASamples_results = trainer.evaluate(DAGASamples_tokenized)
DAGASamples_results

Map:   0%|          | 0/13332 [00:00<?, ? examples/s]

{'eval_loss': 1.090808629989624,
 'eval_precision': 0.6811816094024407,
 'eval_recall': 0.5363233469243042,
 'eval_f1': 0.6001348885186066,
 'eval_accuracy': 0.8335702581576292,
 'eval_runtime': 91.4461,
 'eval_samples_per_second': 145.791,
 'eval_steps_per_second': 29.165,
 'epoch': 5.0}

In [26]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.010 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.117087…

0,1
eval/accuracy,▁
eval/f1,▁
eval/loss,▁
eval/precision,▁
eval/recall,▁
eval/runtime,▁
eval/samples_per_second,▁
eval/steps_per_second,▁
train/epoch,▁
train/global_step,▁

0,1
eval/accuracy,0.83357
eval/f1,0.60013
eval/loss,1.09081
eval/precision,0.68118
eval/recall,0.53632
eval/runtime,91.4461
eval/samples_per_second,145.791
eval/steps_per_second,29.165
train/epoch,5.0
train/global_step,6940.0


# Evaluating MulDA Samples

In [27]:
wandb.init(name = "mulda-test")

In [28]:
# loading MulDa samples
MulDaSamples_train = pd.read_csv('Data_Augmentation_for_Low_Resource_Indian_Legal_NER/train_postprocessed.csv')
MulDaSamples_test = pd.read_csv('Data_Augmentation_for_Low_Resource_Indian_Legal_NER/test_postprocessed.csv')
MulDaSamples_dev = pd.read_csv('Data_Augmentation_for_Low_Resource_Indian_Legal_NER/dev_postprocessed.csv')

MulDaSamples_train_val = MulDaSamples_train.append(MulDaSamples_dev, ignore_index=True)
MulDaSamples = MulDaSamples_train_val.append(MulDaSamples_test, ignore_index=True)

MulDaSamples['tokens'] = MulDaSamples['tokens'].map(lambda x: eval(x))
MulDaSamples['ner_tags'] = MulDaSamples['ner_tags'].map(lambda x: eval(x))

In [29]:
MulDaSamples

Unnamed: 0,tokens,ner_tags
0,"[that, extent.]","[0, 0]"
1,"[The, jurisdiction, granted, to, the, Tribunal...","[0, 0, 0, 0, 0, 0, 0, 0, 10, 10, 10, 0, 0, 0, ..."
2,"[In, the, said, case,, Kailasm, ,, J.,, while,...","[0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,"[In, Bhinka, v., Charan, Singh,, (AIR, 1959, S...","[0, 12, 12, 12, 12, 12, 12, 12, 12, 0, 0, 0, 0..."
4,"[Thereafter,, he, shouted, and, his, aunt, Gau...","[0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 5, 0, 0, ..."
...,...,...
11111,"[Carbp-196, of, 2016.odt, rrpillai, In, The, H...","[0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 0, 0, 0, ..."
11112,"[dss, 1, Judgment-wp-10835-18-g.doc, In, The, ...","[0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, ..."
11113,"[In, The, High, Court, Of, Kerala, At, Ernakul...","[0, 0, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 4, 4, ..."
11114,"[High, Court, Of, Judicature, At, Allahabad,, ...","[2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, ..."


In [30]:
# concatinating gold test samples + MulDa samples
MulDaSamples = pd.concat([test_df, MulDaSamples])

MulDaSamplesDataset = Dataset.from_pandas(MulDaSamples)
MulDaSamples_tokenized = MulDaSamplesDataset.map(tokenize_all_labels, batched=True)
MulDaSamples_results = trainer.evaluate(MulDaSamples_tokenized)
MulDaSamples_results

Map:   0%|          | 0/11983 [00:00<?, ? examples/s]

{'eval_loss': 0.3808448016643524,
 'eval_precision': 0.5797254996388154,
 'eval_recall': 0.6697265570669559,
 'eval_f1': 0.6214845312923502,
 'eval_accuracy': 0.8153815461116569,
 'eval_runtime': 140.8327,
 'eval_samples_per_second': 85.087,
 'eval_steps_per_second': 17.02,
 'epoch': 5.0}

In [31]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.010 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.116978…

0,1
eval/accuracy,▁
eval/f1,▁
eval/loss,▁
eval/precision,▁
eval/recall,▁
eval/runtime,▁
eval/samples_per_second,▁
eval/steps_per_second,▁
train/epoch,▁
train/global_step,▁

0,1
eval/accuracy,0.81538
eval/f1,0.62148
eval/loss,0.38084
eval/precision,0.57973
eval/recall,0.66973
eval/runtime,140.8327
eval/samples_per_second,85.087
eval/steps_per_second,17.02
train/epoch,5.0
train/global_step,6940.0


# Extras: Demonstrating Results on an Indian Legal Example

In [32]:
# importing the trained model from the checkpoint
tokenizer = AutoTokenizer.from_pretrained('./indian_legal_ner.model/')
model = AutoModelForTokenClassification.from_pretrained('./indian_legal_ner.model/', num_labels=len(labels_list))
model

BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, el

In [33]:
sentence = '''Further, section 19(3) of the act specifies that unless a different intention appears, rules contained in section 20 to 24 rules for ascertaining the intention of the parties as to the time at which the property and goods is to passed to the buyer.'''
tokenizer = AutoTokenizer.from_pretrained('./indian_legal_ner.model/')
input_tensor = tokenizer(sentence)['input_ids']
attention_tensor = tokenizer(sentence)['attention_mask']

model = AutoModelForTokenClassification.from_pretrained('./indian_legal_ner.model/', num_labels=len(labels_list))

predictions = model.forward(input_ids=torch.tensor(input_tensor).unsqueeze(0), attention_mask=torch.tensor(attention_tensor).unsqueeze(0))
predictions = torch.argmax(predictions.logits.squeeze(), axis=1)
entities = [labels_list[i] for i in predictions]

tokens = tokenizer.batch_decode(input_tensor)
for tag, entity, token in zip(predictions, entities, tokens):
  print(f"{token:<12}{tag:<12}{entity}")

[CLS]       0           OTHERS
further     0           OTHERS
,           0           OTHERS
section     11          PROVISION
19          11          PROVISION
(           11          PROVISION
3           11          PROVISION
)           11          PROVISION
of          0           OTHERS
the         0           OTHERS
act         0           OTHERS
specifie    0           OTHERS
##s         0           OTHERS
that        0           OTHERS
unless      0           OTHERS
a           0           OTHERS
different   0           OTHERS
intention   0           OTHERS
appears     0           OTHERS
,           0           OTHERS
rules       0           OTHERS
contained   0           OTHERS
in          0           OTHERS
section     11          PROVISION
20          11          PROVISION
to          11          PROVISION
24          11          PROVISION
rules       0           OTHERS
for         0           OTHERS
ascertain   0           OTHERS
##ing       0           OTHERS
the         