In [41]:
import os
import json
import torch
import pandas as pd
from torch.utils.data import Dataset
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import train_test_split
from transformers import (
    DistilBertTokenizerFast,
    DistilBertForSequenceClassification,
    Trainer,
    TrainingArguments
)
from sklearn.metrics import f1_score

# 1. Device setup (M1 MPS or CPU)
device = torch.device("cpu") 
print(f"Using device: {device}")

# 2. Load & flatten JSON docs
with open('../data/TRDataChallenge2023.txt','r',encoding='utf-8') as f:
    raw = [json.loads(l) for l in f if l.strip()]
rows = []
for d in raw:
    chunks = []
    for sec in d['sections']:
        h = sec.get('headtext','').strip()
        if h: chunks.append(h)
        chunks.extend(sec.get('paragraphs',[]))
    rows.append({'id': d['documentId'], 'text': '\n'.join(chunks), 'labels': d.get('postures',[])})
df = pd.DataFrame(rows)

# 3. Binarize labels & subset
mlb = MultiLabelBinarizer()
Y = mlb.fit_transform(df['labels'])
df = pd.concat([df, pd.DataFrame(Y, columns=mlb.classes_)], axis=1)
subset = df.sample(frac=0.1, random_state=42).reset_index(drop=True)

# 4. Train/val split
texts = subset['text'].tolist()
labels = subset[mlb.classes_].values.tolist()
train_texts, val_texts, train_labels, val_labels = train_test_split(
    texts, labels, test_size=0.1, random_state=42
)

# 5. Dataset class
class LegalDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=256):
        self.texts, self.labels = texts, labels
        self.tokenizer, self.max_len = tokenizer, max_len
    def __len__(self): return len(self.texts)
    def __getitem__(self, i):
        enc = self.tokenizer(
            self.texts[i],
            truncation=True,
            padding='max_length',
            max_length=self.max_len,
            return_tensors='pt'
        )
        item = {k: v.squeeze(0) for k,v in enc.items()}
        item['labels'] = torch.tensor(self.labels[i], dtype=torch.float)
        return item

# 6. Tokenizer & Datasets
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
train_ds = LegalDataset(train_texts, train_labels, tokenizer)
val_ds   = LegalDataset(val_texts,   val_labels,   tokenizer)

# 7. Model & freeze encoder
model = DistilBertForSequenceClassification.from_pretrained(
    'distilbert-base-uncased',
    problem_type='multi_label_classification',
    num_labels=len(mlb.classes_)
)
# freeze all DistilBERT layers
for p in model.distilbert.parameters():
    p.requires_grad = False
model.to(device)

# 8. TrainingArguments & Trainer
training_args = TrainingArguments(
    output_dir='./out',
    eval_strategy='epoch',
    save_strategy='epoch',
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=1,
    logging_steps=100,
    learning_rate=2e-5,
    use_cpu=True,
    load_best_model_at_end=True,
    metric_for_best_model='micro_f1',
)

def compute_metrics(pred):
    logits = pred.predictions
    preds = (torch.sigmoid(torch.tensor(logits)) > 0.5).int().numpy()
    return {
        'micro_f1': f1_score(pred.label_ids, preds, average='micro'),
        'macro_f1': f1_score(pred.label_ids, preds, average='macro'),
    }

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    compute_metrics=compute_metrics,
)

# 9. Train & evaluate
trainer.train()
metrics = trainer.evaluate()
print("Eval metrics:", metrics)

# 10. Inference example
examples = [
    "On appeal, the trial court lacked jurisdiction.",
    "Motion to dismiss for failure to state a claim."
]
enc = tokenizer(examples, truncation=True, padding=True, return_tensors='pt').to(device)
with torch.no_grad():
    logits = model(**enc).logits
probs = torch.sigmoid(logits).cpu().numpy()
for text, prob in zip(examples, probs):
    labs = [mlb.classes_[i] for i,p in enumerate(prob) if p>0.5]
    print(f"\n>> {text}\n→ {labs} (scores: {prob.round(2)})")

Using device: cpu


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Micro F1,Macro F1
1,0.4752,0.434434,0.030457,0.000576


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Eval metrics: {'eval_loss': 0.43443435430526733, 'eval_micro_f1': 0.030456852791878174, 'eval_macro_f1': 0.000576036866359447, 'eval_runtime': 6.2116, 'eval_samples_per_second': 28.978, 'eval_steps_per_second': 3.703, 'epoch': 1.0}

>> On appeal, the trial court lacked jurisdiction.
→ ['Motion to Vacate Wardship'] (scores: [0.38 0.37 0.38 0.29 0.35 0.35 0.39 0.38 0.37 0.34 0.38 0.37 0.34 0.3
 0.33 0.33 0.27 0.32 0.31 0.4  0.34 0.35 0.37 0.36 0.32 0.4  0.33 0.33
 0.31 0.37 0.41 0.36 0.32 0.33 0.33 0.37 0.47 0.32 0.44 0.38 0.31 0.38
 0.35 0.38 0.4  0.41 0.43 0.29 0.42 0.41 0.33 0.35 0.41 0.43 0.33 0.4
 0.35 0.32 0.28 0.37 0.4  0.36 0.34 0.34 0.37 0.34 0.35 0.43 0.42 0.3
 0.4  0.46 0.36 0.34 0.41 0.29 0.39 0.42 0.36 0.4  0.37 0.32 0.32 0.39
 0.4  0.4  0.35 0.41 0.35 0.43 0.31 0.28 0.35 0.33 0.37 0.34 0.3  0.35
 0.35 0.39 0.39 0.43 0.35 0.34 0.37 0.37 0.39 0.29 0.33 0.36 0.37 0.35
 0.37 0.39 0.28 0.37 0.28 0.37 0.39 0.41 0.32 0.35 0.46 0.34 0.33 0.34
 0.36 0.35 0.4  0.39 0.41 0.43 0.34 0.3

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
