<a href="https://colab.research.google.com/github/MHosseinHashemi/Image_Similarity/blob/main/Image_Simmilarity_CenterLoss_TF_v4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import random
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, RandomFlip, RandomRotation, Dense, Dropout, Lambda

from tqdm import tqdm
from collections import defaultdict
from sklearn.metrics import roc_auc_score, roc_curve, accuracy_score

In [2]:
(train_data, test_data, validation_data), info = tfds.load("oxford_flowers102", split=['train', 'validation', 'test'], as_supervised=True, with_info=True)

In [3]:
height = 128
width = 128

def preprocess_images(image, label, height, width):
    # image = tf.image.resize_with_crop_or_pad(image, target_height=height, target_width=width)
    image = tf.image.resize(image, [width, height])
    image = tf.cast(image, tf.float32) / 255.0
    return image, label


In [4]:
train_ds = train_data.map(lambda image, label: preprocess_images(image, label, height, width))

In [5]:
test_ds = test_data.map(lambda image, label: preprocess_images(image, label, height, width))

In [8]:
def data_loader(data):
  x = []
  y = []
  for img, label in tqdm(data.as_numpy_iterator()):
    x.append(img)
    y.append(label)

  return x, y

In [9]:
x_train, y_train = data_loader(train_ds)

1020it [00:01, 561.39it/s]


In [10]:
x_test, y_test = data_loader(test_ds)

1020it [00:01, 560.34it/s]


In [11]:
# Base Model
MODEL_URL = "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_s/feature_vector/2"

input_layer = Input(shape=(height, width, 3))
x = RandomFlip()(input_layer)
x = RandomRotation(0.3)(x)
x = hub.KerasLayer(MODEL_URL, trainable=True)(x)
x = Dropout(0.25)(x)
x = Dense(128, activation=None)(x)
x = Lambda(lambda x: tf.math.l2_normalize(x, axis=1))(x)  # L2 normalize embeddings
output_layer = Dense(102, activation='softmax')(x)

model = tf.keras.Model(inputs=input_layer, outputs=output_layer)

model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 128, 128, 3)]     0         
                                                                 
 random_flip (RandomFlip)    (None, 128, 128, 3)       0         
                                                                 
 random_rotation (RandomRota  (None, 128, 128, 3)      0         
 tion)                                                           
                                                                 
 keras_layer (KerasLayer)    (None, 1280)              20331360  
                                                                 
 dropout (Dropout)           (None, 1280)              0         
                                                                 
 dense (Dense)               (None, 128)               163968    
                                                             

In [10]:
def batch_me(images, labels, batch_size, samples_per_class):
  temp_dict = defaultdict(list) # A Dic of Lists to save img, label pairs as one object
  for img, label in zip(images, labels):
    temp_dict[label].append(img)

  while True:
    batch_x = []
    batch_y = []
    while len(batch_x) < batch_size:
      for category, examples in temp_dict.items():
        # Only feed as large as the "samples per class"
        # If the batch did not had enough space, feed as much as it has
        n_samples = min(samples_per_class, (batch_size - len(batch_x)))
        if n_samples == 0:
          break
        # Pick randomly from simmilar images of the same category
        samples = random.sample(examples, k=n_samples)
        # Add corresponding x, y values to the batch
        batch_x.extend(samples)
        batch_y.extend([category] * len(samples))


    # It should be a continous operation
    yield np.array(batch_x), np.array(batch_y)


In [11]:
def center_loss(feature_vector, center):
    difference = feature_vector - center
    loss = tf.reduce_mean(tf.reduce_sum(difference**2, axis=1))

    return loss

# Feature Extraction

In [12]:
feature_extraction_model = Model(inputs=model.input, outputs=model.layers[-2].output)

In [None]:
raw_features = {}
class_feature_vectors = {}  # Store pre-calculated feature vectors for each class

# Calculate and store class feature vectors
for x, y in zip(x_train, y_train):
    if y not in class_feature_vectors:
        feature_vector = feature_extraction_model.predict(np.expand_dims(x, axis=0)).mean()
        class_feature_vectors[y] = feature_vector

# Calculate raw_features using class_feature_vectors
for x, y in zip(x_train, y_train):
    if y in raw_features:
        new_center = feature_extraction_model.predict(np.expand_dims(x, axis=0)).mean()
        raw_features[y] = [(prev + new_center) / 2 for prev in raw_features[y]]
    else:
        raw_features[y] = [class_feature_vectors[y]] * 128

all_features = {key: value for key, value in raw_features.items()}

# Clean up memory
del class_feature_vectors


# Training Loop

In [None]:
epochs = 30
alpha = 0.5
batch_size = 32
n_examples_per_class = 4
EMA_lr = 0.9
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0003)
num_steps_per_epoch = len(x_train) // batch_size

# Calculate initial centers
centers = tf.Variable(initial_value=tf.random.normal((102, 128), mean=0.0, stddev=0.5))

# Loop through all samples
for index in range(102):
    # category_features = all_features[y_train == index]
    category_center = all_features[y_train == index]
    # if len(category_features) > 0:
    if len(category_center) > 0 :
        # category_center = tf.reduce_mean(category_features, axis=0)
        centers[index].assign(category_center)

# Training
for epoch in tqdm(range(epochs)):
    total_loss = 0.0
    num_batches = 0
    train_preds = []
    train_labels = []

    # Create batch generator for the current epoch
    batch_generator = batch_me(images=x_train, labels=y_train, batch_size=batch_size, samples_per_class=n_examples_per_class)

    for batch_idx in range(num_steps_per_epoch):
        if batch_idx == 0:
            print("\n")
        # Get the next batch
        batch_x, batch_y = next(batch_generator)
        # Capture Gradients
        with tf.GradientTape() as tape:

            # Extract Features per batch
            predictions = model(batch_x, training=False)
            # train_preds.extend(predictions.numpy())
            train_preds.extend(predictions)
            train_labels.extend(tf.one_hot(batch_y, 102))

            # initialize batch centers
            batch_centers = centers.numpy()[batch_y]

            # Calculate Batch Centers
            for index in range(batch_size):
                instance_feature = feature_extraction_model(np.expand_dims(batch_x[index], axis=0))
                instance_mean = tf.reduce_mean(instance_feature[0], axis=0)
                batch_centers[index] = [instance_mean] * 128

            # Center-Loss calculation
            c_loss = center_loss(instance_feature, batch_centers)
            # Combine it with CategoricalCrossEntropyLoss
            cls_loss = tf.keras.losses.CategoricalCrossentropy()(tf.one_hot(batch_y, 102), predictions)
            # Total Loss
            loss = (c_loss * alpha) + cls_loss

        # Calculate Gradients
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        # Update training loss
        total_loss += loss
        num_batches += 1
        print(f"Epoch: {epoch} -step: {num_batches} -running loss: {loss:.3f}")


    # Calculate training Loss
    training_loss = total_loss / num_batches

    ### DEBUGGIN
    # print(f"train_preds: {train_preds}")
    # print(f"train_preds shape: {train_preds}")
    # print("="*100)
    # print(f"train_preds: {train_labels}")
    # print(f"train_preds shape: {train_labels.shape}")
    ### END OF DEBUGGIN

    """Calculate AUC and ROC for training"""
    # train_auc = roc_auc_score(tf.one_hot(np.argmax(train_labels, axis=1), 102), train_preds, multi_class='ovr')
    # train_auc = tf.keras.metrics.AUC()(tf.one_hot(np.argmax(train_labels, axis=1), 102), train_preds).numpy()
    # train_fpr, train_tpr, _ = roc_curve(tf.one_hot(np.argmax(train_labels, axis=1), 102), train_preds, pos_label=None)
    # train_accuracy = tf.keras.metrics.Accuracy()(np.argmax(train_labels, axis=1), np.argmax(train_preds, axis=1)).numpy()


    """Validation Loop"""
    val_batch_generator = batch_me(images=x_test, labels=y_test, batch_size=batch_size, samples_per_class=n_examples_per_class)
    val_preds = []
    val_labels = []
    for step in range(num_steps_per_epoch):
        val_batch_x, val_batch_y = next(val_batch_generator)
        # Make Predictions
        val_predictions = model(val_batch_x, training=False)
        # val_preds.extend(val_predictions.numpy())
        val_preds.extend(val_predictions)
        val_labels.extend(val_batch_y)
        # Init Centers
        val_batch_centers = centers.numpy()[val_batch_y]


    # Center Calcualtion
    for idx in range(batch_size):
        val_instance_features = feature_extraction_model(np.expand_dims(val_batch_x[idx], axis=0))
        val_instance_mean = tf.reduce_mean(val_instance_features[0], axis=0)
        val_batch_centers[idx] = val_instance_mean * 128


    # Loss Calculation
    val_c_loss = center_loss(val_instance_features, val_batch_centers)
    val_cls_loss = tf.keras.losses.CategoricalCrossentropy()(tf.one_hot(val_batch_y, 102), val_predictions)
    val_loss = (val_c_loss * alpha) + val_cls_loss
    val_loss = val_loss / num_batches


    """AUC ROC ACC Calculation"""
    val_auc = tf.keras.metrics.AUC()(tf.one_hot(val_labels, 102), val_preds).numpy()
    # val_fpr, val_tpr, _ = roc_curve(tf.one_hot(np.argmax(val_labels, axis=1), 102), val_preds, pos_label=None)
    # val_accuracy = tf.keras.metrics.Accuracy()(np.argmax(val_labels, axis=1), np.argmax(val_preds, axis=1)).numpy()

    print(f"\nEpoch {epoch + 1}/{epochs} - Training Loss: {training_loss:.3f} - Validation Loss: {val_loss.numpy():.3f}")
    # print(f"Training AUC: {train_auc:.3f} - Validation AUC: {val_auc:.3f}")
    # print(f"Training Accuracy: {train_accuracy:.3f} - Validation Accuracy: {val_accuracy:.3f}")


    # Centers Update Frequency
    for index in range(102):
        category_features = all_features[y_train == index]
        if len(category_features)>0:
            category_center = tf.reduce_mean(category_features, axis=0)
            # centers[index].assign((1.0 - EMA_lr) * centers[index]) + (EMA_lr * category_center)
            centers[index].assign((1.0 - EMA_lr) * tf.cast(centers[index], tf.float32) + (EMA_lr * tf.cast(category_center, tf.float32)))
    print(f"Centers updated - Step : {epoch}")
    print("="*100)
