In [None]:
import os
import time
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim import lr_scheduler
from PIL import Image
from torchvision import transforms
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt


In [None]:
train_image_dir = 'dataset/train'
train_mask_dir  = 'dataset/train_label'
val_image_dir   = 'dataset/val'
val_mask_dir    = 'dataset/val_label'

In [None]:
def analyze_image_properties(image_dir):
    image_files = os.listdir(image_dir)
    resolutions, channels, missing = [], [], []

    for f in image_files:
        path = os.path.join(image_dir, f)
        try:
            with Image.open(path) as img:
                resolutions.append(img.size)
                channels.append(len(img.getbands()))
        except Exception as e:
            print(f"Error loading {f}: {e}")
            missing.append(f)

    return resolutions, channels, missing

print(analyze_image_properties(train_image_dir))

def get_unique_classes(mask_dir):
    classes = set()
    for f in os.listdir(mask_dir):
        path = os.path.join(mask_dir, f)
        try:
            with Image.open(path) as mask:
                mask_array = np.array(mask)
                classes.update(np.unique(mask_array))
        except Exception as e:
            print(f"Error loading mask {f}: {e}")
    return classes


In [None]:
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.image_dir, img_name)
        mask_name = img_name.replace('.png', '_label.png')
        mask_path = os.path.join(self.mask_dir, mask_name)

        image = Image.open(img_path).convert('RGB')
        mask  = Image.open(mask_path).convert('L')

        if self.transform:
            image = self.transform(image)
            mask  = TF.resize(mask, (256, 256), interpolation=Image.NEAREST)

        mask = torch.from_numpy(np.array(mask)).long()
        return image, mask

In [None]:
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
])


train_loader = DataLoader(
    SegmentationDataset(train_image_dir, train_mask_dir, train_transform),
    batch_size=16, shuffle=True
)
val_loader = DataLoader(
    SegmentationDataset(val_image_dir, val_mask_dir, train_transform),
    batch_size=16, shuffle=False
)


In [None]:
class UNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        # Encoder
        self.enc1 = self.conv_block(3, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 512)
        self.enc5 = self.conv_block(512, 1024)
        self.bottleneck = self.conv_block(1024, 2048)
        self.up5 = nn.ConvTranspose2d(2048, 1024, 2, stride=2)
        self.dec5 = self.conv_block(2048, 1024)
        self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.dec4 = self.conv_block(1024, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = self.conv_block(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = self.conv_block(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = self.conv_block(128, 64)
        self.out = nn.Conv2d(64, num_classes, 1)

    def conv_block(self, in_ch, out_ch):
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.ReLU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.ReLU()
        )

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(F.max_pool2d(e1, 2))
        e3 = self.enc3(F.max_pool2d(e2, 2))
        e4 = self.enc4(F.max_pool2d(e3, 2))
        e5 = self.enc5(F.max_pool2d(e4, 2))
        b  = self.bottleneck(F.max_pool2d(e5, 2))

        d5 = self.dec5(torch.cat([self.up5(b),  e5], dim=1))
        d4 = self.dec4(torch.cat([self.up4(d5), e4], dim=1))
        d3 = self.dec3(torch.cat([self.up3(d4), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
        return self.out(d1)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(num_classes=3).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5)



In [None]:
def calculate_iou(preds, masks, num_classes):
    preds = torch.argmax(preds, dim=1)
    ious = []
    for cls in range(num_classes):
        p = (preds == cls).float()
        m = (masks == cls).float()
        inter = torch.sum(p * m)
        union = torch.sum(p) + torch.sum(m) - inter
        ious.append(1.0 if union == 0 else (inter / union).item())
    return sum(ious) / len(ious)

def calculate_dice(preds, masks, num_classes):
    preds = torch.argmax(preds, dim=1)
    dices = []
    for cls in range(num_classes):
        p = (preds == cls).float()
        m = (masks == cls).float()
        inter = torch.sum(p * m)
        denom = torch.sum(p) + torch.sum(m)
        dices.append(1.0 if denom == 0 else (2 * inter / denom).item())
    return sum(dices) / len(dices)


In [None]:
num_epochs = 50
train_loss_list, val_loss_list = [], []
iou_list, dice_list = [], []

best_val_loss = float('inf')
early_patience, early_counter = 5, 0
save_path = 'best_model.pth'

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device).long()
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    model.eval()
    val_loss = iou_total = dice_total = 0.0
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device).long()
            outputs = model(images)
            loss = criterion(outputs, masks)
            val_loss += loss.item()
            iou_total += calculate_iou(outputs, masks, 3)
            dice_total += calculate_dice(outputs, masks, 3)

    avg_train = running_loss / len(train_loader)
    avg_val   = val_loss / len(val_loader)
    avg_iou   = iou_total / len(val_loader)
    avg_dice  = dice_total / len(val_loader)

    train_loss_list.append(avg_train)
    val_loss_list.append(avg_val)
    iou_list.append(avg_iou)
    dice_list.append(avg_dice)

    print(f"Epoch [{epoch+1}/{num_epochs}] "
          f"| Train Loss: {avg_train:.4f} "
          f"| Val Loss: {avg_val:.4f} "
          f"| MIoU: {avg_iou:.4f} "
          f"| Dice: {avg_dice:.4f}")

    scheduler.step(avg_val)

    if avg_val < best_val_loss:
        best_val_loss = avg_val
        torch.save(model.state_dict(), save_path)
        early_counter = 0
    else:
        early_counter += 1
        if early_counter >= early_patience:
            print(f"Early stopping at epoch {epoch+1}")
            break


In [None]:
def plot_metrics(train_loss, val_loss, iou, dice):
    epochs = range(1, len(train_loss) + 1)

    plt.figure(figsize=(10, 5))
    plt.plot(epochs, train_loss, label='Train Loss')
    plt.plot(epochs, val_loss, label='Val Loss')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.title('Loss Curve')
    plt.legend(); plt.grid(True); plt.show()

    plt.figure(figsize=(10, 5))
    plt.plot(epochs, iou,  label='Mean IoU')
    plt.plot(epochs, dice, label='Dice Coefficient')
    plt.xlabel('Epoch'); plt.ylabel('Score'); plt.title('Metrics Curve')
    plt.legend(); plt.grid(True); plt.show()

plot_metrics(train_loss_list, val_loss_list, iou_list, dice_list)



In [None]:
plot_metrics(train_loss_list, val_loss_list, iou_list, dice_list)

In [None]:
def visualize_predictions(model, dataloader, device, num_images=5):
    model.eval()
    dataset_size = len(dataloader.dataset)
    random_idx = random.sample(range(dataset_size), num_images)

    fig, ax = plt.subplots(num_images, 3, figsize=(10, num_images * 4))
    with torch.no_grad():
        for batch_idx, (images, masks) in enumerate(dataloader):
            start = batch_idx * images.size(0)
            end   = start + images.size(0)
            valid_idx = [i for i in random_idx if start <= i < end]
            if not valid_idx:
                continue

            images, masks = images.to(device), masks.to(device)
            preds = torch.argmax(model(images), dim=1)

            for i in valid_idx:
                local_idx = i - start
                ax_i = ax[valid_idx.index(i)]
                ax_i[0].imshow(images[local_idx].cpu().permute(1,2,0))
                ax_i[0].set_title('Input')
                ax_i[1].imshow(masks[local_idx].cpu(), cmap='gray')
                ax_i[1].set_title('Ground Truth')
                ax_i[2].imshow(preds[local_idx].cpu(), cmap='gray')
                ax_i[2].set_title('Prediction')

    plt.tight_layout()
    plt.show()

visualize_predictions(model, val_loader, device)



In [None]:
def measure_inference_speed(model, dataloader, device, num_images=10):
    model.eval()
    total_time, processed = 0.0, 0
    with torch.no_grad():
        for images, _ in dataloader:
            images = images.to(device)
            start = time.time()
            _ = model(images)
            total_time += time.time() - start
            processed += len(images)
            if processed >= num_images:
                break
    avg_time = total_time / processed
    print(f"Avg inference time per image: {avg_time:.4f}s")
    print(f"Estimated FPS: {1/avg_time:.2f}")

measure_inference_speed(model, val_loader, device, num_images=10)
