In [1]:
# ===========================================
# 📒 Notebook: Legal Model Training (TFIDF-SRT)
# ===========================================

# ------- 1. Setup --------
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '2,3,4,5,6'

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datasets import load_dataset, Dataset
from transformers import (
    AutoTokenizer, AutoModelForSequenceClassification,
    Trainer, TrainingArguments, set_seed
)
from sklearn.metrics import f1_score
from sklearn.feature_extraction.text import TfidfVectorizer
from tqdm import tqdm
from collections import defaultdict
import random
import itertools

# Set seed
seed = 42
set_seed(seed)

# ------- 2. Configs --------
model_names = {
    "legal_BERT": "nlpaueb/legal-bert-base-uncased",
    "legal_longformer": "lexlms/legal-longformer-base",
    "legal_Roberta": "lexlms/legal-roberta-base"
}

learning_rates = [1e-5, 2e-5, 3e-5]
dropout_rate = 0.1
epochs_list = [3, 4]
tfidf_bucket_sizes = [16, 32]

# ------- 3. Load Datasets --------
original_dataset = load_dataset("coastalcph/lex_glue", "scotus")
dedup_and_sort = load_dataset("victorambrose11/scotus_deduplicate_sort")
norm_dedup_sort = load_dataset("victorambrose11/scotus_normalize_deduplicate_sort")

label_list = original_dataset["train"].features["label"].names
num_labels = len(label_list)

# ------- 4. TFIDF-SRT-EMB Preprocessing --------
def tfidf_score_to_bucket(score, num_buckets=32, min_val=0.0, max_val=10.0):
    if score < min_val:
        return 0
    elif score >= max_val:
        return num_buckets - 1
    normalized = (score - min_val) / (max_val - min_val)
    return int(normalized * (num_buckets - 1))

def preprocess_tfidf_srt_emb(dataset, tokenizer, max_length=512, num_buckets=32):
    tokenized_train = [" ".join(tokenizer.tokenize(text)) for text in dataset['train']['text']]
    tfidf_vectorizer = TfidfVectorizer(analyzer='word', token_pattern=r'\S+')
    tfidf_vectorizer.fit(tokenized_train)
    idf_dict = dict(zip(tfidf_vectorizer.get_feature_names_out(), tfidf_vectorizer.idf_))

    def process_split(split_name):
        data = []
        for text, label in zip(dataset[split_name]['text'], dataset[split_name]['label']):
            tokens = tokenizer.tokenize(text)
            seen = set()
            unique_tokens = []
            for t in tokens:
                if t not in seen:
                    unique_tokens.append(t)
                    seen.add(t)
            token_scores = {t: idf_dict.get(t, 0.0) for t in unique_tokens}
            sorted_tokens = sorted(unique_tokens, key=lambda t: token_scores[t], reverse=True)[:max_length - 2]
            tokens_final = [tokenizer.cls_token] + sorted_tokens + [tokenizer.sep_token]
            input_ids = tokenizer.convert_tokens_to_ids(tokens_final)
            attention_mask = [1] * len(input_ids)
            bucket_ids = [0] + [tfidf_score_to_bucket(token_scores.get(t, 0.0), num_buckets) for t in sorted_tokens] + [0]
            pad_len = max_length - len(input_ids)
            input_ids += [tokenizer.pad_token_id] * pad_len
            attention_mask += [0] * pad_len
            bucket_ids += [0] * pad_len
            data.append({"input_ids": input_ids, "attention_mask": attention_mask, "labels": label, "tfidf_bucket_ids": bucket_ids})
        return Dataset.from_list(data)

    return {
        "train": process_split("train"),
        "validation": process_split("validation")
    }

# ------- 5. Custom Model for TFIDF-SRT-EMB --------
from torch import nn
from transformers import BertModel

class TfidfSRTEMBLegalBERT(nn.Module):
    def __init__(self, model_name, num_labels, bucket_size, dropout=0.1):
        super().__init__()
        self.base_model = BertModel.from_pretrained(model_name)
        self.bucket_embedding = nn.Embedding(bucket_size, self.base_model.config.hidden_size)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(self.base_model.config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask, tfidf_bucket_ids, labels=None):
        outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
        token_embeddings = outputs.last_hidden_state
        bucket_embeds = self.bucket_embedding(tfidf_bucket_ids)
        combined = token_embeddings + bucket_embeds
        pooled = combined[:, 0]  # CLS token
        logits = self.classifier(self.dropout(pooled))
        loss = None
        if labels is not None:
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(logits, labels)
        return {"loss": loss, "logits": logits}

# ------- 6. Training Function & Grid Search --------
from transformers import Trainer, TrainingArguments

def compute_f1(pred):
    preds = np.argmax(pred.predictions, axis=1)
    labels = pred.label_ids
    return {
        "micro_f1": f1_score(labels, preds, average="micro", zero_division=0),
        "macro_f1": f1_score(labels, preds, average="macro", zero_division=0),
    }

def train_and_evaluate(model_key, dataset, dataset_label, lr, epochs, is_tfidf_emb=False, bucket_size=32):
    tokenizer = AutoTokenizer.from_pretrained(model_names[model_key])

    if is_tfidf_emb:
        processed_dataset = preprocess_tfidf_srt_emb(dataset, tokenizer, num_buckets=bucket_size)
        model = TfidfSRTEMBLegalBERT(model_name=model_names[model_key], num_labels=num_labels, bucket_size=bucket_size)

        def collate_fn(batch):
            return {
                "input_ids": torch.tensor([item["input_ids"] for item in batch]),
                "attention_mask": torch.tensor([item["attention_mask"] for item in batch]),
                "tfidf_bucket_ids": torch.tensor([item["tfidf_bucket_ids"] for item in batch]),
                "labels": torch.tensor([item["labels"] for item in batch])
            }
    else:
        def tokenize_fn(examples):
            return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512)
        processed_dataset = dataset.map(tokenize_fn, batched=True)
        processed_dataset = processed_dataset.rename_column("label", "labels")
        processed_dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
        model = AutoModelForSequenceClassification.from_pretrained(model_names[model_key], num_labels=num_labels)
        collate_fn = None

    args = TrainingArguments(
        output_dir=f"./results_{model_key}_{dataset_label}_{lr}_{epochs}",
        num_train_epochs=epochs,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=1,
        learning_rate=lr,
        weight_decay=0.01,
        load_best_model_at_end=True,
        metric_for_best_model="macro_f1",
        logging_dir=f"./logs_{model_key}_{dataset_label}",
        seed=seed
    )

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=processed_dataset["train"],
        eval_dataset=processed_dataset["validation"],
        compute_metrics=compute_f1,
        tokenizer=tokenizer,
        data_collator=collate_fn
    )

    trainer.train()
    return trainer.evaluate()

# ------- 7. Run Grid Search --------
datasets_dict = {
    "original": original_dataset,
    "dedup_sort": dedup_and_sort,
    "norm_dedup_sort": norm_dedup_sort
}

results = []

for model_key in model_names:
    for dataset_label, dataset in datasets_dict.items():
        for lr, ep in itertools.product(learning_rates, epochs_list):
            print(f"\n🚀 Training {model_key} on {dataset_label} | lr={lr}, epochs={ep}")
            metrics = train_and_evaluate(model_key, dataset, dataset_label, lr, ep)
            results.append({"Model": model_key, "Dataset": dataset_label, "lr": lr, "Epochs": ep, **metrics})

# TFIDF-SRT-EMB grid search
for bucket_size in tfidf_bucket_sizes:
    for lr, ep in itertools.product(learning_rates, epochs_list):
        print(f"\n📘 Training TFIDF-SRT-EMB (LegalBERT) | buckets={bucket_size}, lr={lr}, epochs={ep}")
        metrics = train_and_evaluate("legal_BERT", original_dataset, f"tfidf_emb_{bucket_size}", lr, ep, is_tfidf_emb=True, bucket_size=bucket_size)
        results.append({"Model": "TFIDF-SRT-EMB", "Dataset": "original", "Buckets": bucket_size, "lr": lr, "Epochs": ep, **metrics})

# ------- 8. Results Table --------
results_df = pd.DataFrame(results)
results_df.to_csv("training_results.csv", index=False)

print("\n✅ All training complete. Top results:")
print(results_df.sort_values("macro_f1", ascending=False).head(10))

# ------- 9. Plot --------
plt.figure(figsize=(12, 6))
results_df.groupby("Model")["macro_f1"].max().sort_values().plot(kind="barh")
plt.title("Max Macro F1 Score per Model")
plt.xlabel("Macro F1")
plt.grid(True, linestyle="--", alpha=0.6)
plt.tight_layout()
plt.show()


  from .autonotebook import tqdm as notebook_tqdm
Generating train split: 100%|██████████| 5000/5000 [00:00<00:00, 37779.44 examples/s]
Generating test split: 100%|██████████| 1400/1400 [00:00<00:00, 30251.33 examples/s]
Generating validation split: 100%|██████████| 1400/1400 [00:00<00:00, 35709.66 examples/s]
Generating train split: 100%|██████████| 5000/5000 [00:00<00:00, 56622.09 examples/s]
Generating test split: 100%|██████████| 1400/1400 [00:00<00:00, 33800.11 examples/s]
Generating validation split: 100%|██████████| 1400/1400 [00:00<00:00, 43183.33 examples/s]



🚀 Training legal_BERT on original | lr=1e-05, epochs=3


Map: 100%|██████████| 5000/5000 [00:38<00:00, 130.43 examples/s]
Map: 100%|██████████| 1400/1400 [00:16<00:00, 86.41 examples/s]
Map: 100%|██████████| 1400/1400 [00:15<00:00, 89.73 examples/s]
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nlpaueb/legal-bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Micro F1,Macro F1
1,1.5673,1.157971,0.665,0.40937
2,0.922,0.959948,0.717857,0.542829
3,0.7393,0.920369,0.721429,0.562951



🚀 Training legal_BERT on original | lr=1e-05, epochs=4


Map: 100%|██████████| 1400/1400 [00:14<00:00, 98.74 examples/s] 
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nlpaueb/legal-bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Micro F1,Macro F1
1,1.542,1.118323,0.670714,0.407651
2,0.8954,0.927648,0.717857,0.525364
3,0.7085,0.86254,0.744286,0.634787
4,0.5006,0.864499,0.751429,0.642821



🚀 Training legal_BERT on original | lr=2e-05, epochs=3


Map: 100%|██████████| 1400/1400 [00:11<00:00, 125.28 examples/s]
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nlpaueb/legal-bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Micro F1,Macro F1
1,1.3111,1.000257,0.685714,0.465686
2,0.7395,0.815324,0.757143,0.649954
3,0.5377,0.820757,0.762143,0.666257



🚀 Training legal_BERT on original | lr=2e-05, epochs=4


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nlpaueb/legal-bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Micro F1,Macro F1
1,1.3071,0.942622,0.690714,0.498688
2,0.7344,0.811502,0.763571,0.658393
3,0.5273,0.832741,0.769286,0.679355
4,0.291,0.883916,0.767857,0.676017



🚀 Training legal_BERT on original | lr=3e-05, epochs=3


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nlpaueb/legal-bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Micro F1,Macro F1
1,1.2284,0.947471,0.696429,0.545299
2,0.7092,0.799445,0.768571,0.667877
3,0.4894,0.847755,0.772143,0.678068



🚀 Training legal_BERT on original | lr=3e-05, epochs=4


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nlpaueb/legal-bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Micro F1,Macro F1
1,1.2197,0.866162,0.728571,0.622241
2,0.6841,0.819217,0.769286,0.677119
3,0.4707,0.86317,0.772143,0.695557
4,0.2055,0.968443,0.777857,0.694248



🚀 Training legal_BERT on dedup_sort | lr=1e-05, epochs=3


Map: 100%|██████████| 5000/5000 [00:07<00:00, 691.84 examples/s]
Map: 100%|██████████| 1400/1400 [00:02<00:00, 548.75 examples/s]
Map: 100%|██████████| 1400/1400 [00:02<00:00, 574.05 examples/s]
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nlpaueb/legal-bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Micro F1,Macro F1
1,1.6412,1.269697,0.626429,0.369338
2,1.0644,1.135549,0.667857,0.398601
3,0.9031,1.121426,0.667143,0.409822



🚀 Training legal_BERT on dedup_sort | lr=1e-05, epochs=4


Map: 100%|██████████| 1400/1400 [00:03<00:00, 450.81 examples/s]
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nlpaueb/legal-bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Micro F1,Macro F1
1,1.6465,1.255107,0.645,0.379073
2,1.0569,1.133512,0.671429,0.393746
3,0.8921,1.086531,0.680714,0.463737
4,0.7071,1.077594,0.691429,0.471879



🚀 Training legal_BERT on dedup_sort | lr=2e-05, epochs=3


Map: 100%|██████████| 1400/1400 [00:02<00:00, 479.50 examples/s]
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nlpaueb/legal-bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Micro F1,Macro F1
1,1.5099,1.184935,0.648571,0.383958
2,0.9327,1.075147,0.696429,0.497329
3,0.7483,1.053212,0.709286,0.542249



🚀 Training legal_BERT on dedup_sort | lr=2e-05, epochs=4


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nlpaueb/legal-bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Micro F1,Macro F1
1,1.6082,1.232247,0.642857,0.376883
2,0.974,1.118428,0.677857,0.447552
3,0.7948,1.072031,0.695714,0.536559
4,0.5436,1.078284,0.705,0.543021



🚀 Training legal_BERT on dedup_sort | lr=3e-05, epochs=3


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nlpaueb/legal-bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Micro F1,Macro F1
1,1.4784,1.208966,0.657857,0.402549
2,0.9074,1.077614,0.694286,0.513076


KeyboardInterrupt: 