## Building Segmentation Training Migration
Migrating training pipeline from fastai to PyTorch Lightning and TorchGeo.
This notebook works with data prepared by the BEAM DataTiler.

## Setup and Imports


In [None]:
import os
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import rasterio
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchgeo.datasets as datasets
import torchgeo.transforms as transforms
from torchvision import transforms as T
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, TQDMProgressBar
import segmentation_models_pytorch as smp


In [None]:
# Set random seeds for reproducibility
pl.seed_everything(42)


## Configuration


In [3]:

# Project configuration
PROJECT_DIR = Path('/datadrive/training_data_by_city/sd')
TRAIN_DIR = PROJECT_DIR / 'tiles/train'
VAL_DIR = PROJECT_DIR / 'tiles/test'  # Using test split as validation
BATCH_SIZE = 8
NUM_WORKERS = 4
IMG_SIZE = 256  # Assuming 256x256 tiles based on your config

# Model configuration
MODEL_CONFIG = {
    'architecture': 'unet',
    'backbone': 'resnet18',
    'learning_rate': 1e-4,
    'epochs': 50
}

In [4]:
class BuildingSegmentationDataset(Dataset):
    """Dataset for building segmentation"""
    
    def __init__(self, data_dir, transform=None):
        self.data_dir = Path(data_dir)
        self.image_dir = self.data_dir / 'images'
        self.mask_dir = self.data_dir / 'masks'
        self.transform = transform
        
        # Get all image files
        self.image_files = sorted(list(self.image_dir.glob('*.TIF')))
        self.mask_files = sorted(list(self.mask_dir.glob('*.TIF')))
        
        # Verify matching files
        assert len(self.image_files) == len(self.mask_files), \
            f"Number of images ({len(self.image_files)}) and masks ({len(self.mask_files)}) don't match"
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        # Load image and mask
        with rasterio.open(self.image_files[idx]) as src:
            image = src.read()  # CxHxW format
            
        with rasterio.open(self.mask_files[idx]) as src:
            mask = src.read(1)  # Single channel, HxW format
        
        # Convert to torch tensors
        image = torch.from_numpy(image).float()
        mask = torch.from_numpy(mask).long()
        
        # Normalize image to [0, 1] range
        image = image / 255.0
        
        # Convert mask to binary (0 or 1)
        mask = (mask > 0).long()
        
        if self.transform:
            # Stack image and mask for joint transforms
            stacked = torch.cat([image, mask.unsqueeze(0)], dim=0)
            transformed = self.transform(stacked)
            image = transformed[:3]  # First 3 channels are RGB
            mask = transformed[3]    # Last channel is mask
        
        return image, mask

## DataModule Implementation


In [5]:
class BuildingSegmentationDataModule(pl.LightningDataModule):
    """PyTorch Lightning DataModule for building segmentation"""
    
    def __init__(self, train_dir, val_dir, batch_size=8, num_workers=4):
        super().__init__()
        self.train_dir = train_dir
        self.val_dir = val_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        
        # Define transforms
        self.train_transform = T.Compose([
            T.RandomHorizontalFlip(),
            T.RandomVerticalFlip(),
            T.RandomRotation(45),
            # Add more augmentations as needed
        ])
        
        self.val_transform = None  # No augmentation for validation
    
    def setup(self, stage=None):
        # Create train and validation datasets
        self.train_dataset = BuildingSegmentationDataset(
            self.train_dir,
            transform=self.train_transform
        )
        
        self.val_dataset = BuildingSegmentationDataset(
            self.val_dir,
            transform=self.val_transform
        )
    
    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True
        )
    
    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )


In [6]:
class BuildingSegmentationModel(pl.LightningModule):
    def __init__(self, architecture='unet', backbone='resnet18', learning_rate=1e-4):
        super().__init__()
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        
        # Initialize actual model architecture
        if architecture.lower() == 'unet':
            self.model = smp.Unet(
                encoder_name=backbone,      # Choose encoder (e.g., 'resnet18')
                encoder_weights="imagenet", # Use pre-trained weights
                in_channels=3,             # RGB images
                classes=1,                 # Binary segmentation
            )
        else:
            raise ValueError(f"Architecture {architecture} not implemented yet")
        
        # Define loss function for binary segmentation
        self.criterion = nn.BCEWithLogitsLoss()
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        images, masks = batch
        outputs = self(images)
        loss = self.criterion(outputs, masks.float().unsqueeze(1))  # Add channel dimension
        
        # Log training metrics
        self.log('train/loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        
        # Log images periodically
        if batch_idx % 100 == 0:
            self._log_images(images, masks, outputs, 'train', batch_idx)
            
        return loss
    
    def validation_step(self, batch, batch_idx):
        images, masks = batch
        outputs = self(images)
        loss = self.criterion(outputs, masks.float().unsqueeze(1))
        
        # Calculate additional metrics
        preds = (torch.sigmoid(outputs) > 0.5).float()
        iou = self.calculate_iou(preds, masks.unsqueeze(1))
        dice = self.calculate_dice(preds, masks.unsqueeze(1))
        
        # Log validation metrics
        metrics = {
            'val/loss': loss,
            'val/iou': iou,
            'val/dice': dice
        }
        self.log_dict(metrics, on_epoch=True, prog_bar=True)
        
        # Log validation images
        if batch_idx == 0:
            self._log_images(images, masks, outputs, 'val', self.current_epoch)
            
        return metrics
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=0.1,
            patience=5,
            verbose=True
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": scheduler,
            "monitor": "val/loss"
        }
    
    @staticmethod
    def calculate_iou(preds, targets):
        intersection = torch.logical_and(preds, targets).sum()
        union = torch.logical_or(preds, targets).sum()
        iou = intersection / (union + 1e-6)
        return iou
    
    @staticmethod
    def calculate_dice(preds, targets):
        intersection = torch.logical_and(preds, targets).sum()
        union = preds.sum() + targets.sum()
        dice = 2 * intersection / (union + 1e-6)
        return dice
    
    def _log_images(self, images, masks, outputs, stage, step):
        n_images = min(4, images.shape[0])
        predictions = torch.sigmoid(outputs) > 0.5
        
        for idx in range(n_images):
            self.logger.experiment.add_image(
                f'{stage}/image_{idx}',
                images[idx],
                step,
                dataformats='CHW'
            )
            
            self.logger.experiment.add_image(
                f'{stage}/mask_{idx}',
                masks[idx].float().unsqueeze(0),
                step,
                dataformats='CHW'
            )
            
            self.logger.experiment.add_image(
                f'{stage}/pred_{idx}',
                predictions[idx].float(),
                step,
                dataformats='CHW'
            )


In [7]:
def setup_training(model_config, data_module):
    # Initialize TensorBoard logger
    logger = TensorBoardLogger(
        save_dir='lightning_logs',
        name=f"building_seg_{model_config['architecture']}_{model_config['backbone']}",
        default_hp_metric=False  # Disable automatic hp logging
    )
    
    # Callbacks
    callbacks = [
        ModelCheckpoint(
            monitor='val/loss',
            dirpath='checkpoints',
            filename=f"building_seg_{model_config['architecture']}" + 
                     "_{epoch:02d}_{val_loss:.3f}",
            save_top_k=3,
            mode='min'
        ),
        EarlyStopping(
            monitor='val/loss',
            patience=10,
            mode='min'
        ),
        TQDMProgressBar(refresh_rate=20)
    ]
    
    # Initialize trainer
    trainer = pl.Trainer(
        max_epochs=model_config['epochs'],
        accelerator='gpu' if torch.cuda.is_available() else 'cpu',
        devices=4,
        callbacks=callbacks,
        logger=logger,
        precision=16,
        log_every_n_steps=10
    )
    
    return trainer

In [None]:
# 2. Initialize data module
data_module = BuildingSegmentationDataModule(
    train_dir=TRAIN_DIR,
    val_dir=VAL_DIR,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS
)

# 3. Initialize model
model = BuildingSegmentationModel(
    architecture=MODEL_CONFIG['architecture'],
    backbone=MODEL_CONFIG['backbone'],
    learning_rate=MODEL_CONFIG['learning_rate']
)

# 4. Setup training
trainer = setup_training(MODEL_CONFIG, data_module)

# 5. Train the model
trainer.fit(model, data_module)