This notebook was made to compare different approaches performance on the classification task.

## Git Stuff


In [2]:
! git clone https://github.com/aassegai/LegalEval23
import os
os.chdir('LegalEval23')

fatal: destination path 'LegalEval23' already exists and is not an empty directory.


## imports

In [4]:
'''
Sometimes when GPU runtime is used an error appears during spacy installation
NotImplementedError: A UTF-8 locale is required. Got ANSI_X3.4-1968
This piece of code fixes this problem
'''
import locale
if str(locale.getpreferredencoding()) != 'UTF-8':
    def getpreferredencoding(do_setlocale = True):
      return "UTF-8"
    locale.getpreferredencoding = getpreferredencoding

In [5]:
%load_ext autoreload
%autoreload 2

import pandas as pd
import numpy as np
import torch 
import sklearn
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm 
import pickle
import seaborn as sns 
import matplotlib.pyplot as plt
from sklearn.preprocessing import MultiLabelBinarizer
from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss, NLLLoss
from torch.optim import Adam
from tqdm.auto import tqdm
from torch import nn

import requests

! pip install nltk
import nltk
nltk.download('wordnet')
nltk.download('omw-1.4')
from nltk.tokenize import word_tokenize 
from nltk.stem import WordNetLemmatizer

from string import punctuation
import re

! pip install transformers==4.28.0 
! pip install evaluate
! pip install accelerate transformers==4.28.0
from transformers import AutoModel, AutoTokenizer, BertModel


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers==4.28.0
  Downloading transformers-4.28.0-py3-none-any.whl (7.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m22.9 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.11.0 (from transformers==4.28.0)
  Downloading huggingface_hub-0.14.1-py3-none-any.whl (224 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.5/224.5 kB[0m [31m15.3 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers==4.28.0)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m31.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.14.1 tokenizers-0.13.3 transform

In [6]:
legal_bert_name = "nlpaueb/legal-bert-base-uncased"
indian_legal_uncased_bert_name = 'law-ai/InLegalBERT'

In [7]:
USED_MODEL_NAME = indian_legal_uncased_bert_name

## data preparation

In [8]:
label2id = {'PREAMBLE': 1,
            'FAC': 2,
            'RLC': 3,
            'ISSUE': 4,
            'ARG_PETITIONER': 5,
            'ARG_RESPONDENT': 6,
            'ANALYSIS': 7,
            'STA': 8,
            'PRE_RELIED': 9,
            'PRE_NOT_RELIED': 10,
            'RATIO': 11,
            'RPC': 12,
            'NONE': 0
}

id2label = {1: 'PREAMBLE',
            2: 'FAC',
            3: 'RLC',
            4: 'ISSUE',
            5: 'ARG_PETITIONER',
            6: 'ARG_RESPONDENT',
            7: 'ANALYSIS',
            8: 'STA',
            9: 'PRE_RELIED',
            10: 'PRE_NOT_RELIED',
            11: 'RATIO',
            12: 'RPC',
            0: 'NONE'
}

num_labels = 13

In [9]:
train_df, val_df = train_test_split(pd.read_json('./data/raw/train.json'), train_size=0.8, shuffle=True, random_state=17)
test_df = pd.read_json('./data/raw/dev.json')

In [10]:
from src.preprocessing.data_preprocessing import DataPreprocessor
data_preprocessor = DataPreprocessor(lower=True)
train_df = data_preprocessor(train_df)
val_df = data_preprocessor(val_df)
test_df = data_preprocessor(test_df)

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


  0%|          | 0/197 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

In [11]:
train_df

Unnamed: 0_level_0,annotations,data,meta
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
4157,"[{'id': 'b7ed1b2b1b7040908d373c6bb3733689', 's...",non-reportable in the supreme court of india c...,{'group': 'Criminal'}
1723,"[{'id': '0701a8f05d5649bba6393dac66b6de47', 's...",petitioner: bakul cashew co. & ors. vs. respon...,{'group': 'Tax'}
4257,"[{'id': 'a435d470ca2f454dbc6c096ca9a1cce7', 's...","1/11 in the high court of karnataka, bengaluru...",{'group': 'Tax'}
4278,"[{'id': '3ab5646f1f2a4dcd9336fcc5572d43a7', 's...","1/22 in the high court of karnataka, bengaluru...",{'group': 'Tax'}
4175,"[{'id': '6b7999cc3e2f42dc9805036f31d46fa9', 's...",in the high court of karnataka at bangalore da...,{'group': 'Criminal'}
...,...,...,...
4109,"[{'id': 'c81e335b7eed4c6f973e892ef37f7519', 's...","petitioner: commissioner of income-tax, calcut...",{'group': 'Tax'}
1720,"[{'id': '95c7de1ad7d34736add939d27ed86bc3', 's...",petitioner: bakulbhai and anr. vs. respondent:...,{'group': 'Criminal'}
1754,"[{'id': '5d3a9649503641a1ba16445876f55cf5', 's...",$~23 & 24 (common order) in the high court of ...,{'group': 'Criminal'}
4159,"[{'id': 'dfdf79e97bf34de692c03d95c984213d', 's...",* in the high court of delhi at new delhi date...,{'group': 'Tax'}


In [12]:
# don't judge me for this code :(( 
# I just made it once and I don't want to refactor it since it does not appear in final version

train_ids = list(np.concatenate([[result['id'] for result in train_df.loc[id].annotations] for id in train_df.index]))
train_sentences = list(np.concatenate([[result['text'] for result in train_df.loc[id].annotations] for id in train_df.index]))
train_labels = list(np.concatenate([[result['label'][0] for result in train_df.loc[id].annotations] for id in train_df.index]))

val_ids = list(np.concatenate([[result['id'] for result in val_df.loc[id].annotations] for id in val_df.index]))
val_sentences = list(np.concatenate([[result['text'] for result in val_df.loc[id].annotations] for id in val_df.index]))
val_labels = list(np.concatenate([[result['label'][0] for result in val_df.loc[id].annotations] for id in val_df.index]))

test_ids = list(np.concatenate([[result['id'] for result in test_df.loc[id].annotations] for id in test_df.index]))
test_sentences = list(np.concatenate([[result['text'] for result in test_df.loc[id].annotations] for id in test_df.index]))
test_labels = list(np.concatenate([[result['label'][0] for result in test_df.loc[id].annotations] for id in test_df.index]))


In [13]:
train_sentence_df = pd.DataFrame(data=np.array([train_ids, train_sentences, train_labels], dtype=object).T, columns=['id', 'text', 'label']).rename(columns={'text': 'sentence'})
val_sentence_df = pd.DataFrame(data=np.array([val_ids, val_sentences, val_labels], dtype=object).T, columns=['id', 'text', 'label']).rename(columns={'text': 'sentence'})
test_sentence_df = pd.DataFrame(data=np.array([test_ids, test_sentences, test_labels], dtype=object).T, columns=['id', 'text', 'label']).rename(columns={'text': 'sentence'})
train_sentence_df['context'] = ''
val_sentence_df['context'] = ''
test_sentence_df['context'] = ''

In [14]:
train_sentence_df

Unnamed: 0,id,sentence,label,context
0,b7ed1b2b1b7040908d373c6bb3733689,non-reportable,PREAMBLE,
1,7213be2a986946d7bf628f213fc4b390,in the supreme court of india,PREAMBLE,
2,18f8f4feaaae445b8bd23b4e69b39da0,civil appellate jurisdiction civil appeal no.1...,PREAMBLE,
3,cd954bce1e3244c3a32fed20d3ff9927,.... appellant(s) versus chandra bhushan yadav...,PREAMBLE,
4,7786ce217d1a49ac9a11e643bbb038b9,civil appeal no.7440 of 2018,PREAMBLE,
...,...,...,...,...
23414,a959b03eddf3493484c41f482019c65f,we are accordingly of the opinion that the hig...,RATIO,
23415,4a240502a3174fa38c64ad4d1fc466e4,mr. rana contended that there was no proof fro...,ARG_RESPONDENT,
23416,7a855cd3ea4d462d9982909d6714831b,apart from the fact that both the courts have ...,ANALYSIS,
23417,9be7825aa37e4e88afec5bfa7dd1c6c4,the omission of this fact in the medical repor...,RATIO,


In [15]:
from src.datasets.dataset_builder import DatasetBuilder
builder = DatasetBuilder(USED_MODEL_NAME)

Downloading (…)okenizer_config.json:   0%|          | 0.00/516 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/222k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [16]:
train_dataset = builder.build_dataset(train_sentence_df)
val_dataset = builder.build_dataset(val_sentence_df)
test_dataset = builder.build_dataset(test_sentence_df, for_test=True)

Processing...


Map:   0%|          | 0/23419 [00:00<?, ? examples/s]

Processing...


Map:   0%|          | 0/5567 [00:00<?, ? examples/s]

Processing...


Map:   0%|          | 0/2879 [00:00<?, ? examples/s]

Map:   0%|          | 0/2879 [00:00<?, ? examples/s]

## Baseline: tuning indian legal BERT on plain sentences

In [None]:
trainer_params = {'batch_size': 32,
                  'n_epochs': 3,
                  'lr': 2e-5,
                  'optimizer': 'adamw_torch',
                  'weight_decay': 0.015,
                  'do_fp16': True,
                  'num_workers': 2}

In [None]:
from src.model.transformer_trainer import TransformerTrainer
from torch.cuda import empty_cache
import gc
tf_trainer = TransformerTrainer(bert_name=USED_MODEL_NAME, 
                             num_labels=num_labels,
                             params=trainer_params,
                             id2label=id2label,
                             label2id=label2id)
# removing cached file to avoid any possible conflict between consecutive model trainings
print("Cleaning memory...")
! rm -rf ./root/.cache/huggingface/hub/model*
empty_cache()
gc.collect()
trainer_baseline = tf_trainer.fit(train_dataset, 
            val_dataset,
            save_model=False)

Cleaning memory...
Downloading : law-ai/InLegalBERT


Downloading (…)lve/main/config.json:   0%|          | 0.00/671 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/534M [00:00<?, ?B/s]

Some weights of the model checkpoint at law-ai/InLegalBERT were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initial

Epoch,Training Loss,Validation Loss,Accuracy,Precison,Recall,Weighted F1
1,1.0305,1.07145,0.666248,0.638135,0.666248,0.646737
2,0.9331,1.061047,0.667325,0.642689,0.667325,0.649955
3,0.9278,1.061047,0.667325,0.642689,0.667325,0.649955


You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/7.55k [00:00<?, ?B/s]

  _warn_prf(average, modifier, msg_start, len(result))


Downloading builder script:   0%|          | 0.00/7.36k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/6.77k [00:00<?, ?B/s]

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  _warn_prf(average, modifier, msg_start, len(result))
You're using a BertTokenizerFast tokenize

In [None]:
from sklearn.metrics import classification_report
preds_for_baseline = tf_trainer.predict(test_dataset, trainer_baseline)

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
preds_for_baseline = [id2label[np.argmax(pred)] for pred in preds_for_baseline]
print(classification_report(test_labels, preds_for_baseline))

                precision    recall  f1-score   support

      ANALYSIS       0.63      0.79      0.70       984
ARG_PETITIONER       0.25      0.36      0.29        70
ARG_RESPONDENT       0.00      0.00      0.00        38
           FAC       0.66      0.73      0.69       580
         ISSUE       0.74      0.74      0.74        50
          NONE       0.89      0.86      0.87       190
      PREAMBLE       0.88      0.65      0.75       508
PRE_NOT_RELIED       0.00      0.00      0.00        12
    PRE_RELIED       0.60      0.39      0.48       142
         RATIO       0.29      0.06      0.10        70
           RLC       0.52      0.28      0.36       116
           RPC       0.81      0.82      0.82        91
           STA       0.55      0.64      0.59        28

      accuracy                           0.67      2879
     macro avg       0.52      0.49      0.49      2879
  weighted avg       0.67      0.67      0.66      2879



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


## Tuning indian legal on pairs

Since the baseline model perform badly on the classes that are comparatively rare, we will try to fix this by feeding sentence pairs to the model. The logic behind this decision is that we can balance amount of sentence classes, whereas amount of pairs available is $n^2$ where $n$ is size of the class in train set. 

Implementation is rough, we lose the classes frequency distribution and see rarre sentences many times. But maybe it will be enough to see if the approach is good. 

In [None]:
from itertools import combinations
def generate_pairs_with_same_label(df, start_samples=200, samples=2000):
    dfs = []
    for label in tqdm(set(df.label)):
        temp_df = df.loc[df.label==label]
        temp_df = temp_df.sample(n=start_samples if start_samples < len(temp_df) else len(temp_df), random_state=17)
        pairs = temp_df.groupby('label')['sentence'].apply(combinations,2)\
                     .apply(list).apply(pd.Series)\
                     .stack().apply(pd.Series)\
                     .set_axis(labels=['text_1','text_2'],axis=1)\
                     .reset_index(level=0)
        pairs = pairs.sample(n=samples, random_state=17, ignore_index=True)
        pairs['sentence'] = pairs.text_1 + '[SEP]' + pairs.text_2 
        pairs['context'] = ''
        pairs.drop(columns=['text_1', 'text_2'], inplace=True)
        dfs.append(pairs)
    result_df = pd.concat(dfs, ignore_index=True)
    return result_df

The first idea was to make twin-head model and feed pairs of sentences into it. But the model was too large and didn't seem to perform any better than a baseline model. This is the second generation that also has compatibility with trainer used by all other models. 

In [None]:
train_pair_df = generate_pairs_with_same_label(train_sentence_df, start_samples=200, samples=2800)
# val_pair_df = generate_pairs_with_same_label(val_sentence_df, start_samples=100, samples=800)

  0%|          | 0/13 [00:00<?, ?it/s]

In [None]:
train_pair_df.head()

Unnamed: 0,label,sentence,context
0,PRE_RELIED,a director so dismissed was only entitled to r...,
1,PRE_RELIED,"consequently, the learned single judge in the ...",
2,PRE_RELIED,the court concluded that merely because there ...,
3,PRE_RELIED,section 5 (vi) of that act had left it to the ...,
4,PRE_RELIED,if the answer of this question is prima facie ...,


In [None]:
train_pair_dataset = builder.build_dataset(train_pair_df)

Processing...


Map:   0%|          | 0/36400 [00:00<?, ? examples/s]

Processing...


Map:   0%|          | 0/5567 [00:00<?, ? examples/s]

Processing...


Map:   0%|          | 0/2879 [00:00<?, ? examples/s]

Map:   0%|          | 0/2879 [00:00<?, ? examples/s]

In [None]:
trainer_params = {'batch_size': 32,
                  'n_epochs': 3,
                  'lr': 2e-5,
                  'optimizer': 'adamw_torch',
                  'weight_decay': 0.015,
                  'do_fp16': True,
                  'num_workers': 2}

In [None]:
from src.model.transformer_trainer import TransformerTrainer
from torch.cuda import empty_cache
import gc
tf_trainer = TransformerTrainer(bert_name=USED_MODEL_NAME, 
                             num_labels=num_labels,
                             params=trainer_params,
                             id2label=id2label,
                             label2id=label2id)
# removing cached file to avoid any possible conflict between consecutive model trainings
print("Cleaning memory...")
! rm -rf ./root/.cache/huggingface/hub/model*
empty_cache()
gc.collect()
trainer = tf_trainer.fit(train_pair_dataset, 
            val_dataset,
            save_model=False)

Cleaning memory...
Downloading : law-ai/InLegalBERT


Some weights of the model checkpoint at law-ai/InLegalBERT were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initial

Epoch,Training Loss,Validation Loss,Accuracy,Precison,Recall,Weighted F1
1,0.0228,3.06863,0.474223,0.612229,0.474223,0.488644
2,0.0052,3.40197,0.463805,0.61253,0.463805,0.474985
3,0.0018,3.40197,0.463805,0.61253,0.463805,0.474985


You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `_

In [None]:
from sklearn.metrics import classification_report
preds_for_pairs = tf_trainer.predict(test_dataset, trainer)

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Here we can see that despite unstable training process and worse prediction quality in general, the goal to predict rare classes was achieved. In comparison with the baseline model which did not predict classes like ARG_RESPONDENT or PRE_NOT_RELIED at all, pairwise-trained model actually succeed on them. Results are not that good, but an idea has the right to live. 

In [None]:
preds_for_pairs = [id2label[np.argmax(pred)] for pred in preds_for_pairs]
print(classification_report(test_labels, preds_for_pairs))

                precision    recall  f1-score   support

      ANALYSIS       0.70      0.29      0.41       984
ARG_PETITIONER       0.07      0.17      0.10        70
ARG_RESPONDENT       0.14      0.47      0.22        38
           FAC       0.64      0.59      0.61       580
         ISSUE       0.46      0.88      0.60        50
          NONE       0.75      0.86      0.80       190
      PREAMBLE       0.84      0.53      0.65       508
PRE_NOT_RELIED       0.11      0.42      0.17        12
    PRE_RELIED       0.24      0.46      0.31       142
         RATIO       0.14      0.33      0.19        70
           RLC       0.22      0.46      0.30       116
           RPC       0.65      0.80      0.72        91
           STA       0.15      0.89      0.26        28

      accuracy                           0.48      2879
     macro avg       0.39      0.55      0.41      2879
  weighted avg       0.62      0.48      0.50      2879



## bert + gru  

Paying attention to the fact that in token classification tasks models with BERT-like embedder base and heavy head perform really decently, I propose that in this task it will have some success too. As a solution I suggest a 3-layer GRU with indian legal BERT base. The reason why I'm using GRU instead of LSTM is that it is less complex, hence GRU based models are computationally more efficient. Also simplier structure is better for this task since many sentences are semantically similar and quite short, so simplier model will more likely to learn something.

In [17]:
trainer_params = {'batch_size': 32,
                  'n_epochs': 3,
                  'lr': 2e-5,
                  'optimizer': 'adamw_torch',
                  'weight_decay': 0.015,
                  'do_fp16': True,
                  'num_workers': 2}

In [18]:
from src.model.transformer_trainer import TransformerTrainer

In [20]:
from torch.cuda import empty_cache
import gc
tf_trainer = TransformerTrainer(bert_name=USED_MODEL_NAME, 
                             num_labels=num_labels,
                             params=trainer_params,
                             id2label=id2label,
                             label2id=label2id,
                             custom=True)
# removing cached file to avoid any possible conflict between consecutive model trainings
print("Cleaning memory...")
! rm -rf ./root/.cache/huggingface/hub/model*
empty_cache()
gc.collect()
trainer_GRU_0 = tf_trainer.fit(train_dataset, 
            val_dataset,
            save_model=False)

Cleaning memory...
Loading : law-ai/InLegalBERT


Downloading (…)lve/main/config.json:   0%|          | 0.00/671 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/534M [00:00<?, ?B/s]

Some weights of the model checkpoint at law-ai/InLegalBERT were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster 

Epoch,Training Loss,Validation Loss,Precison,Recall,Weighted F1
1,1.5694,1.595195,0.668599,0.671636,0.651004
2,1.4936,1.594255,0.642805,0.668583,0.650435
3,1.4823,1.594255,0.642805,0.668583,0.650435


You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Downloading builder script:   0%|          | 0.00/7.55k [00:00<?, ?B/s]

  _warn_prf(average, modifier, msg_start, len(result))


Downloading builder script:   0%|          | 0.00/7.36k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/6.77k [00:00<?, ?B/s]

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  _warn_prf(average, modifier, msg_start, len(result))
You're using a BertTokenizerFast tokenize

In [25]:
from sklearn.metrics import classification_report
preds_for_gru_wo_context = tf_trainer.predict(test_dataset, trainer_GRU_0)

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


  _warn_prf(average, modifier, msg_start, len(result))


Well, we can see that the performance difference with the baseline lies within the margin of error.

In [26]:
preds_for_gru_wo_context = [id2label[np.argmax(pred)] for pred in preds_for_gru_wo_context]
print(classification_report(test_labels, preds_for_gru_wo_context))

                precision    recall  f1-score   support

      ANALYSIS       0.64      0.79      0.71       984
ARG_PETITIONER       0.25      0.36      0.30        70
ARG_RESPONDENT       0.00      0.00      0.00        38
           FAC       0.65      0.73      0.69       580
         ISSUE       0.81      0.76      0.78        50
          NONE       0.90      0.86      0.88       190
      PREAMBLE       0.88      0.67      0.76       508
PRE_NOT_RELIED       0.00      0.00      0.00        12
    PRE_RELIED       0.64      0.40      0.49       142
         RATIO       0.33      0.04      0.08        70
           RLC       0.55      0.26      0.35       116
           RPC       0.81      0.85      0.83        91
           STA       0.53      0.61      0.57        28

      accuracy                           0.68      2879
     macro avg       0.54      0.49      0.49      2879
  weighted avg       0.68      0.68      0.67      2879



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


## bert + context

The task is actually something like long text token classification with very long tokens. This fact gives motivation to make the model know about the context of the sentence it tries to label. Hence, we will feed the left and right contexts of every sentence if it is available and see if it will make results better.

### Downloading

In [8]:
raw_train_df = pd.read_json('./data/raw/train.json', encoding='utf-8')
raw_test_df = pd.read_json('./data/raw/dev.json', encoding='utf-8')

In [9]:
label2id = {'NONE': 0,
            'PREAMBLE': 1,
            'FAC': 2,
            'RLC': 3,
            'ISSUE': 4,
            'ARG_PETITIONER': 5,
            'ARG_RESPONDENT': 6,
            'ANALYSIS': 7,
            'STA': 8,
            'PRE_RELIED': 9,
            'PRE_NOT_RELIED': 10,
            'RATIO': 11,
            'RPC': 12
}

id2label = {1: 'PREAMBLE',
            2: 'FAC',
            3: 'RLC',
            4: 'ISSUE',
            5: 'ARG_PETITIONER',
            6: 'ARG_RESPONDENT',
            7: 'ANALYSIS',
            8: 'STA',
            9: 'PRE_RELIED',
            10: 'PRE_NOT_RELIED',
            11: 'RATIO',
            12: 'RPC',
            0: 'NONE'
}


num_labels = 13
max_seq_len = 512

### Preprocessing

In [11]:
from src.preprocessing.data_preprocessing import DataPreprocessor, ContextExtractor

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


In [12]:
train_data_preprocessor = DataPreprocessor(lower=True)
test_data_preprocessor = DataPreprocessor(lower=True)
train_df = train_data_preprocessor(raw_train_df)
test_df = test_data_preprocessor(raw_test_df)

  0%|          | 0/247 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

In [13]:
train_df.head()

Unnamed: 0_level_0,annotations,data,meta
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
1735,"[{'id': 'd7a902fe9c23417499a7ef782f9fbdeb', 's...","in the high court of karnataka, circuit bench ...",{'group': 'Criminal'}
4183,"[{'id': 'ac4523a0252e4007986cefbd6d5f571a', 's...","1/11 in the high court of karnataka, bengaluru...",{'group': 'Tax'}
4207,"[{'id': '43499bd62ea94624b2f38f4cbc677913', 's...",in the high court of karnataka circuit bench a...,{'group': 'Criminal'}
4097,"[{'id': 'ec5e65782b1949e4a5445a2115ab5382', 's...",petitioner: raghubar mandal harihar mandal vs....,{'group': 'Tax'}
1778,"[{'id': '7323f9247fbc4618bf006ef103d7cb3a', 's...",petitioner: p.k. badiani vs. respondent: the c...,{'group': 'Tax'}


In [14]:
context_extractor = ContextExtractor()

In [15]:
train_df = context_extractor(train_df)
test_df = context_extractor(test_df)
test_labels = test_df['label'].apply(lambda x: label2id[x]).to_list()

Extracting context...


  0%|          | 0/244 [00:00<?, ?it/s]

Extracting context...


  0%|          | 0/30 [00:00<?, ?it/s]

In [16]:
train_df.head()

Unnamed: 0,doc_id,text,context,sentence,label
0,1735,"in the high court of karnataka, circuit bench ...","in the high court of karnataka, circuit bench ...","in the high court of karnataka, circuit bench ...",PREAMBLE
1,1735,"in the high court of karnataka, circuit bench ...","in the high court of karnataka, circuit bench ...",before the hon'ble mr.justice anand byrareddy ...,PREAMBLE
2,1735,"in the high court of karnataka, circuit bench ...","in the high court of karnataka, circuit bench ...",this criminal appeal is filed under section 37...,PREAMBLE
3,1735,"in the high court of karnataka, circuit bench ...","in the high court of karnataka, circuit bench ...","this appeal coming on for hearing this day, th...",PREAMBLE
4,1735,"in the high court of karnataka, circuit bench ...","in the high court of karnataka, circuit bench ...",heard the learned counsel for the appellant an...,NONE


### Train-val split

In [17]:
from sklearn.model_selection import GroupShuffleSplit
splitter = GroupShuffleSplit(n_splits=1, test_size=.15)  
train_idx, val_idx = next(splitter.split(train_df, groups=train_df['text']))
val_df = train_df.iloc[val_idx]
train_df = train_df.iloc[train_idx]

In [18]:
from src.datasets.dataset_builder import DatasetBuilder
builder = DatasetBuilder(indian_legal_uncased_bert_name)

Downloading (…)okenizer_config.json:   0%|          | 0.00/516 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/222k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [19]:
train_dataset = builder.build_dataset(train_df)
val_dataset = builder.build_dataset(val_df)
test_dataset = builder.build_dataset(test_df, for_test=True)

Building dataset...


Map:   0%|          | 0/23343 [00:00<?, ? examples/s]

Building dataset...


Map:   0%|          | 0/4607 [00:00<?, ? examples/s]

Building dataset...


Map:   0%|          | 0/2863 [00:00<?, ? examples/s]

Map:   0%|          | 0/2863 [00:00<?, ? examples/s]

### model

In [20]:
from src.model.transformer_trainer import TransformerTrainer
trainer_params = {'batch_size': 32,
                  'n_epochs': 3,
                  'lr': 2e-5,
                  'optimizer': 'adamw_torch',
                  'weight_decay': 0.01,
                  'do_fp16': True,
                  'num_workers': 2}

In [27]:
from torch.cuda import empty_cache
import gc
# cleaning cache to avoid out of memory errors
! rm -rf ./root/.cache/huggingface/hub/model*
empty_cache()
gc.collect()

5062

In [28]:
TFTrainer = TransformerTrainer(bert_name=indian_legal_uncased_bert_name, 
                                num_labels=num_labels,
                                params=trainer_params,
                                id2label=id2label,
                                label2id=label2id)

In [29]:
device = 'cpu'
if torch.cuda.is_available():
    device = torch.device('cuda:0')

In [30]:
context_trainer = TFTrainer.fit(train_dataset, 
            val_dataset,
            save_model=False)

Loading : law-ai/InLegalBERT


Some weights of the model checkpoint at law-ai/InLegalBERT were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initial

Epoch,Training Loss,Validation Loss,Precison,Recall,Weighted F1
1,1.3799,1.53659,0.773763,0.777078,0.770834
2,1.3244,1.541585,0.778602,0.779466,0.774508
3,1.3004,1.541585,0.778602,0.779466,0.774508


You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Downloading builder script:   0%|          | 0.00/7.55k [00:00<?, ?B/s]

  _warn_prf(average, modifier, msg_start, len(result))


Downloading builder script:   0%|          | 0.00/7.36k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/6.77k [00:00<?, ?B/s]

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  _warn_prf(average, modifier, msg_start, len(result))
You're using a BertTokenizerFast tokenize

In [40]:
from sklearn.metrics import classification_report
context_preds = TFTrainer.predict(test_dataset, context_trainer)

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


  _warn_prf(average, modifier, msg_start, len(result))


Apparently, it made prediction quality almost 15% better. Very good result. Sadly, some classes are still fully or almost fully ignored by a model.

In [39]:
context_preds = [np.argmax(pred) for pred in context_preds]
print(classification_report(test_labels, context_preds))

              precision    recall  f1-score   support

           0       0.07      1.00      0.12       187
           1       0.00      0.00      0.00       505
           2       0.00      0.00      0.00       580
           3       0.00      0.00      0.00       116
           4       0.00      0.00      0.00        50
           5       0.00      0.00      0.00        64
           6       0.00      0.00      0.00        38
           7       0.00      0.00      0.00       981
           8       0.00      0.00      0.00        28
           9       0.00      0.00      0.00       142
          10       0.00      0.00      0.00        12
          11       0.00      0.00      0.00        69
          12       0.00      0.00      0.00        91

    accuracy                           0.07      2863
   macro avg       0.01      0.08      0.01      2863
weighted avg       0.00      0.07      0.01      2863



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
