In [None]:
# Import libraries
import keras
import dask.array as da
from pathlib import Path
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
import pickle
from sklearn import svm

import dask
dask.config.set(scheduler='synchronous') # to avoid memory issues

# Classification

### Steps:
1. load data: 1) test sample; 2) class examples
1. load the trained Siamese model
1. pair the test sample with all class examples
1. compute similarity scores using Siamese model
1. perform classification based on the similarity scores:
    - average similarity score
    - K-nearest neighbors
    - Support Vector Machine (SVM)

### Load test sample

In [None]:
test_sample = xr.open_zarr('data/test_example_brazil.zarr/')
test_sample

In [None]:
# Visualize the test sample
test_sample['X'].astype('int').plot.imshow()

### Load example classes

In [None]:
class_files = sorted([file for file in Path('data/example_classes/').rglob('*.zarr')])
class_files

In [None]:
# Manually create a map between class text and integer label
class_map = {0: 'banana', 1:'cacao', 2:'fruit', 3:'palmtree'}

class_data={}
for class_i in range(len(class_files)):
    class_data[class_i] = xr.open_zarr(class_files[class_i])
class_data

In [None]:
# Visualize first three examples in each example class
fig, axs = plt.subplots(4, 3, figsize=(15, 15))
for class_i in range(len(class_data)):
    for example_i in range(3):
        class_data[class_i]['X'][example_i].astype('int').plot.imshow(ax=axs[class_i, example_i])

### load the trained Siamese model

In [None]:
@keras.saving.register_keras_serializable(package="MyLayers")
class euclidean_lambda(keras.layers.Layer):
    def __init__(self, **kwargs):
        super(euclidean_lambda, self).__init__(**kwargs)
        self.name = 'euclidean_lambda'

    def call(self, featA, featB):
        squared = keras.ops.square(featA-featB)
        return squared

In [None]:
keras.saving.get_custom_objects()

In [None]:
siamese_model = keras.models.load_model(
    "../optimized_models/siamese_model.keras",
    custom_objects = keras.saving.get_custom_objects()
)
siamese_model.summary()

In [None]:
# Due to memory limit, we make a function to compute the similarity score per batch
batch_size = 10  # number of samples to process at once to compute similarity score
def predict_per_chunk(x, y):
    """Compute similarity score between two sets of images in the same bacth."""
    return siamese_model.predict([x, y], verbose=0).squeeze()

# Compute similarity scores between the test sample and each example class
similarity_scores = {}
list_scores = []
for class_i in class_map.keys():
    
    # Make sample and example class data pairs
    shape = class_data[class_i]["sample"].shape[0]
    X_sample_norm = test_sample.expand_dims({"sample": shape})["X"] / 255.0
    X_class_norm = class_data[class_i]["X"] / 255.0

    # Chunk the data
    X_sample_norm = X_sample_norm.chunk({"sample": batch_size})
    X_class_norm = X_class_norm.chunk({"sample": batch_size})

    # Compute similarity scores per batch
    scores = da.map_blocks(
        predict_per_chunk,
        X_sample_norm.data,
        X_class_norm.data,
        dtype="float32",
        chunks=(batch_size,),
        drop_axis=(1, 2, 3),
    )
    scores = scores.compute()

    similarity_scores[class_i] = scores

In [None]:
similarity_scores

In [None]:
# Save the similarity scores to piclke file
with open('data/similarity_scores.pkl', 'wb') as f:
    pickle.dump(similarity_scores, f)

## Classification

In [None]:
# Load the similarity scores from pickle file
with open('data/similarity_scores.pkl', 'rb') as f:
    similarity_scores = pickle.load(f)
similarity_scores

### Method 1: Average similarity score

In [None]:
# Compute the average similarity score per example class
average_scores = {}
for class_i in similarity_scores.keys():
    average_scores[class_i] = np.mean(similarity_scores[class_i])
average_scores

In [None]:
predicted_class = class_map[np.argmax(list(average_scores.values()))]
print(f"Prediction by average similarity score: {predicted_class}")

## Metod 2: K-nearest neighbors

In [None]:
# Manual input: the number of K
k = 3

In [None]:
# Get the top K highest similarity scores and their corresponding class

# first search for the top k scores per class
top_k_scores = {}
for class_i in similarity_scores.keys():
    top_k_scores[class_i] = np.sort(similarity_scores[class_i])[-k:]

# Reverse the dictionary
reversed_dict = {vi: k for k, v in top_k_scores.items() for vi in v}

# then sort the top k scores from all classes and 
top_k_scores_all = np.concatenate(list(top_k_scores.values()))
top_k_scores_all_sorted = np.sort(top_k_scores_all)[::-1][0:k]

# find the class with most top k scores
top_k_classes = [reversed_dict[key] for key in top_k_scores_all_sorted]
counter = Counter(top_k_classes)
most_common_value = counter.most_common(1)[0][0]

print(f"Prediction by KNN: {class_map[most_common_value]}")

## Method 3: Support Vector Machine (SVM)

In [None]:
# Import the trained SVM and perform prediction based on statistis of similarity scores
# Load the trained SVM model from pickle file
with open('../optimized_models/svm_classifier.pkl', 'rb') as f:
    svm_model = pickle.load(f)
svm_model

In [None]:
# This prediction takes mean similarity scores of each class as input
mean_scores = np.array([average_scores[class_i] for class_i in range(len(class_map))]).reshape(1, -1)
mean_scores

In [None]:
# Make prediction
prediction = svm_model.predict(mean_scores)[0].astype(int)
print(f"Prediction by SVM: {class_map[prediction]}")
