## Multi task transformers

In [3]:
import os, json, random
import numpy as np
import pandas as pd
from dataclasses import dataclass
from typing import Dict, Any

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from sklearn.metrics import accuracy_score, f1_score, classification_report
from tqdm import tqdm


# -------------------------
# Config
# -------------------------
SEED = 42
MODEL_NAME = "microsoft/deberta-v3-base"   # strong baseline
MAX_LEN = 256
BATCH_SIZE = 16
LR = 2e-5
EPOCHS = 4
WARMUP_RATIO = 0.06
WEIGHT_DECAY = 0.01
GRAD_CLIP = 1.0

TRAIN_CSV = "incidents_train.csv"
VALID_CSV = "incidents_valid.csv"
TEST_CSV  = "incidents_test.csv"

OUT_DIR = "sft_results"
os.makedirs(OUT_DIR, exist_ok=True)


# -------------------------
# Repro
# -------------------------
def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_everything(SEED)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)


# -------------------------
# Load data
# -------------------------
def load_df(path: str) -> pd.DataFrame:
    df = pd.read_csv(path)
    for c in ["title", "text", "hazard-category", "product-category"]:
        if c not in df.columns:
            raise ValueError(f"Missing column: {c}")
    df["title"] = df["title"].fillna("").astype(str)
    df["text"] = df["text"].fillna("").astype(str)
    df["hazard-category"] = df["hazard-category"].fillna("").astype(str)
    df["product-category"] = df["product-category"].fillna("").astype(str)
    return df

train_df = load_df(TRAIN_CSV)
valid_df = load_df(VALID_CSV)
test_df  = load_df(TEST_CSV)

print(f"Loaded train={len(train_df)} valid={len(valid_df)} test={len(test_df)}")


# -------------------------
# Label maps (from TRAIN only)
# -------------------------
haz_labels = sorted(train_df["hazard-category"].unique().tolist())
prod_labels = sorted(train_df["product-category"].unique().tolist())

haz2id = {l:i for i,l in enumerate(haz_labels)}
id2haz = {i:l for l,i in haz2id.items()}

prod2id = {l:i for i,l in enumerate(prod_labels)}
id2prod = {i:l for l,i in prod2id.items()}

print("Hazard classes:", len(haz_labels))
print("Product classes:", len(prod_labels))

with open(os.path.join(OUT_DIR, "label_maps.json"), "w", encoding="utf-8") as f:
    json.dump(
        {"haz_labels": haz_labels, "prod_labels": prod_labels},
        f, ensure_ascii=False, indent=2
    )


# -------------------------
# Class weights (imbalance handling)
# -------------------------
def make_class_weights(series: pd.Series, label2id: Dict[str,int]) -> torch.Tensor:
    counts = series.value_counts().to_dict()
    n = len(series)
    k = len(label2id)
    # weight = N / (K * count)
    w = np.zeros(k, dtype=np.float32)
    for lab, idx in label2id.items():
        c = counts.get(lab, 1)
        w[idx] = n / (k * c)
    # normalize (optional)
    w = w / w.mean()
    return torch.tensor(w, dtype=torch.float32)

haz_w = make_class_weights(train_df["hazard-category"], haz2id).to(DEVICE)
prod_w = make_class_weights(train_df["product-category"], prod2id).to(DEVICE)


# -------------------------
# Dataset
# -------------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
class IncidentDataset(Dataset):
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        r = self.df.iloc[idx]
        # simple concatenation
        text = (r["title"] + " [SEP] " + r["text"]).strip()

        enc = tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=MAX_LEN,
            return_tensors="pt"
        )

        haz_id = haz2id.get(r["hazard-category"], None)
        prod_id = prod2id.get(r["product-category"], None)
        if haz_id is None or prod_id is None:
            raise ValueError("Found label in valid/test not seen in train.")

        item = {
            "input_ids": enc["input_ids"].squeeze(0),
            "attention_mask": enc["attention_mask"].squeeze(0),
            "haz_label": torch.tensor(haz_id, dtype=torch.long),
            "prod_label": torch.tensor(prod_id, dtype=torch.long),
        }
        return item

train_ds = IncidentDataset(train_df)
valid_ds = IncidentDataset(valid_df)
test_ds  = IncidentDataset(test_df)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)


# -------------------------
# Multi-task Model
# -------------------------
class MultiTaskClassifier(nn.Module):
    def __init__(self, base_model_name: str, n_haz: int, n_prod: int, dropout: float = 0.1):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(base_model_name)
        hidden = self.encoder.config.hidden_size
        self.dropout = nn.Dropout(dropout)
        self.haz_head = nn.Linear(hidden, n_haz)
        self.prod_head = nn.Linear(hidden, n_prod)

    def forward(self, input_ids, attention_mask):
        out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        # use CLS-like token: for DeBERTa it's first token representation
        pooled = out.last_hidden_state[:, 0, :]
        pooled = self.dropout(pooled)
        haz_logits = self.haz_head(pooled)
        prod_logits = self.prod_head(pooled)
        return haz_logits, prod_logits

model = MultiTaskClassifier(MODEL_NAME, len(haz_labels), len(prod_labels)).to(DEVICE).float()

# -------------------------
# Optimizer / Scheduler
# -------------------------
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
total_steps = len(train_loader) * EPOCHS
warmup_steps = int(total_steps * WARMUP_RATIO)
scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_steps)

haz_loss_fn = nn.CrossEntropyLoss(weight=haz_w)
prod_loss_fn = nn.CrossEntropyLoss(weight=prod_w)


# -------------------------
# Eval helper
# -------------------------
@torch.no_grad()
def evaluate(loader: DataLoader, split_name: str):
    model.eval()
    haz_true, haz_pred = [], []
    prod_true, prod_pred = [], []

    for batch in tqdm(loader, desc=f"Eval {split_name}", leave=False):
        input_ids = batch["input_ids"].to(DEVICE)
        attn = batch["attention_mask"].to(DEVICE)
        haz_y = batch["haz_label"].to(DEVICE)
        prod_y = batch["prod_label"].to(DEVICE)

        haz_logits, prod_logits = model(input_ids, attn)

        haz_p = torch.argmax(haz_logits, dim=-1)
        prod_p = torch.argmax(prod_logits, dim=-1)

        haz_true.extend(haz_y.cpu().tolist())
        haz_pred.extend(haz_p.cpu().tolist())
        prod_true.extend(prod_y.cpu().tolist())
        prod_pred.extend(prod_p.cpu().tolist())

    def metrics(y_true, y_pred, labels, title):
        acc = accuracy_score(y_true, y_pred)
        macro = f1_score(y_true, y_pred, average="macro", zero_division=0)
        micro = f1_score(y_true, y_pred, average="micro", zero_division=0)
        weighted = f1_score(y_true, y_pred, average="weighted", zero_division=0)
        print(f"\n=== {title} ({split_name}) ===")
        print("Accuracy:", acc)
        print("Macro-F1:", macro)
        print("Micro-F1:", micro)
        print("Weighted-F1:", weighted)
        print("\nClassification Report:\n")
        label_ids = list(range(len(labels)))
        print(classification_report(
            y_true, y_pred,
            labels=label_ids,
            target_names=labels,
            zero_division=0
        ))
        return {"acc": acc, "macro_f1": macro, "micro_f1": micro, "weighted_f1": weighted}

    haz_m = metrics(haz_true, haz_pred, haz_labels, "Hazard-category")
    prod_m = metrics(prod_true, prod_pred, prod_labels, "Product-category")

    return {"hazard": haz_m, "product": prod_m}


# -------------------------
# Train
# -------------------------
best_score = -1.0
best_path = os.path.join(OUT_DIR, "best_model.pt")

for epoch in range(1, EPOCHS + 1):
    model.train()
    pbar = tqdm(train_loader, desc=f"Train epoch {epoch}/{EPOCHS}")
    running_loss = 0.0

    for batch in pbar:
        optimizer.zero_grad(set_to_none=True)

        input_ids = batch["input_ids"].to(DEVICE)
        attn = batch["attention_mask"].to(DEVICE)
        haz_y = batch["haz_label"].to(DEVICE)
        prod_y = batch["prod_label"].to(DEVICE)

        haz_logits, prod_logits = model(input_ids, attn)

        loss_h = haz_loss_fn(haz_logits, haz_y)
        loss_p = prod_loss_fn(prod_logits, prod_y)

        # simple equal weighting
        loss = loss_h + loss_p

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
        optimizer.step()
        scheduler.step()

        running_loss += loss.item()
        pbar.set_postfix(loss=running_loss / (pbar.n + 1))

    # Validate
    val_metrics = evaluate(valid_loader, "valid")

    # choose best by avg macro-f1 of two tasks
    score = (val_metrics["hazard"]["macro_f1"] + val_metrics["product"]["macro_f1"]) / 2.0
    print(f"\nEpoch {epoch} combined macro-F1 (haz+prod)/2 = {score:.4f}")

    if score > best_score:
        best_score = score
        torch.save({"model_state_dict": model.state_dict(), "epoch": epoch, "score": score}, best_path)
        print("✅ Saved new best:", best_path)

print("\nBest combined macro-F1:", best_score)

# -------------------------
# Final test eval (load best)
# -------------------------
ckpt = torch.load(best_path, map_location=DEVICE)
model.load_state_dict(ckpt["model_state_dict"])
print("\nLoaded best checkpoint from epoch", ckpt["epoch"], "score", ckpt["score"])

test_metrics = evaluate(test_loader, "test")
with open(os.path.join(OUT_DIR, "test_metrics.json"), "w", encoding="utf-8") as f:
    json.dump(test_metrics, f, ensure_ascii=False, indent=2)

print("\nSaved test metrics to:", os.path.join(OUT_DIR, "test_metrics.json"))

Device: cuda
Loaded train=5082 valid=565 test=997
Hazard classes: 10
Product classes: 22


Loading weights:   0%|          | 0/198 [00:00<?, ?it/s]

[1mDebertaV2Model LOAD REPORT[0m from: microsoft/deberta-v3-base
Key                                     | Status     |  | 
----------------------------------------+------------+--+-
lm_predictions.lm_head.bias             | UNEXPECTED |  | 
lm_predictions.lm_head.LayerNorm.bias   | UNEXPECTED |  | 
mask_predictions.dense.bias             | UNEXPECTED |  | 
lm_predictions.lm_head.dense.weight     | UNEXPECTED |  | 
lm_predictions.lm_head.dense.bias       | UNEXPECTED |  | 
mask_predictions.classifier.bias        | UNEXPECTED |  | 
mask_predictions.dense.weight           | UNEXPECTED |  | 
mask_predictions.LayerNorm.weight       | UNEXPECTED |  | 
lm_predictions.lm_head.LayerNorm.weight | UNEXPECTED |  | 
mask_predictions.LayerNorm.bias         | UNEXPECTED |  | 
mask_predictions.classifier.weight      | UNEXPECTED |  | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.[0m
Train epoch 1/4: 100%|███████████


=== Hazard-category (valid) ===
Accuracy: 0.46371681415929206
Macro-F1: 0.2501607729172971
Micro-F1: 0.46371681415929206
Weighted-F1: 0.47879384126282853

Classification Report:

                                precision    recall  f1-score   support

                     allergens       0.60      0.57      0.59       207
                    biological       0.70      0.44      0.54       194
                      chemical       0.18      0.50      0.26        28
food additives and flavourings       0.00      0.00      0.00         2
                foreign bodies       0.29      0.38      0.33        63
                         fraud       0.30      0.41      0.35        41
                     migration       0.00      0.00      0.00         0
          organoleptic aspects       0.00      0.00      0.00         8
                  other hazard       0.14      0.29      0.19        14
              packaging defect       0.00      0.00      0.00         8

                      accu

Train epoch 2/4: 100%|████████████████████████████████████████████████████| 318/318 [48:30<00:00,  9.15s/it, loss=4.52]
                                                                                                                       


=== Hazard-category (valid) ===
Accuracy: 0.8389380530973451
Macro-F1: 0.4859949506052377
Micro-F1: 0.8389380530973451
Weighted-F1: 0.8301882738287512

Classification Report:

                                precision    recall  f1-score   support

                     allergens       0.93      0.97      0.95       207
                    biological       0.90      0.95      0.93       194
                      chemical       0.79      0.54      0.64        28
food additives and flavourings       0.00      0.00      0.00         2
                foreign bodies       0.65      0.78      0.71        63
                         fraud       0.80      0.49      0.61        41
                     migration       0.00      0.00      0.00         0
          organoleptic aspects       0.00      0.00      0.00         8
                  other hazard       0.33      0.21      0.26        14
              packaging defect       0.33      0.25      0.29         8

                      accurac

Train epoch 3/4: 100%|████████████████████████████████████████████████████| 318/318 [48:29<00:00,  9.15s/it, loss=3.58]
                                                                                                                       


=== Hazard-category (valid) ===
Accuracy: 0.9061946902654867
Macro-F1: 0.6692479349530613
Micro-F1: 0.9061946902654867
Weighted-F1: 0.9004645325715366

Classification Report:

                                precision    recall  f1-score   support

                     allergens       0.92      1.00      0.96       207
                    biological       0.96      0.96      0.96       194
                      chemical       0.83      0.68      0.75        28
food additives and flavourings       0.00      0.00      0.00         2
                foreign bodies       0.91      0.95      0.93        63
                         fraud       0.88      0.56      0.69        41
                     migration       0.00      0.00      0.00         0
          organoleptic aspects       0.67      0.50      0.57         8
                  other hazard       0.44      0.50      0.47        14
              packaging defect       0.67      0.75      0.71         8

                      accurac

Train epoch 4/4: 100%|████████████████████████████████████████████████████| 318/318 [48:29<00:00,  9.15s/it, loss=2.99]
                                                                                                                       


=== Hazard-category (valid) ===
Accuracy: 0.9097345132743363
Macro-F1: 0.7534333642141887
Micro-F1: 0.9097345132743363
Weighted-F1: 0.9077597212370787

Classification Report:

                                precision    recall  f1-score   support

                     allergens       0.95      0.98      0.96       207
                    biological       0.97      0.95      0.96       194
                      chemical       0.78      0.75      0.76        28
food additives and flavourings       1.00      0.50      0.67         2
                foreign bodies       0.86      0.95      0.90        63
                         fraud       0.78      0.68      0.73        41
                     migration       0.00      0.00      0.00         0
          organoleptic aspects       0.67      0.50      0.57         8
                  other hazard       0.54      0.50      0.52        14
              packaging defect       0.67      0.75      0.71         8

                      accurac

                                                                                                                       


=== Hazard-category (test) ===
Accuracy: 0.8996990972918756
Macro-F1: 0.6411859062976499
Micro-F1: 0.8996990972918756
Weighted-F1: 0.9020635107831233

Classification Report:

                                precision    recall  f1-score   support

                     allergens       0.95      0.95      0.95       365
                    biological       0.99      0.94      0.96       343
                      chemical       0.75      0.87      0.80        52
food additives and flavourings       0.50      0.50      0.50         4
                foreign bodies       0.93      0.95      0.94       111
                         fraud       0.68      0.63      0.65        75
                     migration       0.00      0.00      0.00         1
          organoleptic aspects       0.67      0.60      0.63        10
                  other hazard       0.50      0.54      0.52        26
              packaging defect       0.33      0.70      0.45        10

                      accuracy



## Separate model 



In [1]:
import os, json, random, gc, math
import numpy as np
import pandas as pd
from typing import Dict, Any

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from sklearn.metrics import accuracy_score, f1_score, classification_report
from tqdm import tqdm


# -------------------------
# Config
# -------------------------
SEED = 42
MODEL_NAME = "microsoft/deberta-v3-base"
MAX_LEN = 512

# ✅ OOM-safe settings
BATCH_SIZE = 4
GRAD_ACCUM_STEPS = 4
USE_AMP = True
USE_GRAD_CHECKPOINTING = True

LR = 2e-5
EPOCHS = 4
WARMUP_RATIO = 0.06
WEIGHT_DECAY = 0.01
GRAD_CLIP = 1.0

TRAIN_CSV = "incidents_train.csv"
VALID_CSV = "incidents_valid.csv"
TEST_CSV  = "incidents_test.csv"

OUT_DIR = "sft_results_prod_only_len512_amp"
os.makedirs(OUT_DIR, exist_ok=True)


# -------------------------
# Repro
# -------------------------
def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_everything(SEED)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

gc.collect()
if DEVICE == "cuda":
    torch.cuda.empty_cache()


# -------------------------
# Load data
# -------------------------
def load_df(path: str) -> pd.DataFrame:
    df = pd.read_csv(path)
    for c in ["title", "text", "hazard-category", "product-category"]:
        if c not in df.columns:
            raise ValueError(f"Missing column: {c}")
    df["title"] = df["title"].fillna("").astype(str)
    df["text"] = df["text"].fillna("").astype(str)
    df["hazard-category"] = df["hazard-category"].fillna("").astype(str)
    df["product-category"] = df["product-category"].fillna("").astype(str)
    return df

train_df = load_df(TRAIN_CSV)
valid_df = load_df(VALID_CSV)
test_df  = load_df(TEST_CSV)

print(f"Loaded train={len(train_df)} valid={len(valid_df)} test={len(test_df)}")


# -------------------------
# Label maps (from TRAIN only)
# -------------------------
prod_labels = sorted(train_df["product-category"].unique().tolist())
prod2id = {l:i for i,l in enumerate(prod_labels)}
id2prod = {i:l for l,i in prod2id.items()}

print("Product classes:", len(prod_labels))

with open(os.path.join(OUT_DIR, "label_maps.json"), "w", encoding="utf-8") as f:
    json.dump({"prod_labels": prod_labels}, f, ensure_ascii=False, indent=2)


# -------------------------
# Class weights
# -------------------------
def make_class_weights(series: pd.Series, label2id: Dict[str,int]) -> torch.Tensor:
    counts = series.value_counts().to_dict()
    n = len(series)
    k = len(label2id)
    w = np.zeros(k, dtype=np.float32)
    for lab, idx in label2id.items():
        c = counts.get(lab, 1)
        w[idx] = n / (k * c)
    w = w / w.mean()
    return torch.tensor(w, dtype=torch.float32)

prod_w = make_class_weights(train_df["product-category"], prod2id).to(DEVICE)


# -------------------------
# Dataset
# -------------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)

class IncidentDataset(Dataset):
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        r = self.df.iloc[idx]
        text = (r["title"] + " [SEP] " + r["text"]).strip()

        enc = tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=MAX_LEN,
            return_tensors="pt"
        )

        prod_id = prod2id.get(r["product-category"], None)
        if prod_id is None:
            raise ValueError("Found product-category label in valid/test not seen in train.")

        return {
            "input_ids": enc["input_ids"].squeeze(0),
            "attention_mask": enc["attention_mask"].squeeze(0),
            "prod_label": torch.tensor(prod_id, dtype=torch.long),
        }

train_ds = IncidentDataset(train_df)
valid_ds = IncidentDataset(valid_df)
test_ds  = IncidentDataset(test_df)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=0, pin_memory=True)
valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)


# -------------------------
# Product-only Model
# -------------------------
class ProductClassifier(nn.Module):
    def __init__(self, base_model_name: str, n_prod: int, dropout: float = 0.1):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(base_model_name)
        if USE_GRAD_CHECKPOINTING and hasattr(self.encoder, "gradient_checkpointing_enable"):
            self.encoder.gradient_checkpointing_enable()
        hidden = self.encoder.config.hidden_size
        self.dropout = nn.Dropout(dropout)
        self.prod_head = nn.Linear(hidden, n_prod)

    def forward(self, input_ids, attention_mask):
        out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        pooled = out.last_hidden_state[:, 0, :]
        pooled = self.dropout(pooled)
        return self.prod_head(pooled)

model = ProductClassifier(MODEL_NAME, len(prod_labels)).to(DEVICE).float()


# -------------------------
# Optimizer / Scheduler
# -------------------------
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

# ✅ correct number of optimizer steps (with gradient accumulation + last partial step)
steps_per_epoch = math.ceil(len(train_loader) / GRAD_ACCUM_STEPS)
total_steps = steps_per_epoch * EPOCHS
warmup_steps = int(total_steps * WARMUP_RATIO)

scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_steps)
prod_loss_fn = nn.CrossEntropyLoss(weight=prod_w)

# ✅ new AMP API (removes FutureWarning)
scaler = torch.amp.GradScaler("cuda", enabled=(USE_AMP and DEVICE == "cuda"))


# -------------------------
# Eval helper
# -------------------------
@torch.no_grad()
def evaluate(loader: DataLoader, split_name: str):
    model.eval()
    y_true, y_pred = [], []

    for batch in tqdm(loader, desc=f"Eval {split_name}", leave=False):
        input_ids = batch["input_ids"].to(DEVICE, non_blocking=True)
        attn = batch["attention_mask"].to(DEVICE, non_blocking=True)
        y = batch["prod_label"].to(DEVICE, non_blocking=True)

        with torch.amp.autocast("cuda", enabled=(USE_AMP and DEVICE == "cuda")):
            logits = model(input_ids, attn)

        p = torch.argmax(logits, dim=-1)
        y_true.extend(y.cpu().tolist())
        y_pred.extend(p.cpu().tolist())

    acc = accuracy_score(y_true, y_pred)
    macro = f1_score(y_true, y_pred, average="macro", zero_division=0)
    micro = f1_score(y_true, y_pred, average="micro", zero_division=0)
    weighted = f1_score(y_true, y_pred, average="weighted", zero_division=0)

    print(f"\n=== Product-category ({split_name}) ===")
    print("Accuracy:", acc)
    print("Macro-F1:", macro)
    print("Micro-F1:", micro)
    print("Weighted-F1:", weighted)
    print("\nClassification Report:\n")
    label_ids = list(range(len(prod_labels)))
    print(classification_report(
        y_true, y_pred,
        labels=label_ids,
        target_names=prod_labels,
        zero_division=0
    ))

    return {"acc": acc, "macro_f1": macro, "micro_f1": micro, "weighted_f1": weighted}


# -------------------------
# Train
# -------------------------
best_score = -1.0
best_path = os.path.join(OUT_DIR, "best_model_prod.pt")

for epoch in range(1, EPOCHS + 1):
    model.train()
    pbar = tqdm(train_loader, desc=f"Train epoch {epoch}/{EPOCHS}")
    running_loss = 0.0

    optimizer.zero_grad(set_to_none=True)
    did_step = False  # track if we performed an optimizer step in this epoch

    for step, batch in enumerate(pbar, start=1):
        input_ids = batch["input_ids"].to(DEVICE, non_blocking=True)
        attn = batch["attention_mask"].to(DEVICE, non_blocking=True)
        y = batch["prod_label"].to(DEVICE, non_blocking=True)

        with torch.amp.autocast("cuda", enabled=(USE_AMP and DEVICE == "cuda")):
            logits = model(input_ids, attn)
            loss = prod_loss_fn(logits, y)
            loss = loss / GRAD_ACCUM_STEPS

        scaler.scale(loss).backward()
        running_loss += loss.item() * GRAD_ACCUM_STEPS
        pbar.set_postfix(loss=running_loss / step)

        if step % GRAD_ACCUM_STEPS == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
            scheduler.step()  # ✅ always after optimizer.step()
            did_step = True

    # ✅ flush remaining grads if epoch ended mid-accumulation
    if (step % GRAD_ACCUM_STEPS) != 0:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)
        scheduler.step()
        did_step = True

    # Validate
    val_m = evaluate(valid_loader, "valid")
    score = val_m["macro_f1"]
    print(f"\nEpoch {epoch} product macro-F1 (valid) = {score:.4f}")

    if score > best_score:
        best_score = score
        torch.save({"model_state_dict": model.state_dict(), "epoch": epoch, "score": score}, best_path)
        print("✅ Saved new best:", best_path)

print("\nBest product macro-F1:", best_score)


# -------------------------
# Final test eval (load best)
# -------------------------
ckpt = torch.load(best_path, map_location=DEVICE)
model.load_state_dict(ckpt["model_state_dict"])
print("\nLoaded best checkpoint from epoch", ckpt["epoch"], "score", ckpt["score"])

test_metrics = evaluate(test_loader, "test")
with open(os.path.join(OUT_DIR, "test_metrics_prod.json"), "w", encoding="utf-8") as f:
    json.dump(test_metrics, f, ensure_ascii=False, indent=2)

print("\nSaved test metrics to:", os.path.join(OUT_DIR, "test_metrics_prod.json"))

Device: cuda
Loaded train=5082 valid=565 test=997
Product classes: 22


Loading weights:   0%|          | 0/198 [00:00<?, ?it/s]

[1mDebertaV2Model LOAD REPORT[0m from: microsoft/deberta-v3-base
Key                                     | Status     |  | 
----------------------------------------+------------+--+-
lm_predictions.lm_head.LayerNorm.bias   | UNEXPECTED |  | 
lm_predictions.lm_head.LayerNorm.weight | UNEXPECTED |  | 
lm_predictions.lm_head.dense.bias       | UNEXPECTED |  | 
mask_predictions.dense.bias             | UNEXPECTED |  | 
mask_predictions.LayerNorm.bias         | UNEXPECTED |  | 
mask_predictions.dense.weight           | UNEXPECTED |  | 
lm_predictions.lm_head.bias             | UNEXPECTED |  | 
mask_predictions.classifier.bias        | UNEXPECTED |  | 
mask_predictions.classifier.weight      | UNEXPECTED |  | 
mask_predictions.LayerNorm.weight       | UNEXPECTED |  | 
lm_predictions.lm_head.dense.weight     | UNEXPECTED |  | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.[0m
Train epoch 1/4: 100%|███████████


=== Product-category (valid) ===
Accuracy: 0.20707964601769913
Macro-F1: 0.08868873882537093
Micro-F1: 0.20707964601769913
Weighted-F1: 0.20757837137610594

Classification Report:

                                                   precision    recall  f1-score   support

                              alcoholic beverages       0.00      0.00      0.00         7
                      cereals and bakery products       0.22      0.05      0.09        75
     cocoa and cocoa preparations, coffee and tea       0.06      0.60      0.11        15
                                    confectionery       0.06      0.38      0.11        26
dietetic foods, food supplements, fortified foods       0.00      0.00      0.00        14
                                    fats and oils       0.00      0.00      0.00         4
                                   feed materials       0.00      0.00      0.00         1
                   food additives and flavourings       0.00      0.00      0.00         

Train epoch 2/4: 100%|██████████████████████████████████████████████████| 1271/1271 [04:53<00:00,  4.33it/s, loss=2.45]
                                                                                                                       


=== Product-category (valid) ===
Accuracy: 0.479646017699115
Macro-F1: 0.2659434868607698
Micro-F1: 0.479646017699115
Weighted-F1: 0.4304328061767163

Classification Report:

                                                   precision    recall  f1-score   support

                              alcoholic beverages       0.26      1.00      0.41         7
                      cereals and bakery products       0.53      0.12      0.20        75
     cocoa and cocoa preparations, coffee and tea       0.19      0.93      0.31        15
                                    confectionery       0.12      0.08      0.10        26
dietetic foods, food supplements, fortified foods       0.14      0.07      0.10        14
                                    fats and oils       0.00      0.00      0.00         4
                                   feed materials       0.00      0.00      0.00         1
                   food additives and flavourings       0.00      0.00      0.00         0
    

Train epoch 3/4: 100%|██████████████████████████████████████████████████| 1271/1271 [04:53<00:00,  4.33it/s, loss=1.48]
                                                                                                                       


=== Product-category (valid) ===
Accuracy: 0.7185840707964601
Macro-F1: 0.5991021467618702
Micro-F1: 0.7185840707964601
Weighted-F1: 0.7169233868257296

Classification Report:

                                                   precision    recall  f1-score   support

                              alcoholic beverages       0.88      1.00      0.93         7
                      cereals and bakery products       0.64      0.63      0.63        75
     cocoa and cocoa preparations, coffee and tea       0.65      0.73      0.69        15
                                    confectionery       0.90      0.35      0.50        26
dietetic foods, food supplements, fortified foods       0.62      0.57      0.59        14
                                    fats and oils       1.00      0.50      0.67         4
                                   feed materials       0.00      0.00      0.00         1
                   food additives and flavourings       0.00      0.00      0.00         0
  

Train epoch 4/4: 100%|██████████████████████████████████████████████████| 1271/1271 [04:53<00:00,  4.33it/s, loss=1.01]
                                                                                                                       


=== Product-category (valid) ===
Accuracy: 0.7185840707964601
Macro-F1: 0.5910692627418885
Micro-F1: 0.7185840707964601
Weighted-F1: 0.7127597112599545

Classification Report:

                                                   precision    recall  f1-score   support

                              alcoholic beverages       0.78      1.00      0.88         7
                      cereals and bakery products       0.69      0.67      0.68        75
     cocoa and cocoa preparations, coffee and tea       0.55      0.73      0.63        15
                                    confectionery       0.69      0.42      0.52        26
dietetic foods, food supplements, fortified foods       0.57      0.57      0.57        14
                                    fats and oils       1.00      0.50      0.67         4
                                   feed materials       0.00      0.00      0.00         1
                   food additives and flavourings       0.00      0.00      0.00         0
  

                                                                                                                       


=== Product-category (test) ===
Accuracy: 0.7171514543630892
Macro-F1: 0.5343781588767038
Micro-F1: 0.7171514543630892
Weighted-F1: 0.7152838858456307

Classification Report:

                                                   precision    recall  f1-score   support

                              alcoholic beverages       1.00      0.88      0.93        16
                      cereals and bakery products       0.69      0.67      0.68       121
     cocoa and cocoa preparations, coffee and tea       0.72      0.81      0.76        42
                                    confectionery       0.70      0.21      0.33        33
dietetic foods, food supplements, fortified foods       0.50      0.46      0.48        26
                                    fats and oils       1.00      0.17      0.29         6
                                   feed materials       0.00      0.00      0.00         0
                   food additives and flavourings       0.00      0.00      0.00         4
   

