In [2]:
!pip uninstall -y torch torchvision torchaudio

!pip install \
  torch==2.5.1 \
  torchvision==0.20.1 \
  torchaudio==2.5.1 \
  --index-url https://download.pytorch.org/whl/cu124

!pip install -q terratorch==1.1 lightning albumentations rasterio
!pip install "protobuf<5.0.0" --force-reinstall
!pip install -q rasterio
!pip install -q terratorch
!pip install -q pytorch-lightning torchmetrics


Found existing installation: torch 2.6.0+cu124
Uninstalling torch-2.6.0+cu124:
  Successfully uninstalled torch-2.6.0+cu124
Found existing installation: torchvision 0.21.0+cu124
Uninstalling torchvision-0.21.0+cu124:
  Successfully uninstalled torchvision-0.21.0+cu124
Found existing installation: torchaudio 2.6.0+cu124
Uninstalling torchaudio-2.6.0+cu124:
  Successfully uninstalled torchaudio-2.6.0+cu124
Looking in indexes: https://download.pytorch.org/whl/cu124
Collecting torch==2.5.1
  Downloading https://download.pytorch.org/whl/cu124/torch-2.5.1%2Bcu124-cp311-cp311-linux_x86_64.whl (908.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m908.3/908.3 MB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting torchvision==0.20.1
  Downloading https://download.pytorch.org/whl/cu124/torchvision-0.20.1%2Bcu124-cp311-cp311-linux_x86_64.whl (7.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.3/7.3 MB[0m [31m5.1 MB/s[0m eta 

In [17]:
import os
import torch
import terratorch
import albumentations
import lightning.pytorch as pl
import matplotlib.pyplot as plt
from pathlib import Path
from terratorch.datamodules.generic_scalar_label_data_module import GenericNonGeoClassificationDataModule
from terratorch.models import EncoderDecoderFactory
from terratorch.datasets import HLSBands
from terratorch.tasks import ClassificationTask
import rasterio
import shutil
import numpy as np
import random

In [19]:
pl.seed_everything(0)
np.random.seed(0)
torch.manual_seed(0)

data_root = Path("/kaggle/input/lucas-dataset/Lucas_dataset")
train_root = data_root / "training"
val_root = data_root / "validation"

valid_classes = []
for cls_dir in sorted(train_root.iterdir(), key=lambda p: p.name):
    if cls_dir.is_dir():
        cls_name = cls_dir.name
        train_tifs = list(cls_dir.glob("*.tif"))
        val_dir = val_root / cls_name
        val_tifs = list(val_dir.glob("*.tif")) if val_dir.exists() else []
        if len(train_tifs) > 0 and len(val_tifs) > 0:
            valid_classes.append(cls_name)

valid_classes = [c for c in valid_classes if c not in ["9", "10"]]
valid_classes = sorted(valid_classes, key=int)

subset_root = Path("/kaggle/working/Lucas_subset")
if subset_root.exists():
    shutil.rmtree(subset_root)

for split in ["training", "validation"]:
    src_split = data_root / split
    dst_split = subset_root / split
    dst_split.mkdir(parents=True, exist_ok=True)
    for cls in valid_classes:
        src_dir = src_split / cls
        dst_dir = dst_split / cls
        dst_dir.mkdir(parents=True, exist_ok=True)
        tif_files = list(src_dir.glob("*.tif"))
        for f in tif_files:
            shutil.copy(f, dst_dir / f.name)

def augment_array(data: np.ndarray) -> np.ndarray:
    choice = random.choice(["flip_h", "flip_v", "rot90", "rot180", "rot270", "noise"])
    if choice == "flip_h":
        data_aug = np.flip(data, axis=2)
    elif choice == "flip_v":
        data_aug = np.flip(data, axis=1)
    elif choice == "rot90":
        data_aug = np.rot90(data, k=1, axes=(1, 2))
    elif choice == "rot180":
        data_aug = np.rot90(data, k=2, axes=(1, 2))
    elif choice == "rot270":
        data_aug = np.rot90(data, k=3, axes=(1, 2))
    else:
        if np.issubdtype(data.dtype, np.integer):
            max_val = np.iinfo(data.dtype).max
        else:
            max_val = 1.0
        std = 0.02 * max_val
        noise = np.random.normal(0, std, size=data.shape)
        data_aug = data.astype(np.float32) + noise
        data_aug = np.clip(data_aug, 0, max_val).astype(data.dtype)
    return data_aug

def augment_tif(in_path: Path, out_path: Path):
    with rasterio.open(in_path) as src:
        data = src.read()
        profile = src.profile.copy()
    data_aug = augment_array(data)
    with rasterio.open(out_path, "w", **profile) as dst:
        dst.write(data_aug)

target_per_split = {"training": 75, "validation": 25}

for split, target in target_per_split.items():
    split_dir = subset_root / split
    for cls in valid_classes:
        cls_dir = split_dir / cls
        tif_files = sorted(cls_dir.glob("*.tif"))
        n = len(tif_files)
        if n > target:
            keep_files = set(random.sample(tif_files, target))
            for f in tif_files:
                if f not in keep_files:
                    f.unlink()
        elif n < target:
            original_files = tif_files.copy()
            k = 0
            while len(list(cls_dir.glob("*.tif"))) < target and original_files:
                src_path = random.choice(original_files)
                new_name = f"{src_path.stem}_aug{k}.tif"
                out_path = cls_dir / new_name
                augment_tif(src_path, out_path)
                k += 1

coarse_root = Path("/kaggle/working/Lucas_subset_coarse")
if coarse_root.exists():
    shutil.rmtree(coarse_root)

merge_scheme = {
    "0_Arable": ["1"],
    "1_Vegetation": ["2", "3", "5"],
    "2_Forest": ["4"],
    "3_BuiltBare": ["6", "7"],
    "4_Water": ["8"],
}

for split in ["training", "validation"]:
    for new_cls_name, old_cls_list in merge_scheme.items():
        new_dir = coarse_root / split / new_cls_name
        new_dir.mkdir(parents=True, exist_ok=True)
        for old_cls in old_cls_list:
            old_dir = subset_root / split / old_cls
            if not old_dir.exists():
                continue
            for f in old_dir.glob("*.tif"):
                dst = new_dir / f"{old_cls}_{f.name}"
                shutil.copy(f, dst)

train_data_root = coarse_root / "training"
val_data_root = coarse_root / "validation"
test_data_root = val_data_root

sample_tif = next(train_data_root.rglob("*.tif"))
with rasterio.open(sample_tif) as src:
    num_bands = src.count

means = [0.0] * num_bands
stds = [1.0] * num_bands

class_names = ["Arable", "Vegetation", "Forest", "BuiltBare", "Water"]
num_classes = len(class_names)

datamodule_coarse = GenericNonGeoClassificationDataModule(
    batch_size=4,
    num_workers=2,
    train_data_root=train_data_root,
    val_data_root=val_data_root,
    test_data_root=test_data_root,
    means=means,
    stds=stds,
    num_classes=num_classes,
)

datamodule_coarse.setup("fit")

model_args_coarse = dict(
    backbone="prithvi_eo_v2_300",
    backbone_pretrained=True,
    backbone_num_frames=1,
    backbone_bands=[
        HLSBands.BLUE,
        HLSBands.GREEN,
        HLSBands.RED,
        HLSBands.NIR_NARROW,
        HLSBands.SWIR_1,
        HLSBands.SWIR_2,
    ],
    decoder="IdentityDecoder",
    num_classes=num_classes,
    head_dropout=0.1,
)

task_coarse = ClassificationTask(
    model_args=model_args_coarse,
    model_factory="EncoderDecoderFactory",
    lr=1e-4,
    optimizer="AdamW",
    optimizer_hparams={"weight_decay": 0.01},
    class_names=class_names,
)

import torch.nn as nn
task_coarse.criterion = nn.CrossEntropyLoss()



torch.cuda.empty_cache()

trainer_coarse = pl.Trainer(
    accelerator="auto",
    devices="auto",
    max_epochs=15,
    log_every_n_steps=1,
    check_val_every_n_epoch=1,
    precision="16-mixed",
    accumulate_grad_batches=4,
    logger=False,
)

trainer_coarse.fit(task_coarse, datamodule=datamodule_coarse)
trainer_coarse.test(task_coarse, datamodule=datamodule_coarse)


INFO: Seed set to 0
INFO:lightning.fabric.utilities.seed:Seed set to 0
INFO:root:Loaded weights for HLSBands.BLUE in position 0 of patch embed
INFO:root:Loaded weights for HLSBands.GREEN in position 1 of patch embed
INFO:root:Loaded weights for HLSBands.RED in position 2 of patch embed
INFO:root:Loaded weights for HLSBands.NIR_NARROW in position 3 of patch embed
INFO:root:Loaded weights for HLSBands.SWIR_1 in position 4 of patch embed
INFO:root:Loaded weights for HLSBands.SWIR_2 in position 5 of patch embed
INFO: Trainer will use only 1 of 2 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=2)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
INFO:lightning.pytorch.utilities.rank_zero:Trainer will use only 1 of 2 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=2)` but please note 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

INFO:terratorch:Checking stackability.
INFO:terratorch:Checking stackability.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=15` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=15` reached.
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
INFO:terratorch:Checking stackability.


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test/loss': 1.2710461616516113,
  'test/Accuracy': 0.3866666555404663,
  'test/Accuracy_Micro': 0.49000000953674316,
  'test/Class_Accuracy_Arable': 0.1599999964237213,
  'test/Class_Accuracy_Vegetation': 0.8533333539962769,
  'test/Class_Accuracy_Forest': 0.20000000298023224,
  'test/Class_Accuracy_BuiltBare': 0.2800000011920929,
  'test/Class_Accuracy_Water': 0.4399999976158142,
  'test/Class_F1_Arable': 0.21052631735801697,
  'test/Class_F1_Vegetation': 0.5791855454444885,
  'test/Class_F1_Forest': 0.3030303120613098,
  'test/Class_F1_BuiltBare': 0.4000000059604645,
  'test/Class_F1_Water': 0.5789473652839661,
  'test/F1_Score': 0.4143378734588623,
  'test/Precision': 0.5834404826164246,
  'test/Recall': 0.3866666555404663}]