In [None]:
# Gerekli temel kütüphaneler
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# PyTorch ve torchvision modülleri
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# PyTorch ile TensorBoard için gerekli
from torch.utils.tensorboard import SummaryWriter

# MedSegBench veri kümesi
from medsegbench import NusetMSBench

# Sklearn metrikleri
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, jaccard_score


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')
print(torch.cuda.get_device_name(0))

# Dataset için gerekli dönüşümler

In [None]:
transform = transforms.Compose([
    transforms.ToTensor()
])

# load the data
train_dataset = NusetMSBench(split='train', size=256, transform=transform, target_transform = transform, download=True)
val_dataset = NusetMSBench(split='val', size=256, transform=transform, target_transform = transform, download=True)
test_dataset = NusetMSBench(split='test', size=256,  transform=transform, target_transform = transform, download=True)

# Print the number of samples in each dataset
print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of validation samples: {len(val_dataset)}")
print(f"Number of test samples: {len(test_dataset)}")

print("-------------------------------------------")

bs = 8

# DataLoader for training, validation, and test sets
train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=bs, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=bs, shuffle=False)

#Print the number of batches in each DataLoader
print(f"Number of batches in training set: {len(train_loader)}")
print(f"Number of batches in validation set: {len(val_loader)}")
print(f"Number of batches in test set: {len(test_loader)}")

print("-------------------------------------------")

print(f"Image size: {train_dataset[0][0].shape}")
print(f"Label size: {train_dataset[0][1].shape}")


# Loss fonksiyonunu tanımlıyoruz

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, inputs, targets):
        inputs = torch.sigmoid(inputs)  # BCEWithLogitsLoss gibi logits veriyorsun, o yüzden sigmoid uygula

        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice = (2. * intersection + self.smooth) / (inputs.sum() + targets.sum() + self.smooth)

        return 1 - dice

# Model

In [None]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride, bias=False),
                nn.BatchNorm2d(out_channels),
            )

    def forward(self, x):
        identity = x # Skip connection
        out = self.relu(self.bn1(self.conv1(x))) # İlk konvolüsyon ve aktivasyon
        out = self.bn2(self.conv2(out)) # İkinci konvolüsyon

        if self.downsample:
            identity = self.downsample(x) # Skip connection boyutunu eşleştir
        out += identity # Skip connection
        return self.relu(out) 

class ResNet18_UNet(nn.Module):
    def __init__(self, n_classes):
        super().__init__()

        # Encoder
        self.in_channels = 64
        self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)

        self.encoder1 = self._make_layer(64, 2)
        self.encoder2 = self._make_layer(128, 2, stride=2)
        self.encoder3 = self._make_layer(256, 2, stride=2)
        self.encoder4 = self._make_layer(512, 2, stride=2)

        # Decoder
        self.decoder4 = self.upsample_block(512, 256)
        self.decoder3 = self.upsample_block(256, 128)
        self.decoder2 = self.upsample_block(128, 64)
        self.decoder1 = self.upsample_block(64, 64)

        self.final_conv = nn.Conv2d(64, n_classes, kernel_size=1)

    def _make_layer(self, out_channels, blocks, stride=1):
        layers = [BasicBlock(self.in_channels, out_channels, stride)]
        self.in_channels = out_channels
        for _ in range(1, blocks):
            layers.append(BasicBlock(self.in_channels, out_channels))
        return nn.Sequential(*layers)

    def upsample_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        x1 = self.relu(self.bn1(self.conv1(x)))
        x2 = self.maxpool(x1)
        e1 = self.encoder1(x2)
        e2 = self.encoder2(e1)
        e3 = self.encoder3(e2)
        e4 = self.encoder4(e3)

        d4 = self.decoder4(e4) + e3
        d3 = self.decoder3(d4) + e2
        d2 = self.decoder2(d3) + e1
        d1 = self.decoder1(d2) + x1

        out = self.final_conv(d1)
        return out


model = ResNet18_UNet(n_classes=1).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = DiceLoss()

# Model Training

In [None]:
# TensorBoard Writer
writer = SummaryWriter(log_dir="./logs/NusetMSBenchDiceLoss")

num_epochs = 150
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
best_val_loss = float("inf")  # En düşük doğrulama kaybı için başlangıç değeri

save_dir = "./checkpoints"
os.makedirs(save_dir, exist_ok=True)  # Checkpoint klasörü oluştur
log_file_path = os.path.join(save_dir, "NusetMSBenchDiceLoss.txt")  # TXT dosyası yolu

# Accuracy hesaplamak için yardımcı fonksiyon
def compute_accuracy(pred_mask, true_mask):
    correct = (pred_mask == true_mask).sum().float()
    total = true_mask.numel()
    accuracy = correct / total
    return accuracy.item()

# Eğitim
for epoch in range(num_epochs):
    model.train()
    train_loss = []
    train_acc = []

    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):

        inputs, labels = inputs.to(device), labels.to(device)

        # Etiketleri boyutlandır
        labels = F.interpolate(labels, size=(128, 128), mode='bilinear', align_corners=False)

        # Modeli çalıştır ve çıktı al
        logits_mask = model(inputs)

        # Kayıp hesaplama
        loss = criterion(logits_mask, labels)
        train_loss.append(loss.item())

        # Accuracy hesaplama
        prob_mask = logits_mask.sigmoid()
        pred_mask = (prob_mask > 0.5).float()
        true_mask = (labels > 0.5).float()  # Etiketleri ikili formata çeviriyoruz
        acc = compute_accuracy(pred_mask, true_mask)
        train_acc.append(acc)

        # Geri yayılım
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Değerleri kaydet
    val_loss = []
    val_acc = []
    model.eval()
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            # Etiketleri boyutlandır
            labels = F.interpolate(labels, size=(128, 128), mode='bilinear', align_corners=False)

            logits_mask = model(inputs)
            loss = criterion(logits_mask, labels)
            val_loss.append(loss.item())

            prob_mask = logits_mask.sigmoid()
            pred_mask = (prob_mask > 0.5).float()
            true_mask = (labels > 0.5).float()  # Etiketleri ikili formata çeviriyoruz
            acc = compute_accuracy(pred_mask, true_mask)
            val_acc.append(acc)

            # TensorBoard'a validasyon görsellerini ekle
            if epoch % 5 == 0:  # Her 5 epokta bir görsel kaydet
                writer.add_images('Validation/Input', inputs, epoch)
                writer.add_images('Validation/Predicted', pred_mask, epoch)
                writer.add_images('Validation/Target', true_mask, epoch)

    # Epoch sonu çıktılar
    train_loss_avg = sum(train_loss) / len(train_loss)
    train_acc_avg = sum(train_acc) / len(train_acc)
    val_loss_avg = sum(val_loss) / len(val_loss)
    val_acc_avg = sum(val_acc) / len(val_acc)
    
    print(f"[Epoch {epoch+1}] Train Loss: {train_loss_avg:.4f} | Accuracy: {train_acc_avg:.4f}")
    print(f"            Val Loss:   {val_loss_avg:.4f} | Accuracy: {val_acc_avg:.4f}")


    # TensorBoard
    writer.add_scalars(
        "Loss", {"Train": train_loss_avg, "Validation": val_loss_avg}, epoch
    )
    writer.add_scalars(
        "Accuracy", {"Train": train_acc_avg, "Validation": val_acc_avg}, epoch
    )
    
    
    # Sonuçları txt dosyasına kaydet
    with open(log_file_path, "a") as f:
        f.write(f"Epoch {epoch+1}/{num_epochs}\n")
        f.write(f"Train Loss: {train_loss_avg:.4f} | Train Accuracy: {train_acc_avg:.4f}\n")
        f.write(f"Val Loss:   {val_loss_avg:.4f} | Val Accuracy:   {val_acc_avg:.4f}\n")
        f.write("-" * 50 + "\n")

    # En düşük validasyon kaybını bul ve en iyi modeli kaydet
    current_val_loss = val_loss_avg
    if current_val_loss < best_val_loss:
        best_val_loss = current_val_loss
        torch.save(model.state_dict(), os.path.join(save_dir, "NusetMSBenchDiceLoss.pt"))
        print(f"En iyi model kaydedildi. Kayıp: {best_val_loss:.4f}")

        # En iyi modeli txt'ye de kaydet
        with open(log_file_path, "a") as f:
            f.write(f"--> En iyi model kaydedildi. Validation Loss: {best_val_loss:.4f}\n")
            f.write("=" * 50 + "\n")
     
    print("")       

writer.close()


# Model Test

In [None]:
def test_and_evaluate(model, test_loader, device, num_samples=5):
    model.eval()

    # Değerlendirme metriklerinin saklanması
    all_true_masks = []
    all_pred_masks = []
    input_images = []   # Görselleri saklamak için
    true_masks = []
    pred_masks = []

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            prob_mask = torch.sigmoid(outputs)
            pred_mask = (prob_mask > 0.5).float()

            # 256x256 boyuta getir
            pred_mask_resized = F.interpolate(pred_mask, size=(256, 256), mode='bilinear', align_corners=False)

            all_true_masks.append(labels.cpu().numpy())
            all_pred_masks.append(pred_mask_resized.cpu().numpy())

            # İlk birkaç örneği kaydet (görselleştirme için)
            for i in range(inputs.size(0)):
                if len(input_images) >= num_samples:
                    break
                input_images.append(inputs[i].cpu().squeeze().numpy())
                true_masks.append(labels[i].cpu().squeeze().numpy())
                pred_masks.append(pred_mask_resized[i].cpu().squeeze().numpy())

            if len(input_images) >= num_samples:
                break

    # --- METRİKLERİ HESAPLA VE YAZDIR ---
    all_true_masks = np.concatenate(all_true_masks, axis=0).flatten()
    all_pred_masks = np.concatenate(all_pred_masks, axis=0).flatten()
    all_pred_masks = (all_pred_masks > 0.5).astype(np.float32)

    accuracy = accuracy_score(all_true_masks, all_pred_masks)
    precision = precision_score(all_true_masks, all_pred_masks)
    recall = recall_score(all_true_masks, all_pred_masks)
    f1 = f1_score(all_true_masks, all_pred_masks)
    iou = jaccard_score(all_true_masks, all_pred_masks)

    print("=== Overall Metrics ===")
    print(f"Accuracy : {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall   : {recall:.4f}")
    print(f"F1-Score : {f1:.4f}")
    print(f"IoU      : {iou:.4f}")
    print("=" * 30)

    # --- GÖRSELLERİ GÖSTER ---
    for idx in range(len(input_images)):
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))
        axs[0].imshow(input_images[idx], cmap='gray')
        axs[0].set_title('Input Image')
        axs[1].imshow(true_masks[idx], cmap='gray')
        axs[1].set_title('Ground Truth')
        axs[2].imshow(pred_masks[idx], cmap='gray')
        axs[2].set_title('Predicted Mask')
        for ax in axs:
            ax.axis('off')
        plt.tight_layout()
        plt.show()

# Modeli yükleyip test et
model = ResNet18_UNet(n_classes=1).to(device)
model.load_state_dict(torch.load("./checkpoints/NusetMSBenchDiceLoss.pt", map_location=device))
model.eval()

test_and_evaluate(model, test_loader, device, num_samples=5)


# Model Output


In [None]:

model = ResNet18_UNet(n_classes=1).to(device)
model.load_state_dict(torch.load("./checkpoints/NusetMSBenchDiceLoss.pt"))
model.eval()

# Görselleştirme başlar
for inputs, labels in test_loader:
    inputs, labels = inputs.to(device), labels.to(device)

    with torch.no_grad():
        logits_mask = model(inputs)
        prob_mask = logits_mask.sigmoid()
        pred_mask = (prob_mask > 0.5).float()

        # Boyutları eşitle: Tahmin maskesini, ground truth (labels) ile aynı boyuta getir
        pred_mask = F.interpolate(pred_mask, size=labels.shape[2:], mode='bilinear', align_corners=False)

    batch_size = inputs.shape[0]
    plt.figure(figsize=(15, batch_size * 4))

    for i in range(batch_size):
        img = inputs[i, 0].cpu().numpy()
        true = labels[i, 0].cpu().numpy()
        pred = pred_mask[i, 0].cpu().numpy()

        # Doğru - Hatalı maskesi
        correct = (true == pred)
        error_map = np.zeros((*true.shape, 3))
        error_map[correct] = [0, 1, 0]     # Doğru -> Yeşil
        error_map[~correct] = [1, 0, 0]    # Hatalı -> Kırmızı

        # Giriş görüntüsü
        plt.subplot(batch_size, 4, i * 4 + 1)
        plt.imshow(img, cmap='gray')
        plt.title("Girdi Görüntü")
        plt.axis('off')

        # Gerçek maske
        plt.subplot(batch_size, 4, i * 4 + 2)
        plt.imshow(true, cmap='gray')
        plt.title("Gerçek Maske")
        plt.axis('off')

        # Tahmin maskesi
        plt.subplot(batch_size, 4, i * 4 + 3)
        plt.imshow(pred, cmap='viridis')
        plt.title("Tahmin Maske")
        plt.axis('off')

        # Hata görselleştirme
        plt.subplot(batch_size, 4, i * 4 + 4)
        plt.imshow(error_map)
        plt.title("Doğru (Yeşil) / Hatalı (Kırmızı)")
        plt.axis('off')

    plt.tight_layout()
    plt.show()
    break  # Sadece ilk batch'i göster