# IMAGE SIMILARITY USING TRIPLET LOSS

## A - Set-up the working environment

### I - Import packages

In [None]:
import os
import math
import random
import gc

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from sklearn.model_selection import train_test_split
import PIL
from matplotlib import image as mpimg
import tensorflow_datasets as tfds
import pathlib

In [None]:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" 

### II - Define global constants

In [None]:
VAL_SIZE = 0.2
RANDOM_STATE = 21
BATCH_SIZE = 32
EPOCHS = 1000
IMAGE_SIZE_H = 64 #245(tll) 32(cf10)
IMAGE_SIZE_W = 64 #200(tll) 32(cf10)

In [None]:
CHECKPOINT_PATH = "checkpoint/"

## B -  Preprocess the data

### I - Split the dataset into train, valid, and test set

In [None]:
# /kaggle/input/totally-looks-like-dataset
cache_dir = pathlib.Path("/kaggle/input/totally-looks-like-ds2/totally_looks_like_ds2")
print(cache_dir)
anchor_images_path = cache_dir / "left/left"
similar_images_path = cache_dir / "right/right"

In [None]:
image_count = len(list(anchor_images_path.glob('*.jpg')))
datasets = []
labels  = []

for i in range(image_count):
    try:
        anchor_image_path = list(anchor_images_path.glob(f'{i:05d}.jpg'))[0]
        similar_image_path = list(similar_images_path.glob(f'{i:05d}.jpg'))[0]
        
#         anchor_image = np.asarray(PIL.Image.open(anchor_image_path)).astype('float16')
#         similar_image = np.asarray(PIL.Image.open(similar_image_path)).astype('float16')
        
        labels.append(i)
        datasets.append([anchor_image_path, similar_image_path])
        
    except Exception as e:
        print(f"{e}")
        
dataset_images = np.array(datasets)
dataset_labels = np.array(labels)


In [None]:
i = 5870
anchor_image = list(similar_images_path.glob(f'{i:05d}.jpg'))[0]
plt.imshow(np.asarray(PIL.Image.open(anchor_image)))

In [None]:
print(dataset_labels.shape)
print(dataset_images.shape)
print(type(dataset_images[0][0]))

In [None]:
(
    train_images, val_images, train_labels, val_labels
) = train_test_split(
    dataset_images, dataset_labels, test_size=VAL_SIZE, random_state=RANDOM_STATE
)

In [None]:
# %reset_selective -f  dataset_images
# %reset_selective -f dataset_labels
del  dataset_images
del dataset_labels
gc.collect()

In [None]:
train_images.shape

In [None]:
train_images = train_images.reshape(-1)
val_images = val_images.reshape(-1)
train_labels = np.repeat(train_labels, 2)
val_labels = np.repeat(val_labels, 2)

In [None]:
print(train_images.shape)
print(train_labels.shape)

In [None]:
len(val_images)

### II - Group images into triplets

In [None]:
class ImageTriplets(keras.utils.Sequence):
    def __init__(self, images, labels, batch_size):
        self._images = images
        self._labels = labels
        self._indices = np.arange(len(images))
        self._idx_groups = self._group_indices_by_labels()
        self._unique_labels = self._idx_groups.keys()
        self._batch_size = batch_size
        random.shuffle(self._indices)

    def __len__(self):
        return math.ceil(len(self._images) / self._batch_size)

    def __getitem__(self, batch_idx):
        """
        Return a batch of triplets containing an anchor image, a positive image, 
        and a negative image. Each triplet is uniquely identified by its anchor
        image.
        """
        start_idx = min(self._batch_size * batch_idx, len(self._indices) - self._batch_size)
        end_idx = min(self._batch_size * (batch_idx + 1), len(self._indices))
        anchor_image_indices = self._indices[start_idx:end_idx]
        batch = np.stack(
            [self._make_triplet(idx) for idx in anchor_image_indices]
        )
        return batch

    def on_epoch_end(self):
        np.random.shuffle(self._indices)
    
    def _group_indices_by_labels(self):
        """
        Return a dict with label as key and a indices list of samples with the 
        corresponding label as value.
        """
        idx_groups = {}

        for label, idx in zip(self._labels.ravel(), self._indices):
            if label not in idx_groups:
                idx_groups[label] = [idx]
                continue
            idx_groups[label].append(idx)

        return idx_groups
        
    def _make_triplet(self, anchor_image_idx):
        """
        Return a triplet of an anchor image, a positive image, and a negative image, 
        such that:
        - A triplet is uniquely identified by its anchor image.
        - The anchor and positive image aren't the same one.
        - The positve and negative image indices are randomly chosen on every call.
        """
        positive_label = int(self._labels[anchor_image_idx].squeeze())
        positive_group = self._idx_groups[positive_label]
        negative_group = self._choose_negative_group(positive_label)
        positive_image_idx = self._choose_positive_image_idx(
            positive_group, anchor_image_idx
        )
        negative_image_idx = random.choice(negative_group)
        
        anchor_image_path = self._images[anchor_image_idx]
        positive_image_path = self._images[positive_image_idx]
        negative_image_path = self._images[negative_image_idx]
        
        anchor_image = np.asarray(PIL.Image.open(anchor_image_path))/255.
        positive_image = np.asarray(PIL.Image.open(positive_image_path))/255.
        negative_image = np.asarray(PIL.Image.open(negative_image_path))/255.
        
        return np.stack([anchor_image, positive_image, negative_image])
                
    def _choose_negative_group(self, positive_label):
        """Choose a group for negative image to be sampled from."""
        possible_negative_labels = [
            label for label in self._unique_labels if label != positive_label
        ] 
        negative_label = random.choice(possible_negative_labels)
        return self._idx_groups[negative_label]
    
    def _choose_positive_image_idx(self, positive_group, anchor_image_idx):
        """
        Choose an index other than the anchor image index from the positive group.
        """
        possible_positive_image_indices = [
            idx for idx in positive_group if idx != anchor_image_idx
        ]
        positive_image_idx = random.choice(possible_positive_image_indices)
        return positive_image_idx

In [None]:
train_triplets = ImageTriplets(train_images, train_labels, BATCH_SIZE)
val_triplets = ImageTriplets(val_images, val_labels, BATCH_SIZE)
# test_triplets = ImageTriplets(test_images, test_labels, BATCH_SIZE)

In [None]:
del  train_images
del  val_images
gc.collect()

### III - Visualize the some data samples

In [None]:
def show_images_in_triplets(image_triplets):
    """
    Show images from one triplet on a row and different triplets on different
    rows.
    """
    triplet_count = len(image_triplets)
    fig, axes = plt.subplots(
        nrows=triplet_count, ncols=3, figsize=(15, 5 * triplet_count)
    )

    for row, image_triplet, in zip(axes, image_triplets):
        for grid, image, title in zip(
            row, image_triplet, ["anchor", "positive", "negative"]
        ):
            grid.imshow(image)
            grid.set_title(title)

In [None]:
# random_triplets = train_triplets[random.randrange(len(train_triplets))]

# show_images_in_triplets(random_triplets)

## C - Define and train the model

### I - Define the backbone model

#### 1. Define the model building blocks

In [None]:
def cb_block(input_shape, filters, kernel_size, strides):
    layers = [
        keras.layers.Input(input_shape),
        keras.layers.Conv2D(filters, kernel_size, strides, padding="same"),
        keras.layers.BatchNormalization(),
    ]
    return keras.Sequential(layers)

In [None]:
def cba_block(input_shape, filters, kernel_size, strides):
    layers = [
        keras.layers.Input(input_shape),
        cb_block(input_shape, filters, kernel_size, strides),
        keras.layers.ReLU(),
    ]
    return keras.Sequential(layers)

In [None]:
def shallow_feedforward_block(input_shape, filters, strides):
    input_width, input_height, _ = input_shape
    layers = [
        cb_block(input_shape, filters, kernel_size=3, strides=strides),
        cba_block(
            (input_width // strides, input_height // strides, filters), 
            filters, 
            kernel_size=3, 
            strides=1,
        ),
    ]
    return keras.Sequential(layers)

In [None]:
def deep_feedforward_block(input_shape, filters, strides):
    input_width, input_height, _ = input_shape
    layers = [
        cb_block(input_shape, filters // 4, kernel_size=1, strides=strides),
        cb_block(
            (input_width // strides, input_height // strides, filters // 4), 
            filters // 4, 
            kernel_size=3, 
            strides=1,
        ),
        cba_block(
            (input_width // strides, input_height // strides, filters // 4), 
            filters, 
            kernel_size=3, 
            strides=1,
        ),
    ]
    return keras.Sequential(layers)

In [None]:
def dimesion_altering_residual_block(
    input_shape, filters, strides, feedforward_block
):
    inputs = keras.layers.Input(input_shape)
    feature_maps = (
        feedforward_block(input_shape, filters, strides)(inputs) 
        + cb_block(input_shape, filters, kernel_size=1, strides=strides)(inputs)
    )
    outputs = keras.layers.ReLU()(feature_maps)
    return keras.Model(inputs, outputs)

In [None]:
def constant_dimension_residual_block(
    input_shape, feedforward_block
):
    *_, filters = input_shape
    inputs = keras.layers.Input(input_shape)
    feature_maps = feedforward_block(input_shape, filters, strides=1)(inputs) + inputs
    outputs = keras.layers.ReLU()(feature_maps)
    return keras.Model(inputs, outputs)

In [None]:
def repeating_residual_blocks(
    input_shape, filters, strides, repetitions, feedforward_block,
):
    input_width, input_height, _ = input_shape
    layers = [
        dimesion_altering_residual_block(
            input_shape, filters, strides, feedforward_block
        )
    ]
    
    for _ in range(repetitions - 1):
        layers.append(
            constant_dimension_residual_block(
                (input_width // strides, input_height // strides, filters),
                feedforward_block=feedforward_block,
            )
        )
        
    return keras.Sequential(layers)

#### 2. Define a general ResNet architecture

In [None]:
def resnet(
    input_shape, 
    output_shape, 
    first_block_filters,
    repetitions_by_blocks,
    feedforward_block,
    name="ResNet",
):
    input_width, input_height, _ = input_shape
        
    layers = [
        # conv1
        keras.Input(input_shape),
        cba_block(input_shape, filters=64, kernel_size=3, strides=2),
        # conv2_x
        keras.layers.MaxPooling2D(pool_size=3, strides=2, padding="same"),
        repeating_residual_blocks(
            (input_width // 4, input_height // 4, 64), 
            first_block_filters, 
            strides=1, 
            repetitions=repetitions_by_blocks[0],
            feedforward_block=feedforward_block,
        ),
    ]
        
    # conv3_x -> conv5_x
    for idx, repetitions in enumerate(repetitions_by_blocks[1:]):
        layers.append(
            repeating_residual_blocks(
                (
                    input_width // 2 ** (idx + 2),
                    input_height // 2 ** (idx + 2),
                    first_block_filters * 2 ** idx,
                ), 
                first_block_filters * 2 ** (idx + 1), 
                strides=2, 
                repetitions=repetitions,
                feedforward_block=feedforward_block,
            )
        )
    
    # average_pooling & fc
    layers.extend(
        [
            keras.layers.AveragePooling2D(padding="same"),
            keras.layers.Flatten(),
            keras.layers.Dense(units=output_shape),
        ]
    )
    
    return keras.Sequential(layers, name=name)

#### 3. Define different versions of ResNet

In [None]:
def resnet18(input_shape, output_shape, name="ResNet-18"):
    return resnet(
        input_shape,
        output_shape,
        first_block_filters=64,
        repetitions_by_blocks=[2, 2, 2, 2],
        feedforward_block=shallow_feedforward_block,
        name=name,
    )

In [None]:
def resnet34(input_shape, output_shape, name="ResNet-34"):
    return resnet(
        input_shape,
        output_shape,
        first_block_filters=64,
        repetitions_by_blocks=[3, 4, 6, 3],
        feedforward_block=shallow_feedforward_block,
        name=name,
    )

In [None]:
def resnet50(input_shape, output_shape, name="ResNet-50"):
    return resnet(
        input_shape,
        output_shape,
        first_block_filters=256,
        repetitions_by_blocks=[3, 4, 6, 3],
        feedforward_block=deep_feedforward_block,
        name=name,
    )

In [None]:
def resnet101(input_shape, output_shape, name="ResNet-101"):
    return resnet(
        input_shape,
        output_shape,
        first_block_filters=256,
        repetitions_by_blocks=[3, 4, 23, 3],
        feedforward_block=deep_feedforward_block,
        name=name,
    )

In [None]:
def resnet152(input_shape, output_shape, name="ResNet-152"):
    return resnet(
        input_shape,
        output_shape,
        first_block_filters=256,
        repetitions_by_blocks=[3, 8, 36, 3],
        feedforward_block=deep_feedforward_block,
        name=name,
    )

### II - Define the FaceNet model

#### 1. Define augmentation layer

In [None]:
data_augmentation = tf.keras.Sequential([
  keras.layers.RandomFlip("horizontal_and_vertical"),
  keras.layers.RandomRotation(0.8),
#   keras.layers.RandomCrop(IMAGE_SIZE_H, IMAGE_SIZE_W),
])

"""resize image for valid input"""
image_resize = tf.keras.Sequential([
  keras.layers.Resizing(IMAGE_SIZE_H, IMAGE_SIZE_W)
])

#### 2. Define the FaceNet model

In [None]:
class FaceNet(keras.Model):
    def __init__(self, backbone, loss_margin=0.5, **kwargs):
        super().__init__(**kwargs)
        self._backbone = backbone
        self._loss_margin = loss_margin
        self.loss_tracker = tf.keras.metrics.Accuracy(name='accuracy')
        
    def call(self, image):
        return self._backbone(image_resize(image))

    def triplet_embeddings(self, images, training = False):
        anchor_images = images[:, 0]
        positive_images = images[:, 1]
        negative_images = images[:, 2]
        
        if training:
            anchor_images = data_augmentation(anchor_images)
            positive_images = data_augmentation(positive_images)
            negative_images = data_augmentation(negative_images)
            
    
        anchor_embeddings = self(anchor_images, training = training)
        positive_embeddings = self(positive_images, training = training)
        negative_embeddings = self(negative_images, training = training)
    
        embeddings = tf.stack(
            [anchor_embeddings, positive_embeddings, negative_embeddings], axis=1
        )
        return embeddings
    
    def train_step(self, images):
        with tf.GradientTape() as tape:
            embeddings = self.triplet_embeddings(images, training=True)
            
            loss = self.compute_loss(embeddings)
            
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(
            zip(gradients, self.trainable_variables)
        )
        return {"loss": loss, "accuracy": self.loss_tracker.result()}
    
    def test_step(self, images):
        embeddings = self.triplet_embeddings(images, training=False)    
        loss = self.compute_loss(embeddings)
        return {"loss": loss, "accuracy": self.loss_tracker.result()}
    
    def compute_loss(self, embeddings):
        anchor_embeddings = embeddings[:, 0]
        positive_embeddings = embeddings[:, 1]
        negative_embeddings = embeddings[:, 2]

        ap_distance = tf.math.reduce_euclidean_norm(
            anchor_embeddings - positive_embeddings, axis=1
        )

        an_distance = tf.math.reduce_euclidean_norm(
            anchor_embeddings - negative_embeddings, axis=1
        )

#         loss = tf.reduce_max(
#             ap_distance - an_distance + self._loss_margin, 0
#         )
        loss = tf.math.maximum(ap_distance - an_distance + self._loss_margin, 0)
        self.loss_tracker.update_state(tf.math.maximum(ap_distance - an_distance, 0), np.zeros(BATCH_SIZE))
        loss = tf.reduce_sum(loss)
        return loss    
    
    def get_config(self):
        base_config = super().get_config()
        return {
            **base_config,
            "backbone": self._backbone,
            "loss_margin": self._loss_margin,
        }

In [None]:
backbone = resnet18(input_shape=(IMAGE_SIZE_H, IMAGE_SIZE_W, 3), output_shape=130)

In [None]:
model = FaceNet(backbone=backbone)

In [None]:
test = np.stack([train_triplets[0][0][0]])
print("test shape", test.shape)
model(test).shape

In [None]:
model.compile(optimizer=keras.optimizers.Adam())

In [None]:
callbacks = [
    keras.callbacks.EarlyStopping(patience=14, restore_best_weights=True, mode='max', monitor='val_accuracy'),
    keras.callbacks.ReduceLROnPlateau(patience=3),
    keras.callbacks.ModelCheckpoint(CHECKPOINT_PATH, save_best_only=True, monitor="val_accuracy", mode='max'),
    keras.callbacks.TensorBoard(),
]

In [None]:
model.fit(
    train_triplets, 
    validation_data=val_triplets, 
    epochs=EPOCHS, 
    callbacks=callbacks,
)

In [None]:
model.metrics

In [None]:
gc.collect()

In [None]:
# model.save('model', include_optimizer=False)

In [None]:
m = keras.models.load_model('checkpoint')
test = np.stack([train_triplets[0][0][0]])
print(test.shape)
pred = m.predict(test)
print(pred.shape)

In [None]:
import shutil
shutil.make_archive('output_kaggle', 'zip', '/kaggle/working/')

In [None]:
# %load_ext tensorboard
# %tensorboard --logdir "logs"