# 1. Implementing: https://www.cs.cmu.edu/~rsalakhu/papers/oneshot1.pdf

# 1.1. Install

In [None]:
# !pip install tensorflow tensorflow-gpu opencv-python matplotlib

# 1.2. Import

In [None]:
# Standard dependencies
import cv2 as cv
import os
import random
import numpy as np
from matplotlib import pyplot as plt

In [None]:
# TF dependencies - Functional Api
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Layer, Conv2D, Dense, MaxPooling2D, Input, Flatten
import tensorflow as tf

# 1.3. Set GPU growth

In [None]:
# Avoid OOM error by setting GPU memory consumption growth
# gpus = tf.config.experimental.list_physical_devices('GPU')
# for gpu in gpus:
#   tf.config.experimental.set_memory_growth(gpu, True)

# 1.4. Folder Structures

In [None]:
# Setup paths
POS_PATH = os.path.join('data', 'positive') # pos verification images path
NEG_PATH = os.path.join('data', 'negative') # neg verification images path
ANC_PATH = os.path.join('data', 'anchor')

In [None]:
POS_PATH

In [None]:
# Make the directories
os.makedirs(POS_PATH)
os.makedirs(NEG_PATH)
os.makedirs(ANC_PATH)

# 2. Collect Positive and Anchor

# 2.1. Untar Labelled faces in the wild Dataset

#### Data: https://vis-www.cs.umass.edu/lfw/

In [None]:
# Uncompress Tar file
!tar -xf lfw.tgz

In [None]:
# Move the LFW images to the following repository data/negative
for directory in os.listdir('lfw'):
    for file in os.listdir(os.path.join('lfw', directory)):
        EX_PATH = os.path.join('lfw', directory, file)
        NEW_PATH = os.path.join(NEG_PATH, file)
        os.replace(EX_PATH, NEW_PATH)

# 2.2. Collect Positive and Anchor classes

In [None]:
import uuid

In [None]:
# Establish a connection to the webcam
cap = cv.VideoCapture(0)
while cap.isOpened():
    ret, frame = cap.read()
    
    # cut down frame
    frame = frame[120:120+250, 200:200+250, :]
    
    # collect anchor
    if cv.waitKey(1) & 0XFF == ord('a'):
        # create unique name
        imgname = os.path.join(ANC_PATH, '{}.jpg'.format(uuid.uuid1()))
        cv.imwrite(imgname, frame)
        
    # collect positives
    if cv.waitKey(1) & 0XFF == ord('p'):
        imgname = os.path.join(POS_PATH, '{}.jpg'.format(uuid.uuid1()))
        cv.imwrite(imgname, frame)
        
    # show image
    cv.imshow("Image Collection", frame)
    
    if cv.waitKey(1) & 0XFF == ord('q'):
        break

# Release the webcam
cap.release()
# Close the imageshow frame
cv.destroyAllWindows()

# 2.3. Data augmentation

In [None]:
def data_aug(img):
    data = []
    for i in range(9):
        img = tf.image.stateless_random_brightness(img, max_delta=0.02, seed=(1,2))
        img = tf.image.stateless_random_contrast(img, lower=0.6, upper=1, seed=(1,3))
        img = tf.image.stateless_random_flip_left_right(img, seed=(np.random.randint(100),np.random.randint(100)))
        img = tf.image.stateless_random_jpeg_quality(img, min_jpeg_quality=90, max_jpeg_quality=100, seed=(np.random.randint(100),np.random.randint(100)))
        img = tf.image.stateless_random_saturation(img, lower=0.9,upper=1, seed=(np.random.randint(100),np.random.randint(100)))
            
        data.append(img)
    
    return data

In [None]:
# augmenting all pos images
for file_name in os.listdir(os.path.join(POS_PATH)):
    img_path = os.path.join(POS_PATH, file_name)
    img = cv.imread(img_path)
    augmented_images = data_aug(img) 
    
    for image in augmented_images:
        cv.imwrite(os.path.join(POS_PATH, '{}.jpg'.format(uuid.uuid1())), image.numpy())

In [None]:
# augmenting all neg images
for file_name in os.listdir(os.path.join(NEG_PATH)):
    img_path = os.path.join(NEG_PATH, file_name)
    img = cv.imread(img_path)
    augmented_images = data_aug(img) 
    
    for image in augmented_images:
        cv.imwrite(os.path.join(NEG_PATH, '{}.jpg'.format(uuid.uuid1())), image.numpy())

# 3. Load and preprocess images

# 3.1. Get images directories

In [None]:
anchor = tf.data.Dataset.list_files(ANC_PATH+'\*.jpg').take(3000) # take all files matching the pattern within `list_files`
positive = tf.data.Dataset.list_files(POS_PATH+'\*.jpg').take(3000)
negative = tf.data.Dataset.list_files(NEG_PATH+'\*.jpg').take(3000)

# 3.2. Preprocess - Scale & Resize

In [None]:
def preprocess(file_path):
    """
    1. Read image as it is from file path.
    2. Load in the image.
    3. Preprocessing -
                    i. Resize image (100x100x3)
                    ii. Rescale the image to between (0-1)
    """
    # 1.
    byte_image = tf.io.read_file(file_path)
    
    # 2.
    image = tf.io.decode_jpeg(byte_image)
    
    # 3.
    image = tf.image.resize(image, (100, 100))
    image /= 255.0
    
    return image

# 3.3. Create labelled dataset

In [None]:
positives = tf.data.Dataset.zip((anchor, positive, tf.data.Dataset.from_tensor_slices(tf.ones(len(anchor)))))
negatives = tf.data.Dataset.zip((anchor, negative, tf.data.Dataset.from_tensor_slices(tf.zeros(len(anchor)))))
data = positives.concatenate(negatives)

In [None]:
len(data)

# 3.4. Build train & test partition

In [None]:
def preprocess_twin(input_img, validation_img, label):
    return(preprocess(input_img), preprocess(validation_img), label)

In [None]:
# dataloader pipeline
data = data.map(preprocess_twin)
data = data.cache()
data = data.shuffle(buffer_size=10000)

In [None]:
# training partition
train_data = data.take(round(len(data) * 0.7))
train_data = train_data.batch(16)
train_data = train_data.prefetch(8)

In [None]:
# Valid partition -/ can be used for further model tuning
valid_data = data.skip(round(len(data) * 0.7))
valid_data = valid_data.take(round(len(valid_data) * 0.7))
valid_data = valid_data.batch(16)
valid_data = valid_data.prefetch(8)

In [None]:
# Test partition to check model performance at the end
test_data = data.skip(round(len(data) * 0.91))
test_data = test_data.batch(16)
test_data = test_data.prefetch(8)

# 4. Model Engineering

# 4.1. Building Embedding layer

In [None]:
def make_embedding():
    inp = Input((100, 100, 3), name="input_image")
    # First block
    c1 = Conv2D(64, (10, 10), activation="relu", name="conv_layer_1")(inp)
    m1 = MaxPooling2D(64, (2, 2), padding="same", name="max-pool_layer_1")(c1)
    
    # 2nd block
    c2 = Conv2D(128, (7, 7), activation="relu", name="conv_layer_2")(m1)
    m2 = MaxPooling2D(64, (2, 2), padding="same", name="max-pool_layer_2")(c2)
    
    # 3rd block
    c3 = Conv2D(128, (4, 4), activation="relu", name="conv_layer_3")(m2)
    m3 = MaxPooling2D(64, (2, 2), padding="same", name="max-pool_layer_3")(c3)
    
    # final embedding
    c4 = Conv2D(256, (4, 4), activation="relu", name="conv_layer_4")(m3)
    f1 = Flatten(name="flatten_layer_1")(c4)
    d1 = Dense(4096, activation="sigmoid", name="FCD")(f1)
    
    return Model(inputs = [inp], outputs = [d1], name = "embedding")

In [None]:
embedding = make_embedding()
embedding.summary()

# 4.2. Build distance layer

In [None]:
class L1Distance(Layer):
    
    def __init__(self, **kwargs):
        super().__init__()
       
    def call(self, input_embedding, validation_embedding):
        return tf.math.abs(input_embedding - validation_embedding)

# 4.3. Make Siamese model

In [None]:
def siamese_model():
    # anchor image in the network
    input_image = Input(name="input_image", shape=(100, 100, 3))
    # validation image in the network
    validation_image = Input(name="validation_image", shape=(100, 100, 3))
    
    # combine siamese distance components
    siamese_layer = L1Distance()
    siamese_layer._name = "distance"
    distances = siamese_layer(embedding(input_image), embedding(validation_image))
    
    # classification layer
    classifier = Dense(1, activation="sigmoid", name="FCD")(distances)
    
    return Model(inputs = [input_image, validation_image], outputs = classifier, name = "Siamese_Network")

In [None]:
siamese_model = siamese_model()
siamese_model.summary()

# 5. Training

# 5.1. Setup Loss & Optimizer

In [None]:
binary_cross_loss = tf.losses.BinaryCrossentropy() # from_logits=True if the inputs to the function are not normalized

In [None]:
opt = tf.keras.optimizers.Adam(1e-4)

# 5.2. Establish checkpoints

In [None]:
checkpoints_dir = './training_checkpoints'                 # to use checkpoints use model.load('path_to_checkpoint')
checkpoint_prefix = os.path.join(checkpoints_dir, 'chkpt')
checkpoint = tf.train.Checkpoint(opt=opt, siamese_model=siamese_model)

# 5.3. Build train Step function

In [None]:
@tf.function
def train_step(batch):
    
    with tf.GradientTape() as tape:
        X = batch[:2] # anchor & pos/neg image
        y = batch[2] # labels
        
        # forward pass
        yhat = siamese_model(X, training=True)
        # calculate true
        loss = binary_cross_loss(y, yhat)

    # calculate gradients
    grad = tape.gradient(loss, siamese_model.trainable_variables)

    opt.apply_gradients(zip(grad, siamese_model.trainable_variables))

    return loss

# 5.4. Build training loop

In [None]:
# Import metric calculations
from tensorflow.keras.metrics import Precision, Recall

In [None]:
def training(data, EPOCHS):
    for epoch in range(1, EPOCHS+1):
        print("\n Epoch {}/{}".format(epoch, EPOCHS))
        progbar = tf.keras.utils.Progbar(len(data))
        
        r = Recall()
        p = Precision()
        
        for idx, batch in enumerate(data):
            loss = train_step(batch)
            yhat = siamese_model.predict(batch[:2])
            r.update_state(batch[2], yhat)
            p.update_state(batch[2], yhat) 
            progbar.update(idx+1)
        print("Loss: ", loss.numpy(), "Recall: ", r.result().numpy(), "Precission: ", p.result().numpy())
        
        if epoch % 10 == 0:
            checkpoint.save(file_prefix=checkpoint_prefix)

# 5.5. Train the model

In [None]:
EPOCHS=20
training(train_data, EPOCHS)

# 6. Save the model

In [None]:
# import sys
# !pip uninstall --yes --prefix {sys.prefix} h5py=3.6.0 -c pkgs/main

In [None]:
# !conda install -y -n {envname} {package}

In [None]:
# import sys
# !{sys.executable} -m pip install h5py

In [None]:
import h5py

In [None]:
siamese_model.save("saimese_model.h5")

In [None]:
# Reload model 
siamese_model = tf.keras.models.load_model('saimese_model.h5', 
                                   custom_objects={'L1Distance':L1Distance, 'BinaryCrossentropy':tf.losses.BinaryCrossentropy})

# 6.1. Testing metrics on test data

In [None]:
# use our test_data
r = Recall()
p = Precision()

for test_input, test_val, y_true in test_data.as_numpy_iterator():
    yhat = siamese_model.predict([test_input, test_val])
    r.update_state(y_true, yhat)
    p.update_state(y_true,yhat) 

print(r.result().numpy(), p.result().numpy())

In [None]:
# Let's just make a function of above cells
def validation_(data):
    progbar = tf.keras.utils.Progbar(len(data))

    r = Recall()
    p = Precision()

    for idx, batch in enumerate(data):
        yhat = siamese_model.predict(batch[:2])
        r.update_state(batch[2], yhat)
        p.update_state(batch[2], yhat) 
        progbar.update(idx+1)

    print("Recall: ", r.result().numpy(), "Precission: ", p.result().numpy())

In [None]:
validation_(test_data)

# 7. Verification function

In [None]:
def verify(model, detection_threshold, verification_threshold):
    results = []
    for image in os.listdir(os.path.join('application_data', 'verification_images')):
        input_image = preprocess(os.path.join('application_data', 'input_image', 'input_image.jpg'))
        validation_image = preprocess(os.path.join('application_data', 'verification_images', image))
        
        result = model.predict(list(np.expand_dims([input_image, validation_image], axis=1)))
        results.append(result)
        
    # Detection threshold: Metric above which prediction is considered positive
    detection = np.sum(np.array(results) > detection_threshold)
    # verification threshold: Proportion of positive predictions / total positive samples
    verification = detection / len(os.listdir(os.path.join('application_data', 'verification_images')))
    verified = verification > verification_threshold
    
    return results, verified

# 7.1. OpenCV Real Time Verification

In [None]:
cap = cv.VideoCapture(0)
while cap.isOpened():
    _, frame = cap.read()
    frame = frame[120:120+250,200:200+250, :]
    
    cv.imshow('Verification', frame)
    
    # Verification trigger
    if cv.waitKey(10) & 0xFF == ord('v'):
        # Save input image to application_data/input_image folder 
        cv.imwrite(os.path.join('application_data', 'input_image', 'input_image.jpg'), frame)
        # Run verification
        results, verified = verify(siamese_model, 0.5, 0.5)
        print(verified)
    
    if cv.waitKey(10) & 0xFF == ord('q'):
        break
cap.release()
cv.destroyAllWindows()