In [1]:
RAW_CONFIG_NAME = "raw"
SIMPLIFIED_CONFIG_NAME = "simplified"

BASE_MODELS = [
    {"name": "nreimers/MiniLM-L6-H384-uncased", "hpo": True, "hpo_mode": RAW_CONFIG_NAME, "train": False, "train_mode": RAW_CONFIG_NAME},
    {"name": "google/electra-small-discriminator", "hpo": False, "hpo_mode": RAW_CONFIG_NAME, "train": False, "train_mode": RAW_CONFIG_NAME},
    {"name": "roberta-base", "hpo": False, "hpo_mode": RAW_CONFIG_NAME, "train": False, "train_mode": RAW_CONFIG_NAME},
]

RUN_QUANTIZATION = False    # <<< CHANGE TO TRUE when all training has been done after all HPO has been done


In [2]:
# ====================================================
# RUN MODE: MINI vs FULL
# ====================================================

SMOKE_TEST = False    # True = tiny debugging subset, False = full HPO
MINI_TRAIN_SIZE = 50
MINI_VALID_SIZE = 20

In [3]:
# ====================================================
# 0. Environment: install, import, Drive mount, config
# ====================================================

# ---- Clean install of HF stack (no Python switching needed) ----
!pip uninstall -y tokenizers transformers datasets evaluate accelerate
!pip cache purge

!pip install -U transformers datasets tokenizers evaluate accelerate

# ---- Standard imports ----
import os
import random
import json
from typing import Dict, Any, List, Tuple

import numpy as np
import torch

# ---- Mount Google Drive ----
from google.colab import drive
drive.mount('/content/drive')

# ---- HuggingFace imports ----
from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
)
from datasets import load_dataset, DatasetDict, load_from_disk, Dataset
import evaluate
from torch import nn

# ====================================================
# Directory structure
# ====================================================

PROJECT_ROOT = "/content/drive/MyDrive/emotion_project"

if SMOKE_TEST:
    MODE_ROOT = os.path.join(PROJECT_ROOT, "mini")
    RUN_MODE_NAME = "MINI"
else:
    MODE_ROOT = os.path.join(PROJECT_ROOT, "full")
    RUN_MODE_NAME = "FULL"

DATA_CACHE_DIR = os.path.join(MODE_ROOT, "cache")
TOKENIZED_DIR  = os.path.join(MODE_ROOT, "tokenized")
TRIALS_DIR     = os.path.join(MODE_ROOT, "trials")
FINAL_DIR      = os.path.join(MODE_ROOT, "final_models")
LOGS_DIR       = os.path.join(MODE_ROOT, "logs")

for d in [PROJECT_ROOT, MODE_ROOT, DATA_CACHE_DIR, TOKENIZED_DIR, TRIALS_DIR, FINAL_DIR, LOGS_DIR]:
    os.makedirs(d, exist_ok=True)

print(f"RUN_MODE = {RUN_MODE_NAME}")
print("MODE_ROOT:", MODE_ROOT)


# ====================================================
# Model and HPO configs
# ====================================================

MAX_LENGTH = 64
BATCH_SIZE = 8
NUM_EPOCHS = 3
SEED = 42

# ---- Manual HPO grid ----
HPO_GRID = [
    {
        "learning_rate": 2e-5,
        "warmup_ratio": 0.05,
        "weight_decay": 0.01,
        "scheduler": "linear",
        "batch_size": 16,
        "max_length": 64,
        "optimizer": "adamw_torch",
    },
    {
        "learning_rate": 3e-5,
        "warmup_ratio": 0.1,
        "weight_decay": 0.01,
        "scheduler": "cosine",
        "batch_size": 8,
        "max_length": 128,
        "optimizer": "adamw_torch",
    },
    {
        "learning_rate": 5e-5,
        "warmup_ratio": 0.0,
        "weight_decay": 0.0,
        "scheduler": "linear",
        "batch_size": 16,
        "max_length": 64,
        "optimizer": "adafactor",
    },
]

# ---- Reproducibility ----
def set_global_seed(seed: int = SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)


Found existing installation: tokenizers 0.22.1
Uninstalling tokenizers-0.22.1:
  Successfully uninstalled tokenizers-0.22.1
Found existing installation: transformers 4.57.1
Uninstalling transformers-4.57.1:
  Successfully uninstalled transformers-4.57.1
Found existing installation: datasets 4.0.0
Uninstalling datasets-4.0.0:
  Successfully uninstalled datasets-4.0.0
[0mFound existing installation: accelerate 1.11.0
Uninstalling accelerate-1.11.0:
  Successfully uninstalled accelerate-1.11.0
[0mFiles removed: 0
Collecting transformers
  Downloading transformers-4.57.1-py3-none-any.whl.metadata (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets
  Downloading datasets-4.4.1-py3-none-any.whl.metadata (19 kB)
Collecting tokenizers
  Downloading tokenizers-0.22.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)
Collecting evaluate
  Downloading evaluate-0.4.6-py3

In [4]:
# ====================================================
# 1. GoEmotions: loading, label-info, RAW → multi-label
# ====================================================

from functools import lru_cache
from itertools import chain


# ====================================================
# RAW label-info (optimized)
# ====================================================
def get_label_info_raw(ds: DatasetDict) -> Dict[str, Any]:
    """
    RAW config: 36 independent binary label columns.
    We compute statistics using NumPy instead of Python loops for speed.
    """
    train_ds = ds["train"]

    META_COLS = {
        "text", "id", "author", "subreddit",
        "link_id", "parent_id", "created_utc",
        "rater_id", "example_very_unclear"
    }
    all_columns = train_ds.column_names
    label_names = [col for col in all_columns if col not in META_COLS]

    num_labels = len(label_names)

    # ---- FAST LABEL COUNTS ----
    # Stack each binary column into a 2D NumPy array and sum along axis 0.
    arr = np.stack([train_ds[col] for col in label_names], axis=1).astype(np.float32)
    counts = arr.sum(axis=0)

    total = counts.sum()
    neg_counts = total - counts
    pos_weight = neg_counts / (counts + 1e-6)

    id2label = {i: name for i, name in enumerate(label_names)}
    label2id = {name: i for i, name in id2label.items()}

    return {
        "mode": "raw",
        "num_labels": num_labels,
        "id2label": id2label,
        "label2id": label2id,
        "pos_weight": torch.tensor(pos_weight, dtype=torch.float32),
        "class_weights": None,
    }


# ====================================================
# Simplified label-info (optimized)
# ====================================================
def get_label_info_simplified(ds: DatasetDict) -> Dict[str, Any]:
    train_ds = ds["train"]
    features = train_ds.features

    labels_feature = features["labels"]

    # Ensure the feature is a Sequence(ClassLabel)
    if hasattr(labels_feature, "feature") and hasattr(labels_feature.feature, "names"):
        label_names = labels_feature.feature.names
    else:
        raise ValueError("Simplified 'labels' must be Sequence(ClassLabel).")

    num_labels = len(label_names)

    # ---- FAST LABEL COUNTS ----
    counts = np.zeros(num_labels, dtype=np.float32)
    for lab in chain.from_iterable(train_ds["labels"]):
        counts[lab] += 1

    total = counts.sum()
    pos_weight = (total - counts) / (counts + 1e-6)

    id2label = {i: n for i, n in enumerate(label_names)}
    label2id = {n: i for i, n in id2label.items()}

    return {
        "mode": SIMPLIFIED_CONFIG_NAME,
        "num_labels": num_labels,
        "id2label": id2label,
        "label2id": label2id,
        "pos_weight": torch.tensor(pos_weight, dtype=torch.float32),
        "class_weights": None,
    }


# ====================================================
# RAW → multi-label conversion
# ====================================================
def convert_raw_to_multilabel(ds: DatasetDict, label_names: List[str]) -> DatasetDict:
    """
    Converts 36 independent binary label columns into
    a single 'labels' multi-hot list of ints.
    """
    def build_labels(example):
        return {"labels": [int(example[lbl]) for lbl in label_names]}

    converted = ds.map(build_labels)
    converted = converted.remove_columns(label_names)
    return converted


# ====================================================
# Load datasets + compute label infos (cached)
# ====================================================
@lru_cache()
def get_datasets_and_label_infos() -> Tuple[Dict[str, DatasetDict], Dict[str, Dict[str, Any]]]:
    """
    Loads GoEmotions datasets (RAW + simplified),
    applies mini-mode slicing if enabled,
    computes label statistics,
    and returns dataset + info dicts.
    """
    set_global_seed(SEED)

    print("Loading GoEmotions RAW config...")
    raw_full = load_dataset("go_emotions", RAW_CONFIG_NAME)

    # Split RAW train into train / validation
    raw_train_valid = raw_full["train"].train_test_split(test_size=0.1, seed=SEED)
    raw_train_valid["validation"] = raw_train_valid.pop("test")

    print("Loading GoEmotions SIMPLIFIED config...")
    simplified_ds = load_dataset("go_emotions", SIMPLIFIED_CONFIG_NAME)

    # ---- MINI MODE slicing ----
    if SMOKE_TEST:
        print("SMOKE_TEST ACTIVE — dataset slicing enabled")
        raw_train_valid["train"] = raw_train_valid["train"].select(range(MINI_TRAIN_SIZE))
        raw_train_valid["validation"] = raw_train_valid["validation"].select(range(MINI_VALID_SIZE))
        simplified_ds["train"] = simplified_ds["train"].select(range(MINI_TRAIN_SIZE))
        simplified_ds["validation"] = simplified_ds["validation"].select(range(MINI_VALID_SIZE))
    # ----------------------------

    print("Computing RAW label-info...")
    raw_info = get_label_info_raw(raw_train_valid)

    print("Computing SIMPLIFIED label-info...")
    simplified_info = get_label_info_simplified(simplified_ds)

    # Convert RAW → multi-label
    label_names_raw = list(raw_info["id2label"].values())
    print("Converting RAW → multi-label format...")
    raw_ml = convert_raw_to_multilabel(raw_train_valid, label_names_raw)

    datasets = {
        RAW_CONFIG_NAME: raw_ml,
        SIMPLIFIED_CONFIG_NAME: simplified_ds,
    }
    infos = {
        RAW_CONFIG_NAME: raw_info,
        SIMPLIFIED_CONFIG_NAME: simplified_info,
    }
    return datasets, infos


# ====================================================
# Test run (safe)
# ====================================================
datasets_dict, infos_dict = get_datasets_and_label_infos()

print(datasets_dict[RAW_CONFIG_NAME])
print(datasets_dict[SIMPLIFIED_CONFIG_NAME])
print("RAW num_labels:", infos_dict[RAW_CONFIG_NAME]["num_labels"])
print("SIMPLIFIED num_labels:", infos_dict[SIMPLIFIED_CONFIG_NAME]["num_labels"])


Loading GoEmotions RAW config...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

raw/train-00000-of-00001.parquet:   0%|          | 0.00/24.8M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/211225 [00:00<?, ? examples/s]

Loading GoEmotions SIMPLIFIED config...


simplified/train-00000-of-00001.parquet:   0%|          | 0.00/2.77M [00:00<?, ?B/s]

simplified/validation-00000-of-00001.par(…):   0%|          | 0.00/350k [00:00<?, ?B/s]

simplified/test-00000-of-00001.parquet:   0%|          | 0.00/347k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/43410 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/5426 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/5427 [00:00<?, ? examples/s]

Computing RAW label-info...
Computing SIMPLIFIED label-info...
Converting RAW → multi-label format...


Map:   0%|          | 0/190102 [00:00<?, ? examples/s]

Map:   0%|          | 0/21123 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['text', 'id', 'author', 'subreddit', 'link_id', 'parent_id', 'created_utc', 'rater_id', 'example_very_unclear', 'labels'],
        num_rows: 190102
    })
    validation: Dataset({
        features: ['text', 'id', 'author', 'subreddit', 'link_id', 'parent_id', 'created_utc', 'rater_id', 'example_very_unclear', 'labels'],
        num_rows: 21123
    })
})
DatasetDict({
    train: Dataset({
        features: ['text', 'labels', 'id'],
        num_rows: 43410
    })
    validation: Dataset({
        features: ['text', 'labels', 'id'],
        num_rows: 5426
    })
    test: Dataset({
        features: ['text', 'labels', 'id'],
        num_rows: 5427
    })
})
RAW num_labels: 28
SIMPLIFIED num_labels: 28


In [5]:
# ====================================================
# 2. Tokenization + caching (tokenized dataset saved to Drive)
# ====================================================

from transformers import AutoTokenizer


# ----------------------------------------------------
# Tokenizer loader
# ----------------------------------------------------
def prepare_tokenizer(model_name: str) -> AutoTokenizer:
    return AutoTokenizer.from_pretrained(model_name, use_fast=True)


# ----------------------------------------------------
# RAW preprocessing (multi-label)
# ----------------------------------------------------
def preprocess_function_raw(examples, tokenizer, num_labels: int, max_len: int):
    texts = examples["text"]

    encodings = tokenizer(
        texts,
        truncation=True,
        padding="max_length",
        max_length=max_len,
    )

    # examples["labels"] is already multi-hot lists
    multi_hot = np.array(examples["labels"], dtype=np.float32)

    if multi_hot.shape[1] != num_labels:
        raise ValueError(f"Expected {num_labels} labels, got {multi_hot.shape[1]}")

    encodings["labels"] = multi_hot.tolist()
    return encodings


# ----------------------------------------------------
# Simplified preprocessing (single-label)
# ----------------------------------------------------
def preprocess_function_simplified(examples, tokenizer, max_len: int):
    texts = examples["text"]

    encodings = tokenizer(
        texts,
        truncation=True,
        padding="max_length",
        max_length=max_len,
    )

    # Simplified: labels = [[13], [7], ...]
    encodings["labels"] = [lbl_list[0] for lbl_list in examples["labels"]]

    return encodings


# ----------------------------------------------------
# Tokenize dataset for a given model
# ----------------------------------------------------
def tokenize_dataset_for_model(
    model_name: str,
    ds: DatasetDict,
    mode: str,            # "raw" or "simplified"
    label_info: Dict[str, Any],
    max_len: int = MAX_LENGTH,
) -> Tuple[DatasetDict, AutoTokenizer]:

    tokenizer = prepare_tokenizer(model_name)

    if mode == RAW_CONFIG_NAME:
        fn = lambda ex: preprocess_function_raw(
            ex, tokenizer, label_info["num_labels"], max_len
        )
    else:
        fn = lambda ex: preprocess_function_simplified(
            ex, tokenizer, max_len
        )

    # Remove original columns to keep dataset compact
    tokenized = ds.map(
        fn,
        batched=True,
        remove_columns=ds["train"].column_names,
        desc=f"Tokenizing for model={model_name}, mode={mode}",
    )

    tokenized.set_format(type="torch")
    return tokenized, tokenizer


# ----------------------------------------------------
# Cached tokenized dataset loader
# ----------------------------------------------------
def get_tokenized_dataset_cached(
    base_model_name: str,
    mode: str,
    max_len: int = MAX_LENGTH,
) -> Tuple[DatasetDict, AutoTokenizer, Dict[str, Any]]:

    datasets_dict, infos_dict = get_datasets_and_label_infos()
    ds = datasets_dict[mode]
    label_info = infos_dict[mode]

    model_key = base_model_name.replace("/", "_")
    cache_dir = os.path.join(TOKENIZED_DIR, f"{model_key}_{mode}_len{max_len}")

    # ---- Load from cache ----
    if os.path.exists(cache_dir):
        print(f"[CACHE] Loading tokenized dataset from {cache_dir}")
        tokenized = load_from_disk(cache_dir)
        tokenizer = prepare_tokenizer(base_model_name)
        return tokenized, tokenizer, label_info

    # ---- Create new tokenized version ----
    print(f"[TOKENIZE] Creating tokenized dataset for model={base_model_name} | mode={mode}")
    tokenized, tokenizer = tokenize_dataset_for_model(
        model_name=base_model_name,
        ds=ds,
        mode=mode,
        label_info=label_info,
        max_len=max_len,
    )

    os.makedirs(cache_dir, exist_ok=True)
    tokenized.save_to_disk(cache_dir)
    print(f"[TOKENIZE] Saved tokenized dataset to {cache_dir}")

    return tokenized, tokenizer, label_info


# ----------------------------------------------------
# Quick sanity check
# ----------------------------------------------------
test_tokenized, test_tok, _ = get_tokenized_dataset_cached(
    "distilbert-base-uncased",
    RAW_CONFIG_NAME,
    max_len=MAX_LENGTH,
)
print(test_tokenized)


[TOKENIZE] Creating tokenized dataset for model=distilbert-base-uncased | mode=raw


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Tokenizing for model=distilbert-base-uncased, mode=raw:   0%|          | 0/190102 [00:00<?, ? examples/s]

Tokenizing for model=distilbert-base-uncased, mode=raw:   0%|          | 0/21123 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/190102 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/21123 [00:00<?, ? examples/s]

[TOKENIZE] Saved tokenized dataset to /content/drive/MyDrive/emotion_project/full/tokenized/distilbert-base-uncased_raw_len64
DatasetDict({
    train: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 190102
    })
    validation: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 21123
    })
})


In [6]:
# ====================================================
# 3. Plutchik-distance metrics + WeightedTrainer
# ====================================================

# Simple mapping: GoEmotions label → Plutchik primary category
RAW_LABEL_TO_PLUTCHIK = {
    "admiration": "trust",
    "amusement": "joy",
    "anger": "anger",
    "annoyance": "anger",
    "approval": "trust",
    "caring": "trust",
    "confusion": "surprise",
    "curiosity": "anticipation",
    "desire": "anticipation",
    "disappointment": "sadness",
    "disapproval": "disgust",
    "disgust": "disgust",
    "embarrassment": "sadness",
    "excitement": "joy",
    "fear": "fear",
    "gratitude": "joy",
    "grief": "sadness",
    "joy": "joy",
    "love": "joy",
    "nervousness": "fear",
    "optimism": "anticipation",
    "pride": "joy",
    "realization": "surprise",
    "relief": "joy",
    "remorse": "sadness",
    "sadness": "sadness",
    "surprise": "surprise",
    "neutral": "neutral",
}


def plutchik_distance(primary_a: str, primary_b: str) -> float:
    """
    Very simple distance on Plutchik primary categories:
    - 0.0 if same category
    - 0.5 if one of them is 'neutral'
    - 1.0 otherwise
    """
    if primary_a == primary_b:
        return 0.0
    if "neutral" in (primary_a, primary_b):
        return 0.5
    return 1.0


def compute_plutchik_metrics_multilabel(
    logits: np.ndarray,
    labels: np.ndarray,
    id2label: Dict[int, str],
    threshold: float = 0.5,
) -> Dict[str, float]:
    """
    Multi-label evaluation:
    - standard micro F1
    - Plutchik "soft" F1 that gives partial credit when
      predicted emotions fall in the same Plutchik primary category.
    """
    sigmoid = 1.0 / (1.0 + np.exp(-logits))
    preds = (sigmoid >= threshold).astype(int)

    eps = 1e-8

    # Standard micro F1
    tp = np.sum((preds == 1) & (labels == 1))
    fp = np.sum((preds == 1) & (labels == 0))
    fn = np.sum((preds == 0) & (labels == 1))

    prec = tp / (tp + fp + eps)
    rec = tp / (tp + fn + eps)
    f1_micro = 2 * prec * rec / (prec + rec + eps)

    # Precompute primary category for each label id
    num_labels = len(id2label)
    id_to_primary = [
        RAW_LABEL_TO_PLUTCHIK.get(id2label[i], "neutral") for i in range(num_labels)
    ]

    soft_tp = 0.0
    soft_fp = 0.0
    soft_fn = 0.0

    # Plutchik-aware soft counts
    for i in range(labels.shape[0]):
        true_ids = np.where(labels[i] == 1)[0].tolist()
        pred_ids = np.where(preds[i] == 1)[0].tolist()

        true_set = set(true_ids)
        pred_set = set(pred_ids)

        for pid in pred_ids:
            if pid in true_set:
                soft_tp += 1.0
            else:
                p_primary = id_to_primary[pid]
                # Check if there is a true label with the same primary category
                if any(
                    plutchik_distance(
                        p_primary,
                        id_to_primary[tid],
                    )
                    == 0.0
                    for tid in true_ids
                ):
                    soft_tp += 0.5
                else:
                    soft_fp += 1.0

        for tid in true_ids:
            if tid not in pred_set:
                soft_fn += 1.0

    s_prec = soft_tp / (soft_tp + soft_fp + eps)
    s_rec = soft_tp / (soft_tp + soft_fn + eps)
    plutchik_f1 = 2 * s_prec * s_rec / (s_prec + s_rec + eps)

    return {
        "f1_micro": float(f1_micro),
        "plutchik_f1": float(plutchik_f1),
    }


def compute_plutchik_metrics_singlelabel(
    logits: np.ndarray,
    labels: np.ndarray,
    id2label: Dict[int, str],
) -> Dict[str, float]:
    """
    Single-label evaluation (simplified config):
    - accuracy
    - Plutchik "soft" score with partial credit:
        1.0  for exact match
        0.5  if same primary category
        0.25 if one is neutral w.r.t. Plutchik distance
    """
    preds = np.argmax(logits, axis=-1)
    acc = (preds == labels).mean()
    soft_score = 0.0
    n = len(labels)

    num_labels = len(id2label)
    id_to_primary = [
        RAW_LABEL_TO_PLUTCHIK.get(id2label[i], "neutral") for i in range(num_labels)
    ]

    for p, t in zip(preds, labels):
        if p == t:
            soft_score += 1.0
        else:
            primary_p = id_to_primary[p]
            primary_t = id_to_primary[t]
            d = plutchik_distance(primary_p, primary_t)
            if d == 0.0:
                soft_score += 0.5
            elif d == 0.5:
                soft_score += 0.25

    return {
        "accuracy": float(acc),
        "plutchik_soft": float(soft_score / n),
    }


def make_compute_metrics_raw(id2label: Dict[int, str]):
    """
    Factory that returns a compute_metrics function for the RAW (multi-label) setting,
    ready to be passed to the HF Trainer.
    """
    def compute_metrics(eval_pred):
        logits, labels = eval_pred
        logits = np.array(logits)
        labels = np.array(labels)
        return compute_plutchik_metrics_multilabel(logits, labels, id2label)

    return compute_metrics


def make_compute_metrics_simplified(id2label: Dict[int, str]):
    """
    Factory that returns a compute_metrics function for the SIMPLIFIED (single-label) setting.
    """
    def compute_metrics(eval_pred):
        logits, labels = eval_pred
        logits = np.array(logits)
        labels = np.array(labels)
        return compute_plutchik_metrics_singlelabel(logits, labels, id2label)

    return compute_metrics


class WeightedTrainer(Trainer):
    """
    Extension of HF Trainer with:
    - BCEWithLogitsLoss + pos_weight for multi-label
    - CrossEntropyLoss + class weights for single-label
    """
    def __init__(
        self,
        is_multilabel: bool,
        pos_weight: torch.Tensor = None,
        class_weights: torch.Tensor = None,
        num_labels: int = None,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.is_multilabel = is_multilabel
        self.pos_weight = pos_weight
        self.class_weights = class_weights
        self.num_labels = num_labels

        if self.pos_weight is not None:
            self.pos_weight = self.pos_weight.to(self.args.device)
        if self.class_weights is not None:
            self.class_weights = self.class_weights.to(self.args.device)

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits

        if self.is_multilabel:
            # Multi-label: BCE with logits, optional pos_weight for imbalance
            labels = labels.to(logits.dtype)
            loss_fct = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight)
            loss = loss_fct(logits, labels)
        else:
            # Single-label: standard CrossEntropy with optional class weights
            loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return (loss, outputs) if return_outputs else loss


In [7]:
# ====================================================
# 4. Run a single HPO trial and save results to Drive
# ====================================================

def save_json(obj, path: str):
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, indent=2, ensure_ascii=False)


def run_hpo_trial(
    base_model_name: str,
    mode: str,         # "raw" or "simplified"
    trial_index: int,  # HPO grid index
) -> Dict[str, Any]:

    # Disable HPO in MINI mode (safety)
    if SMOKE_TEST:
        raise RuntimeError(
            "HPO (run_hpo_trial) is disabled in SMOKE_TEST. "
            "Set SMOKE_TEST=False for full runs."
        )

    assert mode in [RAW_CONFIG_NAME, SIMPLIFIED_CONFIG_NAME]
    assert 0 <= trial_index < len(HPO_GRID)

    set_global_seed(SEED)

    # -------------------------------------------------
    # Load HPO config
    # -------------------------------------------------
    cfg = HPO_GRID[trial_index]
    model_key = base_model_name.replace("/", "_")

    # effective batch size (prevent OOM on large models)
    effective_bs = cfg["batch_size"]
    if "roberta" in base_model_name.lower():
        effective_bs = min(effective_bs, 8)

    # -------------------------------------------------
    # Prepare trial dirs
    # -------------------------------------------------
    trial_root = os.path.join(TRIALS_DIR, f"{model_key}_{mode}")
    os.makedirs(trial_root, exist_ok=True)

    trial_dir = os.path.join(trial_root, f"trial_{trial_index:02d}")
    os.makedirs(trial_dir, exist_ok=True)

    metrics_path = os.path.join(trial_dir, "metrics.json")
    if os.path.exists(metrics_path):
        print(f"[SKIP] Trial already exists, loading metrics from {metrics_path}")
        with open(metrics_path, "r", encoding="utf-8") as f:
            return json.load(f)

    # -------------------------------------------------
    print("\n" + "="*90)
    print(f"Running HPO TRIAL for {base_model_name} | mode={mode} | trial={trial_index}")
    print("HPO config:", cfg)
    print("="*90)

    # -------------------------------------------------
    # Load tokenized dataset
    # -------------------------------------------------
    tokenized_ds, tokenizer, label_info = get_tokenized_dataset_cached(
        base_model_name, mode, max_len=cfg["max_length"]
    )

    num_labels = label_info["num_labels"]
    id2label = label_info["id2label"]
    label2id = label_info["label2id"]

    # -------------------------------------------------
    # Build model config
    # -------------------------------------------------
    problem_type = (
        "multi_label_classification" if mode == RAW_CONFIG_NAME
        else "single_label_classification"
    )

    config = AutoConfig.from_pretrained(
        base_model_name,
        num_labels=num_labels,
        id2label=id2label,
        label2id=label2id,
        problem_type=problem_type,
    )

    # -------------------------------------------------
    # Metrics + loss mode
    # -------------------------------------------------
    if mode == RAW_CONFIG_NAME:
        compute_metrics = make_compute_metrics_raw(id2label)
        is_multilabel = True
        pos_weight = label_info["pos_weight"]
        class_weights = None
        best_metric_name = "plutchik_f1"
    else:
        compute_metrics = make_compute_metrics_simplified(id2label)
        is_multilabel = False
        pos_weight = None
        class_weights = label_info["class_weights"]
        best_metric_name = "plutchik_soft"

    # -------------------------------------------------
    # Load pretrained model
    # -------------------------------------------------
    model = AutoModelForSequenceClassification.from_pretrained(
        base_model_name,
        config=config,
    )

    # -------------------------------------------------
    # TrainingArguments (fully HPO-driven)
    # -------------------------------------------------
    training_args = TrainingArguments(
        output_dir=trial_dir,

        eval_strategy="epoch",
        save_strategy="no",

        learning_rate=cfg["learning_rate"],
        warmup_ratio=cfg["warmup_ratio"],
        weight_decay=cfg["weight_decay"],
        num_train_epochs=NUM_EPOCHS,
        lr_scheduler_type=cfg["scheduler"],
        optim=cfg["optimizer"],

        per_device_train_batch_size=effective_bs,
        per_device_eval_batch_size=effective_bs,

        logging_steps=200,
        report_to=[],
        seed=SEED,

        fp16=False,
        bf16=torch.cuda.is_available(),   # stable on Ampere GPUs
    )

    # -------------------------------------------------
    # Trainer
    # -------------------------------------------------
    trainer = WeightedTrainer(
        is_multilabel=is_multilabel,
        pos_weight=pos_weight,
        class_weights=class_weights,
        num_labels=num_labels,
        model=model,
        args=training_args,
        train_dataset=tokenized_ds["train"],
        eval_dataset=tokenized_ds["validation"],
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )

    # Run training + evaluation
    trainer.train()
    eval_metrics = trainer.evaluate(tokenized_ds["validation"])
    print("Eval metrics:", eval_metrics)

    main_metric = eval_metrics.get(best_metric_name)
    if main_metric is None:  # backup fallback
        main_metric = eval_metrics.get("f1_micro", eval_metrics.get("accuracy", 0.0))

    # -------------------------------------------------
    # Save results
    # -------------------------------------------------
    result = {
        "base_model_name": base_model_name,
        "mode": mode,
        "trial_index": trial_index,
        "hpo_config": cfg,
        "eval_metrics": eval_metrics,
        "main_metric_name": best_metric_name,
        "main_metric_value": float(main_metric),
    }

    trainer.save_model(trial_dir)
    tokenizer.save_pretrained(trial_dir)
    save_json(result, metrics_path)

    print(f"[DONE] Trial metrics saved to {metrics_path}")
    return result


In [8]:
# ====================================================
# 5. Best trial selection + quantization + export
# ====================================================

from glob import glob


def load_all_trials(base_model_name: str, mode: str) -> List[Dict[str, Any]]:
    """
    Loads all HPO trial results (metrics.json) for a given base model and mode.
    """
    model_key = base_model_name.replace("/", "_")
    trial_root = os.path.join(TRIALS_DIR, f"{model_key}_{mode}")

    if not os.path.exists(trial_root):
        print("No trials directory:", trial_root)
        return []

    results = []
    for trial_dir in sorted(glob(os.path.join(trial_root, "trial_*"))):
        metrics_path = os.path.join(trial_dir, "metrics.json")
        if not os.path.exists(metrics_path):
            continue
        with open(metrics_path, "r", encoding="utf-8") as f:
            res = json.load(f)
        res["trial_dir"] = trial_dir
        results.append(res)

    return results


def select_best_trial(
    base_model_name: str,
    mode: str,
    metric_name: str = None,
) -> Dict[str, Any]:
    """
    Selects the best HPO trial based on the target metric.
    Uses Plutchik F1 (RAW) or Plutchik soft score (SIMPLIFIED) by default.
    """
    results = load_all_trials(base_model_name, mode)
    if not results:
        raise ValueError("No trial results found. Run run_hpo_trial first.")

    if metric_name is None:
        metric_name = "plutchik_f1" if mode == RAW_CONFIG_NAME else "plutchik_soft"

    best = None
    best_value = -1e9

    for r in results:
        mv = r["eval_metrics"].get(metric_name)

        if mv is None:
            mv = r["eval_metrics"].get("plutchik_f1", r["eval_metrics"].get("plutchik_soft"))

        if mv is None:
            continue

        if mv > best_value:
            best_value = mv
            best = r

    if best is None:
        raise ValueError(f"No trial contains metric '{metric_name}' or fallback metrics.")

    print(f"Best trial for {base_model_name} | mode={mode}")
    print(f"  trial_index = {best['trial_index']}")
    print(f"  {metric_name} (fallback) = {best_value:.4f}")
    print(f"  trial_dir = {best['trial_dir']}")

    return best


def quantize_best_trial(
    base_model_name: str,
    mode: str,
    metric_name: str = None,
) -> Dict[str, Any]:
    """
    Loads the best-performing model, attempts dynamic quantization,
    and exports the quantized model (or FP32 fallback).
    """

    if SMOKE_TEST:
        raise RuntimeError(
            "Quantization/export is disabled in SMOKE_TEST. "
            "Set SMOKE_TEST=False and rerun HPO in FULL mode."
        )

    best = select_best_trial(base_model_name, mode, metric_name)
    trial_dir = best["trial_dir"]

    print("\nExporting (with quantization attempt) from:", trial_dir)

    # Load model on CPU for quantization
    model = AutoModelForSequenceClassification.from_pretrained(trial_dir).cpu()
    model.eval()

    model_key = base_model_name.replace("/", "_")
    out_dir = os.path.join(
        FINAL_DIR,
        f"{model_key}_{mode}_quantized",
    )
    os.makedirs(out_dir, exist_ok=True)

    quantized = False

    # Attempt PyTorch dynamic quantization
    try:
        try:
            from torch.ao.quantization import quantize_dynamic
        except ImportError:
            from torch.quantization import quantize_dynamic

        qmodel = quantize_dynamic(
            model,
            {nn.Linear},
            dtype=torch.qint8,
        )
        qmodel.save_pretrained(out_dir)
        quantized = True
        print("Dynamic quantization successful.")
    except Exception as e:
        print("⚠️ Quantization failed, exporting FP32 instead:\n", e)
        model.save_pretrained(out_dir)

    tokenizer = AutoTokenizer.from_pretrained(trial_dir, use_fast=True)
    tokenizer.save_pretrained(out_dir)

    summary = {
        "base_model_name": base_model_name,
        "mode": mode,
        "metric_name": metric_name or ("plutchik_f1" if mode == RAW_CONFIG_NAME else "plutchik_soft"),
        "metric_value": best["eval_metrics"].get(
            metric_name or ("plutchik_f1" if mode == RAW_CONFIG_NAME else "plutchik_soft")
        ),
        "best_trial_index": best["trial_index"],
        "best_trial_dir": trial_dir,
        "quantized_dir": out_dir,
        "actually_quantized": quantized,
    }

    save_json(summary, os.path.join(out_dir, "quantized_summary.json"))
    print("Final model saved to:", out_dir, "| quantized =", quantized)

    return summary


In [None]:
# ============================================================
# HPO Launcher (improved)
# ============================================================

if not SMOKE_TEST:
  for model_cfg in BASE_MODELS:
      if model_cfg["hpo"]:
          model_name = model_cfg["name"]
          mode = model_cfg["hpo_mode"]

          print("\n" + "#"*80)
          print(f"RUNNING HPO FOR MODEL: {model_name} | MODE: {mode}")
          print("#"*80)

          for trial_idx in range(len(HPO_GRID)):
              print(f"\nRunning HPO trial {trial_idx} / {len(HPO_GRID)-1}")
              run_hpo_trial(model_name, mode, trial_idx)

          print("\n=== HPO finished ===")



################################################################################
RUNNING HPO FOR MODEL: nreimers/MiniLM-L6-H384-uncased | MODE: raw
################################################################################

Running HPO trial 0 / 2

Running HPO TRIAL for nreimers/MiniLM-L6-H384-uncased | mode=raw | trial=0
HPO config: {'learning_rate': 2e-05, 'warmup_ratio': 0.05, 'weight_decay': 0.01, 'scheduler': 'linear', 'batch_size': 16, 'max_length': 64, 'optimizer': 'adamw_torch'}
[TOKENIZE] Creating tokenized dataset for model=nreimers/MiniLM-L6-H384-uncased | mode=raw


tokenizer_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/614 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Tokenizing for model=nreimers/MiniLM-L6-H384-uncased, mode=raw:   0%|          | 0/190102 [00:00<?, ? examples…

Tokenizing for model=nreimers/MiniLM-L6-H384-uncased, mode=raw:   0%|          | 0/21123 [00:00<?, ? examples/…

Saving the dataset (0/1 shards):   0%|          | 0/190102 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/21123 [00:00<?, ? examples/s]

[TOKENIZE] Saved tokenized dataset to /content/drive/MyDrive/emotion_project/full/tokenized/nreimers_MiniLM-L6-H384-uncased_raw_len64


pytorch_model.bin:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nreimers/MiniLM-L6-H384-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  super().__init__(*args, **kwargs)


model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

Epoch,Training Loss,Validation Loss,F1 Micro,Plutchik F1
1,0.9579,1.001038,0.230716,0.349383
2,0.9673,0.968881,0.248957,0.364751
3,0.9054,0.956853,0.248472,0.366345


Eval metrics: {'eval_loss': 0.9568529725074768, 'eval_f1_micro': 0.24847205394933983, 'eval_plutchik_f1': 0.36634547401792666, 'eval_runtime': 20.064, 'eval_samples_per_second': 1052.783, 'eval_steps_per_second': 65.839, 'epoch': 3.0}
[DONE] Trial metrics saved to /content/drive/MyDrive/emotion_project/full/trials/nreimers_MiniLM-L6-H384-uncased_raw/trial_00/metrics.json

Running HPO trial 1 / 2

Running HPO TRIAL for nreimers/MiniLM-L6-H384-uncased | mode=raw | trial=1
HPO config: {'learning_rate': 3e-05, 'warmup_ratio': 0.1, 'weight_decay': 0.01, 'scheduler': 'cosine', 'batch_size': 8, 'max_length': 128, 'optimizer': 'adamw_torch'}
[TOKENIZE] Creating tokenized dataset for model=nreimers/MiniLM-L6-H384-uncased | mode=raw


Tokenizing for model=nreimers/MiniLM-L6-H384-uncased, mode=raw:   0%|          | 0/190102 [00:00<?, ? examples…

Tokenizing for model=nreimers/MiniLM-L6-H384-uncased, mode=raw:   0%|          | 0/21123 [00:00<?, ? examples/…

Saving the dataset (0/1 shards):   0%|          | 0/190102 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/21123 [00:00<?, ? examples/s]

[TOKENIZE] Saved tokenized dataset to /content/drive/MyDrive/emotion_project/full/tokenized/nreimers_MiniLM-L6-H384-uncased_raw_len128


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nreimers/MiniLM-L6-H384-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  super().__init__(*args, **kwargs)


Epoch,Training Loss,Validation Loss,F1 Micro,Plutchik F1
1,0.9298,1.061858,0.266855,0.374193


In [None]:
# ============================================================
# TRAINING Launcher (uses best HPO config)
# ============================================================
if not SMOKE_TEST:
  for model_cfg in BASE_MODELS:
      if model_cfg["train"]:
          model_name = model_cfg["name"]
          mode = model_cfg["train_mode"]

          print("\n" + "#"*80)
          print(f"RUNNING FINAL TRAINING FOR MODEL: {model_name} | MODE: {mode}")
          print("#"*80)

          # =========================================
          # Load best HPO trial
          # =========================================
          best = select_best_trial(model_name, RAW_CONFIG_NAME)

          best_trial_idx = best["trial_index"]
          best_cfg = best["hpo_config"]

          print(f"Using best trial index: {best_trial_idx}")
          print("Best HPO config:", best_cfg)

          # =========================================
          # Load dataset (tokenized)
          # =========================================
          tokenized_ds, tokenizer, label_info = get_tokenized_dataset_cached(
              base_model_name=model_name,
              mode=mode,
              max_len=best_cfg["max_length"],
          )

          num_labels = label_info["num_labels"]
          id2label = label_info["id2label"]
          label2id = label_info["label2id"]

          problem_type = (
              "multi_label_classification" if mode == RAW_CONFIG_NAME
              else "single_label_classification"
          )

          # =========================================
          # Build config + model
          # =========================================
          config = AutoConfig.from_pretrained(
              model_name,
              num_labels=num_labels,
              id2label=id2label,
              label2id=label2id,
              problem_type=problem_type,
          )

          model = AutoModelForSequenceClassification.from_pretrained(
              model_name,
              config=config,
          )

          # =========================================
          # Loss + metrics
          # =========================================
          if mode == RAW_CONFIG_NAME:
              compute_metrics = make_compute_metrics_raw(id2label)
              is_multilabel = True
              pos_weight = label_info["pos_weight"]
              class_weights = None
              best_metric_name = "plutchik_f1"
          else:
              compute_metrics = make_compute_metrics_simplified(id2label)
              is_multilabel = False
              pos_weight = None
              class_weights = label_info["class_weights"]
              best_metric_name = "plutchik_soft"

          # =========================================
          # TrainingArguments using best HPO config
          # =========================================
          out_dir = os.path.join(
              FINAL_DIR,
              f"{model_name.replace('/', '_')}_{mode}_final_trained"
          )
          os.makedirs(out_dir, exist_ok=True)

          training_args = TrainingArguments(
              output_dir=out_dir,
              learning_rate=best_cfg["learning_rate"],
              warmup_ratio=best_cfg["warmup_ratio"],
              weight_decay=best_cfg["weight_decay"],
              lr_scheduler_type=best_cfg["scheduler"],
              optim=best_cfg["optimizer"],
              num_train_epochs=NUM_EPOCHS,

              per_device_train_batch_size=best_cfg["batch_size"],
              per_device_eval_batch_size=best_cfg["batch_size"],

              eval_strategy="epoch",
              save_strategy="no",
              report_to=[],
              seed=SEED,
              fp16=False,
              bf16=torch.cuda.is_available(),
          )

          # =========================================
          # Trainer
          # =========================================
          trainer = WeightedTrainer(
              is_multilabel=is_multilabel,
              pos_weight=pos_weight,
              class_weights=class_weights,
              num_labels=num_labels,
              model=model,
              args=training_args,
              train_dataset=tokenized_ds["train"],
              eval_dataset=tokenized_ds["validation"],
              tokenizer=tokenizer,
              compute_metrics=compute_metrics,
          )

          # =========================================
          # TRAIN + EVALUATE
          # =========================================
          trainer.train()
          final_metrics = trainer.evaluate(tokenized_ds["validation"])

          print("\nFinal evaluation metrics:")
          print(final_metrics)

          # Save summary
          summary_path = os.path.join(out_dir, "training_summary.json")
          save_json({
              "model_name": model_name,
              "mode": mode,
              "best_trial_index": best_trial_idx,
              "best_hpo_config": best_cfg,
              "final_metrics": final_metrics,
          }, summary_path)

          print(f"\nSaved final trained model + summary to {out_dir}")
          print("#"*80)


In [None]:
# ============================================================
# QUANTIZATION + FINAL EVALUATION Launcher (SAFE MODE)
# ============================================================

#assert not SMOKE_TEST, "❌ Quantization cannot run in SMOKE_TEST=True!"
if not RUN_QUANTIZATION:
    print("⏭ Quantization skipped. Set RUN_QUANTIZATION=True to run.")
else:
    print("🚀 Starting QUANTIZATION + FINAL EVALUATION pipeline...")

# ============================================================

if RUN_QUANTIZATION:
    for model_cfg in BASE_MODELS:

        if not model_cfg["train"]:
            continue   # skip models not trained

        model_name = model_cfg["name"]

        # You can change this to run both versions:
        # for mode in [RAW_CONFIG_NAME, SIMPLIFIED_CONFIG_NAME]:
        mode = RAW_CONFIG_NAME

        print("\n" + "#"*80)
        print(f"QUANTIZATION + FINAL EVALUATION for MODEL: {model_name} | MODE: {mode}")
        print("#"*80)

        # ===============================
        # Check trained model directory
        # ===============================
        trained_dir = os.path.join(
            FINAL_DIR,
            f"{model_name.replace('/', '_')}_{mode}_final_trained",
        )
        if not os.path.exists(trained_dir):
            print(f"❌ Trained model not found at: {trained_dir}")
            print("Skipping this model...\n")
            continue

        # Output directory
        quantized_out = os.path.join(
            FINAL_DIR,
            f"{model_name.replace('/', '_')}_{mode}_quantized_final",
        )
        os.makedirs(quantized_out, exist_ok=True)

        # ===============================
        # Load dataset (cached)
        # ===============================
        tokenized_ds, tokenizer, label_info = get_tokenized_dataset_cached(
            base_model_name=model_name,
            mode=mode,
            max_len=MAX_LENGTH,
        )

        num_labels = label_info["num_labels"]
        id2label = label_info["id2label"]
        label2id = label_info["label2id"]

        # ===============================
        # Build model config
        # ===============================
        problem_type = (
            "multi_label_classification" if mode == RAW_CONFIG_NAME
            else "single_label_classification"
        )

        config = AutoConfig.from_pretrained(
            trained_dir,
            num_labels=num_labels,
            id2label=id2label,
            label2id=label2id,
            problem_type=problem_type,
        )

        # ===============================
        # Load trained model (CPU)
        # ===============================
        print("Loading trained model...")
        model_cpu = AutoModelForSequenceClassification.from_pretrained(
            trained_dir,
            config=config,
        ).cpu()
        model_cpu.eval()

        # ===============================
        # Attempt dynamic quantization
        # ===============================
        print("Attempting dynamic quantization...")
        quantized_success = False

        try:
            try:
                from torch.ao.quantization import quantize_dynamic
            except ImportError:
                from torch.quantization import quantize_dynamic

            qmodel = quantize_dynamic(
                model_cpu,
                {torch.nn.Linear},
                dtype=torch.qint8,
            )
            qmodel.save_pretrained(quantized_out)
            quantized_success = True
            print("✔ Dynamic quantization successful!")
        except Exception as e:
            print("⚠ Quantization failed, exporting FP32. Error:", e)
            model_cpu.save_pretrained(quantized_out)

        tokenizer.save_pretrained(quantized_out)

        # ===============================
        # Final evaluation
        # ===============================
        print("\nRunning evaluation of quantized model...")
        eval_model = AutoModelForSequenceClassification.from_pretrained(
            quantized_out,
            config=config,
        ).to(DEVICE)
        eval_model.eval()

        if mode == RAW_CONFIG_NAME:
            compute_metrics = make_compute_metrics_raw(id2label)
            is_multilabel = True
            pos_weight = label_info["pos_weight"]
            class_weights = None
        else:
            compute_metrics = make_compute_metrics_simplified(id2label)
            is_multilabel = False
            pos_weight = None
            class_weights = label_info["class_weights"]

        eval_args = TrainingArguments(
            output_dir=os.path.join(quantized_out, "eval_tmp"),
            per_device_eval_batch_size=8,
            report_to=[],
        )

        eval_trainer = WeightedTrainer(
            is_multilabel=is_multilabel,
            pos_weight=pos_weight,
            class_weights=class_weights,
            num_labels=num_labels,
            model=eval_model,
            args=eval_args,
            eval_dataset=tokenized_ds["validation"],
            tokenizer=tokenizer,
            compute_metrics=compute_metrics,
        )

        quant_metrics = eval_trainer.evaluate()

        print("\nFinal quantized model metrics:")
        print(quant_metrics)

        # ===============================
        # Save summary
        # ===============================
        summary = {
            "model_name": model_name,
            "mode": mode,
            "quantized_dir": quantized_out,
            "quantized_success": quantized_success,
            "quantized_metrics": quant_metrics,
        }

        save_json(summary, os.path.join(quantized_out, "quantization_summary.json"))

        print(f"\n✔ Saved quantized model + summary to {quantized_out}")
        print("#"*80)


In [None]:
# =========================================================
# SMOKE TEST (NO TOKENIZATION, NO DATASET DOWNLOAD)
# =========================================================
# Uses only cached tokenized datasets.
# Runs a 1-epoch training for each (model × mode).
# =========================================================

if SMOKE_TEST:
  SMOKE_EPOCHS = 1
  smoke_results = {}

  for model_cfg in BASE_MODELS:
      model_name = model_cfg["name"]
      for mode in [RAW_CONFIG_NAME, SIMPLIFIED_CONFIG_NAME]:

          print("\n" + "="*80)
          print(f"SMOKE TEST: model={model_name} | mode={mode}")
          print("="*80)

          # ------------------------------------------
          # Load *cached* tokenized dataset only
          # ------------------------------------------
          tokenized_ds, tokenizer, label_info = get_tokenized_dataset_cached(
              base_model_name=model_name,
              mode=mode,
              max_len=MAX_LENGTH,   # same as your earlier tokenization
          )

          num_labels = label_info["num_labels"]
          id2label = label_info["id2label"]
          label2id = label_info["label2id"]

          # ------------------------------------------
          # Build config
          # ------------------------------------------
          problem_type = (
              "multi_label_classification" if mode == RAW_CONFIG_NAME
              else "single_label_classification"
          )

          config = AutoConfig.from_pretrained(
              model_name,
              num_labels=num_labels,
              id2label=id2label,
              label2id=label2id,
              problem_type=problem_type,
          )

          # ------------------------------------------
          # Metrics + loss flags
          # ------------------------------------------
          if mode == RAW_CONFIG_NAME:
              compute_metrics = make_compute_metrics_raw(id2label)
              is_multilabel = True
              pos_weight = label_info["pos_weight"]
              class_weights = None
          else:
              compute_metrics = make_compute_metrics_simplified(id2label)
              is_multilabel = False
              pos_weight = None
              class_weights = label_info["class_weights"]

          # ------------------------------------------
          # Load pretrained model
          # ------------------------------------------
          model = AutoModelForSequenceClassification.from_pretrained(
              model_name,
              config=config,
          )

          # ------------------------------------------
          # Minimal training args
          # ------------------------------------------
          training_args = TrainingArguments(
              output_dir=f"/content/smoke_tmp/{model_name.replace('/', '_')}_{mode}",
              per_device_train_batch_size=4,
              per_device_eval_batch_size=4,
              num_train_epochs=SMOKE_EPOCHS,

              eval_strategy="epoch",
              save_strategy="no",
              report_to=[],
              seed=SEED,

              fp16=False,
              bf16=torch.cuda.is_available(),
          )

          # ------------------------------------------
          # Trainer
          # ------------------------------------------
          trainer = WeightedTrainer(
              is_multilabel=is_multilabel,
              pos_weight=pos_weight,
              class_weights=class_weights,
              num_labels=num_labels,
              model=model,
              args=training_args,
              train_dataset=tokenized_ds["train"],
              eval_dataset=tokenized_ds["validation"],
              tokenizer=tokenizer,
              compute_metrics=compute_metrics,
          )

          # ------------------------------------------
          # TRAIN + EVAL
          # ------------------------------------------
          trainer.train()
          metrics = trainer.evaluate(tokenized_ds["validation"])

          smoke_results[(model_name, mode)] = metrics
          print(f"RESULT ({model_name}, {mode}):")
          print(metrics)
          print("="*80)


In [None]:
# ============================================================
# UNIVERSAL EXPERIMENT DASHBOARD
# Added:
#   ✔ Summary table (Pandas dataframe)
#   ✔ Saving all plots to Drive
# ============================================================

import json
import matplotlib.pyplot as plt
import pandas as pd
from glob import glob

PLOTS_DIR = os.path.join(FINAL_DIR, "plots")
os.makedirs(PLOTS_DIR, exist_ok=True)

def savefig(name):
    """Helper to save current plot to Drive."""
    path = os.path.join(PLOTS_DIR, name)
    plt.savefig(path, bbox_inches="tight", dpi=200)
    print(f"📁 Saved plot: {path}")


# ------------------------------------------------------------
# Helpers
# ------------------------------------------------------------
def try_load_json(path):
    if not os.path.exists(path):
        return None
    with open(path, "r") as f:
        return json.load(f)

def find_hpo_results():
    results = {}
    for model_cfg in BASE_MODELS:
        name = model_cfg["name"]
        key = name.replace("/", "_")
        for mode in [RAW_CONFIG_NAME, SIMPLIFIED_CONFIG_NAME]:
            trial_root = os.path.join(TRIALS_DIR, f"{key}_{mode}")
            if not os.path.exists(trial_root):
                continue
            trial_jsons = sorted(glob(os.path.join(trial_root, "trial_*/metrics.json")))
            trials = [try_load_json(p) for p in trial_jsons if try_load_json(p)]
            if trials:
                results[(name, mode)] = trials
    return results

def find_trained_models():
    trained = []
    for model_cfg in BASE_MODELS:
        name = model_cfg["name"]
        key = name.replace("/", "_")
        for mode in [RAW_CONFIG_NAME, SIMPLIFIED_CONFIG_NAME]:
            p = os.path.join(FINAL_DIR, f"{key}_{mode}_final_trained")
            if os.path.exists(p):
                trained.append((name, mode, p))
    return trained

def find_quantized_models():
    quant = []
    for model_cfg in BASE_MODELS:
        name = model_cfg["name"]
        key = name.replace("/", "_")
        for mode in [RAW_CONFIG_NAME, SIMPLIFIED_CONFIG_NAME]:
            p = os.path.join(FINAL_DIR, f"{key}_{mode}_quantized_final")
            if os.path.exists(p):
                summary = try_load_json(os.path.join(p, "quantization_summary.json"))
                quant.append((name, mode, p, summary))
    return quant


# ============================================================
# 1) PLOT HPO RESULTS (AND SAVE)
# ============================================================

hpo_results = find_hpo_results()

if hpo_results:
    print("📊 Plotting HPO results...")

    plt.figure(figsize=(12, 6))

    for (model_name, mode), trials in hpo_results.items():
        metric = "plutchik_f1" if mode == RAW_CONFIG_NAME else "plutchik_soft"
        xs = [t["trial_index"] for t in trials]
        ys = [t["eval_metrics"].get(metric, 0.0) for t in trials]
        plt.plot(xs, ys, marker="o", label=f"{model_name} | {mode}")

    plt.title("HPO Results Across Models")
    plt.xlabel("Trial Index")
    plt.ylabel("Score")
    plt.grid(True)
    plt.legend()
    savefig("hpo_results.png")
    plt.show()
else:
    print("⚠ No HPO results found.")


# ============================================================
# 2) PLOT TRAINING CURVES (AND SAVE)
# ============================================================

trained_models = find_trained_models()

for (model_name, mode, path) in trained_models:
    print(f"\n📈 Plotting training curves for {model_name} | {mode}")

    state_path = os.path.join(path, "trainer_state.json")
    state = try_load_json(state_path)

    if not state:
        print("  ⚠ trainer_state.json missing")
        continue

    logs = state.get("log_history", [])
    train_loss = [l["loss"] for l in logs if "loss" in l]
    eval_loss = [l["eval_loss"] for l in logs if "eval_loss" in l]

    # Training Loss
    if train_loss:
        plt.figure(figsize=(10,4))
        plt.plot(train_loss)
        plt.title(f"Training Loss — {model_name} ({mode})")
        plt.xlabel("Step")
        plt.ylabel("Loss")
        plt.grid(True)
        name = f"training_loss_{model_name.replace('/', '_')}_{mode}.png"
        savefig(name)
        plt.show()

    # Eval Loss
    if eval_loss:
        plt.figure(figsize=(10,4))
        plt.plot(eval_loss, marker="o")
        plt.title(f"Eval Loss — {model_name} ({mode})")
        plt.xlabel("Epoch")
        plt.ylabel("Eval Loss")
        plt.grid(True)
        name = f"eval_loss_{model_name.replace('/', '_')}_{mode}.png"
        savefig(name)
        plt.show()


# ============================================================
# 3) FULL vs QUANTIZED PERFORMANCE (AND SAVE)
# ============================================================

quantized_models = find_quantized_models()

for (model_name, mode, p, summary) in quantized_models:
    if not summary:
        continue

    metric = "plutchik_f1" if mode == RAW_CONFIG_NAME else "plutchik_soft"
    q_val = summary["quantized_metrics"].get(f"eval_{metric}")

    # Load full model metric
    key = model_name.replace("/", "_")
    full_path = os.path.join(FINAL_DIR, f"{key}_{mode}_final_trained")
    trainer_state = try_load_json(os.path.join(full_path, "trainer_state.json"))

    full_val = None
    if trainer_state:
        eval_logs = [l for l in trainer_state["log_history"] if f"eval_{metric}" in l]
        if eval_logs:
            full_val = eval_logs[-1][f"eval_{metric}"]

    if full_val is None or q_val is None:
        print(f"⚠ Missing full/quantized metrics for {model_name} ({mode})")
        continue

    plt.figure(figsize=(6,4))
    plt.bar(["Full", "Quantized"], [full_val, q_val])
    plt.title(f"{metric}: Full vs Quantized — {model_name} ({mode})")
    plt.ylabel(metric)
    plt.grid(axis="y")
    name = f"full_vs_quant_{model_name.replace('/', '_')}_{mode}.png"
    savefig(name)
    plt.show()


# ============================================================
# 4) SUMMARY TABLE
# ============================================================

rows = []
for (model_name, mode), hpo_trials in hpo_results.items():
    # best trial
    metric = "plutchik_f1" if mode == RAW_CONFIG_NAME else "plutchik_soft"
    best_trial = max(hpo_trials, key=lambda t: t["eval_metrics"].get(metric, -1))

    # trained?
    trained_path = os.path.join(
        FINAL_DIR,
        f"{model_name.replace('/', '_')}_{mode}_final_trained"
    )
    is_trained = os.path.exists(trained_path)

    # quantized?
    quant_path = os.path.join(
        FINAL_DIR,
        f"{model_name.replace('/', '_')}_{mode}_quantized_final"
    )
    quant_summary = try_load_json(os.path.join(quant_path, "quantization_summary.json"))

    rows.append({
        "Model": model_name,
        "Mode": mode,
        "Best HPO Trial": best_trial["trial_index"],
        "Best HPO Score": best_trial["eval_metrics"].get(metric),
        "Trained?": is_trained,
        "Trained Path": trained_path if is_trained else None,
        "Quantized?": quant_summary is not None,
        "Quantized Score": quant_summary["quantized_metrics"].get(f"eval_{metric}")
            if quant_summary else None,
        "Quantized Path": quant_path if quant_summary else None,
    })

df_summary = pd.DataFrame(rows)
print("\n📘 EXPERIMENT SUMMARY TABLE:")
display(df_summary)

# Save summary table
df_summary.to_csv(os.path.join(PLOTS_DIR, "experiment_summary.csv"), index=False)
print(f"📁 Saved summary table to: {os.path.join(PLOTS_DIR, 'experiment_summary.csv')}")
