# Legal-BERT GDPR: Training, Evaluation, and Inference

This notebook helps you:
- Inspect the dataset
- Train (or reuse an existing) Legal-BERT 3-way classifier
- Evaluate and visualize results
- Run quick inference examples

Notes:
- Make sure your virtual environment is active before running.
- For RTX 50xx GPUs on Windows, we installed torch 2.9.0+cu128 earlier.
- If you hit an fp16 scaler error ("Attempting to unscale FP16 gradients"), we disable fp16 by default here.

In [None]:
# Environment check
import sys, platform
print('PYTHON:', sys.executable)
print('VERSION:', platform.python_version())
try:
    import torch
    print('TORCH:', torch.__version__, '| CUDA:', getattr(torch.version, 'cuda', None), '| CUDA_AVAILABLE:', torch.cuda.is_available())
    if torch.cuda.is_available():
        print('CUDA_DEVICE:', torch.cuda.get_device_name(0))
except Exception as e:
    print('Torch import error:', e)

In [None]:
# Config
from pathlib import Path
DATA_PATH = Path('training/gdpr_final_training_dataset.csv')
MODEL_NAME = 'nlpaueb/legal-bert-base-uncased'
OUTPUT_DIR = Path('training/models/legalbert_3way')
EPOCHS = 4
BATCH_SIZE = 8
LR = 2e-5
MAX_LENGTH = 256
SEED = 42
USE_FP16 = False  # disable to avoid GradScaler error observed

OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
print('DATA_PATH=', DATA_PATH.resolve())
print('OUTPUT_DIR=', OUTPUT_DIR.resolve())

In [None]:
# Dataset summary
import json
import pandas as pd
df = pd.read_csv(DATA_PATH)
print('Rows:', len(df))
print('Columns:', list(df.columns))
print('Label counts:', df['label'].value_counts().to_dict())
df.head(3)

In [None]:
# Helpers: split, metrics, plotting
import numpy as np
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

CANONICAL_LABELS = ['compliant', 'ambiguous', 'non_compliant']
LABEL2ID = {l: i for i, l in enumerate(CANONICAL_LABELS)}
ID2LABEL = {i: l for l, i in LABEL2ID.items()}

def normalize_labels(s):
    def norm(x):
        x = str(x).strip().lower().replace('-', '_').replace(' ', '_')
        if x in CANONICAL_LABELS:
            return x
        return {'compliance':'compliant','noncompliant':'non_compliant','non_compliance':'non_compliant','unclear':'ambiguous','unknown':'ambiguous'}.get(x, 'ambiguous')
    return s.map(norm)

df['label'] = normalize_labels(df['label'])

def stratified_split(df, seed=42):
    splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=seed)
    y = df['label'].map(LABEL2ID)
    train_idx, temp_idx = next(splitter.split(df, y))
    train_df = df.iloc[train_idx].reset_index(drop=True)
    temp_df = df.iloc[temp_idx].reset_index(drop=True)
    splitter2 = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=seed)
    y_temp = temp_df['label'].map(LABEL2ID)
    dev_idx, test_idx = next(splitter2.split(temp_df, y_temp))
    return (train_df.iloc[dev_idx].reset_index(drop=True),
            train_df.iloc[test_idx].reset_index(drop=True)) if False else (
            train_df, temp_df.iloc[dev_idx].reset_index(drop=True), temp_df.iloc[test_idx].reset_index(drop=True))

def plot_confusion(cm, labels):
    plt.figure(figsize=(4,3))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)
    plt.ylabel('True')
    plt.xlabel('Pred')
    plt.show()

In [None]:
# Train (if needed) or load existing model
import random
import torch
from transformers import (AutoTokenizer, AutoModelForSequenceClassification,
                          DataCollatorWithPadding, Trainer, TrainingArguments)
from transformers.trainer_utils import set_seed

set_seed(SEED); random.seed(SEED); np.random.seed(SEED)
train_df, dev_df, test_df = stratified_split(df, seed=SEED)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

def encode(texts, max_length):
    return tokenizer(texts, truncation=True, max_length=max_length, padding=False)

train_enc = encode(train_df['text'].tolist(), MAX_LENGTH)
dev_enc = encode(dev_df['text'].tolist(), MAX_LENGTH)
test_enc = encode(test_df['text'].tolist(), MAX_LENGTH)

class EncodedDataset(torch.utils.data.Dataset):
    def __init__(self, enc, labels):
        self.enc = enc
        self.labels = labels
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        item = {k: self.enc[k][idx] for k in self.enc}
        item['labels'] = self.labels[idx]
        return item

y_train = train_df['label'].map(LABEL2ID).tolist()
y_dev = dev_df['label'].map(LABEL2ID).tolist()
y_test = test_df['label'].map(LABEL2ID).tolist()

train_ds = EncodedDataset(train_enc, y_train)
dev_ds = EncodedDataset(dev_enc, y_dev)
test_ds = EncodedDataset(test_enc, y_test)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    rep = classification_report(labels, preds, target_names=CANONICAL_LABELS, output_dict=True, zero_division=0)
    out = {
        'accuracy': float(rep.get('accuracy', 0.0)),
        'f1_macro': float(rep['macro avg']['f1-score']),
    }
    return out

if (OUTPUT_DIR / 'model.safetensors').exists():
    print('Model exists. Loading from', OUTPUT_DIR)
    model = AutoModelForSequenceClassification.from_pretrained(str(OUTPUT_DIR))
else:
    print('Training new model...')
    model = AutoModelForSequenceClassification.from_pretrained(
        MODEL_NAME, num_labels=len(CANONICAL_LABELS), id2label=ID2LABEL, label2id=LABEL2ID
    )
    training_args = TrainingArguments(
        output_dir=str(OUTPUT_DIR / 'checkpoints'),
        num_train_epochs=EPOCHS,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=max(8, BATCH_SIZE),
        learning_rate=LR,
        eval_strategy='epoch',
        save_strategy='epoch',
        load_best_model_at_end=True,
        metric_for_best_model='f1_macro',
        greater_is_better=True,
        fp16=False,
        logging_steps=50,
        report_to=[],
        seed=SEED,
    )
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=dev_ds,
        data_collator=DataCollatorWithPadding(tokenizer),
        compute_metrics=compute_metrics,
    )
    trainer.train()
    trainer.save_model(str(OUTPUT_DIR))
    tokenizer.save_pretrained(str(OUTPUT_DIR))
    # test set prediction and saving minimal artifacts
    preds = trainer.predict(test_ds)
    y_pred = np.argmax(preds.predictions, axis=-1)
    rep = classification_report(y_test, y_pred, target_names=CANONICAL_LABELS, output_dict=True, zero_division=0)
    with open(OUTPUT_DIR / 'eval_report.json', 'w', encoding='utf-8') as f:
        json.dump(rep, f, indent=2)
    cm = confusion_matrix(y_test, y_pred, labels=list(range(len(CANONICAL_LABELS))))
    import pandas as pd
    pd.DataFrame(cm, index=CANONICAL_LABELS, columns=CANONICAL_LABELS).to_csv(OUTPUT_DIR / 'confusion_matrix.csv')

print('Ready.')

In [None]:
# Evaluate (uses saved artifacts if they exist)
from pathlib import Path
import json
rep_path = OUTPUT_DIR / 'eval_report.json'
if rep_path.exists():
    rep = json.load(open(rep_path, 'r'))
    from pprint import pprint
    pprint({'accuracy': rep.get('accuracy'), 'macro_f1': rep.get('macro avg',{}).get('f1-score')})
else:
    print('No eval_report.json found. Rerun training cell to create.')

# Show confusion matrix if present
cm_path = OUTPUT_DIR / 'confusion_matrix.csv'
if cm_path.exists():
    import pandas as pd
    cm_df = pd.read_csv(cm_path, index_col=0)
    plot_confusion(cm_df.values, cm_df.columns.tolist())
else:
    print('No confusion_matrix.csv found.')

In [None]:
# Inference helper
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch, numpy as np

inf_tokenizer = AutoTokenizer.from_pretrained(str(OUTPUT_DIR)) if (OUTPUT_DIR / 'tokenizer.json').exists() else tokenizer
inf_model = AutoModelForSequenceClassification.from_pretrained(str(OUTPUT_DIR)).to('cuda' if torch.cuda.is_available() else 'cpu')

def infer(text):
    enc = inf_tokenizer([text], truncation=True, max_length=MAX_LENGTH, return_tensors='pt').to(inf_model.device)
    with torch.no_grad():
        logits = inf_model(**enc).logits.squeeze(0).cpu().numpy()
    probs = np.exp(logits) / np.exp(logits).sum()
    idx = int(np.argmax(probs))
    scores = {CANONICAL_LABELS[i]: float(probs[i]) for i in range(len(CANONICAL_LABELS))}
    return {'label': CANONICAL_LABELS[idx], 'confidence': float(probs[idx]), 'scores': scores, 'device': str(inf_model.device)}

infer('We collect your data only with explicit consent and provide an easy opt-out at any time.')