## Temporal Wrapper Example

Example: Using `TemporalWrapper` with a TerraTorch backbone (e.g., TerraMind) to process a temporal stack of inputs.


### Load Backbone

In [49]:
from terratorch.registry import BACKBONE_REGISTRY

backbone = BACKBONE_REGISTRY.build(
    "terramind_v1_base",
    modalities=["S2L2A"],
    pretrained=True,
)

### Dummy Temporal Input Sample

In [50]:
import torch

x = torch.randn(1, 12, 4, 224, 224) # Shape [B, C, T, H, W]
x_dict = {"S2L2A": x} # Wrap into modality dict for TerraMind

print({k: v.shape for k, v in x_dict.items()})

{'S2L2A': torch.Size([1, 12, 4, 224, 224])}


### Wrap backbone with TemporalWrapper

We load `TemporalWrapper` with `pooling='mean'`, which averages the latent representations across the 4 input timesteps after they have been processed by the backbone.

In [51]:
from terratorch.models.utils import TemporalWrapper

temporal_backbone = TemporalWrapper(backbone, pooling="mean")

### Using wrapped backbone in an EncoderDecoder Task Model

We pass the wrapped backbone as a backbone when building an EncoderDecoder Model for a segmentation downstream task.

In [None]:
from terratorch.tasks import SemanticSegmentationTask

model = SemanticSegmentationTask(
    model_factory="EncoderDecoderFactory",
    model_args={
            "backbone": temporal_backbone,
            # Necks
            "necks": [
                {"name": "SelectIndices",
                "indices": [2, 5, 8, 11]
                },
                {"name": "ReshapeTokensToImage",
                "remove_cls_token": False},  # TerraMind is trained without CLS token
                {"name": "LearnedInterpolateToPyramidal"}  # Some decoders like UNet or UperNet expect hierarchical features.
            ],

            # Decoder
            "decoder": "UNetDecoder",
            "decoder_channels": [512, 256, 128, 64],
            
            # Head
            "head_dropout": 0.1,
            "num_classes": 2,
        }, 
    loss="dice",
    optimizer="AdamW",
    lr=2e-5, 
    ignore_index=-1,
    freeze_backbone=False, 
    freeze_decoder=False,
    )



### Sample Forward Pass

A sample forward pass shows that the 4-timestep input produces a single segmentation mask. The temporal dimension is averaged after the backbone, before the features are passed through the neck, decoder, and head of the model.

In [48]:
device = "cpu"
model = model.to(device).eval()
x_dict = {k: v.to(device) for k, v in x_dict.items()}

print("Input:", {k: v.shape for k, v in x_dict.items()})
print("Output:", model(x_dict).output.shape)
print()

# Forward through temporal encoder (backbone)
feats = model.model.encoder(x_dict)
print("Encoder output:", len(feats), "features of shape", feats[0].shape)

# Forward through necks
necks_out = model.model.neck(feats)

# Forward through decoder
decoder_out = model.model.decoder(necks_out)
print("After decoder:", decoder_out.shape)

# Forward through head
head_out = model.model.head(decoder_out)
print("Final head output:", head_out.shape)

Input: {'S2L2A': torch.Size([1, 12, 4, 224, 224])}
Output: torch.Size([1, 2, 224, 224])

Encoder output: 12 features of shape torch.Size([1, 196, 768])
After decoder: torch.Size([1, 64, 112, 112])
Final head output: torch.Size([1, 2, 112, 112])
