In [None]:
# Importing All Required Libraries

import os
import numpy as np
import cv2
from cv2 import imread
from tensorflow.keras import Input
from tensorflow.keras.backend import abs as kerasAbs
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Convolution2D, MaxPool2D, Flatten, Dense, Dropout, Lambda
from matplotlib import pyplot as plt

In [None]:
# N-way K-shot Learning
N = 6 # number of classes
K = 5 # number of images per class
LIMIT = 8 # number of images to be used for training

In [None]:
def load_images(path=None, label=0, mode=None):
    if path is None:
        raise ValueError("Path is not defined")

    if mode is None:
        mode = 0

    elif isinstance(mode, int):
        if mode not in [0, 1]:
            raise ValueError("Mode should be either 0 or 1")

    elif isinstance(mode, str):
        if mode not in ["train", "val"]:
            raise ValueError("Mode should be either 'train' or 'val'")
        mode = {"train":0, "val":1}[mode]

    classes = []
    labels = []
    current_label = label
    classes_dict = dict()
    target = 0
    valid_set = [{"Hiranmay", "Neel", "Shilpi"}, {"Shreya", "Srijani", "Richa"}][mode]
    
    for person in os.scandir(path):
        if person.is_dir():
            if person.name not in valid_set:
                continue
            print(f"Loading images from {person.name}")
            classes_dict[person.name] = [current_label, current_label]
            person_images = []
            for img_no, image in enumerate(os.scandir(person.path)):
                if image.is_file():
                    if img_no == LIMIT:
                        break
                    img = imread(image.path)
                    person_images.append(img)
                    labels.append(target)
                    classes_dict[person.name][1] = current_label
                    current_label += 1
            target += 1
            classes.append(np.stack(person_images))
    labels = np.vstack(labels)
    classes = np.stack(classes)

    return classes, labels, classes_dict


In [None]:
X_train, y_train, c_train = load_images("dataset/train/labelled data", mode = 0)
X_val, y_val, c_val = load_images("dataset/train/labelled data", mode = 1)

In [None]:
"""
FUNCTION TO TEST THE 'load_iamges' FUNCTION
"""
def test_load_images():
    mode = 0
    X = [X_train, X_val][mode]
    y = [y_train, y_val][mode]
    targets = ['Hiranmay', 'Neel', 'Richa', 'Shilpi', 'Shreya', 'Srijani']
    fig, axes = plt.subplots(nrows=N//2, ncols=LIMIT, figsize=(28, 28))
    for id, person, label in zip(range(N//2), X, y):
        for idx, img in enumerate(person, start = 0):
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            axes[id, idx].imshow(img)
            axes[id, idx].set_title(f"{targets[id]} {idx+1}")
            axes[id, idx].axis('off')
    plt.show()
test_load_images()

In [None]:
def load_batch(batch_size, mode = None):
    if mode is None:
        mode = 0

    elif isinstance(mode, int):
        if mode not in [0, 1]:
            raise ValueError("Mode should be either 0 or 1")

    elif isinstance(mode, str):
        if mode not in ["train", "val"]:
            raise ValueError("Mode should be either 'train' or 'val'")
        mode = {"train":0, "val":1}[mode]
        
    classes = [X_train, X_val][mode]
    n_classes, n_samples, width, height, channels = classes.shape
    
    random_classes = np.random.choice(n_classes, batch_size)
    
    training_pairs = [np.zeros((batch_size, width, height, channels)) for _ in range(2)]
    training_labels = np.zeros((batch_size,))

    training_labels[batch_size//2:] = 1

    for i in range(batch_size):
        class_id = random_classes[i]
        idx_1 = np.random.randint(0, n_samples)
        training_pairs[0][i,:,:,:] = classes[class_id, idx_1].reshape(width, height, 3)
        
        while True:
            idx_2 = np.random.randint(0, n_samples)
            if idx_2 != idx_1:
                break

        if i >= batch_size // 2:
            class_id2 = class_id 
        else: 
            class_id2 = (class_id + np.random.randint(1, n_classes)) % n_classes
        training_pairs[1][i] = classes[class_id2, idx_2].reshape(width, height, 3)
    
    return training_pairs, training_labels
        

In [None]:
"""
FUNCTION TO TEST THE 'load_batch' FUNCTION
"""
batch_size = 6
pairs, labels = load_batch(batch_size, mode = 1)
fig, axes = plt.subplots(nrows=batch_size, ncols=2, figsize=(28, 28))
for id, img1, img2 in zip(range(batch_size), pairs[0], pairs[1]):
    axes[id, 0].imshow(cv2.cvtColor(img1.astype('uint8'), cv2.COLOR_BGR2RGB))
    axes[id, 1].imshow(cv2.cvtColor(img2.astype('uint8'), cv2.COLOR_BGR2RGB))
    axes[id, 1].set_title(f"{labels[id]}")
    axes[id, 0].axis('off')
    axes[id, 1].axis('off')

plt.show()

In [None]:
def get_siamese_model(input_shape = (160, 160, 3)):
    left_input = Input(input_shape)
    right_input = Input(input_shape)

    model = Sequential()
    model.add(Convolution2D(64, (10, 10), activation='relu', input_shape=input_shape, data_format='channels_last'))
    model.add(MaxPool2D())
    model.add(Convolution2D(128, (7, 7), activation='relu'))
    model.add(MaxPool2D())
    model.add(Convolution2D(128, (4, 4), activation='relu'))
    model.add(MaxPool2D())
    model.add(Convolution2D(256, (4, 4), activation='relu'))
    model.add(Flatten())
    model.add(Dense(4096, activation='sigmoid'))

    encoded_l = model(left_input)
    encoded_r = model(right_input)

    L1_layer = Lambda(lambda tensors:kerasAbs(tensors[0] - tensors[1]))
    L1_distance = L1_layer([encoded_l, encoded_r])

    prediction = Dense(1, activation='sigmoid')(L1_distance)
    siamese_net = Model(inputs=[left_input, right_input], outputs=prediction)

    return siamese_net
    

In [None]:
def generate(batch_size, mode = "train"):
    """
    a generator for batches, so model.fit_generator can be used.
    """
    while True:
        pairs, targets = load_batch(batch_size, mode)
        yield (pairs, targets)

In [None]:
model = get_siamese_model()
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(generate(6, mode = "train"),
          steps_per_epoch=10, 
          validation_data=generate(2, mode = "val"),
          validation_steps=2, 
          epochs=10)

In [None]:
support_set = np.zeros((6, 2, 160, 160, 3))
for i in range(6):
    idx = np.random.randint(0, 8)
    support_set[i,0,:,:,:] = X[i, idx].reshape(160, 160, 3).astype('uint8')
    while True:
        idx1 = np.random.randint(0, 8)
        if idx1 != idx:
            break
    support_set[i,1,:,:,:] = X[i, idx1].reshape(160, 160, 3).astype('uint8')

print(support_set[0].shape, X[0][0].shape)

fig, axes = plt.subplots(nrows=6, ncols=2, figsize=(28, 28))
for i in range(6):
    axes[i, 0].imshow(cv2.cvtColor(support_set[i][0].astype('uint8'), cv2.COLOR_BGR2RGB))
    axes[i, 1].imshow(cv2.cvtColor(support_set[i][1].astype('uint8'), cv2.COLOR_BGR2RGB))
    axes[i, 0].axis('off')
    axes[i, 1].axis('off')
plt.show()


In [None]:
model.predict([support_set[0][0].reshape(1, 160, 160, 3), support_set[2][0].reshape(1, 160, 160, 3)])