# *Introduction*

# Build the environment

In [1]:
# Install TerraTorch and dependencies
!pip install terratorch
!pip install gdown tensorboard lightning



In [2]:
import os
import shutil
from pathlib import Path
import numpy as np
from PIL import Image
import torch
import gdown
import lightning.pytorch as pl
import matplotlib.pyplot as plt
from terratorch.datamodules import GenericNonGeoClassificationDataModule
from terratorch.tasks import ClassificationTask
import rasterio
import zipfile
import warnings
warnings.filterwarnings("ignore")

  _C._set_float32_matmul_precision(precision)


In [3]:
PREPARED_ROOT = Path("./prepared")
TRAIN_DIR = PREPARED_ROOT / "train"
VAL_DIR = PREPARED_ROOT / "val"
TEST_DIR = PREPARED_ROOT / "test"
OUTPUT_PATH = "./output/terramind_romania"

IMG_SIZE = 224
NUM_CLASSES = 10
BATCH_SIZE = 8
LEARNING_RATE = 1e-4
MAX_EPOCHS = 30
SEED = 42

pl.seed_everything(SEED)

INFO: Seed set to 42
INFO:lightning.fabric.utilities.seed:Seed set to 42


42

In [4]:
print("="*60)
print("DOWNLOADING DATASET FROM GOOGLE DRIVE")
print("="*60)

FILE_ID = "1bF7f_qgRIEnNySeHwwQLHZv_cCrqqs_H"
DATASET_ZIP = "data_Romania.zip"

if os.path.isfile(DATASET_ZIP):
    print(f"Removing existing {DATASET_ZIP}...")
    os.remove(DATASET_ZIP)

print(f"Downloading {DATASET_ZIP}...")
url = f"https://drive.google.com/uc?id={FILE_ID}&confirm=t"
gdown.download(url, DATASET_ZIP, quiet=False, fuzzy=True)
print("‚úÖ Download complete!")

# Extract dataset
if not os.path.exists("Data"):
    print("Extracting dataset...")
    with zipfile.ZipFile(DATASET_ZIP, 'r') as zip_ref:
        zip_ref.extractall(".")
    print("‚úÖ Extraction complete!")

# Find data root
if os.path.exists("Data"):
    DATA_ROOT = Path("Data")
elif os.path.exists("data_Romania/Data"):
    DATA_ROOT = Path("data_Romania/Data")
else:
    raise ValueError("Cannot find Data folder!")

print(f"‚úÖ DATA_ROOT: {DATA_ROOT}")

DOWNLOADING DATASET FROM GOOGLE DRIVE
Downloading data_Romania.zip...


Downloading...
From (original): https://drive.google.com/uc?id=1bF7f_qgRIEnNySeHwwQLHZv_cCrqqs_H
From (redirected): https://drive.google.com/uc?id=1bF7f_qgRIEnNySeHwwQLHZv_cCrqqs_H&confirm=t&uuid=001fa1d1-a82c-4a9c-b06b-7e9a7ad4f1bf
To: /content/data_Romania.zip
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 55.5M/55.5M [00:01<00:00, 42.5MB/s]


‚úÖ Download complete!
Extracting dataset...
‚úÖ Extraction complete!
‚úÖ DATA_ROOT: Data


In [5]:
if not TRAIN_DIR.exists():
    print("\nCreating train/val/test splits...")
    TRAIN_DIR.mkdir(parents=True, exist_ok=True)
    VAL_DIR.mkdir(parents=True, exist_ok=True)
    TEST_DIR.mkdir(parents=True, exist_ok=True)

    classes = [d for d in os.listdir(DATA_ROOT)
               if os.path.isdir(DATA_ROOT / d) and d.isdigit()]

    print(f"Found classes: {sorted(classes)}")

    for cls in classes:
        cls_path = DATA_ROOT / cls
        images = [f for f in os.listdir(cls_path)
                  if f.endswith((".jpg", ".png", ".tif"))]

        if len(images) == 0:
            print(f"Warning: No images in class {cls}")
            continue

        np.random.seed(SEED)
        np.random.shuffle(images)

        n = len(images)
        train_split = int(0.7 * n)
        val_split = int(0.9 * n)

        print(f"Class {cls}: {n} images (train={train_split}, val={val_split-train_split}, test={n-val_split})")

        for i, img_name in enumerate(images):
            src = cls_path / img_name
            if i < train_split:
                dst_dir = TRAIN_DIR / cls
            elif i < val_split:
                dst_dir = VAL_DIR / cls
            else:
                dst_dir = TEST_DIR / cls

            dst_dir.mkdir(parents=True, exist_ok=True)
            shutil.copy(src, dst_dir)

    print("‚úÖ Dataset split created")
else:
    print("‚úÖ Dataset splits already exist")


Creating train/val/test splits...
Found classes: ['1', '10', '2', '3', '4', '5', '6', '7', '8', '9']
Class 7: 10 images (train=7, val=2, test=1)
Class 4: 10 images (train=7, val=2, test=1)
Class 8: 10 images (train=7, val=2, test=1)
Class 5: 10 images (train=7, val=2, test=1)
Class 10: 10 images (train=7, val=2, test=1)
Class 3: 10 images (train=7, val=2, test=1)
Class 9: 10 images (train=7, val=2, test=1)
Class 1: 10 images (train=7, val=2, test=1)
Class 6: 10 images (train=7, val=2, test=1)
Class 2: 10 images (train=7, val=2, test=1)
‚úÖ Dataset split created


In [6]:
print("\n" + "="*60)
print("COMPUTING DATASET STATISTICS (ALL BANDS)")
print("="*60)

all_pixels = []
num_bands = None

for cls in os.listdir(TRAIN_DIR):
    cls_folder = TRAIN_DIR / cls
    if not cls_folder.is_dir():
        continue

    for img_name in os.listdir(cls_folder):
        img_path = cls_folder / img_name
        if img_path.suffix.lower() != ".tif":
            continue

        try:
            with rasterio.open(img_path) as src:
                arr = src.read()  # shape: (bands, H, W)

                if arr.shape[0] == 0:
                    continue

                if num_bands is None:
                    num_bands = arr.shape[0]
                    print(f"Detected {num_bands} bands in images")

                # Transpose to (H, W, bands)
                arr = np.transpose(arr, (1, 2, 0))
                arr = arr.astype(np.float32)

                # Normalize per band
                for band_idx in range(arr.shape[2]):
                    band_max = arr[:, :, band_idx].max()
                    if band_max > 0:
                        arr[:, :, band_idx] = arr[:, :, band_idx] / band_max

                # Resize each band
                resized_bands = []
                for band_idx in range(arr.shape[2]):
                    band = Image.fromarray((arr[:, :, band_idx] * 255).astype(np.uint8))
                    band = band.resize((IMG_SIZE, IMG_SIZE), Image.BILINEAR)
                    resized_bands.append(np.array(band) / 255.0)

                arr_resized = np.stack(resized_bands, axis=-1)
                all_pixels.append(arr_resized)

        except Exception as e:
            print(f"Error processing {img_name}: {e}")

if len(all_pixels) == 0:
    raise RuntimeError("No valid images found!")

all_pixels_stacked = np.stack(all_pixels)
means = all_pixels_stacked.mean(axis=(0, 1, 2)).tolist()
stds = all_pixels_stacked.std(axis=(0, 1, 2)).tolist()

# Ensure no zero std (would cause NaN in normalization)
stds = [max(s, 1e-6) for s in stds]

print(f"‚úÖ Means: {means}")
print(f"‚úÖ Stds: {stds}")
print(f"‚úÖ Channels: {len(means)}")


COMPUTING DATASET STATISTICS (ALL BANDS)




Detected 7 bands in images




‚úÖ Means: [0.2052698869697226, 0.20597397329211606, 0.1809676738979829, 0.24996581779931318, 0.26791339750153054, 0.23594622670547327, 0.0]
‚úÖ Stds: [0.23311969550603803, 0.2557454820862881, 0.23774046965092685, 0.2138183312919335, 0.22443376649545302, 0.2080685734895017, 1e-06]
‚úÖ Channels: 7


In [7]:
print("\n" + "="*60)
print("CREATING DATAMODULE")
print("="*60)

datamodule = GenericNonGeoClassificationDataModule(
    train_data_root=str(TRAIN_DIR),
    val_data_root=str(VAL_DIR),
    test_data_root=str(TEST_DIR),
    batch_size=BATCH_SIZE,
    num_workers=2,
    num_classes=NUM_CLASSES,
    means=means,
    stds=stds,
)

print("‚úÖ DataModule created")



CREATING DATAMODULE
‚úÖ DataModule created


In [8]:
def get_backbone_info(backbone_name):
    """Get detailed information about a specific backbone."""
    from terratorch.models import BACKBONE_REGISTRY

    if backbone_name in BACKBONE_REGISTRY:
        print(f"\nüìã Information for: {backbone_name}")
        print(f"  Registry entry: {BACKBONE_REGISTRY[backbone_name]}")
        # Try to get more info if available
        try:
            model_class = BACKBONE_REGISTRY[backbone_name]
            print(f"  Model class: {model_class.__name__}")
        except:
            pass
    else:
        print(f"‚ùå Backbone '{backbone_name}' not found in registry")
        print(f"Available options: {list(BACKBONE_REGISTRY.keys())[:5]}...")

# Usage:
get_backbone_info("prithvi_eo_v2_300")

ImportError: cannot import name 'BACKBONE_REGISTRY' from 'terratorch.models' (/usr/local/lib/python3.12/dist-packages/terratorch/models/__init__.py)

In [None]:
print("\n" + "="*60)
print("CREATING MODEL")
print("="*60)

num_channels = len(means)
print(f"Using {num_channels} channels")

# Generate band names for all channels
all_band_names = [f'BAND_{i+1}' for i in range(num_channels)]

# FIX: Create CrossEntropyLoss with proper ignore_index
criterion = torch.nn.CrossEntropyLoss(ignore_index=-100)

task = ClassificationTask(
    model_args={
        "backbone": "prithvi_eo_v2_300",
        "backbone_pretrained": True,
        "backbone_bands": all_band_names,  # All 7 bands
        "backbone_num_frames": 1,
        "decoder": "IdentityDecoder",
        "head_dropout": 0.1,
        "num_classes": NUM_CLASSES
    },
    model_factory="EncoderDecoderFactory",
    loss=criterion,
    lr=LEARNING_RATE,
    aux_loss={},
    optimizer="AdamW",
    optimizer_hparams={"weight_decay": 0.05},
    class_names=[str(i+1) for i in range(NUM_CLASSES)],
)

print("‚úÖ Task created")


CREATING MODEL
Using 7 channels




‚úÖ Task created


In [None]:
print("\n" + "="*60)
print("CONFIGURING TRAINER")
print("="*60)

checkpoint_dir = os.path.join(OUTPUT_PATH, "checkpoints")
os.makedirs(checkpoint_dir, exist_ok=True)

checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath=checkpoint_dir,
    monitor="val/Accuracy",
    mode="max",
    filename="best-acc-{epoch:02d}-{val/Accuracy:.3f}",
    save_top_k=3,
    save_weights_only=True,
)

trainer = pl.Trainer(
    accelerator="auto",
    devices=1,
    max_epochs=MAX_EPOCHS,
    precision="16-mixed",
    callbacks=[
        checkpoint_callback,
        pl.callbacks.RichProgressBar(),
        pl.callbacks.LearningRateMonitor(logging_interval="epoch"),
    ],
    default_root_dir=OUTPUT_PATH,
    log_every_n_steps=10,
    num_sanity_val_steps=2,
    gradient_clip_val=1.0,
)

print("‚úÖ Trainer configured")

INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs



CONFIGURING TRAINER
‚úÖ Trainer configured


In [None]:
print("\n" + "="*60)
print("STARTING TRAINING")
print("="*60)

try:
    trainer.fit(task, datamodule=datamodule)
    print(f"\n‚úÖ Training complete!")
    print(f"Best model: {checkpoint_callback.best_model_path}")

    # ====== TEST ======
    print("\n" + "="*60)
    print("TESTING MODEL")
    print("="*60)

    if checkpoint_callback.best_model_path:
        trainer.test(task, datamodule=datamodule, ckpt_path=checkpoint_callback.best_model_path)
    else:
        print("‚ö†Ô∏è No checkpoint saved, testing with current model")
        trainer.test(task, datamodule=datamodule)

except Exception as e:
    print(f"\n‚ùå Training failed with error: {e}")
    import traceback
    traceback.print_exc()


STARTING TRAINING


INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be dec

INFO: Restoring states from the checkpoint path at /content/output/terramind_romania/checkpoints/best-acc-epoch=09-val/Accuracy=0.300.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Restoring states from the checkpoint path at /content/output/terramind_romania/checkpoints/best-acc-epoch=09-val/Accuracy=0.300.ckpt



‚úÖ Training complete!
Best model: /content/output/terramind_romania/checkpoints/best-acc-epoch=09-val/Accuracy=0.300.ckpt

TESTING MODEL


INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: Loaded model weights from the checkpoint at /content/output/terramind_romania/checkpoints/best-acc-epoch=09-val/Accuracy=0.300.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Loaded model weights from the checkpoint at /content/output/terramind_romania/checkpoints/best-acc-epoch=09-val/Accuracy=0.300.ckpt
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel tha

Output()

ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7


In [None]:
print("\n" + "="*60)
print("VISUALIZING PREDICTIONS")
print("="*60)

try:
    # Load best model
    best_model = ClassificationTask.load_from_checkpoint(
        checkpoint_callback.best_model_path,
        model_factory="EncoderDecoderFactory",
    )
    best_model.eval()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    best_model = best_model.to(device)

    # Get test batch
    test_loader = datamodule.test_dataloader()
    batch = next(iter(test_loader))

    # Extract images and labels correctly
    # The batch structure from TerraTorch is: {"image": tensor, "label": tensor}
    images = batch["image"]
    labels = batch["label"]

    # Move to device
    images = images.to(device)

    # For Prithvi, we need to wrap the tensor in a dict with the modality name
    # TerraTorch expects: {"optical": tensor} format for the model
    images_dict = {"optical": images}

    # Get predictions
    with torch.no_grad():
        outputs = best_model(images_dict)
        preds = torch.argmax(outputs.output, dim=1).cpu().numpy()

    # Visualize
    fig, axes = plt.subplots(5, 3, figsize=(12, 15))
    class_names = [str(i+1) for i in range(NUM_CLASSES)]

    num_samples = min(5, len(labels))
    for i in range(num_samples):
        # Get image (use first 3 bands for RGB visualization)
        img = images[i].cpu().permute(1, 2, 0).numpy()
        img = (img - img.min()) / (img.max() - img.min() + 1e-8)

        if img.shape[2] > 3:
            img = img[:, :, :3]

        axes[i, 0].imshow(img)
        axes[i, 0].set_title(f"Sample {i+1}")
        axes[i, 0].axis('off')

        axes[i, 1].text(0.5, 0.5, f"True: Class {class_names[labels[i]]}",
                        ha='center', va='center', fontsize=14)
        axes[i, 1].axis('off')

        is_correct = preds[i] == labels[i]
        axes[i, 2].text(0.5, 0.5, f"Pred: Class {class_names[preds[i]]}",
                        ha='center', va='center', fontsize=14,
                        color='green' if is_correct else 'red')
        axes[i, 2].axis('off')

    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_PATH, 'predictions.png'), dpi=150, bbox_inches='tight')
    plt.show()

    print("‚úÖ Visualization complete!")

except Exception as e:
    print(f"‚ö†Ô∏è Visualization failed: {e}")
    import traceback
    traceback.print_exc()

print("\n‚úÖ ALL DONE!")


VISUALIZING PREDICTIONS


ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be decoded: 7
ERROR:PIL.TiffImagePlugin:More samples per pixel than can be dec

‚ö†Ô∏è Visualization failed: 'dict' object has no attribute 'shape'

‚úÖ ALL DONE!


Traceback (most recent call last):
  File "/tmp/ipython-input-328878395.py", line 33, in <cell line: 0>
    outputs = best_model(images_dict)
              ^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torchgeo/trainers/base.py", line 81, in forward
    return self.model(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/m