In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader, Dataset
from sklearn.neighbors import NearestNeighbors
import os
from PIL import Image


In [2]:
# Define Autoencoder class
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten()
        )
        self.decoder = nn.Sequential(
            nn.Linear(16 * 7 * 7, 64),
            nn.ReLU(),
            nn.Linear(64, 16 * 7 * 7),
            nn.ReLU(),
            nn.Unflatten(1, (16, 7, 7)),
            nn.ConvTranspose2d(16, 8, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(8, 1, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded

In [3]:
# Load images from the gallery and query sets
class ImageDataset(Dataset):
    def __init__(self, image_folder, transform=None):
        self.image_folder = image_folder
        self.image_files = sorted(os.listdir(image_folder))
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_folder, self.image_files[idx])
        image = Image.open(img_path).convert('L')  # Convert to grayscale
        if self.transform:
            image = self.transform(image)
        return image, self.image_files[idx]


In [4]:
# Define transformations
transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor()
])


In [5]:
# Load datasets
gallery_dataset = ImageDataset("gallery_set", transform=transform)
query_dataset = ImageDataset("query_set", transform=transform)

gallery_loader = DataLoader(gallery_dataset, batch_size=8, shuffle=False)
query_loader = DataLoader(query_dataset, batch_size=1, shuffle=False)

In [6]:
# Initialize model and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_num_threads(4)  # Enable CPU optimization
autoencoder = Autoencoder().to(device)
optimizer = optim.Adam(autoencoder.parameters(), lr=0.001)
criterion = nn.MSELoss()


In [7]:
# Train Autoencoder
num_epochs = 5
for epoch in range(num_epochs):
    autoencoder.train()
    total_loss = 0
    for images, _ in gallery_loader:
        images = images.to(device)
        optimizer.zero_grad()
        encoded, decoded = autoencoder(images)
        loss = criterion(decoded, images)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss / len(gallery_loader):.4f}")


Epoch [1/5], Loss: 0.0221
Epoch [2/5], Loss: 0.0147
Epoch [3/5], Loss: 0.0133
Epoch [4/5], Loss: 0.0126
Epoch [5/5], Loss: 0.0122


In [8]:
# Extract features for gallery images
autoencoder.eval()
gallery_features = []
gallery_image_names = []
with torch.no_grad():
    for images, filenames in gallery_loader:
        images = images.to(device)
        encoded, _ = autoencoder(images)
        gallery_features.append(encoded.cpu().numpy())
        gallery_image_names.extend(filenames)

gallery_features = np.vstack(gallery_features)


In [10]:
# Use Locality Sensitive Hashing (LSH) for fast image retrieval
lsh = NearestNeighbors(n_neighbors=5, algorithm='kd_tree').fit(gallery_features)

# Retrieve similar images for query set
for images, filename in query_loader:
    images = images.to(device)
    encoded, _ = autoencoder(images)
    encoded_np = encoded.cpu().detach().numpy()  
    distances, indices = lsh.kneighbors(encoded_np)
    print(f"Query: {filename[0]} -> Closest Matches: {[gallery_image_names[i] for i in indices[0]]}")

Query: img_0.png -> Closest Matches: ['img_42764.png', 'img_6310.png', 'img_29736.png', 'img_43763.png', 'img_40121.png']
Query: img_1.png -> Closest Matches: ['img_19437.png', 'img_15874.png', 'img_13334.png', 'img_17929.png', 'img_35351.png']
Query: img_10.png -> Closest Matches: ['img_40119.png', 'img_17368.png', 'img_23650.png', 'img_18958.png', 'img_11651.png']
Query: img_100.png -> Closest Matches: ['img_39396.png', 'img_46007.png', 'img_34770.png', 'img_12692.png', 'img_15121.png']
Query: img_1000.png -> Closest Matches: ['img_34556.png', 'img_5114.png', 'img_25293.png', 'img_37542.png', 'img_46669.png']
Query: img_10000.png -> Closest Matches: ['img_3700.png', 'img_18439.png', 'img_17693.png', 'img_17448.png', 'img_42882.png']
Query: img_10001.png -> Closest Matches: ['img_26281.png', 'img_37679.png', 'img_30255.png', 'img_41169.png', 'img_32485.png']
Query: img_10002.png -> Closest Matches: ['img_45115.png', 'img_12824.png', 'img_34216.png', 'img_20978.png', 'img_20102.png']
Q