In [None]:
# Install required libraries
!pip install -q transformers albumentations opencv-python tqdm
!pip install torchmetrics

from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
#necessary imports
import os
import cv2
import numpy as np
from glob import glob
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset

import albumentations as A
from albumentations.pytorch import ToTensorV2

import torch.nn.functional as F
from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score, MulticlassJaccardIndex
from transformers import SegformerForSemanticSegmentation
from torch.optim.lr_scheduler import ReduceLROnPlateau

from torch.optim import AdamW

import random


def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Define label mappings (modify if you have more classes)
id2label = {0: "background", 1: "road"}
label2id = {"background": 0, "road": 1}
NUM_CLASSES = len(id2label)

NUM_EPOCHS = 15
LEARNING_RATE = 5e-5  # Half of original rate, but higher than typical fine-tuning
BATCH_SIZE = 8
IMG_HEIGHT = 400
IMG_WIDTH = 400


In [None]:
# Initialize the model architecture
model = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/segformer-b2-finetuned-ade-512-512",
    num_labels=NUM_CLASSES,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True
)

# Load the saved state dictionary
model.load_state_dict(torch.load("/content/gdrive/MyDrive/Road Segmentation/segformer_model_epochs_25.pt"))

# Move the model to the appropriate device
model.to(device)

In [None]:
#on the fly transformations
train_transform = A.Compose([
    A.Resize(height=IMG_HEIGHT, width=IMG_WIDTH),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.GaussNoise(p=0.2),  # Add some noise augmentation
    A.OneOf([
        A.ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03, p=0.5),
        A.GridDistortion(p=0.5),
        A.OpticalDistortion(distort_limit=1, shift_limit=0.5, p=0.5),
    ], p=0.3),  # Geometric distortions can help with road segmentation
    A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2, p=0.3),  # Color variations
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

val_test_transform = A.Compose([
    A.Resize(height=IMG_HEIGHT, width=IMG_WIDTH),
    A.Normalize(mean=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

In [None]:

class RoadSegmentationDataset(Dataset):
    def __init__(self, images_dir, masks_dir, transform=None):
        self.images_paths = sorted(glob(os.path.join(images_dir, "*")))
        self.masks_paths = sorted(glob(os.path.join(masks_dir, "*")))
        self.transform = transform

        assert len(self.images_paths) == len(self.masks_paths), "Mismatch between images and masks count."

    def __len__(self):
        return len(self.images_paths)

    def __getitem__(self, idx):
        # Load image and convert to RGB
        image = cv2.imread(self.images_paths[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Load mask in grayscale
        mask = cv2.imread(self.masks_paths[idx], cv2.IMREAD_GRAYSCALE)
        mask = np.where(mask > 127, 1, 0).astype(np.uint8)

        # Apply transformations if any
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        return image, mask

In [None]:
"""import os

# Define your directories
images_dir = '/content/gdrive/MyDrive/Road Segmentation/images'
masks_dir = '/content/gdrive/MyDrive/Road Segmentation/groundtruth'

# List all files in each directory
image_files = sorted(os.listdir(images_dir))
mask_files = sorted(os.listdir(masks_dir))

# Normalize filenames by removing extensions
image_basenames = set(os.path.splitext(f)[0] for f in image_files)
mask_basenames = set(os.path.splitext(f)[0] for f in mask_files)
"""

In [None]:
"""# Images without corresponding masks
missing_masks = image_basenames - mask_basenames
if missing_masks:
    print("Images without corresponding masks:")
    for name in missing_masks:
        print(f"{name}")

# Masks without corresponding images
missing_images = mask_basenames - image_basenames
if missing_images:
    print("\nMasks without corresponding images:")
    for name in missing_images:
        print(f"{name}")
"""

In [None]:
# Update the paths to point to your dataset
IMAGES_DIR = '/content/gdrive/MyDrive/Road Segmentation/images'
MASKS_DIR = '/content/gdrive/MyDrive/Road Segmentation/groundtruth'

import os

file_path = '/content/gdrive/MyDrive/Road Segmentation/groundtruth/boston_2363.png'

if os.path.exists(file_path):
    os.remove(file_path)
    print(f"Deleted: {file_path}")
else:
    print(f"File not found: {file_path}")

# Create the full dataset without transformations
full_dataset = RoadSegmentationDataset(IMAGES_DIR, MASKS_DIR, transform=None)

# Generate indices for splitting
dataset_size = len(full_dataset)
indices = list(range(dataset_size))
train_indices, temp_indices = train_test_split(indices, test_size=0.3, random_state=42)
val_indices, test_indices = train_test_split(temp_indices, test_size=0.5, random_state=42)

# Create subsets with appropriate transformations
train_dataset = Subset(RoadSegmentationDataset(IMAGES_DIR, MASKS_DIR, transform=train_transform), train_indices)
val_dataset = Subset(RoadSegmentationDataset(IMAGES_DIR, MASKS_DIR, transform=val_test_transform), val_indices)
test_dataset = Subset(RoadSegmentationDataset(IMAGES_DIR, MASKS_DIR, transform=val_test_transform), test_indices)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# Print dataset sizes
print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of validation samples: {len(val_dataset)}")
print(f"Number of test samples: {len(test_dataset)}")

In [None]:
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

In [None]:
#history dict to store metrics per epoch
history = {
    'train_loss': [],
    'val_loss': [],
    'train_acc': [],
    'val_acc': [],
    'train_f1': [],
    'val_f1': [],
    'train_iou': [],
    'val_iou': [],
    'train_dice': [],
    'val_dice': []
}

In [None]:
def calculate_dice_score(pred, target, num_classes=2, smooth=1e-6):
    dice_scores = []
    for cls in range(num_classes):
        pred_cls = (pred == cls).float()
        target_cls = (target == cls).float()
        intersection = torch.sum(pred_cls * target_cls)
        union = torch.sum(pred_cls) + torch.sum(target_cls)
        dice = (2. * intersection + smooth) / (union + smooth)
        dice_scores.append(dice.item())
    return sum(dice_scores) / len(dice_scores)

def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    acc_metric = MulticlassAccuracy(num_classes=NUM_CLASSES).to(device)
    f1_metric = MulticlassF1Score(num_classes=NUM_CLASSES).to(device)
    iou_metric = MulticlassJaccardIndex(num_classes=NUM_CLASSES).to(device)
    dice_total = 0.0

    for images, masks in tqdm(loader, desc="Training", leave=False):
        images, masks = images.to(device), masks.to(device, dtype=torch.long) # Change data type to torch.long
        optimizer.zero_grad()
        outputs = model(pixel_values=images).logits
        outputs = F.interpolate(outputs, size=masks.shape[1:], mode='bilinear', align_corners=False)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        preds = torch.argmax(outputs, dim=1)

        # Update metrics
        acc_metric.update(preds, masks)
        f1_metric.update(preds, masks)
        iou_metric.update(preds, masks)
        dice_total += calculate_dice_score(preds, masks, NUM_CLASSES) * images.size(0)

    epoch_loss = running_loss / len(loader.dataset)
    epoch_acc = acc_metric.compute().item()
    epoch_f1 = f1_metric.compute().item()
    epoch_iou = iou_metric.compute().item()
    epoch_dice = dice_total / len(loader.dataset)
    return epoch_loss, epoch_acc, epoch_f1, epoch_iou, epoch_dice


def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    acc_metric = MulticlassAccuracy(num_classes=NUM_CLASSES).to(device)
    f1_metric = MulticlassF1Score(num_classes=NUM_CLASSES).to(device)
    iou_metric = MulticlassJaccardIndex(num_classes=NUM_CLASSES).to(device)
    dice_total = 0.0

    with torch.no_grad():
        for images, masks in tqdm(loader, desc="Validating", leave=False):
            images, masks = images.to(device), masks.to(device, dtype=torch.long) # Change data type to torch.long
            outputs = model(pixel_values=images).logits
            outputs = F.interpolate(outputs, size=masks.shape[1:], mode='bilinear', align_corners=False)
            loss = criterion(outputs, masks)
            running_loss += loss.item() * images.size(0)

            preds = torch.argmax(outputs, dim=1)
            acc_metric.update(preds, masks)
            f1_metric.update(preds, masks)
            iou_metric.update(preds, masks)
            dice_total += calculate_dice_score(preds, masks, NUM_CLASSES) * images.size(0)

    epoch_loss = running_loss / len(loader.dataset)
    epoch_acc = acc_metric.compute().item()
    epoch_f1 = f1_metric.compute().item()
    epoch_iou = iou_metric.compute().item()
    epoch_dice = dice_total / len(loader.dataset)
    return epoch_loss, epoch_acc, epoch_f1, epoch_iou, epoch_dice

In [None]:
best_val_loss = float('inf')
patience = 10
early_stopping_counter = 0
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

for epoch in range(NUM_EPOCHS):
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
    train_loss, train_acc, train_f1, train_iou, train_dice = train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc, val_f1, val_iou, val_dice = validate(model, val_loader, criterion, device)
    scheduler.step(val_loss)

    # Store metrics in history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_acc'].append(train_acc)
    history['val_acc'].append(val_acc)
    history['train_f1'].append(train_f1)
    history['val_f1'].append(val_f1)
    history['train_iou'].append(train_iou)
    history['val_iou'].append(val_iou)
    history['train_dice'].append(train_dice)
    history['val_dice'].append(val_dice)

    # Logging
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    print(f"Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")
    print(f"Train F1: {train_f1:.4f} | Val F1: {val_f1:.4f}")
    print(f"Train IoU: {train_iou:.4f} | Val IoU: {val_iou:.4f}")
    print(f"Train Dice: {train_dice:.4f} | Val Dice: {val_dice:.4f}")

    # Early stopping and checkpointing
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        early_stopping_counter = 0
        torch.save(model.state_dict(), f"/content/gdrive/MyDrive/Road Segmentation/srs_finetuned_segformer_model_epochs_{NUM_EPOCHS}.pt")
        print("Saved best model")
    else:
        early_stopping_counter += 1
        print(f"Early stopping counter: {early_stopping_counter}/{patience}")
        if early_stopping_counter >= patience:
            print("Early stopping triggered")
            break

In [None]:
# Plotting metrics
import matplotlib.pyplot as plt

def plot_metric(name, ylabel):
    plt.figure()
    plt.plot(history[f'train_{name}'], label=f'Train {name.capitalize()}')
    plt.plot(history[f'val_{name}'], label=f'Val {name.capitalize()}')
    plt.xlabel('Epoch')
    plt.ylabel(ylabel)
    plt.legend()
    plt.title(f'{ylabel} over Epochs')
    plt.show()

plot_metric('loss', 'Loss')
plot_metric('acc', 'Accuracy')
plot_metric('f1', 'F1 Score')
plot_metric('iou', 'IoU')
plot_metric('dice', 'Dice Score')

In [None]:
import os
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
import cv2

# Define the directory to save predictions
PREDICTIONS_DIR = '/content/gdrive/MyDrive/Road Segmentation/predictions_for_srs_data'
os.makedirs(PREDICTIONS_DIR, exist_ok=True)

# Ensure the model is in evaluation mode
model.eval()

# Loop through each test sample and save the predicted mask
for idx, (image, _) in enumerate(tqdm(test_loader, desc="Running Inference on Test Set")):
    image = image.to(device)  # Move image to the appropriate device

    with torch.no_grad():
        logits = model(image).logits  # Forward pass
        # Resize logits to match original image size if necessary
        logits = F.interpolate(logits, size=image.shape[2:], mode='bilinear', align_corners=False)

        # Iterate through batch to handle each image individually
        for i in range(logits.shape[0]):
            prediction = torch.argmax(logits[i], dim=0).cpu().numpy()  # Get prediction for current image in batch

            # Generate a color mask with distinct colors
            prediction_mask_colored = np.zeros((prediction.shape[0], prediction.shape[1], 3), dtype=np.uint8)

            # Map classes to colors (add more if needed)
            prediction_mask_colored[prediction == 0] = [0, 0, 0]    # Background (Black)
            prediction_mask_colored[prediction == 1] = [255, 0, 0]  # Road (Red)

            # Save the color mask with a unique filename
            save_path = os.path.join(PREDICTIONS_DIR, f"prediction_{idx}_{i}.png")  # Add batch index to filename
            cv2.imwrite(save_path, prediction_mask_colored)

In [None]:
import cv2

# Select a sample index to visualize
sample_idx = 0

# Retrieve the original image and corresponding prediction
orig_image, _ = test_dataset[sample_idx]
orig_image_np = orig_image.permute(1, 2, 0).numpy()  # Convert to HWC format for visualization

# Load the saved prediction mask, adding batch index 0 to match the save pattern
mask_path = os.path.join(PREDICTIONS_DIR, f"prediction_{sample_idx}_0.png")  # Update the mask path
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

# Check if the mask was loaded successfully
if mask is None:
    print(f"Error: Could not load mask from {mask_path}. Check if the file exists and is a valid image.")
else:
    # Plot the original image and the predicted mask side by side
    plt.figure(figsize=(10, 5))

    plt.subplot(1, 2, 1)
    plt.imshow(orig_image_np)
    plt.title("Original Image")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(mask, cmap="gray")
    plt.title("Predicted Road Mask")
    plt.axis("off")

    plt.tight_layout()
    plt.show()