In [None]:
import random
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime, timezone
import torch
from pathlib import Path
from torchgeo.datasets import OSCD

In [None]:
def select_samples(n, root=Path("/Users/samuel.omole/Desktop/repos/geofm_datasets/oscd")):
    """
    Selects n samples from the root directory. 
    Change root directory as appropriate to match
    the path to downloaded dataset. 

    Args:
        n: Number of samples to select
        root: Path to dataset

    Returns:
        Selected n samples
    """
    root.mkdir(parents=True, exist_ok=True)
    ds = OSCD(root=root, download=False) # download can be True to download dataset in root location
    print("Length of OSCD dataset:", len(ds))
    print("Sample keys:", ds[0].keys()) # check sample keys
    print(f"Selecting {n} random samples")
    indices = random.sample(range(len(ds)), n)
    return [(ds.files[ind]["region"], ds.files[ind]["dates"], list(ds)[ind]) for ind in indices]

def percentile_clip_and_scale(band, pmin=2, pmax=98):
    """
    Clip and scale image bands

    Args:
        band: The specific band to clip and scale
        pmin: The min percentile. Defaults to 2.
        pmax: The max percentile. Defaults to 98.

    Returns:
        The clipped and scaled band
    """
    lo, hi = np.percentile(band, (pmin, pmax))
    band = np.clip(band, lo, hi)
    band = (band - lo) / (hi - lo)
    # band = (band - band.min()) / (band - band.max())
    return band

def make_rgb_from_bands(img, band_indices):
    """
    Stack RGB bands of clipped and scaled images

    Args:
        img: The image
        band_indices: The indices to stack e.g., RGB

    Returns:
        The stacked RGB bands
    """
    C, _, _ = img.shape
    channels = []
    for i in band_indices:
        if i < 0 or i >= C:
            raise IndexError(f"band index {i} out of range for C={C}")
        channels.append(img[i])
    rgb = np.stack([percentile_clip_and_scale(ch) for ch in channels], axis=-1)
    return rgb

In [None]:
# Applying the helper functions above and plotting selected samples
samples = select_samples(14)
color_idx = (3, 2, 1)
for s in samples:
    region, dates = s[0], s[1] # unpack region and date
    # unpack images and mask
    img1_rgb = make_rgb_from_bands(s[2]["image1"], color_idx)
    img2_rgb = make_rgb_from_bands(s[2]["image2"], color_idx)
    mask = s[2]["mask"]
    
    #Â plot images and mask
    fig, axs = plt.subplots(1, 3, figsize=(15, 7))
    axs[0].imshow(img1_rgb)
    date = datetime.fromtimestamp(int(dates[0]), tz=timezone.utc).isoformat()
    # date = datetime.fromtimestamp(int(dates[0])).strftime('%Y-%m-%d %H:%M:%S')
    # print(date)
    axs[0].set_title(f"{region.capitalize()}, {date}")
    axs[0].axis("off")

    axs[1].imshow(img2_rgb)
    date = datetime.fromtimestamp(int(dates[1]), tz=timezone.utc).isoformat()
    # date = datetime.fromtimestamp(int(dates[1])).strftime('%Y-%m-%d %H:%M:%S')
    axs[1].set_title(f"{region.capitalize()}, {date}")
    axs[1].axis("off")

    axs[2].imshow(mask, cmap='gray')
    axs[2].set_title("Mask")
    axs[2].axis("off")
    plt.show()