# Model Training: 0.5m Mangrove Dataset

Train semantic segmentation models on the 0.5m resolution mangrove dataset.

**Dataset**: 0.5m resolution aerial imagery
- Pre-tiled 512x512 tiles stored as .npy
- Binary classification: Not Mangrove (0) vs Mangrove (1)
- Value 255 = ignore/no-data pixels

**Models**: DeepLab, ResNet-UNet, SegFormer

**Prerequisites**:
- Run `02_preprocess_0_5m.ipynb` first to verify data and generate class weights

## 1. Setup and Configuration

In [1]:
import sys
import os
import json
import numpy as np
import random
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from pathlib import Path
import torchvision.transforms.v2 as v2
from torchvision import tv_tensors

# ============================================================
# CONFIGURATION - Edit these paths and hyperparameters
# ============================================================

# Paths
DATA_ROOT = Path('../data/0_5m')
WEIGHTS_DIR = Path('../weights')
PLOTS_DIR = Path('../plots/0_5m')
EXPERIMENTS_DIR = Path('./experiments')

# Data files
IMAGES_FILE = DATA_ROOT / '512dataset_images.npy'
LABELS_FILE = DATA_ROOT / '512dataset_labels.npy'

# Ensure directories exist
WEIGHTS_DIR.mkdir(parents=True, exist_ok=True)
PLOTS_DIR.mkdir(parents=True, exist_ok=True)
EXPERIMENTS_DIR.mkdir(parents=True, exist_ok=True)

# Training hyperparameters
BATCH_SIZE = 16
NUM_EPOCHS = 50
INIT_LR = 5e-5
WEIGHT_DECAY = 0.01
NUM_WORKERS = 0  # Set >0 for multiprocessing (may cause issues on Windows)
TRAIN_SPLIT = 0.8  # 80% train, 20% val

# Model selection: 'deeplab', 'resnet_unet', 'segformer'
MODEL_NAME = 'deeplab'
EXPERIMENT_NAME = f'{MODEL_NAME}_mangrove_0_5m'

# Class definitions (binary)
CLASS_NAMES = ['not_mangrove', 'mangrove']
NUM_CLASSES = len(CLASS_NAMES)
IGNORE_INDEX = 255  # No-data pixels

# Device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Device: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
print(f"\nUsing device: {device}")
print(f"\nData root: {DATA_ROOT.absolute()}")
print(f"Model: {MODEL_NAME}")
print(f"Experiment: {EXPERIMENT_NAME}")
print(f"Classes: {CLASS_NAMES} (ignore={IGNORE_INDEX})")

PyTorch: 2.10.0+cu126
CUDA available: True
Device: NVIDIA GeForce RTX 4060
Memory: 8.0 GB

Using device: cuda

Data root: c:\vscode workspace\ml-mangrove\DroneClassification\human_infra\03_model_training\..\data\0_5m
Model: deeplab
Experiment: deeplab_mangrove_0_5m
Classes: ['not_mangrove', 'mangrove'] (ignore=255)


## 2. Data Augmentation

In [2]:
class Rotate90Only(v2.Transform):
    """Random 90-degree rotations (0, 90, 180, 270 degrees)."""
    
    def __init__(self):
        super().__init__()

    def _transform_image(self, img: torch.Tensor, k):
        if k == 0:
            return img
        hdim, wdim = -2, -1
        if k == 1:   # 90 degrees
            return img.transpose(hdim, wdim).flip(wdim)
        elif k == 2: # 180 degrees
            return img.flip(hdim).flip(wdim)
        elif k == 3: # 270 degrees
            return img.transpose(hdim, wdim).flip(hdim)
        return img

    def forward(self, img: torch.Tensor, mask=None):
        k = random.randint(0, 3)
        img = self._transform_image(img, k)
        if mask is not None:
            mask = self._transform_image(mask, k)
            return img, mask
        return img


# Training augmentation pipeline
train_augmentation = v2.Compose([
    v2.RandomHorizontalFlip(p=0.5),
    v2.RandomVerticalFlip(p=0.5),
    Rotate90Only(),
])

print("Augmentation pipeline:")
print("  - Random horizontal flip (p=0.5)")
print("  - Random vertical flip (p=0.5)")
print("  - Random 90-degree rotation")

Augmentation pipeline:
  - Random horizontal flip (p=0.5)
  - Random vertical flip (p=0.5)
  - Random 90-degree rotation


## 3. Dataset Class

In [3]:
class MangroveDataset(Dataset):
    """
    Dataset for 0.5m mangrove .npy data.
    
    Args:
        images_path: Path to images .npy file
        labels_path: Path to labels .npy file
        indices: Optional subset of indices to use
        augment: Whether to apply augmentation
    """
    
    # ImageNet normalization
    MEAN = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    STD = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    
    def __init__(self, images_path, labels_path, indices=None, augment=False):
        # Load as memory-mapped for efficiency
        self.images = np.load(images_path, mmap_mode='r')
        self.labels = np.load(labels_path, mmap_mode='r')
        self.indices = indices if indices is not None else np.arange(len(self.images))
        self.augment = augment
        
        print(f"Loaded {len(self.indices)} samples (augment={augment})")
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        real_idx = self.indices[idx]
        
        # Load image and label
        image = self.images[real_idx].copy()  # Copy from mmap
        label = self.labels[real_idx].copy()
        
        # Convert to torch tensors
        # Image is (C, H, W) uint8
        image = torch.from_numpy(image).float()
        if image.max() > 1.5:
            image = image / 255.0
        
        # Normalize with ImageNet stats
        image = (image - self.MEAN) / self.STD
        
        # Label is (1, H, W) or (H, W)
        label = torch.from_numpy(label).long()
        if label.dim() == 3:
            label = label.squeeze(0)
        
        # Apply augmentation
        if self.augment:
            image, label = train_augmentation(image, label)
        
        return image, label


print("MangroveDataset class defined")

MangroveDataset class defined


## 4. Load and Split Data

In [4]:
print("=== Loading Data ===")
print()

# Verify files exist
print(f"Images file: {IMAGES_FILE}")
print(f"  exists: {IMAGES_FILE.exists()}")
print(f"Labels file: {LABELS_FILE}")
print(f"  exists: {LABELS_FILE.exists()}")
print()

if not IMAGES_FILE.exists() or not LABELS_FILE.exists():
    raise FileNotFoundError("Data files not found!")

# Load to check shape
images = np.load(IMAGES_FILE, mmap_mode='r')
labels = np.load(LABELS_FILE, mmap_mode='r')

print(f"Images shape: {images.shape}")
print(f"Labels shape: {labels.shape}")
print(f"Total samples: {len(images)}")

=== Loading Data ===

Images file: ..\data\0_5m\512dataset_images.npy
  exists: True
Labels file: ..\data\0_5m\512dataset_labels.npy
  exists: True

Images shape: (573, 3, 512, 512)
Labels shape: (573, 1, 512, 512)
Total samples: 573


In [5]:
print("=== Creating Train/Val Split ===")
print()

# Check for existing split files
train_split_file = DATA_ROOT / 'train_indices.npy'
val_split_file = DATA_ROOT / 'val_indices.npy'

if train_split_file.exists() and val_split_file.exists():
    print("Loading existing split...")
    train_indices = np.load(train_split_file)
    val_indices = np.load(val_split_file)
else:
    print(f"Creating new {TRAIN_SPLIT*100:.0f}/{(1-TRAIN_SPLIT)*100:.0f} split...")
    n_samples = len(images)
    indices = np.arange(n_samples)
    np.random.seed(42)
    np.random.shuffle(indices)
    
    split_idx = int(n_samples * TRAIN_SPLIT)
    train_indices = indices[:split_idx]
    val_indices = indices[split_idx:]
    
    # Save for reproducibility
    np.save(train_split_file, train_indices)
    np.save(val_split_file, val_indices)
    print(f"Saved: {train_split_file.name}, {val_split_file.name}")

print(f"\nTrain: {len(train_indices):,} samples")
print(f"Val:   {len(val_indices):,} samples")

=== Creating Train/Val Split ===

Loading existing split...

Train: 458 samples
Val:   115 samples


In [6]:
print("=== Creating Datasets ===")
print()

train_dataset = MangroveDataset(
    IMAGES_FILE, LABELS_FILE,
    indices=train_indices,
    augment=True
)

val_dataset = MangroveDataset(
    IMAGES_FILE, LABELS_FILE,
    indices=val_indices,
    augment=False
)

# Verify sample
img, mask = train_dataset[0]
print(f"\nSample shapes: Image {img.shape}, Mask {mask.shape}")
print(f"Mask unique values: {torch.unique(mask).tolist()}")

=== Creating Datasets ===

Loaded 458 samples (augment=True)
Loaded 115 samples (augment=False)

Sample shapes: Image torch.Size([3, 512, 512]), Mask torch.Size([512, 512])
Mask unique values: [0, 1, 255]


## 5. Create DataLoaders

In [7]:
print("=== Creating DataLoaders ===")
print()

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

print(f"Batch size: {BATCH_SIZE}")
print(f"Num workers: {NUM_WORKERS}")
print(f"\nBatches per epoch:")
print(f"  Train: {len(train_loader):,}")
print(f"  Val:   {len(val_loader):,}")

=== Creating DataLoaders ===

Batch size: 16
Num workers: 0

Batches per epoch:
  Train: 29
  Val:   8


## 6. Verify Data

In [8]:
print("=== Verifying Data ===")
print()

batch = next(iter(train_loader))
x, y = batch

print(f"Image batch: {x.shape}, dtype={x.dtype}")
print(f"Mask batch:  {y.shape}, dtype={y.dtype}")
print(f"\nImage range: [{x.min():.3f}, {x.max():.3f}]")
print(f"Mask values: {sorted(torch.unique(y).tolist())}")

# Count ignore pixels
ignore_pct = (y == IGNORE_INDEX).float().mean() * 100
print(f"Ignore pixels (255): {ignore_pct:.1f}%")

# Check for NaN/Inf
has_nan = torch.isnan(x).any()
has_inf = torch.isinf(x).any()

if not has_nan and not has_inf:
    print("\nData quality check passed")
else:
    print(f"\nWARNING: NaN={has_nan}, Inf={has_inf}")

=== Verifying Data ===

Image batch: torch.Size([16, 3, 512, 512]), dtype=torch.float32
Mask batch:  torch.Size([16, 512, 512]), dtype=torch.int64

Image range: [-2.118, 2.640]
Mask values: [0, 1, 255]
Ignore pixels (255): 47.6%

Data quality check passed


## 7. Load Class Weights

In [9]:
print("=== Loading Class Weights ===")
print()

weights_file = DATA_ROOT / 'class_weights.json'

if weights_file.exists():
    with open(weights_file) as f:
        weights_dict = json.load(f)
    
    class_frequencies = torch.tensor(weights_dict['class_frequencies'])
    class_weights = torch.tensor(weights_dict['weights_inverse_sqrt']).to(device)
    
    print(f"Loaded from: {weights_file.name}")
    print(f"\nClass frequencies:")
    for i, name in enumerate(CLASS_NAMES):
        print(f"  {name:12s}: {class_frequencies[i]:.4f}")
    print(f"\nClass weights (inverse sqrt):")
    for i, name in enumerate(CLASS_NAMES):
        print(f"  {name:12s}: {class_weights[i]:.4f}")
else:
    # Fallback: balanced weights for binary
    print("class_weights.json not found, using balanced weights")
    class_weights = torch.tensor([1.0, 1.0]).to(device)
    print(f"Class weights: {class_weights.tolist()}")

=== Loading Class Weights ===

Loaded from: class_weights.json

Class frequencies:
  not_mangrove: 0.6267
  mangrove    : 0.3733

Class weights (inverse sqrt):
  not_mangrove: 0.8711
  mangrove    : 1.1289


## 8. Import Models and Training Utilities

In [10]:
# Add project root to path
sys.path.insert(0, '../../')

from models import DeepLab, ResNet_UNet, SegFormer, JaccardLoss, DiceLoss
from training_utils import TrainingSession

print("Imported:")
print("  Models: DeepLab, ResNet_UNet, SegFormer")
print("  Losses: JaccardLoss, DiceLoss")
print("  Training: TrainingSession")

  from .autonotebook import tqdm as notebook_tqdm


Imported:
  Models: DeepLab, ResNet_UNet, SegFormer
  Losses: JaccardLoss, DiceLoss
  Training: TrainingSession


## 9. Initialize Model

In [11]:
print(f"=== Initializing Model: {MODEL_NAME} ===")
print()

if MODEL_NAME == 'deeplab':
    model = DeepLab(
        num_classes=NUM_CLASSES,
        input_image_size=512,
        backbone='resnet50',
        output_stride=4
    ).to(device)
elif MODEL_NAME == 'resnet_unet':
    model = ResNet_UNet(
        num_classes=NUM_CLASSES
    ).to(device)
elif MODEL_NAME == 'segformer':
    model = SegFormer(
        num_classes=NUM_CLASSES
    ).to(device)
else:
    raise ValueError(f"Unknown model: {MODEL_NAME}")

num_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model: {MODEL_NAME}")
print(f"Num classes: {NUM_CLASSES} (binary)")
print(f"Total parameters: {num_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Device: {device}")

=== Initializing Model: deeplab ===

Model: deeplab
Num classes: 2 (binary)
Total parameters: 41,999,191
Trainable parameters: 41,999,191
Device: cuda


## 10. Setup Loss Function

In [12]:
print("=== Setting Up Loss Function ===")
print()

# JaccardLoss with ignore_index for 255 pixels
loss_fn = JaccardLoss(
    num_classes=NUM_CLASSES,
    weight=class_weights,
    alpha=0.3,
    boundary_weight=0.3,
    ignore_index=IGNORE_INDEX  # Important: ignore 255 pixels
)

print("Loss: JaccardLoss (CE + IoU + Boundary)")
print(f"  alpha (IoU weight): 0.3")
print(f"  boundary_weight: 0.3")
print(f"  ignore_index: {IGNORE_INDEX}")
print(f"  class_weights: {[f'{w:.2f}' for w in class_weights.tolist()]}")

=== Setting Up Loss Function ===

Loss: JaccardLoss (CE + IoU + Boundary)
  alpha (IoU weight): 0.3
  boundary_weight: 0.3
  ignore_index: 255
  class_weights: ['0.87', '1.13']


## 11. Setup Optimizer and Scheduler

In [13]:
print("=== Setting Up Optimizer ===")
print()

steps_per_epoch = len(train_loader)
num_training_steps = NUM_EPOCHS * steps_per_epoch

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=INIT_LR,
    weight_decay=WEIGHT_DECAY,
    betas=(0.9, 0.999)
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=num_training_steps,
    eta_min=0
)

print(f"Optimizer: AdamW")
print(f"  Initial LR: {INIT_LR}")
print(f"  Weight decay: {WEIGHT_DECAY}")
print(f"\nScheduler: CosineAnnealingLR")
print(f"  Total steps: {num_training_steps:,}")
print(f"  Steps per epoch: {steps_per_epoch}")

=== Setting Up Optimizer ===

Optimizer: AdamW
  Initial LR: 5e-05
  Weight decay: 0.01

Scheduler: CosineAnnealingLR
  Total steps: 1,450
  Steps per epoch: 29


## 12. Create Training Session

In [14]:
print("=== Creating Training Session ===")
print()

trainer = TrainingSession(
    model=model,
    trainLoader=train_loader,
    testLoader=val_loader,
    lossFunc=loss_fn,
    init_lr=INIT_LR,
    num_epochs=NUM_EPOCHS,
    experiment_name=EXPERIMENT_NAME,
    optimizer=optimizer,
    class_names=CLASS_NAMES,
    scheduler=scheduler,
    ignore_index=IGNORE_INDEX  # Important for metrics calculation
)

print(f"Experiment: {EXPERIMENT_NAME}")
print(f"Epochs: {NUM_EPOCHS}")
print(f"Ignore index: {IGNORE_INDEX}")

=== Creating Training Session ===

Using CUDA device.
Experiment: deeplab_mangrove_0_5m
Epochs: 50
Ignore index: 255


## 13. Train Model

In [None]:
print("=== Starting Training ===")
print()
print(f"Model: {MODEL_NAME}")
print(f"Dataset: 0.5m Mangrove ({len(train_dataset):,} training samples)")
print(f"Classes: {CLASS_NAMES}")
print(f"Epochs: {NUM_EPOCHS}")
print(f"Batch size: {BATCH_SIZE}")
print()

# Start training
trainer.learn()

## 14. Evaluate on Validation Set

In [None]:
print("=== Evaluating on Validation Set ===")
print()

val_metrics = trainer.evaluate(val_loader)

print(f"\nValidation Results:")
print(f"  Pixel Accuracy: {val_metrics['Pixel_Accuracy']:.4f}")
print(f"  Mean IoU: {val_metrics['IoU']:.4f}")

# Plot per-class IoU
trainer.plot_metrics("Class IoU", metrics_wanted=["class_ious"])

## 15. Save Final Model

In [None]:
print("=== Saving Model ===")
print()

# Save to weights directory
model_path = WEIGHTS_DIR / f'{EXPERIMENT_NAME}_final.pth'
torch.save(model.state_dict(), model_path)
print(f"Saved: {model_path}")

# Also save training config
config = {
    'model_name': MODEL_NAME,
    'num_classes': NUM_CLASSES,
    'class_names': CLASS_NAMES,
    'ignore_index': IGNORE_INDEX,
    'batch_size': BATCH_SIZE,
    'num_epochs': NUM_EPOCHS,
    'init_lr': INIT_LR,
    'val_pixel_accuracy': val_metrics['Pixel_Accuracy'],
    'val_miou': val_metrics['IoU']
}

config_path = WEIGHTS_DIR / f'{EXPERIMENT_NAME}_config.json'
with open(config_path, 'w') as f:
    json.dump(config, f, indent=2)
print(f"Saved: {config_path}")

## 16. Visualize Predictions

In [None]:
def denormalize(img):
    """Reverse ImageNet normalization."""
    mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
    return torch.clamp(img * std + mean, 0, 1)


# Class colors for visualization (binary)
CLASS_COLORS = {
    0: [0.8, 0.8, 0.8],  # Not mangrove - gray
    1: [0.0, 0.7, 0.0],  # Mangrove - green
}


def visualize_predictions(model, dataset, indices, save_path=None):
    """Visualize model predictions on sample images."""
    model.eval()
    
    fig, axes = plt.subplots(len(indices), 4, figsize=(16, 4*len(indices)))
    fig.suptitle('0.5m Mangrove Predictions', fontsize=14, fontweight='bold')
    
    with torch.no_grad():
        for row, idx in enumerate(indices):
            img, mask = dataset[idx]
            
            # Get prediction
            pred = model(img.unsqueeze(0).to(device))
            pred_mask = torch.argmax(pred, dim=1).squeeze().cpu().numpy()
            
            # Prepare for display
            img_np = denormalize(img).numpy().transpose(1, 2, 0)
            mask_np = mask.numpy()
            
            # Image
            axes[row, 0].imshow(img_np)
            axes[row, 0].set_title('Image')
            axes[row, 0].axis('off')
            
            # Ground truth (show ignore as white)
            mask_display = np.zeros((*mask_np.shape, 3))
            mask_display[mask_np == 0] = CLASS_COLORS[0]
            mask_display[mask_np == 1] = CLASS_COLORS[1]
            mask_display[mask_np == IGNORE_INDEX] = [1.0, 1.0, 1.0]  # White for ignore
            axes[row, 1].imshow(mask_display)
            axes[row, 1].set_title('Ground Truth')
            axes[row, 1].axis('off')
            
            # Prediction
            pred_display = np.zeros((*pred_mask.shape, 3))
            pred_display[pred_mask == 0] = CLASS_COLORS[0]
            pred_display[pred_mask == 1] = CLASS_COLORS[1]
            axes[row, 2].imshow(pred_display)
            axes[row, 2].set_title('Prediction')
            axes[row, 2].axis('off')
            
            # Overlay
            overlay = 0.6 * img_np + 0.4 * pred_display
            axes[row, 3].imshow(np.clip(overlay, 0, 1))
            axes[row, 3].set_title('Overlay')
            axes[row, 3].axis('off')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=100, bbox_inches='tight')
        print(f"Saved: {save_path}")
    
    plt.show()


# Visualize on validation set
val_indices = np.random.choice(len(val_dataset), 4, replace=False).tolist()
print(f"Visualizing validation samples: {val_indices}")

visualize_predictions(
    model,
    val_dataset,
    val_indices,
    save_path=PLOTS_DIR / f'{EXPERIMENT_NAME}_predictions.png'
)

## 17. Summary

In [None]:
print("=" * 60)
print("Training Complete")
print("=" * 60)
print()
print(f"Model: {MODEL_NAME}")
print(f"Dataset: 0.5m Mangrove (binary)")
print(f"Classes: {CLASS_NAMES}")
print(f"Epochs: {NUM_EPOCHS}")
print()
print(f"Validation Results:")
print(f"  Pixel Accuracy: {val_metrics['Pixel_Accuracy']:.4f}")
print(f"  Mean IoU: {val_metrics['IoU']:.4f}")
print()
print(f"Saved Files:")
print(f"  Model: {WEIGHTS_DIR / f'{EXPERIMENT_NAME}_final.pth'}")
print(f"  Config: {WEIGHTS_DIR / f'{EXPERIMENT_NAME}_config.json'}")
print(f"  Plots: {PLOTS_DIR}")