In [None]:
import xarray as xr
import numpy as np
from pathlib import Path
from itertools import combinations, product

rng = np.random.default_rng(seed=42)

In [None]:
# Same labels aas step2
description_labels = {0: 'netflora_murumuru_embrapa00',
                      1: 'netflora_buriti_emprapa00', 
                      2: 'netflora_tucuma_emprapa00',
                      3: 'reforestree_banana',
                      4: 'reforestree_cacao',
                      5: 'reforestree_fruit',}

In [None]:
dir_cutouts = Path('/home/oku/Developments/XAI4GEO/data/cleaned_data/all_cutouts')
list_cutouts = list(dir_cutouts.glob('*.zarr'))
list_cutouts.sort()
for f_cutouts in list_cutouts:
    data = xr.open_zarr(f_cutouts)
    print(f_cutouts)
    print(f"shape:{data['X'].sizes}")
    print(f"label:{np.unique(data['Y'].values)}")
    print("---")


## Manually select cutouts to make training pairs

The strategy is taking 50 images of each class, exept label 4: cacao, which is not significantly different from label 5: fruit.

Label 1 only has 11, and label 3 has 29. So we will take all of them.

In [None]:
ds_label0 = xr.open_zarr(dir_cutouts / "label0_murumuru_embrapa00.zarr")
ds_label0 = ds_label0.isel(sample=range(13))

ds_label1 = xr.open_zarr(dir_cutouts / "label1_netflora_buriti_emprapa00.zarr").isel(
    sample=range(50)
)
ds_label2 = xr.open_zarr(dir_cutouts / "label2_netflora_tucuma_emprapa00.zarr")
idx_label3 = (
    list(range(0, 6))
    + list(range(24, 31))
    + list(range(35, 41))
    + list(range(47, 48))
    + list(range(51, 53))
    + list(range(54, 56))
    + list(range(93, 98))
    + list(range(99, 103))
    + list(range(576, 581))
    + list(range(596, 600))
)
ds_label3 = xr.open_zarr(dir_cutouts / "label3_reforestree_banana.zarr").isel(
    sample=idx_label3
)

idx_label5 = (
    [0, 3, 5]
    + list(range(18, 21))
    + list(range(29, 32))
    + [40, 47]
    + list(range(105, 111))
    + list(range(135, 142))
    + list(range(144, 145))
    + list(range(55, 58))
    + list(range(59, 61))
    + list(range(90, 92))
)
ds_label5 = xr.open_zarr(dir_cutouts / "label5_reforestree_fruit.zarr").isel(
    sample=idx_label5
)

## Check all samples

In [None]:
ds_label0['X'].plot.imshow(col='sample', col_wrap=5)

In [None]:
ds_label1['X'].plot.imshow(col='sample', col_wrap=5)

In [None]:
ds_label2['X'].plot.imshow(col='sample', col_wrap=5)

In [None]:
ds_label3['X'].plot.imshow(col='sample', col_wrap=5)

In [None]:
ds_label5['X'].plot.imshow(col='sample', col_wrap=5)

In [None]:
ds_label3['X'].values

## Save manually picked paris

In [None]:
for ds, id in zip([ds_label0, ds_label1, ds_label2, ds_label3, ds_label5], [0, 1, 2, 3, 5]):
    print(f"shape:{ds['X'].sizes}")
    print(f"label:{np.unique(ds['Y'].values)}")
    print("---")
    ds = ds.compute()
    ds['X'].encoding['_FillValue'] = 0
    ds.chunk({'sample': 50}).to_zarr(f"./label{id}_{description_labels[id]}.zarr", mode="w")

## Generate training pairs

In [None]:
path_cutouts_selected = Path("/home/oku/Developments/XAI4GEO/data/cleaned_data/selected_cutouts")

ds_label0 = xr.open_zarr(path_cutouts_selected / "label0_netflora_murumuru_embrapa00.zarr")
ds_label1 = xr.open_zarr(path_cutouts_selected / "label1_netflora_buriti_emprapa00.zarr")
ds_label2 = xr.open_zarr(path_cutouts_selected / "label2_netflora_tucuma_emprapa00.zarr")
ds_label3 = xr.open_zarr(path_cutouts_selected / "label3_reforestree_banana.zarr")
ds_label5 = xr.open_zarr(path_cutouts_selected / "label5_reforestree_fruit.zarr")

for ds, id in zip([ds_label0, ds_label1, ds_label2, ds_label3, ds_label5], [0, 1, 2, 3, 5]):
    print(f"shape:{ds['X'].sizes}")
    print(f"label:{np.unique(ds['Y'].values)}")
    print("---")

In [None]:
ds_all = xr.concat([ds_label0, ds_label1, ds_label2, ds_label3, ds_label5], dim='sample')
ds_all

In [None]:
def generate_train_image_pairs(images_dataset, labels_dataset):
    """Function to generate image pairs for training

    Parameters
    ----------
    images_dataset : image dataset
        Xarray DataArray containing the images, can be dask array
    labels_dataset : label dataset
        NumPy array for simplicity
    """
    labels_dataset = labels_dataset.compute()
    unique_labels = np.unique(labels_dataset.values)

    # Find the minimum number of samples
    min_n_sample = min(
        [
            images_dataset.where(labels_dataset == label, drop=True).sizes["sample"]
            for label in unique_labels
        ]
    )

    # Generate all possible pairs
    pairs_idx = list(combinations(range(min_n_sample), 2))

    pair_images = []
    pair_labels = []

    for label in unique_labels:
        # Images of current label
        imgs_curr = images_dataset.where(labels_dataset == label, drop=True)

        # select first min_n_sample samples for each label
        imgs_curr = imgs_curr.isel(sample=range(min_n_sample))

        # Find non similar class labels
        # To make dissimilar pairs
        label_other = np.setdiff1d(unique_labels, label)
        label_other = label_other[
            label_other > label
        ]  # Only select labels with higher value to avaoid duplicate pairs
        mask_da = xr.DataArray(np.isin(labels_dataset, label_other), dims="sample")

        # find labels_dataset in list label_other
        imgs_curr_other = images_dataset.where(mask_da, drop=True)

        # Make similar pairs
        for pair in pairs_idx:

            # Make two lists list_img1 and list_img2
            # Each contains the original and augmented images of one image in the pair
            list_img1_img2 = []
            for idx in [0, 1]: # augment the image 0 and 1
                # original image
                img = imgs_curr.isel(sample=pair[idx])

                # randomly rotate img 90, 180, 270
                img_rot = img.copy()
                img_rot.data = np.rot90(img.values, k=rng.integers(1, 4))

                # flip left-right img
                img_flip_lr = img.isel(x=slice(None, None, -1))

                # flip up-down img
                img_flip_ud = img.isel(y=slice(None, None, -1))

                img_list = [
                    img.expand_dims(pair=1),
                    img_rot.expand_dims(pair=1),
                    img_flip_lr.expand_dims(pair=1),
                    img_flip_ud.expand_dims(pair=1),
                ]
                list_img1_img2.append(img_list)
            list_img1 = list_img1_img2[0]
            list_img2 = list_img1_img2[1]

            # exhaustively select pairs between list_img1 and list_img2
            pairs_idx_similar = product(range(len(list_img1)), range(len(list_img2)))
            for pair_similar in pairs_idx_similar:
                curr_pair = xr.concat(
                    [
                        list_img1[pair_similar[0]],
                        list_img2[pair_similar[1]],
                    ],
                    dim="pair",
                )
                curr_pair = curr_pair.expand_dims(sample=1)
                pair_images.append(curr_pair)
                pair_labels.append(1)

        # Make dissimilar pairs
        # exhaustively select pairs between imgs_curr and imgs_curr_other
        pairs_idx_non_similar = product(
            range(imgs_curr.sizes["sample"]), range(imgs_curr_other.sizes["sample"])
        )
        # make combinations of pairs between idx_curr and idx_curr_other
        for pair in pairs_idx_non_similar:
            curr_pair_diff = xr.concat(
                [
                    imgs_curr.isel(sample=pair[0]).expand_dims(pair=1),
                    imgs_curr_other.isel(sample=pair[1]).expand_dims(pair=1),
                ],
                dim="pair",
            )
            curr_pair_diff = curr_pair_diff.expand_dims(sample=1)
            pair_images.append(curr_pair_diff)
            pair_labels.append(0)

    return xr.concat(pair_images, dim="sample"), np.array(pair_labels)

In [None]:
images_pair, labels_pair = generate_train_image_pairs(ds_all['X'], ds_all['Y'])
ds_images_pair = images_pair.to_dataset()
ds_images_pair = ds_images_pair.assign(Y = (['sample'], labels_pair))
ds_images_pair

In [None]:
# select similar and dissimilar pairs
ds_images_pair_similar = ds_images_pair.where(ds_images_pair['Y'] == 1, drop=True)
ds_images_pair_dissimilar = ds_images_pair.where(ds_images_pair['Y'] == 0, drop=True)

# Randomly select in ds_images_pair_similar to make it same size as ds_images_pair_dissimilar
idx_similar = rng.integers(0, ds_images_pair_similar.sizes['sample'], size=ds_images_pair_dissimilar.sizes['sample'])
ds_images_pair_similar = ds_images_pair_similar.isel(sample=idx_similar)

In [None]:
# Combine similar and dissimilar pairs one after the other
ds_images_pair = xr.concat([ds_images_pair_similar, ds_images_pair_dissimilar], dim='sample')
ds_images_pair

In [None]:
print(f"similar pairs: {np.sum(ds_images_pair['Y']==1).values}")
print(f"non similar pairs: {np.sum(ds_images_pair['Y']==0).values}")
ds_images_pair['Y'].plot()

In [None]:
idx_similar = range(0, ds_images_pair_dissimilar.sizes['sample'])
idx_non_similar = range(ds_images_pair_dissimilar.sizes['sample'], ds_images_pair.sizes['sample'])
#shuffle the indices
idx_similar_shuffled = rng.permutation(idx_similar)
idx_non_similar_shuffled = rng.permutation(idx_non_similar)
idx_mix = [val for pair in zip(idx_similar_shuffled, idx_non_similar_shuffled) for val in pair]
idx_mix

In [None]:
# reorder ds_images_pair in sample dimension, make similar and dissimilar pairs one after the other
ds_images_pair_shuffled = ds_images_pair.isel(sample=idx_mix)
print(f"similar pairs first 3200: {np.sum(ds_images_pair_shuffled['Y'].isel(sample=range(3200))==1).values}")
print(f"non similar pairs first 3200: {np.sum(ds_images_pair_shuffled['Y'].isel(sample=range(3200))==0).values}")
ds_images_pair_shuffled['Y'].plot()

In [None]:
# radomly plot 10 similar pairs
from matplotlib import pyplot as plt
ds_images_pair_shuffled_similar = ds_images_pair_shuffled.where(ds_images_pair_shuffled['Y']==1, drop=True)
idx_sel = rng.integers(0, ds_images_pair_shuffled_similar.sizes["sample"], size=10)
ds_plot = ds_images_pair_shuffled_similar.isel(sample=idx_sel)
fig, axs = plt.subplots(10, 2, figsize=(10, 60))
for i in range(10):
    ds_plot['X'].isel(sample=i, pair=0).astype('int').plot.imshow(ax=axs[i, 0])
    ds_plot['X'].isel(sample=i, pair=1).astype('int').plot.imshow(ax=axs[i, 1])

In [None]:
# radomly plot 10 non-similar pairs
from matplotlib import pyplot as plt
ds_images_pair_shuffled_similar = ds_images_pair_shuffled.where(ds_images_pair_shuffled['Y']==0, drop=True)
idx_sel = rng.integers(0, ds_images_pair_shuffled_similar.sizes["sample"], size=10)
ds_plot = ds_images_pair_shuffled_similar.isel(sample=idx_sel)
fig, axs = plt.subplots(10, 2, figsize=(10, 60))
for i in range(10):
    ds_plot['X'].isel(sample=i, pair=0).astype('int').plot.imshow(ax=axs[i, 0])
    ds_plot['X'].isel(sample=i, pair=1).astype('int').plot.imshow(ax=axs[i, 1])

In [None]:
ds_images_pair_shuffled

In [None]:
# change nan in ds_images_pair_shuffled 0
ds_images_pair_shuffled['X'] = ds_images_pair_shuffled['X'].fillna(0)


In [None]:
ds_images_pair_shuffled = ds_images_pair_shuffled.chunk({'sample': 500, 'pair': -1, 'y': -1, 'x': -1, 'channel': -1})
ds_images_pair_shuffled

In [None]:
# Save the dataset
ds_images_pair_shuffled.to_zarr('./traing_pairs.zarr', mode="w")