In [None]:
import os
from pathlib import Path
from typing import Callable, Optional, Sequence, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import KFold

from monai.networks.nets import resnet

try:
    import torchio as tio
except Exception:
    tio = None

def get_transform_torchio():
    if tio is None:
        print("[TorchIO] Not installed; training will run without augmentation.")
        return None

    aug = tio.Compose([
        # Rotation affine (±10°), no scaling or translation
        tio.RandomAffine(scales=1.0, degrees=10, translation=0),
        # Flips along AP & SI (0.5 probability each)
        tio.RandomFlip(axes=('AP', 'SI'), flip_probability=0.5),
        # Mild intensity gamma jitter (±10% in log space)
        tio.RandomGamma(log_gamma=(-0.1, 0.1)),
    ])

    def apply(x: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
        # x: [C, D, H, W], float32
        if isinstance(x, np.ndarray):
            x = torch.from_numpy(x.astype(np.float32))
        subj = tio.Subject(img=tio.ScalarImage(tensor=x))
        out = aug(subj)
        return out['img'].tensor  # [C, D, H, W], float32
    return apply

class Cfg:
    seed = 42
    n_splits = 5
    max_epochs = 1000
    patience = 50
    batch_size = 3
    num_workers = 16

    # optimization
    weight_decay = 1e-4
    lr_encoders = 1e-4
    lr_mlps     = 1e-3
    lr_decoders = 3e-3
    wd_decoders = 0.0    # often better w/ 0 WD on decoders
    warmup_freeze_epochs = 0  # set >0 to freeze encoders at start

    # IO
    save_root = Path("runs/ae_sep_160x160x48")

    # data shape & preprocessing
    target_shape = (160, 160, 48)  # (D, H, W)

def set_seed(seed=42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_seed(Cfg.seed)


# -------------------------------- Dataset ------------------------------------
class MyDataset(Dataset):
    """
    Flexible dataset:
      - ids: sequence of case IDs OR tuples of (tumor_path, liver_path)
      - get_paths: optional callable mapping id -> (tumor_path, liver_path)
      - transform: TorchIO callable that operates on [C, D, H, W]
    Returns:
      x: torch.float32 [2, 160, 160, 48]
      case_id: identifier/path tuple
    """
    def __init__(
        self,
        ids: Sequence,
        transform: Optional[Callable] = None,
        get_paths: Optional[Callable[[object], Tuple[str, str]]] = None,
        target_shape: Tuple[int, int, int] = Cfg.target_shape,
    ):
        self.ids = list(ids)
        self.transform = transform
        self.get_paths = get_paths
        self.target_shape = target_shape

    def __len__(self): return len(self.ids)

    def _resolve_paths(self, case_id) -> Tuple[str, str]:
        # If ID is already a tuple of paths, use it directly
        if isinstance(case_id, (tuple, list)) and len(case_id) == 2:
            return case_id[0], case_id[1]
        # If ID is a dict-like with explicit keys
        if isinstance(case_id, dict):
            return case_id["tumor_path"], case_id["liver_path"]
        # Otherwise, require a resolver callable
        if self.get_paths is None:
            raise ValueError(
                "MyDataset needs 'get_paths' to resolve case_id -> (tumor_path, liver_path)"
            )
        return self.get_paths(case_id)

    @staticmethod
    def _load_volume(path: str) -> np.ndarray:
        # Supports .npy or NIfTI (.nii/.nii.gz)
        if path.endswith(".npy"):
            vol = np.load(path).astype(np.float32)
        else:
            import nibabel as nib
            vol = nib.load(path).get_fdata().astype(np.float32)
        # If 4D, take first channel
        if vol.ndim == 4:
            vol = vol[..., 0]
        # Try to coerce to (D,H,W): if the first axis looks like H/W, swap
        D, H, W = vol.shape if vol.ndim == 3 else (0, 0, 0)
        # Heuristic: if last dim is small (like 48) it's probably W, so keep
        return vol

    def __getitem__(self, idx: int):
        case_id = self.ids[idx]
        tumor_path, liver_path = self._resolve_paths(case_id)
        t = self._load_volume(tumor_path)
        l = self._load_volume(liver_path)

        # Stack to channels-first [C,D,H,W]
        x = np.stack([t, l], axis=0).astype(np.float32)

        # TorchIO aug (train only)
        if self.transform is not None:
            x = self.transform(x)  # expects/returns [C,D,H,W]
            if not isinstance(x, torch.Tensor):
                x = torch.from_numpy(np.asarray(x, dtype=np.float32))
        else:
            x = torch.from_numpy(x)

        return x, case_id


# ---------------------------- Model Components --------------------------------
class MLP(nn.Module):
    """2048 -> 1024 -> 256 latent."""
    def __init__(self, input_dim: int = 2048, latent_dim: int = 256, dropout: float = 0.0):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(1024, latent_dim),
        )
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.proj(x)

class Decoder(nn.Module):
    """
    256 latent -> [1, 160, 160, 48]
    Base seed: (10,10,3) -> 4 × ConvTranspose3d (stride=2) -> (160,160,48)
    Linear output (Option A).
    """
    def __init__(self, latent_dim: int = 256, base_shape=(10, 10, 3)):
        super().__init__()
        self.base_shape = base_shape
        C0 = 128
        self.fc = nn.Linear(latent_dim, C0 * np.prod(self.base_shape))
        self.up1 = nn.Sequential(
            nn.ConvTranspose3d(C0, 64, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv3d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )
        self.up2 = nn.Sequential(
            nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv3d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )
        self.up3 = nn.Sequential(
            nn.ConvTranspose3d(32, 16, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv3d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )
        self.up4 = nn.Sequential(
            nn.ConvTranspose3d(16, 8, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv3d(8, 8, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )
        self.out = nn.Conv3d(8, 1, kernel_size=3, padding=1)
        self.reset_parameters()

    def reset_parameters(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv3d, nn.ConvTranspose3d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
                if getattr(m, "bias", None) is not None:
                    nn.init.constant_(m.bias, 0.0)

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        B = z.size(0)
        x = self.fc(z).view(B, 128, *self.base_shape)  # [B,128,10,10,3]
        x = self.up1(x)
        x = self.up2(x)
        x = self.up3(x)
        x = self.up4(x)        # [B,8,160,160,48]
        x = self.out(x)        # [B,1,160,160,48] (linear)
        return x

def build_encoder() -> nn.Module:
    """Single-channel 3D ResNet-50-like encoder that returns a 2048-D vector."""
    return resnet.resnet50(
        n_input_channels=1,
        feed_forward=False,
        shortcut_type="B",
        bias_downsample=False,
        pretrained=True,
    )


# --------------------------------- Loss --------------------------------------
def recon_loss(pred: torch.Tensor, target: torch.Tensor,
               l1_w: float = 1.0, mse_w: float = 1.0):
    # Option A: z-score domain — no BCE
    loss = 0.0
    if l1_w:  loss += l1_w  * F.l1_loss(pred, target, reduction="mean")
    if mse_w: loss += mse_w * F.mse_loss(pred, target, reduction="mean")
    return loss


# --------------------------- Optimizer grouping -------------------------------
def add_groups(module: nn.Module, lr: float, wd: float):
    decays, nodecays = [], []
    for n, p in module.named_parameters():
        if not p.requires_grad: continue
        if any(k in n.lower() for k in ("bias", "bn", "norm", "ln", "gn")):
            nodecays.append(p)  # no weight decay for norms/bias
        else:
            decays.append(p)
    groups = []
    if decays:   groups.append({"params": decays,   "lr": lr, "weight_decay": wd})
    if nodecays: groups.append({"params": nodecays, "lr": lr, "weight_decay": 0.0})
    return groups


# ------------------------------- Train / Eval --------------------------------
def train_one_epoch(
    enc_tumor, enc_liver,
    mlp_tumor, mlp_liver,
    dec_tumor, dec_liver,
    loader, optimizer,
    epoch: int
):
    # optional warmup freeze for encoders
    warm = Cfg.warmup_freeze_epochs
    freeze = (epoch <= warm and warm > 0)
    for p in enc_tumor.parameters(): p.requires_grad = not freeze
    for p in enc_liver.parameters(): p.requires_grad = not freeze

    enc_tumor.train(); enc_liver.train()
    mlp_tumor.train(); mlp_liver.train()
    dec_tumor.train(); dec_liver.train()

    running, n = 0.0, 0
    for images, _ in loader:
        images = images.to(device)           # [B, 2, 160, 160, 48]
        tumor = images[:, 0:1, ...]
        liver = images[:, 1:2, ...]

        optimizer.zero_grad(set_to_none=True)

        # Tumor branch
        f_t = enc_tumor(tumor)               # [B,2048]
        z_t = mlp_tumor(f_t)                 # [B,256]
        r_t = dec_tumor(z_t)                 # [B,1,160,160,48]

        # Liver branch
        f_l = enc_liver(liver)
        z_l = mlp_liver(f_l)
        r_l = dec_liver(z_l)

        loss = recon_loss(r_t, tumor) + recon_loss(r_l, liver)
        loss.backward()
        optimizer.step()

        bs = images.size(0)
        running += loss.item() * bs
        n += bs

    return running / max(1, n)

@torch.no_grad()
def evaluate(
    enc_tumor, enc_liver,
    mlp_tumor, mlp_liver,
    dec_tumor, dec_liver,
    loader
):
    enc_tumor.eval(); enc_liver.eval()
    mlp_tumor.eval(); mlp_liver.eval()
    dec_tumor.eval(); dec_liver.eval()

    running, n = 0.0, 0
    last_batch = None
    for images, _, _ in loader:
        images = images.to(device)
        tumor = images[:, 0:1, ...]
        liver = images[:, 1:2, ...]

        f_t = enc_tumor(tumor); z_t = mlp_tumor(f_t); r_t = dec_tumor(z_t)
        f_l = enc_liver(liver); z_l = mlp_liver(f_l); r_l = dec_liver(z_l)

        loss = recon_loss(r_t, tumor) + recon_loss(r_l, liver)

        bs = images.size(0)
        running += loss.item() * bs
        n += bs
        last_batch = (tumor.detach().cpu(), r_t.detach().cpu(),
                      liver.detach().cpu(), r_l.detach().cpu())
    return running / max(1, n), last_batch


# --------------------------------- Driver ------------------------------------
def run_training(
    ids: Sequence,
    transform_torchio: Optional[Callable] = None,
):
    """
    ids: sequence of IDs you want to train on. Each item can be:
         - a tuple (tumor_path, liver_path)
    """
    Cfg.save_root.mkdir(parents=True, exist_ok=True)

    kf = KFold(n_splits=Cfg.n_splits, shuffle=True, random_state=Cfg.seed)
    idx_array = np.arange(len(ids))

    for fold_idx, (idx_train_val, idx_test) in enumerate(kf.split(idx_array), start=0):
        idx_train = idx_train_val[:-1]
        idx_val   = idx_train_val[-1:]

        # Datasets
        train_set = MyDataset([ids[i] for i in idx_train], transform_torchio, get_paths)
        val_set   = MyDataset([ids[i] for i in idx_val],   None,               get_paths)
        test_set  = MyDataset([ids[i] for i in idx_test],  None,               get_paths)

        # Persist split indices
        fold_dir = Cfg.save_root / f"fold_{fold_idx:02d}"
        fold_dir.mkdir(parents=True, exist_ok=True)
        np.save(fold_dir / "idx_train.npy", idx_train)
        np.save(fold_dir / "idx_val.npy", idx_val)
        np.save(fold_dir / "idx_test.npy", idx_test)

        # Loaders
        train_loader = DataLoader(train_set, batch_size=Cfg.batch_size, shuffle=True,
                                  num_workers=Cfg.num_workers, drop_last=True)
        val_loader   = DataLoader(val_set,   batch_size=len(val_set), shuffle=False)
        test_loader  = DataLoader(test_set,  batch_size=len(test_set), shuffle=False)

        # Build separate components
        enc_tumor = build_encoder().to(device)
        enc_liver = build_encoder().to(device)
        mlp_tumor = MLP(input_dim=2048, latent_dim=256).to(device)
        mlp_liver = MLP(input_dim=2048, latent_dim=256).to(device)
        dec_tumor = Decoder(latent_dim=256).to(device)
        dec_liver = Decoder(latent_dim=256).to(device)

        # Optimizer with explicit groups (encoders vs mlps vs decoders)
        param_groups = []
        param_groups += add_groups(enc_tumor, lr=Cfg.lr_encoders, wd=Cfg.weight_decay)
        param_groups += add_groups(enc_liver, lr=Cfg.lr_encoders, wd=Cfg.weight_decay)
        param_groups += add_groups(mlp_tumor, lr=Cfg.lr_mlps,     wd=Cfg.weight_decay)
        param_groups += add_groups(mlp_liver, lr=Cfg.lr_mlps,     wd=Cfg.weight_decay)
        param_groups += add_groups(dec_tumor, lr=Cfg.lr_decoders, wd=Cfg.wd_decoders)
        param_groups += add_groups(dec_liver, lr=Cfg.lr_decoders, wd=Cfg.wd_decoders)

        # Safety: ensure no duplicate params across groups
        all_params = [p for g in param_groups for p in g["params"]]
        assert len(all_params) == len(set(map(id, all_params))), "Duplicate params across groups!"

        optimizer = torch.optim.Adam(param_groups)

        best_val, no_improve = float("inf"), 0
        ckpt = {
            "enc_tumor": fold_dir / "enc_tumor_best.pth",
            "enc_liver": fold_dir / "enc_liver_best.pth",
            "mlp_tumor": fold_dir / "mlp_tumor_best.pth",
            "mlp_liver": fold_dir / "mlp_liver_best.pth",
            "dec_tumor": fold_dir / "dec_tumor_best.pth",
            "dec_liver": fold_dir / "dec_liver_best.pth",
        }

        for epoch in range(1, Cfg.max_epochs + 1):
            tr = train_one_epoch(
                enc_tumor, enc_liver, mlp_tumor, mlp_liver, dec_tumor, dec_liver,
                train_loader, optimizer, epoch
            )
            va, last_batch = evaluate(
                enc_tumor, enc_liver, mlp_tumor, mlp_liver, dec_tumor, dec_liver,
                val_loader
            )
            print(f"[Fold {fold_idx:02d}] Epoch {epoch:04d} | train {tr:.6f} | val {va:.6f}")

            if va < best_val - 1e-6:
                best_val, no_improve = va, 0
                torch.save(enc_tumor.state_dict(), ckpt["enc_tumor"])
                torch.save(enc_liver.state_dict(), ckpt["enc_liver"])
                torch.save(mlp_tumor.state_dict(), ckpt["mlp_tumor"])
                torch.save(mlp_liver.state_dict(), ckpt["mlp_liver"])
                torch.save(dec_tumor.state_dict(), ckpt["dec_tumor"])
                torch.save(dec_liver.state_dict(), ckpt["dec_liver"])
            else:
                no_improve += 1
                if no_improve >= Cfg.patience:
                    print(f"[Fold {fold_idx:02d}] Early stopping at epoch {epoch}, best val {best_val:.6f}")
                    break

        # Test with best weights
        if all(p.exists() for p in ckpt.values()):
            enc_tumor.load_state_dict(torch.load(ckpt["enc_tumor"], map_location=device))
            enc_liver.load_state_dict(torch.load(ckpt["enc_liver"], map_location=device))
            mlp_tumor.load_state_dict(torch.load(ckpt["mlp_tumor"], map_location=device))
            mlp_liver.load_state_dict(torch.load(ckpt["mlp_liver"], map_location=device))
            dec_tumor.load_state_dict(torch.load(ckpt["dec_tumor"], map_location=device))
            dec_liver.load_state_dict(torch.load(ckpt["dec_liver"], map_location=device))

        te, _ = evaluate(
            enc_tumor, enc_liver, mlp_tumor, mlp_liver, dec_tumor, dec_liver, test_loader
        )
        print(f"[Fold {fold_idx:02d}] TEST recon loss: {te:.6f}")


# --------------------------------- Usage -------------------------------------
transform_torchio = get_transform_torchio()

# Prepare your ID list. Easiest is to pass pairs of file paths:
# ids = [
#   ("/path/to/tumor_case1.nii.gz", "/path/to/liver_case1.nii.gz"),
#   ("/path/to/tumor_case2.nii.gz", "/path/to/liver_case2.nii.gz"),
#   ...
# ]

run_training(ids, transform_torchio=transform_torchio)
