In [2]:
import terratorch
import lightning.pytorch as pl

pl.seed_everything(0)

# By default, TerraTorch saves the model with the best validation loss. You can overwrite this by defining a custom ModelCheckpoint, e.g., saving the model with the highest validation mIoU.
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath="output/terramind_base_sen1floods11/checkpoints/",
    mode="max",
    monitor="val/mIoU", # Variable to monitor
    filename="best-mIoU",
    save_weights_only=True,
)

# Lightning Trainer
trainer = pl.Trainer(
    accelerator="cpu",
    strategy="auto",
    devices=1, # Deactivate multi-gpu because it often fails in notebooks
    precision="16-mixed",  # Speed up training with half precision, delete for full precision training.
    num_nodes=1,
    logger=True,  # Uses TensorBoard by default
    max_epochs=3, # For demos
    log_every_n_steps=1,
    callbacks=[checkpoint_callback, pl.callbacks.RichProgressBar()],
    default_root_dir="output/terramind_base_sen1floods11/",
)


INFO: Seed set to 0
INFO:lightning.fabric.utilities.seed:Seed set to 0
/Users/isw/Documents/Code/CF_Demo/cf_demo/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/accelerator_connector.py:513: You passed `Trainer(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on CPU. Using `precision='bf16-mixed'` instead.
INFO: Using bfloat16 Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using bfloat16 Automatic Mixed Precision (AMP)
INFO: GPU available: True (mps), used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (mps), used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/Users/isw/Documents/Code/CF_Demo/cf_demo/lib/python3.13/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available

In [3]:

# Segmentation mask that build the model and handles training and validation steps.
model = terratorch.tasks.SemanticSegmentationTask(
    model_factory="EncoderDecoderFactory",  # Combines a backbone with necks, the decoder, and a head
    model_args={
        # TerraMind backbone
        "backbone": "IdentityBackbone", # large version: terramind_v1_large
        "backbone_out_channels": [768,768,768,768],
        # "backbone_pretrained": True,
        # "backbone_modalities": ["S2L1C", "S1GRD"],
        # Optionally, define the input bands. This is only needed if you select a subset of the pre-training bands, as explained above.
        # "backbone_bands": {"S1GRD": ["VV"]},

        # Necks
        "necks": [
            {
                "name": "SelectIndices",
                # "indices": [2, 5, 8, 11]
                "indices": [0, 1, 2, 3] # indices for terramind_v1_base
                # "indices": [5, 11, 17, 23] # indices for terramind_v1_large
            },
            {"name": "ReshapeTokensToImage",
             "remove_cls_token": False},  # TerraMind is trained without CLS token, which neads to be specified.
            {"name": "LearnedInterpolateToPyramidal"}  # Some decoders like UNet or UperNet expect hierarchical features. Therefore, we need to learn a upsampling for the intermediate embedding layers when using a ViT like TerraMind.
        ],

        # Decoder
        "decoder": "UNetDecoder",
        "decoder_channels": [512, 256, 128, 64],

        # Head
        "head_dropout": 0.1,
        "num_classes": 14,
    },

    loss="dice",  # We recommend dice for binary tasks and ce for tasks with multiple classes.
    ignore_index=255,
    optimizer="AdamW",
    lr=2e-5,  # The optimal learning rate varies between datasets, we recommend testing different once between 1e-5 and 1e-4. You can perform hyperparameter optimization using terratorch-iterate.
    # ignore_index=-1,
    freeze_backbone=True, # Only used to speed up fine-tuning in this demo, we highly recommend fine-tuning the backbone for the best performance.
    freeze_decoder=False,  # Should be false in most cases as the decoder is randomly initialized.
    plot_on_val=True,  # Plot predictions during validation steps
    class_names=["Others", "Water"]  # optionally define class names
)



[768, 768, 768, 768]
[768, 768, 768, 768]
[0, 1, 2, 3]
[768, 768, 768, 768]
[768, 768, 768, 768]


In [4]:
from terratorch.datamodules import GenericEmbeddingDataModule

datamodule = GenericEmbeddingDataModule(data_root = path/to/embeddingparquet)
datamodule.setup("fit")

[EmbeddingRowDataset] START 
[EmbeddingRowDataset] Found 1 Parquet file(s) under /Users/isw/Documents/Code/CF_Demo/output2/embeddings
[EmbeddingRowDataset] Using dtype float32
[EmbeddingRowDataset] Using columns: ['embedding_layer_2', 'embedding_layer_5', 'embedding_layer_8', 'embedding_layer_11']
[EmbeddingRowDataset] sentinel_stack_S2A_MSIL2A_20200912T100031_N0500_R122_T34VEH_20230311T210416_embedding.parquet: 1 row groups, 2401 rows
[EmbeddingRowDataset] 1 file(s), 2401 row(s). Columns=['embedding_layer_2', 'embedding_layer_5', 'embedding_layer_8', 'embedding_layer_11']
[EmbeddingRowDataset] START 
[EmbeddingRowDataset] Found 1 Parquet file(s) under /Users/isw/Documents/Code/CF_Demo/output2/embeddings
[EmbeddingRowDataset] Using dtype float32
[EmbeddingRowDataset] Using columns: ['embedding_layer_2', 'embedding_layer_5', 'embedding_layer_8', 'embedding_layer_11']
[EmbeddingRowDataset] sentinel_stack_S2A_MSIL2A_20200912T100031_N0500_R122_T34VEH_20230311T210416_embedding.parquet: 1 ro

In [5]:
trainer.fit(model, datamodule=datamodule)

[EmbeddingRowDataset] START 
[EmbeddingRowDataset] Found 1 Parquet file(s) under /Users/isw/Documents/Code/CF_Demo/output2/embeddings
[EmbeddingRowDataset] Using dtype float32
[EmbeddingRowDataset] Using columns: ['embedding_layer_2', 'embedding_layer_5', 'embedding_layer_8', 'embedding_layer_11']
[EmbeddingRowDataset] sentinel_stack_S2A_MSIL2A_20200912T100031_N0500_R122_T34VEH_20230311T210416_embedding.parquet: 1 row groups, 2401 rows
[EmbeddingRowDataset] 1 file(s), 2401 row(s). Columns=['embedding_layer_2', 'embedding_layer_5', 'embedding_layer_8', 'embedding_layer_11']
[EmbeddingRowDataset] START 
[EmbeddingRowDataset] Found 1 Parquet file(s) under /Users/isw/Documents/Code/CF_Demo/output2/embeddings
[EmbeddingRowDataset] Using dtype float32
[EmbeddingRowDataset] Using columns: ['embedding_layer_2', 'embedding_layer_5', 'embedding_layer_8', 'embedding_layer_11']
[EmbeddingRowDataset] sentinel_stack_S2A_MSIL2A_20200912T100031_N0500_R122_T34VEH_20230311T210416_embedding.parquet: 1 ro

INFO: 
Detected KeyboardInterrupt, attempting graceful shutdown ...
INFO:lightning.pytorch.utilities.rank_zero:
Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined