# HANS-Net Standalone (Kaggle-ready)
- Self-contained pipeline: data loader, model, training loop in this notebook.
- Set `KAGGLE_DATASET_NAME` to your Kaggle dataset slug (folder under `/kaggle/input`).
- Run all cells to validate folders, smoke-test the model, and optionally train.
- Checkpoints write to `/kaggle/working/checkpoints` on Kaggle or `./checkpoints` locally.

In [None]:
import os
import re
import math
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
import torchvision.transforms.functional as TF
from tqdm import tqdm

# Kaggle-aware configuration
KAGGLE_DATASET_NAME = "replace-with-your-dataset"  # e.g., "liver-tumor-segmentation"
KAGGLE_INPUT_ROOT = "/kaggle/input"
KAGGLE_WORK_ROOT = "/kaggle/working"

AUTO_KAGGLE = os.environ.get("KAGGLE_KERNEL_RUN_TYPE") is not None
DEFAULT_DATA_ROOT = "Dataset"
DATA_ROOT = DEFAULT_DATA_ROOT

if AUTO_KAGGLE:
    candidate = os.path.join(KAGGLE_INPUT_ROOT, KAGGLE_DATASET_NAME)
    if os.path.isdir(candidate):
        DATA_ROOT = candidate
        print(f"Detected Kaggle; using dataset at {DATA_ROOT}")
    else:
        print(f"[WARN] Kaggle detected but dataset folder not found: {candidate}")
else:
    print("Kaggle not detected; using local paths.")

IMG_SIZE = (128, 128)
EPOCHS = 1
BATCH_SIZE = 2
LR = 1e-4
CHECKPOINT_DIR = os.path.join(KAGGLE_WORK_ROOT if AUTO_KAGGLE else ".", "checkpoints")
RESUME = None  # path to checkpoint if resuming
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(42)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

print(f"Device: {DEVICE}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Dataset root: {DATA_ROOT}")
print(f"Checkpoints: {CHECKPOINT_DIR}")

In [None]:
# Dataset helpers and loaders
def get_case_id_from_filename(fname: str) -> str:
    basename = os.path.basename(fname)
    name_no_ext = os.path.splitext(basename)[0]
    if '_slice_' not in name_no_ext:
        raise ValueError(f"Invalid filename format: '{fname}'")
    parts = name_no_ext.split('_slice_')
    if len(parts) != 2:
        raise ValueError(f"Invalid filename format: '{fname}'")
    case_id = parts[0]
    if not re.match(r'^volume-\d+$', case_id):
        raise ValueError(f"Invalid case ID format: '{case_id}'")
    return case_id

def get_slice_index(fname: str) -> int:
    basename = os.path.basename(fname)
    name_no_ext = os.path.splitext(basename)[0]
    if '_slice_' not in name_no_ext:
        raise ValueError(f"Invalid filename format: '{fname}'")
    parts = name_no_ext.split('_slice_')
    if len(parts) != 2:
        raise ValueError(f"Invalid filename format: '{fname}'")
    slice_str = parts[1]
    return int(slice_str)

@dataclass
class SliceMeta:
    img_path: str
    mask_path: str
    case_id: str
    slice_idx: int

def build_slice_metadata(img_root: str, mask_root: str) -> List[SliceMeta]:
    if not os.path.isdir(img_root):
        raise FileNotFoundError(f"Image root not found: {img_root}")
    if not os.path.isdir(mask_root):
        raise FileNotFoundError(f"Mask root not found: {mask_root}")
    img_root = os.path.abspath(img_root)
    mask_root = os.path.abspath(mask_root)
    img_files = [f for f in os.listdir(img_root) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    mask_files_set = set(f for f in os.listdir(mask_root) if f.lower().endswith(('.jpg', '.jpeg', '.png')))
    metadata_list: List[SliceMeta] = []
    for img_fname in img_files:
        if img_fname not in mask_files_set:
            raise ValueError(f"No corresponding mask for {img_fname}")
        img_path = os.path.join(img_root, img_fname)
        mask_path = os.path.join(mask_root, img_fname)
        case_id = get_case_id_from_filename(img_fname)
        slice_idx = get_slice_index(img_fname)
        metadata_list.append(SliceMeta(img_path, mask_path, case_id, slice_idx))
    return metadata_list

class LITSSliceDataset(Dataset):
    def __init__(self, img_root: str, mask_root: str, img_size: Tuple[int, int] = (128, 128)) -> None:
        self.img_size = img_size
        metadata_list = build_slice_metadata(img_root, mask_root)
        case_to_slices: Dict[str, List[SliceMeta]] = {}
        for meta in metadata_list:
            case_to_slices.setdefault(meta.case_id, []).append(meta)
        for cid in case_to_slices:
            case_to_slices[cid].sort(key=lambda m: m.slice_idx)
        self.case_to_slices = case_to_slices
        self.samples: List[Tuple[str, int]] = []
        for case_id, slices in self.case_to_slices.items():
            for center_idx in range(len(slices)):
                self.samples.append((case_id, center_idx))

    def __len__(self) -> int:
        return len(self.samples)

    def _get_triplet_indices(self, num_slices: int, center_idx: int) -> Tuple[int, int, int]:
        prev_idx = max(0, center_idx - 1)
        next_idx = min(num_slices - 1, center_idx + 1)
        return (prev_idx, center_idx, next_idx)

    def _load_slice(self, path: str) -> torch.Tensor:
        img = Image.open(path).convert('L')
        img = img.resize((self.img_size[1], self.img_size[0]), resample=Image.BILINEAR)
        return TF.to_tensor(img)

    def _load_mask(self, path: str) -> torch.Tensor:
        mask = Image.open(path).convert('L')
        mask = mask.resize((self.img_size[1], self.img_size[0]), resample=Image.NEAREST)
        mask_tensor = TF.to_tensor(mask)
        return (mask_tensor > 0.5).float()

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        case_id, center_idx = self.samples[idx]
        slices = self.case_to_slices[case_id]
        i_prev, i_center, i_next = self._get_triplet_indices(len(slices), center_idx)
        img_prev = self._load_slice(slices[i_prev].img_path)
        img_center = self._load_slice(slices[i_center].img_path)
        img_next = self._load_slice(slices[i_next].img_path)
        imgs_3 = torch.stack([img_prev, img_center, img_next], dim=0)
        mask_center = self._load_mask(slices[i_center].mask_path)
        return imgs_3, mask_center

def build_dataloaders(dataset_root: str, img_size: Tuple[int, int] = (128, 128), batch_size: int = 4, train_ratio: float = 0.8, num_workers: int = 4) -> Tuple[DataLoader, DataLoader]:
    img_root = os.path.join(dataset_root, "train_images", "train_images")
    mask_root = os.path.join(dataset_root, "train_masks", "train_masks")
    full_dataset = LITSSliceDataset(img_root=img_root, mask_root=mask_root, img_size=img_size)
    n_total = len(full_dataset)
    n_train = int(train_ratio * n_total)
    n_val = n_total - n_train
    train_dataset, val_dataset = random_split(full_dataset, [n_train, n_val])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    return train_loader, val_loader

def demo_dataset_structure(dataset_root: str) -> None:
    img_root = os.path.join(dataset_root, "train_images", "train_images")
    mask_root = os.path.join(dataset_root, "train_masks", "train_masks")
    print("=" * 60)
    print("Dataset Structure Validation")
    print("=" * 60)
    print(f"Dataset root: {dataset_root}")
    print(f"Image folder: {img_root}")
    print(f"Mask folder:  {mask_root}")
    img_exists = os.path.isdir(img_root)
    mask_exists = os.path.isdir(mask_root)
    if not img_exists or not mask_exists:
        print("Folders missing; cannot extract metadata.")
        return
    metadata_list = build_slice_metadata(img_root, mask_root)
    total_slices = len(metadata_list)
    print(f"Total slices: {total_slices}")
    case_to_slices: dict = {}
    for meta in metadata_list:
        case_to_slices.setdefault(meta.case_id, []).append(meta.slice_idx)
    print(f"Cases: {len(case_to_slices)}")
    case_ids = sorted(case_to_slices.keys())[:2]
    for case_id in case_ids:
        slice_indices = sorted(case_to_slices[case_id])
        print(f"  Case {case_id}: {len(slice_indices)} slices; first 6: {slice_indices[:6]}")
    print("Validation complete.")

In [None]:
# HANS-Net model (standalone)
class WaveletDecomposition(nn.Module):
    def __init__(self):
        super().__init__()
        ll = torch.tensor([[1, 1], [1, 1]], dtype=torch.float32) / 2.0
        lh = torch.tensor([[1, 1], [-1, -1]], dtype=torch.float32) / 2.0
        hl = torch.tensor([[1, -1], [1, -1]], dtype=torch.float32) / 2.0
        hh = torch.tensor([[1, -1], [-1, 1]], dtype=torch.float32) / 2.0
        filters = torch.stack([ll, lh, hl, hh], dim=0).unsqueeze(1)
        self.register_buffer('filters', filters)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, H, W = x.shape
        x_reshape = x.view(B * C, 1, H, W)
        coeffs = F.conv2d(x_reshape, self.filters, stride=2, padding=0)
        _, _, H_out, W_out = coeffs.shape
        return coeffs.view(B, C * 4, H_out, W_out)

class WaveletReconstruction(nn.Module):
    def __init__(self):
        super().__init__()
        ll = torch.tensor([[1, 1], [1, 1]], dtype=torch.float32) / 2.0
        lh = torch.tensor([[1, 1], [-1, -1]], dtype=torch.float32) / 2.0
        hl = torch.tensor([[1, -1], [1, -1]], dtype=torch.float32) / 2.0
        hh = torch.tensor([[1, -1], [-1, 1]], dtype=torch.float32) / 2.0
        filters = torch.stack([ll, lh, hl, hh], dim=0).unsqueeze(0)
        self.register_buffer('filters', filters)
    def forward(self, coeffs: torch.Tensor) -> torch.Tensor:
        B, C4, H, W = coeffs.shape
        C = C4 // 4
        coeffs = coeffs.view(B * C, 4, H, W)
        x = F.conv_transpose2d(coeffs, self.filters, stride=2, padding=0)
        return x.view(B, C, H * 2, W * 2)

class SynapticPlasticity(nn.Module):
    def __init__(self, channels: int):
        super().__init__()
        self.gain = nn.Parameter(torch.ones(channels))
        self.threshold = nn.Parameter(torch.zeros(channels))
        self.plasticity_rate = nn.Parameter(torch.ones(1) * 0.1)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, _, _ = x.shape
        channel_mean = x.mean(dim=(2, 3))
        modulation = torch.sigmoid(self.plasticity_rate * (channel_mean - self.threshold.view(1, -1)))
        effective_scale = self.gain.view(1, -1, 1, 1) * (1 + modulation.view(B, C, 1, 1))
        return x * effective_scale

class PlasticConvBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, use_plasticity: bool = True, use_residual: bool = False, dropout_p: float = 0.0):
        super().__init__()
        self.use_plasticity = use_plasticity
        self.use_residual = use_residual and (in_channels == out_channels) and (stride == 1)
        padding = kernel_size // 2
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, 1, padding, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        if use_plasticity:
            self.plasticity = SynapticPlasticity(out_channels)
        self.act = nn.GELU()
        self.dropout = nn.Dropout2d(dropout_p) if dropout_p > 0.0 else nn.Identity()
        if self.use_residual and (in_channels != out_channels or stride != 1):
            self.residual_proj = nn.Conv2d(in_channels, out_channels, 1, stride, bias=False)
        else:
            self.residual_proj = None
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.act(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.use_plasticity:
            out = self.plasticity(out)
        if self.use_residual:
            if self.residual_proj is not None:
                identity = self.residual_proj(identity)
            out = out + identity
        out = self.dropout(out)
        out = self.act(out)
        return out

def exp_map_zero(v: torch.Tensor, c: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
    sqrt_c = torch.sqrt(c)
    v_norm = v.norm(dim=-1, keepdim=True).clamp(min=eps)
    tanh_arg = (sqrt_c * v_norm).clamp(max=15.0)
    return torch.tanh(tanh_arg) * v / (sqrt_c * v_norm + eps)

def log_map_zero(y: torch.Tensor, c: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
    sqrt_c = torch.sqrt(c)
    y_norm = y.norm(dim=-1, keepdim=True).clamp(min=eps)
    y_norm_scaled = (sqrt_c * y_norm).clamp(min=eps, max=1.0 - eps)
    return torch.arctanh(y_norm_scaled) * y / (sqrt_c * y_norm + eps)

def project_to_ball(x: torch.Tensor, c: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
    max_norm = (1.0 - eps) / torch.sqrt(c)
    x_norm = x.norm(dim=-1, keepdim=True).clamp(min=eps)
    scale = torch.clamp(max_norm / x_norm, max=1.0)
    return x * scale

def mobius_add(x: torch.Tensor, y: torch.Tensor, c: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
    x_sq = (x * x).sum(dim=-1, keepdim=True)
    y_sq = (y * y).sum(dim=-1, keepdim=True)
    xy = (x * y).sum(dim=-1, keepdim=True)
    num = (1 + 2 * c * xy + c * y_sq) * x + (1 - c * x_sq) * y
    denom = (1 + 2 * c * xy + c * c * x_sq * y_sq).clamp(min=eps)
    return project_to_ball(num / denom, c, eps)

class HyperbolicConvBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, curvature: float = 1.0, learnable_curvature: bool = True):
        super().__init__()
        if learnable_curvature:
            init_val = math.log(math.exp(curvature) - 1)
            self.curvature = nn.Parameter(torch.tensor(init_val))
        else:
            self.register_buffer('curvature', torch.tensor(curvature))
        self.input_norm = nn.GroupNorm(min(8, in_channels), in_channels)
        padding = kernel_size // 2
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=False)
        self.bias = nn.Parameter(torch.zeros(out_channels))
        self.output_norm = nn.GroupNorm(min(8, out_channels), out_channels)
        self.act = nn.GELU()
        self.scale = nn.Parameter(torch.ones(1) * 0.1)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        c = F.softplus(self.curvature).clamp(min=0.1, max=10.0)
        x = self.input_norm(x)
        x_scaled = x * torch.abs(self.scale)
        x_bhwc = x_scaled.permute(0, 2, 3, 1).contiguous()
        x_hyp = exp_map_zero(x_bhwc, c)
        x_hyp = project_to_ball(x_hyp, c)
        x_tangent = log_map_zero(x_hyp, c)
        x_tangent = x_tangent.permute(0, 3, 1, 2).contiguous()
        out = self.conv(x_tangent)
        out = out + self.bias.view(1, -1, 1, 1)
        out = self.output_norm(out)
        out = self.act(out)
        return out

class TemporalAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int = 4, dropout: float = 0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5
        assert embed_dim % num_heads == 0
        self.q_proj = nn.Conv2d(embed_dim, embed_dim, 1, bias=False)
        self.k_proj = nn.Conv2d(embed_dim, embed_dim, 1, bias=False)
        self.v_proj = nn.Conv2d(embed_dim, embed_dim, 1, bias=False)
        self.out_proj = nn.Conv2d(embed_dim, embed_dim, 1, bias=False)
        self.temporal_pos = nn.Parameter(torch.randn(1, 3, embed_dim, 1, 1) * 0.02)
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(embed_dim)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C, H, W = x.shape
        x = x + self.temporal_pos[:, :T]
        center_idx = T // 2
        q_input = x[:, center_idx]
        q = self.q_proj(q_input)
        x_flat = x.view(B * T, C, H, W)
        k = self.k_proj(x_flat)
        v = self.v_proj(x_flat)
        q = q.view(B, self.num_heads, self.head_dim, H * W)
        k = k.view(B, T, self.num_heads, self.head_dim, H * W)
        v = v.view(B, T, self.num_heads, self.head_dim, H * W)
        q = q.permute(0, 1, 3, 2)
        k = k.permute(0, 2, 3, 4, 1)
        v = v.permute(0, 2, 3, 4, 1)
        attn = torch.einsum('bnsd,bndst->bnst', q, k) * self.scale
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        out = torch.einsum('bnst,bndst->bnsd', attn, v)
        out = out.permute(0, 1, 3, 2).contiguous()
        out = out.view(B, C, H, W)
        out = self.out_proj(out)
        out = out + q_input
        out = out.permute(0, 2, 3, 1)
        out = self.norm(out)
        out = out.permute(0, 3, 1, 2)
        return out

class PositionalEncoding(nn.Module):
    def __init__(self, num_frequencies: int = 10, include_input: bool = True):
        super().__init__()
        self.num_frequencies = num_frequencies
        self.include_input = include_input
        freq_bands = 2.0 ** torch.linspace(0, num_frequencies - 1, num_frequencies)
        self.register_buffer('freq_bands', freq_bands)
        self.out_dim = 2 * num_frequencies * 2
        if include_input:
            self.out_dim += 2
    def forward(self, coords: torch.Tensor) -> torch.Tensor:
        scaled = coords.unsqueeze(-1) * self.freq_bands * math.pi
        encoded = torch.stack([torch.sin(scaled), torch.cos(scaled)], dim=-1)
        encoded = encoded.view(*coords.shape[:-1], -1)
        if self.include_input:
            encoded = torch.cat([coords, encoded], dim=-1)
        return encoded

def make_coord_grid(H: int, W: int, device: torch.device = None) -> torch.Tensor:
    y = torch.linspace(-1, 1, H, device=device)
    x = torch.linspace(-1, 1, W, device=device)
    yy, xx = torch.meshgrid(y, x, indexing='ij')
    return torch.stack([xx, yy], dim=-1)

class INRBranch(nn.Module):
    def __init__(self, feature_dim: int, hidden_dim: int = 256, num_frequencies: int = 10, num_layers: int = 3):
        super().__init__()
        self.pos_encoder = PositionalEncoding(num_frequencies, include_input=True)
        coord_dim = self.pos_encoder.out_dim
        self.feature_proj = nn.Conv2d(feature_dim, hidden_dim // 2, 1)
        input_dim = coord_dim + hidden_dim // 2
        layers = []
        for i in range(num_layers):
            in_dim = input_dim if i == 0 else hidden_dim
            out_dim = hidden_dim if i < num_layers - 1 else 1
            layers.append(nn.Linear(in_dim, out_dim))
            if i < num_layers - 1:
                layers.append(nn.GELU())
        self.mlp = nn.Sequential(*layers)
    def forward(self, features: torch.Tensor, coords: Optional[torch.Tensor] = None) -> torch.Tensor:
        B, C, H, W = features.shape
        device = features.device
        if coords is None:
            coords = make_coord_grid(H, W, device)
        coord_enc = self.pos_encoder(coords).unsqueeze(0).expand(B, -1, -1, -1)
        feat_proj = self.feature_proj(features).permute(0, 2, 3, 1)
        combined = torch.cat([coord_enc, feat_proj], dim=-1)
        out = self.mlp(combined)
        return out.permute(0, 3, 1, 2)

class HANSNet(nn.Module):
    def __init__(self, base_channels: int = 32, num_classes: int = 1):
        super().__init__()
        c1 = base_channels
        c2 = base_channels * 2
        c3 = base_channels * 4
        c4 = base_channels * 8
        self.wavelet = WaveletDecomposition()
        self.enc1 = PlasticConvBlock(4, c1, use_plasticity=True)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = PlasticConvBlock(c1, c2, use_plasticity=True)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = PlasticConvBlock(c2, c3, use_plasticity=True)
        self.temporal_attn = TemporalAttention(embed_dim=c3, num_heads=4)
        self.pool3 = nn.MaxPool2d(2)
        self.bottleneck = HyperbolicConvBlock(c3, c4, curvature=1.0, learnable_curvature=True)
        self.up3 = nn.ConvTranspose2d(c4, c3, kernel_size=2, stride=2)
        self.dec3 = PlasticConvBlock(c3 + c3, c3, use_plasticity=True, dropout_p=0.3)
        self.up2 = nn.ConvTranspose2d(c3, c2, kernel_size=2, stride=2)
        self.dec2 = PlasticConvBlock(c2 + c2, c2, use_plasticity=True, dropout_p=0.3)
        self.up1 = nn.ConvTranspose2d(c2, c1, kernel_size=2, stride=2)
        self.dec1 = PlasticConvBlock(c1 + c1, c1, use_plasticity=True, dropout_p=0.3)
        self.final_up = nn.ConvTranspose2d(c1, c1, kernel_size=2, stride=2)
        self.final_conv = PlasticConvBlock(c1, c1, use_plasticity=True, dropout_p=0.3)
        self.seg_head = nn.Conv2d(c1, num_classes, kernel_size=1)
        self.inr_branch = INRBranch(feature_dim=c1, hidden_dim=128, num_frequencies=10, num_layers=3)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C, H, W = x.shape
        assert T == 3, f"Expected T=3 slices, got {T}"
        assert C == 1, f"Expected C=1 channel, got {C}"
        x_flat = x.view(B * T, C, H, W)
        e1 = self.wavelet(x_flat)
        e1 = self.enc1(e1)
        e2 = self.pool1(e1)
        e2 = self.enc2(e2)
        e3 = self.pool2(e2)
        e3 = self.enc3(e3)
        _, c3_ch, h3, w3 = e3.shape
        e3_temporal = e3.view(B, T, c3_ch, h3, w3)
        f_center = self.temporal_attn(e3_temporal)
        bottleneck = self.pool3(f_center)
        bottleneck = self.bottleneck(bottleneck)
        center_idx = T // 2
        _, c1_ch, h1, w1 = e1.shape
        e1_center = e1.view(B, T, c1_ch, h1, w1)[:, center_idx]
        _, c2_ch, h2, w2 = e2.shape
        e2_center = e2.view(B, T, c2_ch, h2, w2)[:, center_idx]
        e3_center = f_center
        d3 = self.up3(bottleneck)
        d3 = torch.cat([d3, e3_center], dim=1)
        d3 = self.dec3(d3)
        d2 = self.up2(d3)
        d2 = torch.cat([d2, e2_center], dim=1)
        d2 = self.dec2(d2)
        d1 = self.up1(d2)
        d1 = torch.cat([d1, e1_center], dim=1)
        d1 = self.dec1(d1)
        dec_out = self.final_up(d1)
        dec_out = self.final_conv(dec_out)
        coarse_logits = self.seg_head(dec_out)
        refine_logits = self.inr_branch(dec_out)
        return coarse_logits + refine_logits

In [None]:
# Training utilities
def save_checkpoint(model, optimizer, epoch: int, path: str) -> None:
    os.makedirs(os.path.dirname(path), exist_ok=True)
    checkpoint = {"model_state": model.state_dict(), "optim_state": optimizer.state_dict(), "epoch": epoch}
    torch.save(checkpoint, path)

def load_checkpoint(path: str, model, optimizer=None, device: str = "cuda") -> int:
    checkpoint = torch.load(path, map_location=device)
    model.load_state_dict(checkpoint["model_state"])
    if optimizer is not None and "optim_state" in checkpoint:
        optimizer.load_state_dict(checkpoint["optim_state"])
    return checkpoint.get("epoch", 0)

def dice_coeff(pred_probs: torch.Tensor, target_mask: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    pred_flat = pred_probs.view(pred_probs.size(0), -1)
    target_flat = target_mask.view(target_mask.size(0), -1)
    intersection = (pred_flat * target_flat).sum(dim=1)
    union = pred_flat.sum(dim=1) + target_flat.sum(dim=1)
    dice = (2 * intersection + eps) / (union + eps)
    return dice.mean()

def iou_score(pred_probs: torch.Tensor, target_mask: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    pred_flat = pred_probs.view(pred_probs.size(0), -1)
    target_flat = target_mask.view(target_mask.size(0), -1)
    intersection = (pred_flat * target_flat).sum(dim=1)
    union = pred_flat.sum(dim=1) + target_flat.sum(dim=1) - intersection
    iou = (intersection + eps) / (union + eps)
    return iou.mean()

def dice_loss(pred_logits: torch.Tensor, target_mask: torch.Tensor) -> torch.Tensor:
    pred_probs = torch.sigmoid(pred_logits)
    return 1.0 - dice_coeff(pred_probs, target_mask)

def combined_loss(pred_logits: torch.Tensor, target_mask: torch.Tensor) -> torch.Tensor:
    bce = F.binary_cross_entropy_with_logits(pred_logits, target_mask)
    dloss = dice_loss(pred_logits, target_mask)
    return bce + dloss

def train_one_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0.0
    total_dice = 0.0
    total_iou = 0.0
    for imgs_3, mask_center in tqdm(loader, desc="Training", leave=False):
        imgs_3 = imgs_3.to(device)
        mask_center = mask_center.to(device)
        pred = model(imgs_3)
        loss = combined_loss(pred, mask_center)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        pred_probs = torch.sigmoid(pred)
        total_dice += dice_coeff(pred_probs, mask_center).item()
        total_iou += iou_score(pred_probs, mask_center).item()
    n = len(loader)
    return total_loss / n, total_dice / n, total_iou / n

def validate_one_epoch(model, loader, device):
    model.eval()
    total_loss = 0.0
    total_dice = 0.0
    total_iou = 0.0
    with torch.no_grad():
        for imgs_3, mask_center in tqdm(loader, desc="Validation", leave=False):
            imgs_3 = imgs_3.to(device)
            mask_center = mask_center.to(device)
            pred = model(imgs_3)
            loss = combined_loss(pred, mask_center)
            total_loss += loss.item()
            pred_probs = torch.sigmoid(pred)
            total_dice += dice_coeff(pred_probs, mask_center).item()
            total_iou += iou_score(pred_probs, mask_center).item()
    n = len(loader)
    return total_loss / n, total_dice / n, total_iou / n

def train_pipeline(dataset_root: str, epochs: int = 3, lr: float = 1e-4, batch_size: int = 4, device: str = "cuda", checkpoint_dir: str = "checkpoints", resume_path: str | None = None):
    model = HANSNet().to(device)
    train_loader, val_loader = build_dataloaders(dataset_root, batch_size=batch_size)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    start_epoch = 0
    if resume_path is not None and os.path.isfile(resume_path):
        print(f"Resuming from checkpoint: {resume_path}")
        start_epoch = load_checkpoint(resume_path, model, optimizer, device) + 1
    best_val_loss = float("inf")
    print(f"Starting training for {epochs} epochs (from epoch {start_epoch + 1})...")
    print(f"Train samples: {len(train_loader.dataset)}, Val samples: {len(val_loader.dataset)}")
    for epoch in range(start_epoch, epochs):
        train_loss, train_dice, train_iou = train_one_epoch(model, train_loader, optimizer, device)
        val_loss, val_dice, val_iou = validate_one_epoch(model, val_loader, device)
        print(f"Epoch {epoch+1}/{epochs} | train_loss={train_loss:.4f} | train_dice={train_dice:.4f} | train_iou={train_iou:.4f} | val_loss={val_loss:.4f} | val_dice={val_dice:.4f} | val_iou={val_iou:.4f}")
        last_ckpt_path = os.path.join(checkpoint_dir, "last.pt")
        save_checkpoint(model, optimizer, epoch, last_ckpt_path)
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_ckpt_path = os.path.join(checkpoint_dir, "best.pt")
            save_checkpoint(model, optimizer, epoch, best_ckpt_path)
            print(f"  [checkpoint] best  -> {best_ckpt_path} (val improved)")
    print(f"Training complete. Best val_loss: {best_val_loss:.4f}")
    return model

In [None]:
# Validate folders, smoke test, and train (optional)
IMG_ROOT = os.path.join(DATA_ROOT, "train_images", "train_images")
MASK_ROOT = os.path.join(DATA_ROOT, "train_masks", "train_masks")
data_available = os.path.isdir(IMG_ROOT) and os.path.isdir(MASK_ROOT)

print(f"Image folder exists: {data_available and os.path.isdir(IMG_ROOT)} -> {IMG_ROOT}")
print(f"Mask folder exists:  {data_available and os.path.isdir(MASK_ROOT)} -> {MASK_ROOT}")

if data_available:
    demo_dataset_structure(DATA_ROOT)
else:
    print("[WARN] Dataset folders missing; training will be skipped.")

def smoke_test(model_cls, device=DEVICE, img_size=IMG_SIZE):
    model = model_cls().to(device)
    x = torch.randn(1, 3, 1, img_size[0], img_size[1], device=device)
    with torch.no_grad():
        out = model(x)
    print(f"Input: {tuple(x.shape)}, Output: {tuple(out.shape)}")
    return model

_ = smoke_test(HANSNet)

if not data_available:
    print("Dataset not found; skipping training. Set KAGGLE_DATASET_NAME or DATA_ROOT and rerun.")
else:
    _ = train_pipeline(dataset_root=DATA_ROOT, epochs=EPOCHS, lr=LR, batch_size=BATCH_SIZE, device=DEVICE, checkpoint_dir=CHECKPOINT_DIR, resume_path=RESUME)

### **Simple visual Evaluation**

In [None]:
# Quick visual evaluation sampling directly from masks with foreground
import random
import matplotlib.pyplot as plt
import numpy as np

ckpt_path = os.path.join(CHECKPOINT_DIR, "best.pt")
if not os.path.isfile(ckpt_path):
    alt = os.path.join(CHECKPOINT_DIR, "last.pt")
    ckpt_path = alt if os.path.isfile(alt) else None

def _load_slice(path: str, img_size=IMG_SIZE):
    img = Image.open(path).convert("L")
    img = img.resize((img_size[1], img_size[0]), resample=Image.BILINEAR)
    return TF.to_tensor(img)

def _load_mask(path: str, img_size=IMG_SIZE):
    mask = Image.open(path).convert("L")
    mask = mask.resize((img_size[1], img_size[0]), resample=Image.NEAREST)
    mask_tensor = TF.to_tensor(mask)
    return (mask_tensor > 0.5).float()

if ckpt_path is None:
    print("No checkpoint found in CHECKPOINT_DIR; run training first.")
else:
    print(f"Loading checkpoint: {ckpt_path}")
    model = HANSNet().to(DEVICE)
    ckpt = torch.load(ckpt_path, map_location=DEVICE)
    model.load_state_dict(ckpt["model_state"])
    model.eval()

    img_root = os.path.join(DATA_ROOT, "train_images", "train_images")
    mask_root = os.path.join(DATA_ROOT, "train_masks", "train_masks")
    if not (os.path.isdir(img_root) and os.path.isdir(mask_root)):
        print("Image/mask folders missing; cannot sample.")
    else:
        metadata_list = build_slice_metadata(img_root, mask_root)
        case_to_slices = {}
        for meta in metadata_list:
            case_to_slices.setdefault(meta.case_id, []).append(meta)
        for cid in case_to_slices:
            case_to_slices[cid].sort(key=lambda m: m.slice_idx)

        # collect slices whose mask has foreground
        pool = []
        for case_id, slices in case_to_slices.items():
            for idx, meta in enumerate(slices):
                arr = np.array(Image.open(meta.mask_path).convert("L"))
                if np.count_nonzero(arr) > 0:
                    pool.append((case_id, idx, meta))

        num_samples = 8
        if len(pool) == 0:
            print("No non-empty masks found in dataset; cannot plot.")
        else:
            chosen = random.sample(pool, k=min(num_samples, len(pool)))
            plt.figure(figsize=(12, 4 * len(chosen)))
            for plot_idx, (case_id, center_idx, meta) in enumerate(chosen):
                slices = case_to_slices[case_id]
                prev_idx = max(0, center_idx - 1)
                next_idx = min(len(slices) - 1, center_idx + 1)
                img_prev = _load_slice(slices[prev_idx].img_path)
                img_center = _load_slice(slices[center_idx].img_path)
                img_next = _load_slice(slices[next_idx].img_path)
                imgs_3 = torch.stack([img_prev, img_center, img_next], dim=0)
                mask_center = _load_mask(slices[center_idx].mask_path)

                with torch.no_grad():
                    pred_logits = model(imgs_3.unsqueeze(0).to(DEVICE))
                    pred_probs = torch.sigmoid(pred_logits).detach().cpu().squeeze(0).squeeze(0)

                img_show = imgs_3[1].cpu().squeeze(0)
                mask_np = mask_center.cpu().squeeze(0)
                pred_np = (pred_probs > 0.5).float()

                ax1 = plt.subplot(len(chosen), 3, plot_idx * 3 + 1)
                ax1.imshow(img_show.numpy(), cmap="gray")
                ax1.set_title(f"Image {plot_idx + 1}")
                ax1.axis("off")

                ax2 = plt.subplot(len(chosen), 3, plot_idx * 3 + 2)
                ax2.imshow(mask_np.numpy(), cmap="gray")
                ax2.set_title("Ground Truth (non-zero)")
                ax2.axis("off")

                ax3 = plt.subplot(len(chosen), 3, plot_idx * 3 + 3)
                ax3.imshow(pred_np.numpy(), cmap="gray")
                ax3.set_title("Prediction")
                ax3.axis("off")

            plt.tight_layout()
            plt.show()

### **Masks check**

In [None]:
# Mask coverage summary: counts black vs non-black masks
from pathlib import Path
import numpy as np
from tqdm import tqdm

mask_root = Path(os.path.join(DATA_ROOT, "train_masks", "train_masks"))
if not mask_root.is_dir():
    print(f"Mask folder not found: {mask_root}")
else:
    mask_files = [p for p in mask_root.iterdir() if p.suffix.lower() in [".png", ".jpg", ".jpeg"]]
    total = len(mask_files)
    zero_count = 0
    nonzero_count = 0
    for p in tqdm(mask_files, desc="Scanning masks", leave=False):
        arr = np.array(Image.open(p).convert("L"))
        if np.count_nonzero(arr) == 0:
            zero_count += 1
        else:
            nonzero_count += 1
    print(f"Total masks:    {total}")
    print(f"All-zero masks: {zero_count}")
    print(f"Non-zero masks: {nonzero_count}")
    if total > 0:
        print(f"Percent zero:   {zero_count / total * 100:.2f}%")
        print(f"Percent nonzero:{nonzero_count / total * 100:.2f}%")

### **Summary of Today Training (11-12-2025)**
- Ran the model for upto 5 Epochs with Learning Rate: 1e-4, Batch size: 64 and got unrealistically good train results
- Found zero foreground containing masks in the Validation split
- This is the main reason for getting Dice Score 1 for the Val split