# Install Dependencies

In [None]:
!pip install -q medmnist faiss-cpu torchvision
## if GPU FAISS is available
#!pip install faiss-gpu


# Extract Embeddings

In [None]:
import os
import torch
import numpy as np
import faiss
from torchvision import models, transforms
from medmnist import PneumoniaMNIST

OUTPUT_DIR = "/content/task3_outputs"
os.makedirs(OUTPUT_DIR, exist_ok=True)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

# Load test dataset
test_dataset = PneumoniaMNIST(split='test', download=True)

# Pretrained ResNet18 backbone
model = models.resnet18(pretrained=True)
model.fc = torch.nn.Identity()
model = model.to(DEVICE)
model.eval()

transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

embeddings = []
labels = []

print("Extracting embeddings...")

for img, label in test_dataset:
    img = transform(img).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        emb = model(img)

    embeddings.append(emb.cpu().numpy().squeeze())
    labels.append(label.item())

embeddings = np.array(embeddings).astype("float32")

# Normalize for cosine similarity
faiss.normalize_L2(embeddings)

np.save(os.path.join(OUTPUT_DIR, "embeddings.npy"), embeddings)
np.save(os.path.join(OUTPUT_DIR, "labels.npy"), np.array(labels))

print("Embeddings saved.")


# Build FAISS Index

In [None]:
embeddings = np.load(os.path.join(OUTPUT_DIR, "embeddings.npy")).astype("float32")

dim = embeddings.shape[1]

index = faiss.IndexFlatIP(dim)  # inner product = cosine (after normalization)
index.add(embeddings)

faiss.write_index(index, os.path.join(OUTPUT_DIR, "faiss_index.bin"))

print("Index built successfully.")


# Precision@K + Recall@K

In [None]:
embeddings = np.load(os.path.join(OUTPUT_DIR, "embeddings.npy"))
labels = np.load(os.path.join(OUTPUT_DIR, "labels.npy"))

index = faiss.read_index(os.path.join(OUTPUT_DIR, "faiss_index.bin"))

def compute_metrics_at_k(k):
    total_precision = 0
    total_recall = 0

    for i in range(len(embeddings)):
        query_emb = embeddings[i].reshape(1, -1)

        distances, indices = index.search(query_emb, k + 1)
        retrieved = indices[0][1:]  # remove self

        correct = sum(labels[idx] == labels[i] for idx in retrieved)

        precision = correct / k
        recall = correct / sum(labels == labels[i])

        total_precision += precision
        total_recall += recall

    return total_precision / len(embeddings), total_recall / len(embeddings)


for k in [1, 5, 10]:
    p, r = compute_metrics_at_k(k)
    print(f"K={k}")
    print(f"  Precision@{k}: {p:.4f}")
    print(f"  Recall@{k}: {r:.4f}")


# Visual Retrieval Demo (No argparse)

In [None]:
import matplotlib.pyplot as plt

query_index = 10
k = 5

query_emb = embeddings[query_index].reshape(1, -1)
distances, indices = index.search(query_emb, k + 1)
retrieved_indices = indices[0][1:]

query_img, query_label = test_dataset[query_index]

plt.figure(figsize=(15, 3))
plt.subplot(1, k + 1, 1)
plt.imshow(query_img, cmap='gray')
plt.title(f"Query (Label {query_label})")
plt.axis('off')

for i, idx in enumerate(retrieved_indices):
    img, label = test_dataset[idx]
    plt.subplot(1, k + 1, i + 2)
    plt.imshow(img, cmap='gray')
    plt.title(f"Rank {i+1} (L {label})")
    plt.axis('off')

plt.tight_layout()
plt.show()
