In [11]:
# ================================
# 1. Load Dataset
# ================================
import pandas as pd
from datasets import Dataset
from sklearn.model_selection import train_test_split

# Load local batch.csv
df = pd.read_csv("batch.csv")
print("Columns:", df.columns.tolist())
print(df.head())

# Train-test split (stratified to preserve class balance)
train_df, test_df = train_test_split(
    df, test_size=0.2, stratify=df["label"], random_state=42
)

train_dataset = Dataset.from_pandas(train_df.reset_index(drop=True))
test_dataset = Dataset.from_pandas(test_df.reset_index(drop=True))


Columns: ['filename', 'text', 'label']
                                   filename  \
0          Custom_Excise_and_Gold_1996_1183   
1  Delhi_District_Court_2007_2020_2011_1216   
2                       Kolkata_HC_1987_173   
3                     SupremeCourt_1997_944   
4                         Patna_HC_2012_847   

                                                text  label  
0  L. Bhat, J.(President) Appellants were engaged...      1  
1  BRIEF FACTS Adumbrated in brief the prosecutio...      0  
2  Monoj Kumar J.The appellants before us the sui...      1  
3  WITH CIVIL APPEALS NOS 6263 6264 OF 1997 ( S.L...      2  
4  (Per HONOURABLE MR. JUSTICE AMARESH KUMAR LAL)...      1  


In [12]:
# ================================
# 2. Load InLegalBERT Model + Tokenizer
# ================================
from transformers import AutoTokenizer, AutoModelForSequenceClassification

model_name = "law-ai/InLegalBERT"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at law-ai/InLegalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [13]:
# ================================
# 3. Tokenization
# ================================
def tokenize(batch):
    texts = [str(x) for x in batch["text"]]  # ensure all inputs are strings
    return tokenizer(texts, padding="max_length", truncation=True, max_length=512)

train_dataset = train_dataset.map(tokenize, batched=True, batch_size=16)
test_dataset = test_dataset.map(tokenize, batched=True, batch_size=16)

# Rename label -> labels
train_dataset = train_dataset.rename_column("label", "labels")
test_dataset = test_dataset.rename_column("label", "labels")

# Remove unnecessary columns
cols_to_remove = [c for c in train_dataset.column_names if c not in ["labels","input_ids","attention_mask","token_type_ids"]]
train_dataset = train_dataset.remove_columns(cols_to_remove)
test_dataset = test_dataset.remove_columns(cols_to_remove)

# Set PyTorch format
train_dataset.set_format("torch")
test_dataset.set_format("torch")

print(train_dataset[0])


Map:   0%|          | 0/23999 [00:00<?, ? examples/s]

Map:   0%|          | 0/6000 [00:00<?, ? examples/s]

{'labels': tensor(0), 'input_ids': tensor([  101,   154,   117,   226,   450,  2713,   615,   223,   218,   207,
          371,   116,  1324,   213,   145,   511,   842,   256,   230,   117,
         1043,   210,  4215,   373,   207,   296,   240,   210,  8334,   189,
         3850,   236,   351,   207,   308,   210,   210,   207,   296,  2713,
          615,  3740,   218,   207,  5289,   865,   236,  9886,   179,   171,
          861,   303,   210,   237,   308,   210,   207,  6532,  8334,   189,
         3850,   210,   207,  3678,   235,   240,  4793,   211,  2549,   207,
          842,   417,   222,   207, 14479,   210,   207,   371,   116,  1324,
          117,   207,   444,   889,   245,   219,   763,   213,  2259,   117,
          145,   246,  1013,   351,   207,   371,   116,  1324,   212,   207,
          511,   842,   256,   246,  4305,   217,   842,   210,   207,   786,
          217,  1762,   210,   207,   371,   116,  1324,   115,   233,  1763,
          115,   678,   145, 

In [17]:
# ================================
# 4. Training Setup
# ================================
from transformers import TrainingArguments, Trainer
import evaluate

accuracy = evaluate.load("accuracy")
f1 = evaluate.load("f1")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = logits.argmax(axis=-1)
    acc = accuracy.compute(predictions=preds, references=labels)
    f1_macro = f1.compute(predictions=preds, references=labels, average="macro")
    f1_per_class = f1.compute(predictions=preds, references=labels, average=None)
    return {
        "accuracy": acc["accuracy"],
        "macro_f1": f1_macro["f1"],
        "f1_class_0": f1_per_class["f1"][0],
        "f1_class_1": f1_per_class["f1"][1],
        "f1_class_2": f1_per_class["f1"][2]
    }

from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./results_inlegalbert",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    learning_rate=2e-5,
    weight_decay=0.01,
    do_eval=True,          # only runs eval after training
    logging_dir="./logs"
)




In [18]:
# ================================
# 5. Train Model
# ================================
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

trainer.train()


  trainer = Trainer(
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None}.


Step,Training Loss


KeyboardInterrupt: 

In [None]:
# ================================
# 6. Evaluate
# ================================
results = trainer.evaluate()
print("Evaluation Results:", results)


In [None]:
# ================================
# 7. Graphs
# ================================
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# --- Bar: Accuracy vs Macro-F1 ---
plt.bar(["Accuracy", "Macro-F1"], [results["eval_accuracy"], results["eval_macro_f1"]],
        color=["skyblue", "orange"])
plt.title("InLegalBERT Performance")
plt.ylim(0,1)
plt.show()

# --- Bar: Per-class F1 ---
plt.bar(["Rejected (0)", "Accepted (1)", "Partially Accepted (2)"],
        [results["eval_f1_class_0"], results["eval_f1_class_1"], results["eval_f1_class_2"]],
        color=["red", "green", "blue"])
plt.title("InLegalBERT Per-class F1 Scores")
plt.ylim(0,1)
plt.show()

# --- Confusion Matrix ---
preds = trainer.predict(test_dataset).predictions.argmax(-1)
true_labels = test_dataset["labels"]

cm = confusion_matrix(true_labels, preds, labels=[0,1,2])
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Rejected","Accepted","Partially Accepted"])
disp.plot(cmap="Blues", values_format="d")
plt.title("Confusion Matrix - InLegalBERT")
plt.show()
