In [2]:
# ============================================================================
# CONFIGURATION PARAMETERS
# ============================================================================

# Data Configuration
DATA_PATH = "C:\\Users\\rhrou\\Downloads\\justice.csv"
TARGET_COLUMN = 'first_party_winner'
TEXT_COLUMN = 'facts'

# Train-Test Split Configuration
TEST_SIZE = 0.25
RANDOM_STATE = 42

# BERT Configuration
# Options: 'bert-base-uncased', 'nlpaueb/legal-bert-base-uncased'
BERT_MODEL_NAME = 'bert-base-uncased'  

BERT_MAX_LENGTH = 512
BERT_BATCH_SIZE = 40

# Text Preprocessing Configuration
LOWERCASE = True  # Set to True for 'uncased' models, False for 'cased' models
NORMALIZE_UNICODE = True

## Data Loading

In [13]:
import numpy as np
import pandas as pd
import unicodedata
import warnings
warnings.filterwarnings('ignore')

import torch
from transformers import BertTokenizer, BertModel
from sklearn.model_selection import train_test_split
from tqdm import tqdm

df = pd.read_csv("data/justice.csv")
df = df.dropna(subset=["facts", "first_party_winner"])
df["label"] = df["first_party_winner"].astype(int)

print(f"      Loaded: {df.shape[0]} cases, {df.shape[1]} columns")

# 2. Remove Missing Values
print(f"\n[2/6] Removing missing values...")
initial_count = df.shape[0]
df.dropna(inplace=True)
print(f"      Retained: {df.shape[0]} cases ({df.shape[0]/initial_count*100:.1f}%)")

# 4. Extract and Clean Text
print(f"\n[3/6] Preprocessing text data...")
df_nlp = df[[TEXT_COLUMN]].copy()

# Remove HTML tags
df['facts'] = df['facts'].str.replace(r'<[^<>]*>', '', regex=True)

# Apply text preprocessing
df_nlp['facts_clean'] = df_nlp[TEXT_COLUMN].apply(
    lambda x: preprocess_text(x, lowercase=LOWERCASE, normalize=NORMALIZE_UNICODE))


df.head()

from sklearn.model_selection import train_test_split
train_texts, val_texts, train_labels, val_labels = train_test_split(
    df["facts"].tolist(),
    df["label"].tolist(),
    test_size=0.2,
    stratify=df["label"],
    random_state=42
)

      Loaded: 3288 cases, 17 columns

[2/6] Removing missing values...
      Retained: 3098 cases (94.2%)

[3/6] Preprocessing text data...


## 

In [29]:
# Wraps tokenized text and labels so PyTorch's dataloader can batch and feed them to the model during training
class JusticeDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels, model_type=None):
        self.encodings = encodings
        self.labels = labels
        self.model_type = model_type
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx])
         # DistilBERT and MiniLM don't use token_type_ids
        if self.model_type in ["Step1_DistilBERT", "Step4_MiniLM", "Step3_TinyBERT"]:
            item.pop("token_type_ids", None)
        return item

# Compute metrics function    
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    acc = accuracy_score(labels, preds)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average=None)
    macro_f1 = f1_score(labels, preds, average="macro")
    return {
        "accuracy": acc,
        "precision_class0": precision[0],
        "recall_class0": recall[0],
        "f1_class0": f1[0],
        "precision_class1": precision[1],
        "recall_class1": recall[1],
        "f1_class1": f1[1],
        "macro_f1": macro_f1
    }

## Training hyperparameters

In [30]:
import torch
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    save_strategy="no",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_steps=50
)

## Define Interpretability Function through Integrated Gradients

In [31]:
def explain_prediction(model, tokenizer, text, label=None):
    model.eval()
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        max_length=128,
        padding="max_length"
    )
    
    # FIX: Move inputs to same device as model
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    if label is None:
        with torch.no_grad():
            outputs = model(**inputs)
            label = outputs.logits.argmax(-1).item()

    def forward_func(embeddings, attention_mask=None):
        outputs = model(inputs_embeds=embeddings, attention_mask=attention_mask)
        return torch.softmax(outputs.logits, dim=1)[:, label]

    embeddings = model.get_input_embeddings()(inputs["input_ids"])
    ig = IntegratedGradients(forward_func)
    attributions, _ = ig.attribute(
        embeddings,
        additional_forward_args=(inputs["attention_mask"],),
        return_convergence_delta=True
    )

    scores = attributions.sum(dim=-1).squeeze(0).detach().cpu().numpy()  # .cpu() before .numpy()
    return scores

## Model Selections

In [None]:
# 

model_configs = {
    "Teacher_BERT": {
        "tokenizer": "bert-base-uncased",
        "model": lambda: BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
    },
    "Step1_DistilBERT": {
        "tokenizer": "distilbert-base-uncased",
        "model": lambda: DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
    },
    "Step3_TinyBERT": {
        "tokenizer": "bert-base-uncased",
        "model": lambda: BertForSequenceClassification.from_pretrained("huawei-noah/TinyBERT_General_4L_512D", num_labels=2)
    }
    # "Step4_MiniLM": {
    #     "tokenizer": "bert-base-uncased",
    #     "model": lambda: BertForSequenceClassification.from_pretrained("microsoft/MiniLM-L12-H384-uncased", num_labels=2)
    # } - this one keeps gives the error on token_type_ids not found
}


## Loop through models

In [36]:
import torch
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score, precision_recall_fscore_support
from transformers import (
    BertForSequenceClassification, BertTokenizer, BertConfig,
    DistilBertForSequenceClassification, DistilBertTokenizer,
    Trainer, TrainingArguments
)
from captum.attr import IntegratedGradients
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity

results = {}
teacher_vectors = {}

for name, cfg in model_configs.items():
    print(f"\nüîπ Training {name}...")

    # Load tokenizer
    tokenizer = BertTokenizer.from_pretrained(cfg["tokenizer"]) if "bert" in cfg["tokenizer"] else DistilBertTokenizer.from_pretrained(cfg["tokenizer"])

    # FIX: Don't create token_type_ids for DistilBERT
    return_token_type_ids = "distil" not in cfg["tokenizer"].lower()
    
    # Tokenize
    train_enc = tokenizer(
        train_texts, 
        truncation=True, 
        padding=True, 
        max_length=128,
        return_token_type_ids=return_token_type_ids  # ‚Üê Add this
    )
    val_enc = tokenizer(
        val_texts, 
        truncation=True, 
        padding=True, 
        max_length=128,
        return_token_type_ids=return_token_type_ids  # ‚Üê Add this
    )
    
    train_ds = JusticeDataset(train_enc, train_labels, model_type=name)
    val_ds = JusticeDataset(val_enc, val_labels, model_type=name)
    # Load model
    model = cfg["model"]()

    # Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=val_ds,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics
    )

    trainer.train()
    metrics = trainer.evaluate()
    results[name] = metrics

    # Explanations: only store teacher‚Äôs attributions once
    if name == "Teacher_BERT":
        teacher_vectors = [explain_prediction(model, tokenizer, text) for text in val_texts[:50]]
    else:
        student_vectors = [explain_prediction(model, tokenizer, text) for text in val_texts[:50]]
        cos_sims = []
        for t_vec, text in zip(teacher_vectors, val_texts[:50]):
            s_vec = explain_prediction(model, tokenizer, text)
            sim = cosine_similarity([t_vec], [s_vec])[0][0]
            cos_sims.append(sim)
        results[name]["cosine_similarity"] = np.mean(cos_sims)


üîπ Training Teacher_BERT...


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,Precision Class0,Recall Class0,F1 Class0,Precision Class1,Recall Class1,F1 Class1,Macro F1
1,0.609,0.641219,0.667742,0.0,0.0,0.0,0.667742,1.0,0.800774,0.400387
2,0.6605,0.634657,0.669355,1.0,0.004854,0.009662,0.668821,1.0,0.801549,0.405605
3,0.6029,0.636075,0.662903,0.461538,0.087379,0.146939,0.67642,0.949275,0.78995,0.468444



üîπ Training Step1_DistilBERT...


The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DistilBertTokenizer'. 
The class this function is called from is 'BertTokenizer'.
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,Precision Class0,Recall Class0,F1 Class0,Precision Class1,Recall Class1,F1 Class1,Macro F1
1,0.6102,0.636517,0.667742,0.0,0.0,0.0,0.667742,1.0,0.800774,0.400387
2,0.6384,0.644972,0.670968,0.55,0.053398,0.097345,0.675,0.978261,0.798817,0.448081
3,0.5796,0.66235,0.630645,0.370787,0.160194,0.223729,0.6742,0.864734,0.757672,0.4907


TypeError: DistilBertForSequenceClassification.forward() got an unexpected keyword argument 'token_type_ids'