#### We retrain the model using the support set to do the few-shot learning
**Note: each k and each fold is an independent training.**

For example, in our experiment, we have 6 smaples in each class for the few-shot learning.

The data partitioning is listed below for 1-, 2-, 3-shot learning.

In total, we have 11 independent trained models. 

Meaning that, in each trained model, the support and test samples are different. The model never sees the test smaples. 






| K shot | \#Samples in supp. set | \#Samples in test set     | \#Folds  | 
| :---   |    :----:              |        :----:             | :----:   | 
| 1      | 1                      | 5                         | 6/1=6    | 
| 2      | 2                      | 4                         | 6/2=3    |
| 3      | 3                      | 3                         | 6/3=2    |

Data needed for this notebook is available on Zenodo:
- [n_fold_x_validation](https://zenodo.org/records/13833791/files/n_fold_x_validation.zip?download=1)

The output of this notebook, which the refined model, is available on Zenodo:
- [optimized_models](https://zenodo.org/records/13833791/files/optimized_models.zip?download=1)

In [None]:
import xarray as xr
import numpy as np
from pathlib import Path
from matplotlib import pyplot as plt

import keras
import tensorflow as tf
import os
from keras import backend as k
import numpy as np
import matplotlib.pyplot as plt
import pickle
import seaborn as sns
import pandas as pd
from pathlib import Path
import xarray as xr
import matplotlib.pyplot as plt
from keras.callbacks import ReduceLROnPlateau, Callback

# Change the parent dir to the correct dir on your machine 
# to make sure the following relative dirs to be working
os.chdir('/data/Projects/2024_Invasive_species/Tree_Classification')
print(os.getcwd())

#### In this notebook, we mannually set k and iii for k-shot learning in the iii fold, respectively

In [17]:
# Load the support dataset 
# Here we take 3-shot 1 fold for example
# Change this to a loop to exhaust all the data partitioning 
k = 1
iii = 6
support_path = f'./notebooks/data_agu/n_fold_x_validation/{k}_shot_{iii}_fold_supp_pairs.zarr'
support_set = xr.open_zarr(support_path)
support_set

images_pair = support_set["X"].to_numpy()/255 # Scale to [0, 1]
labels_pair = support_set["Y"].to_numpy()


### Specify the base model


In [None]:
# The shallow CNN model
base_model_name = 'CNN'
# Uncomment this to choose the mobilenet03 model
# base_model_name = 'mobilenet03'
base_model_dir = f'./optimized_models/results_training/Agu_pairs_training_v8/siamese_model_{base_model_name}.keras'


In [None]:
@keras.saving.register_keras_serializable(package="MyLayers")
class euclidean_lambda(keras.layers.Layer):
    def __init__(self, **kwargs):
        super(euclidean_lambda, self).__init__(**kwargs)
        self.name = 'euclidean_lambda'

    def call(self, featA, featB):
        squared = keras.ops.square(featA-featB)
        return squared

model = keras.saving.load_model(base_model_dir)
model.summary()

In [18]:
# Setup the refinement training
# Compile model
lr_init = 2.5e-05 # initial learning rate
metrics = [keras.metrics.BinaryAccuracy(threshold=0.5)]    
opt = keras.optimizers.Adam(learning_rate=lr_init)
loss = keras.losses.BinaryCrossentropy(from_logits=False)
model.compile(loss=loss, optimizer=opt, metrics=metrics)

# Configure training  
train_name = f'{k}_shot_{iii}_fold/'  
dir_training = Path("./optimized_models/refine_model/") / train_name
dir_training.mkdir(exist_ok=True)

In [5]:
def plot_history(history, metrics):
    """
    Plot the training history

    Args:
        history (keras History object that is returned by model.fit())
        metrics (str, list): Metric or a list of metrics to plot
    """
    history_df = pd.DataFrame.from_dict(history.history)
    fig = plt.figure(figsize=(10, 5))
    sns.lineplot(data=history_df[metrics])
    plt.xlabel("epochs")
    plt.ylabel("metric")
    plt.ylim(0, 1)

In [6]:
def plot_prediction(labels_pred, labels_pair):
    fig = plt.figure(figsize=(10, 5))
    
    sub = fig.add_subplot(1, 2, 1)
    sub.plot(labels_pred[labels_pair==1], '.', markersize=0.5)
    sub.plot(labels_pair[labels_pair==1])

    sub = fig.add_subplot(1, 2, 2)
    sub.plot(labels_pred[labels_pair==0], '.', markersize=0.5)
    sub.plot(labels_pair[labels_pair==0],)

In [7]:
def plot_prediction_hist(labels_pred):
    fig = plt.figure(figsize=(10, 5))

    counts, bins = np.histogram(labels_pred[labels_pair==1])
    sub = fig.add_subplot(1, 2, 1)
    sub.stairs(counts, bins)

    sub = fig.add_subplot(1, 2, 2)
    counts, bins = np.histogram(labels_pred[labels_pair==0])
    sub.stairs(counts, bins)

In [8]:
# Custom callback to plot predictions every 10 epochs
class PlotPredictionsCallback(Callback):
    def __init__(self, x_val, y_val, dir_training):
        super(PlotPredictionsCallback, self).__init__()
        self.x_val = x_val
        self.y_val = y_val
        self.dir_training = dir_training
    
    def on_epoch_end(self, epoch, logs=None):
        if epoch % 2 == 0:
            
            # Make predictions on the validation data
            images_pair = self.x_val
            labels_pred = self.model.predict([images_pair[:,0], images_pair[:,1]])
            
            dir_epoch = self.dir_training / f'epoch_{epoch+1}'
            dir_epoch.mkdir(exist_ok=True)
            
            # Plots
            plot_prediction(labels_pred, self.y_val)
            plt.savefig(dir_epoch/'prediction.png')
            plot_prediction_hist(labels_pred)
            plt.savefig(dir_epoch/'prediction_hist.png')
            plt.close('all')

In [None]:
# Set callbacks
# learning_rate_scheduler = LearningRateScheduler(lr_schedule)
learning_rate_scheduler = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-7)
plot_pred_callback = PlotPredictionsCallback(images_pair, labels_pair, dir_training)
earlystop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, start_from_epoch=5, restore_best_weights=True)
callbacks = [learning_rate_scheduler, plot_pred_callback, earlystop]

history = model.fit([images_pair[:,0], images_pair[:,1]], labels_pair[:], batch_size=4, epochs=10, validation_split=0.2, callbacks=callbacks)


In [10]:
# Save model
keras.saving.save_model(model, dir_training / 'siamese_model_refined.keras', overwrite=True)
model.save_weights(dir_training / "optimized_weights_refined.weights.h5")

In [None]:
with open(dir_training / "history.pkl", "wb") as file_pi:
        pickle.dump(history, file_pi)
    
# Model evaluation plots
model_loaded = keras.saving.load_model(dir_training / 'siamese_model_refined.keras', compile=False)
labels_pred = model.predict([images_pair[:, 0], images_pair[:, 1]])
plot_history(history, ['loss', 'binary_accuracy', 'val_loss', 'val_binary_accuracy'])
plt.savefig(dir_training / "history.png")
plot_prediction(labels_pred, labels_pair)
plt.savefig(dir_training / "prediction.png")
plot_prediction_hist(labels_pred)
plt.savefig(dir_training / "prediction_hist.png")