In [3]:
import pickle
import faiss
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets
import torch.nn as nn

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [5]:
class HashingMLP(nn.Module):
    def __init__(self, input_dim=128, output_dim=4):
        super(HashingMLP, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Linear(32, output_dim), 
            nn.Tanh()  
        )

    def forward(self, x):
        return self.model(x)

In [6]:
with open("database/faiss.pkl", "rb") as f:
    faiss_indices = pickle.load(f)

with open("data/train_dic.pkl", "rb") as f:
    train_dic = pickle.load(f)

train_features = torch.load("data/train_features.pt")
test_features = torch.load("data/test_features.pt")
mlp = torch.load("models/improved_hash.pt")
train_dataset = datasets.MNIST(root='./data', train=True, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, download=True)
train_labels = torch.load("data/train_labels.pt")


  train_features = torch.load("data/train_features.pt")
  test_features = torch.load("data/test_features.pt")
  mlp = torch.load("models/improved_hash.pt")
  train_labels = torch.load("data/train_labels.pt")


In [None]:
def search_similar_images(query_idx=0, k=5):
    query_feature = test_features[query_idx].to(device)

    with torch.no_grad():

      predicted_bucket = tuple(torch.sign(mlp(query_feature)).long().cpu().numpy().tolist())

    print(predicted_bucket)

    distances, indices = faiss_indices[predicted_bucket].search(query_feature.cpu().float().reshape(1,-1), k)

    vectors = np.array(train_dic[predicted_bucket])

    similar_vectors = vectors[indices].squeeze(0)

    indices = []

    for query in similar_vectors:
        matches = torch.all(train_features.cpu() == query, dim=1)  
        found_indices = torch.nonzero(matches, as_tuple=True)[0] 
        indices.append(found_indices.item())

    fig, axes = plt.subplots(1, k+1, figsize=(12, 3))
    axes[0].imshow(test_dataset[query_idx][0].squeeze(), cmap="gray")
    axes[0].set_title("Query Image")

    for i, idx in enumerate(indices[:5]):
        axes[i+1].imshow(train_dataset[idx][0].squeeze(), cmap="gray")
        axes[i+1].set_title(f"Match {i+1}")
        axes[i+1].xlabel()

    plt.show()

search_similar_images(0)

In [None]:
from sklearn.metrics import precision_recall_fscore_support
import numpy as np
import torch
import faiss
import matplotlib.pyplot as plt
from tqdm import tqdm
from collections import Counter

def compute_precision_recall_batch(k=5, batch_size=100):
    embeddings_dict = {
        bucket: faiss.rev_swig_ptr(index.get_xb(), index.ntotal * index.d).reshape(index.ntotal, index.d)
        for bucket, index in faiss_indices.items()
    }
    
    train_features_np = train_features.cpu().numpy()
    train_features_dict = {
        tuple(feat): idx for idx, feat in enumerate(train_features_np)
    }
    
    y_true = np.zeros(len(test_features), dtype=np.int64)
    y_pred = np.zeros(len(test_features), dtype=np.int64)
    
    batch_size = 32
    num_batches = (len(test_features) + batch_size - 1) // batch_size
    
    with torch.no_grad():  
        for batch_idx in tqdm(range(num_batches)):
            start_idx = batch_idx * batch_size
            end_idx = min((batch_idx + 1) * batch_size, len(test_features))
            
            query_features = test_features[start_idx:end_idx]
            
            predicted_buckets = torch.sign(mlp(query_features.to(device))).long().cpu()
            
            for i, (query_feature, pred_bucket) in enumerate(zip(query_features, predicted_buckets)):
                current_idx = start_idx + i
                query_feature_np = query_feature.numpy().reshape(1, -1)
                bucket_key = tuple(pred_bucket.numpy().tolist())
                
                if bucket_key not in faiss_indices.keys():
                    continue
                distances, indices = faiss_indices[bucket_key].search(query_feature_np, k)
                embeddings = embeddings_dict[bucket_key]
                
                retrieved_features = embeddings[indices[0]]
                retrieved_indices = [train_features_dict[tuple(feat)] for feat in retrieved_features]
                retrieved_labels = train_labels[retrieved_indices]
                
                y_true[current_idx] = test_labels[current_idx].cpu()
                y_pred[current_idx] = Counter(retrieved_labels.tolist()).most_common(1)[0][0]

    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average=None)

    for digit in range(10):
        print(f"Digit {digit}: Precision = {precision[digit]:.3f}, Recall = {recall[digit]:.3f}, F1-score = {f1[digit]:.3f}")

    plt.figure(figsize=(8, 5))
    x = range(10)
    plt.plot(x, precision, label="Precision", marker="o", linestyle="dashed")
    plt.plot(x, recall, label="Recall", marker="s", linestyle="dashed")
    plt.plot(x, f1, label="F1-score", marker="^", linestyle="dashed")
    plt.xticks(x)
    plt.xlabel("Digit Class")
    plt.ylabel("Score")
    plt.title(f"Precision, Recall & F1-score for Top-{k} Retrieval")
    plt.legend()
    plt.grid()
    plt.show()

compute_precision_recall_batch(k=5, batch_size=512)
