# 3.ImgDBSearch 

In [1]:
import os
import sys
import pickle
import json
import glob
import faiss
import gc
import random
import time
import unicodedata


import numpy as np
import pandas as pd

from matplotlib import pyplot as plt 
from tqdm.notebook import tqdm
from pathlib import Path
from scipy.spatial import distance
from collections import defaultdict
from PIL import Image

from sklearn.model_selection import train_test_split

import torch
import torchvision
from torchvision.io import read_image
from torchvision.models import vit_b_16, ViT_B_16_Weights, regnet_y_32gf, RegNet_Y_32GF_Weights
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

sys.path.append('../input/sentence-transformers-222/sentence-transformers')
from sentence_transformers import SentenceTransformer, models

os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Config 

In [2]:
batch_size_config = {
    "regnet_y_32gf": {
        True: 16,
        False: 8
    },
    "vit_b_16": {
        True: 16,
        False: 8
    },
    
}
    
class CFG:
    seed = 42
    text_emb_size = 384
    is_kaggle = (os.environ.get('PWD') == '/kaggle/working')
    
    img_dataset_parts = 100
    img_model_test_size = 0.01
    
    img_model_name = "regnet_y_32gf" # "vit_b_16", "regnet_y_32gf"
    img_model_del_head = False
    img_emb_size = 1000
    
    index_name = "faiss_flat_ip" # "faiss_flat_l2", "faiss_flat_ip"
    normalize_emb = True
    
    sim_img_k = 400
    weight_sim_mode = "standart_scaler" # "standart_scaler", minmax_scaler", "mean"
    
    batch_size = batch_size_config[img_model_name][is_kaggle]
    num_workers = batch_size if not is_kaggle else 4
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    train_files_dir = "imgdbsearch-data"
    train_name = f"parts_{img_dataset_parts}_model_{img_model_name}_index_{index_name}"
    version_name = f"{train_name}_sims_{sim_img_k}"

CFG.version_name

'parts_100_model_regnet_y_32gf_index_faiss_flat_ip_sims_400'

# Functions 

In [3]:
def get_sim(emb1, emb2):
    sim_res = 0
    for i in range(len(emb1)):
        sim_res += 1 - distance.cosine(emb1[i], emb2[i])
    return sim_res / (i + 1)

def get_img_model(img_model_name = CFG.img_model_name):
    if img_model_name == "regnet_y_32gf":
        if CFG.is_kaggle:
            model = regnet_y_32gf()
            model.load_state_dict(torch.load(f"../input/{CFG.train_files_dir}/regnet_y_32gf_swag-04fdfa75.pth"))
        else:
            weights = RegNet_Y_32GF_Weights.IMAGENET1K_SWAG_E2E_V1
            model = regnet_y_32gf(weights=weights)

        if CFG.img_model_del_head:
            model.fc = torch.nn.Identity()

        model.to(CFG.device)
        model.eval()

        preprocess = transforms.Compose([
            transforms.Resize(384, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(384),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225]),
        ])
    elif img_model_name == "vit_b_16":
        if CFG.is_kaggle:
            model = vit_b_16(image_size=384)
            model.load_state_dict(torch.load(f"../input/{CFG.train_files_dir}/vit_b_16_swag-9ac1b537.pth"))
        else:
            weights = ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1
            model = vit_b_16(weights=weights)

        if CFG.img_model_del_head:
            model.fc = torch.nn.Identity()

        model.to(CFG.device)
        model.eval()

        preprocess = transforms.Compose([
            transforms.Resize(384, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(384),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225]),
        ])
    return model, preprocess

def create_index(index_name = CFG.index_name):
    if index_name == "faiss_flat_l2":
        index = faiss.IndexFlatL2(CFG.img_emb_size)
    elif index_name == "faiss_flat_ip":
        index = faiss.IndexFlatIP(CFG.img_emb_size)
    return index

def get_sim_weight(sim_dist_arr, weight_sim_mode = CFG.weight_sim_mode):
    k = CFG.sim_img_k
    batch_size = sim_dist_arr.shape[0]
    
    if weight_sim_mode == "standart_scaler":
        m = sim_dist_arr.mean(axis=1).repeat(k).reshape(-1, k)
        s = sim_dist_arr.std(axis=1).repeat(k).reshape(-1, k)
        sim_norm = (sim_dist_arr - m) / s
        e = np.exp(-sim_norm).sum(axis=1).repeat(k).reshape(-1, k)
        w = np.exp(-sim_norm) / e
        
    elif weight_sim_mode == "minmax_scaler":
        max_ = sim_dist_arr.max(axis=1).repeat(k).reshape(-1, k)
        min_ = sim_dist_arr.min(axis=1).repeat(k).reshape(-1, k)
        sim_norm = (sim_dist_arr - min_) / (max_ - min_)
        e = np.exp(-sim_norm).sum(axis=1).repeat(k).reshape(-1, k)
        w = np.exp(-sim_norm) / e
        
    elif weight_sim_mode == "mean":
        w = np.ones((batch_size, k)) / k

    return w

class CustomDataSet(Dataset):
    def __init__(self, data_dir, img_names, transform):
        self.data_dir = data_dir
        self.img_names = img_names
        self.transform = transform

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img_name = self.img_names[idx]
        img_path = os.path.join(self.data_dir, img_name)
        image = Image.open(img_path)
        tensor_image = self.transform(image)
        return img_name, tensor_image
    
def set_seed(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
def create_submission(pred_arr, img_names):
    imgIds = [i.split('.')[0] for i in img_names]

    EMBEDDING_LENGTH = CFG.text_emb_size
    eIds = list(range(EMBEDDING_LENGTH))

    imgId_eId = [
        '_'.join(map(str, i)) for i in zip(
            np.repeat(imgIds, EMBEDDING_LENGTH),
            np.tile(range(EMBEDDING_LENGTH), len(imgIds)))]
    
    submission = pd.DataFrame(
                    index=imgId_eId,
                    data=np.array(pred_arr).flatten(),
                    columns=['val']).rename_axis('imgId_eId')
    return submission

def is_english_only(string):
    for s in string:
        cat = unicodedata.category(s)         
        if (cat not in ['Ll', 'Lu', 'Nd', 'Po', 'Pd', 'Zs']) or (not cat.isascii()):
            return False
        
    return True

def filter_metadata(df, 
                    img_size_min, img_size_max, 
                    img_max_ratio_diff, 
                    prompt_words_min, prompt_words_max):
    df = df.copy()
    
    df["size_ratio"] = df["height"] / df["width"]
    df['prompt'] = df['prompt'].str.strip()
    df["num_words"] = df['prompt'].str.split(" ").apply(len)
    df["is_english"] = df["prompt"].apply(is_english_only)
    
    img_hw_cond = (
        df["width"].between(img_size_min, img_size_max) & 
        df["height"].between(img_size_min, img_size_max)
    )
    img_ratio_cond = df["size_ratio"].between(1/img_max_ratio_diff, img_max_ratio_diff)
    prompt_empty_cond = (df["prompt"] != "")
    prompt_num_words_cond = df["num_words"].between(prompt_words_min, prompt_words_max)
    prompt_eng_cond = df["is_english"]

    return df[
        img_hw_cond &
        img_ratio_cond &
        prompt_empty_cond &
        prompt_num_words_cond &
        prompt_eng_cond
    ]

set_seed(CFG.seed)

# Train test split 

In [4]:
train_data_dir = Path("../input/DiffusionDB_2M/")

metadata = pd.read_parquet(train_data_dir / "metadata.parquet")
metadata = metadata[metadata["part_id"] <= CFG.img_dataset_parts]

full_prompt = metadata[["image_name", "prompt"]].values
train_prompt, val_prompt = train_test_split(
    full_prompt, 
    test_size=CFG.img_model_test_size, 
    random_state=CFG.seed,
    shuffle=True
)

train_prompt_dict = {img_name: prompt for img_name, prompt in train_prompt}
val_prompt_dict = {img_name: prompt for img_name, prompt in val_prompt}

# Get img emb 

In [5]:
model, preprocess = get_img_model(img_model_name=CFG.img_model_name)
train_dataset = CustomDataSet(train_data_dir, list(train_prompt_dict.keys()), preprocess)
train_dataloader = DataLoader(
    train_dataset, 
    batch_size=CFG.batch_size, 
    shuffle=True, 
    num_workers=CFG.num_workers
)

In [7]:
train_prompts = []
index = create_index(index_name=CFG.index_name)
for img_names, img_arr in tqdm(train_dataloader):  
    img_arr = img_arr.to(CFG.device)

    img_emb_arr = model(img_arr).cpu().detach().numpy()
    
    if CFG.normalize_emb:
        img_emb_arr = img_emb_arr / np.linalg.norm(img_emb_arr)
    index.add(img_emb_arr)
    
    for i in range(img_emb_arr.shape[0]):
        train_prompts.append(train_prompt_dict[img_names[i]])

    gc.collect()

  0%|          | 0/12375 [00:00<?, ?it/s]

In [8]:
with open(f"../input/{CFG.train_files_dir}/train_prompts_{CFG.train_name}.pickle", "wb") as f:
    pickle.dump(train_prompts, f)
    
faiss.write_index(index, f"../input/{CFG.train_files_dir}/train_index_{CFG.train_name}.faiss")

# Validate 

In [14]:
with open(f"../input/{CFG.train_files_dir}/train_prompts_{CFG.train_name}.pickle", "rb") as f:
    train_prompts = pickle.load(f)

index = faiss.read_index(f"../input/{CFG.train_files_dir}/train_index_{CFG.train_name}.faiss")

In [15]:
val_dataset = CustomDataSet(train_data_dir, list(val_prompt_dict.keys()), preprocess)
val_dataloader = DataLoader(
    val_dataset, 
    batch_size=CFG.batch_size, 
    shuffle=False, 
    num_workers=CFG.num_workers
)

In [16]:
st_model = SentenceTransformer('../input/sentence-transformers-222/all-MiniLM-L6-v2/')

sim_sum = 0
img_count = 0
batch_count = 0

time_full = 0
time_get_img_emb = 0
time_get_sim = 0
time_get_true_prompts = 0
time_get_pred_prompts = 0
time_get_sim_prompt_emb = 0
time_get_sim_emb_mean = 0
time_calc_sim = 0
for img_names, img_arr in tqdm(val_dataloader):  
    start_full = time.time()
    
    start = time.time()
    img_arr = img_arr.to(CFG.device)
    img_emb_arr = model(img_arr).cpu().detach().numpy()
    time_get_img_emb += time.time() - start
    
    start = time.time()
    if CFG.normalize_emb:
        img_emb_arr = img_emb_arr / np.linalg.norm(img_emb_arr)
    index.add(img_emb_arr)
    sim_dist_arr, sim_index_arr = index.search(img_emb_arr, k=CFG.sim_img_k)
    sim_weight_arr = get_sim_weight(sim_dist_arr, weight_sim_mode=CFG.weight_sim_mode)
    time_get_sim += time.time() - start
    
    start = time.time()
    true_prompts = [val_prompt_dict[img_name] for img_name in img_names]       
    true_prompt_emb = st_model.encode(true_prompts)
    time_get_true_prompts += time.time() - start
    
    start = time.time()
    for i in range(len(img_names)):
        start_i = time.time()
        sim_prompts = [train_prompts[sim_i] for sim_i in sim_index_arr[i]]
        sim_prompt_emb_arr = st_model.encode(sim_prompts, show_progress_bar=False)
        time_get_sim_prompt_emb += time.time() - start_i
        
        start_i = time.time()
        sim_prompt_emb_arr *= sim_weight_arr[i].repeat(CFG.text_emb_size).reshape(-1, CFG.text_emb_size)
        sim_prompt_emb_mean = sim_prompt_emb_arr.sum(axis=0)
        time_get_sim_emb_mean += time.time() - start_i
    
        start_i = time.time()
        sim_sum += get_sim(true_prompt_emb[i], sim_prompt_emb_mean)
        time_calc_sim += time.time() - start_i
        
        img_count += 1
    time_get_pred_prompts += time.time() - start
    
    time_full += time.time() - start_full
    
    batch_count += 1
    gc.collect()
    

print("Full time: ", time_full * 1000 // batch_count , " ms")
print("Get img emb: ", time_get_img_emb * 1000 // batch_count , " ms")
print("Get sim emb: ", time_get_sim * 1000 // batch_count, " ms")
print("Get true emb: ", time_get_true_prompts * 1000 // batch_count, " ms")
print("Get pred emb, full: ", time_get_pred_prompts * 1000 // batch_count, " ms")
print("\t batch size: ", CFG.batch_size)
print("\t sim emb: ", time_get_sim_prompt_emb * 1000 // img_count, " ms")
print("\t emb mean: ", time_get_sim_emb_mean * 1000 // img_count, " ms")
print("\t calc sim: ", time_calc_sim * 1000 // img_count, " ms")
print("Val score: ", sim_sum / img_count)

  0%|          | 0/125 [00:00<?, ?it/s]

IndexError: list index out of range

# Infer 

In [12]:
model, preprocess = get_img_model()

index = faiss.read_index(f"../input/{CFG.train_files_dir}/train_index_{CFG.train_name}_l2.faiss")

st_model = SentenceTransformer('../input/sentence-transformers-222/all-MiniLM-L6-v2/').to(CFG.device)

with open(f"../input/{CFG.train_files_dir}/train_prompts_{CFG.train_name}.pickle", "rb") as f:
    train_prompts = pickle.load(f)

In [13]:
test_data_dir = Path("../input/stable-diffusion-image-to-prompts/images/")
test_image_names = sorted(os.listdir(test_data_dir))

test_dataset = CustomDataSet(test_data_dir, test_image_names, preprocess)
test_dataloader = DataLoader(
    test_dataset, 
    batch_size=CFG.batch_size, 
    shuffle=False, 
    num_workers=CFG.num_workers
)

In [28]:
pred_arr = []

for img_names, img_arr in tqdm(test_dataloader):  
    img_arr = img_arr.to(CFG.device)

    img_emb_arr = model(img_arr).cpu().detach().numpy()
    
    sim_dist_arr, sim_index_arr = index.search(img_emb_arr, k=CFG.sim_img_k)
    sim_weight_arr = get_sim_weight(sim_dist_arr, weight_sim_mode=CFG.weight_sim_mode)
    
    for i in range(len(img_names)):
        sim_prompts = [train_prompts[sim_i] for sim_i in sim_index_arr[i]]
        sim_prompt_emb_arr = st_model.encode(sim_prompts, show_progress_bar=False)
        
        sim_prompt_emb_arr *= sim_weight_arr[i].repeat(CFG.text_emb_size).reshape(-1, CFG.text_emb_size)
        sim_prompt_emb_mean = sim_prompt_emb_arr.sum(axis=0)
        
        pred_arr.append(sim_prompt_emb_mean)
    
    gc.collect()
        
pred_arr = np.array(pred_arr)

  0%|          | 0/1 [00:00<?, ?it/s]

In [33]:
submission = create_submission(pred_arr, test_image_names)
submission

Unnamed: 0_level_0,val
imgId_eId,Unnamed: 1_level_1
20057f34d_0,-0.039952
20057f34d_1,0.008582
20057f34d_2,0.011542
20057f34d_3,0.007509
20057f34d_4,-0.019421
...,...
f27825b2c_379,0.015785
f27825b2c_380,0.061255
f27825b2c_381,0.006290
f27825b2c_382,-0.025300


In [None]:
if CFG.is_kaggle:
    submission.to_csv("submission.csv")