In [None]:
#!/usr/bin/env python3
"""
CPU-optimized DistilBERT trainer (pure PyTorch; no HF Trainer/accelerate; no TF/Keras)

Includes:
- WeightedRandomSampler to balance minority classes
- FocalLoss (with class weights) to focus on hard/minority examples
- Optional freezing of lower layers for speed
- Early stopping on macro-F1
- Saves metrics, curves, confusion matrix, model, tokenizer

Run:
    python train_distilbert_cpu_fast.py
"""

import os, sys, json, random, warnings
from pathlib import Path

# ---- Force Transformers torch-only (belt & suspenders) ----
os.environ["TRANSFORMERS_NO_TF"] = "1"
os.environ["USE_TF"] = "0"
for _m in ("tensorflow", "keras", "tf_keras", "tensorflow.keras"):
    sys.modules[_m] = None

warnings.filterwarnings("ignore", category=UserWarning, module="tqdm")

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.nn.utils import clip_grad_norm_
import torch.nn.functional as F
from torch.optim import AdamW

# --- CPU perf knobs
torch.backends.mkldnn.enabled = True  # oneDNN fast path on CPU
_NUM_CORES = max(2, (os.cpu_count() or 4) - 1)
try:
    torch.set_num_threads(_NUM_CORES)
    torch.set_num_interop_threads(max(1, _NUM_CORES // 2))
except Exception:
    pass

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)

# ---------------- Hard-coded settings ----------------
in_csv = "out/alerts_pseudo.csv"   # <--- change if needed
text_col = "Pseudo_Description"
label_col = "Priority_Level"
out_dir = Path("out/priority_model")

# Model choice:
#   "distilbert-base-uncased" -> best balance
#   "prajjwal1/bert-tiny"     -> much faster, lower accuracy
model_name = "distilbert-base-uncased"

# Speed/quality tradeoffs (good CPU defaults)
max_len = 1296           # try 128 if not truncated too much (slower)
batch_size = 32         # drop to 24/16 if RAM tight
epochs = 3              # early stopping will usually stop earlier
early_stop_patience = 2
lr = 3e-5
weight_decay = 0.01
warmup_ratio = 0.06
grad_clip = 1.0
seed = 42

# Freeze lower layers for speed (DistilBERT has 6 layers total)
FREEZE_EMBEDDINGS = True
FREEZE_FIRST_N_LAYERS = 3   # 0 = full finetune (best quality, slower); 2~3 = faster
# -----------------------------------------------------

def set_seed(s=42):
    random.seed(s); np.random.seed(s); torch.manual_seed(s); torch.cuda.manual_seed_all(s)

def clean_text(s: str) -> str:
    s = "" if pd.isna(s) else str(s)
    return " ".join(s.split())

class TextClsDataset(Dataset):
    def __init__(self, encodings, labels):
        self.enc = encodings
        self.labels = labels
    def __len__(self): return len(self.labels)
    def __getitem__(self, idx):
        item = {k: torch.tensor(v[idx]) for k, v in self.enc.items()}
        item["labels"] = torch.tensor(self.labels[idx], dtype=torch.long)
        return item

def compute_class_weights(y: np.ndarray, num_classes: int) -> torch.Tensor:
    """Balanced weights = n_samples / (n_classes * n_samples_per_class)."""
    counts = np.bincount(y, minlength=num_classes).astype(np.float32)
    counts[counts == 0] = 1.0
    weights = (len(y) / (num_classes * counts)).astype(np.float32)
    return torch.tensor(weights, dtype=torch.float32)

def split_stratified(texts, y, seed=42):
    from sklearn.model_selection import train_test_split
    X_train, X_tmp, y_train, y_tmp = train_test_split(
        texts, y, test_size=0.30, random_state=seed, stratify=y
    )
    X_val, X_test, y_val, y_test = train_test_split(
        X_tmp, y_tmp, test_size=0.50, random_state=seed, stratify=y_tmp
    )
    return X_train, X_val, X_test, y_train, y_val, y_test

def accuracy(y_true, y_pred):
    return float((y_true == y_pred).mean()) if len(y_true) else 0.0

def per_class_metrics(y_true, y_pred, n_classes):
    metrics = {}
    pr_list, rc_list, f1_list = [], [], []
    for c in range(n_classes):
        tp = int(((y_true == c) & (y_pred == c)).sum())
        fp = int(((y_true != c) & (y_pred == c)).sum())
        fn = int(((y_true == c) & (y_pred != c)).sum())
        prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        rec  = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1   = (2 * prec * rec) / (prec + rec) if (prec + rec) > 0 else 0.0
        metrics[c] = {"precision": prec, "recall": rec, "f1": f1, "support": int((y_true == c).sum())}
        pr_list.append(prec); rc_list.append(rec); f1_list.append(f1)
    macro = {
        "precision": float(np.mean(pr_list)) if pr_list else 0.0,
        "recall": float(np.mean(rc_list)) if rc_list else 0.0,
        "f1": float(np.mean(f1_list)) if f1_list else 0.0,
    }
    return metrics, macro

def confusion_matrix_counts(y_true, y_pred, n_classes):
    cm = np.zeros((n_classes, n_classes), dtype=int)
    for t, p in zip(y_true, y_pred):
        cm[int(t), int(p)] += 1
    return cm

class FocalLoss(torch.nn.Module):
    """CE * (1 - pt)^gamma with optional per-class alpha (weights)."""
    def __init__(self, alpha=None, gamma=2.0, reduction="mean"):
        super().__init__()
        self.alpha = alpha  # tensor [C] or None
        self.gamma = float(gamma)
        self.reduction = reduction
    def forward(self, logits, targets):
        ce = F.cross_entropy(logits, targets, weight=self.alpha, reduction="none")
        with torch.no_grad():
            pt = torch.softmax(logits, dim=1)[torch.arange(len(targets)), targets].clamp_(1e-8, 1 - 1e-8)
        loss = (1.0 - pt) ** self.gamma * ce
        return loss.mean() if self.reduction == "mean" else loss.sum()

def freeze_for_speed(model):
    """Freeze embeddings and first N transformer layers (DistilBERT layout)."""
    if FREEZE_EMBEDDINGS and hasattr(model, "distilbert"):
        for p in model.distilbert.embeddings.parameters():
            p.requires_grad = False
    try:
        layers = model.distilbert.transformer.layer
        for i, block in enumerate(layers):
            if i < max(0, int(FREEZE_FIRST_N_LAYERS)):
                for p in block.parameters():
                    p.requires_grad = False
    except Exception:
        pass
    return model

def main():
    set_seed(seed)
    out_dir.mkdir(parents=True, exist_ok=True)

    # ---- Data
    df = pd.read_csv(in_csv)
    need = {text_col, label_col}
    missing = need - set(df.columns)
    if missing:
        raise SystemExit(f"Missing columns: {missing}")

    df = df[[text_col, label_col]].dropna().drop_duplicates()
    df[text_col] = df[text_col].apply(clean_text)
    df = df[df[text_col].str.len() > 0].reset_index(drop=True)

    # Label encoding (deterministic)
    labels_raw = df[label_col].astype(str).values
    classes_sorted = sorted(np.unique(labels_raw).tolist())
    label2id = {lbl: i for i, lbl in enumerate(classes_sorted)}
    id2label = {i: lbl for lbl, i in label2id.items()}
    y = np.array([label2id[s] for s in labels_raw], dtype=np.int64)
    num_classes = len(classes_sorted)

    (out_dir / "label_map.json").write_text(
        json.dumps({"label2id": label2id, "id2label": {int(k): v for k, v in id2label.items()}}, indent=2),
        encoding="utf-8"
    )

    # Split
    X_train, X_val, X_test, y_train, y_val, y_test = split_stratified(df[text_col].tolist(), y, seed=seed)

    # Tokenize
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    enc_train = tokenizer(X_train, truncation=True, padding=False, max_length=max_len)
    enc_val   = tokenizer(X_val,   truncation=True, padding=False, max_length=max_len)
    enc_test  = tokenizer(X_test,  truncation=True, padding=False, max_length=max_len)

    train_ds = TextClsDataset(enc_train, y_train)
    val_ds   = TextClsDataset(enc_val, y_val)
    test_ds  = TextClsDataset(enc_test, y_test)

    # DataLoaders
    collator = DataCollatorWithPadding(tokenizer)
    # Build WeightedRandomSampler over train split (inverse class freq)
    class_counts = np.bincount(y_train, minlength=num_classes).astype(np.float64)
    class_counts[class_counts == 0] = 1.0
    sample_weights = 1.0 / class_counts[y_train]
    sampler = WeightedRandomSampler(
        weights=torch.tensor(sample_weights, dtype=torch.double),
        num_samples=len(y_train),
        replacement=True
    )
    num_workers = 0 if os.name == "nt" else 2
    train_loader = DataLoader(
        train_ds, batch_size=batch_size, sampler=sampler,
        collate_fn=collator, num_workers=num_workers,
        pin_memory=False, persistent_workers=False
    )
    val_loader   = DataLoader(
        val_ds, batch_size=batch_size, shuffle=False,
        collate_fn=collator, num_workers=num_workers,
        pin_memory=False, persistent_workers=False
    )
    test_loader  = DataLoader(
        test_ds, batch_size=batch_size, shuffle=False,
        collate_fn=collator, num_workers=num_workers,
        pin_memory=False, persistent_workers=False
    )

    print(f"CPU threads={torch.get_num_threads()} interop={torch.get_num_interop_threads()}  "
          f"bsize={batch_size}  max_len={max_len}  workers={num_workers}")

    # ---- Model
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name, num_labels=num_classes, id2label=id2label, label2id=label2id
    )
    model = freeze_for_speed(model)
    device = torch.device("cpu")
    model.to(device)

    # Optimizer & scheduler (trainable params only)
    no_decay = ["bias", "LayerNorm.weight"]
    grouped = [
        {"params": [p for n, p in model.named_parameters() if p.requires_grad and not any(nd in n for nd in no_decay)],
         "weight_decay": weight_decay},
        {"params": [p for n, p in model.named_parameters() if p.requires_grad and any(nd in n for nd in no_decay)],
         "weight_decay": 0.0},
    ]
    optimizer = AdamW(grouped, lr=lr)
    total_steps = epochs * max(1, len(train_loader))
    warmup_steps = int(warmup_ratio * total_steps)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)

    # Class weights and FocalLoss
    class_weights = compute_class_weights(y_train, num_classes=num_classes).to(device)
    focal = FocalLoss(alpha=class_weights, gamma=2.0)

    # ---- Train loop with early stopping on val macro-F1
    history = {"epoch": [], "train_loss": [], "val_macro_f1": [], "val_accuracy": []}
    best_f1, best_state = -1.0, None
    bad_epochs = 0

    for epoch in range(1, epochs + 1):
        model.train()
        running_loss = 0.0
        for batch in train_loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            optimizer.zero_grad()

            logits = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).logits
            loss = focal(logits, batch["labels"])  # focal loss with class weights

            loss.backward()
            clip_grad_norm_( [p for p in model.parameters() if p.requires_grad], grad_clip)
            optimizer.step()
            scheduler.step()

            running_loss += loss.item()

        avg_train_loss = running_loss / max(1, len(train_loader))

        # ---- Validation
        model.eval()
        val_preds, val_true = [], []
        with torch.no_grad():
            for batch in val_loader:
                batch = {k: v.to(device) for k, v in batch.items()}
                logits = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).logits
                pred = torch.argmax(logits, dim=1)
                val_preds.append(pred.cpu().numpy())
                val_true.append(batch["labels"].cpu().numpy())
        val_preds = np.concatenate(val_preds); val_true = np.concatenate(val_true)

        # Metrics
        pr_list, rc_list, f1_list = [], [], []
        for c in range(num_classes):
            tp = int(((val_true == c) & (val_preds == c)).sum())
            fp = int(((val_true != c) & (val_preds == c)).sum())
            fn = int(((val_true == c) & (val_preds != c)).sum())
            p = tp / (tp + fp) if (tp + fp) > 0 else 0.0
            r = tp / (tp + fn) if (tp + fn) > 0 else 0.0
            f1 = (2*p*r)/(p+r) if (p+r) > 0 else 0.0
            pr_list.append(p); rc_list.append(r); f1_list.append(f1)
        val_macro_f1 = float(np.mean(f1_list)) if f1_list else 0.0
        val_acc = accuracy(val_true, val_preds)

        history["epoch"].append(epoch)
        history["train_loss"].append(avg_train_loss)
        history["val_macro_f1"].append(val_macro_f1)
        history["val_accuracy"].append(val_acc)

        print(f"Epoch {epoch}: train_loss={avg_train_loss:.4f}  val_macro_f1={val_macro_f1:.4f}  val_acc={val_acc:.4f}")

        # Early stopping
        if val_macro_f1 > best_f1 + 1e-6:
            best_f1 = val_macro_f1
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            bad_epochs = 0
        else:
            bad_epochs += 1
            if bad_epochs > early_stop_patience:
                print("Early stopping.")
                break

    # Load best
    if best_state is not None:
        model.load_state_dict(best_state)

    # ---- Test
    model.eval()
    test_preds, test_true = [], []
    with torch.no_grad():
        for batch in test_loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            logits = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).logits
            pred = torch.argmax(logits, dim=1)
            test_preds.append(pred.cpu().numpy())
            test_true.append(batch["labels"].cpu().numpy())
    test_preds = np.concatenate(test_preds); test_true = np.concatenate(test_true)

    per_class, macro = per_class_metrics(test_true, test_preds, n_classes=num_classes)
    acc = accuracy(test_true, test_preds)
    cm = confusion_matrix_counts(test_true, test_preds, n_classes=num_classes)

    # ---- Save reports
    (out_dir / "label_map.json").write_text(
        json.dumps({"label2id": label2id, "id2label": {int(k): v for k, v in id2label.items()}}, indent=2),
        encoding="utf-8"
    )
    rep = {
        "accuracy": acc,
        "macro": macro,
        "per_class": {id2label[i]: per_class[i] for i in range(num_classes)}
    }
    (out_dir / "test_classification_report.json").write_text(json.dumps(rep, indent=2), encoding="utf-8")
    pd.DataFrame(cm, index=[id2label[i] for i in range(num_classes)],
                 columns=[id2label[i] for i in range(num_classes)]).to_csv(out_dir / "test_confusion_matrix.csv", index=True)

    summary = {
        "test_accuracy": float(acc),
        "test_macro_f1": float(macro["f1"]),
        "n_train": int(len(train_ds)), "n_val": int(len(val_ds)), "n_test": int(len(test_ds)),
        "labels": [id2label[i] for i in range(num_classes)]
    }
    (out_dir / "summary.json").write_text(json.dumps(summary, indent=2), encoding="utf-8")
    print(json.dumps(summary, indent=2))

    # ---- Graphs
    if history["epoch"]:
        fig = plt.figure(figsize=(8, 5))
        plt.plot(history["epoch"], history["train_loss"], label="train_loss")
        plt.plot(history["epoch"], history["val_macro_f1"], label="val_macro_f1")
        plt.plot(history["epoch"], history["val_accuracy"], label="val_accuracy")
        plt.xlabel("epoch"); plt.ylabel("value"); plt.title("Training curves (CPU)")
        plt.legend(); plt.tight_layout()
        fig.savefig(out_dir / "curves.png", dpi=160); plt.close(fig)

    cm_norm = cm.astype(float) / cm.sum(axis=1, keepdims=True).clip(min=1.0)
    fig = plt.figure(figsize=(6 + 0.3*num_classes, 5 + 0.3*num_classes))
    plt.imshow(cm_norm, aspect="auto")
    ticks = np.arange(num_classes)
    labels = [id2label[i] for i in ticks]
    plt.xticks(ticks, labels, rotation=45, ha="right")
    plt.yticks(ticks, labels)
    plt.xlabel("Predicted"); plt.ylabel("True")
    plt.title("Confusion Matrix (row-normalized)")
    for i in range(cm_norm.shape[0]):
        for j in range(cm_norm.shape[1]):
            plt.text(j, i, f"{cm_norm[i, j]:.2f}", ha="center", va="center")
    plt.tight_layout()
    fig.savefig(out_dir / "confusion_matrix.png", dpi=160); plt.close(fig)

    # ---- Save model & tokenizer
    model.save_pretrained(out_dir)
    tokenizer.save_pretrained(out_dir)

if __name__ == "__main__":
    main()


CPU threads=3 interop=1  bsize=32  max_len=1296  workers=0


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1: train_loss=0.6728  val_macro_f1=0.2704  val_acc=0.2247


KeyboardInterrupt: 