## Introduction

### In this notebook we use [DeepLabV3+](https://arxiv.org/abs/1802.02611) for Road Extraction from Satellite Imagery using [DeepGlobe Road Extraction Dataset](https://www.kaggle.com/balraj98/deepglobe-road-extraction-dataset).

### Libraries 📚⬇

In [1]:
import os, cv2
import numpy as np
import pandas as pd
import random, tqdm
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

import warnings
warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import albumentations as album

In [2]:
!pip install --upgrade pip

Collecting pip
  Downloading pip-25.0.1-py3-none-any.whl.metadata (3.7 kB)
Downloading pip-25.0.1-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m31.4 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.1.2
    Uninstalling pip-24.1.2:
      Successfully uninstalled pip-24.1.2
Successfully installed pip-25.0.1


In [3]:
!pip install -q -U segmentation-models-pytorch albumentations > /dev/null
import segmentation_models_pytorch as smp

### Read Data & Create train / valid splits 📁

In [5]:
DATA_DIR = '../input/deepglobe-road-extraction-dataset'

metadata_df = pd.read_csv(os.path.join(DATA_DIR, 'metadata.csv'))
metadata_df = metadata_df[metadata_df['split']=='train']
metadata_df = metadata_df[['image_id', 'sat_image_path', 'mask_path']]
metadata_df['sat_image_path'] = metadata_df['sat_image_path'].apply(lambda img_pth: os.path.join(DATA_DIR, img_pth))
metadata_df['mask_path'] = metadata_df['mask_path'].apply(lambda img_pth: os.path.join(DATA_DIR, img_pth))
# Shuffle DataFrame
metadata_df = metadata_df.sample(frac=1).reset_index(drop=True)

# Perform 90/10 split for train / val
valid_df = metadata_df.sample(frac=0.1, random_state=42)
train_df = metadata_df.drop(valid_df.index)
len(train_df), len(valid_df)

(5603, 623)

In [6]:
metadata_df.head(10)

Unnamed: 0,image_id,sat_image_path,mask_path
0,272122,../input/deepglobe-road-extraction-dataset/tra...,../input/deepglobe-road-extraction-dataset/tra...
1,622973,../input/deepglobe-road-extraction-dataset/tra...,../input/deepglobe-road-extraction-dataset/tra...
2,207813,../input/deepglobe-road-extraction-dataset/tra...,../input/deepglobe-road-extraction-dataset/tra...
3,493904,../input/deepglobe-road-extraction-dataset/tra...,../input/deepglobe-road-extraction-dataset/tra...
4,110236,../input/deepglobe-road-extraction-dataset/tra...,../input/deepglobe-road-extraction-dataset/tra...
5,998129,../input/deepglobe-road-extraction-dataset/tra...,../input/deepglobe-road-extraction-dataset/tra...
6,682308,../input/deepglobe-road-extraction-dataset/tra...,../input/deepglobe-road-extraction-dataset/tra...
7,507376,../input/deepglobe-road-extraction-dataset/tra...,../input/deepglobe-road-extraction-dataset/tra...
8,664839,../input/deepglobe-road-extraction-dataset/tra...,../input/deepglobe-road-extraction-dataset/tra...
9,841576,../input/deepglobe-road-extraction-dataset/tra...,../input/deepglobe-road-extraction-dataset/tra...


In [7]:
# Define data augmentation
def get_training_augmentation():
    train_transform = [
        album.HorizontalFlip(p=0.5),
        album.VerticalFlip(p=0.5),
        album.ShiftScaleRotate(scale_limit=0.5, rotate_limit=45, shift_limit=0.1, p=1, border_mode=0),
        album.RandomBrightnessContrast(p=0.2),
        album.RandomGamma(p=0.2),
        album.RandomCrop(height=256, width=256, p=1.0)
    ]
    return album.Compose(train_transform)

In [8]:
# Define normalization transformation
def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

def get_preprocessing(preprocessing_fn=None):
    _transform = []
    if preprocessing_fn:
        _transform.append(album.Lambda(image=preprocessing_fn))
    _transform.append(album.Lambda(image=to_tensor, mask=to_tensor))
    return album.Compose(_transform)

In [9]:
from torch.utils.data import DataLoader, Dataset

# Custom dataset class
class RoadsDataset(Dataset):
    def __init__(self, df, class_rgb_values=None, augmentation=None, preprocessing=None, target_size=(1024, 1024)):
        self.image_paths = df['sat_image_path'].tolist()
        self.mask_paths = df['mask_path'].tolist()
        self.class_rgb_values = class_rgb_values
        self.augmentation = augmentation
        self.preprocessing = preprocessing
        self.target_size = target_size
    
    def __getitem__(self, i):
        # Read images and masks
        image = cv2.cvtColor(cv2.imread(self.image_paths[i]), cv2.COLOR_BGR2RGB)
        mask = cv2.cvtColor(cv2.imread(self.mask_paths[i]), cv2.COLOR_BGR2RGB)
        
        # One-hot-encode the mask
        mask = one_hot_encode(mask, self.class_rgb_values).astype('float')
        
        # Apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        image = cv2.resize(image, self.target_size)
        mask = cv2.resize(mask, self.target_size)
        
        # Apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        return image, mask
    
    def __len__(self):
        return len(self.image_paths)

In [10]:
# Helper functions for one-hot encoding and color coding
def one_hot_encode(label, label_values):
    semantic_map = []
    for colour in label_values:
        equality = np.equal(label, colour)
        class_map = np.all(equality, axis=-1)
        semantic_map.append(class_map)
    semantic_map = np.stack(semantic_map, axis=-1)
    return semantic_map

def reverse_one_hot(image):
    x = np.argmax(image, axis=-1)
    return x

def colour_code_segmentation(image, label_values):
    colour_codes = np.array(label_values)
    x = colour_codes[image.astype(int)]
    return x

# Define class RGB values
class_dict = pd.read_csv(os.path.join(DATA_DIR, 'class_dict.csv'))
class_names = class_dict['name'].tolist()
class_rgb_values = class_dict[['r','g','b']].values.tolist()

# Shortlist specific classes
select_classes = ['background', 'road']
select_class_indices = [class_names.index(cls.lower()) for cls in select_classes]
select_class_rgb_values =  np.array(class_rgb_values)[select_class_indices]

# Get train and val dataset instances
train_dataset = RoadsDataset(
    train_df, 
    augmentation=get_training_augmentation(),
    preprocessing=get_preprocessing(smp.encoders.get_preprocessing_fn('resnet50', 'imagenet')),
    class_rgb_values=select_class_rgb_values,
)

valid_dataset = RoadsDataset(
    valid_df, 
    preprocessing=get_preprocessing(smp.encoders.get_preprocessing_fn('resnet50', 'imagenet')),
    class_rgb_values=select_class_rgb_values,
)

In [11]:
# Create dataloaders with reduced batch size
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(valid_dataset, batch_size=4, shuffle=False)

# Initialize the DeepLabV3+ model
model = smp.DeepLabV3Plus(
    encoder_name="resnet50",        # Use a more powerful encoder
    encoder_weights="imagenet",     # Use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # Model input channels (1 for grayscale, 3 for RGB)
    classes=len(select_classes),    # Model output channels (number of classes in your dataset)
)

Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 327MB/s]


In [12]:
import torch.optim as optim

# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Define the learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5, verbose=True)

# Define combined loss function
class DiceBCELoss(nn.Module):
    def __init__(self, smooth=1):
        super(DiceBCELoss, self).__init__()
        self.smooth = smooth

    def forward(self, y_true, y_pred):
        y_true = y_true.contiguous()
        y_pred = y_pred.contiguous()
        
        intersection = (y_true * y_pred).sum(dim=2).sum(dim=2)
        dice_loss = (2. * intersection + self.smooth) / (y_true.sum(dim=2).sum(dim=2) + y_pred.sum(dim=2).sum(dim=2) + self.smooth)
        dice_loss = 1 - dice_loss.mean()
        
        bce_loss = nn.functional.binary_cross_entropy_with_logits(y_pred, y_true)
        loss = bce_loss + dice_loss
        
        return loss

# Define the loss function
loss_fn = DiceBCELoss()

# Define IoU calculation
def calculate_iou(y_true, y_pred, smooth=1e-6):
    y_true = y_true.contiguous()
    y_pred = y_pred.contiguous()
    
    intersection = (y_true * y_pred).sum(dim=2).sum(dim=2)
    union = y_true.sum(dim=2).sum(dim=2) + y_pred.sum(dim=2).sum(dim=2) - intersection
    
    iou = (intersection + smooth) / (union + smooth)
    return iou.mean()

# Define Early Stopping
class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.inf
        self.delta = delta

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), 'checkpoint.pt')
        self.val_loss_min = val_loss

In [None]:
# Training loop with mixed precision
num_epochs = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

scaler = torch.cuda.amp.GradScaler()
early_stopping = EarlyStopping(patience=10, verbose=True)

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for images, masks in train_loader:
        images = images.float().to(device)
        masks = masks.float().to(device)
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            outputs = model(images)
            loss = loss_fn(outputs, masks)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        train_loss += loss.item()
    train_loss /= len(train_loader)

    model.eval()
    val_loss = 0
    val_iou = 0
    with torch.no_grad():
        for images, masks in val_loader:
            images = images.float().to(device)
            masks = masks.float().to(device)
            with torch.cuda.amp.autocast():
                outputs = model(images)
                loss = loss_fn(outputs, masks)
            val_loss += loss.item()
            val_iou += calculate_iou(masks, outputs).item()
    val_loss /= len(val_loader)
    val_iou /= len(val_loader)

    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val IoU: {val_iou:.4f}')

    scheduler.step(val_iou)
    early_stopping(val_loss, model)

    if early_stopping.early_stop:
        print("Early stopping")
        break

In [None]:
# Save the model
torch.save(model.state_dict(), 'best_model.pth')

# Assuming you have stored the training and validation losses in lists
train_losses = []  # Replace with your actual training loss values
val_losses = []    # Replace with your actual validation loss values

plt.figure(figsize=(20,8))
plt.plot(range(num_epochs), train_losses, lw=3, label='Train')
plt.plot(range(num_epochs), val_losses, lw=3, label='Valid')
plt.xlabel('Epochs', fontsize=20)
plt.ylabel('Loss', fontsize=20)
plt.title('Training and Validation Loss', fontsize=20)
plt.legend(loc='best', fontsize=16)
plt.grid()
plt.savefig('loss_plot.png')
plt.show()