Data can be found in: https://zenodo.org/records/13380286

In [None]:
import xarray as xr
import numpy as np
from pathlib import Path
from itertools import combinations, product
from matplotlib import pyplot as plt
from skimage.transform import resize
from scipy.ndimage import rotate
import dask
dask.config.set(scheduler="synchronous")

rng = np.random.default_rng(seed=42)

In [None]:
# Select files used for training
dir_cutouts = Path("/home/oku/Developments/XAI4GEO/data/cleaned_data/all_cutouts")
list_files = [
    
    "label142377591163_murumuru.zarr",
    "label244751236943_tucuma.zarr",
    "label174675723264_banana.zarr",
    "label999240878592_cacao.zarr",
    "label370414265344_fruit.zarr",
]
for file in list_files:
    data = xr.open_zarr(dir_cutouts / file)
    print(file)
    print(f"shape:{data['X'].sizes}")
    print(f"label:{np.unique(data['Y'].values)}")
    print("---")

In [None]:
# Manually investigate and select
# f_invextigate = list_files[2]
# ds_investigate = xr.open_zarr(dir_cutouts / f_invextigate)
# print(f_invextigate)
# # ds_investigate["X"].plot.imshow(col="sample", col_wrap=5)
# ds_investigate["X"].isel(sample = range(0, 50)).plot.imshow(col="sample", col_wrap=5)

## Manually select cutouts to make training pairs


In [None]:
ds_murumuru = xr.open_zarr(dir_cutouts / "label142377591163_murumuru.zarr")
ds_murumuru = ds_murumuru.isel(sample = range(13))
ds_murumuru["X"].plot.imshow(col="sample", col_wrap=5)

In [None]:
ds_tucuma = xr.open_zarr(dir_cutouts / "label244751236943_tucuma.zarr")
ds_tucuma["X"].plot.imshow(col="sample", col_wrap=5)

In [None]:
ds_banana = xr.open_zarr(dir_cutouts / "label174675723264_banana.zarr")
idx_banana = [1,5,10,13,14,15,22,27,33,38,43,45,46,49,51,55,57,59,62,70,77,78,81,91,96,101]
ds_banana = ds_banana.isel(
    sample=idx_banana
)
# shuffle in sample dimension
ds_banana = ds_banana.sel(sample=rng.permutation(ds_banana.sample))
ds_banana["X"].plot.imshow(col="sample", col_wrap=5)

In [None]:
ds_cacao = xr.open_zarr(dir_cutouts / "label999240878592_cacao.zarr")
idx_cacao = [7,10,23,25,28,48,50,51,53,56,58,62,67,69]
ds_cacao = ds_cacao.isel(
    sample=idx_cacao
)
ds_cacao = ds_cacao.sel(sample=rng.permutation(ds_cacao.sample))
ds_cacao['X'].plot.imshow(col="sample", col_wrap=5)

In [None]:
ds_fruit = xr.open_zarr(dir_cutouts / "label370414265344_fruit.zarr")
idx_fruit = [5,8,11,15,16,19,22,25,27,32,35,41,47,49,50,53,54]
ds_fruit = ds_fruit.isel(
    sample=idx_fruit
)
ds_fruit = ds_fruit.sel(sample=rng.permutation(ds_fruit.sample))
ds_fruit['X'].plot.imshow(col="sample", col_wrap=5)

In [None]:
# Select three samples from each class, and plot
fig, axes = plt.subplots(3, 5, figsize=(20, 10))

for ax, idx in zip(axes, range(3)):
    ds_plot_murumuru = ds_murumuru.isel(sample=idx)
    ds_plot_tucuma = ds_tucuma.isel(sample=idx)
    ds_plot_banana = ds_banana.isel(sample=idx)
    ds_plot_cacao = ds_cacao.isel(sample=idx)
    ds_plot_fruit = ds_fruit.isel(sample=idx)
    for i, ds_plot in enumerate([ds_plot_murumuru, ds_plot_tucuma, ds_plot_banana, ds_plot_cacao, ds_plot_fruit]):
        ds_plot["X"].astype(np.int64).plot.imshow(ax=ax[i])
        plt.tight_layout()
        
        if idx == 0:
            ax[i].set_title(f"Species {i+1}", fontsize=20)
        if i == 0:
            ax[i].set_ylabel(f"Sample {idx+1}", fontsize=20)
            ax[i].set_yticks([])
            ax[i].set_xticks([])
            ax[i].set_xlabel("")
        else:
            ax[i].axis("off")

In [None]:
dir_selected_cutouts = Path("/home/oku/Developments/XAI4GEO/data/cleaned_data/selected_cutouts")
list_files = [
    "label142377591163_murumuru.zarr",
    "label244751236943_tucuma.zarr",
    "label174675723264_banana.zarr",
    "label999240878592_cacao.zarr",
    "label370414265344_fruit.zarr",
]

for ds, file in zip(
    [ds_murumuru, ds_tucuma, ds_banana, ds_cacao, ds_fruit], list_files
):
    print(file)
    print(f"shape:{ds['X'].sizes}")
    print(f"label:{np.unique(ds['Y'].values)}")
    ds.chunk({"sample": 50}).to_zarr(dir_selected_cutouts/file, mode="w"
    )
    print("---")

## Generate training pairs

In [None]:
dir_selected_cutouts = Path("/home/oku/Developments/XAI4GEO/data/cleaned_data/selected_cutouts")
list_files = [
    "label142377591163_murumuru.zarr",
    "label244751236943_tucuma.zarr",
    "label174675723264_banana.zarr",
    "label999240878592_cacao.zarr",
    "label370414265344_fruit.zarr",
]

# Load selected cutouts
list_ds = []
for file in list_files:
    ds = xr.open_zarr(dir_selected_cutouts / file)
    list_ds.append(ds)

ds_all = xr.concat(list_ds, dim="sample")
ds_all

In [None]:
# Function to add Gaussian noise to an RGB image
def add_gaussian_noise(image, mean=0, std=25):
    
    non_zeros = image>0
    # Generate Gaussian noise
    noise = np.random.normal(mean, std, image.shape)
    
    # Add the noise to the image
    noisy_image = image + noise
    
    # Clip the image to ensure pixel values are in the range [0, 255]
    noisy_image = np.clip(noisy_image, 0, 255).astype(np.int64)*non_zeros
    
    return noisy_image

def random_crop(img_crop, crop_size=(108, 108)):
    assert crop_size[0] <= img_crop.shape[0] and crop_size[1] <= img_crop.shape[1], "Crop size should be less than image size"
    w, h = img_crop.shape[:2]
    x, y = np.random.randint(h-crop_size[0]), np.random.randint(w-crop_size[1])
    img_crop = img_crop[y:y+crop_size[0], x:x+crop_size[1], :]
    img_crop = resize(img_crop, (w, h))
    img_crop = np.clip(img_crop, 0, 255)
    # img_crop = img_crop.astype(np.uint8)
    return img_crop

def aug_img_pair(img):
    """Augment a image and generate a list of augmented images

    Parameters
    ----------
    img_pair : list of xr.DataArray, size 2

    Returns
    -------
    _type_
        _description_
    """
    
    # randomly add gaussian noise
    img_gaussian = img.copy()
    img_gaussian.data = add_gaussian_noise(img_gaussian.values, mean=0, std=25)                       
            
    # randomly rotate img 90, 180, 270
    img_rot = img.copy()
    img_rot.data = np.rot90(img.values, k=rng.integers(1, 4))
    
    # random rotate another angle which is not 90, 180, 270
    angle = rng.integers(1, 359)
    while angle in {90, 180, 270}:
        angle = rng.integers(1, 359)
    img_ran_rot_1 = img.copy()
    img_ran_rot_1.data = np.clip(rotate(img_ran_rot_1.values, angle, reshape=False), 0, 255)
    
    # random rotate and add noise
    img_ran_rot_2 = img.copy()
    img_ran_rot_2.data = add_gaussian_noise(img_ran_rot_2.data, mean=0, std=25) 
    img_ran_rot_2.data = np.clip(rotate(img_ran_rot_2.values, angle/2, reshape=False), 0, 255)
    
    # random crop
    img_crop = img.copy()
    img_crop.data = random_crop(img_crop.values)

    # flip left-right img
    img_flip_lr = img.isel(x=slice(None, None, -1))

    # flip up-down img
    img_flip_ud = img.isel(y=slice(None, None, -1))

    img_list = [
        img,
        img_rot,
        img_flip_lr,
        img_flip_ud,
        img_ran_rot_1,
        img_ran_rot_2,
        img_crop
    ]
    
    return img_list

In [None]:
N_AUGMENTED = 7 # Number of augmented images plus original image

In [None]:
def generate_train_image_pairs(images_dataset, labels_dataset):
    """Function to generate image pairs for training

    Parameters
    ----------
    images_dataset : image dataset
        Xarray DataArray containing the images, can be dask array
    labels_dataset : label dataset
        NumPy array for simplicity
    """
    labels_dataset = labels_dataset.compute()
    unique_labels = np.unique(labels_dataset.values)

    # Find the minimum number of samples
    min_n_sample = min(
        [
            images_dataset.where(labels_dataset == label, drop=True).sizes["sample"]
            for label in unique_labels
        ]
    )

    # Generate a ds of augmented images
    images_dataset_aug = None
    for label in unique_labels:
        imgs_curr = images_dataset.where(labels_dataset == label, drop=True)
        imgs_curr = imgs_curr.isel(sample=range(min_n_sample))
        list_imgs_curr_aug = []
        for idx_img in range(min_n_sample):
            list_imgs_curr_aug = list_imgs_curr_aug + aug_img_pair(
                imgs_curr.isel(sample=idx_img)
            )
        da_curr_aug = xr.concat(list_imgs_curr_aug, dim="sample")
        ds_curr_aug = xr.Dataset({"X": da_curr_aug})
        ds_curr_aug = ds_curr_aug.assign(
            Y=xr.DataArray(np.full(da_curr_aug.sizes["sample"], label), dims="sample")
        )
        if images_dataset_aug is None:
            images_dataset_aug = ds_curr_aug
        else:
            images_dataset_aug = xr.concat(
                [images_dataset_aug, ds_curr_aug], dim="sample"
            )

    # Generate all possible similar pairs indices combinations
    pairs_idx_similar = list(combinations(range(min_n_sample * N_AUGMENTED), 2))
    pairs_idx_dissimilar = list(
        product(range(min_n_sample * N_AUGMENTED), range(min_n_sample * N_AUGMENTED))
    )

    label_dataset_aug = images_dataset_aug["Y"].compute()

    pair_images = []
    pair_labels = []

    for label in unique_labels:
        pair_images = []
        pair_labels = []
        # Images of current label
        imgs_curr = images_dataset_aug.where(label_dataset_aug == label, drop=True)
        imgs_curr = imgs_curr.expand_dims(pair=1)
        # Make similar pairs
        for pair in pairs_idx_similar:
            pair_images_sim = xr.concat(
                [
                    imgs_curr.isel(sample=pair[0]).expand_dims(sample=1),
                    imgs_curr.isel(sample=pair[1]).expand_dims(sample=1),
                ],
                dim="pair",
            )

            pair_images.append(pair_images_sim)
            pair_labels.append(1) # similar so label is 1

        # Find non similar class labels
        # To make dissimilar pairs
        label_other = np.setdiff1d(unique_labels, label)
        label_other = label_other[
            label_other > label
        ]  # Only select labels with higher value to avaoid duplicate dissimilar pairs
        mask_da = xr.DataArray(np.isin(label_dataset_aug, label_other), dims="sample")

        # find labels_dataset in list label_other
        imgs_curr_other = images_dataset_aug.where(mask_da, drop=True)

        # check length of pairs_idx_dissimilar
        labels_imgs_curr_other = imgs_curr_other["Y"].compute()

        # Make dissimilar pairs
        for label_other_curr in label_other:
            imgs_curr_other_curr = imgs_curr_other.where(
                labels_imgs_curr_other == label_other_curr, drop=True
            )
            for pair in pairs_idx_dissimilar:
                pair_images_dissim = xr.concat(
                    [
                        imgs_curr.isel(sample=pair[0]).expand_dims(sample=1),
                        imgs_curr_other_curr.isel(sample=pair[1]).expand_dims(sample=1),
                    ],
                    dim="pair",
                )
                pair_images.append(pair_images_dissim)
                pair_labels.append(0)  # dissimilar so label is 0
        
        # Write to zarr in batches
        ds_out = xr.concat(pair_images, dim="sample")
        ds_out = ds_out.assign(Y=(["sample"], pair_labels))
        ds_out = ds_out.chunk({"sample": 100, "pair": -1, "x": -1, "y": -1, "channel": -1})
        ds_out.to_zarr(f"./all_pairs_{label.astype(int)}.zarr")

    return None

In [None]:
# generate_train_image_pairs(ds_all["X"].isel(sample=range(0,90,3)), ds_all["Y"].isel(sample=range(0,90,3)))
generate_train_image_pairs(ds_all["X"], ds_all["Y"])

## Balance the training pairs

In [None]:
# Convert to a single zarr: training_pairs_unbalanced.zarr
# Load the generated pairs
list_ds=[]
for zarr_file in Path(".").glob("all_pairs_*.zarr"):
# for zarr_file in zarr_list[0:3]:
    ds = xr.open_zarr(zarr_file)
    list_ds.append(ds)
ds_images_pair = xr.concat(list_ds, dim="sample")

ds_images_pair_chunk = ds_images_pair.chunk({"sample": 100, "pair": -1, "x": -1, "y": -1, "channel": -1})
ds_images_pair_chunk.to_zarr("./training_pairs_unbalanced.zarr", mode="w")


In [None]:
ds_images_pair = xr.open_zarr("./training_pairs_unbalanced.zarr")
ds_images_pair

In [None]:
# select similar and dissimilar pairs
ds_images_pair_similar = ds_images_pair.where(ds_images_pair["Y"].compute() == 1, drop=True)
ds_images_pair_dissimilar = ds_images_pair.where(ds_images_pair["Y"].compute() == 0, drop=True)

In [None]:
#check before selection
print(f"similar pairs: {np.sum(ds_images_pair['Y']==1).values}")
print(f"non similar pairs: {np.sum(ds_images_pair['Y']==0).values}")
ds_images_pair["Y"].plot()

In [None]:
# Randomly select in ds_images_pair_dissimilar to make it same size as ds_images_pair_similar
idx_select = rng.integers(
    0,
    ds_images_pair_dissimilar.sizes["sample"],
    size=ds_images_pair_similar.sizes["sample"],
)
# order idx_select
idx_select = np.sort(idx_select)
ds_images_pair_dissimilar = ds_images_pair_dissimilar.isel(sample=idx_select, drop=True)

In [None]:
# Combine similar and dissimilar pairs one after the other
ds_images_pair_balanced = xr.concat(
    [ds_images_pair_similar, ds_images_pair_dissimilar], dim="sample"
)
ds_images_pair_balanced

In [None]:
#check after selection
print(f"similar pairs after balancing: {np.sum(ds_images_pair_balanced['Y']==1).values}")
print(f"non similar pairs after balancing: {np.sum(ds_images_pair_balanced['Y']==0).values}")
ds_images_pair_balanced["Y"].plot()

In [None]:
# Make a index list
# first shuffle within the similar and dissimilar parts
# then mix them
# Make sure there is every similar pair is followed by a dissimilar pair

idx_similar = range(0, ds_images_pair_balanced.sizes["sample"]//2) # Get similar pair idices
idx_non_similar = range(
    ds_images_pair_balanced.sizes["sample"]//2, ds_images_pair_balanced.sizes["sample"]
) # Get dissimilar pair indices
idx_similar_shuffled = rng.permutation(idx_similar) # Shuffle similar pair indices
idx_non_similar_shuffled = rng.permutation(idx_non_similar) # Shuffle dissimilar pair indices

# Mix the shuffled idices one after the other
idx_mix = [
    val for pair in zip(idx_similar_shuffled, idx_non_similar_shuffled) for val in pair
]
idx_mix

In [None]:
# reorder ds_images_pair in sample dimension, make similar and dissimilar pairs one after the other
ds_images_pair_balanced_shuffled = ds_images_pair_balanced.isel(sample=idx_mix)

In [None]:
# Check the number of similar and dissimilar pairs in the first and second half of dataset
half_size = ds_images_pair_similar.sizes["sample"]
print(
    f"similar pairs first half: {np.sum(ds_images_pair_balanced_shuffled['Y'].isel(sample=range(half_size))==1).values}"
)
print(
    f"non similar pairs first half: {np.sum(ds_images_pair_balanced_shuffled['Y'].isel(sample=range(half_size))==0).values}"
)
print(
    f"similar pairs second half: {np.sum(ds_images_pair_balanced_shuffled['Y'].isel(sample=range(half_size, half_size*2))==1).values}"
)
print(
    f"non similar pairs first half: {np.sum(ds_images_pair_balanced_shuffled['Y'].isel(sample=range(half_size, half_size*2))==0).values}"
)

In [None]:
# radomly plot 10 similar pairs
from matplotlib import pyplot as plt

ds_plot_similar = ds_images_pair_balanced_shuffled.where(
    ds_images_pair_balanced_shuffled["Y"].compute() == 1, drop=True
)
idx_sel = rng.integers(0, ds_plot_similar.sizes["sample"], size=10)
ds_plot = ds_plot_similar.isel(sample=idx_sel)
fig, axs = plt.subplots(10, 2, figsize=(10, 60))
for i in range(10):
    ds_plot["X"].isel(sample=i, pair=0).astype("int").plot.imshow(ax=axs[i, 0])
    ds_plot["X"].isel(sample=i, pair=1).astype("int").plot.imshow(ax=axs[i, 1])

In [None]:
# radomly plot 10 non-similar pairs
from matplotlib import pyplot as plt

ds_plot_dissimilar = ds_images_pair_balanced_shuffled.where(
    ds_images_pair_balanced_shuffled["Y"].compute() == 0, drop=True
)
idx_sel = rng.integers(0, ds_plot_dissimilar.sizes["sample"], size=10)
ds_plot = ds_plot_dissimilar.isel(sample=idx_sel)
fig, axs = plt.subplots(10, 2, figsize=(10, 60))
for i in range(10):
    ds_plot["X"].isel(sample=i, pair=0).astype("int").plot.imshow(ax=axs[i, 0])
    ds_plot["X"].isel(sample=i, pair=1).astype("int").plot.imshow(ax=axs[i, 1])

## Write to Zarr

In [None]:
# Save the dataset
# ds_images_pair_balanced_shuffled = ds_images_pair_balanced_shuffled.chunk(
#     {"sample": 500, "pair": -1, "y": -1, "x": -1, "channel": -1}
# )
# ds_images_pair_balanced_shuffled.to_zarr("./traing_pairs.zarr", mode="w")



## (When limited memory) Save the dataset in batches

In [None]:
batch = 2000
for i in range(0, ds_images_pair_balanced_shuffled.sizes["sample"], batch):
    idx_end = min(i + batch, ds_images_pair_balanced_shuffled.sizes["sample"])
    ds_out = ds_images_pair_balanced_shuffled.isel(sample=range(i, idx_end)).compute()
    ds_out = ds_out.chunk({"sample": 500, "pair": -1, "y": -1, "x": -1, "channel": -1})
    xr.unify_chunks(ds_out)
    ds_out = ds_out.chunk({"sample": 500, "pair": -1, "y": -1, "x": -1, "channel": -1})
    ds_out.to_zarr(f"training_pairs_parts/training_pairs_{i}.zarr", mode="w")

In [None]:
# Merge into one zarr
list_ds = []
for zarr_file in Path("./training_pairs_parts").glob("training_pairs_*.zarr"):
    ds = xr.open_zarr(zarr_file)
    list_ds.append(ds)
ds_images_pair_balanced_shuffled = xr.concat(list_ds, dim="sample")
ds_images_pair_balanced_shuffled

In [None]:
# Check again before saving
half_size = ds_images_pair_balanced_shuffled.sizes["sample"]//2
print(
    f"similar pairs first half: {np.sum(ds_images_pair_balanced_shuffled['Y'].isel(sample=range(half_size))==1).values}"
)
print(
    f"non similar pairs first half: {np.sum(ds_images_pair_balanced_shuffled['Y'].isel(sample=range(half_size))==0).values}"
)
print(
    f"similar pairs second half: {np.sum(ds_images_pair_balanced_shuffled['Y'].isel(sample=range(half_size, half_size*2))==1).values}"
)
print(
    f"non similar pairs first half: {np.sum(ds_images_pair_balanced_shuffled['Y'].isel(sample=range(half_size, half_size*2))==0).values}"
)

In [None]:
ds_images_pair_balanced_shuffled = ds_images_pair_balanced_shuffled.chunk({"sample": 500, "pair": -1, "y": -1, "x": -1, "channel": -1})
ds_images_pair_balanced_shuffled.to_zarr("./training_pairs.zarr", mode="w")