# Smart Product Categorization System

This notebook consolidates the entire project for training in Google Colab.

**Categories:** `beverages`, `snacks`, `dry_food`, `non_food`

**Architecture:** EfficientNet-B0 / ResNet-18 / MobileNet-V2 / SimpleCNN

## 0. Install Dependencies & Mount Drive

In [None]:
!pip install -q torch torchvision scikit-learn matplotlib pandas Pillow tqdm huggingface-hub python-dotenv

## 1. Configuration

In [None]:
import os
import re
import csv
import json
import random
import tarfile
import argparse
import sys
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, field
from pathlib import Path
from typing import (
    Callable, Dict, Iterable, List, Optional, Tuple, Union, Literal,
)

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from torchvision.models import (
    EfficientNet_B0_Weights,
    ResNet18_Weights,
    MobileNet_V2_Weights,
)
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import (
    accuracy_score,
    classification_report,
    confusion_matrix,
    f1_score,
)

try:
    from huggingface_hub import hf_hub_download
except ImportError:
    print("huggingface_hub not installed — manual data upload required.")

In [None]:
# ─── Set your Hugging Face token here ─────────────────────────────────────────
# Option 1: paste directly
HF_TOKEN = ""  # <-- paste your token or leave empty

# Option 2: use Colab secrets
if not HF_TOKEN:
    try:
        from google.colab import userdata
        HF_TOKEN = userdata.get('HF_TOKEN')
    except Exception:
        pass

if HF_TOKEN:
    os.environ["HF_TOKEN"] = HF_TOKEN
    print("HF_TOKEN set ✓")
else:
    print("⚠️  No HF_TOKEN found. Set it above or upload data manually.")

HF_TOKEN set ✓


## 2. Data Config

In [None]:
@dataclass
class DataConfig:
    repo_id: str = "Phathanan/product-categorization-system"
    repo_type: str = "dataset"
    raw_tar_in_repo: str = "data/raw/data_v2.tar"
    token: Optional[str] = None
    revision: Optional[str] = None

    dataset_dir: Path = Path("data_local")
    raw_extract_dirname: str = "raw_extracted"
    processed_dirname: str = "processed"
    raw_metadata_name: str = "metadata.csv"

    labels: List[str] = field(
        default_factory=lambda: ["beverages", "snacks", "dry_food", "non_food"]
    )
    dedup_by_barcode: bool = False
    cap_per_label: Optional[int] = None

    min_side: int = 128
    do_verify: bool = False
    num_workers: int = 8

    seed: int = 42
    train_frac: float = 0.70
    val_frac: float = 0.15
    test_frac: float = 0.15

    def paths(self) -> dict:
        tar_stem = Path(self.raw_tar_in_repo).name
        if tar_stem.endswith(".tar"):
            tar_stem = tar_stem[:-4]
        else:
            tar_stem = Path(tar_stem).stem

        raw_dir = self.dataset_dir / self.raw_extract_dirname / tar_stem
        proc_dir = self.dataset_dir / self.processed_dirname / tar_stem

        return {
            "raw_dir": raw_dir,
            "proc_dir": proc_dir,
            "raw_metadata": raw_dir / self.raw_metadata_name,
            "manifest_clean": proc_dir / "manifest_clean.csv",
            "splits": proc_dir / "splits.json",
            "stats": proc_dir / "stats.json",
            "label_map": proc_dir / "label_map.json",
        }

## 3. Train Config

In [None]:
@dataclass
class TrainConfig:
    model_name: str = "resnet18"
    freeze_backbone: bool = True
    dropout: float = 0.5

    manifest: Path = Path("data_local/processed/data_v2/manifest_clean.csv")
    label_map: Path = Path("data_local/processed/data_v2/label_map.json")

    epochs: int = 20
    batch_size: int = 32
    num_workers: int = 2
    lr: float = 1e-3
    weight_decay: float = 1e-2

    lr_scheduler: str = "cosine"
    lr_step_size: int = 7
    lr_gamma: float = 0.1

    early_stop_patience: int = 3
    early_stop_min_delta: float = 1e-4
    early_stop_metric: str = "val_loss"

    image_size: int = 224

    output_dir: Path = Path("outputs")

    seed: int = 42
    device: Optional[str] = None

    def run_dir(self) -> Path:
        d = self.output_dir / self.model_name
        d.mkdir(parents=True, exist_ok=True)
        return d

## 4. Data Utilities — Loader, Prepare, Validate, Split, Stats

In [None]:
# ─── loader.py ────────────────────────────────────────────────────────────────

def _extract_tar(tar_path: Path, out_dir: Path) -> None:
    out_dir.mkdir(parents=True, exist_ok=True)
    base = out_dir.resolve()
    with tarfile.open(tar_path, "r:*") as tf:
        members = tf.getmembers()
        for m in members:
            target = (out_dir / m.name).resolve()
            if not str(target).startswith(str(base)):
                raise RuntimeError(f"Unsafe path in tar: {m.name}")
        tf.extractall(out_dir)


def download_raw_tar(
    repo_id: str,
    path_in_repo: str,
    repo_type: str = "dataset",
    revision: Optional[str] = None,
    token: Optional[str] = None,
) -> Path:
    return Path(
        hf_hub_download(
            repo_id=repo_id,
            filename=path_in_repo,
            repo_type=repo_type,
            revision=revision,
            token=token,
        )
    )


def ensure_extracted(raw_tar: Path, extract_dir: Path) -> None:
    marker = extract_dir / ".extracted.ok"
    if marker.exists():
        return
    extract_dir.mkdir(parents=True, exist_ok=True)
    _extract_tar(raw_tar, extract_dir)
    marker.write_text("ok", encoding="utf-8")

In [None]:
# ─── prepare.py ───────────────────────────────────────────────────────────────

def norm_barcode(x: object) -> str:
    s = re.sub(r"\D", "", str(x or ""))
    return s.zfill(13) if s else ""


def load_metadata(meta_path: Path) -> pd.DataFrame:
    df = pd.read_csv(meta_path)
    return df


def add_paths(df: pd.DataFrame, raw_dir: Path) -> pd.DataFrame:
    raw_dir = Path(raw_dir)
    images_dir = raw_dir / "images"
    if not images_dir.exists():
        images_dir = raw_dir

    df = df.copy()
    df["barcode"] = df["barcode"].map(norm_barcode)
    df["image_id"] = df["image_id"].astype("string").fillna("").str.strip()

    rel = (
        df["image_id"]
        .str.replace("/", os.sep, regex=False)
        .str.lstrip("\\/")
    )

    df["abs_path"] = rel.map(lambda r: str(images_dir / r))

    return df


def basic_clean(
    df: pd.DataFrame,
    labels: Optional[list] = None,
    dedup_by_barcode: bool = True,
    cap_per_label: Optional[int] = None,
    seed: int = 42,
) -> pd.DataFrame:
    df = df.copy()

    if "label_coarse" not in df.columns:
        raise ValueError("metadata must contain label_coarse")

    df["label_coarse"] = df["label_coarse"].astype(str).str.strip()
    df = df[df["barcode"].astype(str).str.len() > 0]
    df = df[df["image_id"].astype(str).str.len() > 0]
    df = df[df["abs_path"].astype(str).str.len() > 0]

    if labels:
        df = df[df["label_coarse"].isin(labels)]

    df = df.drop_duplicates(subset=["abs_path"], keep="first")

    if dedup_by_barcode:
        df = df.sort_values(["barcode", "label_coarse", "image_id"])
        df = df.drop_duplicates(subset=["barcode"], keep="first")

    if cap_per_label is not None:
        rng = np.random.default_rng(seed)
        kept = []
        for lbl, g in df.groupby("label_coarse"):
            if len(g) <= cap_per_label:
                kept.append(g)
            else:
                idx = rng.choice(g.index.to_numpy(), size=cap_per_label, replace=False)
                kept.append(df.loc[idx])
        df = pd.concat(kept, ignore_index=True)

    df = df.reset_index(drop=True)
    return df


def attach_label_map(labels: list) -> dict:
    return {lbl: i for i, lbl in enumerate(labels)}

In [None]:
# ─── validate.py ──────────────────────────────────────────────────────────────

def _check_one(path: str, min_side: int, do_verify: bool) -> Tuple[int, int, int, int]:
    p = Path(path)
    if not p.exists():
        return 0, 0, 0, 0
    try:
        if do_verify:
            with Image.open(p) as im:
                im.verify()
        with Image.open(p) as im:
            w, h = im.size
        ok = int(min(w, h) >= min_side)
        size = int(p.stat().st_size)
        return ok, w, h, size
    except Exception:
        return 0, 0, 0, 0


def validate_images(
    df: pd.DataFrame,
    min_side: int = 128,
    do_verify: bool = False,
    num_workers: int = 8,
) -> pd.DataFrame:
    if "abs_path" not in df.columns:
        raise ValueError("df must contain abs_path")

    paths = df["abs_path"].astype(str).tolist()
    results = [None] * len(paths)

    with ThreadPoolExecutor(max_workers=max(1, int(num_workers))) as ex:
        futs = {
            ex.submit(_check_one, paths[i], min_side, do_verify): i
            for i in range(len(paths))
        }
        for fut in tqdm(as_completed(futs), total=len(futs), desc="validate images"):
            i = futs[fut]
            results[i] = fut.result()

    out = df.copy()
    out["img_ok"] = [r[0] for r in results]
    out["w"] = [r[1] for r in results]
    out["h"] = [r[2] for r in results]
    out["file_size"] = [r[3] for r in results]
    return out


def keep_only_ok(df: pd.DataFrame) -> pd.DataFrame:
    if "img_ok" not in df.columns:
        raise ValueError("df must contain img_ok")
    return df[df["img_ok"] == 1].copy().reset_index(drop=True)

In [None]:
# ─── split.py ─────────────────────────────────────────────────────────────────

@dataclass
class SplitConfig:
    seed: int = 42
    train_frac: float = 0.8
    val_frac: float = 0.1
    test_frac: float = 0.1


def _alloc_counts(n: int, train_f: float, val_f: float, test_f: float) -> Tuple[int, int, int]:
    if n <= 0:
        return 0, 0, 0
    if n == 1:
        return 1, 0, 0
    if n == 2:
        return 1, 1, 0
    val = max(1, int(round(n * val_f)))
    test = max(1, int(round(n * test_f)))
    if val + test >= n:
        val = 1
        test = 1
    train = n - val - test
    if train <= 0:
        train = max(1, n - 2)
        val = 1 if n - train >= 1 else 0
        test = n - train - val
    return train, val, test


def split_by_barcode(df: pd.DataFrame, cfg: SplitConfig) -> Tuple[pd.DataFrame, Dict]:
    if "barcode" not in df.columns or "label_coarse" not in df.columns:
        raise ValueError("df must contain barcode and label_coarse")

    rng = np.random.default_rng(cfg.seed)
    pairs = df[["barcode", "label_coarse"]].drop_duplicates()
    barcode_to_label = dict(zip(pairs["barcode"], pairs["label_coarse"]))

    split_map: Dict[str, str] = {}

    for lbl, g in pairs.groupby("label_coarse"):
        barcodes = g["barcode"].tolist()
        rng.shuffle(barcodes)

        n = len(barcodes)
        n_train, n_val, n_test = _alloc_counts(n, cfg.train_frac, cfg.val_frac, cfg.test_frac)

        train_ids = barcodes[:n_train]
        val_ids = barcodes[n_train: n_train + n_val]
        test_ids = barcodes[n_train + n_val: n_train + n_val + n_test]

        for b in train_ids:
            split_map[b] = "train"
        for b in val_ids:
            split_map[b] = "val"
        for b in test_ids:
            split_map[b] = "test"

    out = df.copy()
    out["split"] = out["barcode"].map(split_map).fillna("train")

    splits = {"train": [], "val": [], "test": []}
    for b, s in split_map.items():
        splits[s].append(b)

    meta = {
        "seed": cfg.seed,
        "fractions": {
            "train": cfg.train_frac,
            "val": cfg.val_frac,
            "test": cfg.test_frac,
        },
        "counts": {k: len(v) for k, v in splits.items()},
        "splits": splits,
        "barcode_label": barcode_to_label,
    }
    return out, meta


def save_splits_json(meta: Dict, out_path: Path) -> None:
    out_path.parent.mkdir(parents=True, exist_ok=True)
    out_path.write_text(
        json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8"
    )

In [None]:
# ─── stats.py ─────────────────────────────────────────────────────────────────

def compute_stats(df: pd.DataFrame) -> Dict:
    total = int(len(df))
    by_label = (
        df["label_coarse"].value_counts().to_dict()
        if "label_coarse" in df.columns
        else {}
    )
    by_split = (
        df["split"].value_counts().to_dict() if "split" in df.columns else {}
    )

    by_label_split = {}
    if "label_coarse" in df.columns and "split" in df.columns:
        tmp = df.groupby(["label_coarse", "split"]).size().reset_index(name="n")
        for _, r in tmp.iterrows():
            lbl = r["label_coarse"]
            if lbl not in by_label_split:
                by_label_split[lbl] = {}
            by_label_split[lbl][r["split"]] = int(r["n"])

    img_ok_rate = None
    if "img_ok" in df.columns:
        img_ok_rate = float(df["img_ok"].mean()) if len(df) else 0.0

    return {
        "total": total,
        "by_label": {k: int(v) for k, v in by_label.items()},
        "by_split": {k: int(v) for k, v in by_split.items()},
        "by_label_split": by_label_split,
        "img_ok_rate": img_ok_rate,
    }


def save_stats(stats: Dict, out_path: Path) -> None:
    out_path.parent.mkdir(parents=True, exist_ok=True)
    out_path.write_text(
        json.dumps(stats, ensure_ascii=False, indent=2), encoding="utf-8"
    )

## 5. Transforms

In [None]:
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]


def get_train_transforms(size: int = 224) -> transforms.Compose:
    return transforms.Compose(
        [
            # 1. ปรับขนาดภาพเบื้องต้น
            transforms.Resize(size + 32),
            transforms.RandomResizedCrop(
                size,
                scale=(0.7, 1.0),
                ratio=(0.75, 1.33),
                interpolation=transforms.InterpolationMode.BILINEAR,
            ),

            # 2. TrivialAugmentWide: สุ่มเลือกชุด Augmentation อัตโนมัติ (มีประสิทธิภาพสูงมาก)
            transforms.TrivialAugmentWide(),

            # 3. RandomAffine: เพิ่มความทนทานต่อการเอียง (degrees), การเลื่อน (translate),
            # การย่อขยาย (scale) และการบิดเบี้ยว (shear)
            transforms.RandomAffine(
                degrees=20,
                translate=(0.1, 0.1),
                scale=(0.9, 1.1),
                shear=10
            ),

            transforms.RandomHorizontalFlip(p=0.5),

            # 4. Gaussian Blur: สุ่มทำภาพเบลอ (สุ่มใช้ที่ความน่าจะเป็น 30%)
            transforms.RandomApply([
                transforms.GaussianBlur(kernel_size=(3, 7), sigma=(0.1, 2.0))
            ], p=0.3),

            # 5. แปลงเป็น Tensor และ Normalize
            transforms.ToTensor(),
            transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
        ]
    )


def get_val_transforms(size: int = 224) -> transforms.Compose:
    return transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(size),
            transforms.ToTensor(),
            transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
        ]
    )

## 6. Dataset

In [None]:
class ProductDataset(Dataset):
    """
    PyTorch Dataset for product package image classification.

    Parameters
    ----------
    manifest : pd.DataFrame | Path | str
        Either a pre-loaded DataFrame or a path to manifest_clean.csv.
        The DataFrame must contain columns: ``abs_path``, ``label_coarse``.
    label_map : Dict[str, int] | Path | str
        Either a pre-built {class_name: int} dict or a path to label_map.json.
    split : str | None
        If not None, filter the manifest to rows where ``split == split``.
        Typical values: ``"train"``, ``"val"``, ``"test"``.
    transform : Callable | None
        torchvision transform pipeline applied to the PIL image.
        Use ``get_train_transforms()`` / ``get_val_transforms()`` from transforms.py.

    Returns  (via __getitem__)
    -------
    image  : torch.FloatTensor  shape (C, H, W) after transform
    label  : torch.LongTensor   scalar integer class index
    """

    def __init__(
        self,
        manifest: Union[pd.DataFrame, Path, str],
        label_map: Union[Dict[str, int], Path, str],
        split: Optional[str] = None,
        transform: Optional[Callable] = None,
    ) -> None:
        # ── 1. Load manifest ──────────────────────────────────────────────
        manifest_path = manifest if isinstance(manifest, (str, Path)) else None
        if isinstance(manifest, (str, Path)):
            manifest = pd.read_csv(manifest)

        if "abs_path" not in manifest.columns:
            raise ValueError("manifest must contain column 'abs_path'")
        if "label_coarse" not in manifest.columns:
            raise ValueError("manifest must contain column 'label_coarse'")

        # ── 2. Filter by split ────────────────────────────────────────────
        if split is not None:
            if "split" not in manifest.columns:
                if manifest_path is not None:
                    splits_path = Path(manifest_path).parent / "splits.json"
                    if splits_path.exists():
                        splits_data = json.loads(splits_path.read_text(encoding="utf-8"))
                        barcodes = set(splits_data.get("splits", {}).get(split, []))

                        # We need 'barcode' as a string without '.0' etc
                        # Pandas sometimes reads big numbers as float if there are NaNs
                        manifest["_tmp_bc"] = manifest.get("barcode", "").astype(str).str.replace(r"\.0$", "", regex=True)
                        manifest = manifest[manifest["_tmp_bc"].isin(barcodes)].copy()
                        manifest = manifest.drop(columns=["_tmp_bc"])
                    else:
                        raise ValueError(f"split '{split}' requested, no 'split' column, and no splits.json found at {splits_path}")
                else:
                    raise ValueError(
                        f"split='{split}' requested but manifest has no 'split' column. "
                        "Run scripts/prepare_dataset.py first."
                    )
            else:
                manifest = manifest[manifest["split"] == split].copy()

            if len(manifest) == 0:
                raise ValueError(
                    f"No rows found for split='{split}'. "
                    "Check that prepare_dataset.py completed successfully."
                )

        self._df = manifest.reset_index(drop=True)

        # ── 3. Load label_map ─────────────────────────────────────────────
        if isinstance(label_map, (str, Path)):
            label_map = json.loads(Path(label_map).read_text(encoding="utf-8"))

        self._label_map: Dict[str, int] = label_map

        # Pre-validate that every label in manifest is known
        unknown = set(self._df["label_coarse"].unique()) - set(self._label_map.keys())
        if unknown:
            raise ValueError(
                f"Labels found in manifest but missing from label_map: {unknown}"
            )

        self.transform = transform
        self.classes: list = sorted(self._label_map, key=self._label_map.get)  # type: ignore[arg-type]
        self.num_classes: int = len(self._label_map)

    # ── Dataset protocol ──────────────────────────────────────────────────

    def __len__(self) -> int:
        return len(self._df)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        row = self._df.iloc[idx]

        # ── Load image (PIL, RGB) ────────────────────────────────────────
        img_path = Path(str(row["abs_path"]))
        try:
            image: Image.Image = Image.open(img_path).convert("RGB")
        except Exception as exc:
            raise RuntimeError(
                f"Cannot open image at index {idx}: {img_path}"
            ) from exc

        # ── Apply transform pipeline ─────────────────────────────────────
        if self.transform is not None:
            image = self.transform(image)
        else:
            # Fallback: at minimum convert PIL → tensor if no transform supplied
            from torchvision.transforms.functional import to_tensor
            image = to_tensor(image)  # type: ignore[assignment]

        # ── Encode label ─────────────────────────────────────────────────
        label_int: int = self._label_map[row["label_coarse"]]
        label = torch.tensor(label_int, dtype=torch.long)

        return image, label  # type: ignore[return-value]

    # ── Convenience helpers ───────────────────────────────────────────────

    def __repr__(self) -> str:
        return (
            f"ProductDataset("
            f"n={len(self)}, "
            f"split={self._df['split'].unique().tolist() if 'split' in self._df.columns else 'N/A'}, "
            f"classes={self.classes}"
            f")"
        )

    @property
    def label_map(self) -> Dict[str, int]:
        return dict(self._label_map)


# ── Convenience factory ───────────────────────────────────────────────────────

def build_datasets(
    manifest_path: Union[Path, str],
    label_map_path: Union[Path, str],
    train_transform: Optional[Callable] = None,
    val_transform: Optional[Callable] = None,
) -> Dict[str, "ProductDataset"]:
    """
    Build train / val / test datasets in one call.

    Example
    -------
    >>> from src.data.transforms import get_train_transforms, get_val_transforms
    >>> from src.config.data_config import DataConfig
    >>> cfg = DataConfig()
    >>> p = cfg.paths()
    >>> datasets = build_datasets(
    ...     p["manifest_clean"], p["label_map"],
    ...     train_transform=get_train_transforms(),
    ...     val_transform=get_val_transforms(),
    ... )
    >>> datasets["train"], datasets["val"], datasets["test"]
    """
    manifest = pd.read_csv(manifest_path)

    if "split" not in manifest.columns:
        splits_path = Path(manifest_path).parent / "splits.json"
        if splits_path.exists():
            splits_data = json.loads(splits_path.read_text(encoding="utf-8"))
            barcode_to_split = {}
            for sp, bcs in splits_data.get("splits", {}).items():
                for bc in bcs:
                    barcode_to_split[str(bc)] = sp

            manifest["_tmp_bc"] = manifest.get("barcode", "").astype(str).str.replace(r"\.0$", "", regex=True)
            manifest["split"] = manifest["_tmp_bc"].map(barcode_to_split)
            manifest = manifest.drop(columns=["_tmp_bc"])
            manifest = manifest.dropna(subset=["split"])

    label_map: Dict[str, int] = json.loads(
        Path(label_map_path).read_text(encoding="utf-8")
    )

    ds: Dict[str, ProductDataset] = {}
    for split_name in ("train", "val", "test"):
        transform = train_transform if split_name == "train" else val_transform
        ds[split_name] = ProductDataset(
            manifest=manifest,
            label_map=label_map,
            split=split_name,
            transform=transform,
        )
    return ds

## 7. Models

In [None]:
# ─── ProductClassifier (EfficientNet-B0) ──────────────────────────────────────

class ProductClassifier(nn.Module):
    BACKBONE_OUT_FEATURES: int = 1280

    def __init__(
        self,
        num_classes: int = 4,
        freeze_backbone: bool = True,
        dropout: float = 0.3,
        pretrained: bool = True,
    ) -> None:
        super().__init__()

        weights = EfficientNet_B0_Weights.IMAGENET1K_V1 if pretrained else None
        backbone = models.efficientnet_b0(weights=weights)

        self.backbone: nn.Module = backbone.features
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.head = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(p=dropout),
            nn.Linear(self.BACKBONE_OUT_FEATURES, num_classes),
        )

        nn.init.xavier_uniform_(self.head[2].weight)
        nn.init.zeros_(self.head[2].bias)

        self.num_classes = num_classes
        self._backbone_frozen = False

        if freeze_backbone:
            self.freeze_backbone()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.backbone(x)
        x = self.pool(x)
        x = self.head(x)
        return x

    def freeze_backbone(self) -> None:
        for param in self.backbone.parameters():
            param.requires_grad = False
        self._backbone_frozen = True

    def unfreeze_backbone(self) -> None:
        for param in self.backbone.parameters():
            param.requires_grad = True
        self._backbone_frozen = False

    def unfreeze_last_n_blocks(self, n: int = 3) -> None:
        blocks = list(self.backbone.children())
        for block in blocks[-n:]:
            for param in block.parameters():
                param.requires_grad = True

    def trainable_params(self) -> int:
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def total_params(self) -> int:
        return sum(p.numel() for p in self.parameters())

    def param_summary(self) -> Dict[str, int]:
        total = self.total_params()
        trainable = self.trainable_params()
        return {"total": total, "trainable": trainable, "frozen": total - trainable}

    def save(self, path: Path) -> None:
        path = Path(path)
        path.parent.mkdir(parents=True, exist_ok=True)
        torch.save(
            {
                "model_state_dict": self.state_dict(),
                "num_classes": self.num_classes,
            },
            path,
        )

    @classmethod
    def load(cls, path: Path, num_classes: Optional[int] = None, map_location: str = "cpu") -> "ProductClassifier":
        ckpt = torch.load(path, map_location=map_location)
        nc = num_classes or ckpt["num_classes"]
        model = cls(num_classes=nc, freeze_backbone=False, pretrained=False)
        model.load_state_dict(ckpt["model_state_dict"])
        return model

    def __repr__(self) -> str:
        return (
            f"ProductClassifier("
            f"backbone=EfficientNet-B0, "
            f"num_classes={self.num_classes}, "
            f"frozen={self._backbone_frozen}, "
            f"trainable_params={self.trainable_params():,})"
        )

In [None]:
# ─── SimpleCNN ────────────────────────────────────────────────────────────────

class SimpleCNN(nn.Module):
    def __init__(self, num_classes: int = 4, dropout: float = 0.3) -> None:
        super().__init__()

        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
        )

        self.pool = nn.AdaptiveAvgPool2d((4, 4))

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 4 * 4, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(512, num_classes),
        )

        self.num_classes = num_classes
        self._init_weights()

    def _init_weights(self) -> None:
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = self.pool(x)
        x = self.classifier(x)
        return x

    def freeze_backbone(self) -> None:
        pass

    def unfreeze_backbone(self) -> None:
        pass

    def trainable_params(self) -> int:
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def total_params(self) -> int:
        return sum(p.numel() for p in self.parameters())

    def param_summary(self) -> dict:
        total = self.total_params()
        trainable = self.trainable_params()
        return {"total": total, "trainable": trainable, "frozen": total - trainable}

    def save(self, path) -> None:
        path = Path(path)
        path.parent.mkdir(parents=True, exist_ok=True)
        torch.save(
            {"model_state_dict": self.state_dict(), "num_classes": self.num_classes},
            path,
        )

    def __repr__(self) -> str:
        return f"SimpleCNN(num_classes={self.num_classes}, trainable_params={self.trainable_params():,})"

In [None]:
# ─── TransferModel wrapper (for ResNet-18 / MobileNet-V2) ────────────────────

class _TransferModel(nn.Module):
    def __init__(self, backbone: nn.Module, num_classes: int, freeze_backbone: bool) -> None:
        super().__init__()
        self._backbone = backbone
        self.num_classes = num_classes
        self._backbone_frozen = False
        if freeze_backbone:
            self.freeze_backbone()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self._backbone(x)

    def freeze_backbone(self) -> None:
        for name, param in self._backbone.named_parameters():
            if not name.startswith("fc.") and not name.startswith("classifier."):
                param.requires_grad = False
        self._backbone_frozen = True

    def unfreeze_backbone(self) -> None:
        for param in self._backbone.parameters():
            param.requires_grad = True
        self._backbone_frozen = False

    def trainable_params(self) -> int:
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def total_params(self) -> int:
        return sum(p.numel() for p in self.parameters())

    def param_summary(self) -> dict:
        total = self.total_params()
        trainable = self.trainable_params()
        return {"total": total, "trainable": trainable, "frozen": total - trainable}

    def save(self, path) -> None:
        path = Path(path)
        path.parent.mkdir(parents=True, exist_ok=True)
        torch.save(
            {"model_state_dict": self.state_dict(), "num_classes": self.num_classes},
            path,
        )

    def __repr__(self) -> str:
        name = type(self._backbone).__name__
        return (
            f"{name}Wrapper("
            f"num_classes={self.num_classes}, "
            f"frozen={self._backbone_frozen}, "
            f"trainable_params={self.trainable_params():,})"
        )

In [None]:
# ─── Model factory ────────────────────────────────────────────────────────────

def _build_resnet18(num_classes: int, freeze_backbone: bool, dropout: float) -> _TransferModel:
    backbone = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    in_features = backbone.fc.in_features
    backbone.fc = nn.Sequential(
        nn.Dropout(p=dropout),
        nn.Linear(in_features, num_classes),
    )
    nn.init.xavier_uniform_(backbone.fc[1].weight)
    nn.init.zeros_(backbone.fc[1].bias)
    return _TransferModel(backbone, num_classes, freeze_backbone)


def _build_mobilenetv2(num_classes: int, freeze_backbone: bool, dropout: float) -> _TransferModel:
    backbone = models.mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1)
    in_features = backbone.classifier[1].in_features
    backbone.classifier = nn.Sequential(
        nn.Dropout(p=dropout),
        nn.Linear(in_features, num_classes),
    )
    nn.init.xavier_uniform_(backbone.classifier[1].weight)
    nn.init.zeros_(backbone.classifier[1].bias)
    return _TransferModel(backbone, num_classes, freeze_backbone)

def _build_resnet50(num_classes: int, freeze_backbone: bool, dropout: float) -> _TransferModel:
    backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
    in_features = backbone.fc.in_features
    backbone.fc = nn.Sequential(
        nn.Dropout(p=dropout),
        nn.Linear(in_features, num_classes),
    )
    nn.init.xavier_uniform_(backbone.fc[1].weight)
    nn.init.zeros_(backbone.fc[1].bias)
    return _TransferModel(backbone, num_classes, freeze_backbone)


def _build_mobilenetv3_large(num_classes: int, freeze_backbone: bool, dropout: float) -> _TransferModel:
    backbone = models.mobilenet_v3_large(weights=models.MobileNet_V3_Large_Weights.IMAGENET1K_V2)
    in_features = backbone.classifier[-1].in_features
    backbone.classifier[-1] = nn.Sequential(
        nn.Dropout(p=dropout),
        nn.Linear(in_features, num_classes),
    )
    # classifier[-1] เป็น Sequential ใหม่ -> weight อยู่ที่ [1]
    nn.init.xavier_uniform_(backbone.classifier[-1][1].weight)
    nn.init.zeros_(backbone.classifier[-1][1].bias)
    return _TransferModel(backbone, num_classes, freeze_backbone)


_REGISTRY = {
    "efficientnet_b0": lambda nc, fb, do: ProductClassifier(num_classes=nc, freeze_backbone=fb, dropout=do),
    "simple_cnn": lambda nc, fb, do: SimpleCNN(num_classes=nc, dropout=do),
    "resnet18": _build_resnet18,
    "resnet50": _build_resnet50,
    "mobilenetv2": _build_mobilenetv2,
    "mobilenetv3_large": _build_mobilenetv3_large,
}

ModelName = Literal["efficientnet_b0", "simple_cnn", "resnet18", "mobilenetv2"]


def build_model(
    name: str,
    num_classes: int = 4,
    freeze_backbone: bool = True,
    dropout: float = 0.3,
) -> nn.Module:
    name = name.lower().strip()
    if name not in _REGISTRY:
        raise ValueError(
            f"Unknown model '{name}'. Choose from: {list(_REGISTRY.keys())}"
        )
    return _REGISTRY[name](num_classes, freeze_backbone, dropout)


def available_models() -> list:
    return list(_REGISTRY.keys())

## 8. Metrics & Logger

In [None]:
# ─── metrics.py ───────────────────────────────────────────────────────────────

def compute_metrics(
    all_labels: List[int],
    all_preds: List[int],
    class_names: Optional[List[str]] = None,
) -> Dict[str, float]:
    y_true = np.array(all_labels)
    y_pred = np.array(all_preds)

    acc = float(accuracy_score(y_true, y_pred))
    labels_arg = list(range(len(class_names))) if class_names is not None else None
    f1_macro = float(f1_score(y_true, y_pred, labels=labels_arg, average="macro", zero_division=0))

    result: Dict[str, float] = {
        "accuracy": acc,
        "f1_macro": f1_macro,
    }

    if class_names:
        per_class = f1_score(y_true, y_pred, labels=labels_arg, average=None, zero_division=0)
        for name, score in zip(class_names, per_class):
            result[f"f1_{name}"] = float(score)

    return result


def get_classification_report(
    all_labels: List[int],
    all_preds: List[int],
    class_names: Optional[List[str]] = None,
) -> str:
    labels_arg = list(range(len(class_names))) if class_names is not None else None
    return classification_report(
        all_labels, all_preds, labels=labels_arg, target_names=class_names, zero_division=0
    )


def get_confusion_matrix(
    all_labels: List[int],
    all_preds: List[int],
    class_names: Optional[List[str]] = None,
) -> np.ndarray:
    labels_arg = list(range(len(class_names))) if class_names is not None else None
    return confusion_matrix(all_labels, all_preds, labels=labels_arg)

In [None]:
# ─── logger.py ────────────────────────────────────────────────────────────────

class CSVLogger:
    def __init__(self, path: Path) -> None:
        self.path = Path(path)
        self._header_written = self.path.exists()

    def log(self, epoch: int, split: str, **metrics: float) -> None:
        row = {"epoch": epoch, "split": split, **metrics}
        write_header = not self._header_written
        with self.path.open("a", newline="") as fh:
            writer = csv.DictWriter(fh, fieldnames=list(row.keys()))
            if write_header:
                writer.writeheader()
                self._header_written = True
            writer.writerow(row)


def _read_csv(path: Path) -> List[Dict]:
    with path.open() as fh:
        return list(csv.DictReader(fh))


def plot_loss_curves(csv_path: Path, out_path: Path) -> None:
    try:
        import matplotlib
        matplotlib.use("Agg")
        import matplotlib.pyplot as plt
    except ImportError:
        print("[logger] matplotlib not installed — skipping loss curve plot.")
        return

    rows = _read_csv(csv_path)
    train_rows = [r for r in rows if r["split"] == "train"]
    val_rows = [r for r in rows if r["split"] == "val"]

    if not train_rows or not val_rows:
        return

    epochs = [int(r["epoch"]) for r in train_rows]
    train_loss = [float(r["loss"]) for r in train_rows]
    val_loss = [float(r["loss"]) for r in val_rows]

    fig, ax = plt.subplots(figsize=(8, 5))
    ax.plot(epochs, train_loss, label="train_loss", marker="o", markersize=3)
    ax.plot(epochs, val_loss, label="val_loss", marker="s", markersize=3)
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss")
    ax.set_title("Train vs Val Loss")
    ax.legend()
    ax.grid(True, alpha=0.3)
    fig.tight_layout()
    fig.savefig(out_path, dpi=120)
    plt.close(fig)


def plot_accuracy_curve(csv_path: Path, out_path: Path) -> None:
    try:
        import matplotlib
        matplotlib.use("Agg")
        import matplotlib.pyplot as plt
    except ImportError:
        print("[logger] matplotlib not installed — skipping accuracy curve plot.")
        return

    rows = _read_csv(csv_path)
    val_rows = [r for r in rows if r["split"] == "val"]

    if not val_rows:
        return

    epochs = [int(r["epoch"]) for r in val_rows]
    accuracy = [float(r["accuracy"]) for r in val_rows]
    f1 = [float(r["f1_macro"]) for r in val_rows]

    fig, ax = plt.subplots(figsize=(8, 5))
    ax.plot(epochs, accuracy, label="val_accuracy", marker="o", markersize=3)
    ax.plot(epochs, f1, label="val_f1_macro", marker="s", markersize=3)
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Score")
    ax.set_ylim(0, 1.05)
    ax.set_title("Val Accuracy & F1-Macro")
    ax.legend()
    ax.grid(True, alpha=0.3)
    fig.tight_layout()
    fig.savefig(out_path, dpi=120)
    plt.close(fig)


def plot_confusion_matrix(
    cm: np.ndarray,
    class_names: List[str],
    out_path: Path,
    title: str = "Confusion Matrix",
) -> None:
    try:
        import matplotlib
        matplotlib.use("Agg")
        import matplotlib.pyplot as plt
    except ImportError:
        print("[logger] matplotlib not installed — skipping confusion matrix plot.")
        return

    fig, ax = plt.subplots(figsize=(6, 5))
    im = ax.imshow(cm, interpolation="nearest", cmap="Blues")
    fig.colorbar(im, ax=ax)
    ax.set(
        xticks=range(len(class_names)),
        yticks=range(len(class_names)),
        xticklabels=class_names,
        yticklabels=class_names,
        xlabel="Predicted",
        ylabel="True",
        title=title,
    )
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right")

    thresh = cm.max() / 2.0
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(
                j, i, format(cm[i, j], "d"),
                ha="center", va="center",
                color="white" if cm[i, j] > thresh else "black",
            )

    fig.tight_layout()
    fig.savefig(out_path, dpi=120)
    plt.close(fig)

## 9. Trainer

In [None]:
class Trainer:
    def __init__(
        self,
        model: nn.Module,
        train_loader: DataLoader,
        val_loader: DataLoader,
        cfg: TrainConfig,
        class_names: List[str],
    ) -> None:
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.cfg = cfg
        self.class_names = class_names

        # Device
        self.device = self._resolve_device(cfg.device)
        self.model.to(self.device)

        # Optimiser
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.AdamW(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=cfg.lr,
            weight_decay=cfg.weight_decay,
        )

        # LR Scheduler
        self.scheduler = self._build_scheduler()

        # Logging
        run_dir = cfg.run_dir()
        self.csv_logger = CSVLogger(run_dir / "metrics.csv")
        self.best_ckpt_path = run_dir / "best_checkpoint.pt"
        self._best_f1: float = -1.0

    @staticmethod
    def _resolve_device(requested: Optional[str]) -> torch.device:
        if requested:
            return torch.device(requested)
        if torch.cuda.is_available():
            return torch.device("cuda")
        if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
            return torch.device("mps")
        return torch.device("cpu")

    def _build_scheduler(self):
        if self.cfg.lr_scheduler == "cosine":
            return torch.optim.lr_scheduler.CosineAnnealingLR(
                self.optimizer, T_max=self.cfg.epochs
            )
        elif self.cfg.lr_scheduler == "step":
            return torch.optim.lr_scheduler.StepLR(
                self.optimizer,
                step_size=self.cfg.lr_step_size,
                gamma=self.cfg.lr_gamma,
            )
        return None

    def _train_epoch(self, epoch: int) -> float:
        self.model.train()
        running_loss = 0.0

        for images, labels in self.train_loader:
            images = images.to(self.device, non_blocking=True)
            labels = labels.to(self.device, non_blocking=True)

            self.optimizer.zero_grad()
            logits = self.model(images)
            loss = self.criterion(logits, labels)
            loss.backward()
            self.optimizer.step()

            running_loss += loss.item()

        return running_loss / len(self.train_loader)

    @torch.no_grad()
    def _val_epoch(self) -> Tuple[float, dict, List[int], List[int]]:
        self.model.eval()
        running_loss = 0.0
        all_labels: List[int] = []
        all_preds: List[int] = []

        for images, labels in self.val_loader:
            images = images.to(self.device, non_blocking=True)
            labels = labels.to(self.device, non_blocking=True)

            logits = self.model(images)
            loss = self.criterion(logits, labels)
            running_loss += loss.item()

            preds = logits.argmax(dim=1)
            all_labels.extend(labels.cpu().tolist())
            all_preds.extend(preds.cpu().tolist())

        avg_loss = running_loss / len(self.val_loader)
        metrics = compute_metrics(all_labels, all_preds, self.class_names)

        return avg_loss, metrics, all_labels, all_preds

    def _save_best_checkpoint(self, val_f1: float, epoch: int) -> bool:
        if val_f1 > self._best_f1:
            self._best_f1 = val_f1
            if hasattr(self.model, "save"):
                self.model.save(self.best_ckpt_path)
            else:
                torch.save(
                    {"model_state_dict": self.model.state_dict(), "epoch": epoch},
                    self.best_ckpt_path,
                )
            print(f"  ★ New best F1={val_f1:.4f} — checkpoint saved")
            return True
        return False

    def fit(self) -> None:
      print(f"\n{'='*60}")
      print(f"  Model    : {self.cfg.model_name}")
      print(f"  Device   : {self.device}")
      print(f"  Epochs   : {self.cfg.epochs}")
      print(f"  LR       : {self.cfg.lr}  scheduler={self.cfg.lr_scheduler}")
      print(f"  Run dir  : {self.cfg.run_dir()}")
      print(f"{'='*60}\n")

      last_val_labels: List[int] = []
      last_val_preds: List[int] = []

      # ------------------------------
      # Early stopping state
      # ------------------------------
      early_metric = getattr(self.cfg, "early_stop_metric", "val_loss")  # "val_loss" or "val_f1_macro"
      patience = getattr(self.cfg, "early_stop_patience", 3)
      min_delta = getattr(self.cfg, "early_stop_min_delta", 1e-4)

      if early_metric == "val_loss":
          best_score = float("inf")
      else:
          best_score = -float("inf")

      best_epoch = -1
      patience_counter = 0

      for epoch in range(1, self.cfg.epochs + 1):
          print(f"── Epoch {epoch}/{self.cfg.epochs} ──")

          train_loss = self._train_epoch(epoch)
          val_loss, val_metrics, val_labels, val_preds = self._val_epoch()
          last_val_labels, last_val_preds = val_labels, val_preds

          if self.scheduler is not None:
              self.scheduler.step()

          self.csv_logger.log(
              epoch=epoch, split="train",
              loss=train_loss, accuracy=0.0, f1_macro=0.0,
          )
          self.csv_logger.log(
              epoch=epoch, split="val",
              loss=val_loss, **val_metrics,
          )

          print(
              f"  train_loss={train_loss:.4f}  "
              f"val_loss={val_loss:.4f}  "
              f"val_acc={val_metrics['accuracy']:.4f}  "
              f"val_f1={val_metrics['f1_macro']:.4f}"
          )

          # ------------------------------
          # Early stopping check
          # ------------------------------
          if early_metric == "val_loss":
              current = float(val_loss)
              improved = (best_score - current) > min_delta
          else:
              current = float(val_metrics["f1_macro"])
              improved = (current - best_score) > min_delta

          if improved:
              best_score = current
              best_epoch = epoch
              patience_counter = 0

              # Save best checkpoint when improved.
              # ถ้าคุณยังอยากเก็บ best_f1 แบบเดิม ให้เรียกของเดิมต่อไปได้
              self._save_best_checkpoint(val_metrics["f1_macro"], epoch)
          else:
              patience_counter += 1

          if patience_counter >= patience:
              print(
                  f"Early stopping at epoch {epoch} "
                  f"(best_epoch={best_epoch}, best_{early_metric}={best_score:.4f})"
              )
              break

      # Post-training outputs
      run_dir = self.cfg.run_dir()
      csv_path = run_dir / "metrics.csv"

      print(f"\n── Generating plots → {run_dir} ──")
      plot_loss_curves(csv_path, run_dir / "loss_curve.png")
      plot_accuracy_curve(csv_path, run_dir / "accuracy_curve.png")

      cm = get_confusion_matrix(last_val_labels, last_val_preds, self.class_names)
      plot_confusion_matrix(
          cm, self.class_names,
          run_dir / "confusion_matrix.png",
          title=f"Confusion Matrix — {self.cfg.model_name}",
      )

      print(f"\n── Final Classification Report (val) ──")
      report = get_classification_report(
          last_val_labels, last_val_preds, self.class_names
      )
      print(report)

      print(f"\n── Confusion Matrix (val) ──")
      _print_confusion_matrix(cm, self.class_names)

      print(f"\nBest val F1-macro : {self._best_f1:.4f}")
      print(f"Best checkpoint   : {self.best_ckpt_path}")
      print(f"Metrics CSV       : {csv_path}")


def _print_confusion_matrix(cm, class_names: List[str]) -> None:
    col_w = max(12, max(len(n) for n in class_names) + 2)
    header = " " * col_w + "".join(f"{n:>{col_w}}" for n in class_names)
    print(header)
    for i, name in enumerate(class_names):
        row_str = f"{name:<{col_w}}" + "".join(
            f"{cm[i,j]:>{col_w}}" for j in range(len(class_names))
        )
        print(row_str)

## 10. Data Preparation Pipeline

In [None]:
def prepare_data(data_cfg: Optional[DataConfig] = None) -> dict:
    """Run the full data preparation pipeline. Returns the paths dict."""
    if data_cfg is None:
        data_cfg = DataConfig()

    # Use HF_TOKEN from environment if not set
    if data_cfg.token is None:
        data_cfg.token = os.environ.get("HF_TOKEN")

    p = data_cfg.paths()
    p["raw_dir"].mkdir(parents=True, exist_ok=True)
    p["proc_dir"].mkdir(parents=True, exist_ok=True)

    # Check if already prepared
    if p["manifest_clean"].exists() and p["label_map"].exists():
        print("✓ Data already prepared. Skipping download & processing.")
        return p

    print("Downloading raw tar from Hugging Face …")
    raw_tar = download_raw_tar(
        repo_id=data_cfg.repo_id,
        path_in_repo=data_cfg.raw_tar_in_repo,
        repo_type=data_cfg.repo_type,
        revision=data_cfg.revision,
        token=data_cfg.token,
    )

    print("Extracting …")
    ensure_extracted(raw_tar, p["raw_dir"])

    if not p["raw_metadata"].exists():
        raise FileNotFoundError(f"metadata.csv not found at {p['raw_metadata']}")

    print("Loading metadata …")
    df = load_metadata(p["raw_metadata"])
    df = add_paths(df, p["raw_dir"])

    print("Cleaning …")
    df = basic_clean(
        df,
        labels=data_cfg.labels,
        dedup_by_barcode=data_cfg.dedup_by_barcode,
        cap_per_label=data_cfg.cap_per_label,
        seed=data_cfg.seed,
    )

    print("Validating images …")
    df = validate_images(
        df,
        min_side=data_cfg.min_side,
        do_verify=data_cfg.do_verify,
        num_workers=data_cfg.num_workers,
    )
    df = keep_only_ok(df)

    print("Splitting by barcode …")
    split_cfg = SplitConfig(
        seed=data_cfg.seed,
        train_frac=data_cfg.train_frac,
        val_frac=data_cfg.val_frac,
        test_frac=data_cfg.test_frac,
    )
    df, split_meta = split_by_barcode(df, split_cfg)

    # Save outputs
    df.to_csv(p["manifest_clean"], index=False)
    save_splits_json(split_meta, p["splits"])

    label_map = attach_label_map(data_cfg.labels)
    p["label_map"].write_text(
        json.dumps(label_map, ensure_ascii=False, indent=2), encoding="utf-8"
    )

    stats = compute_stats(df)
    save_stats(stats, p["stats"])

    print(f"✓ Data preparation complete!")
    print(f"  manifest : {p['manifest_clean']}")
    print(f"  label_map: {p['label_map']}")
    print(f"  stats    : {p['stats']}")

    return p

## 11. Reproducibility

In [None]:
def seed_everything(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

## 12. Run Data Preparation

In [None]:
data_cfg = DataConfig()
paths = prepare_data(data_cfg)

Downloading raw tar from Hugging Face …


data/raw/data_v2.tar:   0%|          | 0.00/242M [00:00<?, ?B/s]

Extracting …


  tf.extractall(out_dir)


Loading metadata …
Cleaning …
Validating images …


validate images: 100%|██████████| 8000/8000 [00:01<00:00, 4019.40it/s]


Splitting by barcode …
✓ Data preparation complete!
  manifest : data_local/processed/data_v2/manifest_clean.csv
  label_map: data_local/processed/data_v2/label_map.json
  stats    : data_local/processed/data_v2/stats.json


## 13. Configure Training

Edit these parameters as needed:

In [None]:
# ╔══════════════════════════════════════════════════════════════╗
# ║                    TRAINING SETTINGS                        ║
# ║  Edit these values to control your experiment               ║
# ╚══════════════════════════════════════════════════════════════╝

MODEL_NAME = "efficientnet_b0"          # "efficientnet_b0" | "resnet18" | "mobilenetv2" | "simple_cnn"
FREEZE_BACKBONE = True           # Stage 1: head-only training
DROPOUT = 0.3
EPOCHS = 30
BATCH_SIZE = 32
LR = 1e-4
WEIGHT_DECAY = 1e-2
LR_SCHEDULER = "cosine"         # "cosine" | "step" | "none"
IMAGE_SIZE = 224
NUM_WORKERS = 2                  # Colab typically has 2 CPUs
SEED = 42
OUTPUT_DIR = Path("outputs")

print(f"Available models: {available_models()}")
print(f"Selected model  : {MODEL_NAME}")

Available models: ['efficientnet_b0', 'simple_cnn', 'resnet18', 'resnet50', 'mobilenetv2', 'mobilenetv3_large']
Selected model  : resnet50


## 14. Train!

In [None]:
seed_everything(SEED)

# --------- FIX PATHS (paths ของคุณเป็น list) ----------
MANIFEST_PATH = "data_local/processed/data_v2/manifest_clean.csv"
LABEL_MAP_PATH = "data_local/processed/data_v2/label_map.json"



# ======================================================
# Stage 1: train head only (freeze backbone)
# ======================================================
cfg1 = TrainConfig(
    model_name=MODEL_NAME,
    freeze_backbone=True,
    dropout=DROPOUT,
    manifest=MANIFEST_PATH,
    label_map=LABEL_MAP_PATH,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    lr=LR,
    weight_decay=WEIGHT_DECAY,
    lr_scheduler=LR_SCHEDULER,
    image_size=IMAGE_SIZE,
    output_dir=OUTPUT_DIR,
    seed=SEED,
    device=None,
)

print("Loading datasets …")
datasets = build_datasets(
    manifest_path=cfg1.manifest,
    label_map_path=cfg1.label_map,
    train_transform=get_train_transforms(size=cfg1.image_size),
    val_transform=get_val_transforms(size=cfg1.image_size),
)

train_ds = datasets["train"]
val_ds = datasets["val"]

print(f"  train : {train_ds}")
print(f"  val   : {val_ds}")

class_names = train_ds.classes
num_classes = train_ds.num_classes

train_loader = DataLoader(
    train_ds,
    batch_size=cfg1.batch_size,
    shuffle=True,
    num_workers=cfg1.num_workers,
    pin_memory=True,
    drop_last=True,
)
val_loader = DataLoader(
    val_ds,
    batch_size=cfg1.batch_size,
    shuffle=False,
    num_workers=cfg1.num_workers,
    pin_memory=True,
)

print(f"Building model: {cfg1.model_name} …")
model = build_model(
    name=cfg1.model_name,
    num_classes=num_classes,
    freeze_backbone=cfg1.freeze_backbone,
    dropout=cfg1.dropout,
)

if hasattr(model, "param_summary"):
    ps = model.param_summary()
    print(f"[Stage1] params total={ps['total']:,} trainable={ps['trainable']:,} frozen={ps['frozen']:,}")

trainer1 = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    cfg=cfg1,
    class_names=class_names,
)
trainer1.fit()



In [None]:
seed_everything(SEED)

MANIFEST_PATH = "data_local/processed/data_v2/manifest_clean.csv"
LABEL_MAP_PATH = "data_local/processed/data_v2/label_map.json"

MODEL_NAME = "resnet50"   # หรือ "mobilenetv3_large"

STAGE1_EPOCHS = 3
STAGE2_EPOCHS = 20  # รวมแล้ว 23 epochs (ปรับได้)

# ---------- Build datasets once ----------
print("Loading datasets …")
datasets = build_datasets(
    manifest_path=MANIFEST_PATH,
    label_map_path=LABEL_MAP_PATH,
    train_transform=get_train_transforms(size=IMAGE_SIZE),
    val_transform=get_val_transforms(size=IMAGE_SIZE),
)

train_ds = datasets["train"]
val_ds   = datasets["val"]

class_names = train_ds.classes
num_classes = train_ds.num_classes

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    drop_last=True,
)
val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

# ======================================================
# Stage 1: head-only warmup
# ======================================================
cfg1 = TrainConfig(
    model_name=MODEL_NAME,
    freeze_backbone=True,
    dropout=DROPOUT,
    manifest=MANIFEST_PATH,
    label_map=LABEL_MAP_PATH,
    epochs=STAGE1_EPOCHS,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    lr=1e-3,
    weight_decay=WEIGHT_DECAY,
    lr_scheduler=LR_SCHEDULER,
    image_size=IMAGE_SIZE,
    output_dir=OUTPUT_DIR,
    seed=SEED,
    device=None,

    early_stop_patience=2,
    early_stop_min_delta=1e-4,
    early_stop_metric="val_loss",
)

# --- ส่วนที่ต้องแก้ไขใน Cell 14 (หาตำแหน่งหลังจาก build_model) ---

# Build model
print(f"Building model: {cfg.model_name} …")
model = build_model(
    name=cfg.model_name,
    num_classes=num_classes,
    freeze_backbone=cfg.freeze_backbone,
    dropout=cfg.dropout,
)

# --- เพิ่มส่วนนี้เข้าไป ---
# ตรวจสอบว่าถ้าเป็น EfficientNet ให้ Unfreeze 3 บล็อกสุดท้าย
if cfg.model_name == "efficientnet_b0":
    print("Unfreezing last 3 blocks of the backbone for fine-tuning...")
    model.unfreeze_last_n_blocks(n=3)
# หากคุณใช้ ResNet หรือรุ่นอื่น ให้ใช้คำสั่ง unfreeze ทั้งหมดแทน (แต่ต้องใช้ LR ต่ำๆ)
else:
    print("Unfreezing the entire backbone...")
    model.unfreeze_backbone()
# ------------------------

if hasattr(model, "param_summary"):
    ps = model.param_summary()
    print(f"  params  total={ps['total']:,}  trainable={ps['trainable']:,}  frozen={ps['frozen']:,}")

# หลังจากนี้ก็ใช้โค้ด Trainer เดิมของคุณ...

cfg2 = TrainConfig(
    model_name=MODEL_NAME,
    freeze_backbone=False,
    dropout=DROPOUT,
    manifest=MANIFEST_PATH,
    label_map=LABEL_MAP_PATH,
    epochs=STAGE2_EPOCHS,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    lr=5e-5,               # <<< สำคัญ: ลด LR ตอน unfreeze (ลอง 3e-5, 5e-5, 1e-4)
    weight_decay=WEIGHT_DECAY,
    lr_scheduler=LR_SCHEDULER,
    image_size=IMAGE_SIZE,
    output_dir=OUTPUT_DIR,
    seed=SEED,
    device=str(trainer1.device) if hasattr(trainer1, "device") else None,

    early_stop_patience=3,
    early_stop_min_delta=1e-4,
    early_stop_metric="val_loss",
)

# สร้าง Trainer ใหม่เพื่อ reset optimizer ให้เห็น params ที่ unfreeze แล้ว
trainer2 = Trainer(model, train_loader, val_loader, cfg2, class_names)
trainer2.fit()

Loading datasets …
Building model: resnet50 (Stage1) …
Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth


100%|██████████| 97.8M/97.8M [00:01<00:00, 81.1MB/s]



  Model    : resnet50
  Device   : cuda
  Epochs   : 3
  LR       : 0.001  scheduler=cosine
  Run dir  : outputs/resnet50

── Epoch 1/3 ──
  train_loss=1.0993  val_loss=0.9497  val_acc=0.6242  val_f1=0.6227
  ★ New best F1=0.6227 — checkpoint saved
── Epoch 2/3 ──
  train_loss=0.9170  val_loss=0.8941  val_acc=0.6392  val_f1=0.6390
  ★ New best F1=0.6390 — checkpoint saved
── Epoch 3/3 ──
  train_loss=0.8805  val_loss=0.8801  val_acc=0.6533  val_f1=0.6520
  ★ New best F1=0.6520 — checkpoint saved

── Generating plots → outputs/resnet50 ──

── Final Classification Report (val) ──
              precision    recall  f1-score   support

   beverages       0.61      0.66      0.63       300
      snacks       0.72      0.60      0.65       300
    dry_food       0.66      0.59      0.62       300
    non_food       0.64      0.77      0.70       300

    accuracy                           0.65      1200
   macro avg       0.66      0.65      0.65      1200
weighted avg       0.66      0.65 

## 15. Display Training Plots

In [None]:
from IPython.display import Image as IPImage, display

run_dir = cfg.run_dir()

for plot_name in ["loss_curve.png", "accuracy_curve.png", "confusion_matrix.png"]:
    plot_path = run_dir / plot_name
    if plot_path.exists():
        print(f"\n{plot_name}:")
        display(IPImage(filename=str(plot_path)))
    else:
        print(f"⚠️  {plot_name} not found")

## 16. (Optional) Test Set Evaluation

In [None]:
test_ds = datasets["test"]
test_loader = DataLoader(
    test_ds,
    batch_size=cfg.batch_size,
    shuffle=False,
    num_workers=cfg.num_workers,
    pin_memory=True,
)

# Load best checkpoint
device = trainer.device
best_path = trainer.best_ckpt_path

if best_path.exists():
    print(f"Loading best checkpoint: {best_path}")
    ckpt = torch.load(best_path, map_location=str(device))
    model.load_state_dict(ckpt["model_state_dict"])
    model.to(device)
    model.eval()

    all_labels = []
    all_preds = []

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Test"):
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            logits = model(images)
            preds = logits.argmax(dim=1)
            all_labels.extend(labels.cpu().tolist())
            all_preds.extend(preds.cpu().tolist())

    test_metrics = compute_metrics(all_labels, all_preds, class_names)
    print(f"\n── Test Results ──")
    print(f"  Accuracy : {test_metrics['accuracy']:.4f}")
    print(f"  F1-Macro : {test_metrics['f1_macro']:.4f}")

    print(f"\n── Test Classification Report ──")
    print(get_classification_report(all_labels, all_preds, class_names))

    cm = get_confusion_matrix(all_labels, all_preds, class_names)
    plot_confusion_matrix(
        cm, class_names,
        run_dir / "test_confusion_matrix.png",
        title=f"Test Confusion Matrix — {cfg.model_name}",
    )
    display(IPImage(filename=str(run_dir / "test_confusion_matrix.png")))
else:
    print("No checkpoint found. Train the model first.")

## 17. (Optional) Download Model & Results

Uncomment to download outputs to your local machine:

In [None]:
# from google.colab import files
#
# # Download best checkpoint
# if trainer.best_ckpt_path.exists():
#     files.download(str(trainer.best_ckpt_path))
#
# # Download metrics CSV
# csv_path = cfg.run_dir() / "metrics.csv"
# if csv_path.exists():
#     files.download(str(csv_path))
#
# # Or zip the entire outputs folder
# !zip -r outputs.zip outputs/
# files.download("outputs.zip")