# Gated Ensemble: Landcover + Mangrove Models on 0.5m Data

Run both models on our private 0.5m mangrove dataset and combine predictions
via a gated ensemble.

**Landcover Model**: SegFormer trained on Landcover.ai v1
- 5 output classes: Background, Building, Woodland, Water, Road

**Mangrove Model**: SegFormer trained on 0.5m mangrove data
- Binary output: Mangrove / Not Mangrove

**Gated Ensemble Strategy**:
The landcover model cannot distinguish between woodland and mangrove â€” it labels
most mangrove pixels as "woodland". We use the mangrove model as a gate:
wherever it predicts **mangrove**, we override the landcover prediction with a
dedicated **mangrove** class. This produces a 6-class output:

| Class | Label |
|-------|-------|
| 0 | Background |
| 1 | Building |
| 2 | Woodland |
| 3 | Water |
| 4 | Road |
| 5 | **Mangrove** (gated) |

## 1. Setup and Configuration

In [None]:
import sys
import json
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from pathlib import Path
from tqdm import tqdm

# ============================================================
# CONFIGURATION
# ============================================================

# Paths
DATA_ROOT = Path('../data/0_5m')
WEIGHTS_DIR = Path('../weights')
PLOTS_DIR = Path('../plots/landcover_05m_inference')

# Data files
IMAGES_FILE = DATA_ROOT / '512dataset_images.npy'
LABELS_FILE = DATA_ROOT / '512dataset_labels.npy'

# --- Landcover model (5 classes) ---
MODEL_NAME = 'segformer'
LANDCOVER_WEIGHTS = WEIGHTS_DIR / 'human_segformer.pth'
LANDCOVER_NUM_CLASSES = 5

# --- Mangrove model (binary) ---
MANGROVE_WEIGHTS = WEIGHTS_DIR / 'mangrove_segformer.pth'
MANGROVE_NUM_CLASSES = 1  # sigmoid output

BATCH_SIZE = 16
NUM_WORKERS = 0
IGNORE_INDEX = 255
MANGROVE_THRESHOLD = 0.5

# Landcover class definitions (model output, 0-4)
LANDCOVER_NAMES = ['background', 'building', 'woodland', 'water', 'road']
LANDCOVER_COLORS = {
    0: [0.8, 0.8, 0.8],  # Background - gray
    1: [1.0, 0.0, 0.0],  # Building - red
    2: [0.0, 0.5, 0.0],  # Woodland - dark green
    3: [0.0, 0.0, 1.0],  # Water - blue
    4: [1.0, 1.0, 0.0],  # Road - yellow
}

# Ensemble output: 6 classes (landcover 0-4 + mangrove gate 5)
ENSEMBLE_NAMES = ['background', 'building', 'woodland', 'water', 'road', 'mangrove']
NUM_ENSEMBLE_CLASSES = 6
ENSEMBLE_COLORS = {
    0: [0.8, 0.8, 0.8],  # Background - gray
    1: [1.0, 0.0, 0.0],  # Building - red
    2: [0.0, 0.5, 0.0],  # Woodland - dark green
    3: [0.0, 0.0, 1.0],  # Water - blue
    4: [1.0, 1.0, 0.0],  # Road - yellow
    5: [0.0, 0.9, 0.4],  # Mangrove - bright green
}

# Mangrove GT legend
MANGROVE_NAMES = ['not_mangrove', 'mangrove']
MANGROVE_COLORS = {
    0: [0.8, 0.8, 0.8],
    1: [0.0, 0.7, 0.0],
}

PLOTS_DIR.mkdir(parents=True, exist_ok=True)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")
print(f"Landcover weights: {LANDCOVER_WEIGHTS}")
print(f"Mangrove weights:  {MANGROVE_WEIGHTS}")
print(f"Target data: {DATA_ROOT}")

## 2. Load Models

In [None]:
sys.path.insert(0, '../../')

from models import SegFormer

print("=== Loading Landcover Model (5-class) ===")
landcover_model = SegFormer(num_classes=LANDCOVER_NUM_CLASSES)
state_dict = torch.load(LANDCOVER_WEIGHTS, map_location=device)
landcover_model.load_state_dict(state_dict)
landcover_model = landcover_model.to(device)
landcover_model.eval()
print(f"Loaded: {LANDCOVER_WEIGHTS}")

print()

print("=== Loading Mangrove Model (binary) ===")
mangrove_model = SegFormer(num_classes=MANGROVE_NUM_CLASSES)
state_dict = torch.load(MANGROVE_WEIGHTS, map_location=device)
mangrove_model.load_state_dict(state_dict)
mangrove_model = mangrove_model.to(device)
mangrove_model.eval()
print(f"Loaded: {MANGROVE_WEIGHTS}")

## 3. Load Mangrove Dataset

In [3]:
class MangroveDataset(Dataset):
    """Load 0.5m mangrove .npy data for inference."""
    
    MEAN = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    STD = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    
    def __init__(self, images_path, labels_path, indices=None):
        self.images = np.load(images_path, mmap_mode='r')
        self.labels = np.load(labels_path, mmap_mode='r')
        self.indices = indices if indices is not None else np.arange(len(self.images))
        print(f"Loaded {len(self.indices)} samples")
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        real_idx = self.indices[idx]
        
        image = self.images[real_idx].copy()
        label = self.labels[real_idx].copy()
        
        image = torch.from_numpy(image).float()
        if image.max() > 1.5:
            image = image / 255.0
        image = (image - self.MEAN) / self.STD
        
        label = torch.from_numpy(label).long()
        if label.dim() == 3:
            label = label.squeeze(0)
        
        return image, label, real_idx


print("=== Loading Mangrove Dataset ===")
print()

# Load all samples
dataset = MangroveDataset(IMAGES_FILE, LABELS_FILE)

loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

print(f"\nTotal samples: {len(dataset):,}")
print(f"Batches: {len(loader):,}")

=== Loading Mangrove Dataset ===

Loaded 573 samples

Total samples: 573
Batches: 36


## 4. Run Gated Ensemble Inference

In [None]:
print("=== Running Gated Ensemble Inference ===")
print()

all_landcover_preds = []
all_mangrove_preds = []
all_targets = []

with torch.no_grad():
    for images, masks, _ in tqdm(loader, desc="Inference"):
        images = images.to(device)

        # Landcover model: 5-class argmax
        lc_out = landcover_model(images)
        lc_preds = torch.argmax(lc_out, dim=1)

        # Mangrove model: binary sigmoid
        mg_out = mangrove_model(images)
        mg_preds = (torch.sigmoid(mg_out) > MANGROVE_THRESHOLD).squeeze(1).long()

        all_landcover_preds.append(lc_preds.cpu().numpy())
        all_mangrove_preds.append(mg_preds.cpu().numpy())
        all_targets.append(masks.numpy())

all_landcover_preds = np.concatenate(all_landcover_preds, axis=0)
all_mangrove_preds = np.concatenate(all_mangrove_preds, axis=0)
all_targets = np.concatenate(all_targets, axis=0)

# --- Apply gated ensemble ---
# Start with landcover predictions (classes 0-4)
# Where mangrove model says mangrove -> override with class 5
all_ensemble_preds = all_landcover_preds.copy()
mangrove_gate = all_mangrove_preds == 1
all_ensemble_preds[mangrove_gate] = 5  # mangrove class

torch.cuda.empty_cache()

print(f"\nLandcover preds shape: {all_landcover_preds.shape}")
print(f"Mangrove preds shape:  {all_mangrove_preds.shape}")
print(f"Ensemble preds shape:  {all_ensemble_preds.shape}")
print(f"Ensemble classes present: {np.unique(all_ensemble_preds).tolist()}")
print(f"Pixels gated to mangrove: {mangrove_gate.sum():,} ({mangrove_gate.mean()*100:.1f}%)")

## 5. Ensemble Class Distribution Analysis

What does the gated ensemble predict across the dataset?

In [None]:
print("=== Ensemble Prediction Distribution ===")
print()

preds_flat = all_ensemble_preds.flatten()
targets_flat = all_targets.flatten()

total_pixels = len(preds_flat)
print(f"Total pixels: {total_pixels:,}")
print()
print("Ensemble prediction distribution:")
print(f"{'Class':<12} {'Count':>12} {'Percentage':>10}")
print("-" * 36)

pred_counts = {}
for c in range(NUM_ENSEMBLE_CLASSES):
    count = (preds_flat == c).sum()
    pct = count / total_pixels * 100
    pred_counts[c] = count
    print(f"{ENSEMBLE_NAMES[c]:<12} {count:>12,} {pct:>9.1f}%")

In [None]:
# Breakdown: what does the ensemble predict for mangrove vs non-mangrove pixels?
print("=== Ensemble Predictions Conditioned on Ground Truth ===")
print()

for gt_class, gt_name in enumerate(MANGROVE_NAMES):
    mask = targets_flat == gt_class
    if mask.sum() == 0:
        continue

    gt_preds = preds_flat[mask]
    gt_total = len(gt_preds)

    print(f"Where ground truth = {gt_name} ({gt_total:,} pixels):")
    for c in range(NUM_ENSEMBLE_CLASSES):
        count = (gt_preds == c).sum()
        pct = count / gt_total * 100
        bar = '#' * int(pct / 2)
        print(f"  -> {ENSEMBLE_NAMES[c]:<12}: {pct:>5.1f}%  {bar}")
    print()

# Also show ignore pixels
ignore_mask = targets_flat == IGNORE_INDEX
if ignore_mask.sum() > 0:
    ignore_preds = preds_flat[ignore_mask]
    ignore_total = len(ignore_preds)
    print(f"Where ground truth = ignore/255 ({ignore_total:,} pixels):")
    for c in range(NUM_ENSEMBLE_CLASSES):
        count = (ignore_preds == c).sum()
        pct = count / ignore_total * 100
        print(f"  -> {ENSEMBLE_NAMES[c]:<12}: {pct:>5.1f}%")

In [None]:
print("=== Plotting Ensemble Distribution ===")
print()

fig, axes = plt.subplots(1, 3, figsize=(20, 5))
fig.suptitle('Gated Ensemble Predictions on 0.5m Mangrove Data', fontsize=14, fontweight='bold')

colors = [ENSEMBLE_COLORS[i] for i in range(NUM_ENSEMBLE_CLASSES)]

# 1. Overall
counts = [pred_counts[c] for c in range(NUM_ENSEMBLE_CLASSES)]
axes[0].bar(ENSEMBLE_NAMES, counts, color=colors, edgecolor='black')
axes[0].set_title('Overall Ensemble Distribution')
axes[0].set_ylabel('Pixel Count')
axes[0].tick_params(axis='x', rotation=45)

# 2. Predictions where GT = mangrove
mangrove_mask = targets_flat == 1
if mangrove_mask.sum() > 0:
    mangrove_preds_cond = preds_flat[mangrove_mask]
    m_pcts = [(mangrove_preds_cond == c).sum() / len(mangrove_preds_cond) * 100 for c in range(NUM_ENSEMBLE_CLASSES)]
    bars = axes[1].bar(ENSEMBLE_NAMES, m_pcts, color=colors, edgecolor='black')
    for bar, pct in zip(bars, m_pcts):
        if pct > 1:
            axes[1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                        f'{pct:.1f}%', ha='center', va='bottom', fontsize=9)
    axes[1].set_title('Ensemble on MANGROVE pixels')
    axes[1].set_ylabel('Percentage (%)')
    axes[1].tick_params(axis='x', rotation=45)

# 3. Predictions where GT = not_mangrove
nonmangrove_mask = targets_flat == 0
if nonmangrove_mask.sum() > 0:
    nm_preds = preds_flat[nonmangrove_mask]
    nm_pcts = [(nm_preds == c).sum() / len(nm_preds) * 100 for c in range(NUM_ENSEMBLE_CLASSES)]
    bars = axes[2].bar(ENSEMBLE_NAMES, nm_pcts, color=colors, edgecolor='black')
    for bar, pct in zip(bars, nm_pcts):
        if pct > 1:
            axes[2].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                        f'{pct:.1f}%', ha='center', va='bottom', fontsize=9)
    axes[2].set_title('Ensemble on NOT MANGROVE pixels')
    axes[2].set_ylabel('Percentage (%)')
    axes[2].tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.savefig(PLOTS_DIR / 'ensemble_distribution.png', dpi=150, bbox_inches='tight')
plt.show()
print(f"Saved: {PLOTS_DIR / 'ensemble_distribution.png'}")

## 6. Visualize Predictions

In [8]:
def denormalize(img):
    """Reverse ImageNet normalization."""
    mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
    return torch.clamp(img * std + mean, 0, 1)


def mask_to_rgb(mask, class_colors, ignore_index=None):
    """Convert class mask to RGB."""
    h, w = mask.shape
    rgb = np.zeros((h, w, 3))
    for class_id, color in class_colors.items():
        rgb[mask == class_id] = color
    if ignore_index is not None:
        rgb[mask == ignore_index] = [1.0, 1.0, 1.0]  # White for ignore
    return rgb


print("Visualization functions defined")

Visualization functions defined


In [None]:
print("=== Grid Visualization ===")
print()

n_samples = 8
np.random.seed(67)
sample_indices = np.random.choice(len(dataset), n_samples, replace=False)

fig, axes = plt.subplots(n_samples, 5, figsize=(25, 5 * n_samples))
fig.suptitle('Gated Ensemble: Landcover + Mangrove on 0.5m Data', fontsize=16, fontweight='bold')

col_titles = ['Image', 'Ground Truth', 'Landcover Pred', 'Mangrove Pred', 'Ensemble']

torch.cuda.empty_cache()

with torch.no_grad():
    for row, idx in enumerate(sample_indices):
        img, mask, real_idx = dataset[idx]
        img_dev = img.unsqueeze(0).to(device)

        lc_pred = torch.argmax(landcover_model(img_dev), dim=1).squeeze().cpu().numpy()
        mg_pred = (torch.sigmoid(mangrove_model(img_dev)) > MANGROVE_THRESHOLD).squeeze().cpu().numpy()

        # Ensemble: landcover base, mangrove gate overrides to class 5
        ens_pred = lc_pred.copy()
        ens_pred[mg_pred == 1] = 5

        img_np = denormalize(img).numpy().transpose(1, 2, 0)
        mask_np = mask.numpy()

        # Image
        axes[row, 0].imshow(img_np)
        axes[row, 0].set_ylabel(f'idx={real_idx}', fontsize=10)
        axes[row, 0].axis('off')

        # Ground truth (binary)
        axes[row, 1].imshow(mask_to_rgb(mask_np, MANGROVE_COLORS, IGNORE_INDEX))
        axes[row, 1].axis('off')

        # Landcover prediction (5-class)
        axes[row, 2].imshow(mask_to_rgb(lc_pred, LANDCOVER_COLORS))
        axes[row, 2].axis('off')

        # Mangrove prediction (binary)
        mg_vis = {0: [0.8, 0.8, 0.8], 1: [0.0, 0.9, 0.4]}
        axes[row, 3].imshow(mask_to_rgb(mg_pred.astype(int), mg_vis))
        axes[row, 3].axis('off')

        # Ensemble (6-class)
        axes[row, 4].imshow(mask_to_rgb(ens_pred, ENSEMBLE_COLORS))
        axes[row, 4].axis('off')

        if row == 0:
            for col, title in enumerate(col_titles):
                axes[row, col].set_title(title, fontsize=12)

plt.tight_layout()
plt.savefig(PLOTS_DIR / 'ensemble_grid.png', dpi=150, bbox_inches='tight')
plt.show()
print(f"Saved: {PLOTS_DIR / 'ensemble_grid.png'}")

## 7. Per-Sample Inspection

Pick specific indices to examine in detail.

In [None]:
def inspect_sample(landcover_model, mangrove_model, dataset, idx):
    """Detailed inspection of a single sample with gated ensemble."""
    torch.cuda.empty_cache()

    img, mask, real_idx = dataset[idx]

    with torch.no_grad():
        img_dev = img.unsqueeze(0).to(device)
        lc_pred = torch.argmax(landcover_model(img_dev), dim=1).squeeze().cpu().numpy()
        mg_pred = (torch.sigmoid(mangrove_model(img_dev)) > MANGROVE_THRESHOLD).squeeze().cpu().numpy()

    ens_pred = lc_pred.copy()
    ens_pred[mg_pred == 1] = 5

    img_np = denormalize(img).numpy().transpose(1, 2, 0)
    mask_np = mask.numpy()

    valid = mask_np != IGNORE_INDEX
    total_valid = valid.sum()

    print(f"Sample {real_idx}:")
    print(f"  GT mangrove: {(mask_np == 1).sum() / total_valid * 100:.1f}%")
    print(f"  Ensemble predictions:")
    for c in range(NUM_ENSEMBLE_CLASSES):
        pct = (ens_pred[valid] == c).sum() / total_valid * 100
        print(f"    {ENSEMBLE_NAMES[c]:<12}: {pct:>5.1f}%")

    fig, axes = plt.subplots(1, 5, figsize=(25, 5))
    fig.suptitle(f'Sample {real_idx}', fontsize=14)

    axes[0].imshow(img_np)
    axes[0].set_title('Image')
    axes[0].axis('off')

    axes[1].imshow(mask_to_rgb(mask_np, MANGROVE_COLORS, IGNORE_INDEX))
    axes[1].set_title('Ground Truth')
    axes[1].axis('off')

    axes[2].imshow(mask_to_rgb(lc_pred, LANDCOVER_COLORS))
    axes[2].set_title('Landcover Pred')
    axes[2].axis('off')

    mg_vis = {0: [0.8, 0.8, 0.8], 1: [0.0, 0.9, 0.4]}
    axes[3].imshow(mask_to_rgb(mg_pred.astype(int), mg_vis))
    axes[3].set_title('Mangrove Pred')
    axes[3].axis('off')

    axes[4].imshow(mask_to_rgb(ens_pred, ENSEMBLE_COLORS))
    axes[4].set_title('Ensemble')
    axes[4].axis('off')

    plt.tight_layout()
    plt.show()


print("inspect_sample() defined")
print("Usage: inspect_sample(landcover_model, mangrove_model, dataset, idx=0)")

In [None]:
# Inspect specific samples
inspect_sample(landcover_model, mangrove_model, dataset, 0)
inspect_sample(landcover_model, mangrove_model, dataset, 100)
inspect_sample(landcover_model, mangrove_model, dataset, 200)

## 8. Save Results

In [None]:
print("=== Saving Results ===")
print()

valid_mask = targets_flat != IGNORE_INDEX
valid_preds = preds_flat[valid_mask]
valid_targets = targets_flat[valid_mask]

overall_dist = {}
for c in range(NUM_ENSEMBLE_CLASSES):
    overall_dist[ENSEMBLE_NAMES[c]] = float((valid_preds == c).sum() / len(valid_preds))

mangrove_dist = {}
nonmangrove_dist = {}
for c in range(NUM_ENSEMBLE_CLASSES):
    m_mask = valid_targets == 1
    nm_mask = valid_targets == 0
    if m_mask.sum() > 0:
        mangrove_dist[ENSEMBLE_NAMES[c]] = float((valid_preds[m_mask] == c).sum() / m_mask.sum())
    if nm_mask.sum() > 0:
        nonmangrove_dist[ENSEMBLE_NAMES[c]] = float((valid_preds[nm_mask] == c).sum() / nm_mask.sum())

results = {
    'method': 'gated_ensemble',
    'landcover_model': str(LANDCOVER_WEIGHTS),
    'mangrove_model': str(MANGROVE_WEIGHTS),
    'mangrove_threshold': MANGROVE_THRESHOLD,
    'num_samples': len(dataset),
    'ensemble_classes': ENSEMBLE_NAMES,
    'overall_prediction_distribution': overall_dist,
    'predictions_on_mangrove_pixels': mangrove_dist,
    'predictions_on_non_mangrove_pixels': nonmangrove_dist,
    'pixels_gated_to_mangrove': int(mangrove_gate.sum()),
    'gate_percentage': float(mangrove_gate.mean() * 100),
}

results_file = PLOTS_DIR / 'ensemble_results.json'
with open(results_file, 'w') as f:
    json.dump(results, f, indent=2)

print(f"Saved: {results_file}")
print()
print("All outputs:")
print(f"  {PLOTS_DIR / 'ensemble_distribution.png'}")
print(f"  {PLOTS_DIR / 'ensemble_grid.png'}")
print(f"  {PLOTS_DIR / 'ensemble_results.json'}")

## 9. Legend

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 3))

# Ensemble legend (6 classes)
for i, (name, color) in enumerate(zip(ENSEMBLE_NAMES, [ENSEMBLE_COLORS[c] for c in range(NUM_ENSEMBLE_CLASSES)])):
    axes[0].barh(i, 1, color=color, edgecolor='black')
    axes[0].text(0.5, i, name, ha='center', va='center', fontsize=10, fontweight='bold')
axes[0].set_title('Ensemble Output Classes', fontsize=11)
axes[0].set_xlim(0, 1)
axes[0].axis('off')

# GT legend
for i, (name, color) in enumerate(zip(MANGROVE_NAMES, [MANGROVE_COLORS[c] for c in range(2)])):
    axes[1].barh(i, 1, color=color, edgecolor='black')
    axes[1].text(0.5, i, name, ha='center', va='center', fontsize=10, fontweight='bold')
axes[1].barh(2, 1, color=[1, 1, 1], edgecolor='black')
axes[1].text(0.5, 2, 'ignore (255)', ha='center', va='center', fontsize=10, fontweight='bold')
axes[1].set_title('Ground Truth (Mangrove)', fontsize=11)
axes[1].set_xlim(0, 1)
axes[1].axis('off')

plt.tight_layout()
plt.savefig(PLOTS_DIR / 'ensemble_legend.png', dpi=150, bbox_inches='tight')
plt.show()