# 🧠 Learned Skeletonization of Road Networks - Final Notebook

In [None]:

# 📍 Set how many images to use for training (out of 500 total)
NUM_TRAIN_IMAGES = 10  # Change to 5, 50, or 500 as needed


In [None]:

import os
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt


In [None]:

# 📍 Step 1: Define Dataset
class SkeletonDataset(Dataset):
    def __init__(self, input_dir, target_dir, file_list=None, augment=False, limit=None):
        self.input_dir = input_dir
        self.target_dir = target_dir
        self.files = sorted(file_list) if file_list else sorted(os.listdir(input_dir))
        if limit is not None:
            self.files = self.files[:limit]
        self.augment = augment
        self.base_transform = transforms.ToTensor()
        self.aug_transform = transforms.RandomHorizontalFlip()

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        fname = self.files[idx]
        input_path = os.path.join(self.input_dir, fname)
        target_name = fname.replace("image", "target")
        target_path = os.path.join(self.target_dir, target_name)

        input_img = Image.open(input_path).convert("L")
        target_img = Image.open(target_path).convert("L")

        if self.augment:
            seed = torch.randint(0, 10000, (1,)).item()
            torch.manual_seed(seed)
            input_img = self.aug_transform(input_img)
            torch.manual_seed(seed)
            target_img = self.aug_transform(target_img)

        return self.base_transform(input_img), self.base_transform(target_img)


In [None]:

# 📍 Step 2: Define U-Net model
class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, init_features=32):
        super(UNet, self).__init__()
        features = init_features
        self.encoder1 = self._block(in_channels, features)
        self.pool1 = nn.MaxPool2d(2)
        self.encoder2 = self._block(features, features * 2)
        self.pool2 = nn.MaxPool2d(2)
        self.encoder3 = self._block(features * 2, features * 4)
        self.pool3 = nn.MaxPool2d(2)
        self.encoder4 = self._block(features * 4, features * 8)
        self.pool4 = nn.MaxPool2d(2)
        self.bottleneck = self._block(features * 8, features * 16)
        self.upconv4 = nn.ConvTranspose2d(features * 16, features * 8, 2, 2)
        self.decoder4 = self._block(features * 16, features * 8)
        self.upconv3 = nn.ConvTranspose2d(features * 8, features * 4, 2, 2)
        self.decoder3 = self._block(features * 8, features * 4)
        self.upconv2 = nn.ConvTranspose2d(features * 4, features * 2, 2, 2)
        self.decoder2 = self._block(features * 4, features * 2)
        self.upconv1 = nn.ConvTranspose2d(features * 2, features, 2, 2)
        self.decoder1 = self._block(features * 2, features)
        self.final_conv = nn.Conv2d(features, out_channels, kernel_size=1)

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))
        bottleneck = self.bottleneck(self.pool4(enc4))
        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)
        return torch.sigmoid(self.final_conv(dec1))

    def _block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )


In [None]:

# 📍 Step 3: Define Dice Loss
def dice_loss(pred, target, smooth=1e-5):
    pred = pred.view(-1)
    target = target.view(-1)
    intersection = (pred * target).sum()
    return 1 - (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)


In [None]:

# 📍 Step 4: Training
input_dir = "/content/gdrive/MyDrive/Classes/ML Advanced/Project/thinning_data/data/images_distorted"
target_dir = "/content/gdrive/MyDrive/Classes/ML Advanced/Project/thinning_data/data/target"

train_dataset = SkeletonDataset(input_dir, target_dir, augment=True, limit=NUM_TRAIN_IMAGES)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

model = UNet(in_channels=1, out_channels=1).to("cuda" if torch.cuda.is_available() else "cpu")
device = next(model.parameters()).device
optimizer = optim.Adam(model.parameters(), lr=1e-3)
bce_loss = nn.BCELoss()

epochs = 3
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for images, targets in train_loader:
        images, targets = images.to(device), targets.to(device)
        preds = model(images)
        loss = bce_loss(preds, targets) + dice_loss(preds, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/{epochs} - Loss: {total_loss/len(train_loader):.4f}")


In [None]:

# 📍 Step 5: Visualization
def show_triplet(input_img, predicted_mask, target_img, idx=None, figsize=(12, 4), threshold=0.3):
    binarized_pred = (predicted_mask > threshold).astype(np.uint8)
    titles = ['Input Image', 'Predicted Skeleton', 'Ground Truth']
    images = [input_img, binarized_pred, target_img]

    plt.figure(figsize=figsize)
    for i, (title, img) in enumerate(zip(titles, images)):
        plt.subplot(1, 3, i + 1)
        plt.imshow(img, cmap='gray')
        plt.title(title + (f" #{idx}" if idx is not None else ''))
        plt.axis('off')
    plt.tight_layout()
    plt.show()

# Evaluate on one sample
model.eval()
with torch.no_grad():
    input_img, target_img = train_dataset[0]
    input_batch = input_img.unsqueeze(0).to(device)
    pred = model(input_batch).cpu().numpy()[0, 0]
    show_triplet(input_img.squeeze().numpy(), pred, target_img.squeeze().numpy(), idx=0)
