<a href="https://colab.research.google.com/github/BSEU-Misal/SWUs-for-Cervical-Spinal-Cord/blob/main/SwinUnets.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Model SWU1: Encoder = Swin Transformer, Decoder = CNN

In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import timm

class DecoderBlockCNN(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch + skip_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x, skip):
        x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
        x = torch.cat([x, skip], dim=1)
        return self.conv(x)

class SwinEncoder_CNNDecoder_UNet(nn.Module):
    def __init__(self, img_size=384, num_classes=1):
        super().__init__()
        self.img_size = img_size
        self.num_classes = num_classes
        self.backbone = timm.create_model(
            'swin_base_patch4_window12_384',
            pretrained=True,
            features_only=True
        )
        chs = self.backbone.feature_info.channels()  # [96,192,384,768]

        self.dec4 = DecoderBlockCNN(chs[3], chs[2], 512)
        self.dec3 = DecoderBlockCNN(512, chs[1], 256)
        self.dec2 = DecoderBlockCNN(256, chs[0], 128)
        self.dec1 = DecoderBlockCNN(128, 64, 64)  # skip_ch = 64 (örnek)
        self.final_conv = nn.Conv2d(64, num_classes, kernel_size=1)

    def forward(self, x):
        feats = self.backbone(x)
        feats = [f.permute(0, 3, 1, 2).contiguous() for f in feats]  # [B,H,W,C] → [B,C,H,W]

        f1, f2, f3, f4 = feats
        d4 = self.dec4(f4, f3)
        d3 = self.dec3(d4, f2)
        d2 = self.dec2(d3, f1)
        d1 = self.dec1(d2, f1)
        out = self.final_conv(d1)


        out = F.interpolate(out, size=x.shape[2:], mode='bilinear', align_corners=False)
        return out


Model SWU2: Encoder = CNN, Decoder = Swin Transformer

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from timm.models.swin_transformer import SwinTransformerBlock

class CNNEncoderBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_ch)
    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        return x

class SwinDecoderBlock(nn.Module):
    def __init__(self, dim, skip_dim, resolution, num_heads, window_size=7):
        super().__init__()
        self.up = nn.ConvTranspose2d(dim, dim//2, kernel_size=2, stride=2)
        self.norm = nn.LayerNorm((dim//2 + skip_dim))
        self.linear = nn.Linear(dim//2 + skip_dim, dim//2)
        self.swin1 = SwinTransformerBlock(dim=dim//2, input_resolution=resolution, num_heads=num_heads, window_size=window_size)
        self.swin2 = SwinTransformerBlock(dim=dim//2, input_resolution=resolution, num_heads=num_heads, window_size=window_size, shift_size=window_size//2)

    def forward(self, x, skip):
        x = self.up(x)
        x = torch.cat([x, skip], dim=1)
        B,C,H,W = x.shape
        x = x.permute(0,2,3,1).contiguous()
        x = self.norm(x)
        x = self.linear(x)
        x = self.swin1(x)
        x = self.swin2(x)
        return x.permute(0,3,1,2).contiguous()

class CNNEncoder_SwinDecoder_UNet(nn.Module):
    def __init__(self, img_size=384, num_classes=1):
        super().__init__()
        self.img_size = img_size
        self.enc1 = CNNEncoderBlock(3,64)
        self.enc2 = CNNEncoderBlock(64,128)
        self.enc3 = CNNEncoderBlock(128,256)
        self.enc4 = CNNEncoderBlock(256,512)
        resolution = img_size // 16

        self.dec4 = SwinDecoderBlock(512,256,(resolution, resolution), num_heads=8)
        self.dec3 = SwinDecoderBlock(256,128,(resolution*2, resolution*2), num_heads=4)
        self.dec2 = SwinDecoderBlock(128,64,(resolution*4, resolution*4), num_heads=2)
        self.dec1 = nn.Sequential(
            nn.Conv2d(64,64,3,padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64,64,3,padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )
        self.final_conv = nn.Conv2d(64, num_classes, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)
        d4 = self.dec4(e4, e3)
        d3 = self.dec3(d4, e2)
        d2 = self.dec2(d3, e1)
        d1 = self.dec1(d2)
        out = self.final_conv(d1)
        return F.interpolate(out, size=(self.img_size, self.img_size),
                             mode='bilinear', align_corners=False)


Model SWU3: Full Swin‑UNet (Encoder + Decoder = Swin Transformer)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from timm.models.swin_transformer import SwinTransformerBlock

class SwinDecoderTransformerBlock(nn.Module):
    def __init__(self, dim, skip_dim, resolution, num_heads, window_size=7):
        super().__init__()
        self.up = nn.ConvTranspose2d(dim, dim//2, kernel_size=2, stride=2)
        self.norm = nn.LayerNorm((dim//2 + skip_dim))
        self.linear = nn.Linear(dim//2 + skip_dim, dim//2)
        self.swin1 = SwinTransformerBlock(dim=dim//2, input_resolution=resolution, num_heads=num_heads, window_size=window_size)
        self.swin2 = SwinTransformerBlock(dim=dim//2, input_resolution=resolution, num_heads=num_heads, window_size=window_size, shift_size=window_size//2)

    def forward(self, x, skip):
        x = self.up(x)
        x = torch.cat([x, skip], dim=1)
        B,C,H,W = x.shape
        x = x.permute(0,2,3,1).contiguous()
        x = self.norm(x)
        x = self.linear(x)
        x = self.swin1(x)
        x = self.swin2(x)
        return x.permute(0,3,1,2).contiguous()

class SwinUNetFull(nn.Module):
    def __init__(self, img_size=384, num_classes=1):
        super().__init__()
        self.img_size = img_size
        self.backbone = timm.create_model(
            'swin_base_patch4_window12_384',
            pretrained=True,
            features_only=True
        )
        chs = self.backbone.feature_info.channels()
        resolution = img_size // 32

        self.dec4 = SwinDecoderTransformerBlock(chs[3], chs[2], (resolution*2, resolution*2), num_heads=12)
        self.dec3 = SwinDecoderTransformerBlock(chs[2], chs[1], (resolution*4, resolution*4), num_heads=6)
        self.dec2 = SwinDecoderTransformerBlock(chs[1], chs[0], (resolution*8, resolution*8), num_heads=3)
        self.final_conv = nn.Sequential(
            nn.Conv2d(chs[0]//2, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, num_classes, kernel_size=1),
        )

    def forward(self, x):
        feats = self.backbone(x)
        feats = [f.permute(0,3,1,2).contiguous() for f in feats]
        f1,f2,f3,f4 = feats
        d4 = self.dec4(f4, f3)
        d3 = self.dec3(d4, f2)
        d2 = self.dec2(d3, f1)
        out = self.final_conv(d2)
        return F.interpolate(out, size=(self.img_size, self.img_size),
                             mode='bilinear', align_corners=False)


In [None]:
import os
import cv2
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader

# ========================
# Parametreler
# ========================
SIZE = 384
BATCH_SIZE = 8


train_image_dir ="/mnt/data/sagSpinMS/train/Images/"
train_mask_dir ="/mnt/data/sagSpinMS/train/Maskes/"


test_image_dir = "/mnt/data/sagSpinMS/test/Images/"
test_mask_dir = "/mnt/data/sagSpinMS/test/Maskes/"




# ========================
# Görüntü Yükleyici Fonksiyon
# ========================
def load_image(path, size, is_mask=False):
    flag = cv2.IMREAD_GRAYSCALE if is_mask else cv2.IMREAD_COLOR
    image = cv2.imread(path, flag)

    if image is None:
        print(f"Uyarı: Dosya okunamadı: {path}")
        return None

    if not is_mask:
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(image).resize((size, size), resample=Image.BILINEAR)
        return np.array(image)
    else:
        mask = Image.fromarray(image).resize((size, size), resample=Image.NEAREST)
        mask = np.array(mask)
        mask = (mask > 127).astype(np.uint8)
        return mask

# ========================
# Dataset Listesini Yükle
# ========================
def load_dataset(image_dir, mask_dir, size):
    image_names = sorted([f for f in os.listdir(image_dir) if f.lower().endswith(".png")])
    mask_names = sorted([f for f in os.listdir(mask_dir) if f.lower().endswith(".png")])

    assert len(image_names) == len(mask_names), "Görüntü ve maske sayısı eşleşmiyor!"

    images, masks = [], []

    for img_name, mask_name in zip(image_names, mask_names):
        img_path = os.path.join(image_dir, img_name)
        mask_path = os.path.join(mask_dir, mask_name)

        image = load_image(img_path, size, is_mask=False)
        mask = load_image(mask_path, size, is_mask=True)

        if image is not None and mask is not None:
            images.append(image)
            masks.append(mask)

    print(f"Toplam {len(images)} örnek yüklendi.")
    return images, masks

# ========================
# Dataset Sınıfı
# ========================
class SegmentationDataset(Dataset):
    def __init__(self, images, masks, transform=None):
        self.images = images
        self.masks = masks
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        mask = self.masks[idx]


        image = torch.tensor(image / 255.0, dtype=torch.float32).permute(2, 0, 1)  # [3, H, W]
        mask = torch.tensor(mask, dtype=torch.float32).unsqueeze(0)               # [1, H, W]


        if self.transform:
            augmented = self.transform(image=image.permute(1, 2, 0).numpy(), mask=mask.squeeze(0).numpy())
            image = torch.tensor(augmented['image']).permute(2, 0, 1).float()
            mask = torch.tensor(augmented['mask']).unsqueeze(0).float()

        return image, mask

# ========================
# Eğitim ve Test Verilerini Yükle
# ========================
train_images, train_masks = load_dataset(train_image_dir, train_mask_dir, SIZE)
test_images, test_masks = load_dataset(test_image_dir, test_mask_dir, SIZE)

train_dataset = SegmentationDataset(train_images, train_masks)
test_dataset = SegmentationDataset(test_images, test_masks)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print(f"[✓] Train loader hazır → {len(train_dataset)} örnek")
print(f"[✓] Test loader hazır  → {len(test_dataset)} örnek")


In [None]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import time
import pandas as pd
import h5py
import matplotlib.pyplot as plt
import numpy as np

# ==================================================
# 1. Helper: Metrics hesaplama (accuracy, dice, iou)
# ==================================================
def calculate_metrics(outputs, masks, threshold=0.5):
    with torch.no_grad():
        probs = torch.sigmoid(outputs)
        preds = (probs > threshold).float()

        correct = (preds == masks).sum().item()
        total = masks.numel()

        smooth = 1e-6
        intersection = (preds * masks).sum(dim=(1,2,3))
        union = preds.sum(dim=(1,2,3)) + masks.sum(dim=(1,2,3))

        dice = ((2 * intersection + smooth) / (union + smooth)).mean().item()

        intersection_iou = (preds * masks).sum(dim=(1,2,3))
        union_iou = (preds + masks - preds * masks).sum(dim=(1,2,3))
        iou = ((intersection_iou + smooth) / (union_iou + smooth)).mean().item()

    return correct, total, dice, iou

# ========================
# 2. Validation fonksiyonu
# ========================
def validate(model, val_loader, criterion, device):
    model.eval()
    val_loss = 0
    correct_preds = 0
    total_preds = 0
    dice_scores = []
    iou_scores = []

    with torch.no_grad():
        for images, masks in val_loader:
            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)
            loss = criterion(outputs, masks)
            val_loss += loss.item()

            c, t, dice, iou = calculate_metrics(outputs, masks)
            correct_preds += c
            total_preds += t
            dice_scores.append(dice)
            iou_scores.append(iou)

    avg_val_loss = val_loss / len(val_loader)
    avg_val_acc = correct_preds / total_preds * 100
    avg_dice = np.mean(dice_scores)
    avg_iou = np.mean(iou_scores)

    return avg_val_loss, avg_val_acc, avg_dice, avg_iou

# ========================
# 3. Eğitim parametreleri
# ========================
batch_size = 8
num_epochs = 100
learning_rate = 1e-4
SIZE = 384
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ========================
# 4. Model, Loss, Optimizer
# ========================
model = decoSwin(img_size=SIZE, num_classes=1).to(device)
criterion = ComboLoss(bce_weight=0.5, dice_weight=0.5)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# ========================
# 5. Eğitim için kayıtlar
# ========================
history = {
    "epoch": [],
    "train_loss": [],
    "train_acc": [],
    "val_loss": [],
    "val_acc": [],
    "val_dice": [],
    "val_iou": [],
    "epoch_time_sec": []
}

best_val_loss = float('inf')
best_epoch = -1

# ========================
# 6. Eğitim döngüsü (Tüm süre ölçümü dahil)
# ========================
total_start_time = time.time()  # ⏱️ Toplam süre başlangıcı

for epoch in range(num_epochs):
    epoch_start_time = time.time()

    model.train()
    running_loss = 0.0
    correct_preds = 0
    total_preds = 0

    loop = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{num_epochs}]", leave=False)

    for images, masks in loop:
        images = images.to(device)
        masks = masks.to(device)

        outputs = model(images)
        loss = criterion(outputs, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        c, t, _, _ = calculate_metrics(outputs, masks)
        correct_preds += c
        total_preds += t

        loop.set_postfix(loss=loss.item())

    epoch_train_loss = running_loss / len(train_loader)
    epoch_train_acc = correct_preds / total_preds * 100

    epoch_val_loss, epoch_val_acc, epoch_val_dice, epoch_val_iou = validate(model, test_loader, criterion, device)

    epoch_end_time = time.time()
    epoch_duration = epoch_end_time - epoch_start_time

    print(f"Epoch {epoch+1}/{num_epochs} - "
          f"Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.2f}% | "
          f"Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.2f}%, "
          f"Dice: {epoch_val_dice:.4f}, IoU: {epoch_val_iou:.4f}, "
          f"Time: {epoch_duration:.1f}s")

    # En iyi modeli kaydet
    if epoch_val_loss < best_val_loss:
        best_val_loss = epoch_val_loss
        best_epoch = epoch + 1
        torch.save(model.state_dict(), "model.pth")
        print(f"✅ Model kaydedildi. (Epoch {best_epoch}, Val Loss: {best_val_loss:.4f})")

    # History update
    history["epoch"].append(epoch + 1)
    history["train_loss"].append(epoch_train_loss)
    history["train_acc"].append(epoch_train_acc)
    history["val_loss"].append(epoch_val_loss)
    history["val_acc"].append(epoch_val_acc)
    history["val_dice"].append(epoch_val_dice)
    history["val_iou"].append(epoch_val_iou)
    history["epoch_time_sec"].append(epoch_duration)

# Tüm eğitim süresi
total_end_time = time.time()
total_training_time = total_end_time - total_start_time
print(f"\n⏱️ Toplam Eğitim Süresi: {total_training_time:.1f} saniye")

# ========================
# 7. Sonuçları CSV dosyasına kaydet
# ========================
df = pd.DataFrame(history)
df.to_csv("deco-swin_training.csv", index=False)
print("📁 Eğitim geçmişi 'model_training.csv' olarak kaydedildi.")

# ========================
# 8. Sonuçları HDF5 dosyasına kaydet
# ========================
with h5py.File("model.h5", "w") as hf:
    for key, values in history.items():
        hf.create_dataset(key, data=np.array(values))
    hf.attrs["total_training_time_sec"] = total_training_time
    hf.attrs["best_epoch"] = best_epoch
    hf.attrs["best_val_loss"] = best_val_loss

print("📁 Eğitim geçmişi 'model.h5' olarak kaydedildi.")


In [None]:
# ========================
# 9. Eğitim Grafikleri
# ========================
plt.figure(figsize=(16,8))

plt.subplot(2,3,1)
plt.plot(history["epoch"], history["train_loss"], label="Train Loss")
plt.plot(history["epoch"], history["val_loss"], label="Val Loss")
plt.title("Loss")
plt.legend()

plt.subplot(2,3,2)
plt.plot(history["epoch"], history["train_acc"], label="Train Acc")
plt.plot(history["epoch"], history["val_acc"], label="Val Acc")
plt.title("Accuracy (%)")
plt.legend()

plt.subplot(2,3,3)
plt.plot(history["epoch"], history["val_dice"], label="Val Dice")
plt.title("Validation Dice Score")
plt.legend()

plt.subplot(2,3,4)
plt.plot(history["epoch"], history["val_iou"], label="Val IoU")
plt.title("Validation IoU Score")
plt.legend()

plt.subplot(2,3,5)
plt.plot(history["epoch"], history["epoch_time_sec"], label="Epoch Duration (s)")
plt.title("Epoch Süresi")
plt.legend()

plt.tight_layout()
plt.savefig("training_plots.png")
plt.show()
print("📈 Eğitim grafikleri 'model.png' olarak kaydedildi.")




In [None]:
import numpy as np
import torch
from sklearn.metrics import roc_auc_score

def f1_score_metric(pred, target, smooth=1e-6):
    pred = pred.astype(np.float32)
    target = target.astype(np.float32)
    intersection = (pred * target).sum()
    precision = intersection / (pred.sum() + smooth)
    recall = intersection / (target.sum() + smooth)
    f1 = 2 * precision * recall / (precision + recall + smooth)
    return f1

def voe_metric(pred, target):
    # Volume Overlap Error = 1 - Jaccard index
    intersection = np.logical_and(pred, target).sum()
    union = np.logical_or(pred, target).sum()
    jaccard = intersection / union if union != 0 else 1.0
    voe = 1 - jaccard
    return voe

def mean_surface_distance(pred, target):
    from scipy.spatial.distance import directed_hausdorff

    pred_edges = np.argwhere(np.diff(pred.astype(np.uint8), axis=0) != 0)
    target_edges = np.argwhere(np.diff(target.astype(np.uint8), axis=0) != 0)
    if len(pred_edges) == 0 or len(target_edges) == 0:
        return 0.0
    d_pred_to_target = directed_hausdorff(pred_edges, target_edges)[0]
    d_target_to_pred = directed_hausdorff(target_edges, pred_edges)[0]
    msd = (d_pred_to_target + d_target_to_pred) / 2
    return msd

def compute_roc_auc(outputs, masks):
    # Flattened arrays
    outputs_flat = outputs.flatten()
    masks_flat = masks.flatten()
    try:
        score = roc_auc_score(masks_flat, outputs_flat)
    except:
        score = float('nan')
    return score

def dice_coefficient(preds, targets, threshold=0.5):
    preds = preds > threshold
    intersection = np.sum(preds * targets)
    return 2. * intersection / (np.sum(preds) + np.sum(targets) + 1e-8)

def jaccard_index(preds, targets, threshold=0.5):
    preds = preds > threshold
    intersection = np.sum(preds * targets)
    union = np.sum(preds) + np.sum(targets) - intersection
    return intersection / (union + 1e-8)

def precision(preds, targets, threshold=0.5):
    preds = preds > threshold
    true_positive = np.sum(preds * targets)
    false_positive = np.sum(preds * (1 - targets))
    return true_positive / (true_positive + false_positive + 1e-8)

def recall(preds, targets, threshold=0.5):
    preds = preds > threshold
    true_positive = np.sum(preds * targets)
    false_negative = np.sum((1 - preds) * targets)
    return true_positive / (true_positive + false_negative + 1e-8)

def specificity(preds, targets, threshold=0.5):
    preds = preds > threshold
    true_negative = np.sum((1 - preds) * (1 - targets))
    false_positive = np.sum(preds * (1 - targets))
    return true_negative / (true_negative + false_positive + 1e-8)

def hausdorff_distance(preds, targets):
    from scipy.spatial.distance import directed_hausdorff

    pred_coords = np.array(np.where(preds == 1)).T
    target_coords = np.array(np.where(targets == 1)).T

    if len(pred_coords) == 0 or len(target_coords) == 0:
        return 0

    forward_hausdorff = directed_hausdorff(pred_coords, target_coords)[0]
    backward_hausdorff = directed_hausdorff(target_coords, pred_coords)[0]
    return max(forward_hausdorff, backward_hausdorff)

def assd(preds, targets):
    from scipy.spatial.distance import cdist

    pred_coords = np.array(np.where(preds == 1)).T
    target_coords = np.array(np.where(targets == 1)).T

    if len(pred_coords) == 0 or len(target_coords) == 0:
        return 0

    dist_pred_to_target = cdist(pred_coords, target_coords)
    dist_target_to_pred = cdist(target_coords, pred_coords)

    assd_val = (np.mean(np.min(dist_pred_to_target, axis=1)) + np.mean(np.min(dist_target_to_pred, axis=1))) / 2
    return assd_val


In [None]:
import os
import matplotlib.pyplot as plt
import torch
import numpy as np
from scipy.spatial.distance import directed_hausdorff, cdist
from sklearn.metrics import roc_auc_score, f1_score

def visualize_all_predictions(model, test_loader, device, save_dir=None):
    model.eval()

    all_test_dices = []
    all_test_jaccards = []
    all_test_precisions = []
    all_test_recalls = []
    all_test_specificities = []
    all_test_hd = []
    all_test_assd = []
    all_test_roc_auc = []
    all_test_voe = []
    all_test_msd = []
    all_test_f1 = []

    if save_dir:
        os.makedirs(save_dir, exist_ok=True)

    with torch.no_grad():
        for batch_idx, (images, masks) in enumerate(test_loader):
            images = images.to(device)
            masks = masks.to(device)
            outputs = model(images)
            probs = torch.sigmoid(outputs)
            preds = (probs > 0.5).float()

            preds_np = preds.cpu().numpy()
            probs_np = probs.cpu().numpy()
            masks_np = masks.cpu().numpy()

            for i in range(images.size(0)):
                image = images[i].cpu().permute(1, 2, 0).numpy()
                mask = masks_np[i][0]
                pred = preds_np[i][0]
                prob = probs_np[i][0]

                # 🔢 Sıralı numara
                sample_index = batch_idx * test_loader.batch_size + i

                # 📊 Görselleştirme
                fig, axs = plt.subplots(1, 3, figsize=(15, 5))
                axs[0].imshow(image)
                axs[0].set_title("Input Image")
                axs[0].axis("off")

                axs[1].imshow(mask, cmap='gray')
                axs[1].set_title("Ground Truth Mask")
                axs[1].axis("off")

                axs[2].imshow(pred, cmap='gray')
                axs[2].set_title("Predicted Mask")
                axs[2].axis("off")

                plt.tight_layout()

                if save_dir:
                    file_path = f"{save_dir}/sample_{sample_index:04d}.png"
                    plt.savefig(file_path, dpi=150)
                    print(f"📁 Saved image to: {file_path}")
                    plt.close()
                else:
                    plt.show()
                    plt.close()

                # 🧮 Metrik hesaplama
                dice = dice_coefficient(pred, mask)
                jaccard = jaccard_index(pred, mask)
                precision_value = precision(pred, mask)
                recall_value = recall(pred, mask)
                specificity_value = specificity(pred, mask)
                hd_value = hausdorff_distance(pred, mask)
                assd_value = assd(pred, mask)
                voe = volume_overlap_error(pred, mask)
                msd_value = mean_surface_distance(pred, mask)

                try:
                    f1 = f1_score(mask.flatten().astype(int), pred.flatten().astype(int))
                except:
                    f1 = np.nan

                try:
                    roc_auc = roc_auc_score(mask.flatten(), prob.flatten())
                except ValueError:
                    roc_auc = np.nan

                print(f"Sample {sample_index:04d} Metrics:")
                print(f"DICE: {dice:.4f}")
                print(f"Jaccard: {jaccard:.4f}")
                print(f"Precision: {precision_value:.4f}")
                print(f"Recall: {recall_value:.4f}")
                print(f"Specificity: {specificity_value:.4f}")
                print(f"Hausdorff Distance: {hd_value:.4f}")
                print(f"ASSD: {assd_value:.4f}")
                print(f"F1 Score: {f1:.4f}")
                print(f"VOE (Volume Overlap Error): {voe:.4f}")
                print(f"Mean Surface Distance (MSD): {msd_value:.4f}")
                print(f"ROC-AUC: {roc_auc:.4f}")
                print("-" * 50)

                # ⏺ Metriği sakla
                all_test_dices.append(dice)
                all_test_jaccards.append(jaccard)
                all_test_precisions.append(precision_value)
                all_test_recalls.append(recall_value)
                all_test_specificities.append(specificity_value)
                all_test_hd.append(hd_value)
                all_test_assd.append(assd_value)
                all_test_f1.append(f1)
                all_test_voe.append(voe)
                all_test_msd.append(msd_value)
                all_test_roc_auc.append(roc_auc)

    # 📊 Ortalama metrikler
    print(f"\n📈 Average Test Metrics (Overall):")
    print(f"DICE: {np.nanmean(all_test_dices):.4f}")
    print(f"Jaccard: {np.nanmean(all_test_jaccards):.4f}")
    print(f"Precision: {np.nanmean(all_test_precisions):.4f}")
    print(f"Recall: {np.nanmean(all_test_recalls):.4f}")
    print(f"Specificity: {np.nanmean(all_test_specificities):.4f}")
    print(f"Hausdorff Distance: {np.nanmean(all_test_hd):.4f}")
    print(f"ASSD: {np.nanmean(all_test_assd):.4f}")
    print(f"F1 Score: {np.nanmean(all_test_f1):.4f}")
    print(f"VOE: {np.nanmean(all_test_voe):.4f}")
    print(f"Mean Surface Distance: {np.nanmean(all_test_msd):.4f}")
    print(f"ROC-AUC: {np.nanmean(all_test_roc_auc):.4f}")


In [None]:
save_dir = "./model_predictions"  # Çalışma dizininde klasör oluşturur ve kaydeder
visualize_all_predictions(model, test_loader, device, save_dir=save_dir)

