# LEVIR-CD+ change detection example notebook

We start off by installing torchgeo. If you are running this on Colab, then you will need to restart your runtime after this step.

In [1]:
!pip install torchgeo




[notice] A new release of pip is available: 23.2.1 -> 24.0
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
import os

import torchgeo
from torchgeo.datasets import LEVIRCDPlus
from torchgeo.datasets.utils import unbind_samples
from torchgeo.trainers import SemanticSegmentationTask
from torchgeo.datamodules.utils import dataset_split

import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
from lightning.pytorch import LightningDataModule

import torch
from torch.utils.data import DataLoader
import kornia.augmentation as K

import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision.transforms import Compose
from tqdm import tqdm

from sklearn.metrics import precision_score, recall_score

In [3]:
torchgeo.__version__, pl.__version__, torch.__version__

('0.5.1', '2.1.3', '2.1.1+cu118')

In [4]:
torch.cuda.is_available()

True

In [5]:
# some experiment parameters

experiment_name = "experiment_test"
experiment_dir = f"results/{experiment_name}"
os.makedirs(experiment_dir, exist_ok=True)

batch_size = 2
learning_rate = 0.0001
gpu_id = 0
device = torch.device(f"cuda:{gpu_id}")
num_dataloader_workers = 10
patch_size = 32
val_split_pct = 0.1 # how much of our training set to hold out as a validation set

In [6]:
# Download the dataset and see how many images are in the train and test splits

train_dataset = LEVIRCDPlus(root="data/LEVIRCDPlus", split="train")
test_dataset = LEVIRCDPlus(root="data/LEVIRCDPlus", split="test")
len(train_dataset), len(test_dataset)

(637, 348)

## Excersise 1

Plot some examples from the `train_dataset` (note: torchgeo will help you out here).

## Define a PyTorch Lightning module and datamodule

PyTorch Lightning organizes the steps required for training deep learning models in `LightningModules`, and organizes the dataset handling to creating dataloaders in `LightningDataModules`. TorchGeo provides pre-built LightningDataModules for a handful of datasets, and pre-built "trainers" (i.e. LightningModules) for a variety of different types of tasks.

For this tutorial, we will lightly extend TorchGeo's `SemanticSegmentationTask` (just to add some custom plotting code) and create a new LightningDataModule for the LEVIR-CD+ dataset.

In [7]:
class CustomSemanticSegmentationTask(SemanticSegmentationTask):
    
    def plot(self, sample):
        image1 = sample["image"][:3]
        image2 = sample["image"][3:]
        mask = sample["mask"]
        prediction = sample["prediction"]

        fig, axs = plt.subplots(nrows=1, ncols=4, figsize=(4 * 5, 5))
        axs[0].imshow(image1.permute(1, 2, 0))
        axs[0].axis("off")
        axs[1].imshow(image2.permute(1, 2, 0))
        axs[1].axis("off")
        axs[2].imshow(mask)
        axs[2].axis("off")
        axs[3].imshow(prediction)
        axs[3].axis("off")

        axs[0].set_title("Image 1")
        axs[1].set_title("Image 2")
        axs[2].set_title("Mask")
        axs[3].set_title("Prediction")

        plt.tight_layout()
        
        return fig

    # The only difference between this code and the same from SemanticSegmentationTask is our redirect to use our own plotting function
    def training_step(self, *args, **kwargs):
        batch = args[0]
        batch_idx = args[1]
        
        x = batch["image"]
        y = batch["mask"]
        y_hat = self.forward(x)
        y_hat_hard = y_hat.argmax(dim=1)

        loss = self.criterion(y_hat, y)

        self.log("train_loss", loss, on_step=True, on_epoch=False)
        self.train_metrics(y_hat_hard, y)

        if batch_idx < 10:
            batch["prediction"] = y_hat_hard
            for key in ["image", "mask", "prediction"]:
                batch[key] = batch[key].cpu()
            sample = unbind_samples(batch)[0]
            fig = self.plot(sample)
            summary_writer = self.logger.experiment
            summary_writer.add_figure(
                f"image/train/{batch_idx}", fig, global_step=self.global_step
            )
            plt.close()
        
        return loss
    
    # The only difference between this code and the same from SemanticSegmentationTask is our redirect to use our own plotting function
    def validation_step(self, *args, **kwargs):
        batch = args[0]
        batch_idx = args[1]
        x = batch["image"]
        y = batch["mask"]
        y_hat = self.forward(x)
        y_hat_hard = y_hat.argmax(dim=1)

        loss = self.criterion(y_hat, y)

        self.log("val_loss", loss, on_step=False, on_epoch=True)
        self.val_metrics(y_hat_hard, y)

        if batch_idx < 10:
            batch["prediction"] = y_hat_hard
            for key in ["image", "mask", "prediction"]:
                batch[key] = batch[key].cpu()
            sample = unbind_samples(batch)[0]
            fig = self.plot(sample)
            summary_writer = self.logger.experiment
            summary_writer.add_figure(
                f"image/val/{batch_idx}", fig, global_step=self.global_step
            )
            plt.close()

In [8]:
class LEVIRCDPlusDataModule(pl.LightningDataModule):

    def __init__(
        self,
        batch_size=32,
        num_workers=0,
        val_split_pct=0.2,
        patch_size=(256, 256),
        **kwargs,
    ):
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.val_split_pct = val_split_pct
        self.patch_size = patch_size
        self.kwargs = kwargs

    def on_after_batch_transfer(
        self, batch, batch_idx
    ):
        if (
            hasattr(self, "trainer")
            and self.trainer is not None
            and hasattr(self.trainer, "training")
            and self.trainer.training
        ):
            # Kornia expects masks to be floats with a channel dimension
            x = batch["image"]
            y = batch["mask"].float().unsqueeze(1)

            train_augmentations = K.AugmentationSequential(
                K.RandomRotation(p=0.5, degrees=90),
                K.RandomHorizontalFlip(p=0.5),
                K.RandomVerticalFlip(p=0.5),
                K.RandomCrop(self.patch_size),
                K.RandomSharpness(p=0.5),
                data_keys=["input", "mask"],
            )
            x, y = train_augmentations(x, y)

            # torchmetrics expects masks to be longs without a channel dimension
            batch["image"] = x
            batch["mask"] = y.squeeze(1).long()

        return batch
        
    def preprocess(self, sample):
        sample["image"] = (sample["image"]  / 255.0).float()
        sample["image"] = torch.flatten(sample["image"], 0, 1)
        sample["mask"] = sample["mask"].long()
        return sample

    def prepare_data(self):
        LEVIRCDPlus(split="train", **self.kwargs)

    def setup(self, stage=None):
        train_transforms = Compose([self.preprocess])
        test_transforms = Compose([self.preprocess])

        train_dataset = LEVIRCDPlus(
            split="train", transforms=train_transforms, **self.kwargs
        )

        if self.val_split_pct > 0.0:
            self.train_dataset, self.val_dataset, _ = dataset_split(
                train_dataset, val_pct=self.val_split_pct, test_pct=0.0
            )
        else:
            self.train_dataset = train_dataset
            self.val_dataset = train_dataset

        self.test_dataset = LEVIRCDPlus(
            split="test", transforms=test_transforms, **self.kwargs
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False
        )

## Setting up a training run

In [9]:
datamodule = LEVIRCDPlusDataModule(
    root="data/LEVIRCDPlus",
    batch_size=batch_size,
    num_workers=num_dataloader_workers,
    val_split_pct=val_split_pct,
    patch_size=(patch_size, patch_size),
)

In [10]:
task = CustomSemanticSegmentationTask(
    model="unet",
    backbone="resnet18",
    weights=True,
    in_channels=6,
    num_classes=2,
    loss="ce",
    ignore_index=None,
    lr=learning_rate,
    patience=10
)

checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath=experiment_dir,
    save_top_k=1,
    save_last=True,
)

early_stopping_callback = EarlyStopping(
    monitor="val_loss",
    min_delta=0.00,
    patience=10,
)

tb_logger = TensorBoardLogger(
    save_dir="logs/",
    name=experiment_name
)

In [11]:
!pip install tensorboard




[notice] A new release of pip is available: 23.2.1 -> 24.0
[notice] To update, run: python.exe -m pip install --upgrade pip


In [12]:
%reload_ext tensorboard

In [13]:
%tensorboard --logdir logs/

Reusing TensorBoard on port 6006 (pid 2456), started 0:21:58 ago. (Use '!kill 2456' to kill it.)

In [None]:
trainer = pl.Trainer(
    callbacks=[checkpoint_callback, early_stopping_callback],
    logger=[tb_logger],
    default_root_dir=experiment_dir,
    min_epochs=10,
    max_epochs=200,
    accelerator='gpu',
    devices=[gpu_id]
)

_ = trainer.fit(model=task, datamodule=datamodule)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3070 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type             | Params
---------------------------------------------------
0 | criterion     | CrossEntropyLoss | 0     
1 | train_metrics | MetricCollection | 0     
2 | val_metrics   | MetricCollection | 0     
3 | test_metrics  | MetricCollection | 0     
4 | model         | Unet             | 14.3 M
---------------------------------------------------
14.3 M    Trainable params
0         Non-trainable

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

C:\Documents\RoofSense\venv\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:436: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.


In [None]:
trainer.test(model=task, datamodule=datamodule)

## Custom test step to compute the precision, recall, and F1 metrics

In [None]:
# Example of how to load a trained task from a checkpoint file
# task = CustomSemanticSegmentationTask.load_from_checkpoint("results/...")
# datamodule.setup("test")

In [None]:
model = task.model.to(device).eval()

In [None]:
y_preds = []
y_trues = []
for batch in tqdm(datamodule.test_dataloader()):
    images = batch["image"].to(device)
    y_trues.append(batch["mask"].numpy().ravel()[::500])
    with torch.inference_mode():
        y_pred = model(images).argmax(dim=1).cpu().numpy().ravel()[::500]
    y_preds.append(y_pred)

y_preds = np.concatenate(y_preds)
y_trues = np.concatenate(y_trues)

In [None]:
precision = precision_score(y_trues, y_preds)
recall = recall_score(y_trues, y_preds)
f1 = 2 * (precision * recall) / (precision + recall)

In [None]:
precision, recall, f1