# Phase 1B â€“ Large-Scale LULC Inference for Noord-Brabant, Netherlands

## Overview
This notebook performs large-scale LULC inference over
Noord-Brabant Province in the Netherlands using the ViT model trained in Phase 1A.

It applies the trained model to Sentinel-2 imagery to generate a province-wide LULC map.

## Inputs
- Satellite data:
  - Sentinel-2 RGB imagery (sourced via Google Earth Engine)
- Trained model:
  - `best_model.pth` from Phase 1A
- Inference utilities:
  - `datafactory.py`
  - `inference.py`

## Methodology
- Image preprocessing:
  - Tiling and resizing of Sentinel-2 imagery
  - Normalization using ImageNet statistics
- Inference strategy:
  - Sliding window / tiled inference
  - Batch processing on GPU
- Post-processing:
  - Stitching predictions into a full-resolution LULC map

## Outputs
- LULC classification map for Noord-Brabant (GeoTIFF)
- Visualizations of predicted land cover classes

In [1]:
# --- Import packages ---
import ee
import geemap
import torch
import rasterio
import numpy as np
import matplotlib.pyplot as plt
import random
from rasterio.windows import Window
from torchvision import models

# custom modules
from datafactory import get_transforms
from inference import GeoInference

In [2]:
# --- Load and export Sentinel-2 data ---

# Initialize GEE
ee.Authenticate() 
ee.Initialize()

# Define Region (North Brabant)
brabant = ee.FeatureCollection("FAO/GAUL/2015/level1") \
    .filter(ee.Filter.And(
        ee.Filter.eq("ADM0_NAME", "Netherlands"),
        ee.Filter.eq("ADM1_NAME", "Noord-brabant")
    ))

# Cloud Mask
def maskS2clouds(image):
    qa = image.select("QA60")
    mask = qa.bitwiseAnd(1 << 10).eq(0).And(qa.bitwiseAnd(1 << 11).eq(0))
    return image.updateMask(mask)

# Median composite from summer collection
s2_composite = (ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED")
    .filterBounds(brabant)
    .filterDate("2024-07-01", "2024-08-31")
    .filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', 20))
    .map(maskS2clouds)
    .median())

# Converts Sentinel-2 imagery to 8-bit RGB for model compatibility
s2_export = s2_composite.visualize(bands=['B4', 'B3', 'B2'], min=0, max=2000).clip(brabant)

# Export task (Pull the image from your google drive into the data folder when complete)
task = ee.batch.Export.image.toDrive(
    image=s2_export,
    description='noord_brabant_export',
    folder='earthengine_projects',
    fileNamePrefix='brabant_s2rgb',
    region=brabant.geometry(),
    scale=10, 
    crs='EPSG:32631', 
    maxPixels=1e13
)
task.start()

In [1]:
# --- Load model & run inference ---

input_tif = 'data/brabant_s2rgb.tif'
output_tif = 'outputs/brabant_map.tif' # output tif file name
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Loading model...")

# Allow safe loading of ViT
torch.serialization.add_safe_globals([models.vision_transformer.VisionTransformer])
model = torch.load('outputs/best_model.pth', map_location=device, weights_only=False)
model.eval()

# Initialize inference Class
inferencer = GeoInference(
    model=model,
    transform=get_transforms('inference'), # Using the modular transform
    device=device,
    tile_size=64,
    stride=32  # 50% overlap 
)

# Run sliding window
print("Starting Inference...")
pred_map, score_map, profile = inferencer.predict_sliding_window(input_tif)

# Save result
profile.update(dtype='uint8', count=1, nodata=255)
with rasterio.open(output_tif, 'w', **profile) as dst:
    dst.write(pred_map, 1)

print(f" Saved classification to {output_tif}")

In [2]:
# --- Visualize predicted samples with scores ---

classes = ['AnnualCrop', 'Forest', 'Herbaceous', 'Highway', 'Industrial',
           'Pasture', 'PermanentCrop', 'Residential', 'River', 'SeaLake']


def plot_prediction_samples(tif_path, pred_map, score_map, classes, num_samples=10):
    plt.figure(figsize=(20, 8))
    with rasterio.open(tif_path) as src:
        samples_found = 0
        while samples_found < num_samples:
            # Pick a Random window
            x = random.randint(0, src.width - 64)
            y = random.randint(0, src.height - 64)
            
            rgb = src.read([1, 2, 3], window=Window(x, y, 64, 64))
            
            # Skip background or mostly empty patches
            if np.all(rgb == 0): continue
            if np.count_nonzero(rgb) < (rgb.size * 0.5): continue 
            
            # Get prediction and score at center of window (32, 32)
            row_idx, col_idx = y + 32, x + 32
            pred_idx = pred_map[row_idx, col_idx]
            
            # Skip if the center pixel is masked (255)
            if pred_idx == 255: continue
            
            # Extract the specific softmax score for the predicted class
            confidence_score = score_map[pred_idx, row_idx, col_idx]
            
            # Plot
            plt.subplot(2, 5, samples_found + 1)
            plt.imshow(np.moveaxis(rgb, 0, -1))
            plt.title(f"{classes[pred_idx]}\nScore: {confidence_score:.2f}", 
                      fontsize=10, fontweight='bold')
            plt.axis('off')
            samples_found += 1
            
    plt.tight_layout()
    plt.show()

plot_prediction_samples(input_tif, pred_map, score_map, classes)

In [3]:
# --- Interactive Map Overlay ---
hex_colors = [
    '#F3D060', # AnnualCrop
    '#004D1A', # Forest
    '#91AF40', # Herbaceous
    '#5D4037', # Highway 
    '#7A4988', # Industrial
    '#7CFC00', # Pasture
    '#BDB76B', # PermanentCrop 
    '#FF4500', # Residential
    '#00BFFF', # River
    '#000080'  # SeaLake
]

Map = geemap.Map()
Map.centerObject(brabant, 9)

# Add Sentinel-2 Background
Map.addLayer(s2_export, {}, "Sentinel-2 RGB")

# Add LULC map (you need to install xarray and localtileserver to use run add_raster) 
try:
    Map.add_raster(
        output_tif, 
        palette=hex_colors, 
        layer_name="LULC map",
        nodata=255,
    )
except Exception as e:
    print(f"Could not load local raster directly: {e}")
    print("Try installing localtileserver and xarray")

# Add Legend
legend_dict = {classes[i]: hex_colors[i] for i in range(len(classes))}
Map.add_legend(title="Land Cover", legend_dict=legend_dict)

Map