In [10]:
import os
import cv2
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.utils.metrics import IoU

In [3]:
class RoadDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.image_dir, img_name)
        mask_path = os.path.join(self.mask_dir, img_name.replace('.tiff', '.tif')) 
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"))
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']
            mask = (mask > 0).float().unsqueeze(0)  
        return image, mask

train_transform = A.Compose([
    A.RandomCrop(height=512, width=512),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
], additional_targets={'mask': 'mask'})

val_transform = A.Compose([
    A.PadIfNeeded(min_height=1504, min_width=1504, border_mode=cv2.BORDER_CONSTANT, fill=0, fill_mask=0),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
], additional_targets={'mask': 'mask'})

In [13]:
train_images = './MassachusettsRoads/train'
train_masks = './MassachusettsRoads/train_labels'
val_images = './MassachusettsRoads/val'
val_masks = './MassachusettsRoads/val_labels'
test_images = './MassachusettsRoads/test'
test_masks = './MassachusettsRoads/test_labels'

train_dataset = RoadDataset(train_images, train_masks, transform=train_transform)
val_dataset = RoadDataset(val_images, val_masks, transform=val_transform)
test_dataset = RoadDataset(test_images, test_masks, transform=val_transform)

batch_size = 4
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

model = smp.Unet(encoder_name="resnet18", encoder_weights="imagenet", classes=1, activation=None)
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Device: {device}")
model = model.to(device)

loss_func = nn.BCEWithLogitsLoss()
iou = IoU(threshold=0.5)

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

Device: mps


In [7]:
num_epochs = 5
best_val_loss = float('inf')

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for images, masks in tqdm(train_loader, desc=f"Train Epoch {epoch+1}/{num_epochs}", colour="green"):
        images = images.to(device)
        masks = masks.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_func(outputs, masks)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(train_loader)

    model.eval()
    val_loss = 0
    val_iou = 0
    with torch.no_grad():
        for images, masks in tqdm(val_loader, desc=f"Val Epoch {epoch+1}/{num_epochs}", colour="red"):
            images = images.to(device)
            masks = masks.to(device)
            outputs = model(images)
            loss = loss_func(outputs, masks)
            val_loss += loss.item()
            val_iou += iou(outputs, masks).item()
    val_loss /= len(val_loader)
    val_iou /= len(val_loader)
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "best_weights.pth")
        print(f"Model saved with Val Loss: {val_loss:.4f}")

    print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val IoU: {val_iou:.4f}")

Train Epoch 1/5: 100%|[32m██████████[0m| 277/277 [01:40<00:00,  2.76it/s]
Val Epoch 1/5: 100%|[31m██████████[0m| 4/4 [00:03<00:00,  1.12it/s]


Model saved with Val Loss: 0.1023
Epoch 1, Train Loss: 0.0953, Val Loss: 0.1023, Val IoU: 0.4534


Train Epoch 2/5: 100%|[32m██████████[0m| 277/277 [01:40<00:00,  2.75it/s]
Val Epoch 2/5: 100%|[31m██████████[0m| 4/4 [00:03<00:00,  1.13it/s]


Model saved with Val Loss: 0.1009
Epoch 2, Train Loss: 0.0933, Val Loss: 0.1009, Val IoU: 0.4579


Train Epoch 3/5: 100%|[32m██████████[0m| 277/277 [01:40<00:00,  2.76it/s]
Val Epoch 3/5: 100%|[31m██████████[0m| 4/4 [00:03<00:00,  1.15it/s]


Model saved with Val Loss: 0.0986
Epoch 3, Train Loss: 0.0908, Val Loss: 0.0986, Val IoU: 0.4654


Train Epoch 4/5: 100%|[32m██████████[0m| 277/277 [01:39<00:00,  2.77it/s]
Val Epoch 4/5: 100%|[31m██████████[0m| 4/4 [00:03<00:00,  1.12it/s]


Model saved with Val Loss: 0.0970
Epoch 4, Train Loss: 0.0894, Val Loss: 0.0970, Val IoU: 0.4818


Train Epoch 5/5: 100%|[32m██████████[0m| 277/277 [01:40<00:00,  2.77it/s]
Val Epoch 5/5: 100%|[31m██████████[0m| 4/4 [00:03<00:00,  1.13it/s]

Model saved with Val Loss: 0.0953
Epoch 5, Train Loss: 0.0881, Val Loss: 0.0953, Val IoU: 0.5004





In [8]:
model.load_state_dict(torch.load(weights_path, map_location=device))
print(f"Loaded weights from {weights_path}")

model.eval()
test_loss = 0
test_iou = 0
with torch.no_grad():
    for images, masks in test_loader:
        images = images.to(device)
        masks = masks.to(device)
        outputs = model(images)
        loss = loss_func(outputs, masks)
        test_loss += loss.item()
        test_iou += iou(outputs, masks).item()
test_loss /= len(test_loader)
test_iou /= len(test_loader)
print(f"Test Loss: {test_loss:.4f}, Test IoU: {test_iou:.4f}")

Loaded weights from best_weights.pth
Test Loss: 0.0602, Test IoU: 0.5357
