In [None]:
# =========================
# Cell 1: Imports + global config
# =========================
import os, json, math, random
from pathlib import Path
from dataclasses import dataclass

import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
from torchvision.datasets import Food101
from torchvision.models import resnet50, ResNet50_Weights

from skimage.color import rgb2lab
from sklearn.cluster import MiniBatchKMeans
from sklearn.neighbors import NearestNeighbors

# Repro
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Artifacts folder
ART_DIR = Path("artifacts/food101_step10_sigma5_T042")
ART_DIR.mkdir(parents=True, exist_ok=True)

K = 300
# Locked choices
AB_MIN, AB_MAX = -110.0, 110.0
GRID_STEP = 10.0

SOFT_KNN = 5
SIGMA_SOFT = 5.0       # target soft-encoding sigma
SIGMA_SMOOTH = 5.0     # prior smoothing sigma
LAMBDA_UNIFORM = 0.5

ANNEAL_T = 0.42

# Bin-building + prior settings
PRUNE_IMAGES = 30_000         # grid prune uses 30K images
PRIOR_USE_ALL_TRAIN = True    # use all train images for prior
PRIOR_BATCH_IMAGES = 256       # for prior estimation loop (image-level batches)

# Training settings
NUM_WORKERS = 6
EPOCHS = 15
LR_DECODER = 1e-3
LR_ENCODER = 1e-4
WEIGHT_DECAY = 1e-4
GRAD_CLIP = 1.0
FREEZE_EPOCHS = 1  # epoch 0..FREEZE_EPOCHS-1 frozen encoder

print("DEVICE:", DEVICE)
print("Artifacts:", ART_DIR.resolve())


In [None]:
# =========================
# Cell 2: Resize policy A transforms (torchvision)
#   - Train: Resize short side 256, RandomCrop 224, optional flip
#   - Val/Test/Prior: Resize short side 256, CenterCrop 224
# =========================

class ResizeShortSide:
    def __init__(self, short_side: int, interpolation=Image.BICUBIC):
        self.short_side = short_side
        self.interpolation = interpolation

    def __call__(self, img: Image.Image) -> Image.Image:
        w, h = img.size
        if min(w, h) == self.short_side:
            return img
        scale = self.short_side / float(min(w, h))
        new_w = int(round(w * scale))
        new_h = int(round(h * scale))
        return img.resize((new_w, new_h), self.interpolation)

train_tf = transforms.Compose([
    ResizeShortSide(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(p=0.5),
])

eval_tf = transforms.Compose([
    ResizeShortSide(256),
    transforms.CenterCrop(224),
])

# For bin pruning and priors we want center crop stability
prior_tf = eval_tf


In [None]:
# =========================
# Cell 3: Food101 datasets + train/val split
# =========================
DATA_ROOT = Path("./data")

train_base = Food101(root=str(DATA_ROOT), split="train", download=True)
test_base  = Food101(root=str(DATA_ROOT), split="test",  download=True)

# Deterministic val split from train
val_frac = 0.1
n_train = len(train_base)
idx = np.arange(n_train)
rng = np.random.default_rng(SEED)
rng.shuffle(idx)
n_val = int(round(val_frac * n_train))
val_idx = idx[:n_val].tolist()
trn_idx = idx[n_val:].tolist()

print("Food101 sizes:")
print(" train:", len(train_base), " -> trn:", len(trn_idx), " val:", len(val_idx))
print(" test :", len(test_base))


In [None]:
# =========================
# Cell 4: LAB helpers (CPU)
# =========================
def pil_to_rgb01(img: Image.Image) -> np.ndarray:
    """PIL RGB -> float32 [0,1] HxWx3"""
    arr = np.asarray(img.convert("RGB"), dtype=np.float32) / 255.0
    return arr

def rgb01_to_lab(rgb01: np.ndarray) -> np.ndarray:
    """rgb01 HxWx3 -> LAB float32 HxWx3 (L in [0,100], ab ~ [-128,127])"""
    lab = rgb2lab(rgb01).astype(np.float32)
    return lab

def clamp_ab(lab: np.ndarray, ab_min=AB_MIN, ab_max=AB_MAX) -> np.ndarray:
    lab = lab.copy()
    lab[..., 1] = np.clip(lab[..., 1], ab_min, ab_max)
    lab[..., 2] = np.clip(lab[..., 2], ab_min, ab_max)
    return lab


In [None]:
# =========================
# Cell 5: grid + pruning (nearest grid point), weighted k-means to min(K=300 , observed effective bins) 
# =========================
CENTERS_NPY = ART_DIR / f"ab_centers_k{K}.npy"

def build_ab_grid(step=GRID_STEP, ab_min=AB_MIN, ab_max=AB_MAX):
    vals = np.arange(ab_min, ab_max + 1e-6, step, dtype=np.float32)
    aa, bb = np.meshgrid(vals, vals, indexing="xy")
    grid = np.stack([aa.reshape(-1), bb.reshape(-1)], axis=1)  # (G,2)
    return grid, vals

GRID_POINTS, GRID_VALS = build_ab_grid()
G = GRID_POINTS.shape[0]
print("Grid points:", G)

# Nearest grid-point snapping:
# since grid is regular, nearest snapping is just round((x - ab_min)/step)*step + ab_min
def snap_to_grid(ab: np.ndarray, step=GRID_STEP, ab_min=AB_MIN, ab_max=AB_MAX):
    x = np.clip(ab, ab_min, ab_max)
    snapped = np.round((x - ab_min) / step) * step + ab_min
    snapped = np.clip(snapped, ab_min, ab_max)
    return snapped.astype(np.float32)

def grid_index(snapped_ab: np.ndarray, step=GRID_STEP, ab_min=AB_MIN, ab_max=AB_MAX):
    # map (a,b) on grid to a unique index in the flattened grid
    # values are in GRID_VALS, so index is computed by integer coordinates
    coord = np.round((snapped_ab - ab_min) / step).astype(np.int32)
    size = int(round((ab_max - ab_min) / step)) + 1  # e.g. 23
    a_i = coord[:, 0]
    b_i = coord[:, 1]
    return b_i * size + a_i  # consistent with meshgrid flattening

def build_centers_from_food101():
    if CENTERS_NPY.exists():
        print("Centers already exist:", CENTERS_NPY)
        return np.load(CENTERS_NPY).astype(np.float32)

    # select PRUNE_IMAGES from train split
    prune_n = min(PRUNE_IMAGES, len(trn_idx))
    prune_ids = trn_idx[:prune_n]

    counts = np.zeros(G, dtype=np.int64)

    for j, i in enumerate(prune_ids, 1):
        img, _ = train_base[i]
        img = prior_tf(img)  # center crop
        rgb01 = pil_to_rgb01(img)
        lab = clamp_ab(rgb01_to_lab(rgb01))
        ab = lab[..., 1:3].reshape(-1, 2)

        snapped = snap_to_grid(ab)
        idxs = grid_index(snapped)
        # count hits (frequency-weighted)
        counts += np.bincount(idxs, minlength=G)

        if j % 2000 == 0:
            observed = int((counts > 0).sum())
            print(f"  prune {j}/{prune_n} | observed grid points: {observed}/{G}")

    observed_mask = counts > 0
    obs_points = GRID_POINTS[observed_mask]
    obs_counts = counts[observed_mask].astype(np.float64)

    print("Observed grid points:", obs_points.shape[0])
    k_chosen= np.minimum( K, obs_points.shape[0] )
    print("Using K =", k_chosen)
    # Weighted k-means on observed grid points
    kmeans = MiniBatchKMeans(
        n_clusters=k_chosen,
        random_state=SEED,
        batch_size=2048,
        n_init=10,
        max_iter=300,
        init_size=20000,
        reassignment_ratio=0.01,
        verbose=0,
    )

    try:
        kmeans.fit(obs_points, sample_weight=obs_counts)
        print("KMeans fit used sample_weight.")
    except TypeError:
        print("KMeans sample_weight not supported; falling back to repetition.")
        rep = np.clip((obs_counts / obs_counts.mean()).round().astype(np.int32), 1, 200)
        rep_points = np.repeat(obs_points, rep, axis=0)
        kmeans.fit(rep_points)

    centers = kmeans.cluster_centers_.astype(np.float32)
    np.save(CENTERS_NPY, centers)
    meta = {
        "dataset": "Food101",
        "K": int(k_chosen),
        "ab_min": AB_MIN,
        "ab_max": AB_MAX,
        "grid_step": GRID_STEP,
        "prune_images": prune_n,
        "resize_policy": "short_side=256, center_crop=224",
        "nearest_grid": True,
        "weighted_kmeans": True,
        "sigma_soft": SIGMA_SOFT,
        "seed": SEED,
    }
    CENTERS_META = ART_DIR / f"ab_centers_k{K}_meta.json"
    CENTERS_META.write_text(json.dumps(meta, indent=2))
    print("Saved:", CENTERS_NPY)
    print("Saved:", CENTERS_META)
    return centers, k_chosen

centers, k_chosen = build_centers_from_food101()
print("centers shape:", centers.shape)
print("centers range:",
      "a:", float(centers[:,0].min()), float(centers[:,0].max()),
      "| b:", float(centers[:,1].min()), float(centers[:,1].max()))


In [None]:
# =========================
# Cell 6: Compute rebalancing weights (soft prior + smoothing + lambda mix)
# =========================
WEIGHTS_NPY  = ART_DIR / f"ab_weights_k{k_chosen}.npy"
WEIGHTS_META = ART_DIR / f"ab_weights_k{k_chosen}_meta.json"

nn_centers_5 = NearestNeighbors(n_neighbors=min(SOFT_KNN, k_chosen), algorithm="auto").fit(centers)
nn_centers_smooth = NearestNeighbors(n_neighbors=min(60, k_chosen), algorithm="auto").fit(centers)

def soft_encode_ab(ab_hw2: np.ndarray, sigma=SIGMA_SOFT):
    """
    ab_hw2: (H*W,2)
    returns:
      idx: (H*W,5) int64
      w  : (H*W,5) float32, rows sum to 1
    """
    dists, idx = nn_centers_5.kneighbors(ab_hw2, return_distance=True)  # dists: (N,5)
    w = np.exp(-(dists**2) / (2.0 * sigma * sigma)).astype(np.float32)
    w /= (w.sum(axis=1, keepdims=True) + 1e-12)
    return idx.astype(np.int64), w

def smooth_prior(p: np.ndarray, sigma=SIGMA_SMOOTH):
    """
    Smooth prior across centers using Gaussian weights in ab space.
    Approximated with kNN on centers (60 neighbors).
    """
    # For each center q', distribute p[q'] to its neighbors
    dists, nbrs = nn_centers_smooth.kneighbors(centers, return_distance=True)  # both (K,Kn)
    Kn = nbrs.shape[1]
    W = np.exp(-(dists**2) / (2.0 * sigma * sigma)).astype(np.float64)  # (K,Kn)
    W /= (W.sum(axis=1, keepdims=True) + 1e-12)

    p_s = np.zeros_like(p, dtype=np.float64)
    for q in range(k_chosen):
        p_s[nbrs[q]] += p[q] * W[q]
    p_s = p_s / (p_s.sum() + 1e-12)
    return p_s.astype(np.float64)

def compute_rebalancing_weights():
    if WEIGHTS_NPY.exists():
        print("Weights already exist:", WEIGHTS_NPY)
        return np.load(WEIGHTS_NPY).astype(np.float32)

    # Choose indices for prior computation
    prior_ids = trn_idx if PRIOR_USE_ALL_TRAIN else trn_idx[:PRUNE_IMAGES]
    print("Computing prior from images:", len(prior_ids))

    hist = np.zeros(k_chosen, dtype=np.float64)

    # Image-level batching for speed (still CPU work)
    for start in range(0, len(prior_ids), PRIOR_BATCH_IMAGES):
        batch_ids = prior_ids[start:start+PRIOR_BATCH_IMAGES]
        for i in batch_ids:
            img, _ = train_base[i]
            img = prior_tf(img)  # center crop, per our lock
            rgb01 = pil_to_rgb01(img)
            lab = clamp_ab(rgb01_to_lab(rgb01))
            ab = lab[..., 1:3].reshape(-1, 2)

            idx5, w5 = soft_encode_ab(ab, sigma=SIGMA_SOFT)
            # accumulate soft counts
            # hist[q] += sum over pixels of weights assigned to q
            flat_idx = idx5.reshape(-1)
            flat_w   = w5.reshape(-1)
            hist += np.bincount(flat_idx, weights=flat_w, minlength=k_chosen)

        if (start // PRIOR_BATCH_IMAGES) % 20 == 0:
            done = min(start + PRIOR_BATCH_IMAGES, len(prior_ids))
            print(f"  prior progress: {done}/{len(prior_ids)}")

    p = hist / (hist.sum() + 1e-12)
    p_s = smooth_prior(p, sigma=SIGMA_SMOOTH)

    p_tilde = (1.0 - LAMBDA_UNIFORM) * p_s + LAMBDA_UNIFORM * (1.0 / k_chosen)
    w = 1.0 / (p_tilde + 1e-12)
    w = w / w.mean()



    w = w.astype(np.float32)
    np.save(WEIGHTS_NPY, w)

    meta = {
    "dataset": "Food101",
    "K": int(k_chosen),
    "lambda_uniform": float(LAMBDA_UNIFORM),
    "sigma_soft": float(SIGMA_SOFT),
    "sigma_smooth": float(SIGMA_SMOOTH),
    "prior_soft_counts": True,
    "prior_transform": "short_side=256, center_crop=224",
    "annealed_T": float(ANNEAL_T),
    "centers_file": str(CENTERS_NPY.resolve()),
    "seed": int(SEED),
    "prior_images_used": int(len(prior_ids)),
    }
    WEIGHTS_META.write_text(json.dumps(meta, indent=2))

    print("Saved:", WEIGHTS_NPY)
    print("Saved:", WEIGHTS_META)
    print("sum(p):", p.sum())
    print("mean(w):", w.mean(), "min/max:", w.min(), w.max())
    print("top-10 weight bins:", np.sort(w)[-10:])

    return w

ab_weights = compute_rebalancing_weights()
print("weights:", ab_weights.shape, "min/max:", float(ab_weights.min()), float(ab_weights.max()))
