# PART 02: MODEL FINETUNING FOR LAND USE CLASSIFICATION OF SPOT IMAGES

## Introduction
This notebook guides you through fine-tuning a pretrained **Temporal Vision Transformer (TemporalViT)** model to classify land use from multi-temporal SPOT satellite imagery, focusing on a dataset from Khon Kaen, Thailand. Fine-tuning leverages pretrained weights (Prithvi-100M) to adapt the model to our specific task, improving accuracy for land use classes (Urban, Agricultural, Forest, Water, Oil Palm, Para Rubber) using imagery from 2016, 2020, and 2022. The process includes configuring the model, customizing the backbone, defining class-specific loss functions, and optimizing training parameters.

### Why Fine-Tune a Pretrained Model Instead of Training from Scratch for Land Use Classification?
Fine-tuning a pretrained model is preferred because it:
- Avoids training from scratch, which requires massive labeled datasets and high computational power.
- Leverages pre-learned spatial and temporal representations.
- Starts with a model already familiar with remote sensing data.

In this workflow, we use the Prithvi-100M pretrained weights. This model has been trained on diverse satellite imagery, enabling it to capture general patterns that transfer effectively to our specific region (e.g., Khon Kaen) and task (land use classification).

### Benefits of Using Pretrained Models
Using a pretrained model offers several key benefits:
- **Faster Convergence**: Learns faster due to existing knowledge of patterns.
- **Reduced Overfitting**: Better generalization with fewer labeled samples.
- **Improved Accuracy**: Performs better on real-world imagery with limited training data.
- **Transferable Knowledge**: Can be fine-tuned for new regions, sensors, or land use types.

This is why our strategy emphasizes fine-tuning rather than training from scratch.

### Why Use This Model?
We chose the **Temporal Vision Transformer (TemporalViTEncoder)** because it:
- Supports multi-frame input, ideal for our three-year SPOT image stack.
- Captures temporal changes, such as seasonal or yearly land cover shifts.
- Is fully compatible with the pretrained Prithvi model.
- Handles high-dimensional multi-band data better than classical CNNs.

This module guides you through fine-tuning a pretrained TemporalViT to classify land use in a targeted region using multi-year SPOT imagery. You will learn how to configure the model, apply proper normalization, define class-specific loss functions, and optimize the training process for high performance on a real-world remote sensing classification task.

## Main Objective
Fine-tune a pretrained TemporalViT model to achieve optimal land use classification performance for multi-temporal SPOT imagery.

## Specific Objectives
By the end of this module, learners will be able to:
- Configure the TemporalViT model with appropriate data paths and parameters for SPOT imagery.
- Customize the TemporalViT backbone to handle multi-temporal and multi-spectral data.
- Define land use classes and a focal loss function to address class imbalance.
- Optimize training parameters and evaluate model performance using metrics like mIoU.

## Prerequisites
To successfully complete this module, learners should have:
- A basic understanding of remote sensing and deep learning concepts (e.g., neural networks, loss functions).
- Familiarity with Python and libraries like `MMSegmentation`.
- A prepared dataset from the previous notebook (`Part_01_Preprocessing_SPOT_image_and_Land_use_data_for_finetuning.ipynb`), including SPOT image patches, land use mask patches, and a stacked image.
- A Python environment with required libraries installed and access to a GPU for training.

## Required Inputs
- **Stacked SPOT Image**: A multi-band GeoTIFF (`stack.tif`) with 18 bands (6 bands × 3 years).
- **Training, Validation, and Test Datasets**: Image and mask patch pairs in `final_training_data/train`, `final_training_data/val`, and `final_training_data/test` directories, generated from the previous notebook.
- **Pretrained Model Weights**: Prithvi-100M weights located at `/prithvi/Prithvi_100M.pt`.es."
   ]IoU.

---------------------------------------------------------------------------------------------------------

### Set up the model configuration
In this step, learners prepare the foundational structure for training. This includes setting up paths for data, experiments, and model outputs, as well as importing necessary modules and initializing critical parameters.

Key elements include:

- Setting data_root, project_dir, and work_dir
- Defining the number of input frames (num_frames) and patch size (img_size)
- Loading pretrained model weights (pretrained_weights_path)
- Defining normalization values using the input image (img_stack)
- Selecting image bands and constructing data pipelines (train_pipeline and test_pipeline)

This step ensures that all subsequent training components know where to find the data and how to handle it.

### Customize the Temporal Vision Transformer (ViT) backbone
Learners can tailor the backbone of the model — the TemporalViTEncoder — to match the temporal and spectral characteristics of SPOT imagery.

Key configurable parameters include:
- patch_size: spatial subdivision of each input image
- tubelet_size: temporal grouping across frames (1 means treating each time-step separately)
- in_chans: number of bands per frame (e.g., 6: RGB, NIR, NDVI, NDWI)
- num_frames: how many time steps to consider (e.g., 3 years of SPOT imagery)
- embed_dim, num_layers, and num_heads: ViT architectural choices that determine how deep and expressive the model is

This step is crucial for capturing temporal patterns and multi-spectral information that improve land use classification accuracy over time.

### Define land use classes and Loss functions
This step prepares the model to correctly learn from labeled land use (LU) data. Learners define which classes to predict and how to handle training imbalance using specialized loss functions.

Key components:
- CLASSES: a list of class labels (["URB", "AGR", "FOR", "WTR", "OIL", "PRB"] for urban, agricultural, forest, water, oil palm, para rubber)
- loss_func: typically a Focal Loss, which focuses learning on difficult examples and balances class representation
    - gamma: controls how much to down-weight easy examples
    - alpha: sets weights for each class (important if one class is more dominant)

Correctly configuring the loss function is critical for improving performance on underrepresented classes like rare land types.

### Optimize training parameters
In this step, learners define how the training loop behaves — including optimizer, learning rate schedule, checkpointing, and evaluation.

Key components:
- optimizer: defines the algorithm (e.g., AdamW) used to adjust weights
    - Includes lr, weight_decay, and betas
- lr_config: configures how the learning rate changes over time (e.g., poly decay with warmup)
- runner & max_epochs: define the number of training iterations
- checkpoint_config: manages model saving frequency and location
- evaluation: defines how often and by what metric (e.g., mIoU) model performance is assessed

This step balances training efficiency, convergence, and model generalization.

In [5]:
%%writefile finetuning_config_spot.py

import os, rasterio

experiment = 'exp01' #"<experiment name>"
data_root = '/home/jovyan/shared/Dan/THA Space-AI training materials/Experiment_SPOT/learning_notebooks/Khon_Kaen/final_training_data' #<path to data root>
project_dir = 'spot_a_training' #"<project directory name>"
work_dir = os.path.join(project_dir, experiment)
save_path = work_dir
img_stack = '/home/jovyan/shared/Dan/THA Space-AI training materials/Experiment_SPOT/learning_notebooks/Khon_Kaen/stack.tif'

dist_params = dict(backend="nccl")
log_level = "INFO"
load_from = None
resume_from = None
cudnn_benchmark = True
custom_imports = dict(imports=["geospatial_fm"])
# import geofm
num_frames = 3
img_size = 224 #size of the patch images used in training (224x224)
num_workers = 1 #original is 4, this is an overall parameter used to define the 


### MODEL PARAMETERS TO BE DEFINED BY USER
pretrained_weights_path = "/prithvi/Prithvi_100M.pt" #"<path to pretrained weights>"
num_layers = 12
patch_size = 16 #original is 16
embed_dim = 768
num_heads = 12
tubelet_size = 1
max_epochs = 50
eval_epoch_interval = 1


loss_weights_multi = [ #0.0,
                      1.0,
                      1.0,
                      1.0,
                      1.0,
                      1.0,
                      1.0]
# Define the land-use or land-cover classes 
CLASSES = (
            #'NAN', #NAN
            "URB", #urban
            "AGR", #agricultural
            "FOR", #forest
            "WTR", #water
            "OIL", #oil palm
            "PRB") #para rubber


loss_func = dict(
    type="FocalLoss",
    gamma=2.0,  # Focuses more on difficult-to-learn classes like PRB
    alpha=loss_weights_multi,  # Still keeps class weight balancing
#    type="CrossEntropyLoss",
#    use_sigmoid=False,
#    class_weight=loss_weights_multi,
#    avg_non_ignore=True,
)

output_embed_dim = embed_dim * num_frames



dataset_type = "GeospatialDataset"

## The 18 values for 'means and stds lists' are specific for Thailand, these can be modified if the model is applied to a different region.


img_norm_cfg = dict(
    means=rasterio.open(img_stack).read().mean(axis=(1,2)).tolist(),
    stds=rasterio.open(img_stack).read().std(axis=(1,2)).tolist(),
)
          
bands = [1, 2, 3, 4, 5, 6]

tile_size = 224
orig_nsize = 512
crop_size = (tile_size, tile_size)
train_pipeline = [
    dict(type="LoadGeospatialImageFromFile", to_float32=True),
    dict(type="LoadGeospatialAnnotations", reduce_zero_label=True),
    dict(type="RandomFlip",  prob=0.5),
    dict(type="RandomRotate", prob=0.5, degree=10),
    #dict(type="Resize", img_scale=(512, 512), ratio_range=(0.8, 1.2)),
    #dict(type="RandomCrop", crop_size=crop_size, cat_max_ratio=0.75),
    dict(type="ToTensor", keys=["img", "gt_semantic_seg"]),
    # to channels first
    dict(type="TorchPermute", keys=["img"], order=(2, 0, 1)),
    dict(type="TorchNormalize", **img_norm_cfg),
    dict(type="TorchRandomCrop", crop_size=crop_size),
    dict(
        type="Reshape",
        keys=["img"],
        new_shape=(len(bands), num_frames, tile_size, tile_size),
    ),
    dict(type="Reshape", keys=["gt_semantic_seg"], new_shape=(1, tile_size, tile_size)),
    dict(type="CastTensor", keys=["gt_semantic_seg"], new_type="torch.LongTensor"),
    dict(type="Collect", keys=["img", "gt_semantic_seg"]),
]

test_pipeline = [
    dict(type="LoadGeospatialImageFromFile", to_float32=True),
    dict(type="ToTensor", keys=["img"]),
    # to channels first
    dict(type="TorchPermute", keys=["img"], order=(2, 0, 1)),
    dict(type="TorchNormalize", **img_norm_cfg),
    dict(
        type="Reshape",
        keys=["img"],
        new_shape=(len(bands), num_frames, -1, -1),
        look_up=dict({"2": 1, "3": 2}),
    ),
    dict(type="CastTensor", keys=["img"], new_type="torch.FloatTensor"),
    dict(
        type="CollectTestList",
        keys=["img"],
        meta_keys=[
            "img_info",
            "seg_fields",
            "img_prefix",
            "seg_prefix",
            "filename",
            "ori_filename",
            "img",
            "img_shape",
            "ori_shape",
            "pad_shape",
            "scale_factor",
            "img_norm_cfg",
        ],
    ),
]


dataset = "GeospatialDataset"
data = dict(
    samples_per_gpu=1, #original is 8, decrease this parameter if the memory consumption is too much
    workers_per_gpu=1, #original is 2, decrease this parameter if the memory consumption is too much
    train=dict(
        type=dataset,
        CLASSES=CLASSES,
        reduce_zero_label=True,
        data_root=data_root,
        img_dir='train/images',
        ann_dir='train/masks',
        pipeline=train_pipeline,
        img_suffix=".tif",
        seg_map_suffix=".tif",
    ),
    val=dict(
        type=dataset,
        CLASSES=CLASSES,
        reduce_zero_label=True,
        data_root=data_root,
        img_dir='val/images',
        ann_dir='val/masks',
        pipeline=test_pipeline,
        img_suffix=".tif",
        seg_map_suffix=".tif",
    ),
    test=dict(
        type=dataset,
        CLASSES=CLASSES,
        reduce_zero_label=True,
        data_root=data_root,
        img_dir='test/images',
        ann_dir='test/masks',
        pipeline=test_pipeline,
        img_suffix=".tif",
        seg_map_suffix=".tif",
    ),
)

optimizer = dict(type="AdamW", lr=1e-6, betas=(0.9, 0.999), weight_decay=0.05)
optimizer_config = dict(grad_clip=None)
lr_config = dict(
    policy="poly",
    warmup="linear",
    warmup_iters=500,
    warmup_ratio=1e-06,
    power=1.0,
    min_lr=0.0,
    by_epoch=False,
)

log_config = dict(
    interval=10, hooks=[dict(type="TextLoggerHook"), dict(type="TensorboardLoggerHook")]
)

checkpoint_config = dict(by_epoch=True, interval=5, out_dir=save_path)

evaluation = dict(
    interval=eval_epoch_interval,
    metric="mIoU",
    pre_eval=True,
    save_best="mIoU",
    by_epoch=True,
)
reduce_train_set = dict(reduce_train_set=False)
reduce_factor = dict(reduce_factor=1)
runner = dict(type="EpochBasedRunner", max_epochs=max_epochs)
workflow = [("train", 1)]
norm_cfg = dict(type="BN", requires_grad=True)

model = dict(
    type="TemporalEncoderDecoder",
    frozen_backbone=False,
    backbone=dict(
        type="TemporalViTEncoder",
        pretrained=pretrained_weights_path,
        img_size=img_size,
        patch_size=patch_size,
        num_frames=num_frames,
        tubelet_size=1,
        in_chans=len(bands),
        embed_dim=embed_dim,
        depth=12, 
        num_heads=num_heads,
        mlp_ratio=4.0,
        norm_pix_loss=False,
    ),
    neck=dict(
        type="ConvTransformerTokensToEmbeddingNeck",
        embed_dim=embed_dim * num_frames,
        output_embed_dim=output_embed_dim,
        drop_cls_token=True,
        Hp=14,
        Wp=14,
    ),
    decode_head=dict(
        num_classes=6, #len(CLASSES),
        in_channels=output_embed_dim,
        type="FCNHead",
        in_index=-1,
        channels=256,
        num_convs=1,
        concat_input=False,
        dropout_ratio=0.1,
        norm_cfg=dict(type="BN", requires_grad=True),
        align_corners=False,
        #ignore_index=0, # to remove pixels with 0 values due to LU Gaps
        loss_decode=loss_func,
    ),
    auxiliary_head=dict(
        num_classes=6, # len(CLASSES),
        in_channels=output_embed_dim,
        type="FCNHead",
        in_index=-1,
        channels=256,
        num_convs=2,
        concat_input=False,
        dropout_ratio=0.1,
        norm_cfg=dict(type="BN", requires_grad=True),
        align_corners=False,
        #ignore_index=0, # to remove pixels with 0 values due to LU Gaps
        loss_decode=loss_func,
    ),
    train_cfg=dict(),
    test_cfg=dict(
        mode="slide",
        stride=(int(tile_size / 2), int(tile_size / 2)),
        crop_size=(tile_size, tile_size),
    ),
)

auto_resume = False

Overwriting finetuning_config_spot.py


In [None]:
!mim train mmsegmentation finetuning_config_spot.py

Training command is /opt/conda/bin/python3.10 /opt/conda/lib/python3.10/site-packages/mmseg/.mim/tools/train.py finetuning_config_spot.py --launcher none --gpus 1. 
2025-07-06 15:09:22,232 - mmseg - INFO - Multi-processing start method is `None`
2025-07-06 15:09:22,232 - mmseg - INFO - OpenCV num_threads is `8
fatal: not a git repository (or any parent up to mount point /home/jovyan)
Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set).
2025-07-06 15:09:22,405 - mmseg - INFO - Environment info:
------------------------------------------------------------
sys.platform: linux
Python: 3.10.11 | packaged by conda-forge | (main, May 10 2023, 18:58:44) [GCC 11.3.0]
CUDA available: True
GPU 0: NVIDIA GeForce GTX TITAN X
CUDA_HOME: /usr
NVCC: Cuda compilation tools, release 11.5, V11.5.119
GCC: gcc (Ubuntu 11.3.0-1ubuntu1~22.04.1) 11.3.0
PyTorch: 1.11.0+cu115
PyTorch compiling details: PyTorch built with:
  - GCC 7.3
  - C++ Version: 201402
  - Intel(R) Math Kernel Library