In [1]:
! pip install medpy
! pip install segmentation_models_pytorch
! pip install -U albumentations

Collecting medpy
  Downloading medpy-0.5.2.tar.gz (156 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m156.3/156.3 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: medpy
  Building wheel for medpy (setup.py) ... [?25l[?25hdone
  Created wheel for medpy: filename=MedPy-0.5.2-cp310-cp310-linux_x86_64.whl size=762838 sha256=2c68dc339f4d6d539b85b98aa7e9995c8ee22c1df3a233041dc8a130462ca04f
  Stored in directory: /root/.cache/pip/wheels/a1/b8/63/bdf557940ec60d1b8822e73ff9fbe7727ac19f009d46b5d175
Successfully built medpy
Installing collected packages: medpy
Successfully installed medpy-0.5.2
Collecting segmentation_models_pytorch
  Downloading segmentation_models_pytorch-0.4.0-py3-none-any.whl.metadata (32 kB)
Collecting efficientnet-pytorch>=0.6.1 (from segmentation_models_pytorch)
  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)
  Preparing metadata (setup.py) ... [?

In [5]:
import os
import re
import math
import copy
import glob
import random
import warnings
import numpy as np
from collections import defaultdict
from PIL import Image, ImageOps
import matplotlib.pyplot as plt
import cv2  # CLAHE 적용을 위해 OpenCV 사용

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import torchvision.models as models
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
from skimage.segmentation import find_boundaries
from skimage.measure import label, regionprops
import scipy.ndimage

# Albumentations 기반 augmentation
import albumentations as A
from albumentations.pytorch import ToTensorV2

# medpy HD metric
from medpy.metric import binary as medpy_binary

###############################################
# Seed 및 Warning 설정
###############################################
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(42)
warnings.filterwarnings('ignore')

###############################################
# BUSI Segmentation Dataset
###############################################
class BUSISegmentationDataset(Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.transform = transform  
        self.samples = []  # 반드시 초기화
        self._prepare_samples()

    def _prepare_samples(self):
        labels = os.listdir(self.data_path)
        for label in labels:
            folder_path = os.path.join(self.data_path, label)
            if not os.path.isdir(folder_path):
                continue
            files = os.listdir(folder_path)
            image_files = sorted([f for f in files if '_mask' not in f and f.endswith('.png')])
            mask_files  = sorted([f for f in files if '_mask' in f and f.endswith('.png')])
            pattern_img = re.compile(rf'{re.escape(label)} \((\d+)\)\.png')
            pattern_mask = re.compile(rf'{re.escape(label)} \((\d+)\)_mask(?:_\d+)?\.png')
            mask_dict = {}
            for mf in mask_files:
                m = pattern_mask.fullmatch(mf)
                if m:
                    idx = m.group(1)
                    mask_dict.setdefault(idx, []).append(mf)
            for im in image_files:
                m = pattern_img.fullmatch(im)
                if m:
                    idx = m.group(1)
                    img_path = os.path.join(folder_path, im)
                    if idx in mask_dict:
                        mask_paths = [os.path.join(folder_path, mf) for mf in mask_dict[idx]]
                        combined_mask = None
                        for mp in mask_paths:
                            mask_img = Image.open(mp).convert('L')
                            mask_arr = np.array(mask_img)
                            mask_binary = (mask_arr > 128).astype(np.uint8)
                            if combined_mask is None:
                                combined_mask = mask_binary
                            else:
                                combined_mask = np.maximum(combined_mask, mask_binary)
                        self.samples.append((img_path, combined_mask, label))
                    else:
                        image_pil = Image.open(img_path)
                        empty_mask = np.zeros(image_pil.size[::-1], dtype=np.uint8)
                        self.samples.append((img_path, empty_mask, label))

    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, index):
        img_path, mask_array, label = self.samples[index]
        image = Image.open(img_path).convert('RGB')
        mask = Image.fromarray((mask_array * 255).astype(np.uint8))
        if self.transform:
            image, mask = self.transform(image, mask)
        else:
            image = transforms.ToTensor()(image)
            mask = transforms.ToTensor()(mask)
        return image, mask, label

###############################################
# Per-image Z-score normalization (adaptive)
###############################################
def z_score_normalize(tensor):
    # tensor: (C, H, W)
    mean = tensor.mean()
    std = tensor.std() + 1e-6
    return (tensor - mean) / std

###############################################
# CLAHE Augmentation 함수 (OpenCV)
###############################################
def apply_clahe(image, clipLimit=2.0, tileGridSize=(8,8)):
    img_np = np.array(image)
    if len(img_np.shape) == 3 and img_np.shape[2] == 3:
        lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
        l, a, b = cv2.split(lab)
        clahe = cv2.createCLAHE(clipLimit=clipLimit, tileGridSize=tileGridSize)
        cl = clahe.apply(l)
        limg = cv2.merge((cl, a, b))
        final = cv2.cvtColor(limg, cv2.COLOR_LAB2RGB)
    else:
        clahe = cv2.createCLAHE(clipLimit=clipLimit, tileGridSize=tileGridSize)
        final = clahe.apply(img_np)
    return Image.fromarray(final)

###############################################
# joint_transform (Albumentations 기반 Augmentation)
###############################################
def joint_transform(image, mask, size=(224,224)):
    # 1. Geometric transforms (image와 mask에 동시에 적용)
    geom_transform = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.Rotate(limit=15, p=0.5),
        A.ElasticTransform(alpha=100, sigma=10, alpha_affine=15, p=0.5),
        A.Resize(height=size[0], width=size[1])
    ])
    image_np = np.array(image)
    mask_np = np.array(mask)
    augmented = geom_transform(image=image_np, mask=mask_np)
    image = augmented['image']
    mask = augmented['mask']
    
    # 2. Intensity transforms (오직 image에만 적용)
    intensity_transform = A.Compose([
        A.RandomBrightnessContrast(p=0.5),
        A.CLAHE(clip_limit=3.0, p=0.5),
        A.GaussianBlur(p=0.3)
    ])
    image = intensity_transform(image=image)['image']
    
    # 3. ToTensor 및 per-image z-score normalization (image에 적용)
    image = transforms.ToTensor()(image)
    image = z_score_normalize(image)
    mask = transforms.ToTensor()(mask)
    return image, mask

###############################################
# CutMix 함수 (강화: alpha=1.5, p=0.7)
###############################################
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    cx = np.random.randint(W)
    cy = np.random.randint(H)
    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    return bbx1, bby1, bbx2, bby2

def cutmix_data(images, masks, alpha=1.5, p=0.7):
    if np.random.rand() > p:
        return images, masks
    lam = np.random.beta(alpha, alpha)
    rand_index = torch.randperm(images.size(0)).to(images.device)
    bbx1, bby1, bbx2, bby2 = rand_bbox(images.size(), lam)
    images[:, :, bbx1:bbx2, bby1:bby2] = images[rand_index, :, bbx1:bbx2, bby1:bby2]
    masks[:, :, bbx1:bbx2, bby1:bby2] = masks[rand_index, :, bbx1:bbx2, bby1:bby2]
    return images, masks

###############################################
# Advanced 후처리: Morphological Closing 포함
###############################################
def postprocess_mask(mask, min_size=100):
    # 기본적으로 작은 객체 제거
    labeled_mask = label(mask)
    processed_mask = np.zeros_like(mask)
    for region in regionprops(labeled_mask):
        if region.area >= min_size:
            processed_mask[labeled_mask == region.label] = 1
    # Morphological closing (hole filling)
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5,5))
    closed = cv2.morphologyEx(processed_mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
    return closed

###############################################
# 1. AttentionGate 모듈
###############################################
class AttentionGate(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionGate, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

###############################################
# 2. MultiScaleFusion 모듈
###############################################
class MultiScaleFusion(nn.Module):
    def __init__(self, in_channels_list, out_channels):
        super(MultiScaleFusion, self).__init__()
        self.convs = nn.ModuleList([nn.Conv2d(ch, out_channels, kernel_size=1) for ch in in_channels_list])
        self.fuse = nn.Sequential(
            nn.Conv2d(len(in_channels_list)*out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, features):
        target_size = features[-1].size()[2:]
        processed = []
        for i, f in enumerate(features):
            if f.size()[2:] != target_size:
                f = F.interpolate(f, size=target_size, mode='bilinear', align_corners=False)
            f = self.convs[i](f)
            processed.append(f)
        out = torch.cat(processed, dim=1)
        out = self.fuse(out)
        return out

###############################################
# 3. 기타 모델 관련 모듈
###############################################
class StochasticDepth(nn.Module):
    def __init__(self, p, mode="row"):
        super(StochasticDepth, self).__init__()
        self.p = p
        self.mode = mode

    def forward(self, x):
        if not self.training or self.p == 0.0:
            return x
        survival_rate = 1 - self.p
        if self.mode == "row":
            batch_size = x.shape[0]
            noise = torch.rand(batch_size, 1, 1, 1, device=x.device, dtype=x.dtype)
            binary_mask = (noise < survival_rate).float()
            return x / survival_rate * binary_mask
        else:
            if torch.rand(1).item() < self.p:
                return torch.zeros_like(x)
            else:
                return x

class GhostModule(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, ratio=2, dw_kernel_size=3, stride=1, padding=0, use_relu=True):
        super(GhostModule, self).__init__()
        self.out_channels = out_channels
        self.primary_channels = int(torch.ceil(torch.tensor(out_channels / ratio)))
        self.cheap_channels = out_channels - self.primary_channels
        self.primary_conv = nn.Sequential(
            nn.Conv2d(in_channels, self.primary_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(self.primary_channels),
            nn.ReLU(inplace=True) if use_relu else nn.Identity()
        )
        self.cheap_conv = nn.Sequential(
            nn.Conv2d(self.primary_channels, self.cheap_channels, dw_kernel_size, stride=1,
                      padding=dw_kernel_size//2, groups=self.primary_channels, bias=False),
            nn.BatchNorm2d(self.cheap_channels),
            nn.ReLU(inplace=True) if use_relu else nn.Identity()
        )
    def forward(self, x):
        x1 = self.primary_conv(x)
        x2 = self.cheap_conv(x1)
        out = torch.cat([x1, x2], dim=1)
        return out[:, :self.out_channels, :, :].contiguous()

def ghost_conv_block(in_channels, out_channels, use_relu=True):
    return nn.Sequential(
        GhostModule(in_channels, out_channels, kernel_size=3, ratio=2, dw_kernel_size=3, stride=1, padding=1, use_relu=use_relu),
        GhostModule(out_channels, out_channels, kernel_size=3, ratio=2, dw_kernel_size=3, stride=1, padding=1, use_relu=use_relu)
    )

class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )
    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

class DynamicGating(nn.Module):
    def __init__(self, num_branches, hidden_dim=64, dropout_prob=0.1, init_temperature=2.0, iterations=3):
        super(DynamicGating, self).__init__()
        self.temperature = init_temperature
        self.iterations = iterations
        self.dropout_prob = dropout_prob
        self.fc = nn.Sequential(
            nn.Linear(num_branches, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(p=self.dropout_prob),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(p=self.dropout_prob),
            nn.Linear(hidden_dim, num_branches)
        )
        self.layernorm = nn.LayerNorm(num_branches)
        self.softmax = nn.Softmax(dim=1)
        self.res_scale = nn.Parameter(torch.ones(1))
        self.saved_log_probs = None
        self.entropy = None
        self.rl_loss = 0.0

    def forward(self, features):
        norm_features = self.layernorm(features)
        logits = self.fc(norm_features)
        logits = self.layernorm(logits)
        scaled_logits = logits / self.temperature
        gates = self.softmax(scaled_logits)
        for _ in range(self.iterations - 1):
            updated_features = norm_features + self.res_scale * gates
            logits = self.fc(updated_features)
            logits = self.layernorm(logits)
            scaled_logits = logits / self.temperature
            new_gates = self.softmax(scaled_logits)
            gates = 0.5 * gates + 0.5 * new_gates
        return gates

###############################################
###############################################
###############################################

class QuadAgentBlock(nn.Module):
    def __init__(self, in_channels, out_channels, gating_dropout=0.3, gating_hidden_dim=32, gating_temperature=1.5, stochastic_depth_prob=0.5):
        super().__init__()
        branch_channels = out_channels // 4
        self.branch1 = GhostModule(in_channels, branch_channels, ratio=4)
        # branch2: kernel_size=3, padding=1로 spatial size 유지
        self.branch2 = GhostModule(in_channels, branch_channels, kernel_size=3, padding=1, ratio=4)
        self.branch3 = nn.Sequential(
            GhostModule(in_channels, branch_channels, ratio=4),
            SELayer(branch_channels, reduction=32)
        )
        self.gap = nn.AdaptiveAvgPool2d(1)
        # num_branches를 3으로 설정
        self.gating = DynamicGating(num_branches=3, hidden_dim=gating_hidden_dim, dropout_prob=gating_dropout, init_temperature=gating_temperature)
        self.fusion_conv = GhostModule(branch_channels * 3, out_channels, ratio=2)
        self.res_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) if in_channels != out_channels else nn.Identity()
        self.stochastic_depth = StochasticDepth(stochastic_depth_prob)

    def forward(self, x):
        b1 = self.branch1(x)
        b2 = self.branch2(x)
        b3 = self.branch3(x)
        gap_b1 = self.gap(b1).view(x.size(0), -1)
        gap_b2 = self.gap(b2).view(x.size(0), -1)
        gap_b3 = self.gap(b3).view(x.size(0), -1)
        features = torch.stack([gap_b1.mean(dim=1), gap_b2.mean(dim=1), gap_b3.mean(dim=1)], dim=1)
        gates = self.gating(features)  # (B, 3)
        b1 = b1 * gates[:, 0].view(-1, 1, 1, 1)
        b2 = b2 * gates[:, 1].view(-1, 1, 1, 1)
        b3 = b3 * gates[:, 2].view(-1, 1, 1, 1)
        out = torch.cat([b1, b2, b3], dim=1)
        out = self.fusion_conv(out)
        res = self.res_conv(x)
        res = self.stochastic_depth(res)
        return out + res


"""
🔥 결과 기대 효과
개선 방법	설명	기대 효과
브랜치 개수 축소	4개 → 3개로 변경	연산량 약 25% 절감
Dilation → Dynamic Conv	Depthwise 적용	경량화 및 성능 유지
SE → Lightweight SE	채널 축소 후 적용	성능 유지, 메모리 절감
Gating 개선	Softmax → Gumbel-Softmax	학습 속도 및 게이팅 성능 향상
Fusion 최적화	Conv1x1 → GhostFusion	연산량 절감, 성능 유지

"""
###############################################
###############################################
###############################################


###############################################
# 4. PretrainedResNetUNet_AttentionFusion (전체 모델)
###############################################
###############################################
# EfficientNet-B0 Backbone 기반 UNet with Attention + Lite Bottleneck
###############################################
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm  # timm 라이브러리 (ViT 백본)
from medpy.metric import binary as medpy_binary

###############################################
# Transformer Bottleneck (추가)
###############################################
class TransformerBottleneck(nn.Module):
    def __init__(self, d_model=768, nhead=8, num_layers=1, dim_feedforward=2048, dropout=0.1):
        super(TransformerBottleneck, self).__init__()
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, 
                                                    dim_feedforward=dim_feedforward, dropout=dropout)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
    def forward(self, x):
        # x: (B, C, H, W) -> (HW, B, C)
        B, C, H, W = x.size()
        x = x.view(B, C, H*W).permute(2, 0, 1)
        x = self.transformer(x)
        x = x.permute(1, 2, 0).view(B, C, H, W)
        return x

###############################################
# UNet Decoder (간단한 Decoder)
###############################################
class UNetDecoder(nn.Module):
    def __init__(self):
        super(UNetDecoder, self).__init__()
        self.up1 = nn.ConvTranspose2d(768, 384, kernel_size=2, stride=2)
        self.conv1 = nn.Sequential(
            nn.Conv2d(384, 384, kernel_size=3, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU(inplace=True)
        )
        self.up2 = nn.ConvTranspose2d(384, 192, kernel_size=2, stride=2)
        self.conv2 = nn.Sequential(
            nn.Conv2d(192, 192, kernel_size=3, padding=1),
            nn.BatchNorm2d(192),
            nn.ReLU(inplace=True)
        )
        self.up3 = nn.ConvTranspose2d(192, 96, kernel_size=2, stride=2)
        self.conv3 = nn.Sequential(
            nn.Conv2d(96, 96, kernel_size=3, padding=1),
            nn.BatchNorm2d(96),
            nn.ReLU(inplace=True)
        )
        self.up4 = nn.ConvTranspose2d(96, 64, kernel_size=2, stride=2)
        self.conv4 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        x = self.up1(x)
        x = self.conv1(x)
        x = self.up2(x)
        x = self.conv2(x)
        x = self.up3(x)
        x = self.conv3(x)
        x = self.up4(x)
        x = self.conv4(x)
        return x

###############################################
# DARKViT_UNet: ViT 기반 Encoder + QuadAgentBlock + Transformer Bottleneck + UNet Decoder
###############################################
from timm.models import create_model

class DARKViT_UNet(nn.Module):
    def __init__(self, out_channels=1):
        super(DARKViT_UNet, self).__init__()
        self.encoder = create_model('swin_large_patch4_window7_224', pretrained=True, features_only=True)
        self.bottleneck = QuadAgentBlock(in_channels=1024, out_channels=1024, gating_dropout=0.3, gating_hidden_dim=32, gating_temperature=1.5, stochastic_depth_prob=0.5)
        self.transformer_bottleneck1 = TransformerBottleneck(d_model=1024, nhead=8, num_layers=1, dim_feedforward=2048, dropout=0.1)
        self.transformer_bottleneck2 = TransformerBottleneck(d_model=1024, nhead=8, num_layers=1, dim_feedforward=2048, dropout=0.1)
        self.decoder = UNetDecoder_Swin()
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
    
    def forward(self, x):
        features_list = self.encoder(x)  # features_list: [f0, f1, f2, f3]
        # f3가 NHWC 형식로 반환되는 경우 차원 변환
        f3 = features_list[-1].permute(0, 3, 1, 2)  # (B,1024,H,W)
        b_out = self.bottleneck(f3)
        t_out1 = self.transformer_bottleneck1(b_out)
        t_out2 = self.transformer_bottleneck2(t_out1)
        # 만약 decoder가 NCHW 형식을 요구한다면, features_list 나 f3만 교체합니다.
        features_list[-1] = t_out2
        decoded = self.decoder(features_list)
        return self.final_conv(decoded)



###############################################
# 5. Loss & Metric Functions
###############################################
from segmentation_models_pytorch.losses import LovaszLoss
lv_loss = LovaszLoss(mode='binary')

def focal_tversky_loss(pred, target, alpha=0.5, beta=0.5, gamma=4/3, smooth=1e-6):
    pred = torch.sigmoid(pred)
    if pred.shape[-2:] != target.shape[-2:]:
        pred = F.interpolate(pred, size=target.shape[-2:], mode='bilinear', align_corners=False)
    pred = pred.view(pred.size(0), -1)
    target = target.view(target.size(0), -1)
    tp = (pred * target).sum(dim=1)
    fp = ((1 - target) * pred).sum(dim=1)
    fn = (target * (1 - pred)).sum(dim=1)
    tversky_index = (tp + smooth) / (tp + alpha * fp + beta * fn + smooth)
    loss = (1 - tversky_index) ** gamma
    return loss.mean()

def boundary_loss(pred, target):
    pred = torch.sigmoid(pred)
    if pred.shape[-2:] != target.shape[-2:]:
        pred = F.interpolate(pred, size=target.shape[-2:], mode='bilinear', align_corners=False)
    pred_np = pred.detach().cpu().numpy()
    target_np = target.detach().cpu().numpy()
    boundary_masks = []
    for i in range(pred_np.shape[0]):
        gt_mask = target_np[i, 0]
        boundary = find_boundaries(gt_mask, mode='thick')
        boundary_masks.append(boundary.astype(np.float32))
    boundary_masks = np.stack(boundary_masks, axis=0)[:, None, :, :]
    boundary_masks_torch = torch.from_numpy(boundary_masks).to(pred.device)
    intersect = (pred * boundary_masks_torch).sum()
    denom = pred.sum() + boundary_masks_torch.sum()
    boundary_dice = (2.0 * intersect) / (denom + 1e-6)
    return 1.0 - boundary_dice

def combined_loss(outputs, masks, pos_weight=None, lambda_boundary=0.2, lambda_lovasz=0.6):
    if outputs.shape[-2:] != masks.shape[-2:]:
        outputs = F.interpolate(outputs, size=masks.shape[-2:], mode='bilinear', align_corners=False)
    loss_ft = focal_tversky_loss(outputs, masks)
    if pos_weight is not None:
        loss_bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(outputs, masks.float())
    else:
        loss_bce = nn.BCEWithLogitsLoss()(outputs, masks.float())
    bl = boundary_loss(outputs, masks)
    lv = lv_loss(outputs, masks)
    return 0.4 * lv + 0.3 * loss_ft + 0.3 * loss_bce + lambda_boundary * bl

def dice_f1_precision_recall(pred, target, threshold=0.5, smooth=1e-6):
    if pred.shape[-2:] != target.shape[-2:]:
        pred = F.interpolate(pred, size=target.shape[-2:], mode='bilinear', align_corners=False)
    pred_bin = (torch.sigmoid(pred) > threshold).float()
    target_bin = target.float()
    intersection = (pred_bin * target_bin).sum()
    precision = intersection / (pred_bin.sum() + smooth)
    recall = intersection / (target_bin.sum() + smooth)
    f1 = 2 * (precision * recall) / (precision + recall + smooth)
    return f1.item(), precision.item(), recall.item()

def iou_metric(pred, target, threshold=0.5, smooth=1e-6):
    if pred.shape[-2:] != target.shape[-2:]:
        pred = F.interpolate(pred, size=target.shape[-2:], mode='bilinear', align_corners=False)
    pred_bin = (torch.sigmoid(pred) > threshold).float()
    target_bin = target.float()
    intersection = (pred_bin * target_bin).sum()
    union = pred_bin.sum() + target_bin.sum() - intersection
    return (intersection + smooth) / (union + smooth)

def postprocess_mask(mask, min_size=100):
    labeled_mask = label(mask)
    processed_mask = np.zeros_like(mask)
    for region in regionprops(labeled_mask):
        if region.area >= min_size:
            processed_mask[labeled_mask == region.label] = 1
    # Kernel 크기를 (7,7)로 늘려 후처리 강화
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7,7))
    closed = cv2.morphologyEx(processed_mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
    return closed


def compute_hd_metric(outputs, masks, threshold=0.5):
    # outputs와 masks를 후처리하여 binary mask로 변환한 후, 각 sample마다 HD (hd95)를 계산
    if outputs.shape[-2:] != masks.shape[-2:]:
        outputs = F.interpolate(outputs, size=masks.shape[-2:], mode='bilinear', align_corners=False)
    pred_probs = torch.sigmoid(outputs).detach().cpu().numpy()
    masks_np = masks.detach().cpu().numpy()
    hd_list = []
    for i in range(outputs.size(0)):
        pred_bin = (pred_probs[i, 0] > threshold).astype(np.uint8)
        gt_bin = (masks_np[i, 0] > 0.5).astype(np.uint8)
        try:
            hd = medpy_binary.hd95(pred_bin, gt_bin)
        except Exception:
            hd = np.nan
        hd_list.append(hd)
    return np.nanmean(hd_list)


def compute_batch_metrics_new(outputs, masks, threshold=0.5, smooth=1e-6):
    if outputs.shape[-2:] != masks.shape[-2:]:
        outputs = F.interpolate(outputs, size=masks.shape[-2:], mode='bilinear', align_corners=False)
    pred_probs = torch.sigmoid(outputs).detach()
    processed_preds = []
    for i in range(pred_probs.size(0)):
        pred_np = pred_probs[i].cpu().numpy()[0]
        pred_bin = (pred_np > threshold).astype(np.uint8)
        processed = postprocess_mask(pred_bin, min_size=100)
        processed_preds.append(processed)
    batch_iou = []
    batch_f1 = []
    batch_precision = []
    batch_recall = []
    for i in range(pred_probs.size(0)):
        pred = processed_preds[i]
        gt = masks[i].cpu().numpy()[0]
        intersection = np.sum(pred * gt)
        union = np.sum(pred) + np.sum(gt) - intersection
        iou = (intersection + smooth) / (union + smooth)
        batch_iou.append(iou)
        precision = intersection / (np.sum(pred) + smooth)
        recall = intersection / (np.sum(gt) + smooth)
        f1 = 2 * (precision * recall) / (precision + recall + smooth)
        batch_f1.append(f1)
        batch_precision.append(precision)
        batch_recall.append(recall)
    probs = pred_probs.cpu().numpy().flatten()
    masks_np = masks.cpu().numpy().flatten()
    try:
        auc_score = roc_auc_score(masks_np, probs)
    except ValueError:
        auc_score = float('nan')
    return auc_score, np.mean(batch_iou), np.mean(batch_f1), np.mean(batch_precision), np.mean(batch_recall)

def find_optimal_threshold_iou(model, dataloader, device, smooth=1e-6):
    thresholds = np.linspace(0.1, 0.9, 17)  # 0.05 간격 세밀한 탐색
    best_threshold = 0.65
    best_iou = 0.0
    model.eval()
    with torch.no_grad():
        for t in thresholds:
            total_iou = 0.0
            sample_count = 0
            for images, masks, _ in dataloader:
                images = images.to(device)
                masks = masks.to(device)
                outputs = model(images)
                pred_probs = torch.sigmoid(outputs).detach().cpu().numpy()
                masks_np = masks.cpu().numpy()
                for i in range(outputs.size(0)):
                    pred_bin = (pred_probs[i, 0] > t).astype(np.uint8)
                    pred_processed = postprocess_mask(pred_bin, min_size=100)
                    gt = masks_np[i, 0]
                    intersection = np.sum(pred_processed * gt)
                    union = np.sum(pred_processed) + np.sum(gt) - intersection
                    iou_val = (intersection + smooth) / (union + smooth)
                    total_iou += iou_val
                    sample_count += 1
            avg_iou = total_iou / sample_count if sample_count > 0 else 0
            if avg_iou > best_iou:
                best_iou = avg_iou
                best_threshold = t
    return best_threshold, best_iou

def evaluate_with_metrics(model, dataloader, device, threshold=0.5, lambda_boundary=0.2, pos_weight=None):
    model.eval()
    total_loss = 0.0
    total_auc  = 0.0
    total_iou  = 0.0
    total_f1   = 0.0
    total_precision = 0.0
    total_recall = 0.0
    total_hd = 0.0  # HD 합계
    total_samples = 0
    with torch.no_grad():
        for images, masks, _ in dataloader:
            images = images.to(device)
            masks = masks.to(device)
            outputs = model(images)
            loss = combined_loss(outputs, masks, pos_weight=pos_weight, lambda_boundary=lambda_boundary)
            auc_score, iou_val, f1_val, prec, rec = compute_batch_metrics_new(outputs, masks, threshold=threshold, smooth=1e-6)
            hd_val = compute_hd_metric(outputs, masks, threshold=threshold)
            bs = images.size(0)
            total_loss += loss.item() * bs
            total_auc  += auc_score * bs
            total_iou  += iou_val * bs
            total_f1   += f1_val * bs
            total_precision += prec * bs
            total_recall += rec * bs
            total_hd += hd_val * bs
            total_samples += bs
    avg_loss = total_loss / total_samples
    avg_auc  = total_auc / total_samples
    avg_iou  = total_iou / total_samples
    avg_f1   = total_f1 / total_samples
    avg_precision = total_precision / total_samples
    avg_recall = total_recall / total_samples
    avg_hd = total_hd / total_samples
    return avg_loss, avg_auc, avg_iou, avg_f1, avg_precision, avg_recall, avg_hd


###############################################
# Hausdorff Distance (HD) 계산 함수
###############################################
def hausdorff_distance(pred, target):
    # pred, target: numpy binary masks
    return medpy_binary.hd95(pred, target)

###############################################
# TTA (Test Time Augmentation) 함수
###############################################
def tta_predict(model, image, device):
    # image: single image tensor, shape (C, H, W)
    def identity(x): return x
    def hflip(x): return torch.flip(x, dims=[-1])
    def vflip(x): return torch.flip(x, dims=[-2])
    def rot90(x): return torch.rot90(x, k=1, dims=[-2, -1])
    # 역변환 함수: flip은 동일한 연산으로 복구, rot90의 경우 3회 회전하면 원상복구
    def inv_hflip(x): return torch.flip(x, dims=[-1])
    def inv_vflip(x): return torch.flip(x, dims=[-2])
    def inv_rot90(x): return torch.rot90(x, k=3, dims=[-2, -1])
    
    transforms_list = [
        (identity, identity),
        (hflip, inv_hflip),
        (vflip, inv_vflip),
        (rot90, inv_rot90)
    ]
    predictions = []
    model.eval()
    with torch.no_grad():
        for aug, inv in transforms_list:
            augmented = aug(image)
            output = model(augmented.unsqueeze(0).to(device))
            output = torch.sigmoid(output)
            output = inv(output).cpu()
            predictions.append(output)
    avg_prediction = torch.mean(torch.stack(predictions), dim=0)
    return avg_prediction

###############################################
# Optimizer Scheduler 함수
###############################################
def get_lambda_rl(epoch):
    return min(0.1, epoch / 100.0)

def get_lambda_entropy(epoch):
    return max(0.01, 0.1 - (epoch / 100.0))

class WarmupCosineScheduler:
    def __init__(self, optimizer, warmup_steps, max_steps, max_lr, min_lr=1e-6):
        self.optimizer = optimizer
        self.warmup_steps = warmup_steps
        self.max_steps = max_steps
        self.max_lr = max_lr
        self.min_lr = min_lr
        self.global_step = 0

    def step(self):
        self.global_step += 1
        if self.global_step < self.warmup_steps:
            lr = self.max_lr * (self.global_step / self.warmup_steps)
        else:
            progress = (self.global_step - self.warmup_steps) / (self.max_steps - self.warmup_steps)
            lr = self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1 + math.cos(math.pi * progress))
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

###############################################
# 데이터 로더 및 BUSI Dataset 설정
###############################################
data_path = '/kaggle/input/breast-ultrasound-images-dataset/Dataset_BUSI_with_GT/'
full_dataset = BUSISegmentationDataset(data_path, transform=joint_transform)
indices = np.arange(len(full_dataset))
train_val_idx, test_idx = train_test_split(indices, test_size=0.2, random_state=42)
train_idx, val_idx = train_test_split(train_val_idx, test_size=0.25, random_state=42)
train_dataset = Subset(full_dataset, train_idx)
val_dataset   = Subset(full_dataset, val_idx)
test_dataset  = Subset(full_dataset, test_idx)

train_loader = DataLoader(
    train_dataset, batch_size=16, shuffle=True, num_workers=4,
    worker_init_fn=lambda worker_id: np.random.seed(42 + worker_id),
    pin_memory=True, persistent_workers=True
)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True)

###############################################
# 모델, Optimizer, Scheduler 설정
###############################################
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DARKViT_UNet(out_channels=1)
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
model = model.to(device)

def calculate_pos_weight(loader):
    total_pixels = 0
    positive_pixels = 0
    for images, masks, _ in loader:
        positive_pixels += masks.sum().item()
        total_pixels += masks.numel()
    negative_pixels = total_pixels - positive_pixels
    return torch.tensor(negative_pixels / (positive_pixels + 1e-6)).to(device)

pos_weight = calculate_pos_weight(train_loader)
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=5e-4)
num_epochs = 500

scheduler = WarmupCosineScheduler(
    optimizer, warmup_steps=10, max_steps=num_epochs,
    max_lr=1e-4, min_lr=1e-5
)

###############################################
# Training Loop
###############################################
from torch.cuda.amp import GradScaler, autocast
scaler = GradScaler()
patience = 20
best_val_f1 = 0.0
patience_counter = 0

for epoch in range(num_epochs):
    current_temp = 5.0 - ((5.0 - 1.0) * epoch / num_epochs)
    if hasattr(model, 'bottleneck'):
        if isinstance(model, nn.DataParallel):
            model.module.bottleneck.gating.temperature = current_temp
        else:
            model.bottleneck.gating.temperature = current_temp


    model.train()
    epoch_loss = 0.0
    epoch_auc = 0.0
    epoch_iou = 0.0
    epoch_f1 = 0.0
    epoch_precision = 0.0
    epoch_recall = 0.0
    total_samples = 0

    current_lambda_rl = get_lambda_rl(epoch)
    current_lambda_entropy = get_lambda_entropy(epoch)

    for images, masks, _ in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} Training"):
        images = images.to(device)
        masks = masks.to(device)
        # CutMix 활성화 (alpha=1.5, p=0.7)
        images, masks = cutmix_data(images, masks, alpha=1.5, p=0.7)
        optimizer.zero_grad()
        with autocast():
            outputs = model(images)
            loss_seg = combined_loss(outputs, masks, pos_weight=pos_weight, lambda_boundary=0.2, lambda_lovasz=0.6)
            policy_loss = torch.tensor(0.0, device=loss_seg.device)  # RL 미사용
            loss_total = loss_seg + current_lambda_rl * policy_loss
        scaler.scale(loss_total).backward()
        scaler.unscale_(optimizer)
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        bs = images.size(0)
        epoch_loss += loss_total.item() * bs
        auc_score, iou_val, f1_val, prec, rec = compute_batch_metrics_new(outputs, masks, threshold=0.5, smooth=1e-6)
        epoch_auc += auc_score * bs
        epoch_iou += iou_val * bs
        epoch_f1 += f1_val * bs
        epoch_precision += prec * bs
        epoch_recall += rec * bs
        total_samples += bs

    scheduler.step()
    avg_loss = epoch_loss / total_samples
    avg_auc = epoch_auc / total_samples
    avg_iou = epoch_iou / total_samples
    avg_f1 = epoch_f1 / total_samples
    avg_precision = epoch_precision / total_samples
    avg_recall = epoch_recall / total_samples
    print(f"Epoch {epoch+1}/{num_epochs} | Loss: {avg_loss:.4f}, AUC: {avg_auc:.4f}, IoU: {avg_iou:.4f}, Dice/F1: {avg_f1:.4f}, Precision: {avg_precision:.4f}, Recall: {avg_recall:.4f} (Temp: {current_temp:.2f})")
    
    val_loss, val_auc, val_iou, val_f1, val_precision, val_recall, val_hd = evaluate_with_metrics(model, val_loader, device, threshold=0.5, lambda_boundary=0.2, pos_weight=pos_weight)
    print(f"[Val] Loss: {val_loss:.4f}, AUC: {val_auc:.4f}, IoU: {val_iou:.4f}, Dice/F1: {val_f1:.4f}, Precision: {val_precision:.4f}, Recall: {val_recall:.4f}, HD95: {val_hd:.4f}")

    
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        best_weights = copy.deepcopy(model.state_dict())
        torch.save(best_weights, 'best_model.pth')
        patience_counter = 0
        print("🍀 New best validation Dice achieved! Dice: {:.4f} 🍀".format(best_val_f1))
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break


model.safetensors:   0%|          | 0.00/788M [00:00<?, ?B/s]

NameError: name 'UNetDecoder_Swin' is not defined

In [18]:
test_loss, test_auc, test_iou, test_f1, test_precision, test_recall, test_hd = evaluate_with_metrics(model, test_loader, device, threshold=0.5, lambda_boundary=0.2, pos_weight=pos_weight)
print(f"[Test] Loss: {test_loss:.4f}, AUC: {test_auc:.4f}, IoU: {test_iou:.4f}, Dice/F1: {test_f1:.4f}, Precision: {test_precision:.4f}, Recall: {test_recall:.4f}, HD95: {test_hd:.4f}")


[Test] Loss: 1.0922, AUC: 0.9580, IoU: 0.6721, Dice/F1: 0.6004, Precision: 0.6322, Recall: 0.6067, HD95: 23.9212


# DARK-Swin

In [6]:
import os
import re
import math
import copy
import glob
import random
import warnings
import numpy as np
from collections import defaultdict
from PIL import Image, ImageOps
import matplotlib.pyplot as plt
import cv2  # CLAHE 적용을 위해 OpenCV 사용

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
from skimage.segmentation import find_boundaries
from skimage.measure import label, regionprops
import scipy.ndimage

# Albumentations 기반 augmentation
import albumentations as A
from albumentations.pytorch import ToTensorV2

# medpy HD metric
from medpy.metric import binary as medpy_binary
from torchvision.transforms import ColorJitter

###############################################
# Seed 및 Warning 설정
###############################################
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(42)
warnings.filterwarnings('ignore')

###############################################
# BUSI Segmentation Dataset
###############################################
class BUSISegmentationDataset(Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.transform = transform  
        self.samples = []  # 반드시 초기화
        self._prepare_samples()

    def _prepare_samples(self):
        labels = os.listdir(self.data_path)
        for label in labels:
            folder_path = os.path.join(self.data_path, label)
            if not os.path.isdir(folder_path):
                continue
            files = os.listdir(folder_path)
            image_files = sorted([f for f in files if '_mask' not in f and f.endswith('.png')])
            mask_files  = sorted([f for f in files if '_mask' in f and f.endswith('.png')])
            pattern_img = re.compile(rf'{re.escape(label)} \((\d+)\)\.png')
            pattern_mask = re.compile(rf'{re.escape(label)} \((\d+)\)_mask(?:_\d+)?\.png')
            mask_dict = {}
            for mf in mask_files:
                m = pattern_mask.fullmatch(mf)
                if m:
                    idx = m.group(1)
                    mask_dict.setdefault(idx, []).append(mf)
            for im in image_files:
                m = pattern_img.fullmatch(im)
                if m:
                    idx = m.group(1)
                    img_path = os.path.join(folder_path, im)
                    if idx in mask_dict:
                        mask_paths = [os.path.join(folder_path, mf) for mf in mask_dict[idx]]
                        combined_mask = None
                        for mp in mask_paths:
                            mask_img = Image.open(mp).convert('L')
                            mask_arr = np.array(mask_img)
                            mask_binary = (mask_arr > 128).astype(np.uint8)
                            if combined_mask is None:
                                combined_mask = mask_binary
                            else:
                                combined_mask = np.maximum(combined_mask, mask_binary)
                        self.samples.append((img_path, combined_mask, label))
                    else:
                        image_pil = Image.open(img_path)
                        empty_mask = np.zeros(image_pil.size[::-1], dtype=np.uint8)
                        self.samples.append((img_path, empty_mask, label))

    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, index):
        img_path, mask_array, label = self.samples[index]
        image = Image.open(img_path).convert('RGB')
        mask = Image.fromarray((mask_array * 255).astype(np.uint8))
        if self.transform:
            image, mask = self.transform(image, mask)
        else:
            image = transforms.ToTensor()(image)
            mask = transforms.ToTensor()(mask)
        return image, mask, label

###############################################
# Per-image Z-score normalization (adaptive)
###############################################
def z_score_normalize(tensor):
    mean = tensor.mean()
    std = tensor.std() + 1e-6
    return (tensor - mean) / std

###############################################
# CLAHE Augmentation 함수 (OpenCV)
###############################################
def apply_clahe(image, clipLimit=2.0, tileGridSize=(8,8)):
    img_np = np.array(image)
    if len(img_np.shape) == 3 and img_np.shape[2] == 3:
        lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
        l, a, b = cv2.split(lab)
        clahe = cv2.createCLAHE(clipLimit=clipLimit, tileGridSize=tileGridSize)
        cl = clahe.apply(l)
        limg = cv2.merge((cl, a, b))
        final = cv2.cvtColor(limg, cv2.COLOR_LAB2RGB)
    else:
        clahe = cv2.createCLAHE(clipLimit=clipLimit, tileGridSize=tileGridSize)
        final = clahe.apply(img_np)
    return Image.fromarray(final)

###############################################
# joint_transform (Albumentations 기반 Augmentation)
###############################################
def joint_transform(image, mask, size=(224,224)):
    geom_transform = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.Rotate(limit=10, p=0.5),
        A.ElasticTransform(alpha=10, sigma=5, alpha_affine=5, p=0.3),
        A.Resize(height=size[0], width=size[1])
    ])
    image_np = np.array(image)
    mask_np = np.array(mask)
    augmented = geom_transform(image=image_np, mask=mask_np)
    image = augmented['image']
    mask = augmented['mask']
    
    intensity_transform = A.Compose([
        A.RandomBrightnessContrast(p=0.5),
        #A.CLAHE(clip_limit=3.0, p=0.5),
        A.CLAHE(clip_limit=1.0, tile_grid_size=(8,8), p=0.5),
        A.GaussianBlur(p=0.3)
    ])
    image = intensity_transform(image=image)['image']
    
    image = transforms.ToTensor()(image)
    image = z_score_normalize(image)
    mask = transforms.ToTensor()(mask)
    return image, mask

###############################################
# CutMix 함수 (alpha=1.5, p=0.7)
###############################################
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    cx = np.random.randint(W)
    cy = np.random.randint(H)
    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    return bbx1, bby1, bbx2, bby2

def cutmix_data(images, masks, alpha=1.5, p=0.7):
    if np.random.rand() > p:
        return images, masks
    lam = np.random.beta(alpha, alpha)
    rand_index = torch.randperm(images.size(0)).to(images.device)
    bbx1, bby1, bbx2, bby2 = rand_bbox(images.size(), lam)
    images[:, :, bbx1:bbx2, bby1:bby2] = images[rand_index, :, bbx1:bbx2, bby1:bby2]
    masks[:, :, bbx1:bbx2, bby1:bby2] = masks[rand_index, :, bbx1:bbx2, bby1:bby2]
    return images, masks

###############################################
# Advanced 후처리: Morphological Closing (Kernel 9×9)
###############################################
def postprocess_mask(mask, min_size=100):
    labeled_mask = label(mask)
    processed_mask = np.zeros_like(mask)
    for region in regionprops(labeled_mask):
        if region.area >= min_size:
            processed_mask[labeled_mask == region.label] = 1
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (9,9))
    closed = cv2.morphologyEx(processed_mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
    return closed

###############################################
# AttentionGate 모듈
###############################################
class AttentionGate(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionGate, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

###############################################
# MultiScaleFusion 모듈
###############################################
class MultiScaleFusion(nn.Module):
    def __init__(self, in_channels_list, out_channels):
        super(MultiScaleFusion, self).__init__()
        self.convs = nn.ModuleList([nn.Conv2d(ch, out_channels, kernel_size=1) for ch in in_channels_list])
        self.fuse = nn.Sequential(
            nn.Conv2d(len(in_channels_list) * out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, features):
        target_size = features[-1].size()[2:]
        processed = []
        for i, f in enumerate(features):
            if f.size()[2:] != target_size:
                f = F.interpolate(f, size=target_size, mode='bilinear', align_corners=False)
            f = self.convs[i](f)
            processed.append(f)
        out = torch.cat(processed, dim=1)
        out = self.fuse(out)
        return out

###############################################
# 기타 모델 관련 모듈 (StochasticDepth, GhostModule, SELayer, DynamicGating)
###############################################
class StochasticDepth(nn.Module):
    def __init__(self, p, mode="row"):
        super(StochasticDepth, self).__init__()
        self.p = p
        self.mode = mode

    def forward(self, x):
        if not self.training or self.p == 0.0:
            return x
        survival_rate = 1 - self.p
        if self.mode == "row":
            batch_size = x.shape[0]
            noise = torch.rand(batch_size, 1, 1, 1, device=x.device, dtype=x.dtype)
            binary_mask = (noise < survival_rate).float()
            return x / survival_rate * binary_mask
        else:
            if torch.rand(1).item() < self.p:
                return torch.zeros_like(x)
            else:
                return x

class GhostModule(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, ratio=2, dw_kernel_size=3, stride=1, padding=0, use_relu=True):
        super(GhostModule, self).__init__()
        self.out_channels = out_channels
        self.primary_channels = int(torch.ceil(torch.tensor(out_channels / ratio)))
        self.cheap_channels = out_channels - self.primary_channels
        self.primary_conv = nn.Sequential(
            nn.Conv2d(in_channels, self.primary_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(self.primary_channels),
            nn.ReLU(inplace=True) if use_relu else nn.Identity()
        )
        self.cheap_conv = nn.Sequential(
            nn.Conv2d(self.primary_channels, self.cheap_channels, dw_kernel_size, stride=1,
                      padding=dw_kernel_size // 2, groups=self.primary_channels, bias=False),
            nn.BatchNorm2d(self.cheap_channels),
            nn.ReLU(inplace=True) if use_relu else nn.Identity()
        )
    def forward(self, x):
        x1 = self.primary_conv(x)
        x2 = self.cheap_conv(x1)
        out = torch.cat([x1, x2], dim=1)
        return out[:, :self.out_channels, :, :].contiguous()

def ghost_conv_block(in_channels, out_channels, use_relu=True):
    return nn.Sequential(
        GhostModule(in_channels, out_channels, kernel_size=3, ratio=2, dw_kernel_size=3, stride=1, padding=1, use_relu=use_relu),
        GhostModule(out_channels, out_channels, kernel_size=3, ratio=2, dw_kernel_size=3, stride=1, padding=1, use_relu=use_relu)
    )

class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )
    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

class DynamicGating(nn.Module):
    def __init__(self, num_branches, hidden_dim=64, dropout_prob=0.1, init_temperature=2.0, iterations=3):
        super(DynamicGating, self).__init__()
        self.temperature = init_temperature
        self.iterations = iterations
        self.dropout_prob = dropout_prob
        self.fc = nn.Sequential(
            nn.Linear(num_branches, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(p=self.dropout_prob),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(p=self.dropout_prob),
            nn.Linear(hidden_dim, num_branches)
        )
        self.layernorm = nn.LayerNorm(num_branches)
        self.softmax = nn.Softmax(dim=1)
        self.res_scale = nn.Parameter(torch.ones(1))
        self.saved_log_probs = None
        self.entropy = None
        self.rl_loss = 0.0

    def forward(self, features):
        norm_features = self.layernorm(features)
        logits = self.fc(norm_features)
        logits = self.layernorm(logits)
        scaled_logits = logits / self.temperature
        gates = self.softmax(scaled_logits)
        for _ in range(self.iterations - 1):
            updated_features = norm_features + self.res_scale * gates
            logits = self.fc(updated_features)
            logits = self.layernorm(logits)
            scaled_logits = logits / self.temperature
            new_gates = self.softmax(scaled_logits)
            gates = 0.5 * gates + 0.5 * new_gates
        return gates

###############################################
# QuadAgentBlock (Lite 버전, 3 브랜치)
###############################################
class QuadAgentBlock(nn.Module):
    def __init__(self, in_channels, out_channels, gating_dropout=0.3, gating_hidden_dim=32, gating_temperature=1.5, stochastic_depth_prob=0.5):
        super().__init__()
        branch_channels = out_channels // 4
        self.branch1 = GhostModule(in_channels, branch_channels, ratio=4)
        self.branch2 = GhostModule(in_channels, branch_channels, kernel_size=3, padding=1, ratio=4)
        self.branch3 = nn.Sequential(
            GhostModule(in_channels, branch_channels, ratio=4),
            SELayer(branch_channels, reduction=32)
        )
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.gating = DynamicGating(num_branches=3, hidden_dim=gating_hidden_dim, dropout_prob=gating_dropout, init_temperature=gating_temperature)
        self.fusion_conv = GhostModule(branch_channels * 3, out_channels, ratio=2)
        self.res_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) if in_channels != out_channels else nn.Identity()
        self.stochastic_depth = StochasticDepth(stochastic_depth_prob)

    def forward(self, x):
        b1 = self.branch1(x)
        b2 = self.branch2(x)
        b3 = self.branch3(x)
        gap_b1 = self.gap(b1).view(x.size(0), -1)
        gap_b2 = self.gap(b2).view(x.size(0), -1)
        gap_b3 = self.gap(b3).view(x.size(0), -1)
        features = torch.stack([gap_b1.mean(dim=1), gap_b2.mean(dim=1), gap_b3.mean(dim=1)], dim=1)
        gates = self.gating(features)
        b1 = b1 * gates[:, 0].view(-1, 1, 1, 1)
        b2 = b2 * gates[:, 1].view(-1, 1, 1, 1)
        b3 = b3 * gates[:, 2].view(-1, 1, 1, 1)
        out = torch.cat([b1, b2, b3], dim=1)
        out = self.fusion_conv(out)
        res = self.res_conv(x)
        res = self.stochastic_depth(res)
        return out + res


###############################################
# DARKViT_UNet: ViT 기반 Encoder + QuadAgentBlock + Transformer Bottleneck + UNet Decoder
###############################################
from timm.models import create_model

class DARKViT_UNet(nn.Module):
    def __init__(self, out_channels=1):
        super(DARKViT_UNet, self).__init__()
        self.encoder = create_model('swin_large_patch4_window7_224', pretrained=True, features_only=True)
        self.bottleneck = QuadAgentBlock(in_channels=1536, out_channels=1536, 
                                         gating_dropout=0.3, gating_hidden_dim=32, 
                                         gating_temperature=1.5, stochastic_depth_prob=0.5)
        self.transformer_bottleneck1 = TransformerBottleneck(d_model=1536, nhead=8, num_layers=1, 
                                                             dim_feedforward=2048, dropout=0.1)
        self.transformer_bottleneck2 = TransformerBottleneck(d_model=1536, nhead=8, num_layers=1, 
                                                             dim_feedforward=2048, dropout=0.1)
        self.decoder = UNetDecoder_Swin()
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
    
    def forward(self, x):
        # encoder에서 feature map 리스트 얻기: [f0, f1, f2, f3]
        features = self.encoder(x)
        
        # 만약 각 feature map이 (N, H, W, C) 형식이라면 (N, C, H, W)로 변환
        for i in range(len(features)):
            # f0: 192, f1:384, f2:768, f3:1536 채널이어야 함
            if features[i].shape[1] not in {192, 384, 768, 1536}:
                features[i] = features[i].permute(0, 3, 1, 2).contiguous()
        
        # 반환된 순서는 [f0, f1, f2, f3]이므로, decoder가 기대하는 [f3, f2, f1, f0] 순서로 재정렬
        f0, f1, f2, f3 = features  
        features_reordered = [f3, f2, f1, f0]
        
        # 이제 features_reordered[0]는 f3: (B,1536,7,7)
        b_out = self.bottleneck(features_reordered[0])
        t_out1 = self.transformer_bottleneck1(b_out)
        t_out2 = self.transformer_bottleneck2(t_out1)
        features_reordered[0] = t_out2
        
        decoded = self.decoder(features_reordered)
        return self.final_conv(decoded)



###############################################
# UNet Decoder for Swin Backbone
###############################################
class UNetDecoder_Swin(nn.Module):
    def __init__(self):
        super(UNetDecoder_Swin, self).__init__()
        # Swin‑Large feature sizes (가정):
        # f0: (B,192,56,56), f1: (B,384,28,28), f2: (B,768,14,14), f3: (B,1536,7,7)
        self.up1 = nn.ConvTranspose2d(1536, 768, kernel_size=2, stride=2)  # 7->14
        self.conv1 = nn.Sequential(
            nn.Conv2d(768 + 768, 768, kernel_size=3, padding=1),
            nn.BatchNorm2d(768),
            nn.ReLU(inplace=True)
        )
        self.up2 = nn.ConvTranspose2d(768, 384, kernel_size=2, stride=2)   # 14->28
        self.conv2 = nn.Sequential(
            nn.Conv2d(384 + 384, 384, kernel_size=3, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU(inplace=True)
        )
        self.up3 = nn.ConvTranspose2d(384, 192, kernel_size=2, stride=2)   # 28->56
        self.conv3 = nn.Sequential(
            nn.Conv2d(192 + 192, 192, kernel_size=3, padding=1),
            nn.BatchNorm2d(192),
            nn.ReLU(inplace=True)
        )
        self.up4 = nn.ConvTranspose2d(192, 64, kernel_size=2, stride=2)    # 56->112
        self.conv4 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
    def forward(self, features_list):
        # features_list: [f0, f1, f2, f3]
        f3, f2, f1, f0 = features_list  # ✅ 순서 맞추기
        x = f3  # (B,1536,7,7)
        x = self.up1(x)  # -> (B,768,14,14)
        x = torch.cat([x, f2], dim=1)  # f2: (B,768,14,14) → (B,1536,14,14)
        x = self.conv1(x)  # -> (B,768,14,14)
        x = self.up2(x)  # -> (B,384,28,28)
        x = torch.cat([x, f1], dim=1)  # f1: (B,384,28,28) → (B,768,28,28)
        x = self.conv2(x)  # -> (B,384,28,28)
        x = self.up3(x)  # -> (B,192,56,56)
        x = torch.cat([x, f0], dim=1)  # f0: (B,192,56,56) → (B,384,56,56)
        x = self.conv3(x)  # -> (B,192,56,56)
        x = self.up4(x)  # -> (B,64,112,112)
        x = self.conv4(x)  # -> (B,64,112,112)
        x = F.interpolate(x, size=(224,224), mode='bilinear', align_corners=False)
        return x

###############################################
# Transformer Bottleneck (연속 두 개 적용)
###############################################
class TransformerBottleneck(nn.Module):
    def __init__(self, d_model=1536, nhead=8, num_layers=1, dim_feedforward=2048, dropout=0.1):
        super(TransformerBottleneck, self).__init__()
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead,
                                                    dim_feedforward=dim_feedforward, dropout=dropout)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
    def forward(self, x):
        B, C, H, W = x.size()
        x = x.view(B, C, H * W).permute(2, 0, 1)
        x = self.transformer(x)
        x = x.permute(1, 2, 0).view(B, C, H, W)
        return x

###############################################
# Swin‑Transformer Backbone 기반 UNet with Attention + Lite Bottleneck
###############################################
from timm.models import create_model

class PretrainedSwin_UNet_AttentionFusion(nn.Module):
    def __init__(self, out_channels=1):
        super(PretrainedSwin_UNet_AttentionFusion, self).__init__()
        self.encoder = create_model('swin_large_patch4_window7_224', pretrained=True, features_only=True)
        self.bottleneck = QuadAgentBlock(in_channels=1536, out_channels=1536, gating_dropout=0.3, gating_hidden_dim=32, gating_temperature=1.5, stochastic_depth_prob=0.5)
        self.transformer_bottleneck1 = TransformerBottleneck(d_model=1536, nhead=8, num_layers=1, dim_feedforward=2048, dropout=0.1)
        self.transformer_bottleneck2 = TransformerBottleneck(d_model=1536, nhead=8, num_layers=1, dim_feedforward=2048, dropout=0.1)
        self.decoder = UNetDecoder_Swin()
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
        
    def forward(self, x):
        features_list = self.encoder(x)  # [f0, f1, f2, f3]
        
        # 디버깅: Feature Map Shape 확인
        #print("Original feature shapes:")
        #for i, f in enumerate(features_list):
        #    print(f"f{i}: {f.shape}")

        # (N, H, W, C) → (N, C, H, W) 변환이 필요한 경우 적용
        for i in range(len(features_list)):
            if features_list[i].shape[1] not in {192, 384, 768, 1536}:  
                features_list[i] = features_list[i].permute(0, 3, 1, 2).contiguous()

        # 디코더에 올바른 순서로 전달되도록 변환
        f0, f1, f2, f3 = features_list  # Swin Transformer의 기본 반환 순서
        features_list = [f3, f2, f1, f0]  # (B, 1536, 7, 7) → (B, 192, 56, 56)

        # 디버깅: 변환 후 Feature Map Shape 확인
        #print("\nReordered feature shapes (for decoder):")
        #for i, f in enumerate(features_list):
        #    print(f"features_list[{i}]: {f.shape}")

        # Bottleneck과 Transformer Bottleneck 적용
        b_out = self.bottleneck(features_list[0])  # f3가 들어가야 함
        t_out1 = self.transformer_bottleneck1(b_out)
        t_out2 = self.transformer_bottleneck2(t_out1)
        features_list[0] = t_out2  # Bottleneck을 거친 f3

        # UNet 디코더 적용
        decoded = self.decoder(features_list)
        return self.final_conv(decoded)




###############################################
# Loss & Metric Functions (HD 포함)
###############################################
from segmentation_models_pytorch.losses import LovaszLoss
lv_loss = LovaszLoss(mode='binary')

def focal_tversky_loss(pred, target, alpha=0.5, beta=0.5, gamma=4/3, smooth=1e-6):
    pred = torch.sigmoid(pred)
    if pred.shape[-2:] != target.shape[-2:]:
        pred = F.interpolate(pred, size=target.shape[-2:], mode='bilinear', align_corners=False)
    pred = pred.view(pred.size(0), -1)
    target = target.view(target.size(0), -1)
    tp = (pred * target).sum(dim=1)
    fp = ((1 - target) * pred).sum(dim=1)
    fn = (target * (1 - pred)).sum(dim=1)
    tversky_index = (tp + smooth) / (tp + alpha * fp + beta * fn + smooth)
    loss = (1 - tversky_index) ** gamma
    return loss.mean()

def boundary_loss(pred, target):
    pred = torch.sigmoid(pred)
    if pred.shape[-2:] != target.shape[-2:]:
        pred = F.interpolate(pred, size=target.shape[-2:], mode='bilinear', align_corners=False)
    pred_np = pred.detach().cpu().numpy()
    target_np = target.detach().cpu().numpy()
    boundary_masks = []
    for i in range(pred_np.shape[0]):
        gt_mask = target_np[i, 0]
        boundary = find_boundaries(gt_mask, mode='thick')
        boundary_masks.append(boundary.astype(np.float32))
    boundary_masks = np.stack(boundary_masks, axis=0)[:, None, :, :]
    boundary_masks_torch = torch.from_numpy(boundary_masks).to(pred.device)
    intersect = (pred * boundary_masks_torch).sum()
    denom = pred.sum() + boundary_masks_torch.sum()
    boundary_dice = (2.0 * intersect) / (denom + 1e-6)
    return 1.0 - boundary_dice

def combined_loss(outputs, masks, pos_weight=None, lambda_boundary=0.2, lambda_lovasz=0.6):
    if outputs.shape[-2:] != masks.shape[-2:]:
        outputs = F.interpolate(outputs, size=masks.shape[-2:], mode='bilinear', align_corners=False)
    loss_ft = focal_tversky_loss(outputs, masks)
    if pos_weight is not None:
        loss_bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(outputs, masks.float())
    else:
        loss_bce = nn.BCEWithLogitsLoss()(outputs, masks.float())
    bl = boundary_loss(outputs, masks)
    lv = lv_loss(outputs, masks)
    return 0.4 * lv + 0.3 * loss_ft + 0.3 * loss_bce + lambda_boundary * bl

def dice_f1_precision_recall(pred, target, threshold=0.5, smooth=1e-6):
    if pred.shape[-2:] != target.shape[-2:]:
        pred = F.interpolate(pred, size=target.shape[-2:], mode='bilinear', align_corners=False)
    pred_bin = (torch.sigmoid(pred) > threshold).float()
    target_bin = target.float()
    intersection = (pred_bin * target_bin).sum()
    precision = intersection / (pred_bin.sum() + smooth)
    recall = intersection / (target_bin.sum() + smooth)
    f1 = 2 * (precision * recall) / (precision + recall + smooth)
    return f1.item(), precision.item(), recall.item()

def iou_metric(pred, target, threshold=0.5, smooth=1e-6):
    if pred.shape[-2:] != target.shape[-2:]:
        pred = F.interpolate(pred, size=target.shape[-2:], mode='bilinear', align_corners=False)
    pred_bin = (torch.sigmoid(pred) > threshold).float()
    target_bin = target.float()
    intersection = (pred_bin * target_bin).sum()
    union = pred_bin.sum() + target_bin.sum() - intersection
    return (intersection + smooth) / (union + smooth)

def compute_batch_metrics_new(outputs, masks, threshold=0.5, smooth=1e-6):
    if outputs.shape[-2:] != masks.shape[-2:]:
        outputs = F.interpolate(outputs, size=masks.shape[-2:], mode='bilinear', align_corners=False)
    pred_probs = torch.sigmoid(outputs).detach()
    processed_preds = []
    for i in range(pred_probs.size(0)):
        pred_np = pred_probs[i].cpu().numpy()[0]
        pred_bin = (pred_np > threshold).astype(np.uint8)
        processed = postprocess_mask(pred_bin, min_size=100)
        processed_preds.append(processed)
    batch_iou = []
    batch_f1 = []
    batch_precision = []
    batch_recall = []
    for i in range(pred_probs.size(0)):
        pred = processed_preds[i]
        gt = masks[i].cpu().numpy()[0]
        intersection = np.sum(pred * gt)
        union = np.sum(pred) + np.sum(gt) - intersection
        iou = (intersection + smooth) / (union + smooth)
        batch_iou.append(iou)
        precision = intersection / (np.sum(pred) + smooth)
        recall = intersection / (np.sum(gt) + smooth)
        f1 = 2 * (precision * recall) / (precision + recall + smooth)
        batch_f1.append(f1)
        batch_precision.append(precision)
        batch_recall.append(recall)
    probs = pred_probs.cpu().numpy().flatten()
    masks_np = masks.cpu().numpy().flatten()
    try:
        auc_score = roc_auc_score(masks_np, probs)
    except ValueError:
        auc_score = float('nan')
    return auc_score, np.mean(batch_iou), np.mean(batch_f1), np.mean(batch_precision), np.mean(batch_recall)

def hausdorff_distance(pred, target):
    return medpy_binary.hd95(pred, target)

def compute_hd_metric(outputs, masks, threshold=0.5):
    if outputs.shape[-2:] != masks.shape[-2:]:
        outputs = F.interpolate(outputs, size=masks.shape[-2:], mode='bilinear', align_corners=False)
    pred_probs = torch.sigmoid(outputs).detach().cpu().numpy()
    masks_np = masks.detach().cpu().numpy()
    hd_list = []
    for i in range(outputs.size(0)):
        pred_bin = (pred_probs[i, 0] > threshold).astype(np.uint8)
        gt_bin = (masks_np[i, 0] > 0.5).astype(np.uint8)
        try:
            hd = medpy_binary.hd95(pred_bin, gt_bin)
        except Exception:
            hd = np.nan
        hd_list.append(hd)
    return np.nanmean(hd_list)

def evaluate_with_metrics(model, dataloader, device, threshold=0.5, lambda_boundary=0.2, pos_weight=None):
    model.eval()
    total_loss = 0.0
    total_auc  = 0.0
    total_iou  = 0.0
    total_f1   = 0.0
    total_precision = 0.0
    total_recall = 0.0
    total_hd = 0.0
    total_samples = 0
    with torch.no_grad():
        for images, masks, _ in dataloader:
            images = images.to(device)
            masks = masks.to(device)
            outputs = model(images)
            loss = combined_loss(outputs, masks, pos_weight=pos_weight, lambda_boundary=lambda_boundary)
            auc_score, iou_val, f1_val, prec, rec = compute_batch_metrics_new(outputs, masks, threshold=threshold, smooth=1e-6)
            hd_val = compute_hd_metric(outputs, masks, threshold=threshold)
            bs = images.size(0)
            total_loss += loss.item() * bs
            total_auc  += auc_score * bs
            total_iou  += iou_val * bs
            total_f1   += f1_val * bs
            total_precision += prec * bs
            total_recall += rec * bs
            total_hd += hd_val * bs
            total_samples += bs
    avg_loss = total_loss / total_samples
    avg_auc  = total_auc / total_samples
    avg_iou  = total_iou / total_samples
    avg_f1   = total_f1 / total_samples
    avg_precision = total_precision / total_samples
    avg_recall = total_recall / total_samples
    avg_hd = total_hd / total_samples
    return avg_loss, avg_auc, avg_iou, avg_f1, avg_precision, avg_recall, avg_hd

###############################################
# TTA (Test Time Augmentation) 함수 (Noise & ColorJitter 추가)
###############################################
def tta_predict(model, image, device):
    cj = ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)
    def add_noise(x, std=0.05):
        noise = torch.randn_like(x) * std
        return x + noise
    def identity(x): return x
    def hflip(x): return torch.flip(x, dims=[-1])
    def vflip(x): return torch.flip(x, dims=[-2])
    def rot90(x): return torch.rot90(x, k=1, dims=[-2, -1])
    def inv_hflip(x): return torch.flip(x, dims=[-1])
    def inv_vflip(x): return torch.flip(x, dims=[-2])
    def inv_rot90(x): return torch.rot90(x, k=3, dims=[-2, -1])
    
    transforms_list = [
        (identity, identity),
        (hflip, inv_hflip),
        (vflip, inv_vflip),
        (rot90, inv_rot90),
        (lambda x: add_noise(cj(x)), lambda x: x)
    ]
    predictions = []
    model.eval()
    with torch.no_grad():
        for aug, inv in transforms_list:
            augmented = aug(image)
            output = model(augmented.unsqueeze(0).to(device))
            output = torch.sigmoid(output)
            output = inv(output).cpu()
            predictions.append(output)
    avg_prediction = torch.mean(torch.stack(predictions), dim=0)
    return avg_prediction

###############################################
# Optimizer Scheduler (SGDR: CosineAnnealingWarmRestarts)
###############################################
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

###############################################
# 데이터 로더 및 BUSI Dataset 설정
###############################################
data_path = '/kaggle/input/breast-ultrasound-images-dataset/Dataset_BUSI_with_GT/'
full_dataset = BUSISegmentationDataset(data_path, transform=joint_transform)
indices = np.arange(len(full_dataset))
train_val_idx, test_idx = train_test_split(indices, test_size=0.2, random_state=42)
train_idx, val_idx = train_test_split(train_val_idx, test_size=0.25, random_state=42)
train_dataset = Subset(full_dataset, train_idx)
val_dataset = Subset(full_dataset, val_idx)
test_dataset = Subset(full_dataset, test_idx)

train_loader = DataLoader(
    train_dataset, batch_size=16, shuffle=True, num_workers=4,
    worker_init_fn=lambda worker_id: np.random.seed(42 + worker_id),
    pin_memory=True, persistent_workers=True
)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True)

def calculate_pos_weight(loader):
    total_pixels = 0
    positive_pixels = 0
    for images, masks, _ in loader:
        positive_pixels += masks.sum().item()
        total_pixels += masks.numel()
    negative_pixels = total_pixels - positive_pixels
    return torch.tensor(negative_pixels / (positive_pixels + 1e-6)).to(device)

###############################################
# 모델, Optimizer, Scheduler 설정
###############################################
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PretrainedSwin_UNet_AttentionFusion(out_channels=1)
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
model = model.to(device)

pos_weight = calculate_pos_weight(train_loader)
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=5e-4)
num_epochs = 500

scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-5)

"""###############################################
# Training Loop (Early Stopping patience=50)
###############################################
from torch.cuda.amp import GradScaler, autocast
scaler = GradScaler()
patience = 15
best_val_f1 = 0.0
patience_counter = 0

for epoch in range(num_epochs):
    current_temp = 5.0 - ((5.0 - 1.0) * epoch / num_epochs)
    if hasattr(model, 'bottleneck'):
        if isinstance(model, nn.DataParallel):
            model.module.bottleneck.gating.temperature = current_temp
        else:
            model.bottleneck.gating.temperature = current_temp

    model.train()
    epoch_loss = 0.0
    epoch_auc = 0.0
    epoch_iou = 0.0
    epoch_f1 = 0.0
    epoch_precision = 0.0
    epoch_recall = 0.0
    total_samples = 0

    for images, masks, _ in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} Training"):
        images = images.to(device)
        masks = masks.to(device)
        images, masks = cutmix_data(images, masks, alpha=0.8, p=0.7)
        optimizer.zero_grad()
        with autocast():
            outputs = model(images)
            loss_seg = combined_loss(outputs, masks, pos_weight=pos_weight, lambda_boundary=0.2, lambda_lovasz=0.6)
            loss_total = loss_seg  # RL 미사용
        scaler.scale(loss_total).backward()
        scaler.unscale_(optimizer)
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        bs = images.size(0)
        epoch_loss += loss_total.item() * bs
        auc_score, iou_val, f1_val, prec, rec = compute_batch_metrics_new(outputs, masks, threshold=0.5, smooth=1e-6)
        epoch_auc += auc_score * bs
        epoch_iou += iou_val * bs
        epoch_f1 += f1_val * bs
        epoch_precision += prec * bs
        epoch_recall += rec * bs
        total_samples += bs

    scheduler.step(epoch)
    avg_loss = epoch_loss / total_samples
    avg_auc = epoch_auc / total_samples
    avg_iou = epoch_iou / total_samples
    avg_f1 = epoch_f1 / total_samples
    avg_precision = epoch_precision / total_samples
    avg_recall = epoch_recall / total_samples
    print(f"Epoch {epoch+1}/{num_epochs} | Loss: {avg_loss:.4f}, AUC: {avg_auc:.4f}, IoU: {avg_iou:.4f}, Dice/F1: {avg_f1:.4f}, Precision: {avg_precision:.4f}, Recall: {avg_recall:.4f} (Temp: {current_temp:.2f})")
    
    val_loss, val_auc, val_iou, val_f1, val_precision, val_recall, val_hd = evaluate_with_metrics(model, val_loader, device, threshold=0.5, lambda_boundary=0.2, pos_weight=pos_weight)
    print(f"[Val] Loss: {val_loss:.4f}, AUC: {val_auc:.4f}, IoU: {val_iou:.4f}, Dice/F1: {val_f1:.4f}, Precision: {val_precision:.4f}, Recall: {val_recall:.4f}, HD95: {val_hd:.4f}")
    
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        best_weights = copy.deepcopy(model.state_dict())
        torch.save(best_weights, 'best_model.pth')
        patience_counter = 0
        print("🍀 New best validation Dice achieved! Dice: {:.4f} 🍀".format(best_val_f1))
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break
"""

'###############################################\n# Training Loop (Early Stopping patience=50)\n###############################################\nfrom torch.cuda.amp import GradScaler, autocast\nscaler = GradScaler()\npatience = 15\nbest_val_f1 = 0.0\npatience_counter = 0\n\nfor epoch in range(num_epochs):\n    current_temp = 5.0 - ((5.0 - 1.0) * epoch / num_epochs)\n    if hasattr(model, \'bottleneck\'):\n        if isinstance(model, nn.DataParallel):\n            model.module.bottleneck.gating.temperature = current_temp\n        else:\n            model.bottleneck.gating.temperature = current_temp\n\n    model.train()\n    epoch_loss = 0.0\n    epoch_auc = 0.0\n    epoch_iou = 0.0\n    epoch_f1 = 0.0\n    epoch_precision = 0.0\n    epoch_recall = 0.0\n    total_samples = 0\n\n    for images, masks, _ in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} Training"):\n        images = images.to(device)\n        masks = masks.to(device)\n        images, masks = cutmix_data(images, ma

In [4]:
test_loss, test_auc, test_iou, test_f1, test_precision, test_recall, test_hd = evaluate_with_metrics(model, test_loader, device, threshold=0.5, lambda_boundary=0.2, pos_weight=pos_weight)
print(f"[Test] Loss: {test_loss:.4f}, AUC: {test_auc:.4f}, IoU: {test_iou:.4f}, Dice/F1: {test_f1:.4f}, Precision: {test_precision:.4f}, Recall: {test_recall:.4f}, HD95: {test_hd:.4f}")


[Test] Loss: 1.0127, AUC: 0.9585, IoU: 0.7442, Dice/F1: 0.6459, Precision: 0.6761, Recall: 0.6465, HD95: 20.2945


### 🔍 **SelectiveGatingBlock과 유사한 네이밍 아이디어**
Selective Gating의 개념을 반영하면서도 고유성을 유지할 수 있는 새로운 이름들을 제안해볼게요.

---
## 🚀 **새로운 네이밍 후보**
| **이름**                         | **설명** |
|----------------------------------|------------------------------------------------------------|
| **SelectiveGateBlock**           | `SelectiveGatingBlock`을 더 간결하게 표현 |
| **AdaptiveSelectionGate**        | 선택적(Gating) & 적응적(Adaptive) 요소를 반영 |
| **FilteredGatingBlock**          | 불필요한 Feature를 "걸러내는(Filter)" 개념을 반영 |
| **DynamicSelectionGate**         | 동적으로 선택하는 Gating을 강조 |
| **SmartSelectiveGate**           | "스마트한 Feature 선택"을 강조 |
| **FeatureSelectionGate**         | Feature Selection을 직접적으로 표현 |
| **GatedSelectionModule (GSM)**   | Selection 개념과 Gating을 모듈로 조합 (논문에 적절) |
| **S-GateBlock**                  | "Selective"의 약자로 줄인 형태 (코드에서 직관적) |
| **ActiveSelectionGate**          | "Active Selection"을 통해 Feature를 강조하는 느낌 |
| **PrioritizedGatingBlock**       | 중요한 Feature를 우선적으로 반영하는 개념 |

---
## 🔥 **최종 추천**
✅ **`SelectiveGateBlock`** → 기존 이름에서 더 직관적으로 축약  
✅ **`DynamicSelectionGate`** → 동적 선택이라는 핵심 개념을 강조  
✅ **`GatedSelectionModule (GSM)`** → 논문에서 사용하기 좋은 고급스러운 네이밍  
✅ **`PrioritizedGatingBlock`** → 중요한 Feature를 선별하는 개념을 포함  

📌 **최종 의견:**  
- **Selective Gating** 개념을 유지하려면 `SelectiveGateBlock`이 가장 직관적  
- 좀 더 역동적인 느낌을 주고 싶다면 `DynamicSelectionGate` 추천  
- 논문에서는 **GSM (GatedSelectionModule)** 같은 약어가 깔끔할 수도 있음!  

> 👉 **"SelectiveGatingBlock" 대신 `SelectiveGateBlock`, `DynamicSelectionGate`, `GatedSelectionModule (GSM)`을 추천합니다!** 🚀



### 📑 **논문에 적합한 새로운 블록 네이밍 10개 추천**  

논문에서 사용할 수 있도록 **고급스럽고 기술적인 느낌을 주는 이름**을 선정했어!  

---

### **1️⃣ Feature Selection & Gating 기반**  
1. **AdaptiveSelectiveGate (ASG)** → 선택적 Gating + 적응적 구조  
2. **DynamicFeatureGating (DFG)** → Feature 선택을 동적으로 수행  
3. **HierarchicalSelectionGate (HSG)** → 계층적으로 Feature를 선택하는 개념  
4. **PriorityGatedFusion (PGF)** → 중요한 Feature를 우선적으로 선택 후 결합  
5. **Context-Aware Selection Gate (CSG)** → 문맥(Context) 기반으로 선택적 활성화  

---

### **2️⃣ Gating & Fusion을 강조한 이름**  
6. **EfficientSelectiveGating (ESG)** → 경량화와 선택적 Gating을 강조  
7. **Multi-Scale Gated Fusion (MSGF)** → 다중 해상도 Feature의 선택적 융합  
8. **Neural Adaptive Selection Gate (NASG)** → 신경망 기반 적응적 선택  
9. **Attentive Gating Module (AGM)** → Attention과 Gating을 결합한 구조  
10. **Selective Representation Learning Block (SRLB)** → 선택적 Feature 학습을 강조  

---

### ✅ **논문용 최종 추천 Top 3**  
1. **AdaptiveSelectiveGate (ASG)** → 논문에 쓰기 직관적이면서 깔끔  
2. **HierarchicalSelectionGate (HSG)** → 계층적 구조 강조 (논문에서 고급스럽게 보임)  
3. **Neural Adaptive Selection Gate (NASG)** → 신경망 기반 적응적 선택 (고급 AI 모델 느낌)  

📌 **특징**  
- 약어를 포함해 논문에서 사용하기 용이  
- 기존 `SelectiveGatingBlock`과 차별화되면서도 같은 개념을 포함  
- 최신 연구 트렌드(Adaptive, Dynamic, Context-Aware 등) 반영  

> 🚀 **"논문용 네이밍"으로 `AdaptiveSelectiveGate (ASG)`, `HierarchicalSelectionGate (HSG)`, `Neural Adaptive Selection Gate (NASG)` 추천!**

# 256x256

In [2]:
import os
import re
import math
import copy
import glob
import random
import warnings
import numpy as np
from collections import defaultdict
from PIL import Image, ImageOps
import matplotlib.pyplot as plt
import cv2  # CLAHE 적용을 위해 OpenCV 사용

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
from skimage.segmentation import find_boundaries
from skimage.measure import label, regionprops
import scipy.ndimage

# Albumentations 기반 augmentation
import albumentations as A
from albumentations.pytorch import ToTensorV2

# medpy HD metric
from medpy.metric import binary as medpy_binary
from torchvision.transforms import ColorJitter

###############################################
# Seed 및 Warning 설정
###############################################
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(42)
warnings.filterwarnings('ignore')

###############################################
# BUSI Segmentation Dataset
###############################################
class BUSISegmentationDataset(Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.transform = transform  
        self.samples = []  # 반드시 초기화
        self._prepare_samples()

    def _prepare_samples(self):
        labels = os.listdir(self.data_path)
        for label in labels:
            folder_path = os.path.join(self.data_path, label)
            if not os.path.isdir(folder_path):
                continue
            files = os.listdir(folder_path)
            image_files = sorted([f for f in files if '_mask' not in f and f.endswith('.png')])
            mask_files  = sorted([f for f in files if '_mask' in f and f.endswith('.png')])
            pattern_img = re.compile(rf'{re.escape(label)} \((\d+)\)\.png')
            pattern_mask = re.compile(rf'{re.escape(label)} \((\d+)\)_mask(?:_\d+)?\.png')
            mask_dict = {}
            for mf in mask_files:
                m = pattern_mask.fullmatch(mf)
                if m:
                    idx = m.group(1)
                    mask_dict.setdefault(idx, []).append(mf)
            for im in image_files:
                m = pattern_img.fullmatch(im)
                if m:
                    idx = m.group(1)
                    img_path = os.path.join(folder_path, im)
                    if idx in mask_dict:
                        mask_paths = [os.path.join(folder_path, mf) for mf in mask_dict[idx]]
                        combined_mask = None
                        for mp in mask_paths:
                            mask_img = Image.open(mp).convert('L')
                            mask_arr = np.array(mask_img)
                            mask_binary = (mask_arr > 128).astype(np.uint8)
                            if combined_mask is None:
                                combined_mask = mask_binary
                            else:
                                combined_mask = np.maximum(combined_mask, mask_binary)
                        self.samples.append((img_path, combined_mask, label))
                    else:
                        image_pil = Image.open(img_path)
                        empty_mask = np.zeros(image_pil.size[::-1], dtype=np.uint8)
                        self.samples.append((img_path, empty_mask, label))

    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, index):
        img_path, mask_array, label = self.samples[index]
        image = Image.open(img_path).convert('RGB')
        mask = Image.fromarray((mask_array * 255).astype(np.uint8))
        if self.transform:
            image, mask = self.transform(image, mask)
        else:
            image = transforms.ToTensor()(image)
            mask = transforms.ToTensor()(mask)
        return image, mask, label

###############################################
# Per-image Z-score normalization (adaptive)
###############################################
def z_score_normalize(tensor):
    mean = tensor.mean()
    std = tensor.std() + 1e-6
    return (tensor - mean) / std

###############################################
# CLAHE Augmentation 함수 (OpenCV)
###############################################
def apply_clahe(image, clipLimit=2.0, tileGridSize=(8,8)):
    img_np = np.array(image)
    if len(img_np.shape) == 3 and img_np.shape[2] == 3:
        lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
        l, a, b = cv2.split(lab)
        clahe = cv2.createCLAHE(clipLimit=clipLimit, tileGridSize=tileGridSize)
        cl = clahe.apply(l)
        limg = cv2.merge((cl, a, b))
        final = cv2.cvtColor(limg, cv2.COLOR_LAB2RGB)
    else:
        clahe = cv2.createCLAHE(clipLimit=clipLimit, tileGridSize=tileGridSize)
        final = clahe.apply(img_np)
    return Image.fromarray(final)

###############################################
# joint_transform (Albumentations 기반 Augmentation)
###############################################
def joint_transform(image, mask, size=(256,256)):  # 기본 size를 (256,256)으로 변경
    geom_transform = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.Rotate(limit=10, p=0.5),
        A.ElasticTransform(alpha=10, sigma=5, alpha_affine=5, p=0.3),
        A.Resize(height=size[0], width=size[1])
    ])
    image_np = np.array(image)
    mask_np = np.array(mask)
    augmented = geom_transform(image=image_np, mask=mask_np)
    image = augmented['image']
    mask = augmented['mask']
    
    intensity_transform = A.Compose([
        A.RandomBrightnessContrast(p=0.5),
        #A.CLAHE(clip_limit=3.0, p=0.5),
        A.CLAHE(clip_limit=1.0, tile_grid_size=(8,8), p=0.5),
        A.GaussianBlur(p=0.3)
    ])
    image = intensity_transform(image=image)['image']
    
    image = transforms.ToTensor()(image)
    image = z_score_normalize(image)
    mask = transforms.ToTensor()(mask)
    return image, mask

###############################################
# CutMix 함수 (alpha=1.5, p=0.7)
###############################################
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    cx = np.random.randint(W)
    cy = np.random.randint(H)
    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    return bbx1, bby1, bbx2, bby2

def cutmix_data(images, masks, alpha=1.5, p=0.7):
    if np.random.rand() > p:
        return images, masks
    lam = np.random.beta(alpha, alpha)
    rand_index = torch.randperm(images.size(0)).to(images.device)
    bbx1, bby1, bbx2, bby2 = rand_bbox(images.size(), lam)
    images[:, :, bbx1:bbx2, bby1:bby2] = images[rand_index, :, bbx1:bbx2, bby1:bby2]
    masks[:, :, bbx1:bbx2, bby1:bby2] = masks[rand_index, :, bbx1:bbx2, bby1:bby2]
    return images, masks

###############################################
# Advanced 후처리: Morphological Closing (Kernel 9×9)
###############################################
def postprocess_mask(mask, min_size=100):
    labeled_mask = label(mask)
    processed_mask = np.zeros_like(mask)
    for region in regionprops(labeled_mask):
        if region.area >= min_size:
            processed_mask[labeled_mask == region.label] = 1
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (9,9))
    closed = cv2.morphologyEx(processed_mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
    return closed

###############################################
# AttentionGate 모듈
###############################################
class AttentionGate(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionGate, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

###############################################
# MultiScaleFusion 모듈
###############################################
class MultiScaleFusion(nn.Module):
    def __init__(self, in_channels_list, out_channels):
        super(MultiScaleFusion, self).__init__()
        self.convs = nn.ModuleList([nn.Conv2d(ch, out_channels, kernel_size=1) for ch in in_channels_list])
        self.fuse = nn.Sequential(
            nn.Conv2d(len(in_channels_list) * out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, features):
        target_size = features[-1].size()[2:]
        processed = []
        for i, f in enumerate(features):
            if f.size()[2:] != target_size:
                f = F.interpolate(f, size=target_size, mode='bilinear', align_corners=False)
            f = self.convs[i](f)
            processed.append(f)
        out = torch.cat(processed, dim=1)
        out = self.fuse(out)
        return out

###############################################
# 기타 모델 관련 모듈 (StochasticDepth, GhostModule, SELayer, DynamicGating)
###############################################
class StochasticDepth(nn.Module):
    def __init__(self, p, mode="row"):
        super(StochasticDepth, self).__init__()
        self.p = p
        self.mode = mode

    def forward(self, x):
        if not self.training or self.p == 0.0:
            return x
        survival_rate = 1 - self.p
        if self.mode == "row":
            batch_size = x.shape[0]
            noise = torch.rand(batch_size, 1, 1, 1, device=x.device, dtype=x.dtype)
            binary_mask = (noise < survival_rate).float()
            return x / survival_rate * binary_mask
        else:
            if torch.rand(1).item() < self.p:
                return torch.zeros_like(x)
            else:
                return x

class GhostModule(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, ratio=2, dw_kernel_size=3, stride=1, padding=0, use_relu=True):
        super(GhostModule, self).__init__()
        self.out_channels = out_channels
        self.primary_channels = int(torch.ceil(torch.tensor(out_channels / ratio)))
        self.cheap_channels = out_channels - self.primary_channels
        self.primary_conv = nn.Sequential(
            nn.Conv2d(in_channels, self.primary_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(self.primary_channels),
            nn.ReLU(inplace=True) if use_relu else nn.Identity()
        )
        self.cheap_conv = nn.Sequential(
            nn.Conv2d(self.primary_channels, self.cheap_channels, dw_kernel_size, stride=1,
                      padding=dw_kernel_size // 2, groups=self.primary_channels, bias=False),
            nn.BatchNorm2d(self.cheap_channels),
            nn.ReLU(inplace=True) if use_relu else nn.Identity()
        )
    def forward(self, x):
        x1 = self.primary_conv(x)
        x2 = self.cheap_conv(x1)
        out = torch.cat([x1, x2], dim=1)
        return out[:, :self.out_channels, :, :].contiguous()

def ghost_conv_block(in_channels, out_channels, use_relu=True):
    return nn.Sequential(
        GhostModule(in_channels, out_channels, kernel_size=3, ratio=2, dw_kernel_size=3, stride=1, padding=1, use_relu=use_relu),
        GhostModule(out_channels, out_channels, kernel_size=3, ratio=2, dw_kernel_size=3, stride=1, padding=1, use_relu=use_relu)
    )

class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )
    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

class DynamicGating(nn.Module):
    def __init__(self, num_branches, hidden_dim=64, dropout_prob=0.1, init_temperature=2.0, iterations=3):
        super(DynamicGating, self).__init__()
        self.temperature = init_temperature
        self.iterations = iterations
        self.dropout_prob = dropout_prob
        self.fc = nn.Sequential(
            nn.Linear(num_branches, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(p=self.dropout_prob),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(p=self.dropout_prob),
            nn.Linear(hidden_dim, num_branches)
        )
        self.layernorm = nn.LayerNorm(num_branches)
        self.softmax = nn.Softmax(dim=1)
        self.res_scale = nn.Parameter(torch.ones(1))
        self.saved_log_probs = None
        self.entropy = None
        self.rl_loss = 0.0

    def forward(self, features):
        norm_features = self.layernorm(features)
        logits = self.fc(norm_features)
        logits = self.layernorm(logits)
        scaled_logits = logits / self.temperature
        gates = self.softmax(scaled_logits)
        for _ in range(self.iterations - 1):
            updated_features = norm_features + self.res_scale * gates
            logits = self.fc(updated_features)
            logits = self.layernorm(logits)
            scaled_logits = logits / self.temperature
            new_gates = self.softmax(scaled_logits)
            gates = 0.5 * gates + 0.5 * new_gates
        return gates

###############################################
# QuadAgentBlock (Lite 버전, 3 브랜치)
###############################################
class QuadAgentBlock(nn.Module):
    def __init__(self, in_channels, out_channels, gating_dropout=0.3, gating_hidden_dim=32, gating_temperature=1.5, stochastic_depth_prob=0.5):
        super().__init__()
        branch_channels = out_channels // 4
        self.branch1 = GhostModule(in_channels, branch_channels, ratio=4)
        self.branch2 = GhostModule(in_channels, branch_channels, kernel_size=3, padding=1, ratio=4)
        self.branch3 = nn.Sequential(
            GhostModule(in_channels, branch_channels, ratio=4),
            SELayer(branch_channels, reduction=32)
        )
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.gating = DynamicGating(num_branches=3, hidden_dim=gating_hidden_dim, dropout_prob=gating_dropout, init_temperature=gating_temperature)
        self.fusion_conv = GhostModule(branch_channels * 3, out_channels, ratio=2)
        self.res_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) if in_channels != out_channels else nn.Identity()
        self.stochastic_depth = StochasticDepth(stochastic_depth_prob)

    def forward(self, x):
        b1 = self.branch1(x)
        b2 = self.branch2(x)
        b3 = self.branch3(x)
        gap_b1 = self.gap(b1).view(x.size(0), -1)
        gap_b2 = self.gap(b2).view(x.size(0), -1)
        gap_b3 = self.gap(b3).view(x.size(0), -1)
        features = torch.stack([gap_b1.mean(dim=1), gap_b2.mean(dim=1), gap_b3.mean(dim=1)], dim=1)
        gates = self.gating(features)
        b1 = b1 * gates[:, 0].view(-1, 1, 1, 1)
        b2 = b2 * gates[:, 1].view(-1, 1, 1, 1)
        b3 = b3 * gates[:, 2].view(-1, 1, 1, 1)
        out = torch.cat([b1, b2, b3], dim=1)
        out = self.fusion_conv(out)
        res = self.res_conv(x)
        res = self.stochastic_depth(res)
        return out + res

###############################################
# DARKViT_UNet: ViT 기반 Encoder + QuadAgentBlock + Transformer Bottleneck + UNet Decoder
###############################################
from timm.models import create_model

class DARKViT_UNet(nn.Module):
    def __init__(self, out_channels=1):
        super(DARKViT_UNet, self).__init__()
        self.encoder = create_model('swin_large_patch4_window7_224', pretrained=True, features_only=True)
        self.bottleneck = QuadAgentBlock(in_channels=1536, out_channels=1536, 
                                         gating_dropout=0.3, gating_hidden_dim=32, 
                                         gating_temperature=1.5, stochastic_depth_prob=0.5)
        self.transformer_bottleneck1 = TransformerBottleneck(d_model=1536, nhead=8, num_layers=1, 
                                                             dim_feedforward=2048, dropout=0.1)
        self.transformer_bottleneck2 = TransformerBottleneck(d_model=1536, nhead=8, num_layers=1, 
                                                             dim_feedforward=2048, dropout=0.1)
        self.decoder = UNetDecoder_Swin()
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
    
    def forward(self, x):
        # encoder에서 feature map 리스트 얻기: [f0, f1, f2, f3]
        features = self.encoder(x)
        
        # 만약 각 feature map이 (N, H, W, C) 형식이라면 (N, C, H, W)로 변환
        for i in range(len(features)):
            # f0: 192, f1:384, f2:768, f3:1536 채널이어야 함
            if features[i].shape[1] not in {192, 384, 768, 1536}:
                features[i] = features[i].permute(0, 3, 1, 2).contiguous()
        
        # 반환된 순서는 [f0, f1, f2, f3]이므로, decoder가 기대하는 [f3, f2, f1, f0] 순서로 재정렬
        f0, f1, f2, f3 = features  
        features_reordered = [f3, f2, f1, f0]
        
        # 이제 features_reordered[0]는 f3: (B,1536,7,7)
        b_out = self.bottleneck(features_reordered[0])
        t_out1 = self.transformer_bottleneck1(b_out)
        t_out2 = self.transformer_bottleneck2(t_out1)
        features_reordered[0] = t_out2
        
        decoded = self.decoder(features_reordered)
        return self.final_conv(decoded)

###############################################
# UNet Decoder for Swin Backbone
###############################################
class UNetDecoder_Swin(nn.Module):
    def __init__(self):
        super(UNetDecoder_Swin, self).__init__()
        # Swin‑Large feature sizes (가정):
        # f0: (B,192,56,56), f1: (B,384,28,28), f2: (B,768,14,14), f3: (B,1536,7,7)
        self.up1 = nn.ConvTranspose2d(1536, 768, kernel_size=2, stride=2)  # 7->14
        self.conv1 = nn.Sequential(
            nn.Conv2d(768 + 768, 768, kernel_size=3, padding=1),
            nn.BatchNorm2d(768),
            nn.ReLU(inplace=True)
        )
        self.up2 = nn.ConvTranspose2d(768, 384, kernel_size=2, stride=2)   # 14->28
        self.conv2 = nn.Sequential(
            nn.Conv2d(384 + 384, 384, kernel_size=3, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU(inplace=True)
        )
        self.up3 = nn.ConvTranspose2d(384, 192, kernel_size=2, stride=2)   # 28->56
        self.conv3 = nn.Sequential(
            nn.Conv2d(192 + 192, 192, kernel_size=3, padding=1),
            nn.BatchNorm2d(192),
            nn.ReLU(inplace=True)
        )
        self.up4 = nn.ConvTranspose2d(192, 64, kernel_size=2, stride=2)    # 56->112
        self.conv4 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
    def forward(self, features_list):
        # features_list: [f0, f1, f2, f3]
        f3, f2, f1, f0 = features_list  # ✅ 순서 맞추기
        x = f3  # (B,1536,7,7)
        x = self.up1(x)  # -> (B,768,14,14)
        x = torch.cat([x, f2], dim=1)  # f2: (B,768,14,14) → (B,1536,14,14)
        x = self.conv1(x)  # -> (B,768,14,14)
        x = self.up2(x)  # -> (B,384,28,28)
        x = torch.cat([x, f1], dim=1)  # f1: (B,384,28,28) → (B,768,28,28)
        x = self.conv2(x)  # -> (B,384,28,28)
        x = self.up3(x)  # -> (B,192,56,56)
        x = torch.cat([x, f0], dim=1)  # f0: (B,192,56,56) → (B,384,56,56)
        x = self.conv3(x)  # -> (B,192,56,56)
        x = self.up4(x)  # -> (B,64,112,112)
        x = self.conv4(x)  # -> (B,64,112,112)
        # 최종 interpolation을 (256,256)으로 변경
        x = F.interpolate(x, size=(256,256), mode='bilinear', align_corners=False)
        return x

###############################################
# Transformer Bottleneck (연속 두 개 적용)
###############################################
class TransformerBottleneck(nn.Module):
    def __init__(self, d_model=1536, nhead=8, num_layers=1, dim_feedforward=2048, dropout=0.1):
        super(TransformerBottleneck, self).__init__()
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead,
                                                    dim_feedforward=dim_feedforward, dropout=dropout)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
    def forward(self, x):
        B, C, H, W = x.size()
        x = x.view(B, C, H * W).permute(2, 0, 1)
        x = self.transformer(x)
        x = x.permute(1, 2, 0).view(B, C, H, W)
        return x

###############################################
# Swin‑Transformer Backbone 기반 UNet with Attention + Lite Bottleneck
###############################################
from timm.models import create_model

class PretrainedSwin_UNet_AttentionFusion(nn.Module):
    def __init__(self, out_channels=1):
        super(PretrainedSwin_UNet_AttentionFusion, self).__init__()
        self.encoder = create_model('swin_large_patch4_window7_224', pretrained=True, features_only=True, img_size=256)
        self.bottleneck = QuadAgentBlock(in_channels=1536, out_channels=1536, gating_dropout=0.3, gating_hidden_dim=32, gating_temperature=1.5, stochastic_depth_prob=0.5)
        self.transformer_bottleneck1 = TransformerBottleneck(d_model=1536, nhead=8, num_layers=1, dim_feedforward=2048, dropout=0.1)
        self.transformer_bottleneck2 = TransformerBottleneck(d_model=1536, nhead=8, num_layers=1, dim_feedforward=2048, dropout=0.1)
        self.decoder = UNetDecoder_Swin()
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
        
    def forward(self, x):
        features_list = self.encoder(x)  # [f0, f1, f2, f3]

        # (N, H, W, C) → (N, C, H, W) 변환이 필요한 경우 적용
        for i in range(len(features_list)):
            if features_list[i].shape[1] not in {192, 384, 768, 1536}:  
                features_list[i] = features_list[i].permute(0, 3, 1, 2).contiguous()

        # 디코더에 올바른 순서로 전달되도록 변환
        f0, f1, f2, f3 = features_list  # Swin Transformer의 기본 반환 순서
        features_list = [f3, f2, f1, f0]  # (B, 1536, 7, 7) → (B, 192, 56, 56)

        # Bottleneck과 Transformer Bottleneck 적용
        b_out = self.bottleneck(features_list[0])
        t_out1 = self.transformer_bottleneck1(b_out)
        t_out2 = self.transformer_bottleneck2(t_out1)
        features_list[0] = t_out2

        decoded = self.decoder(features_list)
        return self.final_conv(decoded)

###############################################
# Loss & Metric Functions (HD 포함)
###############################################
from segmentation_models_pytorch.losses import LovaszLoss
lv_loss = LovaszLoss(mode='binary')

def focal_tversky_loss(pred, target, alpha=0.5, beta=0.5, gamma=4/3, smooth=1e-6):
    pred = torch.sigmoid(pred)
    if pred.shape[-2:] != target.shape[-2:]:
        pred = F.interpolate(pred, size=target.shape[-2:], mode='bilinear', align_corners=False)
    pred = pred.view(pred.size(0), -1)
    target = target.view(target.size(0), -1)
    tp = (pred * target).sum(dim=1)
    fp = ((1 - target) * pred).sum(dim=1)
    fn = (target * (1 - pred)).sum(dim=1)
    tversky_index = (tp + smooth) / (tp + alpha * fp + beta * fn + smooth)
    loss = (1 - tversky_index) ** gamma
    return loss.mean()

def boundary_loss(pred, target):
    pred = torch.sigmoid(pred)
    if pred.shape[-2:] != target.shape[-2:]:
        pred = F.interpolate(pred, size=target.shape[-2:], mode='bilinear', align_corners=False)
    pred_np = pred.detach().cpu().numpy()
    target_np = target.detach().cpu().numpy()
    boundary_masks = []
    for i in range(pred_np.shape[0]):
        gt_mask = target_np[i, 0]
        boundary = find_boundaries(gt_mask, mode='thick')
        boundary_masks.append(boundary.astype(np.float32))
    boundary_masks = np.stack(boundary_masks, axis=0)[:, None, :, :]
    boundary_masks_torch = torch.from_numpy(boundary_masks).to(pred.device)
    intersect = (pred * boundary_masks_torch).sum()
    denom = pred.sum() + boundary_masks_torch.sum()
    boundary_dice = (2.0 * intersect) / (denom + 1e-6)
    return 1.0 - boundary_dice

def combined_loss(outputs, masks, pos_weight=None, lambda_boundary=0.2, lambda_lovasz=0.6):
    if outputs.shape[-2:] != masks.shape[-2:]:
        outputs = F.interpolate(outputs, size=masks.shape[-2:], mode='bilinear', align_corners=False)
    loss_ft = focal_tversky_loss(outputs, masks)
    if pos_weight is not None:
        loss_bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(outputs, masks.float())
    else:
        loss_bce = nn.BCEWithLogitsLoss()(outputs, masks.float())
    bl = boundary_loss(outputs, masks)
    lv = lv_loss(outputs, masks)
    return 0.4 * lv + 0.3 * loss_ft + 0.3 * loss_bce + lambda_boundary * bl

def dice_f1_precision_recall(pred, target, threshold=0.5, smooth=1e-6):
    if pred.shape[-2:] != target.shape[-2:]:
        pred = F.interpolate(pred, size=target.shape[-2:], mode='bilinear', align_corners=False)
    pred_bin = (torch.sigmoid(pred) > threshold).float()
    target_bin = target.float()
    intersection = (pred_bin * target_bin).sum()
    precision = intersection / (pred_bin.sum() + smooth)
    recall = intersection / (target_bin.sum() + smooth)
    f1 = 2 * (precision * recall) / (precision + recall + smooth)
    return f1.item(), precision.item(), recall.item()

def iou_metric(pred, target, threshold=0.5, smooth=1e-6):
    if pred.shape[-2:] != target.shape[-2:]:
        pred = F.interpolate(pred, size=target.shape[-2:], mode='bilinear', align_corners=False)
    pred_bin = (torch.sigmoid(pred) > threshold).float()
    target_bin = target.float()
    intersection = (pred_bin * target_bin).sum()
    union = pred_bin.sum() + target_bin.sum() - intersection
    return (intersection + smooth) / (union + smooth)

def compute_batch_metrics_new(outputs, masks, threshold=0.5, smooth=1e-6):
    if outputs.shape[-2:] != masks.shape[-2:]:
        outputs = F.interpolate(outputs, size=masks.shape[-2:], mode='bilinear', align_corners=False)
    pred_probs = torch.sigmoid(outputs).detach()
    processed_preds = []
    for i in range(pred_probs.size(0)):
        pred_np = pred_probs[i].cpu().numpy()[0]
        pred_bin = (pred_np > threshold).astype(np.uint8)
        processed = postprocess_mask(pred_bin, min_size=100)
        processed_preds.append(processed)
    batch_iou = []
    batch_f1 = []
    batch_precision = []
    batch_recall = []
    for i in range(pred_probs.size(0)):
        pred = processed_preds[i]
        gt = masks[i].cpu().numpy()[0]
        intersection = np.sum(pred * gt)
        union = np.sum(pred) + np.sum(gt) - intersection
        iou = (intersection + smooth) / (union + smooth)
        batch_iou.append(iou)
        precision = intersection / (np.sum(pred) + smooth)
        recall = intersection / (np.sum(gt) + smooth)
        f1 = 2 * (precision * recall) / (precision + recall + smooth)
        batch_f1.append(f1)
        batch_precision.append(precision)
        batch_recall.append(recall)
    probs = pred_probs.cpu().numpy().flatten()
    masks_np = masks.cpu().numpy().flatten()
    try:
        auc_score = roc_auc_score(masks_np, probs)
    except ValueError:
        auc_score = float('nan')
    return auc_score, np.mean(batch_iou), np.mean(batch_f1), np.mean(batch_precision), np.mean(batch_recall)

def hausdorff_distance(pred, target):
    return medpy_binary.hd95(pred, target)

def compute_hd_metric(outputs, masks, threshold=0.5):
    if outputs.shape[-2:] != masks.shape[-2:]:
        outputs = F.interpolate(outputs, size=masks.shape[-2:], mode='bilinear', align_corners=False)
    pred_probs = torch.sigmoid(outputs).detach().cpu().numpy()
    masks_np = masks.detach().cpu().numpy()
    hd_list = []
    for i in range(outputs.size(0)):
        pred_bin = (pred_probs[i, 0] > threshold).astype(np.uint8)
        gt_bin = (masks_np[i, 0] > 0.5).astype(np.uint8)
        try:
            hd = medpy_binary.hd95(pred_bin, gt_bin)
        except Exception:
            hd = np.nan
        hd_list.append(hd)
    return np.nanmean(hd_list)

def evaluate_with_metrics(model, dataloader, device, threshold=0.5, lambda_boundary=0.2, pos_weight=None):
    model.eval()
    total_loss = 0.0
    total_auc  = 0.0
    total_iou  = 0.0
    total_f1   = 0.0
    total_precision = 0.0
    total_recall = 0.0
    total_hd = 0.0
    total_samples = 0
    with torch.no_grad():
        for images, masks, _ in dataloader:
            images = images.to(device)
            masks = masks.to(device)
            outputs = model(images)
            loss = combined_loss(outputs, masks, pos_weight=pos_weight, lambda_boundary=lambda_boundary)
            auc_score, iou_val, f1_val, prec, rec = compute_batch_metrics_new(outputs, masks, threshold=threshold, smooth=1e-6)
            hd_val = compute_hd_metric(outputs, masks, threshold=threshold)
            bs = images.size(0)
            total_loss += loss.item() * bs
            total_auc  += auc_score * bs
            total_iou  += iou_val * bs
            total_f1   += f1_val * bs
            total_precision += prec * bs
            total_recall += rec * bs
            total_hd += hd_val * bs
            total_samples += bs
    avg_loss = total_loss / total_samples
    avg_auc  = total_auc / total_samples
    avg_iou  = total_iou / total_samples
    avg_f1   = total_f1 / total_samples
    avg_precision = total_precision / total_samples
    avg_recall = total_recall / total_samples
    avg_hd = total_hd / total_samples
    return avg_loss, avg_auc, avg_iou, avg_f1, avg_precision, avg_recall, avg_hd

###############################################
# TTA (Test Time Augmentation) 함수 (Noise & ColorJitter 추가)
###############################################
def tta_predict(model, image, device):
    cj = ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)
    def add_noise(x, std=0.05):
        noise = torch.randn_like(x) * std
        return x + noise
    def identity(x): return x
    def hflip(x): return torch.flip(x, dims=[-1])
    def vflip(x): return torch.flip(x, dims=[-2])
    def rot90(x): return torch.rot90(x, k=1, dims=[-2, -1])
    def inv_hflip(x): return torch.flip(x, dims=[-1])
    def inv_vflip(x): return torch.flip(x, dims=[-2])
    def inv_rot90(x): return torch.rot90(x, k=3, dims=[-2, -1])
    
    transforms_list = [
        (identity, identity),
        (hflip, inv_hflip),
        (vflip, inv_vflip),
        (rot90, inv_rot90),
        (lambda x: add_noise(cj(x)), lambda x: x)
    ]
    predictions = []
    model.eval()
    with torch.no_grad():
        for aug, inv in transforms_list:
            augmented = aug(image)
            output = model(augmented.unsqueeze(0).to(device))
            output = torch.sigmoid(output)
            output = inv(output).cpu()
            predictions.append(output)
    avg_prediction = torch.mean(torch.stack(predictions), dim=0)
    return avg_prediction

###############################################
# Optimizer Scheduler (SGDR: CosineAnnealingWarmRestarts)
###############################################
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

###############################################
# 데이터 로더 및 BUSI Dataset 설정
###############################################
data_path = '/kaggle/input/breast-ultrasound-images-dataset/Dataset_BUSI_with_GT/'
full_dataset = BUSISegmentationDataset(data_path, transform=joint_transform)
indices = np.arange(len(full_dataset))
train_val_idx, test_idx = train_test_split(indices, test_size=0.2, random_state=42)
train_idx, val_idx = train_test_split(train_val_idx, test_size=0.25, random_state=42)
train_dataset = Subset(full_dataset, train_idx)
val_dataset = Subset(full_dataset, val_idx)
test_dataset = Subset(full_dataset, test_idx)

train_loader = DataLoader(
    train_dataset, batch_size=16, shuffle=True, num_workers=4,
    worker_init_fn=lambda worker_id: np.random.seed(42 + worker_id),
    pin_memory=True, persistent_workers=True
)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True)

def calculate_pos_weight(loader):
    total_pixels = 0
    positive_pixels = 0
    for images, masks, _ in loader:
        positive_pixels += masks.sum().item()
        total_pixels += masks.numel()
    negative_pixels = total_pixels - positive_pixels
    return torch.tensor(negative_pixels / (positive_pixels + 1e-6)).to(device)

###############################################
# 모델, Optimizer, Scheduler 설정
###############################################
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PretrainedSwin_UNet_AttentionFusion(out_channels=1)
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
model = model.to(device)

pos_weight = calculate_pos_weight(train_loader)
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=5e-4)
num_epochs = 500

scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-5)

###############################################
# Training Loop (Early Stopping patience=15)
###############################################
from torch.cuda.amp import GradScaler, autocast
scaler = GradScaler()
patience = 15
best_val_f1 = 0.0
patience_counter = 0
"""
for epoch in range(num_epochs):
    current_temp = 5.0 - ((5.0 - 1.0) * epoch / num_epochs)
    if hasattr(model, 'bottleneck'):
        if isinstance(model, nn.DataParallel):
            model.module.bottleneck.gating.temperature = current_temp
        else:
            model.bottleneck.gating.temperature = current_temp

    model.train()
    epoch_loss = 0.0
    epoch_auc = 0.0
    epoch_iou = 0.0
    epoch_f1 = 0.0
    epoch_precision = 0.0
    epoch_recall = 0.0
    total_samples = 0

    for images, masks, _ in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} Training"):
        images = images.to(device)
        masks = masks.to(device)
        images, masks = cutmix_data(images, masks, alpha=0.8, p=0.7)
        optimizer.zero_grad()
        with autocast():
            outputs = model(images)
            loss_seg = combined_loss(outputs, masks, pos_weight=pos_weight, lambda_boundary=0.2, lambda_lovasz=0.6)
            loss_total = loss_seg  # RL 미사용
        scaler.scale(loss_total).backward()
        scaler.unscale_(optimizer)
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        bs = images.size(0)
        epoch_loss += loss_total.item() * bs
        auc_score, iou_val, f1_val, prec, rec = compute_batch_metrics_new(outputs, masks, threshold=0.5, smooth=1e-6)
        epoch_auc += auc_score * bs
        epoch_iou += iou_val * bs
        epoch_f1 += f1_val * bs
        epoch_precision += prec * bs
        epoch_recall += rec * bs
        total_samples += bs

    scheduler.step(epoch)
    avg_loss = epoch_loss / total_samples
    avg_auc = epoch_auc / total_samples
    avg_iou = epoch_iou / total_samples
    avg_f1 = epoch_f1 / total_samples
    avg_precision = epoch_precision / total_samples
    avg_recall = epoch_recall / total_samples
    print(f"Epoch {epoch+1}/{num_epochs} | Loss: {avg_loss:.4f}, AUC: {avg_auc:.4f}, IoU: {avg_iou:.4f}, Dice/F1: {avg_f1:.4f}, Precision: {avg_precision:.4f}, Recall: {avg_recall:.4f} (Temp: {current_temp:.2f})")
    
    val_loss, val_auc, val_iou, val_f1, val_precision, val_recall, val_hd = evaluate_with_metrics(model, val_loader, device, threshold=0.5, lambda_boundary=0.2, pos_weight=pos_weight)
    print(f"[Val] Loss: {val_loss:.4f}, AUC: {val_auc:.4f}, IoU: {val_iou:.4f}, Dice/F1: {val_f1:.4f}, Precision: {val_precision:.4f}, Recall: {val_recall:.4f}, HD95: {val_hd:.4f}")
    
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        best_weights = copy.deepcopy(model.state_dict())
        torch.save(best_weights, 'best_model.pth')
        patience_counter = 0
        print("🍀 New best validation Dice achieved! Dice: {:.4f} 🍀".format(best_val_f1))
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break
"""

model.safetensors:   0%|          | 0.00/788M [00:00<?, ?B/s]

'\nfor epoch in range(num_epochs):\n    current_temp = 5.0 - ((5.0 - 1.0) * epoch / num_epochs)\n    if hasattr(model, \'bottleneck\'):\n        if isinstance(model, nn.DataParallel):\n            model.module.bottleneck.gating.temperature = current_temp\n        else:\n            model.bottleneck.gating.temperature = current_temp\n\n    model.train()\n    epoch_loss = 0.0\n    epoch_auc = 0.0\n    epoch_iou = 0.0\n    epoch_f1 = 0.0\n    epoch_precision = 0.0\n    epoch_recall = 0.0\n    total_samples = 0\n\n    for images, masks, _ in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} Training"):\n        images = images.to(device)\n        masks = masks.to(device)\n        images, masks = cutmix_data(images, masks, alpha=0.8, p=0.7)\n        optimizer.zero_grad()\n        with autocast():\n            outputs = model(images)\n            loss_seg = combined_loss(outputs, masks, pos_weight=pos_weight, lambda_boundary=0.2, lambda_lovasz=0.6)\n            loss_total = loss_seg  # 

# Puesudo Labeling

In [3]:
import os
import re
import math
import copy
import glob
import random
import warnings
import numpy as np
from collections import defaultdict
from PIL import Image, ImageOps
import matplotlib.pyplot as plt
import cv2  # CLAHE 적용을 위해 OpenCV 사용

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
from skimage.segmentation import find_boundaries
from skimage.measure import label, regionprops
import scipy.ndimage

# Albumentations 기반 augmentation
import albumentations as A
from albumentations.pytorch import ToTensorV2

# medpy HD metric
from medpy.metric import binary as medpy_binary
from torchvision.transforms import ColorJitter

###############################################
# Seed 및 Warning 설정
###############################################
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(42)
warnings.filterwarnings('ignore')

###############################################
# BUSI Segmentation Dataset (라벨 데이터)
###############################################
class BUSISegmentationDataset(Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.transform = transform  
        self.samples = []  # 반드시 초기화
        self._prepare_samples()

    def _prepare_samples(self):
        labels = os.listdir(self.data_path)
        for label in labels:
            folder_path = os.path.join(self.data_path, label)
            if not os.path.isdir(folder_path):
                continue
            files = os.listdir(folder_path)
            image_files = sorted([f for f in files if '_mask' not in f and f.endswith('.png')])
            mask_files  = sorted([f for f in files if '_mask' in f and f.endswith('.png')])
            pattern_img = re.compile(rf'{re.escape(label)} \((\d+)\)\.png')
            pattern_mask = re.compile(rf'{re.escape(label)} \((\d+)\)_mask(?:_\d+)?\.png')
            mask_dict = {}
            for mf in mask_files:
                m = pattern_mask.fullmatch(mf)
                if m:
                    idx = m.group(1)
                    mask_dict.setdefault(idx, []).append(mf)
            for im in image_files:
                m = pattern_img.fullmatch(im)
                if m:
                    idx = m.group(1)
                    img_path = os.path.join(folder_path, im)
                    if idx in mask_dict:
                        mask_paths = [os.path.join(folder_path, mf) for mf in mask_dict[idx]]
                        combined_mask = None
                        for mp in mask_paths:
                            mask_img = Image.open(mp).convert('L')
                            mask_arr = np.array(mask_img)
                            mask_binary = (mask_arr > 128).astype(np.uint8)
                            if combined_mask is None:
                                combined_mask = mask_binary
                            else:
                                combined_mask = np.maximum(combined_mask, mask_binary)
                        self.samples.append((img_path, combined_mask, label))
                    else:
                        image_pil = Image.open(img_path)
                        empty_mask = np.zeros(image_pil.size[::-1], dtype=np.uint8)
                        self.samples.append((img_path, empty_mask, label))

    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, index):
        img_path, mask_array, label = self.samples[index]
        image = Image.open(img_path).convert('RGB')
        mask = Image.fromarray((mask_array * 255).astype(np.uint8))
        if self.transform:
            image, mask = self.transform(image, mask)
        else:
            image = transforms.ToTensor()(image)
            mask = transforms.ToTensor()(mask)
        return image, mask, label

###############################################
# Per-image Z-score normalization (adaptive)
###############################################
def z_score_normalize(tensor):
    mean = tensor.mean()
    std = tensor.std() + 1e-6
    return (tensor - mean) / std

###############################################
# CLAHE Augmentation 함수 (OpenCV)
###############################################
def apply_clahe(image, clipLimit=2.0, tileGridSize=(8,8)):
    img_np = np.array(image)
    if len(img_np.shape) == 3 and img_np.shape[2] == 3:
        lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
        l, a, b = cv2.split(lab)
        clahe = cv2.createCLAHE(clipLimit=clipLimit, tileGridSize=tileGridSize)
        cl = clahe.apply(l)
        limg = cv2.merge((cl, a, b))
        final = cv2.cvtColor(limg, cv2.COLOR_LAB2RGB)
    else:
        clahe = cv2.createCLAHE(clipLimit=clipLimit, tileGridSize=tileGridSize)
        final = clahe.apply(img_np)
    return Image.fromarray(final)

###############################################
# joint_transform (Albumentations 기반 Augmentation)
# 여기서 입력 이미지와 mask를 256×256으로 resize
###############################################
def joint_transform(image, mask, size=(256,256)):
    geom_transform = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.Rotate(limit=10, p=0.5),
        A.ElasticTransform(alpha=10, sigma=5, alpha_affine=5, p=0.3),
        A.Resize(height=size[0], width=size[1])
    ])
    image_np = np.array(image)
    mask_np = np.array(mask)
    augmented = geom_transform(image=image_np, mask=mask_np)
    image = augmented['image']
    mask = augmented['mask']
    
    intensity_transform = A.Compose([
        A.RandomBrightnessContrast(p=0.5),
        A.CLAHE(clip_limit=1.0, tile_grid_size=(8,8), p=0.5),
        A.GaussianBlur(p=0.3)
    ])
    image = intensity_transform(image=image)['image']
    
    image = transforms.ToTensor()(image)
    image = z_score_normalize(image)
    mask = transforms.ToTensor()(mask)
    return image, mask

###############################################
# CutMix 함수 (alpha=1.5, p=0.7)
###############################################
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    cx = np.random.randint(W)
    cy = np.random.randint(H)
    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    return bbx1, bby1, bbx2, bby2

def cutmix_data(images, masks, alpha=1.5, p=0.7):
    if np.random.rand() > p:
        return images, masks
    lam = np.random.beta(alpha, alpha)
    rand_index = torch.randperm(images.size(0)).to(images.device)
    bbx1, bby1, bbx2, bby2 = rand_bbox(images.size(), lam)
    images[:, :, bbx1:bbx2, bby1:bby2] = images[rand_index, :, bbx1:bbx2, bby1:bby2]
    masks[:, :, bbx1:bbx2, bby1:bby2] = masks[rand_index, :, bbx1:bbx2, bby1:bby2]
    return images, masks

###############################################
# Advanced 후처리: Morphological Closing (Kernel 9×9)
###############################################
def postprocess_mask(mask, min_size=100):
    labeled_mask = label(mask)
    processed_mask = np.zeros_like(mask)
    for region in regionprops(labeled_mask):
        if region.area >= min_size:
            processed_mask[labeled_mask == region.label] = 1
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (9,9))
    closed = cv2.morphologyEx(processed_mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
    return closed

# ----------------------------------------------------------------
# ... (AttentionGate, MultiScaleFusion, StochasticDepth, GhostModule, SELayer, DynamicGating,
# QuadAgentBlock, DARKViT_UNet, UNetDecoder_Swin, TransformerBottleneck, PretrainedSwin_UNet_AttentionFusion 등
# 기존 코드 정의는 그대로 사용) 
# ----------------------------------------------------------------

# 여기서는 이미 기존 코드의 모델, loss, metric 함수들을 정의했다고 가정합니다.
# (위에 제공된 코드의 나머지 부분을 그대로 사용)

###############################################
# Unlabeled Dataset 클래스 (라벨 없는 데이터)
###############################################
class UnlabeledDataset(Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.transform = transform
        # 폴더 내 모든 PNG 파일 경로 읽기
        self.image_paths = sorted([os.path.join(data_path, fname) for fname in os.listdir(data_path) if fname.endswith('.png')])
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, index):
        img_path = self.image_paths[index]
        image = Image.open(img_path).convert("RGB")
        # mask는 없으므로 더미 값을 사용 (예: None 혹은 0으로 채워진 이미지)
        dummy_mask = Image.new("L", image.size, 0)
        if self.transform:
            image, _ = self.transform(image, dummy_mask)
        else:
            image = transforms.ToTensor()(image)
        return image, img_path

###############################################
# Pseudo Labeling 함수
###############################################
def generate_pseudo_labels(model, unlabeled_loader, device, threshold=0.9):
    model.eval()
    pseudo_data = []  # (image, pseudo mask) 튜플 리스트
    with torch.no_grad():
        for images, paths in unlabeled_loader:
            images = images.to(device)
            outputs = model(images)
            preds = torch.sigmoid(outputs)
            # threshold 기준을 넘는 부분을 pseudo label로 사용 (여기서는 픽셀 단위)
            pseudo_masks = (preds > threshold).float()
            for i in range(images.size(0)):
                # 각 이미지와 해당 pseudo mask를 CPU로 가져와 저장
                pseudo_data.append((images[i].cpu(), pseudo_masks[i].cpu()))
    return pseudo_data

###############################################
# Combined Dataset: labeled + pseudo labeled 데이터
###############################################
class CombinedDataset(Dataset):
    def __init__(self, labeled_dataset, pseudo_data):
        self.labeled_dataset = labeled_dataset
        self.pseudo_data = pseudo_data

    def __len__(self):
        return len(self.labeled_dataset) + len(self.pseudo_data)
    
    def __getitem__(self, index):
        if index < len(self.labeled_dataset):
            return self.labeled_dataset[index]
        else:
            pseudo_index = index - len(self.labeled_dataset)
            image, mask = self.pseudo_data[pseudo_index]
            # pseudo 데이터의 label은 'pseudo'로 구분
            return image, mask, "pseudo"



In [4]:
import os

# /kaggle/working/ 경로에 "unlabeled-busi-images" 폴더 생성
folder_path = '/kaggle/breast-ultrasound-images-dataset/Dataset_BUSI_with_GT/unlabeled-busi-images/'
os.makedirs(folder_path, exist_ok=True)
print(f"Folder created: {folder_path}")


Folder created: /kaggle/breast-ultrasound-images-dataset/Dataset_BUSI_with_GT/unlabeled-busi-images/


In [5]:
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
import copy

# 1. Device, 모델, Optimizer, Scheduler 등 초기 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 라벨 데이터셋과 unlabeled 데이터셋 생성
data_path = '/kaggle/input/breast-ultrasound-images-dataset/Dataset_BUSI_with_GT/'
labeled_dataset = BUSISegmentationDataset(data_path, transform=joint_transform)

unlabeled_data_path = '/kaggle/breast-ultrasound-images-dataset/Dataset_BUSI_with_GT/unlabeled-busi-images/'  # 실제 unlabeled 데이터 경로로 수정
unlabeled_dataset = UnlabeledDataset(unlabeled_data_path, transform=joint_transform)

labeled_loader = DataLoader(labeled_dataset, batch_size=16, shuffle=True, num_workers=4)
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=16, shuffle=False, num_workers=4)

# 모델 생성 (Swin 백본에 img_size=256 적용되어 있다고 가정)
model = PretrainedSwin_UNet_AttentionFusion(out_channels=1)
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
model = model.to(device)

optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-5)

scaler = GradScaler()
patience = 15
best_val_f1 = 0.0
patience_counter = 0

# 2. 초기 학습 (라벨 데이터만 사용, 예: 10 에폭)
num_initial_epochs = 10
print("===> Initial training on labeled data...")

for epoch in range(num_initial_epochs):
    current_temp = 5.0 - ((5.0 - 1.0) * epoch / num_initial_epochs)
    if hasattr(model, 'bottleneck'):
        if isinstance(model, nn.DataParallel):
            model.module.bottleneck.gating.temperature = current_temp
        else:
            model.bottleneck.gating.temperature = current_temp

    model.train()
    epoch_loss = 0.0
    total_samples = 0

    for images, masks, _ in tqdm(labeled_loader, desc=f"Initial Epoch {epoch+1}/{num_initial_epochs}"):
        images = images.to(device)
        masks = masks.to(device)
        # CutMix 적용 (선택 사항)
        images, masks = cutmix_data(images, masks, alpha=0.8, p=0.7)

        optimizer.zero_grad()
        with autocast():
            outputs = model(images)
            loss = combined_loss(outputs, masks, pos_weight=None, lambda_boundary=0.2, lambda_lovasz=0.6)
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()

        bs = images.size(0)
        epoch_loss += loss.item() * bs
        total_samples += bs

    scheduler.step(epoch)
    avg_loss = epoch_loss / total_samples
    print(f"Initial Epoch {epoch+1}/{num_initial_epochs} | Loss: {avg_loss:.4f}")



===> Initial training on labeled data...


Initial Epoch 1/10: 100%|██████████| 49/49 [01:24<00:00,  1.73s/it]


Initial Epoch 1/10 | Loss: 1.0146


Initial Epoch 2/10: 100%|██████████| 49/49 [01:01<00:00,  1.25s/it]


Initial Epoch 2/10 | Loss: 0.9148


Initial Epoch 3/10: 100%|██████████| 49/49 [01:01<00:00,  1.25s/it]


Initial Epoch 3/10 | Loss: 0.8546


Initial Epoch 4/10: 100%|██████████| 49/49 [01:01<00:00,  1.25s/it]


Initial Epoch 4/10 | Loss: 0.8292


Initial Epoch 5/10: 100%|██████████| 49/49 [01:01<00:00,  1.25s/it]


Initial Epoch 5/10 | Loss: 0.7959


Initial Epoch 6/10: 100%|██████████| 49/49 [01:01<00:00,  1.26s/it]


Initial Epoch 6/10 | Loss: 0.7556


Initial Epoch 7/10: 100%|██████████| 49/49 [01:01<00:00,  1.25s/it]


Initial Epoch 7/10 | Loss: 0.7172


Initial Epoch 8/10: 100%|██████████| 49/49 [01:01<00:00,  1.25s/it]


Initial Epoch 8/10 | Loss: 0.7016


Initial Epoch 9/10: 100%|██████████| 49/49 [01:01<00:00,  1.25s/it]


Initial Epoch 9/10 | Loss: 0.6822


Initial Epoch 10/10: 100%|██████████| 49/49 [01:01<00:00,  1.26s/it]

Initial Epoch 10/10 | Loss: 0.6603





In [6]:
from skimage.measure import label as sk_label, regionprops

def postprocess_mask(mask, min_size=100):
    # skimage.measure의 label 함수를 sk_label로 호출
    labeled_mask = sk_label(mask)
    processed_mask = np.zeros_like(mask)
    for region in regionprops(labeled_mask):
        if region.area >= min_size:
            processed_mask[labeled_mask == region.label] = 1
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (9,9))
    closed = cv2.morphologyEx(processed_mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
    return closed

# 3. Pseudo Labeling: unlabeled 데이터에 대해 모델 예측 수행
print("===> Generating pseudo labels on unlabeled data...")
pseudo_data = generate_pseudo_labels(model, unlabeled_loader, device, threshold=0.9)
print(f"Pseudo labels generated for {len(pseudo_data)} images.")

# 4. 라벨 데이터와 pseudo labeled 데이터를 합쳐 CombinedDataset 생성
combined_dataset = CombinedDataset(labeled_dataset, pseudo_data)
combined_loader = DataLoader(combined_dataset, batch_size=16, shuffle=True, num_workers=4)

# 5. Fine-tuning on Combined Dataset (labeled + pseudo labeled)
num_finetune_epochs = 20
print("===> Fine-tuning on combined dataset (labeled + pseudo labeled)...")


===> Generating pseudo labels on unlabeled data...
Pseudo labels generated for 0 images.
===> Fine-tuning on combined dataset (labeled + pseudo labeled)...


from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
import copy

# 기본 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 데이터셋 생성 (경로는 실제 데이터 경로로 수정)
data_path = '/kaggle/input/breast-ultrasound-images-dataset/Dataset_BUSI_with_GT/'
labeled_dataset = BUSISegmentationDataset(data_path, transform=joint_transform)
unlabeled_data_path = '/kaggle/breast-ultrasound-images-dataset/Dataset_BUSI_with_GT/unlabeled-busi-images/'
unlabeled_dataset = UnlabeledDataset(unlabeled_data_path, transform=joint_transform)

# DataLoader (디버깅을 위해 num_workers=0 사용 - 필요 시 값을 늘리세요)
labeled_loader = DataLoader(labeled_dataset, batch_size=16, shuffle=True, num_workers=0)
val_loader = DataLoader(labeled_dataset, batch_size=16, shuffle=False, num_workers=0)

# 모델 생성 (백본 생성 시 img_size=256 적용되어 있다고 가정)
model = PretrainedSwin_UNet_AttentionFusion(out_channels=1)
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
model = model.to(device)

# 옵티마이저와 scheduler 설정: 학습률을 낮게 설정하여 안정화 (예: 1e-5)
optimizer = optim.AdamW(model.parameters(), lr=1e-5, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
scaler = GradScaler()

# Early stopping 설정
patience = 50
best_val_f1 = 0.0
patience_counter = 0

# 온도 스케줄: 온도가 최소 2.0 이하로 내려가지 않도록 함 (최대 5.0에서 선형 감소)
min_temp = 2.0
max_temp = 5.0

num_epochs = 500
combined_loader = None  # 이후 pseudo labeled 데이터를 위한 DataLoader

for epoch in range(num_epochs):
    # 온도 조절: 선형적으로 감소하되 최소값은 min_temp로 유지
    current_temp = max_temp - ((max_temp - min_temp) * epoch / num_epochs)
    if hasattr(model, 'bottleneck'):
        if isinstance(model, nn.DataParallel):
            model.module.bottleneck.gating.temperature = current_temp
        else:
            model.bottleneck.gating.temperature = current_temp

    model.train()
    epoch_loss = 0.0
    epoch_auc = 0.0
    epoch_iou = 0.0
    epoch_f1 = 0.0
    epoch_precision = 0.0
    epoch_recall = 0.0
    total_samples = 0

    # 0~29 epoch: labeled 데이터만 사용, 이후부터 combined_loader 사용
    if epoch < 30:
        train_loader = labeled_loader
    else:
        # epoch 30에서 한 번만 pseudo labeling 생성
        if epoch == 30:
            print("===> Generating pseudo labels on unlabeled data...")
            pseudo_loader = DataLoader(unlabeled_dataset, batch_size=16, shuffle=False, num_workers=0)
            pseudo_data = generate_pseudo_labels(model, pseudo_loader, device, threshold=0.9)
            print(f"Pseudo labels generated for {len(pseudo_data)} images.")
            combined_dataset = CombinedDataset(labeled_dataset, pseudo_data)
            train_loader = DataLoader(combined_dataset, batch_size=16, shuffle=True, num_workers=0)
            combined_loader = train_loader
        else:
            train_loader = combined_loader

    # Training phase
    for images, masks, _ in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} Training"):
        images = images.to(device)
        masks = masks.to(device)
        # CutMix 적용 (옵션)
        images, masks = cutmix_data(images, masks, alpha=0.8, p=0.7)
        optimizer.zero_grad()
        with autocast():
            outputs = model(images)
            loss = combined_loss(outputs, masks, pos_weight=None, lambda_boundary=0.2, lambda_lovasz=0.6)
        # 만약 손실값이 nan이면 해당 배치를 건너뛰기
        if torch.isnan(loss):
            print("Warning: loss is NaN, skipping batch")
            continue
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        # 더 엄격한 gradient clipping: max_norm=0.5
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
        scaler.step(optimizer)
        scaler.update()

        bs = images.size(0)
        epoch_loss += loss.item() * bs

        # 배치별 metric 계산 (compute_batch_metrics_new 함수는 내부에서 GPU 연산 대신 NumPy로 변환하는 부분이 있으므로, 여기서는 그대로 사용)
        auc_score, iou_val, f1_val, prec, rec = compute_batch_metrics_new(outputs, masks, threshold=0.5, smooth=1e-5)
        epoch_auc += auc_score * bs
        epoch_iou += iou_val * bs
        epoch_f1 += f1_val * bs
        epoch_precision += prec * bs
        epoch_recall += rec * bs
        total_samples += bs

    scheduler.step(epoch)
    avg_loss = epoch_loss / total_samples if total_samples > 0 else float('nan')
    avg_auc = epoch_auc / total_samples if total_samples > 0 else float('nan')
    avg_iou = epoch_iou / total_samples if total_samples > 0 else float('nan')
    avg_f1 = epoch_f1 / total_samples if total_samples > 0 else float('nan')
    avg_precision = epoch_precision / total_samples if total_samples > 0 else float('nan')
    avg_recall = epoch_recall / total_samples if total_samples > 0 else float('nan')
    print(f"Epoch {epoch+1}/{num_epochs} | Loss: {avg_loss:.4f}, AUC: {avg_auc:.4f}, IoU: {avg_iou:.4f}, Dice/F1: {avg_f1:.4f}, "
          f"Precision: {avg_precision:.4f}, Recall: {avg_recall:.4f} (Temp: {current_temp:.2f})")
    
    # Validation 평가 (라벨 데이터 사용)
    val_loss, val_auc, val_iou, val_f1, val_precision, val_recall, val_hd = evaluate_with_metrics(
        model, val_loader, device, threshold=0.5, lambda_boundary=0.2, pos_weight=None)
    print(f"[Val] Loss: {val_loss:.4f}, AUC: {val_auc:.4f}, IoU: {val_iou:.4f}, Dice/F1: {val_f1:.4f}, "
          f"Precision: {val_precision:.4f}, Recall: {val_recall:.4f}, HD95: {val_hd:.4f}")
    
    # Early stopping: validation Dice/F1가 개선되지 않으면 patience_counter 증가
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        best_weights = copy.deepcopy(model.state_dict())
        torch.save(best_weights, 'best_model_with_pseudo.pth')
        patience_counter = 0
        print("🍀 New best validation Dice achieved! Dice: {:.4f} 🍀".format(best_val_f1))
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

print("Training complete. Best model saved as 'best_model_with_pseudo.pth'.")


In [8]:
from skimage.measure import label as sk_label, regionprops

def postprocess_mask(mask, min_size=100):
    # skimage.measure의 label 함수를 sk_label로 호출
    labeled_mask = sk_label(mask)
    processed_mask = np.zeros_like(mask)
    for region in regionprops(labeled_mask):
        if region.area >= min_size:
            processed_mask[labeled_mask == region.label] = 1
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (9,9))
    closed = cv2.morphologyEx(processed_mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
    return closed

def compute_batch_metrics_new(outputs, masks, threshold=0.5, smooth=1e-6):
    if outputs.shape[-2:] != masks.shape[-2:]:
        outputs = F.interpolate(outputs, size=masks.shape[-2:], mode='bilinear', align_corners=False)
    pred_probs = torch.sigmoid(outputs).detach()
    processed_preds = []
    for i in range(pred_probs.size(0)):
        pred_np = pred_probs[i].cpu().numpy()[0]
        pred_bin = (pred_np > threshold).astype(np.uint8)
        processed = postprocess_mask(pred_bin, min_size=100)
        processed_preds.append(processed)
    batch_iou = []
    batch_f1 = []
    batch_precision = []
    batch_recall = []
    for i in range(pred_probs.size(0)):
        pred = processed_preds[i]
        gt = masks[i].cpu().numpy()[0]
        intersection = np.sum(pred * gt)
        union = np.sum(pred) + np.sum(gt) - intersection
        iou = (intersection + smooth) / (union + smooth)
        batch_iou.append(iou)
        precision = intersection / (np.sum(pred) + smooth)
        recall = intersection / (np.sum(gt) + smooth)
        f1 = 2 * (precision * recall) / (precision + recall + smooth)
        batch_f1.append(f1)
        batch_precision.append(precision)
        batch_recall.append(recall)
    # AUC 계산: ground truth에 0과 1이 모두 존재하지 않으면 0으로 반환
    probs = pred_probs.cpu().numpy().flatten()
    masks_np = masks.cpu().numpy().flatten()
    if np.all(masks_np == 0) or np.all(masks_np == 1):
        auc_score = 0.0
    else:
        try:
            auc_score = roc_auc_score(masks_np, probs)
        except Exception:
            auc_score = 0.0
    return auc_score, np.mean(batch_iou), np.mean(batch_f1), np.mean(batch_precision), np.mean(batch_recall)

def compute_hd_metric(outputs, masks, threshold=0.5):
    if outputs.shape[-2:] != masks.shape[-2:]:
        outputs = F.interpolate(outputs, size=masks.shape[-2:], mode='bilinear', align_corners=False)
    pred_probs = torch.sigmoid(outputs).detach().cpu().numpy()
    masks_np = masks.detach().cpu().numpy()
    hd_list = []
    for i in range(outputs.size(0)):
        pred_bin = (pred_probs[i, 0] > threshold).astype(np.uint8)
        gt_bin = (masks_np[i, 0] > 0.5).astype(np.uint8)
        # 두 마스크 모두 비어있으면 hd=0, 한쪽만 비어있으면 hd=100 (임의로 높은 값)
        if np.sum(gt_bin)==0 and np.sum(pred_bin)==0:
            hd = 0.0
        elif np.sum(gt_bin)==0 or np.sum(pred_bin)==0:
            hd = 100.0
        else:
            try:
                hd = medpy_binary.hd95(pred_bin, gt_bin)
            except Exception:
                hd = np.nan
        hd_list.append(hd)
    return np.nanmean(hd_list) if len(hd_list) > 0 else 0.0


In [None]:
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
import copy

# ------------------------------
# 기본 설정: device, 데이터셋, DataLoader, 모델, optimizer, scheduler
# ------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 라벨 데이터셋 (BUSI 데이터)
data_path = '/kaggle/input/breast-ultrasound-images-dataset/Dataset_BUSI_with_GT/'
labeled_dataset = BUSISegmentationDataset(data_path, transform=joint_transform)

# unlabeled 데이터셋 (경로는 실제 unlabeled 데이터 위치로 수정)
unlabeled_data_path = '/kaggle/breast-ultrasound-images-dataset/Dataset_BUSI_with_GT/unlabeled-busi-images/'
unlabeled_dataset = UnlabeledDataset(unlabeled_data_path, transform=joint_transform)

# DataLoader 설정
labeled_loader = DataLoader(labeled_dataset, batch_size=16, shuffle=True, num_workers=0)
val_loader = DataLoader(labeled_dataset, batch_size=16, shuffle=False, num_workers=0)  # 검증은 라벨 데이터 사용

# 모델 생성 (백본 생성 시 img_size=256 적용되어 있다고 가정)
model = PretrainedSwin_UNet_AttentionFusion(out_channels=1)
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
model = model.to(device)

optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-5)
scaler = GradScaler()

# Early stopping 설정: patience 15
patience = 15
best_val_f1 = 0.0
patience_counter = 0

# ------------------------------
# Training Loop: 전체 500 epoch
# 0~29 epoch: labeled 데이터만 사용
# 30 epoch부터: pseudo labeling을 생성하여 labeled + pseudo 데이터로 학습
# ------------------------------
num_epochs = 500
combined_loader = None  # 이후에 사용할 combined 데이터 로더

for epoch in range(num_epochs):
    # 온도 조절 (예: 5.0 -> 1.0로 선형 감소)
    current_temp = 5.0 - ((5.0 - 1.0) * epoch / num_epochs)
    if hasattr(model, 'bottleneck'):
        if isinstance(model, nn.DataParallel):
            model.module.bottleneck.gating.temperature = current_temp
        else:
            model.bottleneck.gating.temperature = current_temp

    model.train()
    epoch_loss = 0.0
    epoch_auc = 0.0
    epoch_iou = 0.0
    epoch_f1 = 0.0
    epoch_precision = 0.0
    epoch_recall = 0.0
    total_samples = 0

    # 0~29 epoch: labeled_loader 사용, 이후부터 combined_loader 사용
    if epoch < 30:
        train_loader = labeled_loader
    else:
        # epoch 30에서 한 번만 pseudo labeling 생성
        if epoch == 30:
            print("===> Generating pseudo labels on unlabeled data...")
            pseudo_data = generate_pseudo_labels(model, DataLoader(unlabeled_dataset, batch_size=16, shuffle=False, num_workers=4), device, threshold=0.9)
            print(f"Pseudo labels generated for {len(pseudo_data)} images.")
            # 라벨 데이터와 pseudo 데이터를 결합하여 CombinedDataset 생성
            combined_dataset = CombinedDataset(labeled_dataset, pseudo_data)
            train_loader = DataLoader(combined_dataset, batch_size=16, shuffle=True, num_workers=4)
            combined_loader = train_loader
        else:
            train_loader = combined_loader

    # Training phase
    for images, masks, _ in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} Training"):
        images = images.to(device)
        masks = masks.to(device)
        images, masks = cutmix_data(images, masks, alpha=0.8, p=0.7)
        optimizer.zero_grad()
        with autocast():
            outputs = model(images)
            loss = combined_loss(outputs, masks, pos_weight=None, lambda_boundary=0.2, lambda_lovasz=0.6)
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()

        bs = images.size(0)
        epoch_loss += loss.item() * bs

        # 배치별 metric 계산
        auc_score, iou_val, f1_val, prec, rec = compute_batch_metrics_new(outputs, masks, threshold=0.5, smooth=1e-6)
        epoch_auc += auc_score * bs
        epoch_iou += iou_val * bs
        epoch_f1 += f1_val * bs
        epoch_precision += prec * bs
        epoch_recall += rec * bs
        total_samples += bs

    scheduler.step(epoch)
    avg_loss = epoch_loss / total_samples
    avg_auc = epoch_auc / total_samples
    avg_iou = epoch_iou / total_samples
    avg_f1 = epoch_f1 / total_samples
    avg_precision = epoch_precision / total_samples
    avg_recall = epoch_recall / total_samples
    print(f"Epoch {epoch+1}/{num_epochs} | Loss: {avg_loss:.4f}, AUC: {avg_auc:.4f}, IoU: {avg_iou:.4f}, Dice/F1: {avg_f1:.4f}, Precision: {avg_precision:.4f}, Recall: {avg_recall:.4f} (Temp: {current_temp:.2f})")
    
    # Validation 평가 (라벨 데이터 사용)
    val_loss, val_auc, val_iou, val_f1, val_precision, val_recall, val_hd = evaluate_with_metrics(
        model, val_loader, device, threshold=0.5, lambda_boundary=0.2, pos_weight=None)
    print(f"[Val] Loss: {val_loss:.4f}, AUC: {val_auc:.4f}, IoU: {val_iou:.4f}, Dice/F1: {val_f1:.4f}, Precision: {val_precision:.4f}, Recall: {val_recall:.4f}, HD95: {val_hd:.4f}")
    
    # Early stopping: validation Dice/F1가 개선되지 않으면 patience_counter 증가
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        best_weights = copy.deepcopy(model.state_dict())
        torch.save(best_weights, 'best_model_with_pseudo.pth')
        patience_counter = 0
        print("🍀 New best validation Dice achieved! Dice: {:.4f} 🍀".format(best_val_f1))
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

print("Training complete. Best model saved as 'best_model_with_pseudo.pth'.")


Epoch 1/500 Training: 100%|██████████| 49/49 [01:47<00:00,  2.20s/it]


Epoch 1/500 | Loss: 1.0411, AUC: 0.8724, IoU: 0.4025, Dice/F1: 0.4417, Precision: 0.4259, Recall: 0.6278 (Temp: 5.00)
[Val] Loss: 1.0729, AUC: 0.8207, IoU: 0.4467, Dice/F1: 0.5280, Precision: 0.4490, Recall: 0.7296, HD95: 54.2330
🍀 New best validation Dice achieved! Dice: 0.5280 🍀


Epoch 2/500 Training: 100%|██████████| 49/49 [01:41<00:00,  2.08s/it]


Epoch 2/500 | Loss: 0.9619, AUC: 0.9496, IoU: 0.5168, Dice/F1: 0.5436, Precision: 0.5833, Recall: 0.5872 (Temp: 4.99)
[Val] Loss: 1.0467, AUC: 0.8288, IoU: 0.5386, Dice/F1: 0.5379, Precision: 0.4766, Recall: 0.6975, HD95: 35.4523
🍀 New best validation Dice achieved! Dice: 0.5379 🍀


Epoch 3/500 Training: 100%|██████████| 49/49 [01:41<00:00,  2.07s/it]


Epoch 3/500 | Loss: 0.9125, AUC: 0.9472, IoU: 0.5606, Dice/F1: 0.5732, Precision: 0.5921, Recall: 0.6266 (Temp: 4.98)
[Val] Loss: 0.9464, AUC: 0.8183, IoU: 0.6432, Dice/F1: 0.5992, Precision: 0.7003, Recall: 0.5519, HD95: 28.1701
🍀 New best validation Dice achieved! Dice: 0.5992 🍀


Epoch 4/500 Training: 100%|██████████| 49/49 [01:41<00:00,  2.08s/it]


Epoch 4/500 | Loss: 0.8651, AUC: 0.9588, IoU: 0.5884, Dice/F1: 0.5935, Precision: 0.6205, Recall: 0.6257 (Temp: 4.98)
[Val] Loss: 0.9136, AUC: 0.8297, IoU: 0.6902, Dice/F1: 0.6459, Precision: 0.6442, Recall: 0.6849, HD95: 26.8383
🍀 New best validation Dice achieved! Dice: 0.6459 🍀


Epoch 5/500 Training: 100%|██████████| 49/49 [01:41<00:00,  2.06s/it]


Epoch 5/500 | Loss: 0.8130, AUC: 0.9564, IoU: 0.6242, Dice/F1: 0.6133, Precision: 0.6348, Recall: 0.6455 (Temp: 4.97)
[Val] Loss: 0.7970, AUC: 0.8294, IoU: 0.7449, Dice/F1: 0.6575, Precision: 0.6984, Recall: 0.6487, HD95: 17.9238
🍀 New best validation Dice achieved! Dice: 0.6575 🍀


Epoch 6/500 Training: 100%|██████████| 49/49 [01:40<00:00,  2.06s/it]


Epoch 6/500 | Loss: 0.7998, AUC: 0.9636, IoU: 0.6401, Dice/F1: 0.6295, Precision: 0.6456, Recall: 0.6645 (Temp: 4.96)
[Val] Loss: 0.7831, AUC: 0.8403, IoU: 0.7529, Dice/F1: 0.6628, Precision: 0.6818, Recall: 0.6698, HD95: 17.3736
🍀 New best validation Dice achieved! Dice: 0.6628 🍀


Epoch 7/500 Training: 100%|██████████| 49/49 [01:40<00:00,  2.06s/it]


Epoch 7/500 | Loss: 0.7372, AUC: 0.9723, IoU: 0.6921, Dice/F1: 0.6530, Precision: 0.6800, Recall: 0.6640 (Temp: 4.95)
[Val] Loss: 0.8010, AUC: 0.8445, IoU: 0.7595, Dice/F1: 0.6768, Precision: 0.6627, Recall: 0.7132, HD95: 16.8521
🍀 New best validation Dice achieved! Dice: 0.6768 🍀


Epoch 8/500 Training: 100%|██████████| 49/49 [01:41<00:00,  2.06s/it]


Epoch 8/500 | Loss: 0.7493, AUC: 0.9778, IoU: 0.6717, Dice/F1: 0.6528, Precision: 0.6763, Recall: 0.6743 (Temp: 4.94)
[Val] Loss: 0.7432, AUC: 0.8478, IoU: 0.7865, Dice/F1: 0.6952, Precision: 0.6884, Recall: 0.7191, HD95: 14.5468
🍀 New best validation Dice achieved! Dice: 0.6952 🍀


Epoch 9/500 Training: 100%|██████████| 49/49 [01:40<00:00,  2.06s/it]


Epoch 9/500 | Loss: 0.7089, AUC: 0.9833, IoU: 0.6902, Dice/F1: 0.6624, Precision: 0.6856, Recall: 0.6757 (Temp: 4.94)
[Val] Loss: 0.6951, AUC: 0.8484, IoU: 0.8020, Dice/F1: 0.7038, Precision: 0.7098, Recall: 0.7136, HD95: 13.0606
🍀 New best validation Dice achieved! Dice: 0.7038 🍀


Epoch 10/500 Training: 100%|██████████| 49/49 [01:40<00:00,  2.05s/it]


Epoch 10/500 | Loss: 0.6972, AUC: 0.9854, IoU: 0.7224, Dice/F1: 0.6941, Precision: 0.7109, Recall: 0.7050 (Temp: 4.93)
[Val] Loss: 0.6928, AUC: 0.8492, IoU: 0.8072, Dice/F1: 0.7088, Precision: 0.7163, Recall: 0.7158, HD95: 12.6782
🍀 New best validation Dice achieved! Dice: 0.7088 🍀


Epoch 11/500 Training: 100%|██████████| 49/49 [01:40<00:00,  2.06s/it]


Epoch 11/500 | Loss: 0.6853, AUC: 0.9874, IoU: 0.7355, Dice/F1: 0.6881, Precision: 0.7024, Recall: 0.7029 (Temp: 4.92)
[Val] Loss: 0.6853, AUC: 0.8505, IoU: 0.8136, Dice/F1: 0.7094, Precision: 0.7145, Recall: 0.7191, HD95: 12.0262
🍀 New best validation Dice achieved! Dice: 0.7094 🍀


Epoch 12/500 Training: 100%|██████████| 49/49 [01:40<00:00,  2.05s/it]


Epoch 12/500 | Loss: 0.7033, AUC: 0.9791, IoU: 0.7045, Dice/F1: 0.6580, Precision: 0.6829, Recall: 0.6688 (Temp: 4.91)
[Val] Loss: 0.8074, AUC: 0.8476, IoU: 0.7770, Dice/F1: 0.7064, Precision: 0.6906, Recall: 0.7375, HD95: 16.9035


Epoch 13/500 Training:  16%|█▋        | 8/49 [00:16<01:24,  2.07s/it]

[Val] Loss: 0.7330, AUC: 0.8485, IoU: 0.7989, Dice/F1: 0.7055, Precision: 0.7069, Recall: 0.7186, HD95: 13.8151


Epoch 14/500 Training: 100%|██████████| 49/49 [01:39<00:00,  2.03s/it]


Epoch 14/500 | Loss: 0.7098, AUC: 0.9848, IoU: 0.6932, Dice/F1: 0.6739, Precision: 0.6992, Recall: 0.6955 (Temp: 4.90)
[Val] Loss: 0.7033, AUC: 0.8491, IoU: 0.7952, Dice/F1: 0.7062, Precision: 0.7279, Recall: 0.7028, HD95: 14.5662


Epoch 15/500 Training: 100%|██████████| 49/49 [01:41<00:00,  2.07s/it]


Epoch 15/500 | Loss: 0.6825, AUC: 0.9858, IoU: 0.7220, Dice/F1: 0.6760, Precision: 0.6935, Recall: 0.6928 (Temp: 4.89)
[Val] Loss: 0.6796, AUC: 0.8499, IoU: 0.8073, Dice/F1: 0.7103, Precision: 0.7140, Recall: 0.7253, HD95: 14.2156
🍀 New best validation Dice achieved! Dice: 0.7103 🍀


Epoch 16/500 Training: 100%|██████████| 49/49 [01:40<00:00,  2.06s/it]


Epoch 16/500 | Loss: 0.6644, AUC: 0.9865, IoU: 0.7305, Dice/F1: 0.6813, Precision: 0.7016, Recall: 0.6958 (Temp: 4.88)
[Val] Loss: 0.6798, AUC: 0.8511, IoU: 0.8094, Dice/F1: 0.7163, Precision: 0.7374, Recall: 0.7112, HD95: 13.0505
🍀 New best validation Dice achieved! Dice: 0.7163 🍀


Epoch 17/500 Training: 100%|██████████| 49/49 [01:40<00:00,  2.05s/it]


Epoch 17/500 | Loss: 0.6410, AUC: 0.9900, IoU: 0.7484, Dice/F1: 0.6918, Precision: 0.7100, Recall: 0.7040 (Temp: 4.87)
[Val] Loss: 0.6598, AUC: 0.8524, IoU: 0.8230, Dice/F1: 0.7236, Precision: 0.7219, Recall: 0.7400, HD95: 11.6106
🍀 New best validation Dice achieved! Dice: 0.7236 🍀


Epoch 18/500 Training: 100%|██████████| 49/49 [01:40<00:00,  2.04s/it]


Epoch 18/500 | Loss: 0.6433, AUC: 0.9891, IoU: 0.7370, Dice/F1: 0.6924, Precision: 0.7157, Recall: 0.7032 (Temp: 4.86)
[Val] Loss: 0.5870, AUC: 0.8512, IoU: 0.8356, Dice/F1: 0.7288, Precision: 0.7340, Recall: 0.7355, HD95: 10.1023
🍀 New best validation Dice achieved! Dice: 0.7288 🍀


Epoch 19/500 Training: 100%|██████████| 49/49 [01:40<00:00,  2.05s/it]


Epoch 19/500 | Loss: 0.6225, AUC: 0.9905, IoU: 0.7419, Dice/F1: 0.7063, Precision: 0.7294, Recall: 0.7141 (Temp: 4.86)
[Val] Loss: 0.5731, AUC: 0.8530, IoU: 0.8371, Dice/F1: 0.7276, Precision: 0.7208, Recall: 0.7456, HD95: 10.0082


Epoch 20/500 Training: 100%|██████████| 49/49 [01:40<00:00,  2.05s/it]


Epoch 20/500 | Loss: 0.6024, AUC: 0.9920, IoU: 0.7615, Dice/F1: 0.7171, Precision: 0.7330, Recall: 0.7264 (Temp: 4.85)
[Val] Loss: 0.5905, AUC: 0.8533, IoU: 0.8385, Dice/F1: 0.7277, Precision: 0.7345, Recall: 0.7331, HD95: 10.3703


Epoch 21/500 Training: 100%|██████████| 49/49 [01:40<00:00,  2.05s/it]


Epoch 21/500 | Loss: 0.6161, AUC: 0.9908, IoU: 0.7480, Dice/F1: 0.7075, Precision: 0.7246, Recall: 0.7233 (Temp: 4.84)
[Val] Loss: 0.5568, AUC: 0.8536, IoU: 0.8460, Dice/F1: 0.7344, Precision: 0.7464, Recall: 0.7345, HD95: 9.4208
🍀 New best validation Dice achieved! Dice: 0.7344 🍀


Epoch 22/500 Training: 100%|██████████| 49/49 [01:40<00:00,  2.04s/it]


Epoch 22/500 | Loss: 0.5878, AUC: 0.9931, IoU: 0.7615, Dice/F1: 0.7120, Precision: 0.7297, Recall: 0.7199 (Temp: 4.83)
[Val] Loss: 0.5992, AUC: 0.8538, IoU: 0.8411, Dice/F1: 0.7317, Precision: 0.7225, Recall: 0.7522, HD95: 9.8819


Epoch 23/500 Training: 100%|██████████| 49/49 [01:41<00:00,  2.08s/it]


Epoch 23/500 | Loss: 0.5679, AUC: 0.9948, IoU: 0.7809, Dice/F1: 0.7230, Precision: 0.7373, Recall: 0.7299 (Temp: 4.82)
[Val] Loss: 0.5680, AUC: 0.8539, IoU: 0.8446, Dice/F1: 0.7345, Precision: 0.7248, Recall: 0.7556, HD95: 9.7436
🍀 New best validation Dice achieved! Dice: 0.7345 🍀


Epoch 24/500 Training: 100%|██████████| 49/49 [01:41<00:00,  2.07s/it]


Epoch 24/500 | Loss: 0.5712, AUC: 0.9938, IoU: 0.7827, Dice/F1: 0.7211, Precision: 0.7287, Recall: 0.7348 (Temp: 4.82)
[Val] Loss: 0.5736, AUC: 0.8541, IoU: 0.8534, Dice/F1: 0.7401, Precision: 0.7354, Recall: 0.7545, HD95: 9.4841
🍀 New best validation Dice achieved! Dice: 0.7401 🍀


Epoch 25/500 Training: 100%|██████████| 49/49 [01:40<00:00,  2.06s/it]


Epoch 25/500 | Loss: 0.5525, AUC: 0.9940, IoU: 0.7896, Dice/F1: 0.7353, Precision: 0.7426, Recall: 0.7483 (Temp: 4.81)
[Val] Loss: 0.5249, AUC: 0.8543, IoU: 0.8551, Dice/F1: 0.7399, Precision: 0.7596, Recall: 0.7321, HD95: 8.7593


Epoch 26/500 Training: 100%|██████████| 49/49 [01:40<00:00,  2.05s/it]


Epoch 26/500 | Loss: 0.5583, AUC: 0.9938, IoU: 0.7866, Dice/F1: 0.7353, Precision: 0.7504, Recall: 0.7424 (Temp: 4.80)
[Val] Loss: 0.5313, AUC: 0.8545, IoU: 0.8615, Dice/F1: 0.7450, Precision: 0.7455, Recall: 0.7535, HD95: 8.5376
🍀 New best validation Dice achieved! Dice: 0.7450 🍀


Epoch 27/500 Training: 100%|██████████| 49/49 [01:40<00:00,  2.05s/it]


Epoch 27/500 | Loss: 0.5379, AUC: 0.9958, IoU: 0.7882, Dice/F1: 0.7248, Precision: 0.7348, Recall: 0.7339 (Temp: 4.79)
[Val] Loss: 0.5061, AUC: 0.8545, IoU: 0.8627, Dice/F1: 0.7444, Precision: 0.7520, Recall: 0.7450, HD95: 8.0646


Epoch 28/500 Training:  51%|█████     | 25/49 [00:51<00:48,  2.04s/it]

In [10]:
# Best model 가중치 로드 (파일명은 학습 시 저장한 파일명으로 변경)
model.load_state_dict(torch.load("best_model_with_pseudo.pth", map_location=device))

# 모델을 evaluation 모드로 전환
model.eval()
with torch.no_grad():
    test_loss, test_auc, test_iou, test_f1, test_precision, test_recall, test_hd = evaluate_with_metrics(
        model, test_loader, device, threshold=0.5, lambda_boundary=0.2, pos_weight=pos_weight
    )

print(f"[Test] Loss: {test_loss:.4f}, AUC: {test_auc:.4f}, IoU: {test_iou:.4f}, Dice/F1: {test_f1:.4f}, "
      f"Precision: {test_precision:.4f}, Recall: {test_recall:.4f}, HD95: {test_hd:.4f}")


[Test] Loss: 0.4169, AUC: 0.9989, IoU: 0.8918, Dice/F1: 0.7728, Precision: 0.7755, Recall: 0.7755, HD95: 5.6043


In [None]:
a = 1
a

# full code

In [1]:
! pip install medpy
! pip install segmentation_models_pytorch
! pip install -U albumentations

Collecting medpy
  Downloading medpy-0.5.2.tar.gz (156 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m156.3/156.3 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: medpy
  Building wheel for medpy (setup.py) ... [?25l[?25hdone
  Created wheel for medpy: filename=MedPy-0.5.2-cp310-cp310-linux_x86_64.whl size=762835 sha256=045e8a1051aaa6bff85809d7c68395da517aa5dae22f0fa368964d5b67899b5f
  Stored in directory: /root/.cache/pip/wheels/a1/b8/63/bdf557940ec60d1b8822e73ff9fbe7727ac19f009d46b5d175
Successfully built medpy
Installing collected packages: medpy
Successfully installed medpy-0.5.2
Collecting segmentation_models_pytorch
  Downloading segmentation_models_pytorch-0.4.0-py3-none-any.whl.metadata (32 kB)
Collecting efficientnet-pytorch>=0.6.1 (from segmentation_models_pytorch)
  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)
  Preparing metadata (setup.py) ... [?

In [4]:
import os
import re
import random
import warnings
import numpy as np
from collections import defaultdict
from PIL import Image
import cv2
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset
import torchvision.transforms as transforms
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
from skimage.segmentation import find_boundaries
from skimage.measure import label as sk_label, regionprops
import albumentations as A
from albumentations.pytorch import ToTensorV2
from medpy.metric import binary as medpy_binary
from torchvision.transforms import ColorJitter

import os

# /kaggle/working/ 경로에 "unlabeled-busi-images" 폴더 생성
folder_path = '/kaggle/breast-ultrasound-images-dataset/Dataset_BUSI_with_GT/unlabeled-busi-images/'
os.makedirs(folder_path, exist_ok=True)
print(f"Folder created: {folder_path}")


# ------------------------------
# Seed 및 Warning 설정
# ------------------------------
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(42)
warnings.filterwarnings('ignore')

# ------------------------------
# BUSI Segmentation Dataset
# ------------------------------
class BUSISegmentationDataset(Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.transform = transform  
        self.samples = []
        self._prepare_samples()

    def _prepare_samples(self):
        labels = os.listdir(self.data_path)
        for label in labels:
            folder_path = os.path.join(self.data_path, label)
            if not os.path.isdir(folder_path):
                continue
            files = os.listdir(folder_path)
            image_files = sorted([f for f in files if '_mask' not in f and f.endswith('.png')])
            mask_files  = sorted([f for f in files if '_mask' in f and f.endswith('.png')])
            pattern_img = re.compile(rf'{re.escape(label)} \((\d+)\)\.png')
            pattern_mask = re.compile(rf'{re.escape(label)} \((\d+)\)_mask(?:_\d+)?\.png')
            mask_dict = {}
            for mf in mask_files:
                m = pattern_mask.fullmatch(mf)
                if m:
                    idx = m.group(1)
                    mask_dict.setdefault(idx, []).append(mf)
            for im in image_files:
                m = pattern_img.fullmatch(im)
                if m:
                    idx = m.group(1)
                    img_path = os.path.join(folder_path, im)
                    if idx in mask_dict:
                        mask_paths = [os.path.join(folder_path, mf) for mf in mask_dict[idx]]
                        combined_mask = None
                        for mp in mask_paths:
                            mask_img = Image.open(mp).convert('L')
                            mask_arr = np.array(mask_img)
                            mask_binary = (mask_arr > 128).astype(np.uint8)
                            if combined_mask is None:
                                combined_mask = mask_binary
                            else:
                                combined_mask = np.maximum(combined_mask, mask_binary)
                        self.samples.append((img_path, combined_mask, label))
                    else:
                        image_pil = Image.open(img_path)
                        empty_mask = np.zeros(image_pil.size[::-1], dtype=np.uint8)
                        self.samples.append((img_path, empty_mask, label))
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, index):
        img_path, mask_array, label = self.samples[index]
        image = Image.open(img_path).convert('RGB')
        mask = Image.fromarray((mask_array * 255).astype(np.uint8))
        if self.transform:
            image, mask = self.transform(image, mask)
        else:
            image = transforms.ToTensor()(image)
            mask = transforms.ToTensor()(mask)
        return image, mask, label

# ------------------------------
# Utility Functions
# ------------------------------
def z_score_normalize(tensor):
    mean = tensor.mean()
    std = tensor.std() + 1e-6
    return (tensor - mean) / std

def joint_transform(image, mask, size=(256,256)):
    geom_transform = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.Rotate(limit=10, p=0.5),
        A.ElasticTransform(alpha=10, sigma=5, alpha_affine=5, p=0.3),
        A.Resize(height=size[0], width=size[1])
    ])
    image_np = np.array(image)
    mask_np = np.array(mask)
    augmented = geom_transform(image=image_np, mask=mask_np)
    image = augmented['image']
    mask = augmented['mask']
    
    intensity_transform = A.Compose([
        A.RandomBrightnessContrast(p=0.5),
        A.CLAHE(clip_limit=1.0, tile_grid_size=(8,8), p=0.5),
        A.GaussianBlur(p=0.3)
    ])
    image = intensity_transform(image=image)['image']
    
    image = transforms.ToTensor()(image)
    image = z_score_normalize(image)
    mask = transforms.ToTensor()(mask)
    return image, mask

def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    cx = np.random.randint(W)
    cy = np.random.randint(H)
    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    return bbx1, bby1, bbx2, bby2

def cutmix_data(images, masks, alpha=1.5, p=0.7):
    if np.random.rand() > p:
        return images, masks
    lam = np.random.beta(alpha, alpha)
    rand_index = torch.randperm(images.size(0)).to(images.device)
    bbx1, bby1, bbx2, bby2 = rand_bbox(images.size(), lam)
    images[:, :, bbx1:bbx2, bby1:bby2] = images[rand_index, :, bbx1:bbx2, bby1:bby2]
    masks[:, :, bbx1:bbx2, bby1:bby2] = masks[rand_index, :, bbx1:bbx2, bby1:bby2]
    return images, masks

def postprocess_mask(mask, min_size=100):
    # postprocess_mask: 빈 영역을 제거하는 처리
    labeled_mask = sk_label(mask)
    processed_mask = np.zeros_like(mask)
    for region in regionprops(labeled_mask):
        if region.area >= min_size:
            processed_mask[labeled_mask == region.label] = 1
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (9,9))
    closed = cv2.morphologyEx(processed_mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
    return closed

# ------------------------------
# Pseudo Labeling 관련 Dataset 및 함수
# ------------------------------
class UnlabeledDataset(Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.transform = transform
        self.image_paths = sorted([os.path.join(data_path, fname) for fname in os.listdir(data_path) if fname.endswith('.png')])
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, index):
        img_path = self.image_paths[index]
        image = Image.open(img_path).convert("RGB")
        dummy_mask = Image.new("L", image.size, 0)
        if self.transform:
            image, _ = self.transform(image, dummy_mask)
        else:
            image = transforms.ToTensor()(image)
        return image, img_path

def generate_pseudo_labels(model, unlabeled_loader, device, threshold=0.9):
    model.eval()
    pseudo_data = []
    with torch.no_grad():
        for images, paths in unlabeled_loader:
            images = images.to(device)
            outputs = model(images)
            preds = torch.sigmoid(outputs)
            pseudo_masks = (preds > threshold).float()
            for i in range(images.size(0)):
                pseudo_data.append((images[i].cpu(), pseudo_masks[i].cpu()))
    return pseudo_data

class CombinedDataset(Dataset):
    def __init__(self, labeled_dataset, pseudo_data):
        self.labeled_dataset = labeled_dataset
        self.pseudo_data = pseudo_data

    def __len__(self):
        return len(self.labeled_dataset) + len(self.pseudo_data)
    
    def __getitem__(self, index):
        if index < len(self.labeled_dataset):
            return self.labeled_dataset[index]
        else:
            pseudo_index = index - len(self.labeled_dataset)
            image, mask = self.pseudo_data[pseudo_index]
            return image, mask, "pseudo"

# ------------------------------
# Model 아키텍처 (PretrainedSwin_UNet_AttentionFusion 등)
# ------------------------------
# ※ 모델 관련 코드는 그대로 사용합니다. (필요한 경우 중복된 부분은 제거하였음)
from timm.models import create_model

class UNetDecoder_Swin(nn.Module):
    def __init__(self):
        super(UNetDecoder_Swin, self).__init__()
        self.up1 = nn.ConvTranspose2d(1536, 768, kernel_size=2, stride=2)
        self.conv1 = nn.Sequential(
            nn.Conv2d(768 + 768, 768, kernel_size=3, padding=1),
            nn.BatchNorm2d(768),
            nn.ReLU(inplace=True)
        )
        self.up2 = nn.ConvTranspose2d(768, 384, kernel_size=2, stride=2)
        self.conv2 = nn.Sequential(
            nn.Conv2d(384 + 384, 384, kernel_size=3, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU(inplace=True)
        )
        self.up3 = nn.ConvTranspose2d(384, 192, kernel_size=2, stride=2)
        self.conv3 = nn.Sequential(
            nn.Conv2d(192 + 192, 192, kernel_size=3, padding=1),
            nn.BatchNorm2d(192),
            nn.ReLU(inplace=True)
        )
        self.up4 = nn.ConvTranspose2d(192, 64, kernel_size=2, stride=2)
        self.conv4 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
    def forward(self, features_list):
        f3, f2, f1, f0 = features_list
        x = f3
        x = self.up1(x)
        x = torch.cat([x, f2], dim=1)
        x = self.conv1(x)
        x = self.up2(x)
        x = torch.cat([x, f1], dim=1)
        x = self.conv2(x)
        x = self.up3(x)
        x = torch.cat([x, f0], dim=1)
        x = self.conv3(x)
        x = self.up4(x)
        x = self.conv4(x)
        x = F.interpolate(x, size=(256,256), mode='bilinear', align_corners=False)
        return x

class TransformerBottleneck(nn.Module):
    def __init__(self, d_model=1536, nhead=8, num_layers=1, dim_feedforward=2048, dropout=0.1):
        super(TransformerBottleneck, self).__init__()
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead,
                                                    dim_feedforward=dim_feedforward, dropout=dropout)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
    def forward(self, x):
        B, C, H, W = x.size()
        x = x.view(B, C, H * W).permute(2, 0, 1)
        x = self.transformer(x)
        x = x.permute(1, 2, 0).view(B, C, H, W)
        return x

class PretrainedSwin_UNet_AttentionFusion(nn.Module):
    def __init__(self, out_channels=1):
        super(PretrainedSwin_UNet_AttentionFusion, self).__init__()
        self.encoder = create_model('swin_large_patch4_window7_224', pretrained=True, features_only=True, img_size=256)
        self.bottleneck = QuadAgentBlock(in_channels=1536, out_channels=1536, gating_dropout=0.3, gating_hidden_dim=32, gating_temperature=1.5, stochastic_depth_prob=0.5)
        self.transformer_bottleneck1 = TransformerBottleneck(d_model=1536, nhead=8, num_layers=1, dim_feedforward=2048, dropout=0.1)
        self.transformer_bottleneck2 = TransformerBottleneck(d_model=1536, nhead=8, num_layers=1, dim_feedforward=2048, dropout=0.1)
        self.decoder = UNetDecoder_Swin()
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
        
    def forward(self, x):
        features_list = self.encoder(x)  # [f0, f1, f2, f3]

        # (N, H, W, C) → (N, C, H, W) 변환이 필요한 경우 적용
        for i in range(len(features_list)):
            if features_list[i].shape[1] not in {192, 384, 768, 1536}:  
                features_list[i] = features_list[i].permute(0, 3, 1, 2).contiguous()

        # 디코더에 올바른 순서로 전달되도록 변환
        f0, f1, f2, f3 = features_list  # Swin Transformer의 기본 반환 순서
        features_list = [f3, f2, f1, f0]  # (B, 1536, 7, 7) → (B, 192, 56, 56)

        # Bottleneck과 Transformer Bottleneck 적용
        b_out = self.bottleneck(features_list[0])
        t_out1 = self.transformer_bottleneck1(b_out)
        t_out2 = self.transformer_bottleneck2(t_out1)
        features_list[0] = t_out2

        decoded = self.decoder(features_list)
        return self.final_conv(decoded)

# ------------------------------
# Loss 및 Metric 함수
# ------------------------------
from segmentation_models_pytorch.losses import LovaszLoss
lv_loss = LovaszLoss(mode='binary')

def focal_tversky_loss(pred, target, alpha=0.5, beta=0.5, gamma=4/3, smooth=1e-5):
    pred = torch.sigmoid(pred)
    if pred.shape[-2:] != target.shape[-2:]:
        pred = F.interpolate(pred, size=target.shape[-2:], mode='bilinear', align_corners=False)
    pred = pred.view(pred.size(0), -1)
    target = target.view(target.size(0), -1)
    tp = (pred * target).sum(dim=1)
    fp = ((1 - target) * pred).sum(dim=1)
    fn = (target * (1 - pred)).sum(dim=1)
    tversky_index = (tp + smooth) / (tp + alpha * fp + beta * fn + smooth)
    loss = (1 - tversky_index) ** gamma
    return loss.mean()

def boundary_loss(pred, target):
    pred = torch.sigmoid(pred)
    if pred.shape[-2:] != target.shape[-2:]:
        pred = F.interpolate(pred, size=target.shape[-2:], mode='bilinear', align_corners=False)
    pred_np = pred.detach().cpu().numpy()
    target_np = target.detach().cpu().numpy()
    boundary_masks = []
    for i in range(pred_np.shape[0]):
        gt_mask = target_np[i, 0]
        boundary = find_boundaries(gt_mask, mode='thick')
        boundary_masks.append(boundary.astype(np.float32))
    boundary_masks = np.stack(boundary_masks, axis=0)[:, None, :, :]
    boundary_masks_torch = torch.from_numpy(boundary_masks).to(pred.device)
    intersect = (pred * boundary_masks_torch).sum()
    denom = pred.sum() + boundary_masks_torch.sum()
    boundary_dice = (2.0 * intersect) / (denom + 1e-5)
    return 1.0 - boundary_dice

def combined_loss(outputs, masks, pos_weight=None, lambda_boundary=0.2, lambda_lovasz=0.6):
    if outputs.shape[-2:] != masks.shape[-2:]:
        outputs = F.interpolate(outputs, size=masks.shape[-2:], mode='bilinear', align_corners=False)
    loss_ft = focal_tversky_loss(outputs, masks)
    if pos_weight is not None:
        loss_bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(outputs, masks.float())
    else:
        loss_bce = nn.BCEWithLogitsLoss()(outputs, masks.float())
    bl = boundary_loss(outputs, masks)
    lv = lv_loss(outputs, masks)
    return 0.4 * lv + 0.3 * loss_ft + 0.3 * loss_bce + lambda_boundary * bl

def dice_f1_precision_recall(pred, target, threshold=0.5, smooth=1e-5):
    if pred.shape[-2:] != target.shape[-2:]:
        pred = F.interpolate(pred, size=target.shape[-2:], mode='bilinear', align_corners=False)
    pred_bin = (torch.sigmoid(pred) > threshold).float()
    target_bin = target.float()
    intersection = (pred_bin * target_bin).sum()
    precision = intersection / (pred_bin.sum() + smooth)
    recall = intersection / (target_bin.sum() + smooth)
    f1 = 2 * (precision * recall) / (precision + recall + smooth)
    return f1.item(), precision.item(), recall.item()

def iou_metric(pred, target, threshold=0.5, smooth=1e-5):
    if pred.shape[-2:] != target.shape[-2:]:
        pred = F.interpolate(pred, size=target.shape[-2:], mode='bilinear', align_corners=False)
    pred_bin = (torch.sigmoid(pred) > threshold).float()
    target_bin = target.float()
    intersection = (pred_bin * target_bin).sum()
    union = pred_bin.sum() + target_bin.sum() - intersection
    return (intersection + smooth) / (union + smooth)

def compute_batch_metrics_new(outputs, masks, threshold=0.5, smooth=1e-5):
    if outputs.shape[-2:] != masks.shape[-2:]:
        outputs = F.interpolate(outputs, size=masks.shape[-2:], mode='bilinear', align_corners=False)
    pred_probs = torch.sigmoid(outputs).detach()
    processed_preds = []
    for i in range(pred_probs.size(0)):
        pred_np = pred_probs[i].cpu().numpy()[0]
        pred_bin = (pred_np > threshold).astype(np.uint8)
        processed = postprocess_mask(pred_bin, min_size=100)
        processed_preds.append(processed)
    batch_iou = []
    batch_f1 = []
    batch_precision = []
    batch_recall = []
    for i in range(pred_probs.size(0)):
        pred = processed_preds[i]
        gt = masks[i].cpu().numpy()[0]
        intersection = np.sum(pred * gt)
        union = np.sum(pred) + np.sum(gt) - intersection
        iou = (intersection + smooth) / (union + smooth)
        batch_iou.append(iou)
        precision = intersection / (np.sum(pred) + smooth)
        recall = intersection / (np.sum(gt) + smooth)
        f1 = 2 * (precision * recall) / (precision + recall + smooth)
        batch_f1.append(f1)
        batch_precision.append(precision)
        batch_recall.append(recall)
    probs = pred_probs.cpu().numpy().flatten()
    masks_np = masks.cpu().numpy().flatten()
    # AUC 계산: 만약 ground truth에 0 또는 1만 존재하면 0을 반환
    if np.all(masks_np == 0) or np.all(masks_np == 1):
        auc_score = 0.0
    else:
        try:
            auc_score = roc_auc_score(masks_np, probs)
        except ValueError:
            auc_score = 0.0
    return auc_score, np.mean(batch_iou), np.mean(batch_f1), np.mean(batch_precision), np.mean(batch_recall)

def compute_hd_metric(outputs, masks, threshold=0.5):
    if outputs.shape[-2:] != masks.shape[-2:]:
        outputs = F.interpolate(outputs, size=masks.shape[-2:], mode='bilinear', align_corners=False)
    pred_probs = torch.sigmoid(outputs).detach().cpu().numpy()
    masks_np = masks.detach().cpu().numpy()
    hd_list = []
    for i in range(outputs.size(0)):
        pred_bin = (pred_probs[i, 0] > threshold).astype(np.uint8)
        gt_bin = (masks_np[i, 0] > 0.5).astype(np.uint8)
        try:
            hd = medpy_binary.hd95(pred_bin, gt_bin)
        except Exception:
            hd = np.nan
        hd_list.append(hd)
    return np.nanmean(hd_list)

def evaluate_with_metrics(model, dataloader, device, threshold=0.5, lambda_boundary=0.2, pos_weight=None):
    model.eval()
    total_loss = 0.0
    total_auc  = 0.0
    total_iou  = 0.0
    total_f1   = 0.0
    total_precision = 0.0
    total_recall = 0.0
    total_hd = 0.0
    total_samples = 0
    with torch.no_grad():
        for images, masks, _ in dataloader:
            images = images.to(device)
            masks = masks.to(device)
            outputs = model(images)
            loss = combined_loss(outputs, masks, pos_weight=pos_weight, lambda_boundary=lambda_boundary)
            auc_score, iou_val, f1_val, prec, rec = compute_batch_metrics_new(outputs, masks, threshold=threshold, smooth=1e-5)
            hd_val = compute_hd_metric(outputs, masks, threshold=threshold)
            bs = images.size(0)
            total_loss += loss.item() * bs
            total_auc  += auc_score * bs
            total_iou  += iou_val * bs
            total_f1   += f1_val * bs
            total_precision += prec * bs
            total_recall += rec * bs
            total_hd += hd_val * bs
            total_samples += bs
    avg_loss = total_loss / total_samples
    avg_auc  = total_auc / total_samples
    avg_iou  = total_iou / total_samples
    avg_f1   = total_f1 / total_samples
    avg_precision = total_precision / total_samples
    avg_recall = total_recall / total_samples
    avg_hd = total_hd / total_samples
    return avg_loss, avg_auc, avg_iou, avg_f1, avg_precision, avg_recall, avg_hd

# ------------------------------
# TTA 함수 (Test Time Augmentation)
# ------------------------------
def tta_predict(model, image, device):
    cj = ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)
    def add_noise(x, std=0.05):
        noise = torch.randn_like(x) * std
        return x + noise
    def identity(x): return x
    def hflip(x): return torch.flip(x, dims=[-1])
    def vflip(x): return torch.flip(x, dims=[-2])
    def rot90(x): return torch.rot90(x, k=1, dims=[-2, -1])
    def inv_hflip(x): return torch.flip(x, dims=[-1])
    def inv_vflip(x): return torch.flip(x, dims=[-2])
    def inv_rot90(x): return torch.rot90(x, k=3, dims=[-2, -1])
    
    transforms_list = [
        (identity, identity),
        (hflip, inv_hflip),
        (vflip, inv_vflip),
        (rot90, inv_rot90),
        (lambda x: add_noise(cj(x)), lambda x: x)
    ]
    predictions = []
    model.eval()
    with torch.no_grad():
        for aug, inv in transforms_list:
            augmented = aug(image)
            output = model(augmented.unsqueeze(0).to(device))
            output = torch.sigmoid(output)
            output = inv(output).cpu()
            predictions.append(output)
    avg_prediction = torch.mean(torch.stack(predictions), dim=0)
    return avg_prediction

# ------------------------------
# 데이터 로더 설정 (테스트 포함)
# ------------------------------
data_path = '/kaggle/input/breast-ultrasound-images-dataset/Dataset_BUSI_with_GT/'
full_dataset = BUSISegmentationDataset(data_path, transform=joint_transform)
indices = np.arange(len(full_dataset))
train_val_idx, test_idx = train_test_split(indices, test_size=0.2, random_state=42)
train_idx, val_idx = train_test_split(train_val_idx, test_size=0.25, random_state=42)
train_dataset = Subset(full_dataset, train_idx)
val_dataset = Subset(full_dataset, val_idx)
test_dataset = Subset(full_dataset, test_idx)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0,
                          worker_init_fn=lambda worker_id: np.random.seed(42 + worker_id))
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=0)

def calculate_pos_weight(loader):
    total_pixels = 0
    positive_pixels = 0
    for images, masks, _ in loader:
        positive_pixels += masks.sum().item()
        total_pixels += masks.numel()
    negative_pixels = total_pixels - positive_pixels
    return torch.tensor(negative_pixels / (positive_pixels + 1e-6)).to(device)

# ------------------------------
# 모델, Optimizer, Scheduler 설정 (Training Loop)
# ------------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PretrainedSwin_UNet_AttentionFusion(out_channels=1)
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
model = model.to(device)

pos_weight = calculate_pos_weight(train_loader)
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=5e-4)
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-5)

scaler = GradScaler()
patience = 50
best_val_f1 = 0.0
patience_counter = 0

min_temp = 2.0
max_temp = 5.0

num_epochs = 500
combined_loader = None

for epoch in range(num_epochs):
    current_temp = max_temp - ((max_temp - min_temp) * epoch / num_epochs)
    if hasattr(model, 'bottleneck'):
        if isinstance(model, nn.DataParallel):
            model.module.bottleneck.gating.temperature = current_temp
        else:
            model.bottleneck.gating.temperature = current_temp

    model.train()
    epoch_loss = 0.0
    epoch_auc = 0.0
    epoch_iou = 0.0
    epoch_f1 = 0.0
    epoch_precision = 0.0
    epoch_recall = 0.0
    total_samples = 0

    if epoch < 30:
        train_loader = train_loader  # labeled_loader 사용
    else:
        if epoch == 30:
            print("===> Generating pseudo labels on unlabeled data...")
            pseudo_loader = DataLoader(UnlabeledDataset('/kaggle/input/breast-ultrasound-images-dataset/Dataset_BUSI_with_GT/unlabeled-busi-images/', transform=joint_transform),
                                       batch_size=16, shuffle=False, num_workers=0)
            pseudo_data = generate_pseudo_labels(model, pseudo_loader, device, threshold=0.9)
            print(f"Pseudo labels generated for {len(pseudo_data)} images.")
            combined_dataset = CombinedDataset(train_dataset, pseudo_data)
            combined_loader = DataLoader(combined_dataset, batch_size=16, shuffle=True, num_workers=0)
            train_loader = combined_loader
        else:
            train_loader = combined_loader

    for images, masks, _ in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} Training"):
        images = images.to(device)
        masks = masks.to(device)
        images, masks = cutmix_data(images, masks, alpha=0.8, p=0.7)
        optimizer.zero_grad()
        with autocast():
            outputs = model(images)
            loss = combined_loss(outputs, masks, pos_weight=None, lambda_boundary=0.2, lambda_lovasz=0.6)
        if torch.isnan(loss):
            print("Warning: loss is NaN, skipping batch")
            continue
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
        scaler.step(optimizer)
        scaler.update()

        bs = images.size(0)
        epoch_loss += loss.item() * bs

        auc_score, iou_val, f1_val, prec, rec = compute_batch_metrics_new(outputs, masks, threshold=0.5, smooth=1e-5)
        epoch_auc += auc_score * bs
        epoch_iou += iou_val * bs
        epoch_f1 += f1_val * bs
        epoch_precision += prec * bs
        epoch_recall += rec * bs
        total_samples += bs

    scheduler.step(epoch)
    avg_loss = epoch_loss / total_samples if total_samples > 0 else float('nan')
    avg_auc = epoch_auc / total_samples if total_samples > 0 else float('nan')
    avg_iou = epoch_iou / total_samples if total_samples > 0 else float('nan')
    avg_f1 = epoch_f1 / total_samples if total_samples > 0 else float('nan')
    avg_precision = epoch_precision / total_samples if total_samples > 0 else float('nan')
    avg_recall = epoch_recall / total_samples if total_samples > 0 else float('nan')
    print(f"Epoch {epoch+1}/{num_epochs} | Loss: {avg_loss:.4f}, AUC: {avg_auc:.4f}, IoU: {avg_iou:.4f}, Dice/F1: {avg_f1:.4f}, Precision: {avg_precision:.4f}, Recall: {avg_recall:.4f} (Temp: {current_temp:.2f})")
    
    val_loss, val_auc, val_iou, val_f1, val_precision, val_recall, val_hd = evaluate_with_metrics(model, val_loader, device, threshold=0.5, lambda_boundary=0.2, pos_weight=None)
    print(f"[Val] Loss: {val_loss:.4f}, AUC: {val_auc:.4f}, IoU: {val_iou:.4f}, Dice/F1: {val_f1:.4f}, Precision: {val_precision:.4f}, Recall: {val_recall:.4f}, HD95: {val_hd:.4f}")
    
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        best_weights = copy.deepcopy(model.state_dict())
        torch.save(best_weights, 'best_model_with_pseudo.pth')
        patience_counter = 0
        print("🍀 New best validation Dice achieved! Dice: {:.4f} 🍀".format(best_val_f1))
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

print("Training complete. Best model saved as 'best_model_with_pseudo.pth'.")


Folder created: /kaggle/breast-ultrasound-images-dataset/Dataset_BUSI_with_GT/unlabeled-busi-images/


Epoch 1/500 Training: 100%|██████████| 30/30 [01:23<00:00,  2.79s/it]


Epoch 1/500 | Loss: 1.0333, AUC: 0.8713, IoU: 0.3501, Dice/F1: 0.4121, Precision: 0.4011, Recall: 0.5522 (Temp: 5.00)
[Val] Loss: 1.0126, AUC: 0.9479, IoU: 0.4751, Dice/F1: 0.5198, Precision: 0.4828, Recall: 0.6790, HD95: 66.5344
🍀 New best validation Dice achieved! Dice: 0.5198 🍀


Epoch 2/500 Training:  37%|███▋      | 11/30 [00:25<00:43,  2.29s/it]


KeyboardInterrupt: 