In [None]:
# Cell 1: Load and inspect the dataset directory structure
import os
base_dir = 'data/train'  

print("Contents of the dataset directory:")
print(os.listdir(base_dir))

In [None]:
# Cell 2: Load and preprocess dataset, split into train/test, create data loaders
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, random_split

# Define image preprocessing transformations (resize, normalize for ResNet18)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load dataset from directory structure (cats/ and dogs/ folders)
full_dataset = datasets.ImageFolder(root=base_dir, transform=transform)

# Split dataset (80% train, 20% validation)
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])

# Create data loaders for batch processing
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print(f"Total images: {len(full_dataset)}")
print(f"Training images: {len(train_dataset)}")
print(f"Test images: {len(test_dataset)}")

In [None]:
# Cell 3: Load pre-trained ResNet18 model as feature extractor
import torch.nn as nn
import torchvision.models as models

# Use GPU if available, otherwise CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pre-trained ResNet18 from ImageNet and remove the final classification layer
# We only need the feature extraction part (first 17 layers)
model = models.resnet18(pretrained=True)
model = nn.Sequential(*list(model.children())[:-1])  # Remove final classification layer
model = model.to(device)
model.eval()  # Set to evaluation mode (no dropout, batch norm uses running stats)

print("Model loaded and ready.")

In [None]:
# Cell 4: Extract and cache embeddings for all images in the dataset
import numpy as np
from tqdm import tqdm
from PIL import Image
import pickle

image_paths = []
embeddings = []

# Process each image and extract 512-dimensional embeddings using ResNet18
for path, _ in tqdm(full_dataset.samples, desc="Extracting embeddings"):
    image = Image.open(path).convert("RGB")
    input_tensor = transform(image).unsqueeze(0).to(device)

    # Generate embedding without computing gradients (faster)
    with torch.no_grad():
        embedding = model(input_tensor).squeeze().cpu().numpy()
        embeddings.append(embedding)
        image_paths.append(path)

embeddings = np.array(embeddings)

# Save embeddings and paths for later retrieval
with open('image_embeddings.pkl', 'wb') as f:
    pickle.dump({'paths': image_paths, 'embeddings': embeddings}, f)

print("Embeddings saved to image_embeddings.pkl")

In [None]:
# Cell 5: Define helper function to preprocess query images
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt

def preprocess_image(image_path):
    """Preprocess an image for model input (resize, normalize)"""
    image = Image.open(image_path).convert('RGB')
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    return transform(image).unsqueeze(0).to(device)

In [None]:
# Cell 6: Find similar images using cosine similarity
def find_similar_images(query_image_path, model, all_embeddings, all_paths, top_n=5):
    """Find top-N images most similar to query image using cosine similarity"""
    query_image = preprocess_image(query_image_path)

    # Extract embedding for query image
    with torch.no_grad():
        query_embedding = model(query_image).squeeze().cpu().numpy()

    # Detect query class (cat or dog) from filename
    if "dog" in query_image_path.lower():
        query_class = "dog"
    elif "cat" in query_image_path.lower():
        query_class = "cat"
    else:
        raise ValueError("Query image path does not contain 'dog' or 'cat'.")

    # Filter dataset to only images of the same class for fair comparison
    filtered_embeddings = []
    filtered_paths = []
    for emb, path in zip(all_embeddings, all_paths):
        if query_class in path.lower():
            filtered_embeddings.append(emb)
            filtered_paths.append(path)

    if not filtered_embeddings:
        raise ValueError(f"No images found in dataset for class '{query_class}'.")

    # Compute cosine similarity between query and all images of same class
    similarities = cosine_similarity([query_embedding], filtered_embeddings)[0]
    # Get indices of top-N highest similarity scores
    top_indices = np.argsort(similarities)[-top_n:][::-1]

    similar_paths = [filtered_paths[i] for i in top_indices]
    scores = [similarities[i] for i in top_indices]

    return similar_paths, scores

In [None]:
# Cell 7: Visualize results - display query image with top similar images
def display_images(query_path, similar_paths, scores):
    """Display query image alongside the most similar images with similarity scores"""
    plt.figure(figsize=(15, 8))

    # Display query image on the left
    plt.subplot(1, len(similar_paths) + 1, 1)
    plt.imshow(Image.open(query_path))
    plt.title("Query Image")
    plt.axis('off')

    # Display top similar images with their similarity scores
    for i, (path, score) in enumerate(zip(similar_paths, scores)):
        plt.subplot(1, len(similar_paths) + 1, i + 2)
        plt.imshow(Image.open(path))
        plt.title(f"Similarity: {score:.2f}")
        plt.axis('off')

    plt.tight_layout()
    plt.show()


# Load cached embeddings and paths
with open('image_embeddings.pkl', 'rb') as f:
    data = pickle.load(f)
    all_paths = data['paths']
    all_embeddings = data['embeddings']

# Query image path - change this to test with different images
query_image_path = "dog.jpg"  # Change to your query image
print(f"Using query image: {query_image_path}")

# Find and display similar images
similar_paths, scores = find_similar_images(query_image_path, model, all_embeddings, all_paths)
display_images(query_image_path, similar_paths, scores)