# Example Notebook for model visualisation

# Importing required libraries

In [2]:
import sys

sys.path.append("../")

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from transformers import TFViTModel

from tensorflow.keras.applications import VGG19
from tensorflow.keras.layers import Input, Dense, Flatten, BatchNormalization, Dropout, Subtract, Activation, Conv2D, concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam, SGD, RMSprop
from tensorflow.keras.callbacks import EarlyStopping

from utils import *
from data_aug import *
from dataset import *

In [5]:
resize_rescale_hf = tf.keras.Sequential([
    tf.keras.layers.Permute((3, 1, 2))
])

# Loading raw data

## Streetview

In [None]:
image1_array_Streetview, image2_array_Streetview, labels_Streetview = load_data("../data/question_1/Streetview_dataaug")

## Mapillary

In [None]:
image1_array_Mapillary, image2_array_Mapillary, labels_Mapillary = load_data("../Mapillary/mapillary_training_dataaug_contrast/") 

# Models architecture

### Comparison model (VGG19)

In [6]:
def comparison_model(input_shape):
    """Create a siamese model for image comparison using VGG19 as base model.

    Args:
        input_shape (tuple): Shape of the input images.
    Returns:
        keras.models.Model: The compiled siamese model.
    """
    base_model = VGG19(weights='imagenet', include_top=False, input_shape=input_shape)
    for layer in base_model.layers[:-4]:
        layer.trainable=False

    # Create inputs for pairs of images
    input_1 = Input(shape=input_shape)
    input_2 = Input(shape=input_shape)

    # Get embeddings of the images using the shared VGG19 model
    output_1 = base_model(input_1)
    output_2 = base_model(input_2)

    concat = concatenate([output_1, output_2])

    # Classification layer to predict similarity
    flatten = Flatten()(concat)
    x = Conv2D(32, (3, 3), activation="tanh", padding='same')(concat)
    x = Dropout(0.3)(x)
    x = Conv2D(32, (3, 3), activation="tanh", padding='same')(x)
    x = Dropout(0.3)(x)
    x = Flatten()(x)
    output = Dense(2, activation='sigmoid')(x)

    # Create the complete siamese model
    siamese_model = Model(inputs=[input_1, input_2], outputs=output)

    # Compile the model with the provided hyperparameters
    siamese_model.compile(loss="binary_crossentropy", optimizer=Adam(learning_rate=1e-05, decay=0.001), metrics=['accuracy'])

    # Print model summary
    siamese_model.summary()

    return siamese_model

### Ranking model (VGG19)

In [7]:
def create_ranking_network(img_size):
    """
    Create ranking network which give a score to an image.

    :param img_size: size of input images during training
    :type img_size: tuple(int)
    :return: ranking network model
    :rtype: keras.Model
    """
    # Create feature extractor from VGG19
    feature_extractor = VGG19(weights="imagenet", include_top=False, input_shape=(img_size, img_size, 3))
    for layer in feature_extractor.layers[:-4]:
        layer.trainable = False

    # Add dense layers on top of the feature extractor
    inp = Input(shape=(img_size, img_size, 3), name='input_image')
    base = feature_extractor(inp)
    base = Flatten(name='Flatten')(base)

    # Block 1
    base = Dense(32, activation='sigmoid', name='Dense_1')(base)
    base = BatchNormalization(name='BN1')(base)
    base = Dropout(0.2, name='Drop_1')(base)

    # Block 2
    base = Dense(32, activation='sigmoid', name='Dense_2')(base)
    base = BatchNormalization(name='BN2')(base)
    base = Dropout(0.2, name='Drop_2')(base)

    # Final dense
    base = Dense(1, name="Dense_Output")(base)
    base_network = Model(inp, base, name='Scoring_model')
    return base_network


def create_meta_network(img_size, weights=None):
    """
    Create meta network which is used to to teach the ranking network.

    :param img_size: dimension of input images during training.
    :type img_size: tuple(int)
    :param weights: path to the weights use for initialization
    :type weights: str
    :return: meta network model
    :rtype: keras.Model
    """

    # Create the two input branches
    input_left = Input(shape=(img_size, img_size, 3), name='left_input')
    input_right = Input(shape=(img_size, img_size, 3), name='right_input')
    base_network = create_ranking_network(img_size)
    left_score = base_network(input_left)
    right_score = base_network(input_right)

    # Subtract scores
    diff = Subtract()([left_score, right_score])

    # Pass difference through sigmoid function.
    prob = Activation("sigmoid", name="Activation_sigmoid")(diff)
    model = Model(inputs=[input_left, input_right], outputs= prob, name="Meta_Model")

    if weights:
        print('Loading weights ...')
        model.load_weights(weights)

    model.compile(optimizer=RMSprop(learning_rate=0.0001, decay=1e-05), loss="binary_crossentropy", metrics=['accuracy'])

    return model

### Comparison model (ViT from Google)

In [8]:
def comparison_vit_model(input_shape):
    input_1 = layers.Input(shape=input_shape)
    resized_input_1 = resize_rescale_hf(input_1)  # Make sure resize_rescale_hf is defined or imported
    input_2 = layers.Input(shape=input_shape)
    resized_input_2 = resize_rescale_hf(input_2)  # Make sure resize_rescale_hf is defined or imported
    
    # Load the ViT model for image classification
    base_model = TFViTModel.from_pretrained('google/vit-base-patch16-224-in21k')

    for layer in base_model.layers:
        layer.trainable=True

    # Extract the features from the ViT model
    features_1 = base_model.vit(resized_input_1)[0][:, 0, :]
    features_2 = base_model.vit(resized_input_2)[0][:, 0, :]

    # Calculate the Euclidean distance between the representations of the two images
    distance = layers.Lambda(lambda tensors: tf.math.abs(tensors[0] - tensors[1]))([features_1, features_2])
    outputs = layers.Dense(2, activation='softmax')(distance)

    # Create the Keras model
    siamese_network = tf.keras.Model(inputs=[input_1, input_2], outputs=outputs)
    optimizer = tf.keras.optimizers.Adam()

    siamese_network.compile(
        optimizer=optimizer,
        loss=tf.keras.losses.BinaryCrossentropy(),
        metrics=[
            tf.keras.metrics.BinaryAccuracy(name="accuracy"),
        ],
    )

    return siamese_network

### Ranking model (ViT from Google)

In [9]:
def create_ranking_network_vit(img_size):
    """
    Create ranking network which give a score to an image.

    :param img_size: size of input images during training
    :type img_size: tuple(int)
    :return: ranking network model
    :rtype: keras.Model
    """
    # Create feature extractor from VGG19
    feature_extractor = TFViTModel.from_pretrained('google/vit-base-patch16-224-in21k')

    for layer in feature_extractor.layers:
        layer.trainable=True

    # Add dense layers on top of the feature extractor
    inp = Input(shape=(img_size, img_size, 3), name='input_image')
    resized_inp = resize_rescale_hf(inp) 
    base = feature_extractor.vit(resized_inp)[0][:, 0, :]
    base = Flatten(name='Flatten')(base)

    # Block 1
    base = Dense(32, activation='sigmoid', name='Dense_1')(base)
    base = BatchNormalization(name='BN1')(base)
    base = Dropout(0.2, name='Drop_1')(base)

    # Block 2
    base = Dense(32, activation='sigmoid', name='Dense_2')(base)
    base = BatchNormalization(name='BN2')(base)
    base = Dropout(0.2, name='Drop_2')(base)

    # Final dense
    base = Dense(1, name="Dense_Output")(base)
    base_network = Model(inp, base, name='Scoring_model')
    return base_network


def create_meta_network_vit(img_size, weights=None):
    """
    Create meta network which is used to to teach the ranking network.

    :param img_size: dimension of input images during training.
    :type img_size: tuple(int)
    :param weights: path to the weights use for initialization
    :type weights: str
    :return: meta network model
    :rtype: keras.Model
    """

    # Create the two input branches
    input_left = Input(shape=(img_size, img_size, 3), name='left_input')
    input_right = Input(shape=(img_size, img_size, 3), name='right_input')
    base_network = create_ranking_network(img_size)
    left_score = base_network(input_left)
    right_score = base_network(input_right)

    # Subtract scores
    diff = Subtract()([left_score, right_score])

    # Pass difference through sigmoid function.
    prob = Activation("sigmoid", name="Activation_sigmoid")(diff)
    model = Model(inputs=[input_left, input_right], outputs= prob, name="Meta_Model")

    if weights:
        print('Loading weights ...')
        model.load_weights(weights)

    model.compile(optimizer=RMSprop(learning_rate=0.0001, decay=1e-05), loss="binary_crossentropy", metrics=['accuracy'])

    return model

# Looking at the training and validation loss/accuracy of the trained models

### Comparison model (VGG19) trained on Streetview data with data augmentation

<div style="display: flex; justify-content: center;">
    <img src="../Result/Streetview_Result/Comparison_Handpicked_DataAugmentation_With_contrast/accuracy_curve.png" style="width: 40%; margin-right: 10px;">
    <img src="../Result/Streetview_Result/Comparison_Handpicked_DataAugmentation_With_contrast/loss_curve.png" style="width: 40%; margin-right: 10px;">
</div>

### Ranking model (VGG19) trained on Streetview data with data augmentation

<div style="display: flex; justify-content: center;">
    <img src="../Result/Streetview_Result/Ranking_Handpicked_DatAug_With_Contrast/accuracy_curve.png" style="width: 40%; margin-right: 10px;">
    <img src="../Result/Streetview_Result/Ranking_Handpicked_DatAug_With_Contrast/loss_curve.png" style="width: 40%; margin-right: 10px;">
</div>

### Comparison model (VGG19) trained on Mapillary data with data augmentation

<div style="display: flex; justify-content: center;">
    <img src="../Result/Mapillary_Results/Best_ComparisonModel_From_Streetview_Trained_On_Mapillary_DataAug_Contrast/accuracy_curve.png" style="width: 40%; margin-right: 10px;">
    <img src="../Result/Mapillary_Results/Best_ComparisonModel_From_Streetview_Trained_On_Mapillary_DataAug_Contrast/loss_curve.png" style="width: 40%; margin-right: 10px;">
</div>

### Ranking model (VGG19) trained on Mapillary data with data augmentation

<div style="display: flex; justify-content: center;">
    <img src="../Result/Mapillary_Results/Best_RankingModel_From_Streetview_Trained_On_Mapillary_DataAug_Contrast/accuracy_curve.png" style="width: 40%; margin-right: 10px;">
    <img src="../Result/Mapillary_Results/Best_RankingModel_From_Streetview_Trained_On_Mapillary_DataAug_Contrast/loss_curve.png" style="width: 40%; margin-right: 10px;">
</div>

### Comparison model (ViT from Google) trained on Streetview data with data augmentation

<div style="display: flex; justify-content: center;">
    <img src="../Result/Transformer_Results/Streetview/Comparison_200E_dataaug_contrast/accuracy_curve.png" style="width: 40%; margin-right: 10px;">
    <img src="../Result/Transformer_Results/Streetview/Comparison_200E_dataaug_contrast/loss_curve.png" style="width: 40%; margin-right: 10px;">
</div>

### Ranking model (ViT from Google) trained on Streetview data with data augmentation

<div style="display: flex; justify-content: center;">
    <img src="../Result/Transformer_Results/Streetview/Ranking_10E_dataaug_contrast/accuracy_curve.png" style="width: 40%; margin-right: 10px;">
    <img src="../Result/Transformer_Results/Streetview/Ranking_10E_dataaug_contrast/loss_curve.png" style="width: 40%; margin-right: 10px;">
</div>

### Comparison model (ViT from Google) trained on Mapillary data with data augmentation

<div style="display: flex; justify-content: center;">
    <img src="../Result/Transformer_Results/Mapillary/Comparison_200E_dataaug_contrast/accuracy_curve.png" style="width: 40%; margin-right: 10px;">
    <img src="../Result/Transformer_Results/Mapillary/Comparison_200E_dataaug_contrast/loss_curve.png" style="width: 40%; margin-right: 10px;">
</div>

### Ranking model (ViT from Google) trained on Mapillary data with data augmentation

<div style="display: flex; justify-content: center;">
    <img src="../Result/Transformer_Results/Mapillary/Ranking_30E_dataaug_contrast/accuracy_curve.png" style="width: 40%; margin-right: 10px;">
    <img src="../Result/Transformer_Results/Mapillary/Ranking_30E_dataaug_contrast/loss_curve.png" style="width: 40%; margin-right: 10px;">
</div>

# Reproducing the training process

## Load the data for the comparison models

#### Mapillary data for regular VGG19 comparison models

In [None]:
batch_size = 64

train_generator, valid_generator, test_generator, train_size, valid_size = prepare_dataset_generators(image1_array_Mapillary, image2_array_Mapillary, labels_Mapillary, batch_size, "comparison")

train_steps_per_epoch = train_size // batch_size
valid_steps_per_epoch = valid_size // batch_size

#### Streetview data for regular VGG19 comparison models

In [None]:
batch_size = 64

train_generator, valid_generator, test_generator, train_size, valid_size = prepare_dataset_generators(image1_array_Streetview, image2_array_Streetview, labels_Streetview, batch_size, "comparison")

train_steps_per_epoch = train_size // batch_size
valid_steps_per_epoch = valid_size // batch_size

#### Mapillary data for the Google ViT comparison model

In [None]:
batch_size = 16

train_generator, valid_generator, test_generator, train_size, valid_size = prepare_dataset_generators(image1_array_Mapillary, image2_array_Mapillary, labels_Mapillary, batch_size, "comparison")

train_steps_per_epoch = train_size // batch_size
valid_steps_per_epoch = valid_size // batch_size

#### Streetview data for the Google ViT comparison model

In [None]:
batch_size = 16

train_generator, valid_generator, test_generator, train_size, valid_size = prepare_dataset_generators(image1_array_Streetview, image2_array_Streetview, labels_Streetview, batch_size, "comparison")

train_steps_per_epoch = train_size // batch_size
valid_steps_per_epoch = valid_size // batch_size

## Load the data for the Ranking models

#### Mapillary data for the regular VGG19 ranking model

In [None]:
batch_size = 64

train_generator, valid_generator, test_generator, train_size, valid_size = prepare_dataset_generators(image1_array_Mapillary, image2_array_Mapillary, labels_Mapillary, batch_size, "ranking")

train_steps_per_epoch = train_size // batch_size
valid_steps_per_epoch = valid_size // batch_size

#### Streetview data for the regular VGG19 ranking model

In [None]:
batch_size = 64

train_generator, valid_generator, test_generator, train_size, valid_size = prepare_dataset_generators(image1_array_Streetview, image2_array_Streetview, labels_Streetview, batch_size, "ranking")

train_steps_per_epoch = train_size // batch_size
valid_steps_per_epoch = valid_size // batch_size

#### Mapillary data for the Google ViT ranking model

In [None]:
batch_size = 16

train_generator, valid_generator, test_generator, train_size, valid_size = prepare_dataset_generators(image1_array_Mapillary, image2_array_Mapillary, labels_Mapillary, batch_size, "ranking")

train_steps_per_epoch = train_size // batch_size
valid_steps_per_epoch = valid_size // batch_size

#### Streetview data for the Google ViT ranking model

In [None]:
batch_size = 16

train_generator, valid_generator, test_generator, train_size, valid_size = prepare_dataset_generators(image1_array_Streetview, image2_array_Streetview, labels_Streetview, batch_size, "ranking")

train_steps_per_epoch = train_size // batch_size
valid_steps_per_epoch = valid_size // batch_size

## Training the models

#### Initialize the model

In [None]:
input_shape_ranking = 224
input_shape_comparison = (224, 224, 3)
num_epochs = 100

model = "Initialize the desired model here"

#### Train the model

In [None]:
history = model.fit(train_generator,
          steps_per_epoch=train_steps_per_epoch,
          epochs=num_epochs,
          validation_data=valid_generator,
          validation_steps=valid_steps_per_epoch)

#### Save and look at the model performance

In [None]:
plot_accuracy(history)
plot_loss(history)
model.save("Test_model.h5")

In [None]:
model.evaluate(test_generator, steps=valid_steps_per_epoch)