# Contrastive
In questo file ho fatto un piccolo test per vedere se le funzioni per la loss contrastive sono corrette.

In [None]:
import keras
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm
from keras import layers, datasets, ops

epochs = 3
batch_size = 16
margin = 1.0  # Margin for contrastive loss.
temperature = 1.0  # Temperature for loss

### Dataset
Prendo il dataset dall'esempio di [keras](https://keras.io/examples/vision/siamese_contrastive/)

In [None]:
import random
def make_pairs(x, y, batch):
    num_classes = max(y) + 1
    digit_indices = [np.where(y == i)[0] for i in range(num_classes)]

    pairs = []
    labels = []

    for idx1 in tqdm(range(len(x))):
        for _ in range(batch // 2):
            x1 = x[idx1]
            label1 = y[idx1]

            # Find a positive pair
            idx2 = random.choice(digit_indices[label1])
            x2 = x[idx2]
            pairs += [[x1, x2]]
            labels += [0.0]

            # Find a negative pair
            label2 = random.randint(0, num_classes - 1)
            while label2 == label1:
                label2 = random.randint(0, num_classes - 1)
            idx2 = random.choice(digit_indices[label2])
            x2 = x[idx2]

            pairs += [[x1, x2]]
            labels += [1.0]
    return np.array(pairs), np.array(labels, dtype=np.float32)

(x_train_val, y_train_val), (x_test, y_test) = datasets.mnist.load_data()

x_train_val = x_train_val.astype("float32")
x_test = x_test.astype("float32")
x_train, x_val = x_train_val[:30000], x_train_val[30000:]
y_train, y_val = y_train_val[:30000], y_train_val[30000:]

pairs_train, labels_train = make_pairs(x_train_val, y_train_val, batch_size)
pairs_val, labels_val = make_pairs(x_val, y_val, batch_size)
pairs_test, labels_test = make_pairs(x_test, y_test, batch_size)

# Split all pairs into two sets
x_train_1 = pairs_train[:, 0]  # x_train_1.shape is (60000, 28, 28)
x_train_2 = pairs_train[:, 1]
x_val_1 = pairs_val[:, 0]  # x_val_1.shape = (60000, 28, 28)
x_val_2 = pairs_val[:, 1]
x_test_1 = pairs_test[:, 0]  # x_test_1.shape = (20000, 28, 28)
x_test_2 = pairs_test[:, 1]

print("Training pairs shape:", pairs_train.shape)
print("Training labels shape:", labels_train.shape)

# mostro quante classi ci sono e ne faccio il plot
clazz = np.bincount(y_train)
plt.bar(range(len(clazz)), clazz)
plt.title("MNIST Class Distribution")
plt.xlabel("Class")
plt.ylabel("Number of samples")
plt.show()

Visualizzo alcuni esempi per capire se è corretto

In [None]:
# Funzione presa da Keras
def visualize(pairs, labels, to_show=6, num_col=3, predictions=None, test=False):
    num_row = to_show // num_col if to_show // num_col != 0 else 1
    to_show = num_row * num_col

    fig, axes = plt.subplots(num_row, num_col, figsize=(5, 5))
    for i in range(to_show):
        if num_row == 1:
            ax = axes[i % num_col]
        else:
            ax = axes[i // num_col, i % num_col]

        ax.imshow(ops.concatenate([pairs[i][0], pairs[i][1]], axis=1), cmap="gray")
        ax.set_axis_off()
        if test:
            ax.set_title("True: {} | Pred: {:.5f}".format(labels[i], predictions[i][0]))
        else:
            ax.set_title("Label: {}".format(labels[i]))
    if test:
        plt.tight_layout(rect=(0, 0, 1.9, 1.9), w_pad=0.0)
    else:
        plt.tight_layout(rect=(0, 0, 1.5, 1.5))
    plt.show()

visualize(pairs_train[:-1], labels_train[:-1], to_show=6, num_col=6)


### Funzioni
Ora mettiamo le distanze e la loss.\
Una loss aggiuntiva è la [Soft Nearest Neighbors Loss](https://lilianweng.github.io/posts/2021-05-31-contrastive/#soft-nearest-neighbors-loss)

In [None]:
def euclidean_distance(vects):
    x, y = vects
    sum_square = ops.sum(ops.square(x - y), axis=1, keepdims=True)
    return ops.sqrt(ops.maximum(sum_square, keras.backend.epsilon()))

def loss(margin=1):
    # Contrastive loss = mean( (1-true_value) * square(prediction) +
    #                         true_value * square( max(margin-prediction, 0) ))
    def contrastive_loss(y_true, y_pred):
        square_pred = ops.square(y_pred)
        margin_square = ops.square(ops.maximum(margin - (y_pred), 0))
        return ops.mean((1 - y_true) * square_pred + (y_true) * margin_square)

    def contrastive_SNN_loss(y_true, y_pred):
        mask = ops.equal(y_true, 0)
        exp_similarity = ops.exp(ops.negative(y_pred / temperature))

        numerator = ops.sum(exp_similarity * mask)
        denominator = ops.sum(exp_similarity) + keras.backend.epsilon()  # Add epsilon to avoid division by zero

        safe_ratio = numerator / denominator
        safe_ratio = ops.maximum(safe_ratio, keras.backend.epsilon())  # Ensure ratio is not less than epsilon

        return ops.negative(ops.mean(ops.log(safe_ratio)))

    return contrastive_SNN_loss

### Modello
Qui viene definito il modello e trainato

In [None]:
input = layers.Input((28, 28, 1))
x = layers.BatchNormalization()(input)
x = layers.Conv2D(4, (5, 5), activation="tanh")(x)
x = layers.AveragePooling2D(pool_size=(2, 2))(x)
x = layers.Conv2D(16, (5, 5), activation="tanh")(x)
x = layers.AveragePooling2D(pool_size=(2, 2))(x)
x = layers.Flatten()(x)

x = layers.BatchNormalization()(x)
x = layers.Dense(10, activation="tanh")(x)
embedding_network = keras.Model(input, x)


input_1 = layers.Input((28, 28, 1))
input_2 = layers.Input((28, 28, 1))
tower_1 = embedding_network(input_1)
tower_2 = embedding_network(input_2)
merge_layer = layers.Lambda(euclidean_distance, output_shape=(1,))([tower_1, tower_2])
normal_layer = layers.BatchNormalization()(merge_layer)
output_layer = layers.Dense(1, activation="sigmoid")(normal_layer)
siamese = keras.Model(inputs=[input_1, input_2], outputs=output_layer)

siamese.compile(loss=loss(margin=margin), optimizer="RMSprop", metrics=["accuracy"])

In [None]:
history = siamese.fit(
    [x_train_1, x_train_2],
    labels_train,
    validation_data=([x_val_1, x_val_2], labels_val),
    batch_size=batch_size,
    epochs=epochs,
)

### Risultati
Qui vedremo i risultati dell'addestramento

In [None]:
def plt_metric(history, metric, title, has_valid=True):
    plt.plot(history[metric])
    if has_valid:
        plt.plot(history["val_" + metric])
        plt.legend(["train", "validation"], loc="upper left")
    plt.title(title)
    plt.ylabel(metric)
    plt.xlabel("epoch")
    plt.show()

results = siamese.evaluate([x_test_1, x_test_2], labels_test)
print("test loss, test acc:", results)

predictions = siamese.predict([x_test_1, x_test_2])
visualize(pairs_test, labels_test, to_show=16, predictions=predictions, test=True)

plt_metric(history=history.history, metric="accuracy", title="Model accuracy")
plt_metric(history=history.history, metric="loss", title="Contrastive Loss")
