In [2]:
import os
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split

In [None]:
IMG_PATH = "output/pairwise_2x2/images"
MASK_PATH = "output/pairwise_2x2/masks"
BATCH_SIZE = 32
EPOCHS = 5
LR = 1e-3
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


In [None]:
class PairwiseSegmentationDataset(Dataset):
    def __init__(self, img_dir, mask_dir, transform=None):
        self.img_paths = sorted(os.listdir(img_dir))
        self.mask_paths = sorted(os.listdir(mask_dir))
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform = transform

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img = Image.open(os.path.join(self.img_dir, self.img_paths[idx])).convert("L")
        mask = Image.open(os.path.join(self.mask_dir, self.mask_paths[idx])).convert("L")

        if self.transform:
            img = self.transform(img)
            mask = self.transform(mask)

        mask = (mask > 0).float()
        return img, mask

In [None]:
transform = transforms.Compose([
    transforms.Resize((56, 56)),  
    transforms.ToTensor()
])

In [None]:
class SimpleUNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 2, stride=2), nn.ReLU(),
            nn.ConvTranspose2d(16, 1, 2, stride=2), nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [None]:
def dice_score(pred, target, threshold=0.5):
    pred = (pred > threshold).float()
    intersection = (pred * target).sum()
    total = pred.sum() + target.sum()
    dice = (2. * intersection) / (total + 1e-8)
    return dice.item()

In [None]:
dataset = PairwiseSegmentationDataset(IMG_PATH, MASK_PATH, transform=transform)
train_len = int(0.8 * len(dataset))
test_len = len(dataset) - train_len
train_ds, test_ds = random_split(dataset, [train_len, test_len])

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=1)

In [None]:
model = SimpleUNet().to(DEVICE)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

In [None]:
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    for imgs, masks in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)

        preds = model(imgs)
        loss = criterion(preds, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"Epoch {epoch+1} - Loss: {running_loss/len(train_loader):.4f}")

Epoch 1/5: 100%|██████████| 375/375 [00:58<00:00,  6.43it/s]


Epoch 1 - Loss: 0.1933


Epoch 2/5: 100%|██████████| 375/375 [00:28<00:00, 13.05it/s]


Epoch 2 - Loss: 0.0515


Epoch 3/5: 100%|██████████| 375/375 [00:38<00:00,  9.64it/s]


Epoch 3 - Loss: 0.0425


Epoch 4/5: 100%|██████████| 375/375 [00:31<00:00, 12.04it/s]


Epoch 4 - Loss: 0.0388


Epoch 5/5: 100%|██████████| 375/375 [00:33<00:00, 11.20it/s]

Epoch 5 - Loss: 0.0365





In [12]:
model.eval()
dice_scores = []

with torch.no_grad():
    for imgs, masks in test_loader:
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
        preds = model(imgs)
        score = dice_score(preds, masks)
        dice_scores.append(score)

mean_dice = np.mean(dice_scores)
print(f"\nMean Dice Coefficient on Test Set: {mean_dice:.4f}")


Mean Dice Coefficient on Test Set: 0.9470
