# Practical 3 (Path 2) — Fine-tuning a Geospatial Foundation Model with **TerraTorch** (Prithvi-EO-2.0, Flood Mapping)

**Road to SKA: Foundation Models, Embeddings, and Latent Spaces**

This practical is a **realistic EO workflow**: use the **TerraTorch** toolkit to fine-tune a **Geospatial Foundation Model (GFM)** for **flood segmentation**.

We focus on:
- **Prithvi-EO-2.0** (IBM + NASA) as the foundation backbone
- **Sen1Floods11** as the downstream dataset (flood masks)
- **TerraTorch** CLI + config-driven training
- a **low-resource path**: short runs, small subsets, and PEFT pointers (LoRA/VPT)

---

## Key links (keep these handy)

- Prithvi-EO-2.0 release + example configs/notebooks: https://github.com/NASA-IMPACT/Prithvi-EO-2.0  
- Sen1Floods11 dataset repo (download instructions + citation): https://github.com/cloudtostreet/Sen1Floods11  
- TerraTorch repo + quickstart + install notes (GDAL!): https://github.com/terrastackai/terratorch  
- Fine-tuned reference model (flood segmentation): https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11  
- PEFT in geospatial (LoRA, VPT, ViT-adapter) integrated into TerraTorch: https://github.com/IBM/peft-geofm  

> **Tip:** If installs get painful (GDAL) or downloads are blocked, skip to **Section 9 (Fallback paths)**.


## 0. What you’ll build

A minimal end-to-end pipeline:

1. Download a **tiny subset** of Sen1Floods11 chips (or use a pre-downloaded folder).
2. Sanity check: read GeoTIFFs, visualise RGB + label.
3. Create a **small TerraTorch config** (based on Prithvi-EO-2.0 example) and run:

```bash
terratorch fit -c sen1floods11_small.yaml
```

4. Evaluate a checkpoint and run inference on a sample chip.

This is *config-first* to match how TerraTorch is used in practice.


## 1. Setup and installs (TerraTorch + GDAL)

TerraTorch depends on **GDAL**. The least painful way is usually **conda**:

```bash
conda create -n terratorch python=3.10 -y
conda activate terratorch
conda install -c conda-forge gdal -y
pip install terratorch
```

If you *must* use pip only, you may need system GDAL already installed.

In this notebook, we’ll assume the environment can run `terratorch` as a CLI.


In [None]:
# Optional pip installs for the notebook helper steps (visualisation + yaml)
# (TerraTorch itself is installed separately; see instructions above)
# %pip -q install rasterio pyyaml matplotlib numpy

import os
import glob
import yaml
import numpy as np
import matplotlib.pyplot as plt

# rasterio can be finicky if GDAL isn't present
import rasterio


## 2. Data: Sen1Floods11

Sen1Floods11 is hosted in a public Google Cloud Storage bucket and can be downloaded with `gsutil`.

### Option A — Download the full dataset (~14GB)
```bash
gsutil -m rsync -r gs://sen1floods11 /home/files/sen1floods11
```

### Option B — Download a tiny subset (recommended for workshop)
We’ll pull **just a handful of chips** for a quick run.

> If you don't have `gsutil`, you can (1) install it, or (2) ask the workshop organisers to provide a pre-downloaded subset folder.


In [None]:
# Set your local dataset directory here
DATA_ROOT = "./sen1floods11_subset"   # put chips here (or point to pre-downloaded data)
os.makedirs(DATA_ROOT, exist_ok=True)

print("DATA_ROOT:", os.path.abspath(DATA_ROOT))


### 2.1 Download a small subset using `gsutil` (optional)

This cell uses shell commands. If you're on Windows or don't have gsutil, skip it.

We try to download:
- one Sentinel-2 hand-labelled raster (`*_S2Hand.tif`) and
- its corresponding flood mask (`*_LabelHand.tif`) if present in the bucket layout.

Bucket structure can change between versions; if a path fails, consult the Sen1Floods11 README.


In [None]:
# If you have gsutil available, uncomment and run this cell.
# The pattern below is intentionally conservative; adjust as needed based on bucket layout/version.

# %bash
# set -e
# mkdir -p sen1floods11_subset
# echo "Listing a few S2Hand files (may take a moment)..."
# gsutil ls "gs://sen1floods11/**/**_S2Hand.tif" | head -n 10
#
# # Pick ONE file from the list and copy it:
# FILE="gs://sen1floods11/data/Bolivia/Bolivia_103757_S2Hand.tif"
# gsutil cp "$FILE" sen1floods11_subset/
#
# # Try common label name patterns (may need adjustment):
# gsutil -m cp "gs://sen1floods11/**/Bolivia_103757_LabelHand.tif" sen1floods11_subset/ || true
# gsutil -m cp "gs://sen1floods11/**/Bolivia_103757_Label.tif" sen1floods11_subset/ || true
#
# echo "Done. Files:"
# ls -lh sen1floods11_subset


## 3. Sanity check: read a chip and visualise

We’ll:
- load one `*_S2Hand.tif` chip (Sentinel-2 bands)
- (optionally) load the label chip
- plot an RGB composite + mask

If you don’t have labels downloaded, you can still visualise the imagery.


In [None]:
def find_one(pattern: str):
    matches = sorted(glob.glob(os.path.join(DATA_ROOT, pattern)))
    return matches[0] if matches else None

s2_path = find_one("*_S2Hand.tif")
lbl_path = find_one("*_Label*.tif")  # tries LabelHand / Label etc.

print("S2 chip:", s2_path)
print("Label:", lbl_path)
assert s2_path is not None, "No *_S2Hand.tif found. Download a subset or point DATA_ROOT to the dataset folder."

with rasterio.open(s2_path) as src:
    s2 = src.read()  # (bands, H, W)
    profile = src.profile

print("S2 shape:", s2.shape)
print("dtype:", s2.dtype)
print("crs:", profile.get("crs"))


In [None]:
def to_rgb(s2_arr):
    # Sen1Floods11 chips often store S2 bands as a stack.
    # If band order is unknown, we try a reasonable guess and clip to percentiles.
    # You may need to adjust which indices correspond to (R,G,B).
    bands, H, W = s2_arr.shape
    if bands >= 3:
        # guess: last 3 are RGB or first 3 are RGB; try first 3
        rgb = np.stack([s2_arr[2], s2_arr[1], s2_arr[0]], axis=-1).astype(np.float32)
    else:
        rgb = np.repeat(s2_arr[0][...,None], 3, axis=-1).astype(np.float32)

    # robust scaling
    lo, hi = np.percentile(rgb, [2, 98])
    rgb = np.clip((rgb - lo) / (hi - lo + 1e-6), 0, 1)
    return rgb

rgb = to_rgb(s2)

plt.figure(figsize=(5,5))
plt.imshow(rgb)
plt.title(os.path.basename(s2_path))
plt.axis("off")
plt.show()

if lbl_path is not None:
    with rasterio.open(lbl_path) as src:
        lbl = src.read(1)
    plt.figure(figsize=(5,5))
    plt.imshow(lbl, cmap="gray")
    plt.title(os.path.basename(lbl_path))
    plt.axis("off")
    plt.show()


## 4. TerraTorch training: config-driven fine-tuning

TerraTorch is typically run via CLI with a YAML config.

### What we will do
1. Create a **small** training config (`sen1floods11_small.yaml`) by starting from the Prithvi-EO-2.0 example and editing:
   - dataset paths
   - batch size / epochs
   - limits for a short workshop run
2. Run:
```bash
terratorch fit -c sen1floods11_small.yaml
```

> The Prithvi-EO-2.0 repo provides an example config for Sen1Floods11 and notes that the reference model was fine-tuned with `terratorch fit -c sen1floods11.yaml`.


### 4.1 Create a small config skeleton

Because workshop environments differ, we generate a **template config** you can edit.

If you have the official config file, you can download it and modify it instead:
- https://github.com/NASA-IMPACT/Prithvi-EO-2.0 (see `configs/sen1floods11.yaml`)

Below we provide a **minimal template** that you should treat as a starting point.

> TerraTorch configs can evolve between versions. If something fails, check the TerraTorch docs / examples for your installed version.


In [None]:
CONFIG_OUT = "sen1floods11_small.yaml"

# Minimal / illustrative template.
# You will likely need to tweak keys to match your TerraTorch version and the dataset folder structure you downloaded.
config = {
    "seed_everything": 42,
    "trainer": {
        "accelerator": "auto",
        "devices": 1,
        "max_epochs": 2,
        "log_every_n_steps": 10,
        # Short-run controls (workshop):
        "limit_train_batches": 20,
        "limit_val_batches": 5,
    },
    "model": {
        # You typically choose a backbone (Prithvi-EO-2.0) and a decoder head (e.g., UNet-like)
        # Exact names depend on TerraTorch version.
        "task": "segmentation",
        "backbone": {
            "name": "prithvi_eo_v2_300_tl",  # common shorthand in examples; adjust to match your registry
            "pretrained": True,
        },
        "decoder": {
            "name": "unet",                 # example; may be "unet" / "fcn" / etc. depending on version
            "num_classes": 2,               # water vs non-water (often binary)
        },
    },
    "data": {
        "name": "sen1floods11",
        "data_root": DATA_ROOT,
        "batch_size": 2,
        "num_workers": 2,
        # Optionally: patch size, bands, normalisation settings, etc.
    },
    "optimizer": {
        "name": "adamw",
        "lr": 1e-4,
        "weight_decay": 1e-2,
    },
    "lr_scheduler": {
        "name": "cosine",
    },
}

with open(CONFIG_OUT, "w") as f:
    yaml.safe_dump(config, f, sort_keys=False)

print("Wrote:", CONFIG_OUT)
print("\n--- config preview ---")
print(open(CONFIG_OUT, "r").read()[:1200])


### 4.2 Run training

If `terratorch` is installed and on PATH, run:

```bash
terratorch fit -c sen1floods11_small.yaml
```

You should see logs, checkpoints, and metrics.

> If you get errors about model names/registries, check TerraTorch examples and adjust `model.backbone.name` / `decoder.name`.


In [None]:
# Uncomment to run (requires TerraTorch installed + working config)
# !terratorch fit -c {CONFIG_OUT}


## 5. (Optional) Evaluate and run inference

Once training finishes, you typically have a `.ckpt` file.

You can test using TerraTorch:

```bash
terratorch test -c sen1floods11_small.yaml --ckpt_path path/to/checkpoint.ckpt
```

Or export a Torch model and run inference over chips.

Below is a placeholder cell you can adapt once you know your checkpoint path.


In [None]:
# Replace with your checkpoint path if you ran training
CKPT_PATH = None  # e.g., "lightning_logs/version_0/checkpoints/epoch=1-step=....ckpt"

# Example test command:
# !terratorch test -c {CONFIG_OUT} --ckpt_path {CKPT_PATH}


## 6. PEFT (LoRA / VPT) with TerraTorch

Under tighter resource constraints, you often want **parameter-efficient** adaptation:
- LoRA (train low-rank adapters)
- Visual Prompt Tuning (train small prompt tokens)
- ViT-Adapters (train small adapter modules)

A good “known working” starting point is the **peft-geofm** repo, which integrates these methods into TerraTorch and provides configs:

- https://github.com/IBM/peft-geofm

### Suggested workshop activity
1. Clone `peft-geofm`
2. Choose a config for Prithvi-EO-2.0 + a downstream task (e.g., burn scars or floods if available)
3. Run `terratorch fit -c <config.yaml>`

> Config keys differ by model/task; the repo’s configs are the ground truth.


In [None]:
# Suggested (optional) workflow:
# !git clone https://github.com/IBM/peft-geofm
# %cd peft-geofm
# # Explore available configs:
# !find configs -maxdepth 4 -type f -name "*.yaml" | head -n 30
#
# # Then run something like:
# # !terratorch fit -c configs/peft/<...>/lora.yaml


## 7. Compare: full fine-tune vs PEFT (what to measure)

Have students report:

- **Trainable parameter count**
- GPU memory / batch size feasible
- Training time per epoch
- Validation IoU / F1 (segmentation)
- Robustness across flood events (generalisation)

Even if you can't run a full benchmark in class, logging these values helps them “think like practitioners”.


## 8. Connection to embeddings and latent spaces

How this connects to Practical 2 (embeddings):

- The **backbone encoder** produces a representation (latent embedding) per patch/token.
- Decoders (UNet/FCN) turn these into per-pixel predictions.
- PEFT methods (LoRA/VPT) nudge the latent space **just enough** to match the new task, without rewriting all weights.

A useful mental model:

> Pretraining learns a *general-purpose latent space*; fine-tuning learns a *task-aligned readout* (and sometimes small latent tweaks).


## 9. Fallback plans (low resources / low connectivity)

If GPUs, storage, or bandwidth are limited:

### Fallback A — Use the already fine-tuned flood model
Use the reference model (Prithvi-EO-2.0-300M-TL-Sen1Floods11) and run inference only:
- https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11

Students can still:
- visualise predictions
- compute IoU on a few labelled chips
- do error analysis and qualitative checks

### Fallback B — Use a tiny/small backbone
IBM has released “tiny/small” variants of EO models (Prithvi/TerraMind) aimed at running on consumer devices.
(See IBM Research blog posts and the TerraMind repo.)

### Fallback C — Freeze the backbone, train decoder only
In config:
- set `backbone.pretrained = True`
- freeze backbone parameters
- train only the segmentation head/decoder

### Fallback D — Train on *embeddings*, not pixels
If you can’t run segmentation:
- extract patch embeddings for chips (Practical 2)
- train a lightweight classifier/regressor on top
- use nearest-neighbour retrieval for weak localisation / similarity search

### Fallback E — Reduce everything
- fewer bands
- fewer chips
- smaller patches
- fewer epochs
- mixed precision if available


## 10. Where this goes next

If you want a capstone:

1. Use STAC (Practical 2 Option B) to pull a real flood event AOI.
2. Extract embeddings with a GFM encoder.
3. Fine-tune with PEFT using a small set of hand labels.
4. Deploy a “flood alert” notebook pipeline that:
   - ingests new imagery
   - runs inference
   - produces a raster + quicklook map + summary metrics
