In [1]:
import torch
import timm
import terratorch # even though we don't use the import directly, we need it so that the models are available in the timm registry

# Backbone factory

Learn more about timm at https://huggingface.co/docs/timm/en/index

In [2]:
# find available prithvi models by name
print(timm.list_models("prithvi*"))
# and those with pretrained weights
print(timm.list_pretrained("prithvi*"))

['prithvi_swin_B', 'prithvi_swin_B', 'prithvi_swin_B_MP', 'prithvi_swin_L', 'prithvi_swin_L_MP', 'prithvi_vit_100', 'prithvi_vit_100_us', 'prithvi_vit_300', 'prithvi_vit_tiny']
['prithvi_swin_B', 'prithvi_swin_B', 'prithvi_swin_L', 'prithvi_vit_100', 'prithvi_vit_100_us', 'prithvi_vit_300']


In [3]:
# instantiate your desired model with features_only=True to obtain a backbone
# this defaults to the weights present in CCC.
model = timm.create_model(
    "prithvi_swin_B", num_frames=1, pretrained=True, features_only=True
)

# Rest of your PyTorch / PyTorchLightning code

## The model
The resulting model is a torch module. Because we set `features_only=True`, it is only the encoder portion.
By default, the model is instantiated with the same bands it was pretrained on, with the same order.

We can inspect both of these.

In [4]:
print(f"The model was pretrained on bands {model.pretrained_bands}.\n The model is using bands {model.model_bands}")

The model was pretrained on bands [<HLSBands.BLUE: 'BLUE'>, <HLSBands.GREEN: 'GREEN'>, <HLSBands.RED: 'RED'>, <HLSBands.NIR_NARROW: 'NIR_NARROW'>, <HLSBands.SWIR_1: 'SWIR_1'>, <HLSBands.SWIR_2: 'SWIR_2'>].
 The model is using bands [<HLSBands.BLUE: 'BLUE'>, <HLSBands.GREEN: 'GREEN'>, <HLSBands.RED: 'RED'>, <HLSBands.NIR_NARROW: 'NIR_NARROW'>, <HLSBands.SWIR_1: 'SWIR_1'>, <HLSBands.SWIR_2: 'SWIR_2'>]


The model output is a list with the output of each encoder stage. This may be different for each encoder.

In [5]:
trial_data = torch.zeros(1, 6, 224, 224) # batch_size, channels, height, width
features = model(trial_data)
for index, feature in enumerate(features):
    print(f"Feature index {index} has shape {feature.shape}")

Feature index 0 has shape torch.Size([1, 56, 56, 128])
Feature index 1 has shape torch.Size([1, 28, 28, 256])
Feature index 2 has shape torch.Size([1, 14, 14, 512])
Feature index 3 has shape torch.Size([1, 7, 7, 1024])


## Band choice
Sometimes you may wish to use a separate set of bands than was used in pretraining. This may be a different ordering, a subset, a superset, or a completely different set.

To do this, you may specify the bands you wish to train on using a mixture of integers and members of the `HLSBands` enum.

In the patch embed layer, the weights corresponding to bands that exist in the pretrained bands will be mapped to the correct order. Bands that do not exist will be randomly initialized.

**Warning:** the enum maps to integers 1 through 12. If using integers, make sure they are outside this range!

In [6]:
from terratorch.datasets import HLSBands
# lets get only the RGB bands, and put them in that order rather than BGR, and lets add an extra band not in HLSBands
bands = [HLSBands.RED, HLSBands.GREEN, HLSBands.BLUE, 14]
model = timm.create_model( # let's use a vit model this time
    "prithvi_vit_100", num_frames=1, pretrained=True, features_only=True, bands=bands
)

In [7]:
model.model_bands

[<HLSBands.RED: 'RED'>, <HLSBands.GREEN: 'GREEN'>, <HLSBands.BLUE: 'BLUE'>, 14]

In [8]:
# the model now expects 4 channels, not 6
trial_data = torch.zeros(1, 4, 224, 224) # batch_size, channels, height, width
features = model(trial_data)
for index, feature in enumerate(features):
    print(f"Feature index {index} has shape {feature.shape}")

Feature index 0 has shape torch.Size([1, 197, 768])
Feature index 1 has shape torch.Size([1, 197, 768])
Feature index 2 has shape torch.Size([1, 197, 768])
Feature index 3 has shape torch.Size([1, 197, 768])
Feature index 4 has shape torch.Size([1, 197, 768])
Feature index 5 has shape torch.Size([1, 197, 768])
Feature index 6 has shape torch.Size([1, 197, 768])
Feature index 7 has shape torch.Size([1, 197, 768])
Feature index 8 has shape torch.Size([1, 197, 768])
Feature index 9 has shape torch.Size([1, 197, 768])
Feature index 10 has shape torch.Size([1, 197, 768])
Feature index 11 has shape torch.Size([1, 197, 768])


# Model factory
The model factories let us create full models ready for specific tasks, including decoders and task specific heads.
They create normal `torch.nn.Module` s that you can use anywhere in your code.

Lets create a model for semantic segmentation with 4 classes

In [9]:
from terratorch.models import PrithviModelFactory
model_factory = PrithviModelFactory()

# Let's build a segmentation model
# Parameters prefixed with backbone_ get passed to the backbone
# Parameters prefixed with decoder_ get passed to the decoder
# Parameters prefixed with head_ get passed to the head

model = model_factory.build_model(task="segmentation",
        backbone="prithvi_vit_100",
        decoder="FCNDecoder",
        in_channels=6,
        bands=[
            HLSBands.BLUE,
            HLSBands.GREEN,
            HLSBands.RED,
            HLSBands.NIR_NARROW,
            HLSBands.SWIR_1,
            HLSBands.SWIR_2,
        ],
        num_classes=4,
        pretrained=True,
        num_frames=1,
        decoder_channels=128,
        head_dropout=0.2
    )

Their output is a `ModelOutput` object, including the main `output` and the output of any `auxiliary_heads`.

In [10]:
trial_data = torch.zeros(1, 6, 224, 224) # batch_size, channels, height, width
out = model(trial_data)
print(out.output.shape)

torch.Size([1, 4, 224, 224])


# Datamodule
You can create datamodules for training by creating your own subclasses of `torchgeo.datamodules.GeoDataModule` or `torchgeo.datamodules.NonGeoDataModule`.

Alternatively, leverage one of our generic data modules.

Datamodules package train, test and validation datasets as well as any transforms done.


In [15]:
from terratorch.datamodules import GenericNonGeoPixelwiseRegressionDataModule
batch_size = 4
num_workers = 4
train_val_test = [
    "/dccstor/hhr-weather/latest_filters_all_agb_patches_tts_clipped_0_500/train_images",
    "/dccstor/hhr-weather/latest_filters_all_agb_patches_tts_clipped_0_500/val_images",
    "/dccstor/hhr-weather/latest_filters_all_agb_patches_tts_clipped_0_500/test_images",
]

train_val_test_labels = {
    "train_label_data_root": "/dccstor/hhr-weather/latest_filters_all_agb_patches_tts_clipped_0_500/train_labels",
    "val_label_data_root": "/dccstor/hhr-weather/latest_filters_all_agb_patches_tts_clipped_0_500/val_labels",
    "test_label_data_root": "/dccstor/hhr-weather/latest_filters_all_agb_patches_tts_clipped_0_500/test_labels",
}

means = [385.88501817, 714.60615207, 658.96267376, 3314.57774238, 2238.71812558, 1250.00982518]
stds = [264.62872, 355.62848, 504.54855, 898.4953, 947.22894, 828.1297]
datamodule = GenericNonGeoPixelwiseRegressionDataModule(
    batch_size,
    num_workers,
    *train_val_test,
    means,
    stds,
    **train_val_test_labels,
    # train_transform=train_transform,
    dataset_bands=[
        -1,
        HLSBands.BLUE,
        HLSBands.GREEN,
        HLSBands.RED,
        HLSBands.NIR_NARROW,
        HLSBands.SWIR_1,
        HLSBands.SWIR_2,
        14,
        15,
        16,
        17,
    ],
    output_bands=[
        HLSBands.BLUE,
        HLSBands.GREEN,
        HLSBands.RED,
        HLSBands.NIR_NARROW,
        HLSBands.SWIR_1,
        HLSBands.SWIR_2,
    ],
)
# we want to access some properties of the train dataset later on, so lets call setup here
# if not, we would not need to
datamodule.setup("fit") 

# Lightning Trainers
At the highest level of abstraction, you can operate with task specific trainers. These encapsulate the model, loss, optimizer and any training hyperparameters.

They build on the model factory we introduced previously and are able to take any. To use a task with a model not supported by a currently existing model factory, simply create your own model factory!

Let's create a Trainer for PixelWise Regression

We also show how to use the popular CosineLrDecay scheduler into training

In [12]:
from terratorch.tasks import IBMPixelwiseRegressionTask
from terratorch.models.model import AuxiliaryHead
from torch.optim.lr_scheduler import OneCycleLR
import math

epochs = 100
lr = 1e-3
model_args = {
        "backbone":"prithvi_vit_100",
        "decoder":"FCNDecoder",
        "in_channels": 6,
        "bands": [
            HLSBands.RED,
            HLSBands.GREEN,
            HLSBands.BLUE,
            HLSBands.NIR_NARROW,
            HLSBands.SWIR_1,
            HLSBands.SWIR_2,
        ],
        "pretrained": True,
        "num_frames":1,
        "decoder_channels":128,
        "head_dropout":0.2
}

task = IBMPixelwiseRegressionTask(
    model_args,
    "PrithviModelFactory",
    loss="rmse",
    aux_loss={"fcn_aux_head": 0.4},
    lr=lr,
    ignore_index=-1,
    optimizer=torch.optim.AdamW,
    optimizer_hparams={"weight_decay": 0.05},
    scheduler=OneCycleLR,
    scheduler_hparams={
        "max_lr": lr,
        "epochs": epochs,
        "steps_per_epoch": math.ceil(len(datamodule.train_dataset) / batch_size),
        "pct_start": 0.05,
        "interval": "step",
    },
    aux_heads=[
        AuxiliaryHead( # define an auxiliary head
            "fcn_aux_head",
            "FCNDecoder",
            {"decoder_channels": 512, "decoder_in_index": 2, "decoder_num_convs": 2, "head_channel_list": [64]},
        )
    ],
)

Now we can use a Lightning Trainer to train this model on the datamodule we specified

In [13]:
import os
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, RichProgressBar, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch import Trainer

accelerator = "gpu"
experiment = "tutorial"
default_root_dir = os.path.join("/dccstor/geofm-finetuning/carlosgomes/", "tutorial_experiments", experiment)
checkpoint_callback = ModelCheckpoint(monitor=task.monitor, save_top_k=1, save_last=True)
early_stopping_callback = EarlyStopping(monitor=task.monitor, min_delta=0.00, patience=20)
logger = TensorBoardLogger(save_dir=default_root_dir, name=experiment)

trainer = Trainer(
    # precision="16-mixed",
    accelerator=accelerator,
    callbacks=[
        RichProgressBar(),
        checkpoint_callback,
        LearningRateMonitor(logging_interval="epoch"),
    ],
    logger=logger,
    max_epochs=1, # train only one epoch for demo
    default_root_dir=default_root_dir,
)
trainer.fit(model=task, datamodule=datamodule)

INFO: `Trainer.fit` stopped: `max_epochs=1` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=1` reached.


In [16]:
trainer.test(model=task, datamodule=datamodule)

[{'test/loss': 130.076416015625,
  'test/decode_head_epoch': 95.1190185546875,
  'test/fcn_aux_head_epoch': 87.39353942871094,
  'test/MAE': 74.75745391845703,
  'test/MSE': 13011.7548828125,
  'test/RMSE': 114.06907653808594}]

# Configs
Alternatively, define all this in a config file and run it through the cli. This is the way models must be specified to be onboarded to the studio.

For example

```yaml
# lightning.pytorch==2.1.1
seed_everything: 0
trainer:
  accelerator: auto
  strategy: auto
  devices: auto
  num_nodes: 1
  precision: 16-mixed
  logger: True # will use tensorboardlogger
  callbacks:
    - class_path: RichProgressBar
    - class_path: LearningRateMonitor
      init_args:
        logging_interval: epoch

  max_epochs: 200
  check_val_every_n_epoch: 1
  log_every_n_steps: 50
  enable_checkpointing: true
  default_root_dir: <path to root dir>
data:
  class_path: GenericNonGeoSegmentationDataModule
  init_args:
    batch_size: 4
    num_workers: 8
    constant_scale: 0.0001
    rgb_indices:
      - 2
      - 1
      - 0
    filter_indices:
      - 2
      - 1
      - 0
      - 3
      - 4
      - 5
    train_data_root: <path to train data root>
    val_data_root: <path to val data root>
    test_data_root: <path to test data root>
    img_grep: "*_S2GeodnHand.tif"
    label_grep: "*_LabelHand.tif"
    means:
      - 0.107582
      - 0.13471393
      - 0.12520133
      - 0.3236181
      - 0.2341743
      - 0.15878009
    stds:
      - 0.07145836
      - 0.06783548
      - 0.07323416
      - 0.09489725
      - 0.07938496
      - 0.07089546
    num_classes: 2

model:
  class_path: IBMSemanticSegmentationTask
  init_args:
    model_args:
      decoder: FCNDecoder
      pretrained: true
      backbone: prithvi_vit_100
      img_size: 512
      decoder_channels: 256
      in_channels: 6
      bands:
        - RED
        - GREEN
        - BLUE
        - NIR_NARROW
        - SWIR_1
        - SWIR_2
      num_frames: 1
      num_classes: 2
      head_dropout: 0.1
      head_channel_list:
        - 256
    loss: ce
    aux_heads:
      - name: aux_head
        decoder: FCNDecoder
        decoder_args:
          decoder_channels: 256
          decoder_in_index: 2
          decoder_num_convs: 1
          head_channel_list:
            - 64
    aux_loss:
      aux_head: 1.0
    ignore_index: -1
    class_weights:
      - 0.3
      - 0.7
    freeze_backbone: false
    freeze_decoder: false
    model_factory: PrithviModelFactory
optimizer:
  class_path: torch.optim.AdamW
  init_args:
    lr: 6.e-5
    weight_decay: 0.05
lr_scheduler:
  class_path: ReduceLROnPlateau
  init_args:
    monitor: val/loss
```

You can train with `terratorch fit --config <path_to_config_file>`

You can test with `terratorch test --config <path_to_config_file> --ckpt_path <path_to_checkpoint_file>`

You can run inference with 
```shell
terratorch predict -c <path_to_config_file> --ckpt_path<path_to_checkpoint> --predict_output_dir <path_to_output_dir> --data.init_args.predict_data_root <path_to_input_dir> --data.init_args.predict_dataset_bands <all bands in the predicted dataset, e.g. [BLUE,GREEN,RED,NIR_NARROW,SWIR_1,SWIR_2,0]>
```

```