In [1]:
import os
import sys
import torch
import terratorch
import albumentations
import lightning.pytorch as pl
import matplotlib.pyplot as plt
from pathlib import Path
from terratorch.datamodules import GenericNonGeoSegmentationDataModule
import warnings
import os
import zipfile
warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
pl.seed_everything(0)

checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath="output/tutorial/checkpoints/",
    mode="min",
    monitor="val/RMSE",
    filename="best-{epoch:02d}",
)

from lightning.pytorch.callbacks import Callback


Seed set to 0


In [3]:
datamodule = terratorch.datamodules.GenericNonGeoPixelwiseRegressionDataModule(
    batch_size=4,
    num_workers=2,
    num_classes=2,
    check_stackability = False,
    # Define dataset paths
    train_data_root="../../tests/resources/inputs",
    train_label_data_root="../../tests/resources/inputs",
    val_data_root="../../tests/resources/inputs",
    val_label_data_root="../../tests/resources/inputs",
    test_data_root="../../tests/resources/inputs",
    test_label_data_root="../../tests/resources/inputs",
    img_grep='*input*.tif',
    label_grep='*label*.tif',

    train_transform=[
        albumentations.D4(), # Random flips and rotation
        albumentations.pytorch.transforms.ToTensorV2(),
    ],
    val_transform=None,  # Using ToTensor() by default
    test_transform=None,

    # Define standardization values
    means=[
      547.36707,
      898.5121,
      1020.9082,
      2665.5352,
      2340.584,
      1610.1407,
    ],
    stds=[
      411.4701,
      558.54065,
      815.94025,
      812.4403,
      1113.7145,
      1067.641,
    ],
    dataset_bands = [-1, "BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2", -1, -1, -1, -1],
    output_bands = ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"],
    rgb_indices = [2, 1, 0],
    no_data_replace=0,
    no_label_replace=-1,
)

# Setup train and val datasets
datamodule.setup("fit")


In [7]:
# Lightning Trainer
trainer = pl.Trainer(
    accelerator="auto",
    strategy="auto",
    devices=1,
    #precision='bf16-mixed',
    num_nodes=1,
    logger=True,
    max_epochs=5,
    log_every_n_steps=1,
    enable_checkpointing=True,
    callbacks=[checkpoint_callback, pl.callbacks.RichProgressBar()],
    default_root_dir="output/tutorial",
    detect_anomaly=True,
)

You have turned on `Trainer(detect_anomaly=True)`. This will significantly slow down compute speed and is recommended only for model debugging.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [5]:
# Model
model = terratorch.tasks.PixelwiseRegressionTask(
    model_factory="EncoderDecoderFactory",
    model_args={
        # Backbone
        "backbone": "prithvi_eo_v2_300", 
        "backbone_pretrained": True,
        "backbone_num_frames": 1, 
        "backbone_bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"],
        "necks": [
            {
                "name": "SelectIndices",
                "indices": [5, 11, 17, 23]
            },
            {"name": "ReshapeTokensToImage",},
            {"name": "LearnedInterpolateToPyramidal"}
        ],
        "decoder": "UNetDecoder",
        "decoder_channels": [512, 256, 128, 64],
    },

    loss="rmse",
    optimizer="AdamW",
    lr=1e-3,
    ignore_index=-1,
    freeze_backbone=True,
    freeze_decoder=False,
    plot_on_val=True,
)

INFO:root:Loaded weights for HLSBands.BLUE in position 0 of patch embed
INFO:root:Loaded weights for HLSBands.GREEN in position 1 of patch embed
INFO:root:Loaded weights for HLSBands.RED in position 2 of patch embed
INFO:root:Loaded weights for HLSBands.NIR_NARROW in position 3 of patch embed
INFO:root:Loaded weights for HLSBands.SWIR_1 in position 4 of patch embed
INFO:root:Loaded weights for HLSBands.SWIR_2 in position 5 of patch embed


In [8]:
trainer.fit(model, datamodule=datamodule)

`Trainer.fit` stopped: `max_epochs=5` reached.


In [13]:
ckpt_path = "output/tutorial/checkpoints/best-epoch=04.ckpt"

In [14]:
trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path)

Restoring states from the checkpoint path at output/tutorial/checkpoints/best-epoch=04.ckpt


Restored all states from the checkpoint at output/tutorial/checkpoints/best-epoch=04.ckpt
`Trainer.fit` stopped: `max_epochs=5` reached.


In [None]:
# Lightning Trainer
trainer = pl.Trainer(
    accelerator="auto",
    strategy="auto",
    devices=1,
    #precision='bf16-mixed',
    num_nodes=1,
    logger=True,
    max_epochs=7,
    log_every_n_steps=1,
    enable_checkpointing=True,
    #callbacks=[checkpoint_callback, pl.callbacks.RichProgressBar()],
    default_root_dir="output/tutorial",
    detect_anomaly=True,
)
trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path)

You have turned on `Trainer(detect_anomaly=True)`. This will significantly slow down compute speed and is recommended only for model debugging.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Restoring states from the checkpoint path at output/tutorial/checkpoints/best-epoch=04.ckpt

  | Name          | Type             | Params
---------------------------------------------------
0 | model         | PixelWiseModel   | 324 M 
1 | criterion     | RootLossWrapper  | 0     
2 | train_metrics | MetricCollection | 0     
3 | val_metrics   | MetricCollection | 0     
4 | test_metrics  | ModuleList       | 0     
---------------------------------------------------
20.3 M    Trainable params
303 M     Non-trainable params
324 M     Total params
1,296.818 Total estimated model params size (MB)
Restored all states from the checkpoint at output/tutorial/checkpoints/best-epoch=04.ckpt


Epoch 5: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:26<00:00,  0.04it/s, v_num=4]
Validation: |                                                                                                                                                          | 0/? [00:00<?, ?it/s][A
Validation:   0%|                                                                                                                                                      | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                                                                                                         | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:  50%|████████████████████████████████████████████████████████████████▌                                                                | 1/2 [00:16<00:16,  0.06it/s][A
Validation DataLoader 0: 100%|████████