In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
import time

import numpy as np
import torch
from torch.utils.data import DataLoader
from torchgeo.datamodules import BaseDataModule
from torchgeo.trainers import BaseTask
from torchgeo.transforms import AugmentationSequential
from torchmetrics import JaccardIndex, MetricCollection, Precision, Recall
from tqdm import tqdm

import kornia.augmentation as K

from src.ftw.datamodules import preprocess, FTWDataModule
from src.ftw.datasets import FTW
from src.ftw.metrics import get_object_level_metrics
from src.ftw.trainers import CustomSemanticSegmentationTask

In [3]:
if torch.cuda.is_available():
    device = torch.device(f"cuda:{0}")
else:
    device = torch.device("cpu")

In [4]:
print(f"Using device: {device}")

Using device: cuda:0


In [5]:
mean = torch.tensor([0, 0, 0, 0]) 
std = torch.tensor([3000, 3000, 3000, 3000])

train_aug = AugmentationSequential(
            K.Normalize(mean=mean, std=std),
            K.RandomRotation(p=0.5, degrees=90),
            K.RandomHorizontalFlip(p=0.5),
            K.RandomVerticalFlip(p=0.5),
            K.RandomSharpness(p=0.5),
            data_keys=["image", "mask"],
        )

aug = AugmentationSequential(
            K.Normalize(mean=mean, std=std), data_keys=["image", "mask"]
        )

In [6]:
countries = ["austria", "belgium", "brazil", "cambodia", "corsica", "croatia", "denmark", "estonia", "finland", 
            "france", "germany", "india", "kenya", "latvia", "lithuania", "luxembourg", "netherlands", "portugal", 
            "rwanda", "slovakia", "slovenia", "south_africa", "spain", "sweden", "vietnam"]
train_ds = FTW(
        root="/home/airg/rbalogun/ftwfieldmapper/data/ftw_data",
        countries=countries,
        split="train",
        transforms=train_aug,
        load_boundaries=True,
        temporal_options="windowB" , # "windowA", "windowB" , "median"
    )
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=12)

val_ds = FTW(
        root="/home/airg/rbalogun/ftwfieldmapper/data/ftw_data",
        countries=countries,
        split="val",
        transforms=aug,
        load_boundaries=True,
        temporal_options="windowB" 
    )
val_dl = DataLoader(val_ds, batch_size=64, shuffle=False, num_workers=12)

test_ds = FTW(
        root="/home/airg/rbalogun/ftwfieldmapper/data/ftw_data",
        countries=countries,
        split="test",
        transforms=aug,
        load_boundaries=True,
        temporal_options="windowB" 
    )
test_dl = DataLoader(test_ds, batch_size=64, shuffle=False, num_workers=12)

Loading 3 Class Masks, with Boundaries
Temporal option:  windowB
Selecting :  55779  samples
Loading 3 Class Masks, with Boundaries
Temporal option:  windowB
Selecting :  6878  samples
Loading 3 Class Masks, with Boundaries
Temporal option:  windowB
Selecting :  7805  samples


In [7]:
metrics = MetricCollection([
            JaccardIndex(task="multiclass", average="none", num_classes=3, ignore_index=3),
            Precision(task="multiclass", average="none", num_classes=3, ignore_index=3),
            Recall(task="multiclass", average="none", num_classes=3, ignore_index=3)
        ]).to(device)



In [8]:
import os
import yaml
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
import pandas as pd

from src.ftw.datamodules import preprocess, FTW
from src.lacuna.utils import make_reproducible
from src.lacuna.compiler import ModelCompiler
from src.lacuna.models import unet, unet_att_d
from src.lacuna.evaluate import evaluate


import kornia.augmentation as K
from torchgeo.transforms import AugmentationSequential

def run_segmentation_ftw(config_path, do_train=True, test_only=False, do_prediction=False):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Set reproducibility
    make_reproducible(seed=42)

    # Load configuration
    with open(config_path, "r") as file:
        params = yaml.safe_load(file)

    params_train = params["Train_Validate"]

    # Dataset loading
    mean = torch.tensor([0, 0, 0, 0])
    std = torch.tensor([3000, 3000, 3000, 3000])
    countries = params_train["countries"]

    train_aug = AugmentationSequential(
        K.Normalize(mean=mean, std=std),
        K.RandomRotation(p=0.5, degrees=90),
        K.RandomHorizontalFlip(p=0.5),
        K.RandomVerticalFlip(p=0.5),
        K.RandomSharpness(p=0.5),
        data_keys=["image", "mask"],
    )

    val_aug = AugmentationSequential(
        K.Normalize(mean=mean, std=std),
        data_keys=["image", "mask"],
    )

    train_ds = FTW(
        root=params_train["data_path"],
        countries=countries,
        split="train",
        transforms=train_aug,
        load_boundaries=True,
        temporal_options="windowB",
    )
    val_ds = FTW(
        root=params_train["data_path"],
        countries=countries,
        split="val",
        transforms=val_aug,
        load_boundaries=True,
        temporal_options="windowB",
    )
    test_ds = FTW(
        root=params_train["data_path"],
        countries=countries,
        split="test",
        transforms=val_aug,
        load_boundaries=True,
        temporal_options="windowB",
    )

    train_dl = torch.utils.data.DataLoader(train_ds, batch_size=params_train["train_batch"], shuffle=True, num_workers=12)
    val_dl = torch.utils.data.DataLoader(val_ds, batch_size=params_train["validate_batch"], shuffle=False, num_workers=12)
    test_dl = torch.utils.data.DataLoader(test_ds, batch_size=params_train["validate_batch"], shuffle=False, num_workers=12)

    # Compile model

    model = ModelCompiler(
            model = eval(params_train['model'])( 
                    params_train['n_classes'],
                    params_train['channels'],
                    **params_train['model_kwargs']
                    ),
            working_dir=params_train["model_working_dir"],
            out_dir=params_train["model_out_dir"],
            buffer=params_train["buffer"],
            class_mapping=params_train["class_mapping"],
            gpu_devices=params_train.get("gpu_devices", [0]),
            use_sync_bn=params_train.get("use_sync_bn", False),
            model_init_type=params_train.get("model_init_type", "kaiming"),
            params_init=params_train.get("params_init", None),
            freeze_params=eval(params_train.get("freeze_params", "None"))
        )
    
    model.to(device)
    
    # If in full training mode
    if do_train and not test_only:
        print("Training...")
        model.fit(
            train_dataset=train_dl,
            val_dataset=val_dl,
            epochs=params_train["epoch"],
            optimizer_name=params_train["optimizer_name"],
            lr_init=params_train["learning_rate_init"],
            lr_policy=params_train["learning_rate_policy"],
            criterion=nn.CrossEntropyLoss(),
            momentum=params_train.get("momentum", 0.9),
            checkpoint_interval=params_train.get("checkpoint_interval", 5),
            early_stopping_patience=params_train.get("early_stopping_patience", 10),
            min_delta=params_train.get("min_delta", 0.001),
            warmup_period=params_train.get("warmup_period", 10),
        )

    if do_prediction:
        raise NotImplementedError("Prediction module not adapted yet. Coming soon.")

    # Final evaluation
    print("Final Evaluation...")
    model.accuracy_evaluation(test_dl, filename="final_metrics_ftw.csv")
    print(f"Evaluation saved at: {os.path.join(model.working_dir, model.out_dir)}")


In [9]:
import yaml
import time
import torch

import numpy as np
import pandas as pd 
from torch.utils.data import DataLoader

configPath = "./src/configs/config_lacuna.yaml"

with open(configPath, "r") as config:
        params = yaml.safe_load(config)

# parameters
params_train = params['Train_Validate']
params_test = params['Test']

In [16]:
from src.lacuna.compiler import ModelCompiler
from src.lacuna.models import unet_att_d

model = ModelCompiler(
        model = eval(params_train['model'])( 
                params_train['n_classes'],
                params_train['channels'],
                **params_train['model_kwargs']
                ),
        working_dir=params_train["model_working_dir"],
        out_dir=params_train["model_out_dir"],
        buffer=params_train["buffer"],
        class_mapping=params_train["class_mapping"],
        gpu_devices=params_train.get("gpu_devices", [0]),
        use_sync_bn=params_train.get("use_sync_bn", False),
        model_init_type=params_train.get("model_init_type", "kaiming"),
        params_init= None, #params_train.get("params_init", None),
        freeze_params=eval(params_train.get("freeze_params", "None"))
    )

# model.to(device)

--------- Vanilla Model compiled successfully ---------
----------GPU available----------
total number of trainable parameters: 157.9M


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
import os
import pandas as pd
import numpy as np
import segmentation_models_pytorch as smp



# Define loss (Dice Loss + CrossEntropy Loss)
class DiceCrossEntropyLoss(nn.Module):
    def __init__(self, weight_dice=0.5, weight_ce=0.5):
        super().__init__()
        self.dice_loss = smp.losses.DiceLoss(mode='multiclass')
        self.ce_loss = nn.CrossEntropyLoss()
        self.weight_dice = weight_dice
        self.weight_ce = weight_ce

    def forward(self, preds, targets):
        dice = self.dice_loss(preds, targets)
        ce = self.ce_loss(preds, targets)
        return self.weight_dice * dice + self.weight_ce * ce

criterion = DiceCrossEntropyLoss()

# Optimizer and Scheduler
optimizer = optim.Adam(model.parameters(), lr=params_train["learning_rate_init"])
scheduler = ReduceLROnPlateau(optimizer, mode="min", patience=5, verbose=True)

# Paths
num_epochs = params_train["epoch"]
early_stopping_patience = params_train.get("early_stopping_patience", 10)
checkpoint_dir = params_train.get("model_out_dir", "./checkpoints")
metrics_log_file = os.path.join(checkpoint_dir, "training_metrics.csv")
os.makedirs(checkpoint_dir, exist_ok=True)

# Class Mapping for evaluation
class_mapping = params_train["class_mapping"]  # Must be a dict like {0: "classA", 1: "classB", ...}

# Training
best_val_loss = float("inf")
early_stopping_counter = 0
metrics_records = []

for epoch in range(1, num_epochs + 1):
    print(f"\n--- Epoch {epoch}/{num_epochs} ---")

    model.train()
    train_loss = 0.0

    for imgs, labels, mask in tqdm(train_dataloader, desc="Training"):
        imgs, labels, mask = imgs.to(device), labels.to(device), mask.to(device)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * imgs.size(0)

    train_loss /= len(train_dataloader.dataset)

    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for imgs, labels, mask in tqdm(val_dataloader, desc="Validating"):
            imgs, labels, mask = imgs.to(device), labels.to(device), mask.to(device)

            outputs = model(imgs)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * imgs.size(0)

    val_loss /= len(val_dataloader.dataset)
    scheduler.step(val_loss)

    # Evaluate using your Evaluator
    eval_metrics = Evaluator(
        model=model,
        dataloader=val_dataloader,
        num_classes=model.classes,
        class_mapping=class_mapping,
        device=device,
        buffer=params_train.get('buffer', None),
        out_name=os.path.join(checkpoint_dir, f"epoch{epoch}_metrics.csv")
    )

    epoch_metrics = {
        "epoch": epoch,
        "train_loss": train_loss,
        "val_loss": val_loss,
        "overall_accuracy": eval_metrics["Overall Accuracy"],
        "mean_accuracy": eval_metrics["Mean Accuracy"],
        "mean_iou": eval_metrics["Mean IoU"],
        "mean_precision": eval_metrics["mean Precision"],
        "mean_recall": eval_metrics["mean Recall"],
        "mean_f1_score": eval_metrics["Mean F1 Score"],
    }
    metrics_records.append(epoch_metrics)

    print(
        f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | "
        f"OA: {eval_metrics['Overall Accuracy']:.4f} | mIoU: {eval_metrics['Mean IoU']:.4f}"
    )

    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        early_stopping_counter = 0
        checkpoint_path = os.path.join(checkpoint_dir, f"best_model_epoch{epoch}.pth")
        torch.save(model.state_dict(), checkpoint_path)
        print(f"  → New best model saved to {checkpoint_path}")
    else:
        early_stopping_counter += 1
        print(f"  → Early stopping counter: {early_stopping_counter}/{early_stopping_patience}")

    if early_stopping_counter >= early_stopping_patience:
        print("Early stopping triggered. Stopping training.")
        break

# Save metrics
metrics_df = pd.DataFrame(metrics_records)
metrics_df.to_csv(metrics_log_file, index=False)
print(f"\nAll training metrics saved to {metrics_log_file}")


In [11]:
metrics

MetricCollection(
  (MulticlassJaccardIndex): MulticlassJaccardIndex()
  (MulticlassPrecision): MulticlassPrecision()
  (MulticlassRecall): MulticlassRecall()
)

In [13]:
# training loop
all_tps = 0
all_fps = 0
all_fns = 0

for batch in tqdm(dl):
    images = batch["image"].to(device)
    masks = batch["mask"].to(device)

    with torch.inference_mode():
        outputs = model(images)

        outputs = outputs.argmax(dim=1)

        new_outputs = torch.zeros(outputs.shape[0], outputs.shape[1], outputs.shape[2], device=device)
        new_outputs[outputs == 2] = 0  # Boundary pixels
        new_outputs[outputs == 0] = 0  # Background pixels
        new_outputs[outputs == 1] = 1  # Crop pixels
        outputs = new_outputs

        metrics.update(outputs, masks)
        outputs = outputs.cpu().numpy().astype(np.uint8)
        masks = masks.cpu().numpy().astype(np.uint8)

        for i in range(len(outputs)):
            output = outputs[i]
            mask = masks[i]
            if postprocess:
                post_processed_output = out.copy()
                output = post_processed_output
            tps, fps, fns = get_object_level_metrics(mask, output, iou_threshold=iou_threshold)
            all_tps += tps
            all_fps += fps
            all_fns += fns

    results = metrics.compute()
    pixel_level_iou = results["MulticlassJaccardIndex"][1].item()
    pixel_level_precision = results["MulticlassPrecision"][1].item()
    pixel_level_recall = results["MulticlassRecall"][1].item()

    if all_tps + all_fps > 0:
        object_precision = all_tps / (all_tps + all_fps)
    else:
        object_precision = float('nan')

    if all_tps + all_fns > 0:
        object_recall = all_tps / (all_tps + all_fns)
    else:
        object_recall = float('nan')

    print(f"Pixel level IoU: {pixel_level_iou:.4f}")
    print(f"Pixel level precision: {pixel_level_precision:.4f}")
    print(f"Pixel level recall: {pixel_level_recall:.4f}")
    print(f"Object level precision: {object_precision:.4f}")
    print(f"Object level recall: {object_recall:.4f}")

    if out is not None:
        if not os.path.exists(out):
            with open(out, "w") as f:
                f.write("train_checkpoint,test_countries,pixel_level_iou,pixel_level_precision,pixel_level_recall,object_level_precision,object_level_recall\n")
        with open(out, "a") as f:
            f.write(f"{model},{countries},{pixel_level_iou},{pixel_level_precision},{pixel_level_recall},{object_precision},{object_recall}\n")


NameError: name 'dl' is not defined