# Setup
1. In colab: Go to "Runtime" -> "Change runtime type" -> Select "T4 GPU"
2. Install TerraTorch

Installing the basic dependencies in our environment.

In [2]:
!pip install terratorch==1.0 gdown tensorboard &> install.log

In [7]:
import os
import sys
import torch
import gdown
import terratorch
import albumentations
import lightning.pytorch as pl
import matplotlib.pyplot as plt
from pathlib import Path
from terratorch.datamodules import GenericNonGeoSegmentationDataModule
import warnings
import os
import zipfile

#os.environ["TENSORBOARD_PROXY_URL"]= os.environ["NB_PREFIX"]+"/proxy/6006/"
warnings.filterwarnings('ignore')

This is the method used to "checkpoint" (periodically save the model to disk).

The class that we have called `model` is, in fact, a combination of `task` (segmentation, regression, classification, ...), the neural network architecture and the optimizer.
In TerraTorch, the neural network is a combination of:
* **backbone**: the encoder of the pretrained model.
* **neck**: and intermediary network aimed at adjusting the output of the backbone to be compartible with the decoder expected input.
* **decoder**: a network we introduce with the purpose of converting the backbone embedding state to our aimed task.
* **head**: the last layer of specialization, it is a small network associated with each specific task and works by adapting the decoder outputs to predict the target ones.

Experiment with different configurations in order to see how the object is modified.

In [8]:
# Model
model = terratorch.tasks.PixelwiseRegressionTask(
    model_factory="EncoderDecoderFactory",
    model_args={
        # Backbone
        "backbone": "prithvi_eo_v2_300", # Model can be either prithvi_eo_v1_100, prithvi_eo_v2_300, prithvi_eo_v2_300_tl, prithvi_eo_v2_600, prithvi_eo_v2_600_tl
        "backbone_pretrained": True,
        "backbone_num_frames": 1, # 1 is the default value,
        # "backbone_img_size": 224,
        "backbone_bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"],
        # "backbone_coords_encoding": [], # use ["time", "location"] for time and location metadata

        # Necks
        "necks": [
            {
                "name": "SelectIndices",
                # "indices": [2, 5, 8, 11] # indices for prithvi_eo_v1_100
                "indices": [5, 11, 17, 23] # indices for prithvi_eo_v2_300
                # "indices": [7, 15, 23, 31] # indices for prithvi_eo_v2_600
            },
            {"name": "ReshapeTokensToImage",},
            {"name": "LearnedInterpolateToPyramidal"}
        ],

        # Decoder
        "decoder": "UNetDecoder",
        "decoder_channels": [512, 256, 128, 64],
        # "head_dropout": 0.16194593880230534,
        # "head_final_act": torch.nn.ReLU,
        # "head_learned_upscale_layers": 2
    },

    loss="rmse",
    optimizer="AdamW",
    lr=1e-3,
    ignore_index=-1,
    freeze_backbone=True, # Only to speed up fine-tuning
    freeze_decoder=False,
    plot_on_val=True,
    # class_names=['no burned', 'burned']  # optionally define class names
)

An overview of the model we just instantiated.

In [9]:
model

PixelwiseRegressionTask(
  (model): PixelWiseModel(
    (encoder): PrithviViT(
      (patch_embed): PatchEmbed(
        (proj): Conv3d(6, 1024, kernel_size=(1, 16, 16), stride=(1, 16, 16))
        (norm): Identity()
      )
      (blocks): ModuleList(
        (0-23): 24 x Block(
          (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=1024, out_features=3072, bias=True)
            (q_norm): Identity()
            (k_norm): Identity()
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=1024, out_features=1024, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (ls1): Identity()
          (drop_path1): Identity()
          (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=1024, out_features=4096, bias=True)
            (act): GELU(approximate='none')
   

Now, let's do the same for another model. Maybe some model from the TorchGeo gallery. For looking for a model, you can read the TerraTorch's documentation about backbones or the  documentation about models in TorchGeo. If you want just to interactively explore the models registered for TerraTorch, you can try to use Python via command line or script, as seen below.

We will handle an object Registry, which is very similar to a Python dictionary.

In [14]:
from terratorch import BACKBONE_REGISTRY

Print all the available models (a too long list).

In [16]:
#BACKBONE_REGISTRY

Let's see how many bacbkones we have available on our register.

In [17]:
len(BACKBONE_REGISTRY)

1287

Let's explore a different architecture, as Swin transformers.

In [26]:
swin_list = [item for item in BACKBONE_REGISTRY if "swin" in item]
swin_list

['terratorch_satlas_swin_t_sentinel2_mi_ms',
 'terratorch_satlas_swin_t_sentinel2_mi_rgb',
 'terratorch_satlas_swin_t_sentinel2_si_ms',
 'terratorch_satlas_swin_t_sentinel2_si_rgb',
 'terratorch_satlas_swin_b_sentinel2_mi_ms',
 'terratorch_satlas_swin_b_sentinel2_mi_rgb',
 'terratorch_satlas_swin_b_sentinel2_si_ms',
 'terratorch_satlas_swin_b_sentinel2_si_rgb',
 'terratorch_satlas_swin_b_naip_mi_rgb',
 'terratorch_satlas_swin_b_naip_si_rgb',
 'terratorch_satlas_swin_b_landsat_mi_ms',
 'terratorch_satlas_swin_b_landsat_mi_rgb',
 'terratorch_satlas_swin_b_sentinel1_mi',
 'terratorch_satlas_swin_b_sentinel1_si',
 'timm_hiera_base_abswin_256',
 'timm_hiera_small_abswin_256',
 'timm_prithvi_swin_B',
 'timm_prithvi_swin_L',
 'timm_swin_base_patch4_window7_224',
 'timm_swin_base_patch4_window12_384',
 'timm_swin_large_patch4_window7_224',
 'timm_swin_large_patch4_window12_384',
 'timm_swin_s3_base_224',
 'timm_swin_s3_small_224',
 'timm_swin_s3_tiny_224',
 'timm_swin_small_patch4_window7_224'

Let's choose `prithvi_swin_B` (notice that we removed the prefix `timm`).

In [34]:
# Model
model = terratorch.tasks.PixelwiseRegressionTask(
    model_factory="EncoderDecoderFactory",
    model_args={
        # Backbone
        "backbone": "prithvi_swin_L",
        "backbone_pretrained": True,
        "backbone_bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"],
        "necks": [
            {"name": "ReshapeTokensToImage",},
            {"name": "LearnedInterpolateToPyramidal"}
        ],

        # Decoder
        "decoder": "UNetDecoder",
        "decoder_channels": [512, 256, 128, 64],
    },

    loss="rmse",
    optimizer="AdamW",
    lr=1e-3,
    ignore_index=-1,
    freeze_backbone=True, # Only to speed up fine-tuning
    freeze_decoder=False,
    plot_on_val=True,
)

No pretrained configuration was found for the model prithvi_swin_L.


In [35]:
model

PixelwiseRegressionTask(
  (model): PixelWiseModel(
    (encoder): TimmBackboneWrapper(
      (patch_embed): PatchEmbed(
        (drop): Dropout(p=0.0, inplace=False)
        (adap_padding): AdaptivePadding()
        (projection): Conv2d(6, 192, kernel_size=(4, 4), stride=(4, 4))
        (norm): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
      )
      (stages_0): SwinBlockSequence(
        (blocks): ModuleList(
          (0): SwinBlock(
            (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
            (attn): ShiftWindowMSA(
              (w_msa): WindowMSA(
                (qkv): Linear(in_features=192, out_features=576, bias=True)
                (attn_drop): Dropout(p=0.0, inplace=False)
                (proj): Linear(in_features=192, out_features=192, bias=True)
                (proj_drop): Dropout(p=0.0, inplace=False)
                (softmax): Softmax(dim=-1)
              )
              (drop): DropPath(drop_prob=0.000)
            )
           