#### Test the performance of the n-shot leanrning using the leave-one-out test set of each fold
**Note**: each k and each fold is an independent training

In each k and each fold we do the following to test the performance of the siamese network:
1. Load the support set and the test set for each k and each fold 
2. Select a sample from the test set, and predict the similarity score between it with each sample in the support set
3. Repeat 2 for all the samples from the test set

In [1]:
import xarray as xr
import numpy as np
from pathlib import Path
from matplotlib import pyplot as plt

import keras
import tensorflow as tf
import os
from keras import backend as k
import numpy as np
import matplotlib.pyplot as plt
import pickle
import seaborn as sns
import pandas as pd
from pathlib import Path
import xarray as xr
import matplotlib.pyplot as plt
from keras.callbacks import LearningRateScheduler, ReduceLROnPlateau, Callback

from sklearn.metrics import classification_report, confusion_matrix

2024-09-01 22:20:49.179608: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-01 22:20:49.193277: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-01 22:20:49.197161: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-09-01 22:20:49.210992: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
# Load the support data and test data
# Here we take 3-shot 1 fold for example
# Change this to a loop to exhaust all the data partitioning 
k = 1
i = 5

support_smaples_path = f'/data/Projects/2024_Invasive_species/Tree_Classification/notebooks/data/n_fold_x_validation/{k}_shot_{i}_fold_supp_samples.zarr'
test_samples_path = f'/data/Projects/2024_Invasive_species/Tree_Classification/notebooks/data/n_fold_x_validation/{k}_shot_{i}_fold_test_samples.zarr'

support_samples = xr.open_zarr(support_smaples_path)
test_samples = xr.open_zarr(test_samples_path)

# support_samples
test_samples

Unnamed: 0,Array,Chunk
Bytes,11.25 MiB,11.25 MiB
Shape,"(30, 128, 128, 3)","(30, 128, 128, 3)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 11.25 MiB 11.25 MiB Shape (30, 128, 128, 3) (30, 128, 128, 3) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",30  1  3  128  128,

Unnamed: 0,Array,Chunk
Bytes,11.25 MiB,11.25 MiB
Shape,"(30, 128, 128, 3)","(30, 128, 128, 3)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,240 B,48 B
Shape,"(30,)","(6,)"
Dask graph,5 chunks in 2 graph layers,5 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 240 B 48 B Shape (30,) (6,) Dask graph 5 chunks in 2 graph layers Data type float64 numpy.ndarray",30  1,

Unnamed: 0,Array,Chunk
Bytes,240 B,48 B
Shape,"(30,)","(6,)"
Dask graph,5 chunks in 2 graph layers,5 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [3]:
# Load the refined model
@keras.saving.register_keras_serializable(package="MyLayers")
class euclidean_lambda(keras.layers.Layer):
    def __init__(self, **kwargs):
        super(euclidean_lambda, self).__init__(**kwargs)
        self.name = 'euclidean_lambda'

    def call(self, featA, featB):
        squared = keras.ops.square(featA-featB)
        return squared

# Refined model, the model retrained using the support set
refined_model_path = f'/data/Projects/2024_Invasive_species/Tree_Classification/optimized_models/refine_model/{k}_shot_{i}_fold/siamese_model_refined.keras'
refined_model = keras.saving.load_model(refined_model_path)

# Base model, the model trained only using the initial data
# base_model_path = '/data/Projects/2024_Invasive_species/Tree_Classification/optimized_models/siamese_model.keras'
base_model_path = '/data/Projects/2024_Invasive_species/Tree_Classification/optimized_models/results_training/Agu_pairs_training_v8/siamese_model_mobilenet03.keras'
base_model = keras.saving.load_model(base_model_path)

refined_model.summary()

2024-09-01 22:20:51.589627: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 674 MB memory:  -> device: 0, name: NVIDIA A40, pci bus id: 0000:21:00.0, compute capability: 8.6


In [4]:
### Compute the classification performance
def predict_label(gt_label, score_dic, metric="max"):
    """
    gt_label: ground truth label
    score_dic: the dic that contains the predicted similarity score for each support sample
               the key is the class label
    metric: the metric to aggregate the similarity scores across the support samples within each class
    
    return:
        result: [ifcorrect, similarity_score_of_the_target_class, predicted_class, similarity_score_of_the_predicted_class]
    """
    reduced_score = {}
           
    for key, values in score_dic.items():
        if metric == "avg": 
            reduced_score[key] = sum(values) / len(values) if values else 0
        elif metric == "max":
            reduced_score[key] = max(values) if values else 0
        
    largest_key = max(reduced_score, key=reduced_score.get)
    largest_value = reduced_score[largest_key]
    gt_label = int(gt_label)
    
    if gt_label==largest_key:
        result = [1, gt_label, largest_value, largest_key, largest_value]
    else:
        result = [0, gt_label, reduced_score[gt_label], largest_key, largest_value]
             
    return result

In [5]:
# predict the similarity score with each support sample and sort the similarity scores by class
def get_similarity_score(test_X, support_samples, model):   
    support_X = support_samples["X"] / 255.0  
    similarity_score = model.predict([test_X, support_X], verbose=0).squeeze()

    # store the score into each class dic
    unique_labels = np.unique(support_samples['Y'].values)
    score_dic = {int(unique_label):[] for unique_label in unique_labels}
    for j, support_Y in enumerate(support_samples['Y'].values):
        score_dic[support_Y].append(similarity_score[j])
    
    return score_dic

In [6]:
# Compute the results
zero_shot_results = np.zeros((0, 5))
refined_results = np.zeros((0, 5))
print('[ifcorrect, gt_label, similarity_score_of_the_target_class, predicted_class, similarity_score_of_the_predicted_class]')

num_test_smaples = len(test_samples['X']['sample'])
num_test_smaples = 18
for j in range(num_test_smaples):
    test_sample_j = test_samples.isel(sample=j)
    
    # Make the batch size as the total support_sample size
    support_sample_size =  len(support_samples['X']['sample'])  
    test_Y = test_sample_j['Y'].values 
    test_X = test_sample_j.expand_dims({"sample": support_sample_size})["X"] / 255.0  
    
    # ### Test the base model zero-shot learning
    # # Compute the similarity scores across classes
    zeroshot_score_dic = get_similarity_score(test_X, support_samples, base_model)
    
    # # Compute the prediction results
    zeroshot_result_j = predict_label(test_Y, zeroshot_score_dic, metric="avg")
    print("zero shot", zeroshot_result_j)
        
    
    ### Test the refined model for n-shot learning
    # Compute the similarity scores across classes
    nshot_score_dic = get_similarity_score(test_X, support_samples, refined_model)
    
    # Compute the prediction results
    nshot_result_j = predict_label(test_Y, nshot_score_dic, metric="avg")
    
    # All the results
    print("n-shot", nshot_result_j)
    
    # All the results
    zero_shot_results = np.vstack((zero_shot_results, zeroshot_result_j))
    refined_results = np.vstack((refined_results, nshot_result_j))

print("-" * 20)   
print("Overall accuracy of the base model", sum(zero_shot_results[:, 0])/num_test_smaples)
print("Overall accuracy of the refined model", sum(refined_results[:, 0])/num_test_smaples)  

gt_zero_shot = zero_shot_results[:, 1]  
pd_zero_shot = zero_shot_results[:, 3]
zero_shot_results = classification_report(gt_zero_shot, pd_zero_shot)
print("***** zero shot results *****")
print(zero_shot_results)

gt_n_shot = refined_results[:, 1]  
pd_n_shot = refined_results[:, 3]
n_shot_results = classification_report(gt_n_shot, pd_n_shot)
cm = confusion_matrix(gt_n_shot, pd_n_shot)
print("***** n shot results *****")
print(n_shot_results)
print(cm)

[ifcorrect, gt_label, similarity_score_of_the_target_class, predicted_class, similarity_score_of_the_predicted_class]


I0000 00:00:1725222053.500225 3596473 service.cc:146] XLA service 0x7f9138005a40 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1725222053.500294 3596473 service.cc:154]   StreamExecutor device (0): NVIDIA A40, Compute Capability 8.6
2024-09-01 22:20:53.560225: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2024-09-01 22:20:53.891382: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:531] Loaded cuDNN version 8902
I0000 00:00:1725222055.252766 3596473 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


zero shot [0, 333988661248, 0.594398021697998, 394585504768, 0.8974649906158447]
n-shot [0, 333988661248, 0.6139724254608154, 394585504768, 0.7296779155731201]
zero shot [0, 333988661248, 0.7416532039642334, 394585504768, 0.8384000658988953]
n-shot [0, 333988661248, 0.40814208984375, 394585504768, 0.6538810729980469]
zero shot [0, 333988661248, 0.7514841556549072, 399601058816, 0.8184464573860168]
n-shot [0, 333988661248, 0.2631523311138153, 394585504768, 0.2881317436695099]
zero shot [0, 333988661248, 0.702363133430481, 394585504768, 0.9665350317955017]
n-shot [0, 333988661248, 0.3744087815284729, 394585504768, 0.8574883937835693]
zero shot [0, 333988661248, 0.691322386264801, 399601058816, 0.8492051362991333]
n-shot [0, 333988661248, 0.356015682220459, 394585504768, 0.5804061889648438]
zero shot [0, 394585504768, 0.13548170030117035, 664680244048, 0.996723473072052]
n-shot [0, 394585504768, 8.603134915574628e-07, 664680244048, 0.9507625102996826]
zero shot [0, 394585504768, 0.9801779

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


| k-shot n-Fold | support set | test set     | n-Accuracy  |0-Accuracy  | 
| :---   |    :---              |        :---             | :---   |  :---   |
| 3-1      | (0, 1, 2)                |   (3, 4, 5)                       |  0.89   |  |
| 3-2      | (3, 4, 5)                      | (0, 1, 2)                        | 0.83    | |
| 2-1      | (0, 1)                      | (2, 3, 4, 5)                         | 0.61    |0.22 |
| 2-2      | (2, 3)                      | (0, 1, 4, 5)                         | 0.56    |0.22  |
| 2-3      | (4, 5)                      | (0, 1, 2, 3)                         | 0.56    |0.22  |
| 1-1      | (0)                      | (1, 2, 3, 4, 5)                         | 0.61    |0.31  |
| 1-2      | (1)                      | (0, 1, 2, 3, 4)                         | 0.56    |0.28  |
| 1-3      | (2)                      | (0, 1, 3, 4, 5)                         | 0.50    |0.22  |
| 1-4      | (3)                      | (0, 1, 2, 4, 5)                         | 0.56    |0.11  |
| 1-5      | (4)                      | (0, 1, 2, 3, 5)                         | 0.06    |0.00  |
| 1-6      | (5)                      | (0, 1, 2, 3, 4)                         | 0.61    |0.11  |