In [1]:
from PIL import Image
import torch
import clip
import numpy as np
import faiss
import os
import pickle
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"

def build_embeddings(model, preprocess, data_dir, save_dir): 
    # read all image names
    images = []
    for root, dirs, files in os.walk(data_dir): 
        for file in files: 
                images.append(root  + '/'+ file) 

    # define the emb variable to store embeddings
    embeddings = []

    # extract embeddings and store them in the emb variable
    for i in tqdm(range(len(images))):
        with torch.no_grad():
            image = preprocess(Image.open(images[i])).unsqueeze(0).to(device)
            image_features = model.encode_image(image)
            embeddings.append(image_features.detach().cpu().numpy()) 

    data = {"image_paths": images, "embeddings": embeddings}
    with open(save_dir, 'wb') as f:
        pickle.dump(data, f, protocol=4) 


def build_index(file_dir): 
    with open(file_dir, 'rb') as f:
        data = pickle.load(f) 

    # create Faiss index using FlatL2 
    embeddings = np.array(data["embeddings"], dtype=np.float32) 
    index = faiss.index_factory(embeddings.shape[1], "Flat", faiss.METRIC_INNER_PRODUCT) 
    index.train(embeddings) 
    index.add(embeddings) 

    # store the index locally
    faiss.write_index(index, "vector.index") 

def text_search(model, preprocess, index_dir, model_input, input_type, top_k): 
    index  = faiss.read_index(index_dir) 

    if input_type == "image": 
        image = Image.open(model_input) 
        image = preprocess(image).unsqueeze(0).to(device)
        with torch.no_grad():
            features = model.encode_image(image.to(device)) 
    else: 
        with torch.no_grad():
            text = clip.tokenize([model_input]).to(device)
            features = model.encode_text(text)

    # preprocess the tensor
    embeddings = features.detach().cpu().numpy().astype(np.float32)
    faiss.normalize_L2(embeddings)

    # search the top k images 
    probs, indices = index.search(embeddings, top_k) 

    return probs, indices 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import os
from PIL import Image
from PIL import ImageFile
# ImageFile.LOAD_TRUNCATED_IMAGES = True

# class imageDataset(Dataset):
#     def __init__(self, img_dir,  img_ext_list = ['.jpg', '.png', '.jpeg', '.tiff'], preprocess = None):    
#         self.preprocess = preprocess
#         self.img_path_list = []
#         self.walk_dir(img_dir, img_ext_list)
#         print(f'Found {len(self.img_path_list)} images in {img_dir}')

#     def walk_dir(self, dir_path, img_ext_list): # work for symbolic link
#         for root, dirs, files in os.walk(dir_path):
#             self.img_path_list.extend(
#                 os.path.join(root, file) for file in files 
#                 if os.path.splitext(file)[1].lower() in img_ext_list
#             )
            
#             for dir in dirs:
#                 full_dir_path = os.path.join(root, dir)
#                 if os.path.islink(full_dir_path):
#                     self.walk_dir(full_dir_path, img_ext_list)

#     def __len__(self):
#         return len(self.img_path_list)
    
#     def __getitem__(self, idx):
#         img_path = self.img_path_list[idx]
#         img = Image.open(img_path).convert('RGB')
#         img = self.preprocess(img)
#         return img, img_path

class ImageDataset(Dataset):
    def __init__(self, root_dir, preprocess):
        self.root_dir = root_dir
        self.preprocess = preprocess
        self.image_list = os.listdir(root_dir)

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.image_list[idx])
        image = Image.open(img_name)
        image = self.preprocess(image)

        return image, img_name

In [19]:
model, preprocess = clip.load('ViT-B/32', device) 

In [12]:
build_embeddings(model, preprocess, "train2017", "data")

  2%|▏         | 1992/118287 [02:19<2:16:04, 14.24it/s]


KeyboardInterrupt: 

In [37]:
dataset = ImageDataset(root_dir = os.getcwd() + "\\train2017", preprocess = preprocess)
dataloader = DataLoader(dataset, batch_size=256, shuffle=False)

images = []
embeddings = []
for image, img_path in tqdm(dataloader):
    with torch.no_grad():
        # features = model.encode_image(img.to(device))
        # features /= features.norm(dim=-1, keepdim=True)
        # embedding_list.extend(features.detach().cpu().numpy())
        # img_path_list.extend(img_path)

        # image = preprocess(Image.open(images[i]).to(device))
        image_features = model.encode_image(image.to(device))

        images.append(img_path) 
        embeddings.append(image_features.detach().cpu().numpy()) 

# data = {"image_paths": images, "embeddings": embeddings}
# with open(save_dir, 'wb') as f:
#     pickle.dump(data, f, protocol=4) 

100%|██████████| 463/463 [1:40:09<00:00, 12.98s/it]


In [38]:
len(images), len(embeddings) 

(463, 463)

In [52]:
image_list = [] 
for i in range(len(images)): 
    image_list.extend(images[i]) 

In [53]:
len(image_list) 

118287

In [54]:
embeddings[0].shape 

(256, 512)

In [55]:
embedding_list = [] 
for i in range(len(embeddings)): 
    embedding_list.extend(embeddings[i]) 

In [56]:
len(embedding_list) 

118287

In [57]:
data = {"image_paths": image_list, "embeddings": embedding_list}
with open("embeddings.pkl", 'wb') as f:
    pickle.dump(data, f, protocol=4) 

In [60]:
build_index("embeddings.pkl")

In [61]:
probs, indices = text_search(model, preprocess, "vector.index", "cat on table", "text", 1) 

In [62]:
probs, indices 

(array([[3.6320894]], dtype=float32), array([[22185]], dtype=int64))

In [64]:
indices[0, 0]

22185

In [3]:
with open("embeddings.pkl", 'rb') as f:
        data = pickle.load(f) 
image_list = data["image_paths"]

In [4]:
len(image_list)

118287

In [6]:
img = Image.open(image_list[86153]) 
img.show() 

In [14]:
os.getcwd()

'C:\\Users\\Stephen Ma\\Desktop\\llama-2-chatbot'

In [15]:
os.getcwd() + "/train2017"

'C:\\Users\\Stephen Ma\\Desktop\\llama-2-chatbot/train2017'

In [28]:
dataset.image_list[1]

IndexError: list index out of range

In [29]:
os.listdir(os.getcwd())

['.git',
 '.ipynb_checkpoints',
 '.vscode',
 'data',
 'README.md',
 'server',
 'train2017',
 'Untitled.ipynb']

In [34]:
arr = os.listdir(dir)

In [31]:
dir = os.getcwd() + "\\train2017"
dir

'C:\\Users\\Stephen Ma\\Desktop\\llama-2-chatbot\\train2017'

In [35]:
len(arr)

118287