In [1]:
import os

In [2]:
import time

import numpy as np
import torch
from torch.utils.data import DataLoader
from torchgeo.datamodules import BaseDataModule
from torchgeo.trainers import BaseTask
from torchgeo.transforms import AugmentationSequential
from torchmetrics import JaccardIndex, MetricCollection, Precision, Recall
from tqdm import tqdm

import kornia.augmentation as K

from src.ftw.datamodules import preprocess, FTWDataModule
from src.ftw.datasets import FTW
from src.ftw.metrics import get_object_level_metrics
from src.ftw.trainers import CustomSemanticSegmentationTask

In [3]:
if torch.cuda.is_available():
    device = torch.device(f"cuda:{0}")
else:
    device = torch.device("cpu")

In [4]:
print(f"Using device: {device}")

Using device: cuda:0


In [5]:
mean = torch.tensor([0, 0, 0, 0]) 
std = torch.tensor([3000, 3000, 3000, 3000])

train_aug = AugmentationSequential(
            K.Normalize(mean=mean, std=std),
            K.RandomRotation(p=0.5, degrees=90),
            K.RandomHorizontalFlip(p=0.5),
            K.RandomVerticalFlip(p=0.5),
            K.RandomSharpness(p=0.5),
            data_keys=["image", "mask"],
        )

aug = AugmentationSequential(
            K.Normalize(mean=mean, std=std), data_keys=["image", "mask"]
        )

In [6]:
countries = ["austria", "belgium", "brazil", "cambodia", "corsica", "croatia", "denmark", "estonia", "finland", 
            "france", "germany", "india", "kenya", "latvia", "lithuania", "luxembourg", "netherlands", "portugal", 
            "rwanda", "slovakia", "slovenia", "south_africa", "spain", "sweden", "vietnam"]
train_ds = FTW(
        root="/home/airg/rbalogun/ftwfieldmapper/data/ftw_data",
        countries=countries,
        split="train",
        transforms=train_aug,
        load_boundaries=True,
        temporal_options="windowB" , # "windowA", "windowB" , "median"
    )
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=12)

val_ds = FTW(
        root="/home/airg/rbalogun/ftwfieldmapper/data/ftw_data",
        countries=countries,
        split="val",
        transforms=aug,
        load_boundaries=True,
        temporal_options="windowB" 
    )
val_dl = DataLoader(val_ds, batch_size=64, shuffle=False, num_workers=12)

test_ds = FTW(
        root="/home/airg/rbalogun/ftwfieldmapper/data/ftw_data",
        countries=countries,
        split="test",
        transforms=aug,
        load_boundaries=True,
        temporal_options="windowB" 
    )
test_dl = DataLoader(test_ds, batch_size=64, shuffle=False, num_workers=12)

Loading 3 Class Masks, with Boundaries
Temporal option:  windowB
Selecting :  55779  samples
Loading 3 Class Masks, with Boundaries
Temporal option:  windowB
Selecting :  6878  samples
Loading 3 Class Masks, with Boundaries
Temporal option:  windowB
Selecting :  7805  samples


In [7]:
# metrics = MetricCollection([
#             JaccardIndex(task="multiclass", average="none", num_classes=3, ignore_index=3),
#             Precision(task="multiclass", average="none", num_classes=3, ignore_index=3),
#             Recall(task="multiclass", average="none", num_classes=3, ignore_index=3)
#         ]).to(device)

In [8]:
import gc
import torch

gc.collect()
torch.cuda.empty_cache()


In [9]:
import yaml
import time
import torch

import numpy as np
import pandas as pd 
from torch.utils.data import DataLoader

configPath = "./src/configs/config_lacuna.yaml"

with open(configPath, "r") as config:
        params = yaml.safe_load(config)

# parameters
params_train = params['Train_Validate']
params_test = params['Test']

In [10]:
from src.lacuna.compiler import ModelCompiler
from src.lacuna.models import unet_att_d

model = ModelCompiler(
        model = eval(params_train['model'])( 
                params_train['n_classes'],
                params_train['channels'],
                **params_train['model_kwargs']
                ),
        working_dir=params_train["model_working_dir"],
        out_dir=params_train["model_out_dir"],
        buffer=params_train["buffer"],
        class_mapping=params_train["class_mapping"],
        gpu_devices=params_train.get("gpu_devices", [0]),
        use_sync_bn=params_train.get("use_sync_bn", False),
        model_init_type=params_train.get("model_init_type", "kaiming"),
        params_init= None, #params_train.get("params_init", None),
        freeze_params=eval(params_train.get("freeze_params", "None"))
    )

model = model.model.to(device)

--------- Vanilla Model compiled successfully ---------
----------GPU available----------
total number of trainable parameters: 157.9M


In [14]:
from src.lacuna.losses import tversky_focal 
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
import os
import yaml
import pandas as pd
from src.lacuna.evaluate import evaluate

class TupleWrapper(torch.utils.data.Dataset):
    def __init__(self, ds):
        self.ds = ds
    def __len__(self):
        return len(self.ds)
    def __getitem__(self, idx):
        sample = self.ds[idx]
        return sample["image"], sample["mask"].long().squeeze(0)


def run_segmentation_ftw(config_path, model, train_dl, val_dl):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load config
    with open(config_path, "r") as f:
        params = yaml.safe_load(f)
    params_train = params["Train_Validate"]

    # Wrap datasets for evaluation
    train_dl_eval = torch.utils.data.DataLoader(
        TupleWrapper(train_dl.dataset),
        batch_size=params_train["train_batch"],
        shuffle=False,
        num_workers=12
    )
    val_dl_eval = torch.utils.data.DataLoader(
        TupleWrapper(val_dl.dataset),
        batch_size=params_train["validate_batch"],
        shuffle=False,
        num_workers=12
    )

    optimizer = optim.Adam(model.parameters(), lr=params_train["learning_rate_init"])
    scheduler = ReduceLROnPlateau(optimizer, mode="min", patience=5, verbose=True)
    criterion = tversky_focal.LocallyWeightedTverskyFocalLoss()

    scaler = GradScaler()  # Mixed precision scaler

    metrics_df = []

    for epoch in range(params_train["epoch"]):
        model.train()
        train_loss = 0.0
        for batch in tqdm(train_dl, desc=f"Epoch {epoch + 1} Training"):
            inputs = batch["image"].to(device).squeeze(1)
            labels = batch["mask"].long().squeeze(1).to(device)

            optimizer.zero_grad()
            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            train_loss += loss.item()

        train_loss /= len(train_dl)

        # Evaluate on training set
        print(f"Epoch {epoch + 1} evaluating on training set")
        model.eval()
        train_eval_out = os.path.join(model.working_dir, model.out_dir, f"train_epoch{epoch + 1}_metrics.csv")
        train_metrics = evaluate(
            model=model,
            dataloader=train_dl_eval,
            num_classes=params_train["n_classes"],
            class_mapping=params_train["class_mapping"],
            device=device,
            buffer=params_train.get("buffer", None),
            out_name=train_eval_out
        )

        # Evaluate on validation set
        print(f"Epoch {epoch + 1} evaluating on validation set")
        val_eval_out = os.path.join(model.working_dir, model.out_dir, f"val_epoch{epoch + 1}_metrics.csv")
        val_metrics = evaluate(
            model=model,
            dataloader=val_dl_eval,
            num_classes=params_train["n_classes"],
            class_mapping=params_train["class_mapping"],
            device=device,
            buffer=params_train.get("buffer", None),
            out_name=val_eval_out
        )

        scheduler.step(val_metrics["Mean IoU"])

        print(f"Train Loss: {train_loss:.4f} | "
              f"Train OA: {train_metrics['Overall Accuracy']:.4f} | "
              f"Val OA: {val_metrics['Overall Accuracy']:.4f} | "
              f"Val IoU: {val_metrics['Mean IoU']:.4f}")

        metrics_df.append({
            "Epoch": epoch + 1,
            "Train Loss": train_loss,
            "Train OA": train_metrics["Overall Accuracy"],
            "Train IoU": train_metrics["Mean IoU"],
            "Train Precision": train_metrics["mean Precision"],
            "Train Recall": train_metrics["mean Recall"],
            "Train F1": train_metrics["Mean F1 Score"],
            "Val OA": val_metrics["Overall Accuracy"],
            "Val IoU": val_metrics["Mean IoU"],
            "Val Precision": val_metrics["mean Precision"],
            "Val Recall": val_metrics["mean Recall"],
            "Val F1": val_metrics["Mean F1 Score"],
        })

    # Save metrics to CSV
    os.makedirs(os.path.join(model.working_dir, model.out_dir), exist_ok=True)
    metrics_path = os.path.join(model.working_dir, model.out_dir, "final_metrics_ftw.csv")
    pd.DataFrame(metrics_df).to_csv(metrics_path, index=False)
    print(f"Training and validation metrics saved to {metrics_path}")


In [15]:
# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [16]:
config_path = "/home/airg/rbalogun/ftwfieldmapper/src/configs/config_lacuna.yaml"
run_segmentation_ftw(config_path, model, train_dl, val_dl)

Epoch 1 Training:   0%|                           | 0/872 [00:35<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 512.00 MiB. GPU 