# Make training pairs for refinment training

This step makes training data for the refinement training. 

It takes data in `few_shot_learning`, which can be retrieved from the link below:

- `few_shot_learning`: [link](https://zenodo.org/records/13833791/files/few_shot_learning.zip?download=1)

The output of this step is available on Zonodo:

 - `n_fold_x_validation`: [link](https://zenodo.org/records/13833540/files/n_fold_x_validation.zip?download=1)

In [None]:
import xarray as xr
import numpy as np
from pathlib import Path
from itertools import combinations, product, combinations_with_replacement
import os
from sklearn.model_selection import KFold
from scipy.ndimage import rotate
from skimage.transform import resize
from matplotlib import pyplot as plt

rng = np.random.default_rng(seed=42)

# Change the parent dir to the correct dir on your machine 
# to make sure the following relative dirs to be working
os.chdir('/data/Projects/2024_Invasive_species/Tree_Classification')
print(os.getcwd())

#### So far, just use the selected cutout samples to counterbalance the unevenly distributed samples across all the classes

In [None]:
path_cutouts_selected = Path("./notebooks/data_agu/selected_cutouts")

ds_label0 = xr.open_zarr(path_cutouts_selected / "label142377591163_murumuru.zarr")
ds_label1 = xr.open_zarr(path_cutouts_selected / "label244751236943_tucuma.zarr")
ds_label2 = xr.open_zarr(path_cutouts_selected / "label174675723264_banana.zarr")
ds_label3 = xr.open_zarr(path_cutouts_selected / "label999240878592_cacao.zarr")
ds_label5 = xr.open_zarr(path_cutouts_selected / "label370414265344_fruit.zarr")

ds_all_base = xr.concat([ds_label0, ds_label1, ds_label2, ds_label3, ds_label5], dim='sample')
ds_all_base

#### Set up the test samples for few-shot learning
The data for the few shot learning is stored in 

'./notebooks/step0_data_preparation_examples/n_shots_learning/data/'

In [None]:
test_data_path = Path('./notebooks/data_agu/few_shot_learning')
test_ds_path = os.path.join(test_data_path, 'Tree_labels_merged', 'tree_labels_merged.zarr')
test_ds = xr.open_zarr(test_ds_path)
test_ds = test_ds.compute() # Load data into memory since it's small
test_ds

#### Select the cutouts, trying to reduce the bias in terms of the correlation between canopy size and class

In [31]:
unique_species_ids = np.unique(test_ds['Y'].values)
unique_species_names = [test_ds.attrs[str(id)]['ESPECIE'] for id in unique_species_ids]

for id, name in zip(unique_species_ids, unique_species_names):
    ds_species = test_ds.where(test_ds['Y']==id,drop=True)
    if ds_species.sizes['sample'] > 5:
        ds_species = ds_species.sel(x=slice(64,192),y=slice(64,192))

        for i in range(ds_species.sizes['sample']):
            ds_i = ds_species.isel(sample=i)
            cutout = ds_i['X']
            cutout.values = np.clip(cutout.values, 0, 255)

            # Get the size of non zero part
            coutout_nonzero = cutout.values
            coutout_nonzero = coutout_nonzero[~(coutout_nonzero == 0).all(axis=(0, 2))]
            idx = np.nonzero(~((coutout_nonzero == 0).all(axis=(0,2))))
            coutout_nonzero = coutout_nonzero[:, idx[0], :]
            x_size = coutout_nonzero.shape[0]
            y_size = coutout_nonzero.shape[1]

            if x_size < 64 and y_size < 64:
                # Select the non zero part in cutout
                cutout = cutout.isel(x=range(64-int(x_size/2),64+int(x_size/2)),y=range(64-int(y_size/2),64+int(y_size/2)))

                # Interpolate the non zero part to 64x64 pixels
                cutout = cutout.interp(x=np.linspace(cutout.x.min(), cutout.x.max(), 64),y=np.linspace(cutout.y.min(), cutout.y.max(), 64))

                # Reset the x and y coordinates to 0-64
                cutout['x'] = range(32,96)
                cutout['y'] = range(32,96)

                # Pad the interpolated image to 128x128 pixels with zeros
                cutout = cutout.interp(x=range(0,128),y=range(0,128), kwargs={'fill_value':0})

                # update ds_i
                ds_i['X'] = cutout
            
            if i==0:
                ds_output = ds_i
            else:
                ds_output = xr.concat([ds_output, ds_i], dim='sample')
        
        name = name.replace(' ', '_')
        ds_output.to_zarr(test_data_path / 'Tree_labels_merged' / f'{id}_{name}.zarr', mode='w')

In [None]:
test_file_path = Path(os.path.join(test_data_path, 'Tree_labels_merged'))
test_file_list = list(test_file_path.glob('*.zarr'))
test_file_list.sort()
# remove the non-filtered .zarr
test_file_list = test_file_list[:-1]

test_file_list

In [None]:
for f_cutouts in test_file_list:
    data = xr.open_zarr(f_cutouts)
    print(f_cutouts)
    print(f"shape:{data['X'].sizes}")
    print(f"label:{np.unique(data['Y'].values)}")
    print("---")

In [27]:
# Because the smallest number of samples across the test tree species is 6
# we set the random maximum size to 6 
M = 6
select_balanced_data_list  = []
for f_cutouts in test_file_list:
    data = xr.open_zarr(f_cutouts)
    selected_samples = rng.integers(0, data.sizes["sample"], size=M)
    select_balanced_data_list.append(data.isel(sample=selected_samples))

#### Partitioning the test data for few-shot learning with n-fold cross-validation.

In [8]:
def n_fold_cross_idx(k, M):
    support_idx_list = []
    test_idx_list = []
    # Initialize KFold with N folds
    idx = np.arange(M)
    # umcomment the following line to shuffle the order
    # kf = KFold(n_splits=N, shuffle=True, random_state=1) 
    # fix the order
    kf = KFold(n_splits=int(M/k), shuffle=False)
    # Iterate through each fold
    for fold, (test_idx, support_idx) in enumerate(kf.split(idx)):
        support_idx_list.append(support_idx)
        test_idx_list.append(test_idx)
    return support_idx_list, test_idx_list

In [9]:
def n_class_k_shot(data_list, k, M=6):
    
    # save the data samples as a list for each fold
    # in total, there are int(int(M/k)) folds for a given k
    support_sets = [[] for i in range(int(M/k))]
    test_sets = [[] for i in range(int(M/k))]
    
    # get the n-fold cross-validation indexing
    support_idx_list, test_idx_list =  n_fold_cross_idx(k, M)
    
    for i, (support_idx, test_idx) in enumerate(zip(support_idx_list, test_idx_list)):
        
        for data in data_list:
            support_set = data.isel(sample=support_idx)
            support_sets[i].append(support_set)
            test_set = data.isel(sample=test_idx)
            test_sets[i].append(test_set)
            print(f"Fold {i + 1}:")
            print(f"Support indices: {support_idx}")
            print(f"Test indices: {test_idx}")
            print(f"support_set shape:{support_set['X'].sizes}")
            print(f"support_set label:{np.unique(support_set['Y'].values)}")           
            print(f"test_set shape:{test_set['X'].sizes}")
            print(f"test_set label:{np.unique(test_set['Y'].values)}")            
            print("-" * 20)
        
        # concatenate the samples across each class    
        support_sets[i] = xr.concat(support_sets[i], dim='sample')
        test_sets[i] = xr.concat(test_sets[i], dim='sample')
    
    return support_sets, test_sets

In [10]:
# Function to add Gaussian noise to an RGB image
def add_gaussian_noise(image, mean=0, std=25):
    
    non_zeros = image>0
    # Generate Gaussian noise
    np.random.seed(seed=42)
    noise = np.random.normal(mean, std, image.shape)
    
    # Add the noise to the image
    noisy_image = image + noise
    
    # Clip the image to ensure pixel values are in the range [0, 255]
    noisy_image = np.clip(noisy_image, 0, 255).astype(np.uint8)*non_zeros
    
    return noisy_image

In [11]:
def random_crop(img, crop_size=(108, 108)):
    assert crop_size[0] <= img.shape[0] and crop_size[1] <= img.shape[1], "Crop size should be less than image size"
    w, h = img.shape[:2]
    img = np.clip(img, 0, 255)
    x, y = np.random.randint(h-crop_size[0]), np.random.randint(w-crop_size[1])
    img_crop = img[y:y+crop_size[0], x:x+crop_size[1], :]   
    img_crop = resize(img_crop, (w, h))
    if not np.any(img_crop):
        return img_crop
    else:
        return img

In [12]:
def aug_img_pair(img):
    """Augment a image and generate a list of augmented images

    Parameters
    ----------
    img_pair : list of xr.DataArray, size 2

    Returns
    -------
    _type_
        _description_
    """
    
    # randomly add gaussian noise
    img_gaussian = img.copy()
    img_gaussian.data = add_gaussian_noise(img_gaussian.values, mean=0, std=25)                       
            
    # randomly rotate img 90, 180, 270
    img_rot = img.copy()
    img_rot.data = np.rot90(img.values, k=rng.integers(1, 4))
    
    # random rotate another angle which is not 90, 180, 270
    angle = rng.integers(1, 359)
    while angle in {90, 180, 270}:
        angle = rng.integers(1, 359)
    img_ran_rot_1 = img.copy()
    img_ran_rot_1.data = np.clip(rotate(img_ran_rot_1.values, angle, reshape=False), 0, 255)
    
    # random rotate and add noise
    img_ran_rot_2 = img.copy()
    img_ran_rot_2.data = add_gaussian_noise(img_ran_rot_2.data, mean=0, std=25) 
    img_ran_rot_2.data = np.clip(rotate(img_ran_rot_2.values, angle/2, reshape=False), 0, 255)
    
    # random crop
    img_crop = img.copy()
    img_crop.data = random_crop(img_crop.values)

    # flip left-right img
    img_flip_lr = img.isel(x=slice(None, None, -1))

    # flip up-down img
    img_flip_ud = img.isel(y=slice(None, None, -1))

    img_list = [
        img,
        img_rot,
        img_flip_lr,
        img_flip_ud,
        img_ran_rot_1,
        img_ran_rot_2,
        img_crop
    ]
    
    return img_list

In [13]:
def generate_similar_pair(imgs_curr, pair=None):
    pair_images = []
    pair_labels = []
    # Make two lists list_img1 and list_img2
    # Each contains the original and augmented images of one image in the pair
    list_img1_img2 = []
    for idx in [0, 1]: # augment the image 0 and 1
        # original image
        if pair is not None:
            if idx == 0:
                img = imgs_curr.isel(sample=pair[idx])
            else:
                img = imgs_curr.isel(sample=pair[idx])
                img.values = add_gaussian_noise(img.values, mean=0, std=25)
                img.data = np.rot90(img.values, k=rng.integers(1, 4))
                
        else:
            if idx == 0:
                img = imgs_curr.isel(sample=0)
            # Add noise to slightly corrupt the image to make a psudo new image
            else:
                img = imgs_curr.isel(sample=0)
                img.values = add_gaussian_noise(img.values, mean=0, std=25)
                img.data = np.rot90(img.values, k=rng.integers(1, 4))                           
                
        img_list = aug_img_pair(img)
        list_img1_img2.append(img_list)
    
    list_img1 = list_img1_img2[0]
    list_img2 = list_img1_img2[1]

    # exhaustively select pairs between list_img1 and list_img2
    pairs_idx_similar = product(range(len(list_img1)), range(len(list_img2)))
    for pair_similar in pairs_idx_similar:
        curr_pair = xr.concat(
            [
                list_img1[pair_similar[0]],
                list_img2[pair_similar[1]],
            ],
            dim="pair",
        )
        curr_pair = curr_pair.expand_dims(sample=1)
        pair_images.append(curr_pair)
        pair_labels.append(1)
    print('number of pairs positive samples', len(pair_labels))   
    return pair_images, pair_labels

In [14]:
def generate_dissimilar_pair(imgs_curr, imgs_curr_other):
    pair_images = []
    pair_labels = []
    
    ### Add the agumentations to imgs_curr and imgs_curr_other
    imgs_curr.sizes["sample"]
    imgs_curr_other.sizes["sample"]
    
    list_imgs_curr_aug = []
    list_imgs_curr_other_aug = []
    
    for i in range(imgs_curr.sizes["sample"]):
        list_imgs_curr_aug = list_imgs_curr_aug + aug_img_pair(
                imgs_curr.isel(sample=i)
            )
    imgs_curr = xr.concat(list_imgs_curr_aug, dim="sample")
        
    for i in range(imgs_curr_other.sizes["sample"]):
        list_imgs_curr_other_aug = list_imgs_curr_other_aug + aug_img_pair(
                imgs_curr_other.isel(sample=i)
            )
    imgs_curr_other = xr.concat(list_imgs_curr_other_aug, dim="sample")
            
    # exhaustively select pairs between imgs_curr and imgs_curr_other
    pairs_idx_non_similar = product(
        range(imgs_curr.sizes["sample"]), range(imgs_curr_other.sizes["sample"])
    )
    
    # make combinations of pairs between idx_curr and idx_curr_other
    for pair in pairs_idx_non_similar:
        curr_pair_diff = xr.concat(
            [
                imgs_curr.isel(sample=pair[0]).expand_dims(pair=1),
                imgs_curr_other.isel(sample=pair[1]).expand_dims(pair=1),
            ],
            dim="pair",
        )
        curr_pair_diff = curr_pair_diff.expand_dims(sample=1)
        pair_images.append(curr_pair_diff)
        pair_labels.append(0)
    print('imgs_curr', imgs_curr.sizes["sample"])
    print('imgs_curr_other', imgs_curr_other.sizes["sample"])
    print('number of pairs negtive samples', len(pair_labels))
    return pair_images, pair_labels

In [15]:
def generate_refine_image_paris(support_set, base_set):
    """Function to generate new image pairs for the k-shot learning
    
    Parameters
    ----------
    support_set: the dataset contains the support samples of images and labels of the 
        Xarray DataArray 
    labels_dataset: the dataset contains the original samples of images and labels of the
        Xarray DataArray
    """
    support_labels = support_set['Y'].compute()
    support_iamges = support_set['X']
    unique_support_labels = np.unique(support_labels.values)
    
    # Get the combined datasets
    combined_set = xr.concat([support_set, base_set], dim="sample")
    combined_images = combined_set['X']
    combined_labels = combined_set['Y'].compute()
    unique_combined_labels = np.unique(combined_labels.values)
    
    
    # Find the minimum number of samples
    min_n_sample = min(
        [
            support_labels.where(support_labels == label, drop=True).sizes["sample"]
            for label in support_labels
        ]
    )
    
    pair_images = []
    pair_labels = []
     
    # Only loop the unique_support_labels to avoid repetative pairing from the base_set
    for label in unique_support_labels:
        # Images of current label
        imgs_curr = support_iamges.where(support_labels == label, drop=True)

        # select first min_n_sample samples for each label
        # this step is not necessary because in the support set the samples across classes are balanced
        imgs_curr = imgs_curr.isel(sample=range(min_n_sample))

        # Find non similar class labels
        # To make dissimilar pairs
        # A dissimilar image can come from the base_set or the other images from the support_set
        # Therefore, here we use the unique_combined_labels 
        label_other = np.setdiff1d(unique_combined_labels, label)
        mask_da = xr.DataArray(np.isin(combined_labels, label_other), dims="sample")

        # find labels_dataset in list label_other
        imgs_curr_other = combined_images.where(mask_da, drop=True)
        
        # Generate positive pairs
        # When there is more than one unique support sample in each class 
        if min_n_sample>=2:
            # Generate all possible pairs
            pairs_idx = list(combinations_with_replacement(range(min_n_sample), 2))
            for pair in pairs_idx:
                _pair_images_pos, _pair_labels_pos = generate_similar_pair(imgs_curr, pair=pair)
                pair_images += _pair_images_pos
                pair_labels += _pair_labels_pos
        # When there is only one unique support sample in each class
        else:
            _pair_images_pos, _pair_labels_pos = generate_similar_pair(imgs_curr, pair=None)
            pair_images += _pair_images_pos
            pair_labels += _pair_labels_pos
                     
        # Generative negtive pairs
        _pair_images_neg, _pair_labels_neg = generate_dissimilar_pair(imgs_curr, imgs_curr_other)
        pair_images += _pair_images_neg
        pair_labels += _pair_labels_neg
     
    return xr.concat(pair_images, dim="sample"), np.array(pair_labels) 

#### Two steps for the few-shot learning
* Using the test set to do the zero-shot test
* Using the support set to do the few-shot learning

**Note: in this notebook, we mannually set k to 1, 2, 3 to generate all the partitioning.**

Later, the finally performance results are averaged across all the folds within each k.

In [None]:
# pairing
k = 1
support_sets, test_sets = n_class_k_shot(select_balanced_data_list, k)
print(f"{len(support_sets)} folds")

In [17]:
# Plot some of the pairs
def shuffle_pairs(ds_images_pair, npair=10, plot_examples=True):
    idx_similar = range(0, ds_images_pair.sizes['sample']//2)
    idx_non_similar = range(ds_images_pair.sizes['sample']//2, ds_images_pair.sizes['sample'])
    #shuffle the indices
    idx_similar_shuffled = rng.permutation(idx_similar)
    idx_non_similar_shuffled = rng.permutation(idx_non_similar)
    idx_mix = [val for pair in zip(idx_similar_shuffled, idx_non_similar_shuffled) for val in pair]
    
    ds_images_pair_shuffled = ds_images_pair.isel(sample=idx_mix)
    ds_images_pair_shuffled['X'] = ds_images_pair_shuffled['X'].fillna(0)
    
    if plot_examples:
        # radomly plot 10 similar pairs
        ds_images_pair_shuffled_similar = ds_images_pair_shuffled.where(ds_images_pair_shuffled['Y']==1, drop=True)
        idx_sel = rng.integers(0, ds_images_pair_shuffled_similar.sizes["sample"], size=10)
        ds_plot1 = ds_images_pair_shuffled_similar.isel(sample=idx_sel)
        fig1, axs1 = plt.subplots(10, 2, figsize=(10, 60))
        for i in range(npair):
            ds_plot1['X'].isel(sample=i, pair=0).astype('int').plot.imshow(ax=axs1[i, 0])
            ds_plot1['X'].isel(sample=i, pair=1).astype('int').plot.imshow(ax=axs1[i, 1])
            assert ds_plot1['X'].isel(sample=i, pair=0).shape[0]==ds_plot1['X'].isel(sample=i, pair=0).shape[1]==128, "image wh is not 128"
            assert ds_plot1['X'].isel(sample=i, pair=1).shape[0]==ds_plot1['X'].isel(sample=i, pair=1).shape[1]==128, "image wh is not 128"
                
        # radomly plot 10 dissimilar pairs
        ds_images_pair_shuffled_dissimilar = ds_images_pair_shuffled.where(ds_images_pair_shuffled['Y']==0, drop=True)
        idx_sel = rng.integers(0, ds_images_pair_shuffled_dissimilar.sizes["sample"], size=10)
        ds_plot2 = ds_images_pair_shuffled_dissimilar.isel(sample=idx_sel)
        fig2, axs2 = plt.subplots(10, 2, figsize=(10, 60))
        for i in range(npair):
            ds_plot2['X'].isel(sample=i, pair=0).astype('int').plot.imshow(ax=axs2[i, 0])
            ds_plot2['X'].isel(sample=i, pair=1).astype('int').plot.imshow(ax=axs2[i, 1])
            assert ds_plot2['X'].isel(sample=i, pair=0).shape[0]==ds_plot2['X'].isel(sample=i, pair=0).shape[1]==128, "image wh is not 128"
            assert ds_plot2['X'].isel(sample=i, pair=1).shape[0]==ds_plot2['X'].isel(sample=i, pair=1).shape[1]==128, "image wh is not 128"
        
    return ds_images_pair_shuffled 

In [18]:
# Balance the negtive and positive pairs
def get_balanced_pairs(images_pair, labels_pair):
    ds_images_pair = images_pair.to_dataset()
    ds_images_pair = ds_images_pair.assign(Y = (['sample'], labels_pair))
    
    # select similar and dissimilar pairs
    ds_images_pair_similar = ds_images_pair.where(ds_images_pair['Y'] == 1, drop=True)
    ds_images_pair_dissimilar = ds_images_pair.where(ds_images_pair['Y'] == 0, drop=True)
  
    similar_pair_size = ds_images_pair_similar.sizes['sample']
    dissimilar_pair_size = ds_images_pair_dissimilar.sizes['sample']
    if similar_pair_size>=dissimilar_pair_size:
        idx_similar = rng.integers(0, similar_pair_size, size=dissimilar_pair_size)
        ds_images_pair_similar = ds_images_pair_similar.isel(sample=idx_similar)
    else:
        idx_dissimilar = rng.integers(0, dissimilar_pair_size, size=similar_pair_size)
        ds_images_pair_dissimilar = ds_images_pair_dissimilar.isel(sample=idx_dissimilar)  
    
    # Combine similar and dissimilar pairs one after the other
    ds_images_pair = xr.concat([ds_images_pair_similar, ds_images_pair_dissimilar], dim='sample')
    
    print(f"similar pairs: {np.sum(ds_images_pair['Y']==1).values}")
    print(f"non similar pairs: {np.sum(ds_images_pair['Y']==0).values}")
    
    return ds_images_pair

In [None]:
for i in range(len(support_sets)):
    print(f"{i+1}/{len(support_sets)} fold, {k}-shot learning")
    
    ##---- support samples ----##
    support_set = support_sets[i] 
    support_set['X'] = support_set['X'].fillna(0)
    support_set = support_set.chunk({'sample': 500, 'y': -1, 'x': -1, 'channel': -1})
    support_set.to_zarr(f'./notebooks/data_agu/n_fold_x_validation/{k}_shot_{i+1}_fold_supp_samples.zarr', mode="w")
    support_set
     
    ##---- test samples ----##
    # Save corresponding leave-one-out test set for each fold
    test_set = test_sets[i]
    test_set['X'] = test_set['X'].fillna(0)
    test_set = test_set.chunk({'sample': 500, 'y': -1, 'x': -1, 'channel': -1})
    test_set.to_zarr(f'./notebooks/data_agu/n_fold_x_validation/{k}_shot_{i+1}_fold_test_samples.zarr', mode="w")
    test_set
    
    ##---- Pairing ----##
    # Get all the pairs with respect to the support dataset
    # The pairs within the base set are not considered any more 
    # meaning that at least one image is from the support set in each generated pair
    images_pair, labels_pair = generate_refine_image_paris(support_set, ds_all_base)
    
    # # Balance the positive and negtive pairs, and shuffle the pairs
    ds_images_pair = get_balanced_pairs(images_pair, labels_pair)    
    print("-" * 20)    
    ds_images_pair_shuffled = shuffle_pairs(ds_images_pair, npair=10)
    
    ## Save the data for the few-shot learning
    ds_images_pair_shuffled.to_zarr(f'./notebooks/data_agu/n_fold_x_validation/{k}_shot_{i+1}_fold_supp_pairs.zarr', mode="w")