# Train SVM classifier

In [None]:
import numpy as np
import pickle
import xarray as xr
from sklearn import svm
from pathlib import Path
from keras import backend  # required for loading model
from keras.models import load_model
import dask.array as da

import dask
dask.config.set(scheduler='synchronous')

## Load data and Siamise network

In [None]:
imgs = xr.open_zarr('data/svm_training_data.zarr')
imgs

In [None]:
# Load example classes
class_files = sorted([file for file in Path('data/example_classes/').rglob('*.zarr')])
# Manually create a map between class text and integer label
class_map = {0: 'banana', 1:'cacao', 2:'fruit', 3:'palmtree'}
class_data={}
for class_i in range(len(class_files)):
    class_data[class_i] = xr.open_zarr(class_files[class_i])
class_data

In [None]:
# Load Siamese network
siamese_model = load_model('../optimized_models/siamese_model.h5')
siamese_model.summary()

## Compute similarity matrix for training data

In [None]:
# Due to memory limit, we make a function to compute the similarity score per batch
batch_size = 10  # number of samples to process at once to compute similarity score
def predict_per_chunk(x, y):
    """Compute similarity score between two sets of images in the same bacth."""
    return siamese_model.predict([x, y], verbose=0).squeeze()

# Compute similarity score for each class per sample
matrix_stats = np.empty((0, 4*len(class_map))) # each row: mean, std, max, min per class 
for sample_i in range(imgs.sizes['sample']):
    test_sample = imgs.isel(sample=sample_i)
    arr_stats = np.empty((0))
    print(f"Processing sample {sample_i}")
    for class_i in class_map.keys():

        # Make sample and example class data pairs
        shape = class_data[class_i]["sample"].shape[0]
        X_sample_norm = test_sample.expand_dims({"sample": shape})["X"] / 255.0
        X_class_norm = class_data[class_i]["X"] / 255.0

        # Debug: select 10 samples
        X_sample_norm = X_sample_norm.isel(sample=slice(0, 10))
        X_class_norm = X_class_norm.isel(sample=slice(0, 10))

        # Chunk the data
        X_sample_norm = X_sample_norm.chunk({"sample": batch_size})
        X_class_norm = X_class_norm.chunk({"sample": batch_size})

        # Compute similarity scores per batch
        scores = da.map_blocks(
            predict_per_chunk,
            X_sample_norm.data,
            X_class_norm.data,
            dtype="float32",
            chunks=(batch_size,),
            drop_axis=(1, 2, 3),
        )

        scores = scores.compute()

        # mean, std, max, min
        statistics = np.array([scores.mean(), scores.std(), scores.max(), scores.min()])
        arr_stats = np.hstack((arr_stats, statistics))
    matrix_stats = np.vstack((matrix_stats, arr_stats))

In [None]:
# save matrix_stats to npy
np.save("data/matrix_stats.npy", matrix_stats)

In [None]:
# load matrix_stats from npy
matrix_stats = np.load("data/matrix_stats.npy")

In [None]:
# def gaussian_kernel(mean_similarity, y):
#     """customized gaussian kernel function for SVM.
#     """
#     distance = 1 - mean_similarity

#     return np.exp(np.dot(distance,(-y.T)))

# classifier = svm.SVC(kernel=gaussian_kernel)

In [None]:
classifier = svm.SVC()
mean_similarity = matrix_stats[:, 0::4]
y = imgs['Y'].values
classifier.fit(mean_similarity, y)

In [None]:
# save the SVM model as a pickle file
with open('./svm_classifier.pkl', 'wb') as f:
    pickle.dump(classifier, f)