#### Test the performance of the few-shot leanrning using the leave-one-out test set of each fold
**Note**: each k and each fold is an independent training

In each k and each fold we do the following to test the performance of the siamese network:
1. Load the support set and the test set for each k and each fold 
2. Select a sample from the test set, and predict the similarity score between it with each sample in the support set
3. Repeat 2 for all the samples from the test set

Data needed for this notebook:

- [n_fold_x_validation](https://zenodo.org/records/13833791/files/n_fold_x_validation.zip?download=1)
- [refined models](https://zenodo.org/records/13833791/files/optimized_models.zip?download=1)

In [None]:
import xarray as xr
import numpy as np
from pathlib import Path
from matplotlib import pyplot as plt

import keras
import tensorflow as tf
import os
from keras import backend as k
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path
import xarray as xr
import matplotlib.pyplot as plt

from sklearn.metrics import classification_report, confusion_matrix
from scipy.special import softmax
from scipy.stats import entropy
from scipy.ndimage import rotate
from skimage.transform import resize

rng = np.random.default_rng(seed=42)

# Change the parent dir to the correct dir on your machine 
# to make sure the following relative dirs to be working
os.chdir('/data/Projects/2024_Invasive_species/Tree_Classification')
print(os.getcwd())

Here we manually take 3-shot 1 fold for example

Change the follow code block to a loop to exhaust all the data partitionings 

In [None]:
# Load the support data and test data
k = 3
iii = 2

support_smaples_path = f'./notebooks/data/n_fold_x_validation/{k}_shot_{iii}_fold_supp_samples.zarr'
test_samples_path = f'./notebooks/data/n_fold_x_validation/{k}_shot_{iii}_fold_test_samples.zarr'

support_samples = xr.open_zarr(support_smaples_path)
test_samples = xr.open_zarr(test_samples_path)

# support_samples
test_samples

### Specify the base and refined models

In [3]:
### Base model
# base_model_name = 'siamese_model_CNN'
# Uncomment this to choose the mobilenet03 model
base_model_name = 'siamese_model_mobilenet03'

### Refine model
# refined from the CNN
# refine_model_name = 'siamese_model_refined_CNN'
# Uncomment this to choose the refined model from mobilenet03
refine_model_name = 'siamese_model_refined_best'

In [None]:
# Load the refined model
@keras.saving.register_keras_serializable(package="MyLayers")
class euclidean_lambda(keras.layers.Layer):
    def __init__(self, **kwargs):
        super(euclidean_lambda, self).__init__(**kwargs)
        self.name = 'euclidean_lambda'

    def call(self, featA, featB):
        squared = keras.ops.square(featA-featB)
        return squared

# Refined model, the model retrained using the support set
refined_model_path = f'./optimized_models/refine_model/{k}_shot_{iii}_fold/{refine_model_name}.keras'
refined_model = keras.saving.load_model(refined_model_path)

# Base model, the model trained only using the initial data
base_model_path = f'./optimized_models/results_training/Agu_pairs_training_v8/{base_model_name}.keras'
base_model = keras.saving.load_model(base_model_path)

refined_model.summary()

In [5]:
### Compute the classification performance
def predict_label(gt_label, score_dic, metric="max"):
    """
    gt_label: ground truth label
    score_dic: the dic that contains the predicted similarity score for each support sample
               the key is the class label
    metric: the metric to aggregate the similarity scores across the support samples within each class
    
    return:
        result: [ifcorrect, similarity_score_of_the_target_class, predicted_class, similarity_score_of_the_predicted_class]
    """
    reduced_score = {}
           
    for key, values in score_dic.items():
        if metric == "avg": 
            reduced_score[key] = sum(values) / len(values) if values else 0
        elif metric == "max":
            reduced_score[key] = max(values) if values else 0
        
    largest_key = max(reduced_score, key=reduced_score.get)
    largest_value = reduced_score[largest_key]
    gt_label = int(gt_label)
    
    if gt_label==largest_key:
        result = [1, gt_label, largest_value, largest_key, largest_value]
    else:
        result = [0, gt_label, reduced_score[gt_label], largest_key, largest_value]
             
    return result

#### Add the correctness metirc
Correctness computes the ratio of the support samples belonging to the same class as the gt sample



In [6]:
def predict_cor_cst(gt, score_dic, k):
    '''
    gt_label: ground truth label
    score_dic: the dic that contains the predicted similarity score for each support sample
               the key is the class label
    '''
    # count the top-k similarity scores and check if they have the same key as the gt
    sim_scores_list = []
    for key, values in score_dic.items():
        for v in values:
            sim_scores_list.append([int(key), v])
    sim_scores = np.asarray(sim_scores_list)
    sorted_indices = np.argsort(sim_scores[:, 1])
    # get descending order
    sorted_indices = sorted_indices[::-1]
    sim_scores_sorted = sim_scores[sorted_indices]
    top_k_scores = sim_scores_sorted[:k, :]
    correctness = np.sum(top_k_scores == int(gt))/k
    
    # compute contrastivity
    sim_prob_sorted = np.concatenate((sim_scores_sorted[:, 0:1], softmax(sim_scores_sorted[:, 1:2])), axis=1)
    contrastivity_unnormalized = entropy(sim_prob_sorted[:k, 1])
    
    return correctness, contrastivity_unnormalized

In [7]:
# predict the similarity score with each support sample and sort the similarity scores by class
def get_similarity_score(test_X, support_samples, model):   
    support_X = support_samples["X"] / 255.0  
    similarity_score = model.predict([test_X, support_X], verbose=0).squeeze()

    # store the score into each class dic
    unique_labels = np.unique(support_samples['Y'].values)
    score_dic = {int(unique_label):[] for unique_label in unique_labels}
    for j, support_Y in enumerate(support_samples['Y'].values):
        score_dic[support_Y].append(similarity_score[j])
    
    return score_dic

#### Compute continuity
Agument the test_X using one of the seven augmentation schemes used for pairing


In [8]:
# Function to add Gaussian noise to an RGB image
def add_gaussian_noise(image, mean=0, std=25):
    
    non_zeros = image>0
    # Generate Gaussian noise
    np.random.seed(seed=42)
    noise = np.random.normal(mean, std, image.shape)
    
    # Add the noise to the image
    noisy_image = image + noise
    
    # Clip the image to ensure pixel values are in the range [0, 255]
    noisy_image = np.clip(noisy_image, 0, 255).astype(np.uint8)*non_zeros
    
    # # Convert back to uint8
    # noisy_image = noisy_image
    
    return noisy_image

def random_crop(img, crop_size=(108, 108)):
    assert crop_size[0] <= img.shape[0] and crop_size[1] <= img.shape[1], "Crop size should be less than image size"
    w, h = img.shape[:2]
    img = np.clip(img, 0, 255)
    x, y = np.random.randint(h-crop_size[0]), np.random.randint(w-crop_size[1])
    img_crop = img[y:y+crop_size[0], x:x+crop_size[1], :]   
    img_crop = resize(img_crop, (w, h))
    if not np.any(img_crop):
        return img_crop
    else:
        return img


def aug_img_pair(img):
    """Augment a image and generate a list of augmented images

    Parameters
    ----------
    img_pair : list of xr.DataArray, size 2

    Returns
    -------
    _type_
        _description_
    """
    
    # randomly add gaussian noise
    img_gaussian = img.copy()
    img_gaussian.data = add_gaussian_noise(img_gaussian.values, mean=0, std=25)                       
            
    # randomly rotate img 90, 180, 270
    img_rot = img.copy()
    img_rot.data = np.rot90(img.values, k=rng.integers(1, 4))
    
    # random rotate another angle which is not 90, 180, 270
    angle = rng.integers(1, 359)
    while angle in {90, 180, 270}:
        angle = rng.integers(1, 359)
    img_ran_rot_1 = img.copy()
    img_ran_rot_1.data = np.clip(rotate(img_ran_rot_1.values, angle, reshape=False), 0, 255)
    
    # random rotate and add noise
    img_ran_rot_2 = img.copy()
    img_ran_rot_2.data = add_gaussian_noise(img_ran_rot_2.data, mean=0, std=25) 
    img_ran_rot_2.data = np.clip(rotate(img_ran_rot_2.values, angle/2, reshape=False), 0, 255)
    
    # random crop
    img_crop = img.copy()
    img_crop.data = random_crop(img_crop.values)

    # flip left-right img
    img_flip_lr = img.isel(x=slice(None, None, -1))

    # flip up-down img
    img_flip_ud = img.isel(y=slice(None, None, -1))

    img_list = [
        img,
        img_rot,
        img_flip_lr,
        img_flip_ud,
        img_ran_rot_1,
        img_ran_rot_2,
        img_crop
    ]
    
    return img_list


def predict_cty(test_X, test_X_agu, support_samples, k, model):
    support_X = support_samples["X"] / 255.0 
    support_Y = support_samples["Y"].values.reshape(-1, 1)
    index_id = np.array([i for i in range(support_X.sizes['sample'])]).reshape(-1, 1)
    
    # get the similarity score and sort by descending order for original test image
    ori_similarity_score = model.predict([test_X, support_X], verbose=0)
    ori_similarity_score = np.concatenate((index_id, ori_similarity_score, support_Y), axis=1)
    ori_indices = np.argsort(ori_similarity_score[:, 1])
    ori_indices = ori_indices[::-1]
    ori_similarity_score = ori_similarity_score[ori_indices]
    
    # get the similarity score and sort by descending order for agumented test image
    agu_similarity_score = model.predict([test_X_agu, support_X], verbose=0)
    agu_similarity_score = np.concatenate((index_id, agu_similarity_score, support_Y), axis=1)
    agu_indices = np.argsort(agu_similarity_score[:, 1])
    agu_indices = agu_indices[::-1]
    agu_similarity_score = agu_similarity_score[agu_indices]
    
    # compute continuity
    count = 0      
    for i in agu_indices[:k]:
        if i in ori_indices[:k]:
            count+=1
    continuity = count/k
        
    return continuity, ori_similarity_score, agu_similarity_score 

#### Plot the predictions results 

In [9]:
import matplotlib
def plot_results(test_sample, support_samples, similarity_scores, explation_scores, k, name=None):
    """
    test_sample: the X and Y of the test image
    support_samples: the X and Y of the support images
    similarity_scores: [index, similarity_score, support_sample_classID]*lne(support images)
    explation_scores: [correctness, contrastivity, continuity]
    """
    matplotlib.rcParams.update({'font.size': 14})
    # only predict the top-k support images
    lable_dic = {92352972800: 'Species 11',
                333988661248: 'Species 6',
                394585504768: 'Species 7',
                399601058816: 'Species 8',
                578797953024: 'Species 10',
                664680244048: 'Species 9'}
    
    # ugly normalization of Contrastivity
    # k=3, max_Contrastivity = 1.1
    max_Contrastivity = 1.1
    
    fig1, axs1 = plt.subplots(1, k+1, figsize=(4.5*(k+1), 4))
    plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[])
    ex_str = "Correctness: " + f"{explation_scores[0]:.2f}, " + "Continuity: " + f"{explation_scores[2]:.2f}, " + "Contrastivity: " + f"{explation_scores[1]/max_Contrastivity:.2f} \n"
    if k!=1:
        fig1.suptitle("Support samples - "+ex_str, y=1, x=0.625)
    for i in range(k+1):
        if i==0:
            test_sample['X'].astype('int').plot.imshow(ax=axs1[i])
            # ex_str = "(cor: " + f"{explation_scores[0]:.2f}, " + "cty: " + f"{explation_scores[2]:.2f}, " + "cst: " + f"{explation_scores[1]:.2f})"
            axs1[i].set_title('Input: ' + lable_dic[int(test_sample['Y'].values)], pad=1)
        else:
            support_samples['X'].isel(sample=int(similarity_scores[i-1, 0])).astype('int').plot.imshow(ax=axs1[i])
            support_label = int(support_samples['Y'].isel(sample=int(similarity_scores[i-1, 0])).values)
            support_label = lable_dic[support_label] + ' ' + '(sim: '+ f"{similarity_scores[i-1, 1]:.2f})"
            axs1[i].set_title(support_label, pad=1) 
        
        axs1[i].set_xlabel('')
        axs1[i].set_ylabel('')
              
    results_dir = Path(f'./optimized_models/refine_model/{k}_shot_{iii}_fold')/'plot_evaluations/'
    results_dir.mkdir(exist_ok=True)
    plt.savefig(os.path.join(results_dir, f'sample_{name}.png'))


In [None]:
# Compute the results
zero_shot_results = np.zeros((0, 5))
refined_results = np.zeros((0, 5))
print('[ifcorrect, gt_label, similarity_score_of_the_target_class, predicted_class, similarity_score_of_the_predicted_class]')

zero_shot_correctness = []
n_shot_correctness = []

zero_shot_contrastivity  = []
n_shot_contrastivity = []

zero_shot_continuity  = []
n_shot_continuity = []

num_test_smaples = len(test_samples['X']['sample'])
for j in range(num_test_smaples):
    test_sample_j = test_samples.isel(sample=j)
    
    # Make the batch size as the total support_sample size
    support_sample_size =  len(support_samples['X']['sample'])  
    test_Y = test_sample_j['Y'].values 
    test_X = test_sample_j.expand_dims({"sample": support_sample_size})["X"] / 255.0  
    
    # ### Test the base model zero-shot learning
    # # Compute the similarity scores across classes
    zeroshot_score_dic = get_similarity_score(test_X, support_samples, base_model)
    
    # # Compute the prediction results
    zeroshot_result_j = predict_label(test_Y, zeroshot_score_dic, metric="avg")
    print("zero shot", zeroshot_result_j)
        
    
    ### Test the refined model for k-shot learning
    # Compute the similarity scores across classes
    nshot_score_dic = get_similarity_score(test_X, support_samples, refined_model)
    
    # Compute the prediction results
    nshot_result_j = predict_label(test_Y, nshot_score_dic, metric="avg")
    
    # All the results
    print("k-shot", nshot_result_j)
    
    # correctness, contrastivity
    zero_shot_cor, zero_shot_cst= predict_cor_cst(test_Y, zeroshot_score_dic, k)
    zero_shot_correctness.append(zero_shot_cor)
    zero_shot_contrastivity.append(zero_shot_cst)
    
    n_shot_cor, n_shot_cst = predict_cor_cst(test_Y, nshot_score_dic, k)
    n_shot_correctness.append(n_shot_cor)
    n_shot_contrastivity.append(n_shot_cst)
    
    # All the results
    zero_shot_results = np.vstack((zero_shot_results, zeroshot_result_j))
    refined_results = np.vstack((refined_results, nshot_result_j))
    
    # Compute the continuity
    index = np.random.randint(1, 7)
    test_x_agu = aug_img_pair(test_sample_j['X'])[index]
    test_X_agu = test_x_agu.expand_dims({"sample": support_sample_size})/255.0 
      
    # fig1, axs1 = plt.subplots(support_sample_size, 2, figsize=(6, 60))
    # for i in range(support_sample_size):
    #     test_X.isel(sample=i).plot.imshow(ax=axs1[i, 0])
    #     test_X_agu.isel(sample=i).plot.imshow(ax=axs1[i, 1])
        
    zero_shot_cty, zero_ori_similarity_score, zero_agu_similarity_score = predict_cty(test_X, test_X_agu, support_samples, k, base_model)
    zero_shot_continuity.append(zero_shot_cty)
    
    n_shot_cty, n_ori_similarity_score, n_agu_similarity_score  = predict_cty(test_X, test_X_agu, support_samples, k, refined_model)
    n_shot_continuity.append(n_shot_cty)
    
    ### Plot the results
    zero_shor_explanation_scores = [zero_shot_cor, zero_shot_cst, zero_shot_cty]
    n_shot_explanation_scores = [n_shot_cor, n_shot_cst, n_shot_cty]
    
    plot_results(test_sample_j, support_samples, n_ori_similarity_score, n_shot_explanation_scores, k, name=j)
     
print("-" * 20)   
print("Overall accuracy of the base model", sum(zero_shot_results[:, 0])/num_test_smaples)
print("Overall accuracy of the refined model", sum(refined_results[:, 0])/num_test_smaples) 

print("Correctness of the base model", sum(zero_shot_correctness)/len(zero_shot_correctness))
print("Correctness of the refined model", sum(n_shot_correctness)/len(n_shot_correctness))

print("zero shot contrastivity", sum(zero_shot_contrastivity/max(zero_shot_contrastivity))/len(zero_shot_correctness)) 
print("n shot contrastivity", sum(n_shot_contrastivity/max(n_shot_contrastivity))/len(zero_shot_correctness)) 

print("zero shot continuity", sum(zero_shot_continuity)/len(zero_shot_correctness)) 
print("n shot continuity", sum(n_shot_continuity)/len(zero_shot_correctness))

gt_zero_shot = zero_shot_results[:, 1]  
pd_zero_shot = zero_shot_results[:, 3]
zero_shot_results = classification_report(gt_zero_shot, pd_zero_shot)
print("***** zero shot results *****")
print(zero_shot_results)

gt_n_shot = refined_results[:, 1]  
pd_n_shot = refined_results[:, 3]
n_shot_results = classification_report(gt_n_shot, pd_n_shot)
cm = confusion_matrix(gt_n_shot, pd_n_shot)
print("***** n shot results *****")
print(n_shot_results)
print(cm)