In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
# Pipeline
# 1 - Augment: 
#         + Augmentation on the fly
#         + Robust to skin tones
#         + Hair removal with conditions (resize w.r.t ratio -> chọn lọc những hair thon dài)
#         + Horizontal Flip
#         + Random rotate 90
#         + Affine
#         + Elastic Transform
#         + RGB or HSV shift
#         + Median or Motion Blur
#         + Coarse Dropout
# 2 - Preprocessing: 
#         + Clahe on L channel
#         + Conditional (Median Blur) nếu như std lớn quá thì sẽ blur
#         + Resize
#         + Normalize
# 3 - Loss function: 
#         + Focal Tversky (60%) → Class imbalance + recall focus (catching all lesion pixels).
#         + Dynamic BCE (20%) → Stabilizes pixel-wise classification, avoids trivial "all background."
#         + Surface loss (20%) → Fine-tunes edges so the segmentation matches lesion contours.
# 4 - Model Selection:
#         + EPolar-ResUnet
#         + Unet++
#         + Unet
#         + PSPNet
#         + DeepLabv3
#         + Polar-ResUnet++
#         + SegFormer
# 5 - Postprocessing
#         + CRFs
#         + mask2rle
#         + csv submission


        

In [None]:
!pip install git+https://github.com/lucasb-eyer/pydensecrf.git
!pip install lightning
!pip install segmentation_models_pytorch
!pip install albumentations opencv-python

In [1]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchmetrics
from torchmetrics import JaccardIndex
from torchmetrics.segmentation import DiceScore
from torchvision.transforms import v2 as T
import os
from tqdm import tqdm
from glob import glob
import timm
from PIL import Image, ImageEnhance, ImageFilter
import random
from sklearn.model_selection import train_test_split
import csv
import torch.nn.functional as F
import math
import torch.utils.model_zoo as model_zoo
import pytorch_lightning as pl
import wandb
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import random
import pydensecrf.densecrf as dcrf
from pydensecrf.utils import unary_from_softmax
import albumentations as A
from torchmetrics.classification import BinaryJaccardIndex, BinaryF1Score
from pathlib import Path
from albumentations.pytorch import ToTensorV2  
from scipy.ndimage import distance_transform_edt as dted
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor


In [None]:
def read_image(path: Path):
    """Read an image (RGB)."""
    img = cv2.imread(str(path), cv2.IMREAD_COLOR)
    if img is None:
        raise FileNotFoundError(f"Image not found: {path}")
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img

def read_mask(path: Path, binarize=True):
    """Read a mask (grayscale)."""
    mask = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
    if mask is None:
        raise FileNotFoundError(f"Mask not found: {path}")
    if binarize:
        mask = (mask > 127).astype("uint8")
    return mask
def plot_image_and_mask(image, mask=None, title="Sample"):
    """Plot image and option mask."""
    plt.figure(figsize=(10, 5))

    # Show image
    if mask is not None:
        plt.figure(figsize=(12, 6))

        plt.subplot(1, 2, 1)
        plt.imshow(image)
        plt.axis("off")
        plt.title("Image")

        plt.subplot(1, 2, 2)
        plt.imshow(mask, cmap="gray")
        plt.axis("off")
        plt.title("Mask")
    else:
        plt.figure(figsize=(6, 6))
        plt.imshow(image)
        plt.axis("off")
        plt.title("Image")
    plt.suptitle(title)
    plt.show
def show_random_sample(image_paths, mask_paths, title="Random Sample"):
    """Randomly choose one sample (image + mask) and plot it"""
    idx = random.randint(0, len(image_paths) - 1)
    img = read_image(image_paths[idx])
    mask = read_mask(mask_paths[idx])
    print("mask:", mask.shape, mask.dtype, mask.min().item(), mask.max().item())
    print("img:", img.shape, img.dtype, img.min().item(), img.max().item(), img.mean().item())
    plot_image_and_mask(img, mask, title=f"{title} - {image_paths[idx].stem}")


In [None]:
test_dir = Path("/kaggle/input/warm-up-program-ai-vietnam-skin-segmentation/Test/Test")
train_dir = Path("/kaggle/input/warm-up-program-ai-vietnam-skin-segmentation/Train/Train")

train_image_path_list = sorted(list(train_dir.glob("Image/*.jpg")))
train_mask_path_list = sorted(list(train_dir.glob("Mask/*.png")))
test_image_path_list = sorted(list(test_dir.glob("Image/*")))

In [None]:
len(train_image_path_list), len(train_mask_path_list), len(test_image_path_list)

In [None]:
show_random_sample(train_image_path_list, train_mask_path_list)

In [None]:
class SkinSegmentationDataset(Dataset):
    def __init__(self, image_paths, mask_paths=None, test=False, transform=None):
        self.test = test
        self.transform = transform
        self.image_paths = image_paths
        self.mask_paths = mask_paths
    def __len__(self):
        return len(self.image_paths)
    def __getitem__(self, index: int):
        image = read_image(self.image_paths[index])

        if self.test:
            if self.transform:
                image = self.transform(image=image)["image"]
            return image
        else:
            mask = read_mask(self.mask_paths[index])
            if self.transform:
                transformed = self.transform(image=image, mask=mask)
                image, mask = transformed["image"], transformed["mask"]
            return image, mask

In [None]:
def hair_removal(
    img_path: str,
    mask_path: str,
    process_original: bool = False,
    min_side_for_fast: int = 512,       # when process_original=False, resize so min side == this
    blackhat_kernel: int = 17,
    threshold_method: str = "otsu", # "adaptive" or "otsu" or "fixed"
    fixed_thresh: int = 10,
    min_hair_area_base: int = 100,      # base min area for ~512 input; will scale by image area
    aspect_ratio_thresh: float = 3.0,   # gentle default
    compactness_thresh: float = 0.65,   # gentle default (area / (w*h))
    dilate_iter: int = 1,
    dilate_kernel: int = 3,
    inpaint_radius: int = 2,
    inpaint_method = cv2.INPAINT_NS
):
    """
    Detect long/elongated hairs and inpaint. Returns (inpainted_rgb, mask_uint8, confidence).
    - mask is binary uint8 same size as original image (255 = hair).
    """
    def read_rgb(path):
        im = cv2.imread(path, cv2.IMREAD_COLOR)
        if im is None:
            raise FileNotFoundError(path)
        return cv2.cvtColor(im, cv2.COLOR_BGR2RGB)

    def read_mask(path):
        m = cv2.imread(path, cv2.IMREAD_UNCHANGED)
        if m is None:
            raise RuntimeError(f"Cannot read mask: {path}")
        if m.ndim == 3:
            m = cv2.cvtColor(m, cv2.COLOR_BGR2GRAY)
        return m

    img = read_rgb(img_path)
    mask = read_mask(mask_path)
    H, W = img.shape[:2]

    # choose working image (either original or resized while keeping aspect ratio)
    if process_original:
        work = img.copy()
        scale = 1.0
    else:
        min_side = min(H, W)
        if min_side <= min_side_for_fast:
            work = img.copy()
            scale = 1.0
        else:
            scale = float(1) / int(min_side / float(min_side_for_fast))
            new_w = int(round(W * scale))
            new_h = int(round(H * scale))
            resize_transform = A.Resize(new_h, new_w, interpolation=cv2.INTER_AREA)
            resize = resize_transform(image=img, mask=mask)
            work, mask = resize['image'], resize['mask']

    wh, ww = work.shape[:2]

    # grayscale + blackhat
    gray = cv2.cvtColor(work, cv2.COLOR_RGB2GRAY)
    k = cv2.getStructuringElement(cv2.MORPH_RECT, (blackhat_kernel, blackhat_kernel))
    bh = cv2.morphologyEx(gray, cv2.MORPH_BLACKHAT, k)
    bh = cv2.normalize(bh, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)

    # threshold
    if threshold_method == "otsu":
        thr_val, _ = cv2.threshold(bh, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        lower = max(thr_val // 3, fixed_thresh)
        _, th = cv2.threshold(bh, lower, 255, cv2.THRESH_BINARY)
    elif threshold_method == "fixed":
        _, th = cv2.threshold(bh, fixed_thresh, 255, cv2.THRESH_BINARY)
    else:  # adaptive
        # blockSize must be odd and >=3
        bs = 15 if min(wh, ww) > 100 else 11
        th = cv2.adaptiveThreshold(bh, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                   cv2.THRESH_BINARY, bs, -4)

    # connected components filtering (area + elongation + compactness)
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(th, connectivity=8)
    mask_small = np.zeros_like(th)

    # scale min hair area by image area relative to 512x512
    scale_area = (wh * ww) / float(512 * 512)
    min_hair_area = max(8, int(round(min_hair_area_base * scale_area)))

    for i in range(1, num_labels):
        x, y, w, h, area = stats[i]
        if area < min_hair_area:
            continue

        # elongation measure
        aspect = max(w, h) / (min(w, h) + 1e-8)
        compactness = area / float(w * h + 1e-8)

        # keep if reasonably elongated OR very large (covering long hair)
        if (aspect >= aspect_ratio_thresh and compactness <= compactness_thresh) or (area >= 4 * min_hair_area):
            mask_small[labels == i] = 255

    # small morphological clean + dilate to ensure hair fully covered
    if np.any(mask_small):
        mask_small = cv2.morphologyEx(mask_small, cv2.MORPH_OPEN, np.ones((3,3), np.uint8))
        mask_small = cv2.dilate(mask_small, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dilate_kernel, dilate_kernel)), iterations=dilate_iter)

    mask_full = mask_small

    # compute simple detection metrics / confidence
    total_px = mask_full.size
    hair_px = int(np.count_nonzero(mask_full))
    frac = hair_px / float(total_px)
    mean_bh = float(np.mean(bh)) / 255.0  # relative on working image

    # inpaint on original image (use mask_full). cv2.inpaint expects single-channel mask uint8
    mask_for_inpaint = (mask_full > 0).astype(np.uint8) * 255
    inpainted = cv2.inpaint(work, mask_for_inpaint, inpaint_radius, inpaint_method)
    return inpainted, mask


In [None]:
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

# --- Custom preprocessing ---
def conditional_median_blur(img, std_thresh=30, ksize=3):
    """Apply median blur only if image has high noise (std > threshold)."""
    if img.std() > std_thresh:
        return cv2.medianBlur(img, ksize)
    return img

# Albumentations allows custom transforms using A.Lambda
def preprocess_fn(img, **kwargs):
    # img = remove_hair(img)
    img = conditional_median_blur(img)
    return img

# --- Train Augmentations ---
train_transforms = A.Compose([
    A.Lambda(image=preprocess_fn, p=1.0),  # hair removal + conditional blur
    A.Resize(512, 512, interpolation=cv2.INTER_LINEAR),
    A.HorizontalFlip(p=0.5),
    A.RandomRotate90(p=0.3),
    A.Affine(
        scale=(0.9, 1.1),
        translate_percent=(0.1, 0.1),
        rotate=(-20, 20),
        shear=(-10, 10),
        p=0.7
    ),
    A.ElasticTransform(alpha=1, sigma=50, p=0.3),
    A.OneOf([
        A.RGBShift(r_shift_limit=20, g_shift_limit=20, b_shift_limit=20, p=0.5),
        A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5),
    ], p=0.5),
    A.OneOf([
        A.MotionBlur(blur_limit=5, p=0.2),
        A.MedianBlur(blur_limit=3, p=0.2),
    ], p=0.2),
    A.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), p=0.3),
    A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ToTensorV2()
])  # ensures mask is transformed consistently

# --- Validation Augmentations ---
val_transforms = A.Compose([
    A.Lambda(image=preprocess_fn, p=1.0),
    A.Resize(512, 512, interpolation=cv2.INTER_LINEAR),
    A.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), p=1.0),  # deterministic
    A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ToTensorV2()
])

# --- Test uses same as validation ---
test_transforms = val_transforms

In [None]:
train_img, val_img, train_msk, val_msk = train_test_split(
    train_image_path_list,
    train_mask_path_list,
    test_size=0.1,
    random_state=42,
    shuffle=True
)

In [None]:
train_dataset = SkinSegmentationDataset(
    image_paths=train_img,
    mask_paths=train_msk,
    test=False,
    transform=train_transforms
)

val_dataset = SkinSegmentationDataset(
    image_paths=val_img,
    mask_paths=val_msk,
    test=False,
    transform=val_transforms
)

test_dataset = SkinSegmentationDataset(
    image_paths=test_image_path_list,
    mask_paths=None,
    test=True,
    transform=test_transforms
)

In [None]:
img, mask = train_dataset[0]
validation_img, validation_mask = val_dataset[0]
testing_img = test_dataset[0]

In [None]:
plot_image_and_mask(img.permute(1,2,0), mask)


In [None]:
plot_image_and_mask(validation_img.permute(1,2,0), validation_mask)


In [None]:
plt.imshow(testing_img.permute(1,2,0))

In [None]:
# --- Core Tversky & Focal-Tversky ---
class TverskyLoss(nn.Module):
    def __init__(self, alpha=0.3, beta=0.7, eps=1e-6):
        super().__init__(); self.a, self.b, self.eps = alpha, beta, eps
    def forward(self, logits, y):
        p = torch.sigmoid(logits); y = y.float()
        tp = (p*y).sum((2,3))
        fp = (p*(1-y)).sum((2,3))
        fn = ((1-p)*y).sum((2,3))
        t = (tp + self.eps) / (tp + self.a*fp + self.b*fn + self.eps)
        return (1.0 - t).mean()

class FocalTverskyLoss(nn.Module):
    def __init__(self, alpha=0.3, beta=0.7, gamma=1.0, eps=1e-6):
        super().__init__(); self.a, self.b, self.g, self.eps = alpha, beta, gamma, eps
    def forward(self, logits, y):
        p = torch.sigmoid(logits); y = y.float()
        tp = (p*y).sum((2,3))
        fp = (p*(1-y)).sum((2,3))
        fn = ((1-p)*y).sum((2,3))
        t = (tp + self.eps) / (tp + self.a*fp + self.b*fn + self.eps)
        return (1.0 - t).pow(self.g).mean()

# --- Boundary-aware (Surface) loss: boosts edge accuracy ---
class SurfaceLoss(nn.Module):
    def forward(self, logits, y):
        p = torch.sigmoid(logits)
        with torch.no_grad():
            dts = []
            for yi in y.squeeze(1).detach().cpu().numpy():
                pos, neg = dted(yi>0), dted(yi==0)
                d = pos + neg
                dts.append(torch.from_numpy(d).unsqueeze(0))
            dt = torch.stack(dts, 0).to(logits.device).float()
            dt = dt / (dt.amax(dim=(2,3), keepdim=True) + 1e-6)
        return ((p - y.float()).abs() * dt).mean()

# --- BCE with dynamic pos_weight per batch (controls FP vs FN) ---
def bce_with_logits_dynamic_pw(logits, targets, max_pw=3.0):
    y = targets.float()
    p = (y>0.5).float()
    pos = p.sum()
    neg = (1-p).sum()
    # pos_weight = neg/pos (clipped)
    pw = (neg / (pos + 1e-6)).clamp(min=1.0, max=max_pw)
    return F.binary_cross_entropy_with_logits(logits, y, pos_weight=pw)

In [None]:
class ComboLossGen(nn.Module):
    """
    0.6 * FocalTversky(alpha=0.3, beta=0.7, gamma=1.0)
    + 0.2 * BCEWithLogits (dynamic pos_weight)
    + 0.2 * SurfaceLoss
    Shapes: logits [B,1,H,W], targets [B,1,H,W] or [B,H,W]
    """
    def __init__(self, a=0.3, b=0.7, g=1.0, bce_max_pw=3.0, surf_w=0.20):
        super().__init__()
        self.ft = FocalTverskyLoss(alpha=a, beta=b, gamma=g)
        self.surf = SurfaceLoss()
        self.bce_max_pw = bce_max_pw
        self.surf_w = surf_w

    def forward(self, logits, targets):
        if targets.ndim == 3: targets = targets.unsqueeze(1)
        if logits.ndim == 3: logits = logits.unsqueeze(1)
        lt = self.ft(logits, targets)
        lb = bce_with_logits_dynamic_pw(logits, targets, max_pw=self.bce_max_pw)
        ls = self.surf(logits, targets)
        return 0.6*lt + 0.2*lb + self.surf_w*ls

In [None]:
!pip install segmentation_models_pytorch

In [None]:
import segmentation_models_pytorch as smp

# ---------------------------
# Utils: soft-argmax and polar sampling
# ---------------------------
def soft_argmax_2d(heatmaps: torch.Tensor, eps=1e-6) -> torch.Tensor:
    # heatmaps: (B,1,H,W) or (B,H,W)
    if heatmaps.dim() == 4:
        hmap = heatmaps[:, 0]
    else:
        hmap = heatmaps
    B, H, W = hmap.shape
    h = hmap.view(B, -1)
    p = F.softmax(h, dim=1).view(B, H, W)
    xs = torch.linspace(0, W - 1, W, device=heatmaps.device)
    ys = torch.linspace(0, H - 1, H, device=heatmaps.device)
    xs = xs.view(1, 1, W).expand(B, H, W)
    ys = ys.view(1, H, 1).expand(B, H, W)
    x = (p * xs).sum(dim=(1, 2))
    y = (p * ys).sum(dim=(1, 2))
    return torch.stack([x, y], dim=1)  # (B,2) (x,y)

def polar_grid(center_xy: torch.Tensor, out_h: int, out_w: int, max_radius: float, H: int, W: int, device) -> torch.Tensor:
    # center_xy: (B,2) pixel coords (x,y)
    B = center_xy.shape[0]
    r = torch.linspace(0.0, max_radius, out_h, device=device)
    theta = torch.linspace(0.0, 2 * math.pi, out_w, device=device)
    rr, tt = torch.meshgrid(r, theta, indexing="ij")  # (out_h, out_w)
    xx = rr * torch.cos(tt)
    yy = rr * torch.sin(tt)
    xx = xx.unsqueeze(0).expand(B, -1, -1) + center_xy[:, 0].view(B, 1, 1)
    yy = yy.unsqueeze(0).expand(B, -1, -1) + center_xy[:, 1].view(B, 1, 1)
    grid = torch.stack([xx, yy], dim=-1)  # pixel coords
    # normalize to [-1,1]
    x = grid[..., 0]; y = grid[..., 1]
    x_norm = (x / (W - 1)) * 2.0 - 1.0
    y_norm = (y / (H - 1)) * 2.0 - 1.0
    return torch.stack([x_norm, y_norm], dim=-1)  # (B, out_h, out_w, 2)

def polar_transform_batch(img: torch.Tensor, center_xy: torch.Tensor, out_h: int, out_w: int, max_radius: float) -> torch.Tensor:
    # img: (B,C,H,W)
    B, C, H, W = img.shape
    device = img.device
    grid = polar_grid(center_xy, out_h, out_w, max_radius, H, W, device)
    polar = F.grid_sample(img, grid, mode="bilinear", padding_mode="zeros", align_corners=True)
    return polar  # (B,C,out_h,out_w)

def inverse_polar_to_cartesian(polar_img: torch.Tensor, center_xy: torch.Tensor, H: int, W: int, max_radius: float) -> torch.Tensor:
    # polar_img: (B, C, out_h, out_w)
    B, C, out_h, out_w = polar_img.shape
    device = polar_img.device
    ys = torch.linspace(0, H - 1, H, device=device)
    xs = torch.linspace(0, W - 1, W, device=device)
    yg, xg = torch.meshgrid(ys, xs, indexing='ij')  # (H,W)
    xg = xg.unsqueeze(0).expand(B, -1, -1)
    yg = yg.unsqueeze(0).expand(B, -1, -1)
    cx = center_xy[:, 0].view(B, 1, 1)
    cy = center_xy[:, 1].view(B, 1, 1)
    dx = xg - cx
    dy = yg - cy
    r = torch.sqrt(dx ** 2 + dy ** 2)
    theta = (torch.atan2(dy, dx) % (2 * math.pi))
    # map r,theta to polar pixel indices
    r_idx = (r / max_radius) * (out_h - 1)
    t_idx = (theta / (2 * math.pi)) * (out_w - 1)
    x_norm = (t_idx / (out_w - 1)) * 2.0 - 1.0
    y_norm = (r_idx / (out_h - 1)) * 2.0 - 1.0
    grid = torch.stack([x_norm, y_norm], dim=-1)  # (B,H,W,2)
    cart = F.grid_sample(polar_img, grid, mode='bilinear', padding_mode='zeros', align_corners=True)
    return cart  # (B,C,H,W)

# ---------------------------
# Stacked hourglass (small)
# ---------------------------
class ConvBNReLU(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.net(x)

class SimpleHourglass(nn.Module):
    def __init__(self, in_ch=1, base=32):
        super().__init__()
        # 3-level hourglass
        self.e1 = ConvBNReLU(in_ch, base)
        self.p1 = nn.MaxPool2d(2)
        self.e2 = ConvBNReLU(base, base*2)
        self.p2 = nn.MaxPool2d(2)
        self.e3 = ConvBNReLU(base*2, base*4)
        # up
        self.up2 = nn.ConvTranspose2d(base*4, base*2, 2, stride=2)
        self.u2_conv = ConvBNReLU(base*4, base*2)
        self.up1 = nn.ConvTranspose2d(base*2, base, 2, stride=2)
        self.u1_conv = ConvBNReLU(base*2, base)
        self.out_head = nn.Conv2d(base, 1, 1)

    def forward(self, x):
        x1 = self.e1(x)
        x2 = self.e2(self.p1(x1))
        x3 = self.e3(self.p2(x2))
        u2 = self.u2_conv(torch.cat([self.up2(x3), x2], dim=1))
        u1 = self.u1_conv(torch.cat([self.up1(u2), x1], dim=1))
        heat = self.out_head(u1)
        return heat, u1  # return heatmap and final feature

class StackedHourglass(nn.Module):
    def __init__(self, n_stacks=3, in_ch=1, base=24):
        super().__init__()
        self.stacks = nn.ModuleList([SimpleHourglass(in_ch, base) for _ in range(n_stacks)])

    def forward(self, x):
        heatmaps = []
        for hg in self.stacks:
            heat, _ = hg(x)   # each stack sees the same 1-channel input
            heatmaps.append(heat)
        return heatmaps


# ---------------------------
# Edge-attending decoder block (UpEdgeAttention)
# ---------------------------
class UpEdgeAttention(nn.Module):
    def __init__(self, enc_ch, dec_in_ch, out_ch):
        super().__init__()
        # upsample encoder skip → out_ch
        self.enc_up = nn.ConvTranspose2d(enc_ch, out_ch, kernel_size=2, stride=2)
        # project decoder input → out_ch
        self.dec_proj = nn.Conv2d(dec_in_ch, out_ch, kernel_size=1)
        # fusion convs
        self.conv = nn.Sequential(
            nn.Conv2d(out_ch * 2, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, enc_feat, dec_feat):
        enc_up = self.enc_up(enc_feat)
        dec_proj = self.dec_proj(dec_feat)

        if enc_up.shape[-2:] != dec_proj.shape[-2:]:
            enc_up = F.interpolate(enc_up, size=dec_proj.shape[-2:], mode='bilinear', align_corners=True)

        edge = enc_up - dec_proj
        cat = torch.cat([edge, dec_proj], dim=1)
        return self.conv(cat)

# ---------------------------
# SMP-based encoder wrapper to expose multi-scale features
# ---------------------------
class SMPEncoderWrapper(nn.Module):
    def __init__(self, encoder_name="resnet34", in_channels=3, pretrained=True):
        super().__init__()
        # create smp Unet but we'll access its encoder blocks
        model = smp.Unet(encoder_name=encoder_name, encoder_weights="imagenet" if pretrained else None, in_channels=in_channels, classes=1)
        self.encoder = model.encoder
        # encoder.out_channels gives channels of each stage (list)
        self.out_channels = self.encoder.out_channels  # e.g., [64, 64, 128, 256, 512] depending on encoder

    def forward(self, x):
        # returns list of features from encoder stages (skip connections)
        features = self.encoder(x)  # list of tensors
        # features[0] is initial conv output, last is deepest
        return features

# ---------------------------
# Polar/Cartesian decoders with edge attention + multi-scale fusion
# ---------------------------
class DecoderWithEdgeFusion(nn.Module):
    def __init__(self, encoder_channels, decoder_channels):
        super().__init__()
        self.num_stages = len(decoder_channels)
        self.up_blocks = nn.ModuleList()

        dec_in_ch = encoder_channels[-1]  # start with bottleneck channels
        for i in range(self.num_stages):
            enc_ch = encoder_channels[-(i+2)]  # skip channel
            out_ch = decoder_channels[i]
            self.up_blocks.append(UpEdgeAttention(enc_ch, dec_in_ch, out_ch))
            dec_in_ch = out_ch  # update for next stage

        self.final_conv = nn.Conv2d(decoder_channels[-1], 1, kernel_size=1)


    def forward(self, encoder_features, bottleneck_feat):
        # encoder_features: list returned by encoder (len L)
        # bottleneck_feat: deepest feature
        x = bottleneck_feat
        out_feats = []
        for i, up in enumerate(self.up_blocks):
            # enc feature corresponding to this stage
            skip = encoder_features[-(i+2)]
            x = up(skip, x)  # edge-attending fusion
            out_feats.append(x)
        # out_feats[-1] is the last decoder level
        logits = self.final_conv(out_feats[-1])
        return logits, out_feats  # return logits and intermediate decoded features

# ---------------------------
# Full EPolar-UNet (paper-close)
# ---------------------------
class EPolarUNetPaper(nn.Module):
    def __init__(self, in_channels=5, encoder_name="resnet34", n_stacks=3, polar_h=512, polar_w=512):
        super().__init__()
        self.in_channels = in_channels
        self.encoder_name = encoder_name
        self.polar_h = polar_h
        self.polar_w = polar_w
        # pole predictor (stacked hourglass) - expects single-chan input; if rgb convert mean
        self.hourglass = StackedHourglass(n_stacks=n_stacks, in_ch=1, base=24)

        # SMP encoders for Cartesian and Polar branches (shared architecture, separate weights)
        self.cart_encoder = SMPEncoderWrapper(encoder_name=encoder_name, in_channels=in_channels, pretrained=True)
        self.polar_encoder = SMPEncoderWrapper(encoder_name=encoder_name, in_channels=in_channels, pretrained=True)

        enc_ch = self.cart_encoder.out_channels  # list
        # choose decoder channel sizes (you might tune these to match encoder depth)
        # We'll use reversed smaller sizes for decoder
        decoder_channels = [256, 128, 64]  # adjust depending on encoder depth/out_channels
        # create decoders
        self.cart_decoder = DecoderWithEdgeFusion(enc_ch, decoder_channels)
        self.polar_decoder = DecoderWithEdgeFusion(enc_ch, decoder_channels)

        # multi-scale fusion convs: fuse decoder features at each decoder stage
        # number of decoder stages = len(decoder_channels)
        self.fusion_convs = nn.ModuleList([nn.Conv2d(c*2, c, 1) for c in decoder_channels])  # careful ordering later

        # final fusion head (after inverse warp of polar logits)
        self.fusion_head = nn.Sequential(
            nn.Conv2d(2, 1, 1)
        )

    def forward(self, x):
        # x: (B, C, H, W)
        B, C, H, W = x.shape
        device = x.device

        # 1) pole heatmaps (stacked hourglass expects single channel)
        gray = x.mean(dim=1, keepdim=True)  # (B,1,H,W)
        heatmaps = self.hourglass(gray)  # list of (B,1,H,W)
        last_hm = heatmaps[-1]
        coords = soft_argmax_2d(last_hm)  # (B,2) x,y

        # compute max_radius (use diag)
        max_radius = float(math.sqrt(H*H + W*W))

        # 2) Cartesian branch encoder
        cart_feats = self.cart_encoder(x)  # list of features length L
        # deepest bottleneck is last element
        cart_bottleneck = cart_feats[-1]

        # 3) Polar transform (differentiable)
        polar_img = polar_transform_batch(x, coords, self.polar_h, self.polar_w, max_radius)  # (B,C,ph,pw)

        # polar branch encoder
        polar_feats = self.polar_encoder(polar_img)
        polar_bottleneck = polar_feats[-1]

        # 4) decode both branches with edge-attending decoders
        cart_logits, cart_decoded_feats = self.cart_decoder(cart_feats, cart_bottleneck)
        polar_logits, polar_decoded_feats = self.polar_decoder(polar_feats, polar_bottleneck)
        # cart_decoded_feats & polar_decoded_feats: list length num_stages, each (B, ch, h_i, w_i)

        # 5) Multi-scale fusion: fuse each corresponding decoder feature
        # We'll upsample/reshape fused features progressively and optionally add to final prediction.
        fused_feats = []
        for i in range(len(cart_decoded_feats)):
            cfeat = cart_decoded_feats[i]
            pfeat = polar_decoded_feats[i]
            # if polar feature spatial size != cart feature, resize
            if pfeat.shape[-2:] != cfeat.shape[-2:]:
                pfeat = F.interpolate(pfeat, size=cfeat.shape[-2:], mode='bilinear', align_corners=True)
            fused = torch.cat([cfeat, pfeat], dim=1)  # concat channels
            # fusion conv (1x1) to reduce
            fuse_conv = self.fusion_convs[i]
            fused = fuse_conv(fused)
            fused_feats.append(fused)

        # 6) Choose a fusion strategy: here we take last fused feature, conv to logits
        # convert polar_logits (in polar image space) back to cartesian
        polar_logits_cart = inverse_polar_to_cartesian(polar_logits, coords, H, W, max_radius)  # (B,1,H,W)

        # ensure shapes for concatenation
        if polar_logits_cart.shape != cart_logits.shape:
            polar_logits_cart = F.interpolate(polar_logits_cart, size=cart_logits.shape[-2:], mode='bilinear', align_corners=True)

        fused_final = torch.cat([cart_logits, polar_logits_cart], dim=1)  # (B,2,H,W)
        out = self.fusion_head(fused_final)  # (B,1,H,W)
        out = F.interpolate(out, size=(H, W), mode="bilinear", align_corners=True)
        return out, {
            "coords": coords,
            "cart_logits": cart_logits,
            "polar_logits_cart": polar_logits_cart,
            "heatmaps": heatmaps,
            "fused_feats": fused_feats,
            "polar_img": polar_img
        }


In [None]:
class LightningSampleModel(pl.LightningModule):
    def __init__(self, lr=1e-2, weight_decay=0.0, base_ch=32):
        super().__init__()
        self.save_hyperparameters()
        
        self.model =  EPolarUNetPaper(
            in_channels=3,
            encoder_name="resnet34",
            n_stacks=2,
            polar_h=512,
            polar_w=512
        ) # <= Remmember to define this to your model

        # reuse your losses
        self.criterion = ComboLossGen(a=0.3, b=0.7, g=1.0, bce_max_pw=3.0, surf_w=0.20)

        # metrics
        self.train_iou = BinaryJaccardIndex(threshold=0.5)
        self.val_iou   = BinaryJaccardIndex(threshold=0.5)
        self.train_f1  = BinaryF1Score(threshold=0.5)
        self.val_f1    = BinaryF1Score(threshold=0.5)

    @staticmethod
    def dice_from_logits(logits, target, eps=1e-7):
        probs = torch.sigmoid(logits)
        preds = (probs > 0.5).float().view(probs.size(0), -1)
        target = target.view(target.size(0), -1).float()
        inter = (preds * target).sum(dim=1)
        dice = (2 * inter + eps) / (preds.sum(dim=1) + target.sum(dim=1) + eps)
        return dice.mean()

    def forward(self, x):
        return self.model(x)

    def _step(self, batch, stage: str):
        x, y = batch                    # y: (B,H,W) {0,1}
        x = x.float()
        logits = self(x)[0].squeeze(1)     # -> (B,H,W)
        loss = self.criterion(logits, y)
        dice = self.dice_from_logits(logits, y)
        probs = torch.sigmoid(logits)

        if stage == "train":
            self.train_iou.update(probs, y.int()); self.train_f1.update(probs, y.int())
            self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True)
            self.log("train/dice", dice, on_step=False, on_epoch=True, prog_bar=True)
        elif stage == "val":
            self.val_iou.update(probs, y.int()); self.val_f1.update(probs, y.int())
            self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=True)
            self.log("val/dice", dice, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def training_step(self, batch, batch_idx): return self._step(batch, "train")
    def validation_step(self, batch, batch_idx): self._step(batch, "val")

    def on_train_epoch_end(self):
        self.log("train/iou", self.train_iou.compute(), prog_bar=True)
        self.log("train/f1",  self.train_f1.compute())
        self.train_iou.reset(); self.train_f1.reset()
    def on_validation_epoch_end(self):
        self.log("val/iou", self.val_iou.compute(), prog_bar=True)
        self.log("val/f1",  self.val_f1.compute())
        self.val_iou.reset(); self.val_f1.reset()

    def configure_optimizers(self):
        # Paper uses Adam; keep your plateau scheduler for simplicity
        opt = torch.optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        sch = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode="max", patience=3, factor=0.5)
        return {"optimizer": opt, "lr_scheduler": {"scheduler": sch, "monitor": "val/dice"}}

In [None]:
class SkinDataModule(pl.LightningDataModule):
    def __init__(self, train_dataset, val_dataset, test_dataset=None, batch_size=5, num_workers=4):
        super().__init__()
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.test_dataset = test_dataset
        self.batch_size = batch_size
        self.num_workers = num_workers

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers
        )
    
    def test_dataloader(self):
        if self.test_dataset is None:
            raise ValueError("No test dataset provided!")
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers
        )

In [None]:
!mkdir checkpoints
!mkdir checkpoints/epolarunet

In [None]:
student_id = "10423057"  # TODO: replace with your student ID
api_key = os.environ.get("WANDB_API_KEY", "83f4544a22543e319c6009abceaac90b634c68a3")  # configure your wandb key here

if api_key == "":
    raise ValueError("Please set your wandb key in the code or in the environment variable WANDB_API_KEY")
else:
    print("WandB API key is set. Proceeding with login...")
    
wandb.login(key=api_key)

In [None]:

# Find latest checkpoint (if exists)
checkpoint_dir = "/kaggle/working/checkpoints/epolarunet"
ckpt_path = None

if os.path.exists(checkpoint_dir):
    ckpts = [os.path.join(checkpoint_dir, f) for f in os.listdir(checkpoint_dir) if f.endswith(".ckpt")]
    if ckpts:
        ckpt_path = max(ckpts, key=os.path.getctime)  # latest file by creation time

print("Resuming from checkpoint:" if ckpt_path else "No checkpoint found.", ckpt_path)


In [None]:
pl.seed_everything(42)

dm = SkinDataModule(train_dataset, val_dataset, test_dataset=test_dataset, batch_size=5, num_workers=os.cpu_count())
MODEL_NAME = "EPolar-Unet" # <= CHANGE YOUR MODEL NAME
PROJECT_NAME = "Skin-Lesion Segmentation" # <= CHANGE YOUR PROJECT NAME

model = LightningSampleModel(lr=1e-2, weight_decay=0.0, base_ch=32)

# 1) W&B logger
wandb_logger = WandbLogger(
    project=PROJECT_NAME,
    name=MODEL_NAME,
    log_model=True  # upload best.ckpt as an artifact
)

# Optional: track your hyperparams in the run config
wandb_logger.experiment.config.update({
    "lr": 1e-2,               # paper starts at 1e-2, but you can use 3e-4 if safer
    "batch_size": dm.batch_size,  # e.g. 8 for 512x512 on 15 GB
    "model": MODEL_NAME,      # instead of encoder_name
    "base_channels": 32,      # starting channels in the UNet
    "loss": "Combo Loss",
    "optimizer": "Adam",      # matches paper
    "img_size": "512x512",  # auto from your DataModule
})

# 2) Callbacks
ckpt_cb = ModelCheckpoint(
    dirpath='/kaggle/working/checkpoints/epolarunet',
    monitor="val/dice", mode="max", save_top_k=3, save_last=True,
    filename="epoch{epoch:02d}-valdice{val/dice:.4f}"
)

early_cb = EarlyStopping(monitor="val/dice", mode="max", patience=8)
lr_cb = LearningRateMonitor(logging_interval="epoch")

# 3) Trainer
trainer = pl.Trainer(
    max_epochs=100,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1,
    precision=32,           # use 32 if AMP causes issues
    gradient_clip_val=1.0,
    log_every_n_steps=10,
    callbacks=[ckpt_cb, early_cb,lr_cb],
    logger=wandb_logger,
    fast_dev_run=False
)

# 6) Continue training
trainer.fit(model, datamodule=dm, ckpt_path=ckpt_path)
