In [2]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications import ResNet50, MobileNetV2
from tensorflow.keras.models import Model
from sklearn.metrics.pairwise import cosine_similarity
from tensorflow.keras.layers import Flatten, Dense
import random

cifar100 = tf.keras.datasets.cifar100
(x_train, y_train), (x_test, y_test) = cifar100.load_data()

x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
input_shape = x_train.shape[1:]  # (32, 32, 3)
ground_truth_labels = y_test.flatten()




Downloading data from https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
[1m169001437/169001437[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 0us/step


In [3]:
def build_resnet_encoder(input_shape):
    base_model = ResNet50(weights='imagenet', include_top=False, input_shape=input_shape)
    x = Flatten()(base_model.output)
    x = Dense(512, activation='relu')(x)
    encoder = Model(inputs=base_model.input, outputs=x)
    return encoder

def build_mobilenet_encoder(input_shape):
    base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=input_shape)
    x = Flatten()(base_model.output)
    x = Dense(512, activation='relu')(x)
    encoder = Model(inputs=base_model.input, outputs=x)
    return encoder

resnet_encoder = build_resnet_encoder(input_shape)
mobilenet_encoder = build_mobilenet_encoder(input_shape)


Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
[1m94765736/94765736[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 0us/step


  base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=input_shape)


Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top.h5
[1m9406464/9406464[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


In [3]:
def evaluate_similarity(encoder, query_image, dataset, ground_truth_labels, query_label, top_k=5):
    query_embedding = encoder.predict(np.expand_dims(query_image, axis=0))
    dataset_embeddings = encoder.predict(dataset)

    query_embedding = query_embedding.reshape(1, -1)
    dataset_embeddings = dataset_embeddings.reshape(dataset_embeddings.shape[0], -1)

    similarities = cosine_similarity(query_embedding, dataset_embeddings).flatten()

    sorted_indices = np.argsort(similarities)[::-1]

    relevant_labels = [1 if ground_truth_labels[i] == query_label else 0 for i in sorted_indices]

    retrieved_labels = relevant_labels[:top_k]

    precision = np.sum(retrieved_labels) / top_k if top_k > 0 else 0
    recall = np.sum(retrieved_labels) / np.sum(np.array(ground_truth_labels) == query_label) if np.sum(np.array(ground_truth_labels) == query_label) > 0 else 0
    retrieval_accuracy = np.mean(retrieved_labels)

    return precision, recall, retrieval_accuracy


In [4]:
query_idx = random.randint(0, len(x_test) - 1)
query_image = x_test[query_idx]
query_label = ground_truth_labels[query_idx]

precision_resnet, recall_resnet, retrieval_accuracy_resnet = evaluate_similarity(
    resnet_encoder, query_image, x_test, ground_truth_labels, query_label, top_k=5
)

precision_mobilenet, recall_mobilenet, retrieval_accuracy_mobilenet = evaluate_similarity(
    mobilenet_encoder, query_image, x_test, ground_truth_labels, query_label, top_k=5
)

print("ResNet Results:")
print(f"Precision: {precision_resnet:.4f}")
print(f"Recall: {recall_resnet:.4f}")
print(f"Retrieval Accuracy: {retrieval_accuracy_resnet:.4f}")

print("\nMobileNet Results:")
print(f"Precision: {precision_mobilenet:.4f}")
print(f"Recall: {recall_mobilenet:.4f}")
print(f"Retrieval Accuracy: {retrieval_accuracy_mobilenet:.4f}")

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 6s/step
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 13ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 12ms/step
ResNet Results:
Precision: 0.2000
Recall: 0.0100
Retrieval Accuracy: 0.2000

MobileNet Results:
Precision: 0.2000
Recall: 0.0100
Retrieval Accuracy: 0.2000
