In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install tensorflow_addons

In [3]:
import os
import tensorflow as tf
import tensorflow_addons as tfa
tfa.register.register_all()
import cv2
import matplotlib.pyplot as plt
import numpy as np
import numpy.random as rng
import seaborn as sns
import pickle
from sklearn.metrics import classification_report

In [4]:
class PokeMetric:
    def __init__(self, width, height, n_channel, support_dataset, support_size=151):
        self.width = width
        self.height = height
        self.n_channel = n_channel
        self.support = []
        self.model = None
        self.test_images = []
        self.test_labels = []
        self.current_epoch = 0
        self.support_size = support_size
        self.load_support_images(support_dataset)

    def load_support_images(self, dataset_location):
        self.support = []
        labels = []
        for data in os.listdir(dataset_location):
            name = data.split('.')[0]
            loading_mode = 0 if self.n_channel == 1 else 1
            if name.isdigit() and int(name) <= self.support_size:
                image = cv2.imread(dataset_location + data, loading_mode)
                image = cv2.resize(image, (self.width, self.height))
                self.support.append(image / 255.)
                labels.append(int(name))
        self.support = [self.support[i] for i in np.argsort(labels)]

    def load_val_images(self, dataset_location):
        self.test_images = []
        self.test_labels = []
        loading_mode = 0 if self.n_channel == 1 else 1
        for data in os.listdir(dataset_location):
            name = data.split('-')[0]
            if int(name) <= self.support_size:
                image = cv2.imread(dataset_location + data, loading_mode)
                image = cv2.resize(image, (self.width, self.height))
                self.test_images.append(image / 255.)
                self.test_labels.append(int(name)-1)
        self.test_labels = np.array(self.test_labels)

    def build_model(self, drop_rate=0, kernel_reg=0, std_init=0.01):
        input_shape = (self.width, self.height, self.n_channel)
        left_input = tf.keras.layers.Input(input_shape)
        right_input = tf.keras.layers.Input(input_shape)
        
        w_init = tf.keras.initializers.RandomNormal(0,std_init)
        b_init = tf.keras.initializers.RandomNormal(0.5,std_init)
        regul = tf.keras.regularizers.l2(kernel_reg)

        convnet = tf.keras.models.Sequential()
        convnet.add(tf.keras.layers.Conv2D(64, (10, 10), activation='relu', input_shape=input_shape,
                                          kernel_initializer=w_init, kernel_regularizer=regul))
        convnet.add(tf.keras.layers.MaxPooling2D())
        convnet.add(tf.keras.layers.Conv2D(128, (7, 7), activation='relu',
                                          kernel_initializer=w_init, kernel_regularizer=regul,bias_initializer=b_init))
        convnet.add(tf.keras.layers.MaxPooling2D())
        convnet.add(tf.keras.layers.Conv2D(128, (4, 4), activation='relu',
                                          kernel_initializer=w_init, kernel_regularizer=regul,bias_initializer=b_init))
        convnet.add(tf.keras.layers.MaxPooling2D())
        convnet.add(tf.keras.layers.Conv2D(256, (4, 4), activation='relu',
                                          kernel_initializer=w_init, kernel_regularizer=regul,bias_initializer=b_init))
        convnet.add(tf.keras.layers.Dropout(drop_rate))
        convnet.add(tf.keras.layers.Flatten())
        convnet.add(tf.keras.layers.Dense(4096, activation="sigmoid",
                                         kernel_initializer=tf.keras.initializers.RandomNormal(0,0.1), 
                                          kernel_regularizer=regul,
                                          bias_initializer=tf.keras.initializers.RandomNormal(0,0.1)))
        convnet.add(tf.keras.layers.Dropout(drop_rate))

        encoded_l = convnet(left_input)
        encoded_r = convnet(right_input)
        merged = tf.math.abs(encoded_l - encoded_r)
        prediction = tf.keras.layers.Dense(1, activation='sigmoid',
                                          kernel_initializer=tf.keras.initializers.RandomNormal(0,0.1), 
                                          kernel_regularizer=regul,
                                          bias_initializer=tf.keras.initializers.RandomNormal(0,0.1))(merged)

        self.model = tf.keras.Model(inputs=[left_input, right_input], outputs=prediction)
        self.current_epoch = 0

    def compile(self, optimizer, loss):
        self.model.compile(optimizer, loss)

    def get_batch(self, batch_size=32):
        #get a batch of training pairs, positive get labeled 1 (same images) negative get labeled 0 (different images)
        pairs = [[], []]
        # data augmentation by randomly rotate
        train_support = np.array(tfa.image.rotate(np.array(self.support).reshape(self.support_size, self.width, self.height, self.n_channel),
                                         rng.randn(self.support_size)))
        # avoid overfitting by adding random noise to the background full of zeros
        train_support += (0.1*rng.randn(self.support_size, self.width, self.height, self.n_channel) + 0.5)*(train_support == 0)
        ref = rng.randint(0, self.support_size, batch_size//2)
        indices = list(np.arange(self.support_size))
        pairs[0] = [train_support[i] for i in ref]
        pairs[1] = [train_support[i] for i in ref]

        
        for i in range(batch_size//2):
            pairs[0].append(train_support[ref[i]])
            indices.remove(ref[i])
            pairs[1].append(train_support[rng.choice(indices)])
            indices.append(ref[i])
        pairs[0] = np.array(pairs[0])
        pairs[1] = np.array(pairs[1])
        labels = np.array([1]*(batch_size//2) + [0]*(batch_size//2))
        del train_support
        return pairs, labels
    
    def get_val_batch(self, batch_size):
        # get a validation batch meaning negative and positive pairs with the first from the support dataset and the second from the ecg images
        pairs = [[], []]
        ref = rng.randint(0, self.support_size, batch_size//2)
        pairs[0] = [self.support[i] if i in self.test_labels else self.support[0] for i in ref]
        pairs[1] = [self.test_images[rng.choice(np.where(self.test_labels == i)[0])] if i in self.test_labels 
                    else self.test_images[rng.choice(np.where(self.test_labels == 0)[0])] for i in ref]
        for i in range(batch_size//2):
            pairs[0].append(self.support[ref[i]])
            pairs[1].append(self.test_images[rng.choice(np.where(self.test_labels != i)[0])])
        pairs[0] = np.array(pairs[0])
        pairs[1] = np.array(pairs[1])
        labels = np.array([1]*(batch_size//2) + [0]*(batch_size//2))
        return pairs, labels
    
    def train(self, n_batch, batch_size, saving_period=100, val_step=100):
        val_loss_track = []
        train_loss_track = []
        for i in range(n_batch):
            x, y = self.get_batch(batch_size)
            train_loss = self.model.train_on_batch(x, y)
            del x
            del y
            if i % saving_period == 0 and i != 0:
                self.model.save("poke" + str(self.current_epoch + i))
            if i % val_step == 0:
                print(i)            
                train_loss_track.append(train_loss)
                x_val, y_val = self.get_val_batch(batch_size)
                val_loss = self.model.test_on_batch(x_val, y_val)
                del x_val
                del y_val                
                val_loss_track.append(val_loss)


        self.current_epoch += n_batch
        self.model.save('./drive/MyDrive/kaggle-one-shot-pokemon/model')

        return train_loss_track, val_loss_track
    
    def test(self, test_size):
        correct = 0
        results = []
        y_pred = []
        selected_indices = rng.randint(0, len(self.test_images), test_size)
        selected = ([self.test_images[i] for i in selected_indices], [self.test_labels[i] for i in selected_indices])
        x = np.array(selected[1])
        labels = np.unique(x)
        labels = list(labels)
        for i in range(test_size):
            r = self.prediction(selected[0][i], test=True)
            y_pred.append(r)
            if r == selected[1][i]:
                correct += 1
            results.append((r,selected[1][i]))
        #print(classification_report(selected[1], y_pred, labels=labels))
        return (correct/test_size)*100, results

    def prediction(self, img, test=False):
        if not test:
            img = cv2.resize(img, (self.width, self.height))/255.
        pairs = [np.array([img]*self.support_size), np.array(self.support)]
        probs = self.model.predict(pairs)
        return np.argmax(probs)

In [5]:
support_dataset = "./drive/MyDrive/kaggle-one-shot-pokemon/pokemon-a/"
test_dataset = "./drive/MyDrive/kaggle-one-shot-pokemon/pokemon-tcg-images/"
inputs = []
values = []
width = 128
height = 128
pokemetric = PokeMetric(width, height, 1, support_dataset, 50)

In [6]:
pokemetric.load_val_images(test_dataset)


In [7]:
pokemetric.build_model(drop_rate=0.3, kernel_reg=0.0002)
pokemetric.compile(tf.keras.optimizers.Adam(learning_rate=0.00006), tf.losses.BinaryCrossentropy())
t,v = [], []

In [None]:
tbis, vbis = pokemetric.train(50001, 32, saving_period=5000)
t += tbis
v += vbis

plt.figure(figsize = (15,10))
plt.plot(t)

plt.title("Losses")
plt.legend(['training'])

In [None]:
r = pokemetric.test(20)
print(r[0])