In [None]:
# Install required libraries
!pip install -q transformers albumentations opencv-python tqdm
!pip install torchmetrics

# %% Mount Google Drive to access your dataset
from google.colab import drive
drive.mount('/content/drive')


In [None]:
import os
import cv2
import numpy as np
from glob import glob
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset

import albumentations as A
from albumentations.pytorch import ToTensorV2

import torch.nn.functional as F
from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score, MulticlassJaccardIndex
from transformers import SegformerForSemanticSegmentation
from torch.optim.lr_scheduler import ReduceLROnPlateau


import random


def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Global variables for image size & training parameters
IMG_HEIGHT = 512   # Change as needed, depending on your dataset resolution
IMG_WIDTH = 512
BATCH_SIZE = 4     # Adjust depending on Colab free tier memory
NUM_EPOCHS = 30    # i want to change them to 25 and re train
LEARNING_RATE = 1e-4

# Define label mappings (modify if you have more classes)
id2label = {0: "background", 1: "road"}
label2id = {"background": 0, "road": 1}
NUM_CLASSES = len(id2label)


In [None]:
# Define offline augmentation pipeline
offline_transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
])

def augment_and_save(images_dir, masks_dir, save_images_dir, save_masks_dir, augmentations_per_image=3):
    os.makedirs(save_images_dir, exist_ok=True)
    os.makedirs(save_masks_dir, exist_ok=True)

    image_paths = sorted(glob(os.path.join(images_dir, "*")))
    mask_paths = sorted(glob(os.path.join(masks_dir, "*")))

    for img_path, mask_path in tqdm(zip(image_paths, mask_paths), total=len(image_paths)):
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        base_name = os.path.splitext(os.path.basename(img_path))[0]

        # Save original image and mask
        cv2.imwrite(os.path.join(save_images_dir, f"{base_name}.png"), cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
        cv2.imwrite(os.path.join(save_masks_dir, f"{base_name}.png"), mask)

        for i in range(augmentations_per_image):
            augmented = offline_transform(image=image, mask=mask)
            aug_image = augmented['image']
            aug_mask = augmented['mask']

            aug_image_bgr = cv2.cvtColor(aug_image, cv2.COLOR_RGB2BGR)
            cv2.imwrite(os.path.join(save_images_dir, f"{base_name}_aug_{i}.png"), aug_image_bgr)
            cv2.imwrite(os.path.join(save_masks_dir, f"{base_name}_aug_{i}.png"), aug_mask)


images_dir = '/content/drive/MyDrive/Road Segmentation/training/images'
masks_dir = '/content/drive/MyDrive/Road Segmentation/training/groundtruth'
save_images_dir = '/content/drive/MyDrive/Road Segmentation/training_augmented/images'
save_masks_dir = '/content/drive/MyDrive/Road Segmentation/training_augmented/groundtruth'

if not os.path.exists(save_images_dir) or len(os.listdir(save_images_dir)) == 0:
    augment_and_save(images_dir, masks_dir, save_images_dir, save_masks_dir, augmentations_per_image=10)


In [None]:
#on the fly transformations
train_transform = A.Compose([
    A.Resize(height=IMG_HEIGHT, width=IMG_WIDTH),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.GaussNoise(p=0.2),  # Add some noise augmentation
    A.OneOf([
        A.ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03, p=0.5),
        A.GridDistortion(p=0.5),
        A.OpticalDistortion(distort_limit=1, shift_limit=0.5, p=0.5),
    ], p=0.3),  # Geometric distortions can help with road segmentation
    A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2, p=0.3),  # Color variations
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

val_transform = A.Compose([
    A.Resize(height=IMG_HEIGHT, width=IMG_WIDTH),
    A.Normalize(mean=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])


In [None]:
class RoadSegmentationDataset(Dataset):
    def __init__(self, images_dir, masks_dir, transform=None):
        self.images_paths = sorted(glob(os.path.join(images_dir, "*")))
        self.masks_paths = sorted(glob(os.path.join(masks_dir, "*")))
        self.transform = transform

        assert len(self.images_paths) == len(self.masks_paths), "Mismatch between images and masks count."

    def __len__(self):
        return len(self.images_paths)

    def __getitem__(self, idx):
        image = cv2.imread(self.images_paths[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        mask = cv2.imread(self.masks_paths[idx], cv2.IMREAD_GRAYSCALE)
        mask = np.where(mask > 127, 1, 0).astype(np.uint8)

        if self.transform is not None:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        return image, mask.clone().detach().long()

In [None]:
# Update the paths to point to your augmented dataset
TRAIN_IMAGES_DIR = '/content/drive/MyDrive/Road Segmentation/training_augmented/images'
TRAIN_MASKS_DIR = '/content/drive/MyDrive/Road Segmentation/training_augmented/groundtruth'

# Create the full training dataset using augmented images and ground truth masks
full_dataset = RoadSegmentationDataset(TRAIN_IMAGES_DIR, TRAIN_MASKS_DIR, transform=train_transform)

# Split the dataset into training and validation sets (e.g., 80/20 split)
from torch.utils.data import random_split

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size

indices = list(range(len(full_dataset)))
train_indices, val_indices = train_test_split(indices, test_size=0.2, random_state=42)

train_dataset = Subset(full_dataset, train_indices)
val_dataset = Subset(full_dataset, val_indices)

# Create data loaders for training and validation
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of validation samples: {len(val_dataset)}")

In [None]:
# Load the pretrained SegFormer model
model = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/segformer-b2-finetuned-ade-512-512",  # B2 is larger than B0
    num_labels=NUM_CLASSES,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True
)
model.to(device)

In [None]:
from torch.optim import AdamW

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

In [None]:
#history dict to store metrics per epoch
history = {
    'train_loss': [],
    'val_loss': [],
    'train_acc': [],
    'val_acc': [],
    'train_f1': [],
    'val_f1': [],
    'train_iou': [],
    'val_iou': [],
    'train_dice': [],
    'val_dice': []
}

In [None]:
def calculate_dice_score(pred, target, num_classes=2, smooth=1e-6):
    dice_scores = []
    for cls in range(num_classes):
        pred_cls = (pred == cls).float()
        target_cls = (target == cls).float()
        intersection = torch.sum(pred_cls * target_cls)
        union = torch.sum(pred_cls) + torch.sum(target_cls)
        dice = (2. * intersection + smooth) / (union + smooth)
        dice_scores.append(dice.item())
    return sum(dice_scores) / len(dice_scores)

# Modified train and validation functions to compute accuracy, F1, IoU, and Dice

def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    acc_metric = MulticlassAccuracy(num_classes=NUM_CLASSES).to(device)
    f1_metric = MulticlassF1Score(num_classes=NUM_CLASSES).to(device)
    iou_metric = MulticlassJaccardIndex(num_classes=NUM_CLASSES).to(device)
    dice_total = 0.0

    for images, masks in tqdm(loader, desc="Training", leave=False):
        images, masks = images.to(device), masks.to(device)
        optimizer.zero_grad()
        outputs = model(pixel_values=images).logits
        outputs = F.interpolate(outputs, size=masks.shape[1:], mode='bilinear', align_corners=False)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        preds = torch.argmax(outputs, dim=1)

        # Update metrics
        acc_metric.update(preds, masks)
        f1_metric.update(preds, masks)
        iou_metric.update(preds, masks)
        dice_total += calculate_dice_score(preds, masks, NUM_CLASSES) * images.size(0)

    epoch_loss = running_loss / len(loader.dataset)
    epoch_acc = acc_metric.compute().item()
    epoch_f1 = f1_metric.compute().item()
    epoch_iou = iou_metric.compute().item()
    epoch_dice = dice_total / len(loader.dataset)
    return epoch_loss, epoch_acc, epoch_f1, epoch_iou, epoch_dice


def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    acc_metric = MulticlassAccuracy(num_classes=NUM_CLASSES).to(device)
    f1_metric = MulticlassF1Score(num_classes=NUM_CLASSES).to(device)
    iou_metric = MulticlassJaccardIndex(num_classes=NUM_CLASSES).to(device)
    dice_total = 0.0

    with torch.no_grad():
        for images, masks in tqdm(loader, desc="Validating", leave=False):
            images, masks = images.to(device), masks.to(device)
            outputs = model(pixel_values=images).logits
            outputs = F.interpolate(outputs, size=masks.shape[1:], mode='bilinear', align_corners=False)
            loss = criterion(outputs, masks)
            running_loss += loss.item() * images.size(0)

            preds = torch.argmax(outputs, dim=1)
            acc_metric.update(preds, masks)
            f1_metric.update(preds, masks)
            iou_metric.update(preds, masks)
            dice_total += calculate_dice_score(preds, masks, NUM_CLASSES) * images.size(0)

    epoch_loss = running_loss / len(loader.dataset)
    epoch_acc = acc_metric.compute().item()
    epoch_f1 = f1_metric.compute().item()
    epoch_iou = iou_metric.compute().item()
    epoch_dice = dice_total / len(loader.dataset)
    return epoch_loss, epoch_acc, epoch_f1, epoch_iou, epoch_dice

In [None]:
best_val_loss = float('inf')
patience = 7
early_stopping_counter = 0
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=patience, verbose=True)

for epoch in range(NUM_EPOCHS):
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
    train_loss, train_acc, train_f1, train_iou, train_dice = train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc, val_f1, val_iou, val_dice = validate(model, val_loader, criterion, device)
    scheduler.step(val_loss)

    # Store metrics in history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_acc'].append(train_acc)
    history['val_acc'].append(val_acc)
    history['train_f1'].append(train_f1)
    history['val_f1'].append(val_f1)
    history['train_iou'].append(train_iou)
    history['val_iou'].append(val_iou)
    history['train_dice'].append(train_dice)
    history['val_dice'].append(val_dice)

    # Logging
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    print(f"Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")
    print(f"Train F1: {train_f1:.4f} | Val F1: {val_f1:.4f}")
    print(f"Train IoU: {train_iou:.4f} | Val IoU: {val_iou:.4f}")
    print(f"Train Dice: {train_dice:.4f} | Val Dice: {val_dice:.4f}")

    # Early stopping and checkpointing
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        early_stopping_counter = 0
        torch.save(model.state_dict(), "/content/drive/MyDrive/Road Segmentation/best_segformer_model.pt")
        print("Saved best model")
    else:
        early_stopping_counter += 1
        print(f"Early stopping counter: {early_stopping_counter}/{patience}")
        if early_stopping_counter >= patience:
            print("Early stopping triggered")
            break

In [None]:
# Plotting metrics
import matplotlib.pyplot as plt

def plot_metric(name, ylabel):
    plt.figure()
    plt.plot(history[f'train_{name}'], label=f'Train {name.capitalize()}')
    plt.plot(history[f'val_{name}'], label=f'Val {name.capitalize()}')
    plt.xlabel('Epoch')
    plt.ylabel(ylabel)
    plt.legend()
    plt.title(f'{ylabel} over Epochs')
    plt.show()

plot_metric('loss', 'Loss')
plot_metric('acc', 'Accuracy')
plot_metric('f1', 'F1 Score')
plot_metric('iou', 'IoU')
plot_metric('dice', 'Dice Score')

In [None]:
import matplotlib.pyplot as plt

def decode_segmentation(mask):
    # Simple color mapping: background as black, road as white.
    colors = {
        0: [0, 0, 0],       # background
        1: [255, 255, 255]  # road
    }
    h, w = mask.shape
    rgb = np.zeros((h, w, 3), dtype=np.uint8)
    for label, color in colors.items():
        rgb[mask == label] = color
    return rgb

# Access the underlying dataset
if hasattr(train_dataset, 'dataset'):
    original_dataset = train_dataset.dataset
else:
    original_dataset = train_dataset

# Retrieve the image path using the first index in the subset
sample_image_path = original_dataset.images_paths[train_dataset.indices[0]]
sample_image = cv2.imread(sample_image_path)
sample_image = cv2.cvtColor(sample_image, cv2.COLOR_BGR2RGB)
orig_h, orig_w, _ = sample_image.shape

# Preprocess the image with the same transforms used in validation (without ToTensorV2 conversion)
aug = val_transform(image=sample_image)
input_tensor = aug["image"].unsqueeze(0).to(device)

# Model inference
model.eval()
with torch.no_grad():
    logits = model(pixel_values=input_tensor).logits  # shape: [1, num_labels, H, W]
prediction = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy()

# Resize prediction to original size if needed
prediction_resized = cv2.resize(prediction.astype(np.uint8), (orig_w, orig_h), interpolation=cv2.INTER_NEAREST)

# Visualize
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(sample_image)
plt.title("Original Image")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(decode_segmentation(prediction_resized))
plt.title("Predicted Road Segmentation")
plt.axis("off")
plt.show()


In [None]:
#Test‐time inference + single‐sample visualization
import os
import cv2
import torch
import torch.nn.functional as F
from glob import glob
from tqdm import tqdm
import matplotlib.pyplot as plt

# Define the test images directory and the directory to save predictions
TEST_IMAGES_DIR = '/content/drive/MyDrive/Road Segmentation/test_set_images'
PREDICTIONS_DIR = '/content/drive/MyDrive/Road Segmentation/test_predictions'

# Create the predictions directory if it doesn't exist
os.makedirs(PREDICTIONS_DIR, exist_ok=True)

# Ensure the model is in evaluation mode
model.eval()

# Get all image paths from the test directory and its subdirectories
test_image_paths = sorted(glob(os.path.join(TEST_IMAGES_DIR, '**', '*.*'), recursive=True))

# Loop through each test image and save mask
for img_path in tqdm(test_image_paths, desc="Running Inference on Test Set"):
    # Read and preprocess the image
    image = cv2.imread(img_path)
    if image is None:
        print(f"Failed to read image: {img_path}")
        continue
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    orig_h, orig_w, _ = image.shape

    # Apply the same validation transforms you used during training
    augmented = val_transform(image=image)
    input_tensor = augmented["image"].unsqueeze(0).to(device)

    # Perform inference
    with torch.no_grad():
        logits = model(pixel_values=input_tensor).logits
        logits = F.interpolate(logits, size=(orig_h, orig_w),
                               mode='bilinear', align_corners=False)
        prediction = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy()

    # Convert the prediction to a binary mask (0 background, 255 road)
    prediction_mask = (prediction * 255).astype(np.uint8)

    # Save the prediction mask (same filename as input)
    base_filename = os.path.basename(img_path)
    save_path = os.path.join(PREDICTIONS_DIR, base_filename)
    cv2.imwrite(save_path, prediction_mask)

# Visualize one example

# Pick one sample (e.g. the first)
sample_img_path  = test_image_paths[0]
sample_mask_path = os.path.join(PREDICTIONS_DIR, os.path.basename(sample_img_path))

# Load original image and the saved mask
orig = cv2.imread(sample_img_path)
orig = cv2.cvtColor(orig, cv2.COLOR_BGR2RGB)
mask = cv2.imread(sample_mask_path, cv2.IMREAD_GRAYSCALE)

# Plot them side by side
plt.figure(figsize=(10, 5))

plt.subplot(1, 2, 1)
plt.imshow(orig)
plt.title("Original Image")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(mask, cmap="gray")
plt.title("Predicted Road Mask")
plt.axis("off")

plt.tight_layout()
plt.show()
