In [1]:
!pip uninstall -y torch torchvision torchaudio

!pip install \
  torch==2.5.1 \
  torchvision==0.20.1 \
  torchaudio==2.5.1 \
  --index-url https://download.pytorch.org/whl/cu124

!pip install -q terratorch==1.1 lightning albumentations rasterio
!pip install "protobuf<5.0.0" --force-reinstall



Found existing installation: torch 2.6.0+cu124
Uninstalling torch-2.6.0+cu124:
  Successfully uninstalled torch-2.6.0+cu124
Found existing installation: torchvision 0.21.0+cu124
Uninstalling torchvision-0.21.0+cu124:
  Successfully uninstalled torchvision-0.21.0+cu124
Found existing installation: torchaudio 2.6.0+cu124
Uninstalling torchaudio-2.6.0+cu124:
  Successfully uninstalled torchaudio-2.6.0+cu124
Looking in indexes: https://download.pytorch.org/whl/cu124
Collecting torch==2.5.1
  Downloading https://download.pytorch.org/whl/cu124/torch-2.5.1%2Bcu124-cp311-cp311-linux_x86_64.whl (908.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m908.3/908.3 MB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting torchvision==0.20.1
  Downloading https://download.pytorch.org/whl/cu124/torchvision-0.20.1%2Bcu124-cp311-cp311-linux_x86_64.whl (7.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.3/7.3 MB[0m [31m2.1 MB/s[0m eta 

In [2]:
import os
import torch
import terratorch
import albumentations
import lightning.pytorch as pl
import matplotlib.pyplot as plt
from pathlib import Path
from terratorch.datamodules.generic_scalar_label_data_module import GenericNonGeoClassificationDataModule
from terratorch.models import EncoderDecoderFactory
from terratorch.datasets import HLSBands
from terratorch.tasks import ClassificationTask
import rasterio
import shutil

INFO:numexpr.utils:NumExpr defaulting to 4 threads.
INFO:albumentations.check_version:A new version of Albumentations is available: 2.0.8 (you have 1.4.10). Upgrade using: pip install --upgrade albumentations
2025-11-15 13:53:29.056722: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1763214809.295042      48 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1763214809.364450      48 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
from pathlib import Path
import rasterio
import shutil
import random
import numpy as np

path = Path("/kaggle/input/lucas-dataset/Lucas_dataset")

sample_tif = next((path / "training" / "1").glob("*.tif"))
with rasterio.open(sample_tif) as src:
    num_bands = src.count
    height, width = src.height, src.width

train_root = path / "training"
val_root = path / "validation"

valid_classes = []
for cls_dir in sorted(train_root.iterdir(), key=lambda p: p.name):
    if cls_dir.is_dir():
        cls_name = cls_dir.name
        train_tifs = list(cls_dir.glob("*.tif"))
        val_dir = val_root / cls_name
        val_tifs = list(val_dir.glob("*.tif")) if val_dir.exists() else []
        print(cls_name, "-> train:", len(train_tifs), "val:", len(val_tifs))
        if len(train_tifs) > 0 and len(val_tifs) > 0:
            valid_classes.append(cls_name)

valid_classes = [c for c in valid_classes if c not in ["9", "10"]]
valid_classes = sorted(valid_classes, key=int)

subset_root = Path("/kaggle/working/Lucas_subset")
if subset_root.exists():
    shutil.rmtree(subset_root)

for split in ["training", "validation"]:
    src_split = path / split
    dst_split = subset_root / split
    dst_split.mkdir(parents=True, exist_ok=True)
    for cls in valid_classes:
        src_dir = src_split / cls
        dst_dir = dst_split / cls
        dst_dir.mkdir(parents=True, exist_ok=True)
        tif_files = list(src_dir.glob("*.tif"))
        for f in tif_files:
            shutil.copy(f, dst_dir / f.name)


#DATA AUGMENTATION:
#if the training set of a class is less than 75 ==> augment
#if the validation set of a class is less than 25 ==> augment
random.seed(0)
np.random.seed(0)

def augment_array(data: np.ndarray) -> np.ndarray:
    choice = random.choice(["flip_h", "flip_v", "rot90", "rot180", "rot270", "noise"])

    if choice == "flip_h":
        data_aug = np.flip(data, axis=2)  
    elif choice == "flip_v":
        data_aug = np.flip(data, axis=1)  
    elif choice == "rot90":
        data_aug = np.rot90(data, k=1, axes=(1, 2))
    elif choice == "rot180":
        data_aug = np.rot90(data, k=2, axes=(1, 2))
    elif choice == "rot270":
        data_aug = np.rot90(data, k=3, axes=(1, 2))
    else:  
        # "noise"
        if np.issubdtype(data.dtype, np.integer):
            max_val = np.iinfo(data.dtype).max
        else:
            max_val = 1.0
        std = 0.02 * max_val
        noise = np.random.normal(0, std, size=data.shape)
        data_aug = data.astype(np.float32) + noise
        data_aug = np.clip(data_aug, 0, max_val).astype(data.dtype)

    return data_aug

def augment_tif(in_path: Path, out_path: Path):
    with rasterio.open(in_path) as src:
        data = src.read()   
        profile = src.profile.copy()

    data_aug = augment_array(data)
    with rasterio.open(out_path, 'w', **profile) as dst:
        dst.write(data_aug)

target_per_split = {
    "training": 75,
    "validation": 25,
}

for split, target in target_per_split.items():
    print(f"\nBalancing split '{split}' to {target} chips per class")
    split_dir = subset_root / split

    for cls in valid_classes:
        cls_dir = split_dir / cls
        tif_files = sorted(cls_dir.glob("*.tif"))
        n = len(tif_files)

        if n > target:
            keep_files = set(random.sample(tif_files, target))
            for f in tif_files:
                if f not in keep_files:
                    f.unlink()
            print(f"Class {cls} ({split}): {n} -> {target} (undersampled)")

        elif n < target:
            print(f"Class {cls} ({split}): {n} -> {target} (augmenting)")
            original_files = tif_files.copy()
            k = 0
            while len(list(cls_dir.glob("*.tif"))) < target:
                src_path = random.choice(original_files)
                new_name = f"{src_path.stem}_aug{k}.tif"
                out_path = cls_dir / new_name
                augment_tif(src_path, out_path)
                k += 1

        else:
            print(f"Class {cls} ({split}): already {target}")


train_data_root = subset_root / "training"
val_data_root = subset_root / "validation"
test_data_root = val_data_root  

num_classes = len(valid_classes)

means = [0.0] * num_bands
stds = [1.0] * num_bands

datamodule = GenericNonGeoClassificationDataModule(
    batch_size=4,
    num_workers=2,
    train_data_root=train_data_root,
    val_data_root=val_data_root,
    test_data_root=test_data_root,
    means=means,
    stds=stds,
    num_classes=num_classes,
)

datamodule.setup("fit")
batch = next(iter(datamodule.train_dataloader()))
print("Batch image shape:", batch["image"].shape, "label shape:", batch["label"].shape)
print("Num train samples:", len(datamodule.train_dataset))
print("Num val samples:", len(datamodule.val_dataset))

pl.seed_everything(0)

class_names = valid_classes  

model_args = dict(
    backbone="prithvi_eo_v2_300",
    backbone_pretrained=True,
    backbone_num_frames=1,
    backbone_bands=[
        HLSBands.BLUE,
        HLSBands.GREEN,
        HLSBands.RED,
        HLSBands.NIR_NARROW,
        HLSBands.SWIR_1,
        HLSBands.SWIR_2,
    ],
    decoder="IdentityDecoder",
    num_classes=num_classes,
    head_dropout=0.1,
)

task = ClassificationTask(
    model_args=model_args,
    model_factory="EncoderDecoderFactory",
    lr=1e-4,
    optimizer="AdamW",
    optimizer_hparams={"weight_decay": 0.01},
    class_names=class_names,
)

torch.cuda.empty_cache()

trainer = pl.Trainer(
    accelerator="auto",
    devices="auto",
    max_epochs=30,
    log_every_n_steps=1,
    check_val_every_n_epoch=1,
    precision="16-mixed",
    accumulate_grad_batches=4,
)

trainer.fit(task, datamodule=datamodule)
trainer.test(task, datamodule=datamodule)




1 -> train: 75 val: 25
10 -> train: 1 val: 0
2 -> train: 40 val: 17
3 -> train: 81 val: 19
4 -> train: 84 val: 16
5 -> train: 75 val: 25
6 -> train: 30 val: 8
7 -> train: 83 val: 17
8 -> train: 88 val: 12

Classes with at least 1 .tif: ['1', '2', '3', '4', '5', '6', '7', '8']





Balancing split 'training' to 75 chips per class
Class 1 (training): already 75
Class 2 (training): 40 -> 75 (augmenting)




Class 3 (training): 81 -> 75 (undersampled)
Class 4 (training): 84 -> 75 (undersampled)
Class 5 (training): already 75
Class 6 (training): 30 -> 75 (augmenting)




Class 7 (training): 83 -> 75 (undersampled)
Class 8 (training): 88 -> 75 (undersampled)

Balancing split 'validation' to 25 chips per class
Class 1 (validation): already 25
Class 2 (validation): 17 -> 25 (augmenting)




Class 3 (validation): 19 -> 25 (augmenting)




Class 4 (validation): 16 -> 25 (augmenting)




Class 5 (validation): already 25
Class 6 (validation): 8 -> 25 (augmenting)




Class 7 (validation): 17 -> 25 (augmenting)




Class 8 (validation): 12 -> 25 (augmenting)


INFO:terratorch:Checking stackability.
INFO: Seed set to 0
INFO:lightning.fabric.utilities.seed:Seed set to 0


Batch image shape: torch.Size([4, 6, 224, 224]) label shape: torch.Size([4])
Num train samples: 600
Num val samples: 200


Prithvi_EO_V2_300M.pt:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

INFO:root:Loaded weights for HLSBands.BLUE in position 0 of patch embed
INFO:root:Loaded weights for HLSBands.GREEN in position 1 of patch embed
INFO:root:Loaded weights for HLSBands.RED in position 2 of patch embed
INFO:root:Loaded weights for HLSBands.NIR_NARROW in position 3 of patch embed
INFO:root:Loaded weights for HLSBands.SWIR_1 in position 4 of patch embed
INFO:root:Loaded weights for HLSBands.SWIR_2 in position 5 of patch embed
INFO: Trainer will use only 1 of 2 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=2)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
INFO:lightning.pytorch.utilities.rank_zero:Trainer will use only 1 of 2 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=2)` but please note that multi-GPU inside interactive / notebook environments is considered

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

INFO:terratorch:Checking stackability.
INFO:terratorch:Checking stackability.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=30` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=30` reached.
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
INFO:terratorch:Checking stackability.


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test/loss': 2.0332818031311035,
  'test/Accuracy': 0.3199999928474426,
  'test/Accuracy_Micro': 0.3199999928474426,
  'test/Class_Accuracy_1': 0.6000000238418579,
  'test/Class_Accuracy_2': 0.11999999731779099,
  'test/Class_Accuracy_3': 0.20000000298023224,
  'test/Class_Accuracy_4': 0.47999998927116394,
  'test/Class_Accuracy_5': 0.11999999731779099,
  'test/Class_Accuracy_6': 0.5600000023841858,
  'test/Class_Accuracy_7': 0.0,
  'test/Class_Accuracy_8': 0.47999998927116394,
  'test/Class_F1_1': 0.4000000059604645,
  'test/Class_F1_2': 0.0833333358168602,
  'test/Class_F1_3': 0.21276596188545227,
  'test/Class_F1_4': 0.47058823704719543,
  'test/Class_F1_5': 0.15000000596046448,
  'test/Class_F1_6': 0.6511628031730652,
  'test/Class_F1_7': 0.0,
  'test/Class_F1_8': 0.6153846383094788,
  'test/F1_Score': 0.3229043781757355,
  'test/Precision': 0.3609451949596405,
  'test/Recall': 0.3199999928474426}]

In [5]:
import torch
from sklearn.metrics import confusion_matrix, classification_report
import numpy as np

lit_model = task

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lit_model.eval()
lit_model.to(device)

all_preds = []
all_labels = []

test_loader = datamodule.test_dataloader()

with torch.no_grad():
    for batch in test_loader:
        x = batch["image"].to(device)
        y = batch["label"].to(device)
        output = lit_model.model(x)

        if isinstance(output, torch.Tensor):
            logits = output
        else:
            logits = None
            for attr in ["logits", "preds", "logits_per_image"]:
                if hasattr(output, attr):
                    logits = getattr(output, attr)
                    break

            if logits is None:
                for name in dir(output):
                    if name.startswith("_"):
                        continue
                    try:
                        val = getattr(output, name)
                    except Exception:
                        continue
                    if isinstance(val, torch.Tensor):
                        logits = val
                        break

            if logits is None:
                raise RuntimeError(
                    f"No tesnsors. "
                    f"Type: {type(output)}; attributes: {[n for n in dir(output) if not n.startswith('_')]}"
                )

        preds = logits.argmax(dim=1)

        all_preds.append(preds.cpu().numpy())
        all_labels.append(y.cpu().numpy())

all_preds = np.concatenate(all_preds)
all_labels = np.concatenate(all_labels)

cm = confusion_matrix(all_labels, all_preds)
print("Confusion matrix:\n", cm)


INFO:terratorch:Checking stackability.


Confusion matrix:
 [[15  5  0  0  3  0  2  0]
 [ 9  3  5  3  4  0  1  0]
 [ 5 11  5  1  2  1  0  0]
 [ 2  4  2 12  2  1  2  0]
 [ 6  6  4  3  3  0  1  2]
 [ 3  4  1  2  0 14  1  0]
 [10 10  3  1  1  0  0  0]
 [ 0  4  2  4  0  2  1 12]]
