# iSyncTab Demo Example

### Demo configuration (read this before running)

> **Attention:**  
> This script is configured as a **demo run** of **iSyncTab + NS-PFS** on the **HAM10000** dataset, using:
> - **Optuna** with `N_TRIALS = 5`
> - **Final training** with `FINAL_EPOCHS = 5`
> - **Tuning** with `EPOCHS_TUNE = 3`

For a **full experiment**, you should **increase**:
- `N_TRIALS` &nbsp;→ more Optuna trials  
- `FINAL_EPOCHS` → more epochs for the final train on train+val  
- Optionally `EPOCHS_TUNE` → more epochs per trial during tuning

The code will:
- Run Optuna hyperparameter search on the validation set  
- Print the **best validation objective** and the corresponding **best hyperparameters**  
- Retrain iSyncTab on the **combined train+val split** using those best hyperparameters  
- Report the final **test accuracy** and **test loss** on HAM10000

---

### Required Python packages

Make sure these libraries are installed in your environment:

```bash
pip install torch torchvision linformer optuna pandas numpy pillow

### HAM10000 dataset preparation

We use the **HAM10000** dataset from Kaggle and assume a simple layout:

- One folder containing **all images**: `ham_images_link/`
- One processed metadata file: `HAM10000_metadata_processed.csv`

---

#### 1. Download HAM10000 from Kaggle

Source (Kaggle):  
https://www.kaggle.com/datasets/kmader/skin-cancer-mnist-ham10000

Download and extract the archive locally. You should get:
- Image folders such as `HAM10000_images_part_1`, `HAM10000_images_part_2`, …
- A metadata CSV like `HAM10000_metadata.csv`

---

#### 2. Put all images into a single folder

Create a directory in your working path, for example:

```text
ham_images_link/

In [1]:
# ============================ QUIET LOGS ============================
import os, warnings
warnings.filterwarnings("ignore", message=r"Deterministic behavior.*")
warnings.filterwarnings("ignore", message=r".*does not have a deterministic implementation.*")
warnings.filterwarnings("ignore", category=UserWarning, module="torch")
os.environ.pop("CUBLAS_WORKSPACE_CONFIG", None)

# ============================ IMPORTS ===============================
from pathlib import Path
import numpy as np
import pandas as pd
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as T
from torch import optim
import optuna

# Your model (simple iSyncTab)
from iSyncTab import iSyncTab, set_seed

# ============================ USER PATHS ============================
# Processed HAM metadata with columns:
# lesion_id, image_id, dx, dx_type, age, sex, localization, filename
CSV_PATH = "HAM10000_metadata_processed.csv"
IMG_ROOT = "ham_images_link"  # directory containing image files

# ============================ DEVICE PICKER (fixed to cuda:6) =======
def pick_device(force_idx=6):
    if torch.cuda.is_available():
        if force_idx < torch.cuda.device_count():
            torch.cuda.set_device(force_idx)
            return torch.device(f"cuda:{force_idx}")
        torch.cuda.set_device(0)
        return torch.device("cuda:0")
    return torch.device("cpu")

set_seed(123)
device = pick_device()
print("Device:", device)

# ============================ HELPERS ===============================
IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tif", ".tiff"}

def _has_img_ext(x: str) -> bool:
    s = str(x).lower().strip()
    return any(s.endswith(ext) for ext in IMG_EXTS)

def _ensure_jpg(x: str) -> str:
    s = str(x).strip()
    return s if _has_img_ext(s) else (s + ".jpg")

# ============================ DATASET ===============================
class HamDataset(Dataset):
    """
    HAM10000 multimodal dataset (processed version):
      Required columns:
        lesion_id, image_id, dx, dx_type, age, sex, localization, filename
      - Image: IMG_ROOT / filename  (fallback: image_id + '.jpg' if needed)
      - Tabular features:
          * Numeric: age
          * Categorical: dx_type, sex, localization
      - Target: dx (string → label-encoded integer)
    """
    def __init__(self, csv_path, img_root, transform=None):
        df = pd.read_csv(csv_path)
        self.img_root = Path(img_root)
        self.transform = transform

        required_cols = {
            "lesion_id", "image_id", "dx", "dx_type",
            "age", "sex", "localization", "filename"
        }
        missing = required_cols - set(df.columns)
        if missing:
            raise AssertionError(f"CSV missing required columns: {missing}")

        # ----- target & class names from dx -----
        dx_str = df["dx"].astype(str)
        classes = sorted(dx_str.unique().tolist())
        mapping = {c: i for i, c in enumerate(classes)}
        df["label"] = dx_str.map(mapping).astype(int)

        self.classes = classes
        self.class_to_id = mapping
        self.id_to_class = {i: c for c, i in mapping.items()}
        self.y = torch.tensor(df["label"].to_numpy(), dtype=torch.long)
        self.n_classes = len(self.classes)

        # ----- resolve image paths -----
        keep_idx, img_paths = [], []
        for i, row in df.iterrows():
            # Prefer filename column
            fname = str(row["filename"])
            p = self.img_root / fname
            if not p.exists():
                # Fallback to image_id + .jpg if needed
                iid = str(row["image_id"])
                name = iid if _has_img_ext(iid) else _ensure_jpg(iid)
                p = self.img_root / name
            if p.exists():
                keep_idx.append(i)
                img_paths.append(p)

        if not keep_idx:
            raise RuntimeError("No valid images found under IMG_ROOT using 'filename' or 'image_id'.")

        df = df.iloc[keep_idx].reset_index(drop=True)
        self.img_paths = img_paths
        self.y = self.y[keep_idx]

        # ----- build feature lists explicitly -----
        # We use only: age (numeric), dx_type/sex/localization (categorical)
        self.num_cols = ["age"]
        self.cat_cols = ["dx_type", "sex", "localization"]
        self.text_cols = []  # none

        # numeric tensor (age)
        num_df = df[self.num_cols].astype(float)
        # keep NaNs; TabularTokenEncoder does median imputation
        self.x_num = torch.tensor(num_df.to_numpy(copy=True), dtype=torch.float32)

        # categorical -> ids (per-column vocab, -1 for missing/unseen)
        self.x_cat = None
        self._cat_vocabs = {}
        cat_arrays = []
        for col in self.cat_cols:
            vals = df[col].astype("object")
            uniq_vals = sorted({str(v) for v in vals.dropna().unique().tolist()})
            vocab = {tok: i for i, tok in enumerate(uniq_vals)}
            self._cat_vocabs[col] = vocab
            ids = [vocab.get(str(v), -1) if pd.notna(v) else -1 for v in vals]
            cat_arrays.append(torch.tensor(ids, dtype=torch.long))
        if cat_arrays:
            self.x_cat = torch.stack(cat_arrays, dim=1)
        else:
            self.x_cat = torch.zeros((len(df), 0), dtype=torch.long)

        # dummy text channel so EmbeddingBag never sees O_text=0
        self.x_text = torch.zeros((len(df), 1, 1), dtype=torch.long)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        img = Image.open(img_path)
        if img.mode != "RGB":
            img = img.convert("RGB")
        if self.transform:
            img = self.transform(img)

        x_num  = self.x_num[idx] if self.x_num.numel() else torch.zeros(0, dtype=torch.float32)
        x_cat  = self.x_cat[idx] if self.x_cat is not None else torch.zeros(0, dtype=torch.long)
        x_text = self.x_text[idx]  # (1,1)
        y      = self.y[idx]

        # shape to (1, O_*) to match TabularTokenEncoder expectations
        x_tab = {
            "num":  x_num.unsqueeze(0),
            "cat":  x_cat.unsqueeze(0),
            "text": x_text.unsqueeze(0)
        }
        return x_tab, img, y

# -------------- Collate --------------
def collate(batch):
    x_tab_b, x_img_b, y_b = zip(*batch)
    x_num_b  = torch.cat([b["num"]  for b in x_tab_b], dim=0)
    if x_tab_b[0]["cat"].numel():
        x_cat_b = torch.cat([b["cat"] for b in x_tab_b], dim=0)
    else:
        x_cat_b = torch.zeros((len(x_tab_b), 0), dtype=torch.long)
    x_text_b = torch.cat([b["text"] for b in x_tab_b], dim=0)
    x_tab_batch = {"num": x_num_b, "cat": x_cat_b, "text": x_text_b}
    x_img_b = torch.stack(x_img_b)
    y_b     = torch.tensor(y_b)
    return x_tab_batch, x_img_b, y_b

def accuracy_from_logits(logits, y):
    return (logits.argmax(dim=1) == y).float().mean().item()

# ========================== DATA & LOADERS ===========================
transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])  # ImageNet
])

ds = HamDataset(CSV_PATH, IMG_ROOT, transform=transform)

n_classes = ds.n_classes
N = len(ds)
n_train = int(round(0.64 * N))
n_val   = int(round(0.16 * N))
n_test  = N - n_train - n_val
assert n_train > 0 and n_val > 0 and n_test > 0, f"Bad split sizes (N={N})."

g = torch.Generator().manual_seed(123)
ds_train, ds_val, ds_test = random_split(ds, [n_train, n_val, n_test], generator=g)

def make_loaders(batch_size):
    pin = device.type == "cuda"
    dl_train = DataLoader(ds_train, batch_size=batch_size, shuffle=True,
                          num_workers=0, pin_memory=pin, collate_fn=collate)
    dl_val   = DataLoader(ds_val,   batch_size=batch_size, shuffle=False,
                          num_workers=0, pin_memory=pin, collate_fn=collate)
    dl_test  = DataLoader(ds_test,  batch_size=batch_size, shuffle=False,
                          num_workers=0, pin_memory=pin, collate_fn=collate)
    return dl_train, dl_val, dl_test

# token hint = numeric + categorical + (dummy) text(=1)
NUM_TAB_FEATURES = len(ds.num_cols) + len(ds.cat_cols) + 1

# ====================== SAFETY HOOK (nspfs -> Long valid indices) ======================
def _install_nspfs_safety_hook(model: torch.nn.Module):
    """
    Ensure model.nspfs forward returns Long indices in [0, L-1].
    """
    if not hasattr(model, "nspfs") or not isinstance(model.nspfs, torch.nn.Module):
        return None

    def hook(mod, inputs, output):
        try:
            t_tok = inputs[0]; i_tok = inputs[1]
            L = int(t_tok.size(1) + i_tok.size(1))
        except Exception:
            L = None
            t_tok = None
        if isinstance(output, torch.Tensor):
            if output.dtype != torch.long:
                output = output.long()
            if (L is not None) and output.numel() > 0:
                output = output.clamp_(0, max(L - 1, 0))
            if t_tok is not None and isinstance(t_tok, torch.Tensor):
                output = output.to(t_tok.device)
        return output

    handle = model.nspfs.register_forward_hook(hook)
    return handle

# ============================ MODEL ================================
def build_model(params):
    core = dict(
        num_tab_features = NUM_TAB_FEATURES,
        num_classes      = n_classes,
        num_clusters     = params["num_clusters"],
        metric           = params["metric"],
        lambda_fs        = params["lambda_fs"],
        pretrained_resnet= True
    )
    opt_kwargs = dict(
        d_model         = params.get("d_model"),
        linformer_depth = params.get("linformer_depth"),
        linformer_heads = params.get("linformer_heads"),
        linformer_k     = params.get("linformer_k"),
    )
    opt_kwargs = {k: v for k, v in opt_kwargs.items() if v is not None}
    try:
        model = iSyncTab(**core, **opt_kwargs).to(device)
    except TypeError:
        model = iSyncTab(**core).to(device)

    # install safety hook (kept as attribute to avoid GC)
    model._nspfs_hook_handle = _install_nspfs_safety_hook(model)
    return model

def run_epoch(model, loader, opt=None):
    train = opt is not None
    model.train(train)
    losses, accs = [], []
    for x_tab, img, y in loader:
        img, y = img.to(device), y.to(device)
        x_tab = {k: v.to(device) for k, v in x_tab.items()}
        out = model(x_tab, img, y=y)
        if train:
            opt.zero_grad()
            out["loss"].backward()
            opt.step()
        losses.append(out["loss"].item())
        accs.append(accuracy_from_logits(out["logits"], y))
    return (float(np.mean(losses)) if losses else 0.0,
            float(np.mean(accs))   if accs   else 0.0)

# ================================ OPTUNA =============================
# Objective = for best performance
PENALIZE_LAMBDA = 0.0
EPOCHS_TUNE     = 3
FINAL_EPOCHS    = 5

def suggest_params(trial):
    return {
        "d_model": trial.suggest_categorical("d_model", [128, 192, 256]),
        "linformer_heads": trial.suggest_categorical("linformer_heads", [2, 4, 8]),
        "linformer_depth": trial.suggest_int("linformer_depth", 3, 5),
        "linformer_k": trial.suggest_categorical("linformer_k", [16, 32, 64]),
        "num_memory_tokens": trial.suggest_int("num_memory_tokens", 1, 3),
        "num_clusters": trial.suggest_int("num_clusters", 3, 7),
        "metric": trial.suggest_categorical("metric", ["variance", "euclidean", "cosine", "correlation", "kl", "js", "manhattan"]),
        "lr": trial.suggest_float("lr", 1e-4, 4e-4, log=True),
        "weight_decay": trial.suggest_float("weight_decay", 1e-6, 2e-4, log=True),
        "batch_size": trial.suggest_categorical("batch_size", [8, 16, 32]),
        "lambda_fs": trial.suggest_float("lambda_fs", 0.0, 0.3),
    }

def objective(trial):
    params = suggest_params(trial)

    # sanity: d_model divisible by heads
    if params["d_model"] % params["linformer_heads"] != 0:
        raise optuna.TrialPruned()

    set_seed(1000 + trial.number)
    dl_train, dl_val, _ = make_loaders(params["batch_size"])

    model = build_model(params)
    opti  = optim.AdamW(model.parameters(), lr=params["lr"], weight_decay=params["weight_decay"])

    # warmup (optional, to initialize NSPFS etc.)
    try:
        x_tab_b, img_b, y_b = next(iter(dl_train))
        img_b, y_b = img_b.to(device), y_b.to(device)
        x_tab_b = {k: v.to(device) for k, v in x_tab_b.items()}
        _ = model(x_tab_b, img_b, y=y_b)
    except StopIteration:
        pass

    for _ in range(EPOCHS_TUNE):
        run_epoch(model, dl_train, opt=opti)

    _, val_acc = run_epoch(model, dl_val, opt=None)
    obj = float(val_acc - PENALIZE_LAMBDA * params["lambda_fs"])

    print(f"[Trial {trial.number:02d}] val_acc={val_acc:.4f}, "
          f"lambda_fs={params['lambda_fs']:.3f}, objective={obj:.4f}")

    return obj

# Keep Optuna metadata minimal (no explicit study_name)
N_TRIALS = 5
study = optuna.create_study(
    direction="maximize",
    sampler=optuna.samplers.TPESampler(seed=123)
)
study.optimize(objective, n_trials=N_TRIALS,
               gc_after_trial=True, show_progress_bar=True)

print("\n=== Best (validation objective) ===")
print(f"Score = {study.best_value:.4f}")
best = study.best_trial.params
for k, v in best.items():
    print(f"{k}: {v}")

# ==================== RETRAIN BEST + TEST EVAL =======================
set_seed(777)
dl_train_best, dl_val_best, dl_test_best = make_loaders(best["batch_size"])

# Merge train+val for final training
trainval_indices = list(range(len(ds_train))) + [len(ds_train) + i for i in range(len(ds_val))]
subset = torch.utils.data.Subset(ds, trainval_indices)
dl_trainval = DataLoader(subset, batch_size=best["batch_size"], shuffle=True,
                         num_workers=0, pin_memory=(device.type=="cuda"),
                         collate_fn=collate)

model_best = build_model(best)
opt_best   = optim.AdamW(model_best.parameters(), lr=best["lr"], weight_decay=best["weight_decay"])

# warmup once
try:
    x_tab_b, img_b, y_b = next(iter(dl_trainval))
    img_b, y_b = img_b.to(device), y_b.to(device)
    x_tab_b = {k: v.to(device) for k, v in x_tab_b.items()}
    _ = model_best(x_tab_b, img_b, y=y_b)
except StopIteration:
    pass

for ep in range(FINAL_EPOCHS):
    train_loss, train_acc = run_epoch(model_best, dl_trainval, opt=opt_best)
    print(f"[Final Train] Epoch {ep+1:02d} | loss={train_loss:.4f} | acc={train_acc:.4f}")

test_loss, test_acc = run_epoch(model_best, dl_test_best, opt=None)
print("\n=== FINAL TEST RESULTS (iSyncTab + NS-PFS, HAM10000) ===")
print(f"Test accuracy: {test_acc:.4f}")
print(f"Test loss:     {test_loss:.4f}")
print("Classes:", ds.classes)

Device: cuda:6


[I 2025-11-21 00:59:39,615] A new study created in memory with name: no-name-3ea1bec3-35fb-492b-a160-10e035324833


  0%|          | 0/5 [00:00<?, ?it/s]

[Trial 00] val_acc=0.8103, lambda_fs=0.217, objective=0.8103
[I 2025-11-21 01:26:00,601] Trial 0 finished with value: 0.8103233830845771 and parameters: {'d_model': 128, 'linformer_heads': 4, 'linformer_depth': 5, 'linformer_k': 16, 'num_memory_tokens': 2, 'num_clusters': 6, 'metric': 'correlation', 'lr': 0.0002090220546562389, 'weight_decay': 2.882541940368984e-05, 'batch_size': 8, 'lambda_fs': 0.21673301477106646}. Best is trial 0 with value: 0.8103233830845771.
[Trial 01] val_acc=0.8057, lambda_fs=0.184, objective=0.8057
[I 2025-11-21 01:40:23,233] Trial 1 finished with value: 0.8056930693069307 and parameters: {'d_model': 192, 'linformer_heads': 4, 'linformer_depth': 4, 'linformer_k': 32, 'num_memory_tokens': 1, 'num_clusters': 5, 'metric': 'euclidean', 'lr': 0.0003323304105591294, 'weight_decay': 3.769687142814138e-06, 'batch_size': 16, 'lambda_fs': 0.1838683577288903}. Best is trial 0 with value: 0.8103233830845771.
[Trial 02] val_acc=0.7730, lambda_fs=0.005, objective=0.7730
[I 