# Prithvi Sen3 CM Model config and training

Run prithvi_S3_CM_datapreproc.ipynb before in order to create the dataset.

In [None]:
!pip install terratorch==1.0.1

In [None]:
import os
import sys
import torch
import gdown
import terratorch
import albumentations
import lightning.pytorch as pl
import matplotlib.pyplot as plt
from pathlib import Path
from terratorch.datamodules import GenericNonGeoSegmentationDataModule
import pandas as pd

First we create and analyze the datamodule!


In [None]:
from google.colab import drive
drive.mount('/content/drive')
base_dir = '/content/drive/MyDrive/terratorch_S3_CM/'

dataset_path = Path(base_dir+'/merged/')

In [None]:
# Load means and stds
with open(dataset_path / 'data/means_stds.txt') as f:
    lines = f.readlines()[2:]  # Skip header and separator
    stats = [tuple(map(float, line.strip().split()[1:])) for line in lines]

stats

In [None]:
datamodule = terratorch.datamodules.GenericNonGeoSegmentationDataModule(
    batch_size=8,
    num_workers=2,
    num_classes=2,

    # Define dataset paths
    train_data_root=dataset_path / 'data/train-data',
    train_label_data_root=dataset_path / 'data/train-data',
    val_data_root=dataset_path / 'data/val-data',
    val_label_data_root=dataset_path / 'data/val-data',
    test_data_root=dataset_path / 'data/test-data',
    test_label_data_root=dataset_path / 'data/test-data',

    # Define splits
    train_split=dataset_path / 'data/splits/train.txt',
    val_split=dataset_path / 'data/splits/val.txt',
    test_split=dataset_path / 'data/splits/test.txt',

    img_grep='*_reflectance.tif',
    label_grep='*_binary.tif',

    train_transform=[
        albumentations.D4(), # Random flips and rotation
        albumentations.pytorch.transforms.ToTensorV2(),
    ],
    val_transform=None,  # Using ToTensor() by default
    test_transform=None,

    # Define standardization values
    means=[
      stats[0][0],
      stats[1][0],
      stats[2][0],
      stats[3][0],
      stats[4][0],
      stats[5][0],
    ],
    stds=[
      stats[0][1],
      stats[1][1],
      stats[2][1],
      stats[3][1],
      stats[4][1],
      stats[5][1],
    ],
    no_data_replace = 0,
    no_label_replace = -1,
    # We use all six bands of the data, so we don't need to define dataset_bands and output_bands.
)

# Setup train and val datasets
datamodule.setup("fit")

In [None]:
datamodule.setup("fit")
train_dataset = datamodule.train_dataset
val_dataset = datamodule.val_dataset

datamodule.setup("test")
test_dataset = datamodule.test_dataset

print("Train:", len(train_dataset))
print("Val:  ", len(val_dataset))
print("Test: ", len(test_dataset))



In [None]:
import numpy as np
import torch
from collections import Counter

def analyze_class_distribution(dataset, name="dataset"):
    pixel_counter = Counter()

    for i in range(len(dataset)):
        mask = dataset[i]['mask']
        if isinstance(mask, torch.Tensor):
            mask = mask.squeeze().cpu().numpy()  # Remove channel dim if exists
        unique, counts = np.unique(mask, return_counts=True)
        pixel_counter.update(dict(zip(unique, counts)))

    total_pixels = sum(pixel_counter.values())

    print(f"\n📊 Class distribution in '{name}':")
    for label, count in sorted(pixel_counter.items()):
        percentage = (count / total_pixels) * 100
        print(f"Class {int(label)}: {count} pixels ({percentage:.2f}%)")

    return pixel_counter

# Analyze each split
train_counts = analyze_class_distribution(train_dataset, name="train")
val_counts = analyze_class_distribution(val_dataset, name="val")
test_counts = analyze_class_distribution(test_dataset, name="test")



In [None]:
# check if there is any overlap in the data
train_ids = set(sample["filename"] for sample in train_dataset)
val_ids = set(sample["filename"] for sample in val_dataset)
test_ids = set(sample["filename"] for sample in test_dataset)

overlap_train_val = train_ids & val_ids
overlap_train_test = train_ids & test_ids
overlap_val_test = val_ids & test_ids

with open("overlap_report.txt", "w") as f:
    f.write("Train ∩ Val:\n")
    f.writelines(f"{fn}\n" for fn in sorted(overlap_train_val))
    f.write("\nTrain ∩ Test:\n")
    f.writelines(f"{fn}\n" for fn in sorted(overlap_train_test))
    f.write("\nVal ∩ Test:\n")
    f.writelines(f"{fn}\n" for fn in sorted(overlap_val_test))

print("✅ Saved overlaps to overlap_report.txt")


In [None]:
# plotting a few samples
print('Train Plots')
train_dataset.plot(train_dataset[1])
train_dataset.plot(train_dataset[4])
train_dataset.plot(train_dataset[8])
# plotting a few samples
print('Val Plots')
val_dataset.plot(val_dataset[1])
val_dataset.plot(val_dataset[4])
val_dataset.plot(val_dataset[8])

In [None]:
# checking datasets testing split size
datamodule.setup("test")
test_dataset = datamodule.test_dataset
len(test_dataset)

# Fine-tune Prithvi

In [None]:
# first we create a dictionary, so that we can run through different configurations:
model_configs = {
    "exp_0": {
        "exp_no": 0,
        "backbone": "prithvi_eo_v2_300",
        "backbone_pretrained": False,
        "backbone_encoding": [],
        "freeze_backbone": False,
        "freeze_decoder": False,
        "max_epochs": 5,
    },

    "exp_1": {
        "exp_no": 1,
        "backbone": "prithvi_eo_v2_300",
        "backbone_pretrained": False,
        "backbone_encoding": [],
        "freeze_backbone": True,
        "freeze_decoder": False,
        "max_epochs": 20,
    },

    "exp_2": {
        "exp_no": 2,
        "backbone": "prithvi_eo_v2_300",
        "backbone_pretrained": True,
        "backbone_encoding": [],
        "freeze_backbone": False,
        "freeze_decoder": False,
        "max_epochs": 20,
    },

    "exp_3": {
        "exp_no": 3,
        "backbone": "prithvi_eo_v2_300",
        "backbone_pretrained": True,
        "backbone_encoding": [],
        "freeze_backbone": True,
        "freeze_decoder": False,
        "max_epochs": 20,
    },

    "exp_4": {
        "exp_no": 4,
        "backbone": "prithvi_eo_v2_300_tl",
        "backbone_pretrained": True,
        "backbone_encoding": ['time','location'],
        "freeze_backbone": True,
        "freeze_decoder": False,
        "max_epochs": 20,
    },

    "exp_5": {
        "exp_no": 5,
        "backbone": "prithvi_eo_v2_600",
        "backbone_pretrained": False,
        "backbone_encoding": [],
        "freeze_backbone": False,
        "freeze_decoder": False,
        "max_epochs": 20,
    },

    "exp_6": {
        "exp_no": 6,
        "backbone": "prithvi_eo_v2_600",
        "backbone_pretrained": True,
        "backbone_encoding": [],
        "freeze_backbone": True,
        "freeze_decoder": False,
        "max_epochs": 20,
    },
    "exp_7": {
        "exp_no": 7,
        "backbone": "prithvi_eo_v2_300",
        "backbone_pretrained": False,
        "backbone_encoding": [],
        "freeze_backbone": True,
        "freeze_decoder": True,
        "max_epochs": 20,
    },
    "exp_31": {
        "exp_no": 31,
        "backbone": "prithvi_eo_v2_300",
        "backbone_pretrained": False,
        "backbone_encoding": [],
        "freeze_backbone": True,
        "freeze_decoder": False,
        "max_epochs": 150,
    },

    "exp_33": {
        "exp_no": 33,
        "backbone": "prithvi_eo_v2_300",
        "backbone_pretrained": True,
        "backbone_encoding": [],
        "freeze_backbone": True,
        "freeze_decoder": False,
        "max_epochs": 150,
    }



    # Add more experiments here...
}

# Convert to DataFrame
df = pd.DataFrame.from_dict(model_configs, orient="index")

# Save to CSV
df.to_csv("model_configs.csv", index_label="experiment")

print("✅ Saved: model_configs.csv")


# These functions have to be defined before running the fine-tuning loop!

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import os

def generate_uncertainty_maps(model, test_loader, exp_no, save_dir=f"{base_dir}/plots/uncertainty_art"):
    save_dir = os.path.join(save_dir, f"exp_{exp_no}")
    os.makedirs(save_dir, exist_ok=True)



    for batch_idx, batch in enumerate(test_loader):
        with torch.no_grad():
            images = batch["image"].to(model.device)
            outputs = model(images)

            # Get probabilities
            probs = torch.softmax(outputs.output, dim=1)
            probs_np = probs.cpu().numpy()

            # Compute pixel-wise entropy (uncertainty)
            entropy = -np.sum(probs_np * np.log(probs_np + 1e-10), axis=1)

        for i in range(len(images)):
            plt.figure(figsize=(10, 10))
            plt.imshow(entropy[i], cmap='magma')
            plt.axis("off")
            plt.tight_layout(pad=0)

            filename = f"exp_{exp_no}_batch_{batch_idx}_sample_{i}_uncertainty.png"
            save_path = os.path.join(save_dir, filename)
            plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
            plt.close()

    print(f"✅ Saved all uncertainty maps to: {save_dir}")


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import os

def plot_experiment_metrics(exp_no):
    # === Find highest version folder ===
    exp_path = f"{base_dir}/logs/exp_{exp_no}"
    versions = [d for d in os.listdir(exp_path) if d.startswith("version_")]
    version_nums = [int(v.split("_")[1]) for v in versions]
    latest_version = f"version_{max(version_nums)}"
    metrics_path = os.path.join(exp_path, latest_version, "metrics.csv")

    # === Load CSV ===
    df = pd.read_csv(metrics_path)

    # === Step-level loss plot ===
    step_df = df.dropna(subset=["step", "train/loss"])
    grouped = step_df.groupby("step").mean()
    has_val_loss = "val/loss" in grouped.columns and not grouped["val/loss"].isnull().all()

    # === Epoch-level average loss ===
    val_df = df[df["val/loss"].notna()].drop_duplicates(subset="epoch")
    train_df = df[["epoch", "train/loss"]].dropna()
    avg_train = train_df.groupby("epoch").mean().rename(columns={"train/loss": "avg_train_loss"})
    plot_df = val_df.merge(avg_train, on="epoch")

    # === Plotting ===
    fig, axes = plt.subplots(1, 3, figsize=(20, 5))

    # Plot 1: Step-wise loss
    axes[0].plot(grouped.index, grouped["train/loss"], label="Train Loss", linewidth=2)
    if has_val_loss:
        axes[0].plot(grouped.index, grouped["val/loss"], label="Val Loss", linewidth=2)
    axes[0].set_ylim(0, 0.3)  # ✅ Fixed y-limits for loss
    axes[0].set_title("Loss Per Step")
    axes[0].set_xlabel("Step")
    axes[0].set_ylabel("Loss")
    axes[0].legend()
    axes[0].grid(True)

    # Plot 2: Epoch-wise loss
    axes[1].plot(plot_df["epoch"], plot_df["avg_train_loss"], label="Train Loss (avg)", marker="o")
    axes[1].plot(plot_df["epoch"], plot_df["val/loss"], label="Val Loss", marker="o")
    axes[1].set_ylim(0, 0.3)  # ✅ Fixed y-limits for loss
    axes[1].set_title("Loss Per Epoch")
    axes[1].set_xlabel("Epoch")
    axes[1].set_ylabel("Loss")
    axes[1].legend()
    axes[1].grid(True)

    # Plot 3: Validation metrics
    val_metrics = [
        col for col in plot_df.columns
        if col.startswith("val/") and "loss" not in col and plot_df[col].notna().any()
    ]
    # Plot 3: Validation metrics
    for metric in val_metrics:
        short_name = metric.replace("val/", "").replace("_", " ")
        axes[2].plot(plot_df["epoch"], plot_df[metric], label=short_name, marker="o")
    axes[2].set_ylim(0.6, 1.0)  # ✅ Fixed y-limits for val metrics


    axes[2].set_title("Validation Metrics Per Epoch")
    axes[2].set_xlabel("Epoch")
    axes[2].set_ylabel("Metric Value")
    axes[2].legend(loc="lower center", bbox_to_anchor=(0.5, -0.35), ncol=2)
    axes[2].grid(True)

    # Save and show
    plt.tight_layout()
    save_path = f"{base_dir}/plots/loss_metrics/exp_{exp_no}_combined_loss_and_metrics.png"
    plt.savefig(save_path)
    plt.show()
    print(f"✅ Saved: {save_path}")


In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import math
import os
from sklearn.metrics import jaccard_score
import matplotlib.colors as mcolors
from itertools import islice

def dice_coefficient(pred, target, smooth=1e-6):
    intersection = np.sum(pred * target)
    return (2. * intersection + smooth) / (np.sum(pred) + np.sum(target) + smooth)

def compute_jaccard_index(pred, target):
    return jaccard_score(target.flatten(), pred.flatten(), average='binary')

from pathlib import Path
import re

def extract_timestamps(filenames):
    """Extract timestamps like 20230510T101427 from a list of filenames."""
    timestamps = []
    for fname in filenames:
        match = re.search(r'\d{8}T\d{6}', Path(fname).name)
        timestamps.append(match.group(0) if match else None)
    return timestamps


def visualize_full_segmentation_analysis(exp_no, best_ckpt_path, model, datamodule, target_class=1,
                                          save_dir=f"{base_dir}/plots/full_analysis", max_rows_per_fig=4, amount_batches_to_plot ='all'):
    save_dir = os.path.join(save_dir, f"exp_{exp_no}")
    os.makedirs(save_dir, exist_ok=True)

    # --- Load Model from checkpoint ---
  #  model = model.__class__.load_from_checkpoint(
   #     best_ckpt_path,
    #    model_factory=model.hparams.model_factory,
     #   model_args=model.hparams.model_args,
    #)

    trainer.test(model, datamodule=datamodule, ckpt_path=best_ckpt_path)

    model = terratorch.tasks.SemanticSegmentationTask.load_from_checkpoint(
        best_ckpt_path,
        model_factory=model.hparams.model_factory,
        model_args=model.hparams.model_args,
    )


    # --- Load test dataloader ---
    test_loader = datamodule.test_dataloader()


    if amount_batches_to_plot == 'all':

        amount_batches_to_plot = len(test_loader)

    # inialize list

    jaccard_list = []
    dice_list = []

    for batch_idx, batch in enumerate(islice(test_loader, amount_batches_to_plot)):
        print(f"Processing batch {batch_idx}...")



        with torch.no_grad():
                    images = datamodule.aug(batch)
                    images = batch["image"].to(model.device)
                    filenames = batch["filename"]
                    timestamps = extract_timestamps(filenames)
                    outputs = model(images)

                    probs = torch.softmax(outputs.output, dim=1).cpu().numpy()
                    preds = torch.argmax(outputs.output, dim=1).cpu().numpy()

        images_np = images.cpu().numpy()
        gts = batch["mask"].numpy() if "mask" in batch else np.zeros_like(preds)

        num_images = len(images)
        cols = 5
        num_parts = math.ceil(num_images / max_rows_per_fig)

        for part_idx in range(num_parts):
            start_idx = part_idx * max_rows_per_fig
            end_idx = min((part_idx + 1) * max_rows_per_fig, num_images)
            current_rows = end_idx - start_idx

            fig, axes = plt.subplots(current_rows, cols, figsize=(cols * 4, current_rows * 3))

            for i in range(start_idx, end_idx):
                row = i - start_idx
                row_axes = axes[row] if current_rows > 1 else axes

                img = images_np[i][0]
                img = (img - img.min()) / (img.max() - img.min() + 1e-8)

                gt_mask = gts[i]
                pred_mask = preds[i]
                diff_mask = gt_mask - pred_mask
                prob_map = probs[i, target_class, :, :]

                # --- Metrics ---
                dice =  1 - dice_coefficient(pred_mask == target_class, gt_mask == target_class)
                jaccard = compute_jaccard_index(pred_mask == target_class, gt_mask == target_class)

                dice_list.append(dice)
                jaccard_list.append(jaccard)

                # Track the Jaccard index and loss
                #total_jaccard += jaccard
                # Assuming that the model outputs raw logits or a similar format:
                #v#al_loss = dice  # You can modify this to whatever loss you are calculating
                #total_loss += val_loss

                titles = [
                    f"Sen3/SLSTR Band2 (659nm)\n{timestamps[i]}",
                    "Reference Mask",
                    f"Prediction\nDice Loss: {dice:.3f}, Jaccard: {jaccard:.3f}",
                    "difference(Ref - Pred)",
                    f"Prob. Class {target_class}"
                ]

                visuals = [img, gt_mask, pred_mask, diff_mask, prob_map]

                # Define custom colormap for the input image
                cmap_input = mcolors.LinearSegmentedColormap.from_list("", ["black", "blue", "red", "yellow"])

                # Define colormap for the gt_mask and pred_mask (discrete colors: blue and white)
                cmap_mask = mcolors.ListedColormap(['blue', 'white'])

                # Define the colormap for the other images
                #cmaps = [cmap_input, cmap_mask, cmap_mask, "bwr", "viridis"]
                cmaps = ["cubehelix", cmap_mask, cmap_mask, "bwr", "viridis"]

                # Define the vmin and vmax for each image
                vmins = [0, 0, 0, -1, 0]
                vmaxs = [1, 1, 1, 1, 1]

                # Loop through columns and plot images with appropriate colormap and colorbars
                for j in range(cols):
                    ax = row_axes[j]
                    im = ax.imshow(visuals[j], cmap=cmaps[j], vmin=vmins[j], vmax=vmaxs[j])
                    ax.set_title(titles[j])
                    ax.grid(True, linestyle=":", linewidth=0.5)

                    # Add colorbar based on the visual type
                    if j == 1 or j == 2:  # For gt_mask and pred_mask, use the discrete colormap
                        fig.colorbar(im, ax=ax, ticks=[0, 1], fraction=0.046, pad=0.04)
                    else:
                        fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

            plt.tight_layout()
            filename = f"exp_{exp_no}_batch_{batch_idx}_part_{part_idx}.png"
            save_path = os.path.join(save_dir, filename)
            plt.savefig(save_path)
            plt.close()
            print(f"✅ Saved: {save_path}")



    # After all batches, calculate mean Jaccard index and mean loss
    mean_jaccard = sum(jaccard_list) / len(jaccard_list)
    mean_loss = sum(dice_list) / len(dice_list)


    # Save the metrics to a text file
    metrics_file = os.path.join(save_dir, "metrics.txt")
    with open(metrics_file, "w") as f:
        f.write(f"Mean Jaccard Index: {mean_jaccard:.4f}\n")
        f.write(f"Mean Validation Loss: {mean_loss:.4f}\n")
        f.write(f"Jaccard list: {jaccard_list}\n")
        f.write(f"Dice list: {dice_list}\n")
    print(f"✅ Metrics saved to: {metrics_file}")


# TRAINING in a Loop

In [None]:
import os

results_file = "experiment_results.txt"

# Only create and write header if the file does not exist
if not os.path.exists(results_file):
    with open(results_file, "w") as f:
        f.write("Experiment Results:\n")

for config_name, cfg in list(model_configs.items())[0:1]:   ## CHANGE HERE FOR ALL
    print(f"\n🚀 Running {config_name}...")


    pl.seed_everything(0)
    ####################
    exp_no = cfg['exp_no']
    ####################

    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath=f"{base_dir}/output/sen3_cm/checkpoints/exp_{exp_no}",
        mode="max",
        monitor="val/Multiclass_Jaccard_Index", # Variable to monitor
        filename="best-{epoch:02d}",
        save_top_k=1
    )


    from pytorch_lightning.loggers import CSVLogger
    logger = CSVLogger(f"{base_dir}/logs/", name=f"exp_{exp_no}")


    # Lightning Trainer
    trainer = pl.Trainer(
        accelerator="auto",
        strategy="auto",
        devices=1, # Deactivate multi-gpu because it often fails in notebooks
        precision='16-mixed',  # Speed up training
        num_nodes=1,
        logger=logger,
        max_epochs=cfg["max_epochs"],
        log_every_n_steps=20,
        enable_checkpointing=True,
        callbacks=[checkpoint_callback, pl.callbacks.RichProgressBar()],
        default_root_dir=f"{base_dir}/output/sen3_cm",
    )

    if cfg["backbone"].startswith("prithvi_eo_v2_100"):
        indices = [2, 5, 8, 11]
    elif cfg["backbone"].startswith("prithvi_eo_v2_300"):
        indices = [5, 11, 17, 23]
    elif cfg["backbone"].startswith("prithvi_eo_v2_600"):
        indices = [7, 15, 23, 31]
    else:
        raise ValueError(f"Unsupported backbone: {cfg['backbone']}")

    # Model
    model = terratorch.tasks.SemanticSegmentationTask(
        model_factory="EncoderDecoderFactory",
        model_args={
            # Backbone
            "backbone": cfg["backbone"],
            "backbone_pretrained": cfg["backbone_pretrained"],
            "backbone_num_frames": 1, # 1 is the default value
            "backbone_bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"], #this is actually not true when using Sen3 data
            "backbone_coords_encoding": cfg["backbone_encoding"],

            # Necks
            "necks": [
                {"name": "SelectIndices",
                 "indices": indices},
                {"name": "ReshapeTokensToImage",},
                {"name": "LearnedInterpolateToPyramidal"}
            ],

            # Decoder
            "decoder": "UNetDecoder",
            "decoder_channels": [512, 256, 128, 64],

            # Head
            "head_dropout": 0.1,
            "num_classes": 2,
        },

        loss="dice",
        optimizer="AdamW",
        lr=1e-4,
        ignore_index=-1,
        freeze_backbone=cfg["freeze_backbone"],
        freeze_decoder=cfg["freeze_decoder"],
        plot_on_val=True,
        class_names=['clear-sky ocean', 'else']  # optionally define class names
    )

    # Training
    trainer.fit(model, datamodule=datamodule)
    print(f"✅ Best checkpoint: {checkpoint_callback.best_model_path}")
    print(f"🏆 Best score: {checkpoint_callback.best_model_score}")

    with open(results_file, "a") as f:
        f.write(f"\n📁 Experiment: {config_name}\n")
        f.write(f"Best Checkpoint: {checkpoint_callback.best_model_path}\n")
        f.write(f"Best Score: {checkpoint_callback.best_model_score:.4f}\n")

    plot_experiment_metrics(exp_no)
    best_ckpt_path = checkpoint_callback.best_model_path

    trainer.test(model, datamodule=datamodule, ckpt_path=best_ckpt_path)
    test_loader = datamodule.test_dataloader()
    generate_uncertainty_maps(model, test_loader, exp_no)

    visualize_full_segmentation_analysis(
        exp_no,
        best_ckpt_path,
        model=model,
        datamodule=datamodule
        )





# Here you can run the plots without training

comment: complete model definition might be unnecessary here

In [None]:
import os
import glob


for i, (config_name, cfg) in enumerate(model_configs.items()):
    print(config_name)
    #if i in [0,1,2,3,4,5,6,7,8,9]:
    if i in [0]:
        print(f"🚀 Running evaluation of {config_name}...")

        pl.seed_everything(0)
        ####################
        exp_no = cfg['exp_no']
        ####################

        checkpoint_callback = pl.callbacks.ModelCheckpoint(
            dirpath=f"output/sen3_cm/checkpoints/exp_{exp_no}",
            mode="max",
            monitor="val/Multiclass_Jaccard_Index", # Variable to monitor
            filename="best-{epoch:02d}",
            save_top_k=1
        )


        from pytorch_lightning.loggers import CSVLogger
        logger = CSVLogger("logs/", name=f"exp_{exp_no}")


        # Lightning Trainer
        trainer = pl.Trainer(
            accelerator="auto",
            strategy="auto",
            devices=1, # Deactivate multi-gpu because it often fails in notebooks
            precision='16-mixed',  # Speed up training
            num_nodes=1,
            logger=logger,
            max_epochs=cfg["max_epochs"],
            log_every_n_steps=20,
            enable_checkpointing=True,
            callbacks=[checkpoint_callback, pl.callbacks.RichProgressBar()],
            default_root_dir="output/sen3_cm",
        )

        if cfg["backbone"].startswith("prithvi_eo_v2_100"):
            indices = [2, 5, 8, 11]
        elif cfg["backbone"].startswith("prithvi_eo_v2_300"):
            indices = [5, 11, 17, 23]
        elif cfg["backbone"].startswith("prithvi_eo_v2_600"):
            indices = [7, 15, 23, 31]
        else:
            raise ValueError(f"Unsupported backbone: {cfg['backbone']}")

        # Model
        model = terratorch.tasks.SemanticSegmentationTask(
            model_factory="EncoderDecoderFactory",
            model_args={
                # Backbone
                "backbone": cfg["backbone"],
                "backbone_pretrained": cfg["backbone_pretrained"],
                "backbone_num_frames": 1, # 1 is the default value
                "backbone_bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"], #this is actually not true when using Sen3 data
                "backbone_coords_encoding": cfg["backbone_encoding"],

                # Necks
                "necks": [
                    {"name": "SelectIndices",
                     "indices": indices},
                    {"name": "ReshapeTokensToImage",},
                    {"name": "LearnedInterpolateToPyramidal"}
                ],

                # Decoder
                "decoder": "UNetDecoder",
                "decoder_channels": [512, 256, 128, 64],

                # Head
                "head_dropout": 0.1,
                "num_classes": 2,
            },

            loss="dice",
            optimizer="AdamW",
            lr=1e-4,
            ignore_index=-1,
            freeze_backbone=cfg["freeze_backbone"],
            freeze_decoder=cfg["freeze_decoder"],
            plot_on_val=True,
            class_names=['clear-sky ocean', 'else']  # optionally define class names
        )

        # === Find best .ckpt file (highest epoch number) ===
        checkpoint_dir = f"{base_dir}/output/sen3_cm/checkpoints/exp_{exp_no}"
        ckpt_files = glob.glob(os.path.join(checkpoint_dir, "*.ckpt"))

        if not ckpt_files:
            print(f"⚠️ No checkpoint found in {checkpoint_dir}. Skipping.")
            continue

        def extract_epoch(filename):
            base = os.path.basename(filename)
            if "epoch=" in base:
                try:
                    return int(base.split("epoch=")[-1].split(".")[0])
                except:
                    return -1
            return -1

        best_ckpt_path = max(ckpt_files, key=extract_epoch)
        print(f"📦 Loading checkpoint: {best_ckpt_path}")

       # plot_experiment_metrics(exp_no)
        trainer.test(model, datamodule=datamodule, ckpt_path=best_ckpt_path)
        #test_loader = datamodule.test_dataloader()
        #generate_uncertainty_maps(model, test_loader, exp_no)


        visualize_full_segmentation_analysis(
            exp_no,
            best_ckpt_path,
            model=model,
            datamodule=datamodule,
            amount_batches_to_plot = 'all'
        )









# compare different loss and metrics in ONE plot

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import os

def plot_multiple_experiment_metrics(exp_nos, metric="Multiclass_Jaccard_Index"):
    # === Initialize lists to store data for all experiments ===
    all_train_losses = []
    all_val_losses = []
    all_jaccard_indices = []

    # To store min/max values for table export
    min_train_losses = []
    min_val_losses = []
    max_jaccard_indices = []

    # Loop over each experiment number and process
    for exp_no in exp_nos:
        print(f"\nProcessing experiment {exp_no}...")

        # === Find highest version folder ===
        exp_path = f"{base_dir}/logs/exp_{exp_no}"
        versions = [d for d in os.listdir(exp_path) if d.startswith("version_")]
        version_nums = [int(v.split("_")[1]) for v in versions]
        latest_version = f"version_{max(version_nums)}"
        metrics_path = os.path.join(exp_path, latest_version, "metrics.csv")

        # === Load CSV ===
        df = pd.read_csv(metrics_path)

        # === Epoch-level average loss ===
        val_df = df[df["val/loss"].notna()].drop_duplicates(subset="epoch")
        train_df = df[["epoch", "train/loss"]].dropna()
        avg_train = train_df.groupby("epoch").mean().rename(columns={"train/loss": "avg_train_loss"})
        plot_df = val_df.merge(avg_train, on="epoch")

        # Collect data for plotting
        all_train_losses.append(plot_df["avg_train_loss"])
        all_val_losses.append(plot_df["val/loss"])

        # Collect Jaccard Index for validation
        jaccard_index = plot_df.get(f"val/{metric}", None)
        if jaccard_index is not None:
            all_jaccard_indices.append(jaccard_index)
        else:
            all_jaccard_indices.append(None)

        # Store min/max values for table
        min_train_losses.append(min(plot_df["avg_train_loss"]))
        min_val_losses.append(min(plot_df["val/loss"]))
        if jaccard_index is not None:
            max_jaccard_indices.append(max(jaccard_index))
        else:
            max_jaccard_indices.append(None)

    # === Plotting ===
    fig, axes = plt.subplots(3, 1, figsize=(20, 15))

    # Plot 1: All Train Losses
    for i, train_loss in enumerate(all_train_losses):
        axes[0].plot(train_loss, label=f"Exp {exp_nos[i]} - Min Loss: {min_train_losses[i]:.4f}")
    axes[0].set_ylim(0, 0.3)  # Fixed y-limits for loss
    axes[0].set_title("Train Loss (Dice) Per Epoch")
    axes[0].set_xlabel("Epoch")
    axes[0].set_ylabel("Loss")
    axes[0].legend(ncol=2)
    axes[0].grid(True)
    axes[0].set_xticks(range(len(all_train_losses[0])))  # Set x-ticks to match number of epochs
    axes[0].set_xticklabels(range(1, len(all_train_losses[0]) + 1))  # Display epochs from 1 to n

    # Plot 2: All Validation Losses
    for i, val_loss in enumerate(all_val_losses):
        axes[1].plot(val_loss, label=f"Exp {exp_nos[i]} - Min Val Loss: {min_val_losses[i]:.4f}")
    axes[1].set_ylim(0, 0.3)  # Fixed y-limits for loss
    axes[1].set_title("Validation Loss (Dice) Per Epoch")
    axes[1].set_xlabel("Epoch")
    axes[1].set_ylabel("Loss")
    axes[1].legend(ncol=2)
    axes[1].grid(True)
    axes[1].set_xticks(range(len(all_val_losses[0])))  # Set x-ticks to match number of epochs
    axes[1].set_xticklabels(range(1, len(all_val_losses[0]) + 1))  # Display epochs from 1 to n

    # Plot 3: All Multiclass Jaccard Indices
    for i, jaccard_index in enumerate(all_jaccard_indices):
        if jaccard_index is not None:
            axes[2].plot(jaccard_index, label=f"Exp {exp_nos[i]} - Max Jaccard: {max_jaccard_indices[i]:.4f}")
    axes[2].set_ylim(0.6, 1.0)  # Fixed y-limits for Jaccard Index
    axes[2].set_title("Multiclass Jaccard Index Per Epoch")
    axes[2].set_xlabel("Epoch")
    axes[2].set_ylabel("Jaccard Index")
    axes[2].legend(ncol=2)
    axes[2].grid(True)
    axes[2].set_xticks(range(len(all_jaccard_indices[0])))  # Set x-ticks to match number of epochs
    axes[2].set_xticklabels(range(1, len(all_jaccard_indices[0]) + 1))  # Display epochs from 1 to n

    # Save and show
    plt.tight_layout()

    save_path = f"{base_dir}/plots/comparison_{metric}_exp_{'_'.join(map(str, exp_nos))}_loss_and_metrics.png"
    plt.savefig(save_path)
    plt.show()

    print(f"✅ Saved: {save_path}")

    # === Save Min/Max Losses and Jaccard Indices to a .txt File ===
    results_filename = f"{base_dir}/plots/experiment_comparison_min_max_{'_'.join(map(str, exp_nos))}.txt"
    with open(results_filename, "w") as f:
        f.write(f"Experiment Comparison: Min Losses, Max Jaccard Indices\n")
        for i, exp_no in enumerate(exp_nos):
            f.write(f"Exp {exp_no}: Min Train Loss: {min_train_losses[i]:.4f}, Min Val Loss: {min_val_losses[i]:.4f}, Max Jaccard Index: {max_jaccard_indices[i]:.4f}\n")

    print(f"✅ Min/Max values saved to: {results_filename}")


In [None]:
import pandas as pd
import os

def plot_multiple_experiment_metrics(exp_nos, metric="Multiclass_Jaccard_Index"):
    # === Initialize lists to store data for all experiments ===
    all_train_losses = []
    all_val_losses = []
    all_jaccard_indices = []

    # To store min/max values for table export
    min_train_losses = []
    min_val_losses = []
    max_jaccard_indices = []

    # Loop over each experiment number and process
    for exp_no in exp_nos:
        print(f"\nProcessing experiment {exp_no}...")

        # === Find highest version folder ===
        exp_path = f"{base_dir}/logs/exp_{exp_no}"
        versions = [d for d in os.listdir(exp_path) if d.startswith("version_")]
        version_nums = [int(v.split("_")[1]) for v in versions]
        latest_version = f"version_{max(version_nums)}"

        # Find the metrics.csv file with the largest file size in the latest version
        version_path = os.path.join(exp_path, latest_version)
        files = [f for f in os.listdir(version_path) if f.endswith(".csv")]

        largest_file = None
        largest_size = 0
        latest_version_file = None
        latest_version_num = -1

        # Check each file for size and keep track of the largest one
        for file in files:
            file_path = os.path.join(version_path, file)
            file_size = os.path.getsize(file_path)  # Get file size in bytes

            # If this file is larger than the current largest file, update
            if file_size > largest_size:
                largest_size = file_size
                largest_file = file_path
                # Also track the version for comparison in case of a tie
                latest_version_file = file
                latest_version_num = int(file.split("_")[1].split(".")[0])
            # If file sizes are equal, compare version numbers and choose the later one
            elif file_size == largest_size:
                version_num_from_file = int(file.split("_")[1].split(".")[0])
                if version_num_from_file > latest_version_num:
                    largest_file = file_path
                    latest_version_file = file
                    latest_version_num = version_num_from_file

        metrics_path = largest_file
        print(f"Chosen metrics file: {metrics_path}")

        # === Load CSV ===
        df = pd.read_csv(metrics_path)

        # === Epoch-level average loss ===
        val_df = df[df["val/loss"].notna()].drop_duplicates(subset="epoch")
        train_df = df[["epoch", "train/loss"]].dropna()
        avg_train = train_df.groupby("epoch").mean().rename(columns={"train/loss": "avg_train_loss"})
        plot_df = val_df.merge(avg_train, on="epoch")

        # Collect data for plotting
        all_train_losses.append(plot_df["avg_train_loss"])
        all_val_losses.append(plot_df["val/loss"])

        # Collect Jaccard Index for validation
        jaccard_index = plot_df.get(f"val/{metric}", None)
        if jaccard_index is not None:
            all_jaccard_indices.append(jaccard_index)
        else:
            all_jaccard_indices.append(None)

        # Store min/max values for table
        min_train_losses.append(min(plot_df["avg_train_loss"]))
        min_val_losses.append(min(plot_df["val/loss"]))
        if jaccard_index is not None:
            max_jaccard_indices.append(max(jaccard_index))
        else:
            max_jaccard_indices.append(None)

    # === Plotting ===
    fig, axes = plt.subplots(3, 1, figsize=(20, 15))

    # Plot 1: All Train Losses
    for i, train_loss in enumerate(all_train_losses):
        axes[0].plot(train_loss, label=f"Exp {exp_nos[i]} - Min Loss: {min_train_losses[i]:.4f}")
    axes[0].set_ylim(0, 0.3)  # Fixed y-limits for loss
    axes[0].set_title("Train Loss (Dice) Per Epoch")
    axes[0].set_xlabel("Epoch")
    axes[0].set_ylabel("Loss")
    axes[0].legend(ncol=2)
    axes[0].grid(True)
    axes[0].set_xticks(range(len(all_train_losses[0])))  # Set x-ticks to match number of epochs
    axes[0].set_xticklabels(range(1, len(all_train_losses[0]) + 1))  # Display epochs from 1 to n

    # Plot 2: All Validation Losses
    for i, val_loss in enumerate(all_val_losses):
        axes[1].plot(val_loss, label=f"Exp {exp_nos[i]} - Min Val Loss: {min_val_losses[i]:.4f}")
    axes[1].set_ylim(0, 0.3)  # Fixed y-limits for loss
    axes[1].set_title("Validation Loss (Dice) Per Epoch")
    axes[1].set_xlabel("Epoch")
    axes[1].set_ylabel("Loss")
    axes[1].legend(ncol=2)
    axes[1].grid(True)
    axes[1].set_xticks(range(len(all_val_losses[0])))  # Set x-ticks to match number of epochs
    axes[1].set_xticklabels(range(1, len(all_val_losses[0]) + 1))  # Display epochs from 1 to n

    # Plot 3: All Multiclass Jaccard Indices
    for i, jaccard_index in enumerate(all_jaccard_indices):
        if jaccard_index is not None:
            axes[2].plot(jaccard_index, label=f"Exp {exp_nos[i]} - Max Jaccard: {max_jaccard_indices[i]:.4f}")
    axes[2].set_ylim(0.6, 1.0)  # Fixed y-limits for Jaccard Index
    axes[2].set_title("Multiclass Jaccard Index Per Epoch")
    axes[2].set_xlabel("Epoch")
    axes[2].set_ylabel("Jaccard Index")
    axes[2].legend(ncol=2)
    axes[2].grid(True)
    axes[2].set_xticks(range(len(all_jaccard_indices[0])))  # Set x-ticks to match number of epochs
    axes[2].set_xticklabels(range(1, len(all_jaccard_indices[0]) + 1))  # Display epochs from 1 to n

    # Save and show
    plt.tight_layout()

    save_path = f"{base_dir}/plots/comparison_{metric}_exp_{'_'.join(map(str, exp_nos))}_loss_and_metrics.png"
    plt.savefig(save_path)
    plt.show()

    print(f"✅ Saved: {save_path}")

    # === Save Min/Max Losses and Jaccard Indices to a .txt File ===
    results_filename = f"{base_dir}/plots/experiment_comparison_min_max_{'_'.join(map(str, exp_nos))}.txt"
    with open(results_filename, "w") as f:
        f.write(f"Experiment Comparison: Min Losses, Max Jaccard Indices\n")
        for i, exp_no in enumerate(exp_nos):
            f.write(f"Exp {exp_no}: Min Train Loss: {min_train_losses[i]:.4f}, Min Val Loss: {min_val_losses[i]:.4f}, Max Jaccard Index: {max_jaccard_indices[i]:.4f}\n")

    print(f"✅ Min/Max values saved to: {results_filename}")


In [None]:
plot_multiple_experiment_metrics([0, 1, 2,3,4,5,6,7,8,9])  # Compare experiments 0, 1, and 2
