In [None]:
import os
import asyncio
import concurrent.futures
import torch
import cv2
import pickle
from io import BytesIO
from torchvision import models, transforms
from PIL import Image, ImageFile
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import faiss
import pandas as pd
import math
import matplotlib.pyplot as plt
from IPython.display import display
ImageFile.LOAD_TRUNCATED_IMAGES = True

### Listing Objects from AWS S3 bucket

In [None]:
import boto3
from botocore.exceptions import NoCredentialsError  

#This is if you are using dataset from s3 bucket
access_key = ''
secret_key = ''
bucket_name = ''
folder_path = ''

s3 = boto3.client('s3', aws_access_key_id=access_key, aws_secret_access_key=secret_key)

In [None]:
def list_objects_page(page):
    aws_files = []
    if 'Contents' in page:
        aws_files.extend([os.path.join(folder_path, os.path.basename(obj["Key"])) for obj in page['Contents'] if ".jpg" in obj['Key']])
    return aws_files

paginator = s3.get_paginator('list_objects_v2')
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=folder_path)

In [None]:
aws_files = []
page_workers = 8
# Use ThreadPoolExecutor to parallelize listing
with concurrent.futures.ThreadPoolExecutor(max_workers=page_workers) as executor:
    # List objects in parallel
    futures = [executor.submit(list_objects_page, page) for page in tqdm(page_iterator, desc="Listing")]
    
    # Gather results from all futures
    for future in concurrent.futures.as_completed(futures):
        aws_files.extend(future.result())

In [None]:
# Saving the list in a pickle file
with open('indexes/aws_file_list.pkl', 'wb') as f:
   pickle.dump(aws_files, f)

### Using ResNet-18 Imagenet1K for generating the embeddings

In [None]:
weights = models.ResNet18_Weights.IMAGENET1K_V1
model = models.resnet18(weights=weights)

model.eval()
model.fc = nn.Identity()
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
device = 'cpu'
print(device)
model.to(device)

In [None]:
def process_batch(batch_paths_images):
    paths, images = zip(*batch_paths_images)
    
    # Convert images to tensors and stack them
    images_tensor = torch.stack([transform(Image.fromarray(img)) for img in images]).to(device)
    
    with torch.no_grad():
        out_tensors = model(images_tensor)
    
    embeddings = [(path, out_tensor.cpu().numpy()) for path, out_tensor in zip(paths, out_tensors)]
    return embeddings

In [None]:
batch_size = 10
download_workers = 4
embeddings = []
representations = []

def download_and_process_batch(batch_paths):
    batch_images = []
    for path in batch_paths:
        response = s3.get_object(Bucket=bucket_name, Key=path)
        image_data = response['Body']
        img_np_array = np.asarray(bytearray(image_data.read()), dtype=np.uint8)
        img = cv2.imdecode(img_np_array, cv2.IMREAD_COLOR)
        batch_images.append((path, img))
    return process_batch(batch_images)



### Parallelising the download of images and encoding for saving GPU runtime

In [None]:
batch_paths_list = [aws_files[i:i + batch_size] for i in range(0,len(aws_files), batch_size)]
embeddings = []

with concurrent.futures.ThreadPoolExecutor(max_workers=download_workers) as download_executor:
    for batch_paths in tqdm(batch_paths_list, desc="Downloading"):
        batch_futures = []
        for path in batch_paths:
            future = download_executor.submit(download_and_process_batch, [path])
            batch_futures.append(future)
        
        for future in concurrent.futures.as_completed(batch_futures):
            try:
                embeddings_list = future.result()
                embeddings.extend(embeddings_list)
            except Exception as e:
                print(f"An error occurred: {e}")

In [None]:
with open('index_server/indexes/aws_file_list.pkl', 'rb') as f:
   aws_files = pickle.load(f)

In [None]:
with open('index_server/indexes/aws_representations.pkl', 'rb') as f:
   aws_rep = pickle.load(f)

In [None]:
path, embeddings = zip(*aws_rep)
path = list(path)
emb = list(embeddings)

In [None]:
emb = np.array(embeddings, dtype='float32')

In [None]:
emb.shape

### Training the FAISS Index for Similarity Search

In [None]:
dimensions = 512
ncentroids = 10
m = 16
quantiser = faiss.IndexFlatL2(dimensions)
index = faiss.IndexIVFPQ (quantiser, dimensions ,ncentroids, m , 8) 
faiss.normalize_L2(emb) 
index.train(emb)

In [None]:
print(index.is_trained)

In [None]:
faiss.write_index(index, "indexes/trained.index")

In [None]:
index.add(emb)
faiss.write_index(index,"indexes/jewel_trained.index")

In [None]:
index = faiss.read_index("indexes/jewel_trained.index")

In [None]:
def encode(image):
    input_tensor = transform(image).unsqueeze(0)

    if input_tensor.size()[1] == 3:
        with torch.no_grad():
            out_tensor = model(input_tensor)
        image.close()
        return out_tensor.numpy()
    else:
        image.close()
        return None

### Load a target image and test the searh results

In [None]:
response = s3.get_object(Bucket=bucket_name, Key=path[0]) ## change this part to get the target image from the app
image_data = response['Body']
img_np_array = np.asarray(bytearray(image_data.read()), dtype=np.uint8)
target_img = cv2.imdecode(img_np_array, cv2.IMREAD_COLOR)
target_rep = encode(Image.fromarray(target_img))
display(Image.fromarray(target_img).resize((300,300)))
faiss.normalize_L2(target_rep)

In [None]:
%%time
k = 20
D, I = index.search(target_rep, k)
I = I[0]
[path[i] for i in I]

In [None]:
for i in range(len(I)):
    response = s3.get_object(Bucket=bucket_name, Key=path[I[i]])
    image_data = response['Body']
    img_np_array = np.asarray(bytearray(image_data.read()), dtype=np.uint8)
    img = cv2.imdecode(img_np_array, cv2.IMREAD_COLOR)
    print(path[I[i]])
    display(Image.fromarray(img).resize((200,200)))