In [11]:
import os
from tqdm import tqdm
import numpy as np
import pandas as pd
import faiss
from PIL import Image
import torch
from torchvision import transforms


IMG_DIR = "../../data/filtered/"
IMG_SIZE = 224
EMB_SIZE = 384
DEVICE = "cuda"
CHECKPOINT_PATH = "../checkpoints/extractor.torchscript"
DUMP_PATH = "../checkpoints/landmarks_filtered_db"

In [2]:
image_df = pd.read_csv("../../data/filtered/filtered.csv")
print(image_df.count())
image_df.head()

landmark_id    556884
id             556884
dtype: int64


Unnamed: 0,landmark_id,id
0,126637,61a922bc87eade27
1,126637,deceb9e5b5cb1f68
2,126637,232c227f4002b3f5
3,126637,1b7f6ab4d250c671
4,126637,be277cf05cb58cc0


In [3]:
def img_path_from_id(id, data_dir):
    img_path = os.path.join(data_dir, id[0], id[1], id[2], f"{id}.jpg")
    return img_path

images = [(landmark_id, img_path_from_id(img_id, IMG_DIR)) for landmark_id, img_id in zip(image_df["landmark_id"].values, image_df["id"].values)]

In [4]:
labels = [im[0] for im in images]
img_names = [im[1] for im in images]

In [None]:
embedding_model = torch.jit.load(CHECKPOINT_PATH, map_location=DEVICE)
embedding_model.eval()

In [7]:
def get_image(path):
    image = Image.open(path).convert("RGB")
    return image


def transform_image(image):
    trans = transforms.Compose(
        [
            transforms.Resize((IMG_SIZE + 32, IMG_SIZE + 32)),
            transforms.CenterCrop(IMG_SIZE),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )
    return trans(image)


def get_image_embedding(model, image):
    return model(image.to(DEVICE).unsqueeze(0))


In [8]:
faiss_index = faiss.IndexFlatIP(EMB_SIZE)

In [12]:
all_embeddings = []
for i in tqdm(range(len(img_names))):
    im_name = img_names[i]
    image = get_image(im_name)
    image = transform_image(image)
    image_emb = get_image_embedding(embedding_model, image).detach().cpu().numpy()
    del image
    torch.cuda.empty_cache()
    all_embeddings.append(image_emb)

100%|██████████| 556884/556884 [3:37:36<00:00, 42.65it/s]  


In [16]:
all_embeddings_squeeze = [emb.squeeze(0) for emb in all_embeddings]

In [17]:
embeddings = np.array(all_embeddings_squeeze, dtype=np.float32)

In [18]:
embeddings.shape

(556884, 384)

In [19]:
faiss_index.add(embeddings)

In [20]:
faiss.write_index(faiss_index, DUMP_PATH + ".index")
np.savez_compressed(DUMP_PATH + ".npz", labels=labels, img_names=img_names)