#### Test the performance of the n-shot leanrning using the leave-one-out test set of each fold
**Note**: each the based model
In each k and each fold we do the following to test the performance of the siamese network:
1. Load the support set and the test set for each k and each fold 
2. Select a sample from the test set, and predict the similarity score between it with each sample in the support set
3. Repeat 2 for all the samples from the test set

In [None]:
import xarray as xr
import numpy as np
from pathlib import Path
from matplotlib import pyplot as plt

import keras
import tensorflow as tf
import os
from keras import backend as k
import numpy as np
import matplotlib.pyplot as plt
import pickle
import seaborn as sns
import pandas as pd
from pathlib import Path
import xarray as xr
import matplotlib.pyplot as plt
from keras.callbacks import LearningRateScheduler, ReduceLROnPlateau, Callback

from sklearn.metrics import classification_report, confusion_matrix

In [None]:
# Load the support data and test data

data_dir = Path("../../../../data/cleaned_data/selected_cutouts/")
list_samples_file = [
    "label142377591163_murumuru.zarr",
    "label244751236943_tucuma.zarr",
    "label174675723264_banana.zarr",
    "label999240878592_cacao.zarr",
    "label370414265344_fruit.zarr"
]

support_samples = None
test_samples = None
for file in list_samples_file:
    ds_samples = xr.open_zarr(data_dir / file)
    n_sample = ds_samples.sizes["sample"]
    if support_samples is None:
        support_samples = ds_samples.isel(sample=range(13))
    else:
        support_samples = xr.concat([support_samples, ds_samples.isel(sample=range(13))], dim="sample")
    if n_sample>13:
        if test_samples is None:
            test_samples = ds_samples.isel(sample=range(13, n_sample))
        else:
            test_samples = xr.concat([test_samples, ds_samples.isel(sample=range(13, n_sample))], dim="sample")

In [None]:
support_samples

In [None]:
test_samples['Y'].plot.hist(bins=20)

In [None]:
# Load the refined model
@keras.saving.register_keras_serializable(package="MyLayers")
class euclidean_lambda(keras.layers.Layer):
    def __init__(self, **kwargs):
        super(euclidean_lambda, self).__init__(**kwargs)
        self.name = 'euclidean_lambda'

    def call(self, featA, featB):
        squared = keras.ops.square(featA-featB)
        return squared

# Base model, the model trained only using the initial data
base_model_path = '../../optimized_models/results_training/Agu_pairs_training_v8/siamese_model_mobilenet03.keras'
base_model = keras.saving.load_model(base_model_path)

base_model.summary()

In [None]:
### Compute the classification performance
def predict_label(gt_label, score_dic, metric="max"):
    """
    gt_label: ground truth label
    score_dic: the dic that contains the predicted similarity score for each support sample
               the key is the class label
    metric: the metric to aggregate the similarity scores across the support samples within each class
    
    return:
        result: [ifcorrect, similarity_score_of_the_target_class, predicted_class, similarity_score_of_the_predicted_class]
    """
    reduced_score = {}
    for key, values in score_dic.items():
        if metric == "avg": 
            reduced_score[key] = sum(values) / len(values) if values else 0
        elif metric == "max":
            reduced_score[key] = max(values) if values else 0
        
    largest_key = max(reduced_score, key=reduced_score.get)
    largest_value = reduced_score[largest_key]
    gt_label = int(gt_label)
    
    if gt_label==largest_key:
        result = [1, gt_label, largest_value, largest_key, largest_value]
    else:
        result = [0, gt_label, reduced_score[gt_label], largest_key, largest_value]
             
    return result

In [None]:
# predict the similarity score with each support sample and sort the similarity scores by class
def get_similarity_score(test_X, support_samples, model):   
    support_X = support_samples["X"] / 255.0  
    similarity_score = model.predict([test_X, support_X], verbose=0).squeeze()

    # store the score into each class dic
    unique_labels = np.unique(support_samples['Y'].values)
    score_dic = {int(unique_label):[] for unique_label in unique_labels}
    for j, support_Y in enumerate(support_samples['Y'].values):
        score_dic[support_Y].append(similarity_score[j])
    
    return score_dic

In [None]:
predict_metric = "avg" # "avg" or "max"

In [None]:
# Compute the results
base_results = np.zeros((0, 5))
# refined_results = np.zeros((0, 5))
# print(
#     "[ifcorrect, gt_label, similarity_score_of_the_target_class, predicted_class, similarity_score_of_the_predicted_class]"
# )

num_test_samples = test_samples.sizes["sample"]
# num_test_samples = 18
for j in range(num_test_samples):
    test_sample_j = test_samples.isel(sample=j)

    # Make the batch size as the total support_sample size
    support_sample_size = len(support_samples["X"]["sample"])
    test_Y = test_sample_j["Y"].values
    test_X = test_sample_j.expand_dims({"sample": support_sample_size})["X"] / 255.0

    # ### Test the base model zero-shot learning
    # # Compute the similarity scores across classes
    zeroshot_score_dic = get_similarity_score(test_X, support_samples, base_model)

    # # Compute the prediction results
    zeroshot_result_j = predict_label(test_Y, zeroshot_score_dic, metric=predict_metric)
    print("zero shot", zeroshot_result_j)

    # ### Test the refined model for n-shot learning
    # # Compute the similarity scores across classes
    # nshot_score_dic = get_similarity_score(test_X, support_samples, refined_model)

    # # Compute the prediction results
    # nshot_result_j = predict_label(test_Y, nshot_score_dic, metric="avg")

    # # All the results
    # print("n-shot", nshot_result_j)

    # All the results
    base_results = np.vstack((base_results, zeroshot_result_j))
    # refined_results = np.vstack((refined_results, nshot_result_j))

print("-" * 20)
print(
    "Overall accuracy of the base model",
    sum(base_results[:, 0]) / num_test_samples,
)
# print(
#     "Overall accuracy of the refined model",
#     sum(refined_results[:, 0]) / num_test_samples,
# )

gt_zero_shot = base_results[:, 1]
pd_zero_shot = base_results[:, 3]
base_results = classification_report(gt_zero_shot, pd_zero_shot)
cm = confusion_matrix(gt_zero_shot, pd_zero_shot)
print("***** zero shot results *****")
print(base_results)
print(cm)
# gt_n_shot = refined_results[:, 1]
# pd_n_shot = refined_results[:, 3]
# n_shot_results = classification_report(gt_n_shot, pd_n_shot)
# cm = confusion_matrix(gt_n_shot, pd_n_shot)
# print("***** n shot results *****")
# print(n_shot_results)
# print(cm)

| k-shot n-Fold | support set | test set     | n-Accuracy  |0-Accuracy  | 
| :---   |    :---              |        :---             | :---   |  :---   |
| 3-1      | (0, 1, 2)                |   (3, 4, 5)                       |  0.89   |  |
| 3-2      | (3, 4, 5)                      | (0, 1, 2)                        | 0.83    | |
| 2-1      | (0, 1)                      | (2, 3, 4, 5)                         | 0.61    |0.22 |
| 2-2      | (2, 3)                      | (0, 1, 4, 5)                         | 0.56    |0.22  |
| 2-3      | (4, 5)                      | (0, 1, 2, 3)                         | 0.56    |0.22  |
| 1-1      | (0)                      | (1, 2, 3, 4, 5)                         | 0.61    |0.31  |
| 1-2      | (1)                      | (0, 1, 2, 3, 4)                         | 0.56    |0.28  |
| 1-3      | (2)                      | (0, 1, 3, 4, 5)                         | 0.50    |0.22  |
| 1-4      | (3)                      | (0, 1, 2, 4, 5)                         | 0.56    |0.11  |
| 1-5      | (4)                      | (0, 1, 2, 3, 5)                         | 0.06    |0.00  |
| 1-6      | (5)                      | (0, 1, 2, 3, 4)                         | 0.61    |0.11  |