In [None]:
import os
import numpy as np
import random
import itertools
import matplotlib.pyplot as plt
import torch
import terratorch
import albumentations as A
import lightning.pytorch as pl
import matplotlib.pyplot as plt
from pathlib import Path
from terratorch.datamodules import GenericNonGeoSegmentationDataModule

import warnings
warnings.filterwarnings("ignore")

from terratorch.registry import BACKBONE_REGISTRY, TERRATORCH_BACKBONE_REGISTRY, TERRATORCH_DECODER_REGISTRY

# Download dataset and format it

- The levir-cd+ dataset can be downloaded from [TorchGeo datasets](https://torchgeo.readthedocs.io/en/latest/api/datasets.html#id11)
- Check `levir_cd.ipynb` notebook for how to download dataset and view samples of it
- First run `convert_levircdplus_to_genericnongeosegdatamodule.py` to format dataset for the TerraTorch datamodule (this notebook will pointing at this structured directory)
- Also run `plot_levircdplus.py` to visualise random stacked image and equivalent mask
- Uncomment cell below to compute the means & stds for standardization of the 3 RGB channels (this imports `compute_stats_for_stacked_tifs`)

In [None]:
# Compute the means and stds per channel (pointing at the restructured directory)
# from geofm.compute_levircdplus_means_stds import compute_stats_for_stacked_tifs

# train_images_dir = Path("/Users/samuel.omole/Desktop/repos/geofm_datasets/levircdplus_restructured")
# means, stds  = compute_stats_for_stacked_tifs(train_images_dir)
# print("means:", means)
# print("stds:", stds)

In [None]:
#Â The means and stds, so they don't have to be recomputed every time
means = [100.09773106949002, 98.8373331565483, 84.30711440011567]
stds = [48.06710882295874, 45.49288485657313, 42.16697994338281]

# Preparing the datamodule

In [None]:
def mask_to_binary(mask, **kwargs):
    """
    Helper code to convert masked images to binary
    Map any positive pixel to 1, zero stays 0
    """
    return (mask > 0).astype("uint8")

def check_label_values(datamodule, max_batches=20):
    """
    Helper code to confirm mask_to_binary has converted mask labels to binary

    Args:
        datamodule: The TerraTorch datamodule
        max_batches (int, optional): A random max number of batches to go through. Defaults to 20.
    """
    datamodule.setup("fit")
    loader = datamodule.train_dataloader()
    vals = set()
    for i, batch in enumerate(loader):
        # Try common batch formats
        if isinstance(batch, (list, tuple)) and len(batch) >= 2:
            _, y = batch[0], batch[1]
        elif isinstance(batch, dict):
            # Try common mask keys if dataset has different naming convention
            for k in ("mask", "masks", "label", "labels", "mask_target"):
                if k in batch:
                    y = batch[k]
                    break
            else:
                # Fallback to pick the second tensor-like item
                tensor_items = [v for v in batch.values() if isinstance(v, torch.Tensor)]
                if len(tensor_items) >= 2:
                    y = tensor_items[1]
                elif len(tensor_items) == 1:
                    y = tensor_items[0]
                else:
                    raise RuntimeError("Couldn't locate mask tensor in batch. Keys: " + ", ".join(batch.keys()))
        else:
            raise RuntimeError(f"Unexpected batch type: {type(batch)}")

        # If torch.Tensor or numpy array, convert to int
        if isinstance(y, torch.Tensor):
            uniq = torch.unique(y).cpu().numpy()
        else:
            uniq = np.unique(y)
        vals.update([int(v) for v in uniq])
        if i >= max_batches - 1:
            break

    print("Unique label values in first", max_batches, "batches:", sorted(vals))

In [None]:
# Point dataset path to restructured directory
dataset_path = Path("/Users/samuel.omole/Desktop/repos/geofm_datasets/levircdplus_restructured")

train_transform = [
        terratorch.datasets.transforms.FlattenTemporalIntoChannels(),
        A.D4(), # Random flips and rotation
        A.Lambda(mask=mask_to_binary), 
        A.pytorch.transforms.ToTensorV2(),
        terratorch.datasets.transforms.UnflattenTemporalFromChannels(n_timesteps=2),
    ]

val_transform = [
        terratorch.datasets.transforms.FlattenTemporalIntoChannels(),
        # A.D4(), # Random flips and rotation
        A.Lambda(mask=mask_to_binary),
        A.pytorch.transforms.ToTensorV2(),
        terratorch.datasets.transforms.UnflattenTemporalFromChannels(n_timesteps=2),
    ]

datamodule = terratorch.datamodules.GenericNonGeoSegmentationDataModule(
    batch_size=8,
    num_workers=0,
    num_classes=2,
    # Define dataset paths, having restructured the original dataset
    train_data_root=dataset_path / 'train' / 'images',
    train_label_data_root=dataset_path / 'train' / 'labels',
    val_data_root=dataset_path / 'val' / 'images',
    val_label_data_root=dataset_path / 'val' / 'labels',
    test_data_root=dataset_path / 'test' / 'images',
    test_label_data_root=dataset_path / 'test' / 'labels',
    
    img_grep='*_stacked.tif',
    label_grep='*_mask.png',
    
    dataset_bands=["BLUE","GREEN","RED"], 
    output_bands=["BLUE","GREEN","RED"],
    
    train_transform=train_transform,
    val_transform=val_transform, 
    test_transform=val_transform, # Apply the same transform as validation set
    expand_temporal_dimension=True,
    means=means,
    stds=stds,
    no_label_replace=-1,
    no_data_replace=0,
)

In [None]:
# Setup train and val datasets
datamodule.setup("fit")

In [None]:
# Check that the labels are binary
check_label_values(datamodule)

In [None]:
train_dataset = datamodule.train_dataset
len(train_dataset)

In [None]:
val_dataset = datamodule.val_dataset
len(val_dataset)

# Building the TerraMind model and fine-tuning with PyTorch Lightning

## Setting up the trainer

In [None]:
pl.seed_everything(0) # Set seed for reproducibility

checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath="../output_levirpluscd/terramind_small/checkpoints/", # Change as appropriate
    mode="min",
    monitor="val/loss", # Variable to monitor
    filename="best-loss",
)

# Lightning Trainer
trainer = pl.Trainer(
    accelerator="cpu",
    strategy="auto",
    devices=1, 
    precision='32',
    num_nodes=1,
    logger=True,
    max_epochs=5, # For demos
    log_every_n_steps=1,
    enable_checkpointing=True,
    callbacks=[checkpoint_callback, pl.callbacks.RichProgressBar()],
    default_root_dir="../output_levirpluscd/terramind_small/", # Change as appropriate
)

# Model
model = terratorch.tasks.SemanticSegmentationTask(
    model_factory="EncoderDecoderFactory",
    model_args={
        # Backbone
        "backbone": "terramind_v1_small",
        "backbone_pretrained": True,
        # "backbone_in_channels": 6,
        "backbone_use_temporal": True,
        "backbone_temporal_pooling": 'diff',
        "backbone_temporal_concat": False,
        "backbone_modalities": ["RGB"], #["S2L2A"],
        "backbone_bands": {"RGB": ["RED","GREEN","BLUE"]}, 
        # Necks 
        "necks": [
            {
                "name": "SelectIndices",
                "indices": [2, 5, 8, 11]
            },
            {"name": "ReshapeTokensToImage",},
            {"name": "LearnedInterpolateToPyramidal"}            
        ],
        # Decoder
        "decoder": "UNetDecoder",
        "decoder_channels": [512, 256, 128, 64],
        
        # Head
        "head_dropout": 0.1,
        "num_classes": 2,
    },
    
    loss="dice",
    optimizer="AdamW",
    lr=1e-4,
    ignore_index=-1,
    freeze_backbone=True, 
    freeze_decoder=False,
    plot_on_val=True,
    class_names=['change', 'no change']  # Optionally define class names
)

In [None]:
# Training
trainer.fit(model, datamodule=datamodule)

## Evaluate the model performance on the test dataset

In [None]:
best_ckpt_path = "../output_levirpluscd/terramind_small/checkpoints/best-loss.ckpt" # Change path to saved model
trainer.test(model, datamodule=datamodule, ckpt_path=best_ckpt_path)

In [None]:
# Required later on when plotting test set predictions
datamodule.setup("test")
test_dataset = datamodule.test_dataset
len(test_dataset)

## Predicting & plotting some example test set

In [None]:
COMMON_CHANNEL_COUNTS = (1, 2, 3, 4, 6)

def ensure_channel_first(img, n_timesteps=None, pick_time='last'):
    """
    Return an image numpy array with shape (C, H, W) that the
    dataloader can use for plotting
    """
    if torch.is_tensor(img):
        arr = img.cpu().numpy()
    else:
        arr = np.asarray(img)
    arr = np.squeeze(arr)

    # 2D array (1, H, W)
    if arr.ndim == 2:
        return arr[np.newaxis, :, :].astype(np.float32)

    # 3D array - could be (C,H,W) or (H,W,C) or (T,H,W)
    if arr.ndim == 3:
        a, b, c = arr.shape
        # (C, H, W)
        if a in COMMON_CHANNEL_COUNTS:
            return arr.astype(np.float32)
        # (H, W, C)
        if c in COMMON_CHANNEL_COUNTS:
            return arr.transpose(2, 0, 1).astype(np.float32)
        # (T, H, W)
        if a <= 50 and (n_timesteps is None or a == n_timesteps):
            if pick_time == 'first':
                chosen = arr[0]
            elif pick_time == 'last':
                chosen = arr[-1]
            else:
                chosen = arr.mean(axis=0)
            return chosen[np.newaxis, :, :].astype(np.float32)
        # fallback assume (H, W, C)
        return arr.transpose(2, 0, 1).astype(np.float32)

    # 4D array - Handle all possible permutations
    if arr.ndim == 4:
        d0, d1, d2, d3 = arr.shape
        t_idx = -1 if pick_time == 'last' else 0

        # (C, T, H, W)
        if d0 in COMMON_CHANNEL_COUNTS and d1 <= 50:
            out = arr[:, t_idx, :, :]
            return out.astype(np.float32)

        # (T, C, H, W)
        if d1 in COMMON_CHANNEL_COUNTS and d0 <= 50:
            out = arr[t_idx, :, :, :]
            return out.astype(np.float32)

        # (H, W, C, T)
        if d2 in COMMON_CHANNEL_COUNTS and d3 <= 50:
            out = arr[:, :, :, t_idx]           # (H, W, C)
            return out.transpose(2, 0, 1).astype(np.float32)

        # (H, W, T, C)
        if d3 in COMMON_CHANNEL_COUNTS and d2 <= 50:
            out = arr[:, :, t_idx, :]           # (H, W, C)
            return out.transpose(2, 0, 1).astype(np.float32)

        # Flattened (T*C, H, W) with known n_timesteps
        if n_timesteps is not None and d0 % n_timesteps == 0:
            c = d0 // n_timesteps
            reshaped = arr.reshape(n_timesteps, c, d1, d2)  # (T, C, H, W)
            out = reshaped[t_idx, :, :, :]
            return out.astype(np.float32)

        # Fallback to average over the first axis and hope for best
        out = arr.mean(axis=0)
        # Fine if this produced (C,H,W) else try transpose
        if out.ndim == 3 and out.shape[0] in COMMON_CHANNEL_COUNTS:
            return out.astype(np.float32)
        if out.ndim == 3:
            return out.transpose(2, 0, 1).astype(np.float32)

    raise ValueError(f"Unsupported image ndim={arr.ndim}, shape={arr.shape}")

In [None]:
def get_random_batch(loader, seed=None):
    """
    Gets random batch not just the first

    Args:
        loader (_type_): The test loader 
        seed (_type_, optional): The seed for reproducibility. Defaults to None.

    Returns:
        The extracted batch from the test loader
    """
    if seed is not None:
        random.seed(seed)
    n_batches = len(loader)
    idx = random.randrange(n_batches)
    batch = next(itertools.islice(iter(loader), idx, None))
    return batch

In [None]:
model = terratorch.tasks.SemanticSegmentationTask.load_from_checkpoint(
    best_ckpt_path,
    model_factory=model.hparams.model_factory,
    model_args=model.hparams.model_args,
)

test_loader = datamodule.test_dataloader()
with torch.no_grad():
    # batch = next(iter(test_loader))
    batch = get_random_batch(test_loader, seed=10)
    images = batch["image"].to(model.device)   # leave for inference
    masks = batch["mask"].cpu().numpy()
    outputs = model(images)
    preds = torch.argmax(outputs.output, dim=1).cpu().numpy()

# Plot example predictions
for i in range(3):
    raw_img = batch["image"][i].cpu()   # torch tensor or numpy
    print(raw_img.shape)
    prepared_img = ensure_channel_first(raw_img, n_timesteps=2, pick_time='first')
    print(prepared_img.shape)
    print(f"mask shape: {masks[i].shape}")

    sample = {
        "image": prepared_img,
        "mask": masks[i],
        "prediction": preds[i], 
    }
    test_dataset.plot(sample)
    plt.show()
