In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
import time

import numpy as np
import torch
from torch.utils.data import DataLoader
from torchgeo.datamodules import BaseDataModule
from torchgeo.trainers import BaseTask
from torchmetrics import JaccardIndex, MetricCollection, Precision, Recall
from tqdm import tqdm

from src.ftw.datamodules import preprocess, FTWDataModule
from src.ftw.datasets import FTW
from src.ftw.metrics import get_object_level_metrics
from src.ftw.trainers import CustomSemanticSegmentationTask

In [3]:
if torch.cuda.is_available():
    device = torch.device(f"cuda:{gpu}")
else:
    device = torch.device("cpu")

In [4]:
print(f"Using device: {device}")

Using device: cpu


In [None]:
# FTWDataModule = FTWDataModule(
#     batch_size=64,
#     num_workers=4,
#     train_countries = ["france"],
#     val_countries = ["france"],
#     test_countries = ["france"],
#     temporal_options = "windowB" , # "windowA", "windowB" , "median"
#     load_boundaries = True,
#     num_samples = -1
# )
# # # FTWDataModule.prepare_data()
# # train_dataset = FTWDataModule.setup(stage="fit")
# # train_dataloader = DataLoader(
# #                             train_dataset,
# #                             batch_size=64,
# #                             shuffle=True,
# #                             num_workers=12
# #                             )

# # val_dataset = FTWDataModule.setup(stage="validate")
# # val_dataloader = DataLoader(
# #                             val_dataset,
# #                             batch_size=64,
# #                             shuffle=False,
# #                             num_workers=12
# #                             )
# # test_dataset = FTWDataModule.setup(stage="test")
# # test_dataloader = DataLoader(
# #                             test_dataset,
# #                             batch_size=64,
# #                             shuffle=False,
# #                             num_workers=12
# #                             )

Loaded datamodule with:
Train countries: ['france']
Val countries: ['france']
Test countries: ['france']
Number of samples: -1


In [27]:
train_ds = FTW(
        root="/home/airg/rbalogun/ftwfieldmapper/data/ftw_data",
        countries=["france"],
        split="train",
        # transforms=preprocess,
        load_boundaries=True,
        temporal_options="windowB" , # "windowA", "windowB" , "median"
    )
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=12)

val_ds = FTW(
        root="/home/airg/rbalogun/ftwfieldmapper/data/ftw_data",
        countries=["france"],
        split="val",
        # transforms=preprocess,
        load_boundaries=True,
        temporal_options="windowB" 
    )
val_dl = DataLoader(val_ds, batch_size=64, shuffle=False, num_workers=12)

test_ds = FTW(
        root="/home/airg/rbalogun/ftwfieldmapper/data/ftw_data",
        countries=["france"],
        split="test",
        # transforms=preprocess,
        load_boundaries=True,
        temporal_options="windowB" 
    )
val_dl = DataLoader(test_ds, batch_size=64, shuffle=False, num_workers=12)

Loading 3 Class Masks, with Boundaries
Temporal option:  windowB
Selecting :  2773  samples
Loading 3 Class Masks, with Boundaries
Temporal option:  windowB
Selecting :  339  samples
Loading 3 Class Masks, with Boundaries
Temporal option:  windowB
Selecting :  396  samples


In [26]:
val_dl.dataset.__getitem__(0).keys()

dict_keys(['image', 'mask'])

In [None]:
metrics = MetricCollection([
            JaccardIndex(task="multiclass", average="none", num_classes=3, ignore_index=3),
            Precision(task="multiclass", average="none", num_classes=3, ignore_index=3),
            Recall(task="multiclass", average="none", num_classes=3, ignore_index=3)
        ]).to(device)

In [None]:
# training loop
all_tps = 0
    all_fps = 0
    all_fns = 0
    for batch in tqdm(dl):
        images = batch["image"].to(device)
        masks = batch["mask"].to(device)
        with torch.inference_mode():
            outputs = model(images)

        outputs = outputs.argmax(dim=1)

        if model_predicts_3_classes:
            new_outputs = torch.zeros(outputs.shape[0], outputs.shape[1], outputs.shape[2], device=device)
            new_outputs[outputs == 2] = 0  # Boundary pixels
            new_outputs[outputs == 0] = 0  # Background pixels
            new_outputs[outputs == 1] = 1  # Crop pixels
            outputs = new_outputs
        else:
            if test_on_3_classes:
                raise ValueError("Cannot test on 3 classes when the model was trained on 2 classes")

        metrics.update(outputs, masks)
        outputs = outputs.cpu().numpy().astype(np.uint8)
        masks = masks.cpu().numpy().astype(np.uint8)

        for i in range(len(outputs)):
            output = outputs[i]
            mask = masks[i]
            if postprocess:
                post_processed_output = out.copy()
                output = post_processed_output
            tps, fps, fns = get_object_level_metrics(mask, output, iou_threshold=iou_threshold)
            all_tps += tps
            all_fps += fps
            all_fns += fns

    results = metrics.compute()
    pixel_level_iou = results["MulticlassJaccardIndex"][1].item()
    pixel_level_precision = results["MulticlassPrecision"][1].item()
    pixel_level_recall = results["MulticlassRecall"][1].item()

    if all_tps + all_fps > 0:
        object_precision = all_tps / (all_tps + all_fps)
    else:
        object_precision = float('nan')

    if all_tps + all_fns > 0:
        object_recall = all_tps / (all_tps + all_fns)
    else:
        object_recall = float('nan')

    print(f"Pixel level IoU: {pixel_level_iou:.4f}")
    print(f"Pixel level precision: {pixel_level_precision:.4f}")
    print(f"Pixel level recall: {pixel_level_recall:.4f}")
    print(f"Object level precision: {object_precision:.4f}")
    print(f"Object level recall: {object_recall:.4f}")

    if out is not None:
        if not os.path.exists(out):
            with open(out, "w") as f:
                f.write("train_checkpoint,test_countries,pixel_level_iou,pixel_level_precision,pixel_level_recall,object_level_precision,object_level_recall\n")
        with open(out, "a") as f:
            f.write(f"{model},{countries},{pixel_level_iou},{pixel_level_precision},{pixel_level_recall},{object_precision},{object_recall}\n")
