In [None]:
import random
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime, timezone
import torch
from pathlib import Path
from torchgeo.datasets import LEVIRCDPlus

## This is just a notebook that helps with exploring the [LEVIR-CD+](https://torchgeo.readthedocs.io/en/latest/api/datasets.html#id11) dataset

- Make sure codes are pointing to downloaded datasets
- Run each cell in order
- Outputs some example plots of the LEVIR-CD+ dataset

In [None]:
def select_samples(n, root=Path("/Users/samuel.omole/Desktop/repos/geofm_datasets/levircdplus")):
    """
    Selects n samples from the root directory. 
    Change root directory as appropriate to match
    the path to downloaded dataset. 

    Args:
        n: The number of samples to select
        root: The path to the downloaded dataset

    Returns:
        Selected n samples
    """
    root.mkdir(parents=True, exist_ok=True)
    ds = LEVIRCDPlus(root=root, download=False) # download can be True to download dataset in root location
    print("Length of LEVIRCDPlus dataset:", len(ds))
    print("Sample keys:", ds[0].keys()) # check sample keys
    print(f"Selecting {n} random samples")
    indices = random.sample(range(len(ds)), n)
    return [list(ds)[ind] for ind in indices]

def percentile_clip_and_scale(band, pmin=2, pmax=98):
    """
    Clip and scale image bands

    Args:
        band: The specific band to clip and scale
        pmin: The min percentile. Defaults to 2.
        pmax: The max percentile. Defaults to 98.

    Returns:
        The clipped and scaled band
    """
    lo, hi = np.percentile(band, (pmin, pmax))
    band = np.clip(band, lo, hi)
    band = (band - lo) / (hi - lo)
    return band

def make_rgb_from_bands(img, band_indices):
    """
    Stack RGB bands of clipped and scaled images

    Args:
        img: The image
        band_indices: The indices to stack e.g., RGB

    Returns:
        The stacked RGB bands
    """
    C, _, _ = img.shape
    channels = []
    for i in band_indices:
        if i < 0 or i >= C:
            raise IndexError(f"band index {i} out of range for C={C}")
        channels.append(img[i])
    rgb = np.stack([percentile_clip_and_scale(ch) for ch in channels], axis=-1)
    return rgb

In [None]:
# Applying the helper functions above and plotting some samples
samples = select_samples(10)
color_idx = (2, 1, 0)
for s in samples:
    img1_rgb = make_rgb_from_bands(s["image1"], color_idx)
    img2_rgb = make_rgb_from_bands(s["image2"], color_idx)
    mask = s["mask"]
    
    #Â plot images and mask
    fig, axs = plt.subplots(1, 3, figsize=(15, 7))
    axs[0].imshow(img1_rgb)
    axs[0].axis("off")

    axs[1].imshow(img2_rgb)
    axs[1].axis("off")

    axs[2].imshow(mask, cmap='gray')
    axs[2].set_title("Mask")
    axs[2].axis("off")
    plt.show()