# üè† Rooftop Segmentation with DeepLabV3Plus

This notebook contains everything needed for rooftop segmentation:
- Data loading and preprocessing
- Model training
- Visualization and inference

Run cells sequentially from top to bottom.


## 1. Imports and Setup


In [None]:
import os
import sys
import torch
import numpy as np
import albumentations as A
from PIL import Image, ImageFile
from matplotlib import pyplot as plt
from glob import glob
from torch.utils.data import Dataset, DataLoader
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
import time
from tqdm import tqdm
from torch.nn import functional as F
from torch.cuda.amp import GradScaler, autocast
import random
from torchvision import transforms as tfs

# Fix for Windows multiprocessing
if sys.platform == 'win32':
    import multiprocessing
    multiprocessing.freeze_support()

ImageFile.LOAD_TRUNCATED_IMAGES = True


In [None]:
class CustomSegmentationDataset(Dataset):
    
    def __init__(self, root, transformations=None):
        all_images = sorted(glob(f"{root}/images/images/*"))
        
        pairs = []
        for img_path in all_images:
            label_path = f"{root}/label/label/{self.get_filename(img_path)}_label.tif"
            if os.path.exists(label_path):
                pairs.append((img_path, label_path))
        
        self.image_paths = [p[0] for p in pairs]
        self.label_paths = [p[1] for p in pairs]
        self.transformations = transformations
        self.n_cls = 2
        
        assert len(self.image_paths) == len(self.label_paths)
        
    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img, label = self.load_image_pair(self.image_paths[idx], self.label_paths[idx])
        
        if self.transformations:
            transformed = self.transformations(image=img, mask=label)
            img = transformed["image"]
            label = transformed["mask"]
        
        return img, (label / 255).int()
        
    def get_filename(self, path):
        return os.path.splitext(os.path.basename(path))[0]

    def load_image_pair(self, img_path, label_path):
        img = np.array(Image.open(img_path).convert("RGB"))
        label = np.array(Image.open(label_path).convert("L"))
        return img, label

def get_dataloaders(root, transformations, batch_size, split=[0.9, 0.05, 0.05], num_workers=0):
    assert abs(sum(split) - 1.0) < 0.001, "Split ratios must sum to 1"
    
    dataset = CustomSegmentationDataset(root=root, transformations=transformations)
    n_classes = dataset.n_cls
    
    total = len(dataset)
    
    if total == 0:
        print(f"\n‚ö†Ô∏è  WARNING: No images found in dataset!")
        print(f"   Please check:")
        print(f"   1. Dataset directory exists: {root}/")
        print(f"   2. Images are in: {root}/images/images/")
        print(f"   3. Labels are in: {root}/label/label/")
        print(f"   4. Label naming: {{image_name}}_label.tif")
        print(f"\n   Please check your dataset structure matches the requirements.\n")
        raise ValueError(f"No images found in dataset directory: {root}")
    
    train_size = int(total * split[0])
    val_size = int(total * split[1])
    test_size = total - train_size - val_size
    
    # Ensure at least 1 sample in each split if possible
    if total >= 3:
        if train_size == 0:
            train_size = 1
        if val_size == 0:
            val_size = 1
        if test_size == 0:
            test_size = total - train_size - val_size
    
    train_ds, val_ds, test_ds = torch.utils.data.random_split(
        dataset, [train_size, val_size, test_size]
    )
    
    print(f"\n‚úÖ Dataset loaded successfully!")
    print(f"Train set: {len(train_ds)} images")
    print(f"Validation set: {len(val_ds)} images")
    print(f"Test set: {len(test_ds)} images\n")
    
    train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    test_dl = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=num_workers)
    
    return train_dl, val_dl, test_dl, n_classes


## 5. Load Dataset


In [None]:
# Configuration
root = "dataset"
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
img_size = 256
batch_size = 32

# Check if dataset directory exists
if not os.path.exists(root):
    print(f"‚ö†Ô∏è  Dataset directory '{root}' not found!")
    print(f"   Current working directory: {os.getcwd()}")
    print(f"   Please ensure the dataset folder exists.\n")
else:
    print(f"‚úÖ Dataset directory found: {os.path.abspath(root)}")

# Data transformations
transform = A.Compose([
    A.Resize(img_size, img_size),
    A.Normalize(mean=mean, std=std),
    ToTensorV2(transpose_mask=True)
])

# Use num_workers=0 on Windows to avoid multiprocessing issues
num_workers = 0 if sys.platform == 'win32' else 2

# Load data
tr_dl, val_dl, test_dl, n_cls = get_dataloaders(
    root=root, 
    transformations=transform, 
    batch_size=batch_size, 
    num_workers=num_workers
)


## 7. Model Setup


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# Create model
model = smp.DeepLabV3Plus(
    encoder_name="resnet50",
    encoder_weights="imagenet",
    in_channels=3,
    classes=n_cls,
)

print(f"\n‚úÖ Model created: DeepLabV3Plus with ResNet50 encoder")
print(f"   Classes: {n_cls}")


## 8. Loss Function and Metrics


In [None]:
class CombinedLoss(torch.nn.Module):
    def __init__(self, ce_weight=0.6, dice_weight=0.4, n_cls=2):
        super().__init__()
        self.ce = torch.nn.CrossEntropyLoss()
        self.ce_weight = ce_weight
        self.dice_weight = dice_weight
        self.n_cls = n_cls
        
    def dice(self, pred, target):
        smooth = 1.0
        pred_soft = F.softmax(pred, dim=1)
        pred_hot = F.one_hot(torch.argmax(pred_soft, dim=1), self.n_cls).permute(0, 3, 1, 2).float()
        target_hot = F.one_hot(target.squeeze(1).long(), self.n_cls).permute(0, 3, 1, 2).float()
        
        inter = (pred_hot * target_hot).sum(dim=(2, 3))
        union = pred_hot.sum(dim=(2, 3)) + target_hot.sum(dim=(2, 3))
        dice_score = (2.0 * inter + smooth) / (union + smooth)
        return 1 - dice_score.mean()
    
    def forward(self, pred, target):
        ce_loss = self.ce(pred, target.squeeze(1).long())
        dice_loss = self.dice(pred, target)
        return self.ce_weight * ce_loss + self.dice_weight * dice_loss

class Metrics:
    def __init__(self, pred, gt, loss_fn, eps=1e-10, n_cls=2):
        self.pred = torch.argmax(pred, dim=1)
        self.gt = gt.squeeze(1)
        self.loss_fn = loss_fn
        self.eps = eps
        self.n_cls = n_cls
        self.pred_full = pred
        
    def to_flat(self, x):
        return x.contiguous().view(-1)
    
    def pixel_accuracy(self):
        with torch.no_grad():
            correct = torch.eq(self.pred, self.gt).int()
        return float(correct.sum()) / float(correct.numel())

    def mean_iou(self):
        with torch.no_grad():
            pred_flat = self.to_flat(self.pred)
            gt_flat = self.to_flat(self.gt)
            
            ious = []
            for cls in range(self.n_cls):
                pred_cls = pred_flat == cls
                gt_cls = gt_flat == cls
                
                if gt_cls.long().sum().item() == 0:
                    ious.append(np.nan)
                else:
                    inter = torch.logical_and(pred_cls, gt_cls).sum().float().item()
                    union = torch.logical_or(pred_cls, gt_cls).sum().float().item()
                    iou = (inter + self.eps) / (union + self.eps)
                    ious.append(iou)
                    
            return np.nanmean(ious)
    
    def compute_loss(self):
        return self.loss_fn(self.pred_full, self.gt.long())

def timer(start=None):
    if start is None:
        return time.time()
    return time.time() - start

# Initialize loss function
loss_fn = CombinedLoss(ce_weight=0.6, dice_weight=0.4, n_cls=n_cls)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-6
)

scaler = GradScaler()

print("‚úÖ Loss function, optimizer, and scheduler initialized")


In [None]:
def train(model, train_loader, val_loader, loss_fn, optimizer, scheduler, scaler, device, epochs, save_prefix, n_cls, threshold=0.005, save_path="saved_models", grad_clip=1.0):
    train_losses, train_pa, train_iou = [], [], []
    val_losses, val_pa, val_iou = [], [], []
    train_batches = len(train_loader)
    val_batches = len(val_loader)
    best_iou = 0.0
    best_loss = float('inf')
    no_improve = 0
    early_stop = 10
    
    os.makedirs(save_path, exist_ok=True)
    model.to(device)
    
    if device == "cuda":
        print(f"Model on GPU: {next(model.parameters()).is_cuda}")
    
    start_time = timer()
    print("Starting training...")
    
    for epoch in range(1, epochs + 1):
        epoch_start = timer()
        train_loss, train_iou_val, train_pa_val = 0, 0, 0
        
        model.train()
        print(f"\nEpoch {epoch}/{epochs}")
        
        for batch in tqdm(train_loader, desc=f"Training"):
            images, masks = batch
            images = images.to(device, non_blocking=True)
            masks = masks.to(device, non_blocking=True)
            
            optimizer.zero_grad()
            
            with autocast():
                predictions = model(images)
                metrics = Metrics(predictions, masks, loss_fn, n_cls=n_cls)
                loss = metrics.compute_loss()
            
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            scaler.step(optimizer)
            scaler.update()
            
            with torch.no_grad():
                train_iou_val += metrics.mean_iou()
                train_pa_val += metrics.pixel_accuracy()
                train_loss += loss.item()
        
        model.eval()
        val_loss, val_iou_val, val_pa_val = 0, 0, 0
        
        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Validation"):
                images, masks = batch
                images = images.to(device, non_blocking=True)
                masks = masks.to(device, non_blocking=True)
                
                with autocast():
                    predictions = model(images)
                    metrics = Metrics(predictions, masks, loss_fn, n_cls=n_cls)
                
                val_loss += metrics.compute_loss().item()
                val_iou_val += metrics.mean_iou()
                val_pa_val += metrics.pixel_accuracy()
        
        train_loss /= train_batches
        train_iou_val /= train_batches
        train_pa_val /= train_batches
        
        val_loss /= val_batches
        val_iou_val /= val_batches
        val_pa_val /= val_batches
        
        scheduler.step(val_loss)
        current_lr = optimizer.param_groups[0]['lr']
        
        print(f"\nEpoch {epoch} Results:")
        print(f"Time: {timer(epoch_start):.2f}s | LR: {current_lr:.6f}")
        print(f"Train - Loss: {train_loss:.4f}, PA: {train_pa_val:.4f}, IoU: {train_iou_val:.4f}")
        print(f"Val   - Loss: {val_loss:.4f}, PA: {val_pa_val:.4f}, IoU: {val_iou_val:.4f}\n")
        
        train_losses.append(train_loss)
        train_iou.append(train_iou_val)
        train_pa.append(train_pa_val)
        
        val_losses.append(val_loss)
        val_iou.append(val_iou_val)
        val_pa.append(val_pa_val)
        
        improved = False
        if val_iou_val > best_iou:
            print(f"IoU improved: {best_iou:.4f} -> {val_iou_val:.4f}")
            best_iou = val_iou_val
            improved = True
            no_improve = 0
            
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_iou': best_iou,
                'val_loss': val_loss,
            }
            torch.save(checkpoint, f"{save_path}/{save_prefix}_best_model_iou.pt")
            print(f"Saved model (IoU: {best_iou:.4f})\n")
        
        if val_loss < best_loss - threshold:
            print(f"Loss improved: {best_loss:.4f} -> {val_loss:.4f}")
            best_loss = val_loss
            
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_loss': best_loss,
                'val_iou': val_iou_val,
            }
            torch.save(checkpoint, f"{save_path}/{save_prefix}_best_model_loss.pt")
            print(f"Saved model (Loss: {best_loss:.4f})\n")
        
        if not improved:
            no_improve += 1
            print(f"No improvement for {no_improve} epochs")
            if no_improve >= early_stop:
                print(f"\nEarly stopping after {no_improve} epochs\n")
                break
    
    total_time = timer(start_time) / 60
    print(f"\nTraining completed in {total_time:.2f} minutes")
    print(f"Best IoU: {best_iou:.4f} | Best Loss: {best_loss:.4f}\n")
    
    return {
        "tr_loss": train_losses, "tr_iou": train_iou, "tr_pa": train_pa,
        "val_loss": val_losses, "val_iou": val_iou, "val_pa": val_pa,
        "best_iou": best_iou, "best_loss": best_loss
    }


In [None]:
# Verify dataloaders have data before training
if len(tr_dl) == 0:
    raise ValueError("Training dataloader is empty! Please check your dataset.")
if len(val_dl) == 0:
    raise ValueError("Validation dataloader is empty! Please check your dataset.")

print(f"\n‚úÖ Ready to train!")
print(f"   Training batches: {len(tr_dl)}")
print(f"   Validation batches: {len(val_dl)}")
print(f"   Test batches: {len(test_dl)}")
print(f"   Model: DeepLabV3Plus with ResNet50 encoder")
print(f"   Classes: {n_cls}")
print(f"   Epochs: 50\n")

# Start training
history = train(
    model=model,
    train_loader=tr_dl,
    val_loader=val_dl,
    loss_fn=loss_fn,
    optimizer=optimizer,
    scheduler=scheduler,
    scaler=scaler,
    device=device,
    epochs=50,
    save_prefix="rooftop",
    n_cls=n_cls,
    grad_clip=1.0
)


## 11. Visualize Learning Curves


In [None]:
class LearningCurves:
    def __init__(self, results):
        self.results = results
        self.plot_curve("tr_iou", "val_iou", "Train IoU", "Val IoU", "IoU", "IoU Score")
        self.plot_curve("tr_pa", "val_pa", "Train PA", "Val PA", "Pixel Accuracy", "PA Score")
        self.plot_curve("tr_loss", "val_loss", "Train Loss", "Val Loss", "Loss", "Loss Value")
    
    def plot_curve(self, train_key, val_key, train_label, val_label, title, ylabel):
        plt.figure(figsize=(10, 5))
        plt.plot(self.results[train_key], label=train_label)
        plt.plot(self.results[val_key], label=val_label)
        plt.title(title)
        plt.xlabel("Epochs")
        plt.ylabel(ylabel)
        plt.legend()
        plt.show()

# Plot learning curves
if 'history' in locals():
    LearningCurves(history)
else:
    print("‚ö†Ô∏è  No training history found. Please train the model first.")


## 12. Inference on Test Set


In [None]:
def run_inference(data_loader, model, device, num_images=15):
    cols = num_images // 3
    rows = num_images // cols
    
    images_list, masks_list, predictions_list = [], [], []
    
    for data in data_loader:
        img, mask = data
        with torch.no_grad():
            pred = torch.argmax(model(img.to(device)), dim=1)
        images_list.append(img)
        masks_list.append(mask)
        predictions_list.append(pred)
    
    plt.figure(figsize=(25, 20))
    pos = 1
    
    for i, (img, mask, pred) in enumerate(zip(images_list, masks_list, predictions_list)):
        if i >= cols:
            break
        
        pos = plot_image(cols, rows, pos, img, title="Original")
        pos = plot_image(cols, rows, pos, mask.squeeze(0), is_mask=True, title="Ground Truth")
        pos = plot_image(cols, rows, pos, pred, is_mask=True, title="Prediction")

# Load best model
iou_path = "saved_models/rooftop_best_model_iou.pt"
loss_path = "saved_models/rooftop_best_model_loss.pt"
old_path = "saved_models/rooftop_best_model.pt"

if os.path.exists(iou_path):
    print(f"Loading model from {iou_path}")
    checkpoint = torch.load(iou_path, map_location=device)
    if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"Best IoU: {checkpoint.get('best_iou', 'N/A')}")
    else:
        model = checkpoint
elif os.path.exists(loss_path):
    print(f"Loading model from {loss_path}")
    checkpoint = torch.load(loss_path, map_location=device)
    if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"Best Loss: {checkpoint.get('best_loss', 'N/A')}")
    else:
        model = checkpoint
elif os.path.exists(old_path):
    print(f"Loading model from {old_path}")
    model = torch.load(old_path, map_location=device)
else:
    raise FileNotFoundError("No model found! Train the model first.")
        
model.to(device)
model.eval()

# Run inference
if len(test_dl) > 0:
    run_inference(test_dl, model, device)
else:
    print("‚ö†Ô∏è  Test dataloader is empty. Cannot run inference.")
