In [None]:
import os
import torch
import gdown
import terratorch
import albumentations
import lightning.pytorch as pl
import matplotlib.pyplot as plt
from pathlib import Path
from terratorch.datamodules import MultiTemporalCropClassificationDataModule
import warnings
warnings.filterwarnings("ignore")
import tarfile

from terratorch.registry import BACKBONE_REGISTRY, TERRATORCH_BACKBONE_REGISTRY, TERRATORCH_DECODER_REGISTRY

# Downloading the multi-temporal crop classification dataset

- Uncomment the cell below to download dataset for the first time
- Comment the cell out after first download

In [None]:
# dataset_root = "/Users/samuel.omole/Desktop/repos/geofm_datasets"
# url = "https://drive.google.com/uc?id=1SycflNslu47yfMg2i_z8FqYkhZQv7JQM"
# archive = dataset_root + "/multi-temporal-crop-classification-subset.tar.gz"
# extract_dir = dataset_root + "/multi-temporal-crop-classification-subset"

# # download if missing
# if not os.path.isfile(archive):
#     gdown.download(url, output=archive, quiet=False)

# # extract if not already extracted
# if not os.path.isdir(extract_dir):
#     with tarfile.open(archive, "r:gz") as tar:
#         tar.extractall(path=dataset_root)

# Preparing dataset with TerraTorch datamodule
- Now using the `MultiTemporalCropClassificationDataModule` datamodule

## Setting up the datamodule

In [None]:

dataset_path = Path("/Users/samuel.omole/Desktop/repos/geofm_datasets/multi-temporal-crop-classification-subset")

# Adjusted dataset class for this dataset
datamodule = MultiTemporalCropClassificationDataModule(
    batch_size=8,
    num_workers=2,
    data_root=dataset_path,
    train_transform=[
        terratorch.datasets.transforms.FlattenTemporalIntoChannels(),  # Required for temporal data
        albumentations.D4(), # Random flips and rotation
        albumentations.pytorch.transforms.ToTensorV2(),
        terratorch.datasets.transforms.UnflattenTemporalFromChannels(n_timesteps=3), # There are 3 timestamps in the dataset
    ],
    val_transform=None,  # Using ToTensor() by default
    test_transform=None,
    expand_temporal_dimension=True,
    use_metadata=False, # The crop dataset has metadata for location and time
    reduce_zero_label=True, 
)

# Setup train and val datasets
datamodule.setup("fit")

In [None]:
# Checking for the dataset means and stds
datamodule.means, datamodule.stds

In [None]:
# Checking datasets train split size
train_dataset = datamodule.train_dataset
len(train_dataset)

In [None]:
# Checking datasets available bands
train_dataset.all_band_names

In [None]:
# Checking datasets classes
train_dataset.class_names

## Plotting some training and validation examples

In [None]:
# Plotting a few train samples to visualise
for i in range(5):
    train_dataset.plot(train_dataset[i])

In [None]:
# Checking datasets validation split size
val_dataset = datamodule.val_dataset
len(val_dataset)

In [None]:
# Also plotting some validation samples
for i in range(5):
    val_dataset.plot(val_dataset[i])

In [None]:
# Checking datasets testing split size
# required later on when plotting test set predictions
datamodule.setup("test")
test_dataset = datamodule.test_dataset
len(test_dataset)

# Building the TerraMind model and fine-tuning with PyTorch Lightning

## Setting up the trainer

In [None]:
pl.seed_everything(0)

checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath="../output_multicrop/terramind_small_multicrop/checkpoints/", # Change as appropriate
    mode="min",
    monitor="val/loss",
    filename="best-loss",
    save_weights_only=True,
)

# Lightning Trainer
trainer = pl.Trainer(
    accelerator="cpu", # Use gpu if available 
    strategy="auto",
    devices=1,
    precision="16-mixed",
    num_nodes=1,
    logger=True,
    max_epochs=50, # Change as appropriate
    log_every_n_steps=1,
    callbacks=[checkpoint_callback, pl.callbacks.RichProgressBar()],
    default_root_dir="../output_multicrop/terramind_small_multicrop/", # Change as appropriate
)

# Building the Model
model = terratorch.tasks.SemanticSegmentationTask(
    model_factory="EncoderDecoderFactory",
    model_args={
        # TerraMind backbone
        "backbone": "terramind_v1_small",
        "backbone_pretrained": True,
        "backbone_modalities": ["S2L2A"],
        "backbone_bands": {"S2L2A": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"]},

        # Apply temporal wrapper (params are passed with prefix backbone_temporal)
        "backbone_use_temporal": True,
        "backbone_temporal_pooling": "concat",  # Defaults to "mean" which also supports flexible input lengths
        "backbone_temporal_n_timestamps": 3,  # Required for pooling = concat
        
        # Necks 
        "necks": [
            {
                "name": "SelectIndices",
                "indices": [2, 5, 8, 11] # indices for terramind_v1_tiny, small, and base
                # "indices": [5, 11, 17, 23] # indices for terramind_v1_large
            },
            {
                "name": "ReshapeTokensToImage",
                "remove_cls_token": False,
            },
            {"name": "LearnedInterpolateToPyramidal"},            
        ],
        
        # Decoder
        "decoder": "UNetDecoder",
        "decoder_channels": [512, 256, 128, 64],
        
        # Head
        "head_dropout": 0.1,
        "num_classes": 13,
    },
    
    loss="ce",
    lr=1e-4,
    optimizer="AdamW",
    ignore_index=-1,
    freeze_backbone=False,  # Speeds up fine-tuning
    freeze_decoder=False,
    plot_on_val=True,
    class_names=["Natural Vegetation", "Forest", "Corn", "Soybeans", "Wetlands", "Developed / Barren", "Open Water", "Winter Wheat", "Alfalfa", "Fallow / Idle Cropland", "Cotton", "Sorghum", "Other"],
)

In [None]:
# Training
trainer.fit(model, datamodule=datamodule)

## Evaluate the model performance on the test dataset

In [None]:
best_ckpt_path = "../output_multicrop/terramind_small_multicrop/checkpoints/best-loss.ckpt"
trainer.test(model, datamodule=datamodule, ckpt_path=best_ckpt_path)

## Predicting some example test set

In [None]:
# Now we can use the model for predictions and plotting
model = terratorch.tasks.SemanticSegmentationTask.load_from_checkpoint(
    best_ckpt_path,
    model_factory=model.hparams.model_factory,
    model_args=model.hparams.model_args,
)

test_loader = datamodule.test_dataloader()
with torch.no_grad():
    batch = next(iter(test_loader))
    images = batch["image"]
    images = images.to(model.device)
    masks = batch["mask"].numpy()

    with torch.no_grad():
        outputs = model(images)

    preds = torch.argmax(outputs.output, dim=1).cpu().numpy()

for i in range(5): # 8 is the batch size so set this to <= 8
    sample = {
        "image": batch["image"][i].cpu(),
        "mask": batch["mask"][i],
        "prediction": preds[i],
    }
    test_dataset.plot(sample)
    plt.show()