# 1. Img2Emb Ensemble 

In [1]:
import os
import sys
import pickle
import json
import glob
import gc
import random
import time
import unicodedata
import traceback
import datetime
import copy
import itertools

import numpy as np
import pandas as pd

from matplotlib import pyplot as plt 
from tqdm.notebook import tqdm
from pathlib import Path
from scipy.spatial import distance
from collections import defaultdict
from PIL import Image

import torch
import torchvision
from torchvision.models import (
    vit_b_16, ViT_B_16_Weights, 
    regnet_y_32gf, RegNet_Y_32GF_Weights,
    regnet_y_16gf, RegNet_Y_16GF_Weights,
)
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

os.environ["TOKENIZERS_PARALLELISM"] = "false"

def set_seed(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [2]:
batch_size_config = {
    "vit_b_16": {
        True: 256,
        False: 16
    },
    "vit_b_16_linear": {
        True: 256,
        False: 48
    },
    "regnet_y_16gf": {
        True: 64,
        False: 26
    },
    "regnet_y_16gf_linear": {
        True: 64,
        False: 26
    },
    "regnet_y_32gf": {
        True: 16,
        False: 6
    },
    "regnet_y_32gf_linear": {
        True: 16,
        False: 20
    },
}

In [3]:
class CFG:
    seed = 42
    text_emb_size = 384
    is_kaggle = (os.environ.get('PWD') == '/kaggle/working')
    
    train_files_dir = "img2emb-data"
    
    test_flip = True
    
    # RESOURCES
    num_workers = 2
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model_names = [
        "vit_b_16", "vit_b_16_linear", "regnet_y_16gf", "regnet_y_32gf", "regnet_y_16gf_linear", "regnet_y_32gf_linear", 
    ]
    model_scores = [
        0.52504, 0.52275, 0.53276, 0.52972, 0.53928, 0.54452
    ]
    
    model_alphas = [
        0.2, 0.2, 0.4, 0.8, 0.2, 0.4
    ]

set_seed(CFG.seed)
CFG.train_files_dir

'img2emb-data'

In [4]:
def get_img_model(img_model_name: str, load_weight: bool, head_emb_size: int):
    if img_model_name == "regnet_y_16gf":
        if not load_weight:
            model = regnet_y_16gf()
        else:
            weights = RegNet_Y_16GF_Weights.IMAGENET1K_SWAG_E2E_V1
            model = regnet_y_16gf(weights=weights)
        model.fc = torch.nn.Linear(3024, head_emb_size)
        
        preprocess = transforms.Compose([
            transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225]),
        ])
    if img_model_name == "regnet_y_16gf_linear":
        if not load_weight:
            model = regnet_y_16gf()
        else:
            weights = RegNet_Y_16GF_Weights.IMAGENET1K_SWAG_LINEAR_V1
            model = regnet_y_16gf(weights=weights)
        model.fc = torch.nn.Linear(3024, head_emb_size)
        
        preprocess = transforms.Compose([
            transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225]),
        ])
    elif img_model_name == "regnet_y_32gf":
        if not load_weight:
            model = regnet_y_32gf()
        else:
            weights = RegNet_Y_32GF_Weights.IMAGENET1K_SWAG_E2E_V1
            model = regnet_y_32gf(weights=weights)
        model.fc = torch.nn.Linear(3712, head_emb_size)
        
        preprocess = transforms.Compose([
            transforms.Resize(384, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(384),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225]),
        ])
        
    elif img_model_name == "regnet_y_32gf_linear":
        if not load_weight:
            model = regnet_y_32gf()
        else:
            weights = RegNet_Y_32GF_Weights.IMAGENET1K_SWAG_LINEAR_V1
            model = regnet_y_32gf(weights=weights)
        model.fc = torch.nn.Linear(3712, head_emb_size)
        
        preprocess = transforms.Compose([
            transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225]),
        ])
        
    elif img_model_name == "vit_b_16":
        if not load_weight:
            model = vit_b_16(image_size=384)
        else:
            weights = ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1
            model = vit_b_16(weights=weights)
        model.heads.head = torch.nn.Linear(768, head_emb_size)
        
        preprocess = transforms.Compose([
            transforms.Resize(384, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(384),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225]),
        ])
    
    elif img_model_name == "vit_b_16_linear":
        if not load_weight:
            model = vit_b_16(image_size=224)
        else:
            weights = ViT_B_16_Weights.IMAGENET1K_SWAG_LINEAR_V1
            model = vit_b_16(weights=weights)
        model.heads.head = torch.nn.Linear(768, head_emb_size)
        
        preprocess = transforms.Compose([
            transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225]),
        ])
    
    return model, preprocess

def create_submission(pred_arr, img_names, text_emb_size):
    imgIds = [i.split('.')[0] for i in img_names]

    EMBEDDING_LENGTH = text_emb_size
    eIds = list(range(EMBEDDING_LENGTH))

    imgId_eId = [
        '_'.join(map(str, i)) for i in zip(
            np.repeat(imgIds, EMBEDDING_LENGTH),
            np.tile(range(EMBEDDING_LENGTH), len(imgIds)))]
    
    submission = pd.DataFrame(
                    index=imgId_eId,
                    data=np.array(pred_arr).flatten(),
                    columns=['val']).rename_axis('imgId_eId')
    return submission

class CustomDataSet(Dataset):
    def __init__(self, data_dir, img2prompt, img_preprocess):
        self.data_dir = data_dir
        self.img_names = list(img2prompt.keys())
        self.img2prompt = img2prompt
        self.img_preprocess = img_preprocess

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img_name = self.img_names[idx]
        img_path = os.path.join(self.data_dir, img_name)
        img = Image.open(img_path)
        img_emb = self.img_preprocess(img)
        
        prompt = str(self.img2prompt[img_name])
        
        return img_name, img_emb, prompt

In [5]:
test_data_dir = Path("../input/stable-diffusion-image-to-prompts/images/")
test_image_names = os.listdir(test_data_dir)
test_prompt_dict = {img_name: "" for img_name in test_image_names}

pred_arr_models_list = []
for model_name in CFG.model_names:
    img_model, img_preprocess = get_img_model(img_model_name=model_name, 
                                              load_weight=False, 
                                              head_emb_size=CFG.text_emb_size)
    
    model_path = f"../input/{CFG.train_files_dir}/dataset_duplwords_5_model_{model_name}_sch_None_lr_1e_05.torch"
    img_model.load_state_dict(torch.load(model_path))
    img_model.to(CFG.device)
    img_model.eval()
    
    test_dataset = CustomDataSet(   
        data_dir=test_data_dir, 
        img2prompt=test_prompt_dict, 
        img_preprocess=img_preprocess
    )
    test_dataloader = DataLoader(
        test_dataset, 
        batch_size=batch_size_config[model_name][True], 
        shuffle=False, 
        num_workers=CFG.num_workers
    )
    
    pred_arr_model = [] 
    with torch.no_grad():
        for img_names, img_embs, prompts in tqdm(test_dataloader):
            img_embs = img_embs.to(CFG.device)
            pred = img_model(img_embs) # (batch, emb_size)
            
            if CFG.test_flip:
                img_embs_flip = transforms.functional.hflip(img_embs)
                pred_flip = img_model(img_embs_flip)
                pred = (pred + pred_flip) / 2
            
            pred = pred.cpu().detach().numpy()
                
            pred_arr_model.extend(pred) 
                
    pred_arr_model = np.array(pred_arr_model) # (images, emb_size)
    pred_arr_models_list.append(pred_arr_model) # (models, images, emb_size)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

In [6]:
n_images = pred_arr_models_list[0].shape[0]
n_models = len(pred_arr_models_list)

img2emb_pred_arr = np.zeros((n_images, CFG.text_emb_size))
for j in range(n_models):
    img2emb_pred_arr += pred_arr_models_list[j] * CFG.model_alphas[j] / sum(CFG.model_alphas)
img2emb_pred_arr.shape

(7, 384)

In [7]:
img2emb_pred_arr = img2emb_pred_arr / np.linalg.norm(img2emb_pred_arr, ord=2, axis=1)[:, np.newaxis]

In [8]:
pred_img2emb = img2emb_pred_arr.flatten()
pred_img2emb.shape

(2688,)

# 2. CLIP knn-regressor

In [9]:
wheels_path = "/kaggle/input/clip-interrogator-wheels-x"
clip_interrogator_whl_path = f"{wheels_path}/clip_interrogator-0.4.3-py3-none-any.whl"
! pip install --no-index --find-links $wheels_path $clip_interrogator_whl_path -q

[0m

In [10]:
import os, glob, gc
import random
import numpy as np
import pandas as pd
import pickle
from pathlib import Path
from PIL import Image
from tqdm.notebook import tqdm
from sklearn.preprocessing import normalize
import torch
from torch import nn
from torchvision import transforms
import open_clip
import warnings
warnings.filterwarnings('ignore')
from matplotlib import pyplot as plt

In [11]:
class CFG:
    clip_model_path = "/kaggle/input/laion-vit-h-14-model/ViT-H-14_laion2b_s32b_b79k.pt"
    clip_preproc_pkl = "/kaggle/input/laion-vit-h-14-model/preprocess.pkl"
    input_size = 224
    batch_size = 64
    seed = 42
    knn_topk = 100
    knn_interval = 1000
    knn_dim = 6

In [12]:
def seed_everything(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    if torch.cuda.is_available(): 
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True

seed_everything(CFG.seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [13]:
FILES_OBJECTIVE_EMB = []
FILES_CLIP_EMB = []

### DiffusionDB-14M (https://huggingface.co/datasets/poloclub/diffusiondb)
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-diffusiondb-14m-part1/all_minilm_l6_v2/*.npy"))
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-diffusiondb-14m-part2/all_minilm_l6_v2/*.npy"))
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-diffusiondb-14m-part3/all_minilm_l6_v2/*.npy"))

# CLIP text tembeddings
FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-diffusiondb-14m-part1/vith14/*.npy"))
FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-diffusiondb-14m-part2/vith14/*.npy"))
FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-diffusiondb-14m-part3/vith14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

# MSCOCO 2017(train data) (https://cocodataset.org/#download)
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-mscoco/all_minilm_l6_v2/*.npy"))

# CLIP text tembeddings
FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-mscoco/vith14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

# dataset80k (https://www.kaggle.com/competitions/stable-diffusion-image-to-prompts/discussion/390674)
# Objective embeddings
FILES_OBJECTIVE_EMB.append("/kaggle/input/pub-embeddings-dataset80k/all_minilm_l6_v2/prompt_embeddings_allminilm_001.npy")

# CLIP text tembeddings
FILES_CLIP_EMB.append("/kaggle/input/pub-embeddings-dataset80k/vith14/prompt_embedding_vith14_001.npy")
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

# Dataset30k (https://www.kaggle.com/competitions/stable-diffusion-image-to-prompts/discussion/391500)
# Objective embeddings
FILES_OBJECTIVE_EMB.append("/kaggle/input/pub-embeddings-dataset30k/all_minilm_l6_v2/prompt_embeddings_allminilm_001.npy")

# CLIP text tembeddings
FILES_CLIP_EMB.append("/kaggle/input/pub-embeddings-dataset30k/vith14/prompt_embedding_vith14_001.npy")
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

# Dataset900k (https://www.kaggle.com/competitions/stable-diffusion-image-to-prompts/discussion/399699)
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-dataset900k/all_minilm_l6_v2/*.npy"))

# CLIP text tembeddings
FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-dataset900k/vith14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

# Conceptual Captions
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-conceptual-captions/all_minilm_l6_v2/*.npy"))

# CLIP text tembeddings
FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-conceptual-captions/vith14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

# SD2GPT2 (https://www.kaggle.com/datasets/xiaozhouwang/sd2gpt2)
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-sd2gpt2/all_minilm_l6_v2/*.npy"))

# CLIP text tembeddings
FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-sd2gpt2/vith14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

# SD2Hardcode (https://www.kaggle.com/datasets/xiaozhouwang/sd2hardcode)
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-sd2hardcode/all_minilm_l6_v2/*.npy"))

# CLIP text tembeddings
FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-sd2hardcode/vith14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

# ChatGPT (https://www.kaggle.com/competitions/stable-diffusion-image-to-prompts/discussion/402146)
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-chatgpt/all_minilm_l6_v2/*.npy"))

# CLIP text tembeddings
FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-chatgpt/vith14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

# Laion2B-en (https://huggingface.co/datasets/laion/laion2B-en)
# (part0000-part0050 of 2000)

# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0000-0004/all_minilm_l6_v2/*.npy"))
FILES_OBJECTIVE_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0005-0009/all_minilm_l6_v2/*.npy"))
FILES_OBJECTIVE_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0010-0014/all_minilm_l6_v2/*.npy"))
FILES_OBJECTIVE_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0015-0019/all_minilm_l6_v2/*.npy"))
FILES_OBJECTIVE_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0020-0024/all_minilm_l6_v2/*.npy"))
FILES_OBJECTIVE_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0025-0029/all_minilm_l6_v2/*.npy"))
FILES_OBJECTIVE_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0030-0034/all_minilm_l6_v2/*.npy"))
FILES_OBJECTIVE_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0035-0039/all_minilm_l6_v2/*.npy"))
FILES_OBJECTIVE_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0040-0044/all_minilm_l6_v2/*.npy"))
FILES_OBJECTIVE_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0045-0049/all_minilm_l6_v2/*.npy"))

# CLIP text tembeddings
FILES_CLIP_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0000-0004/vith14/*.npy"))
FILES_CLIP_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0005-0009/vith14/*.npy"))
FILES_CLIP_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0010-0014/vith14/*.npy"))
FILES_CLIP_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0015-0019/vith14/*.npy"))
FILES_CLIP_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0020-0024/vith14/*.npy"))
FILES_CLIP_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0025-0029/vith14/*.npy"))
FILES_CLIP_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0030-0034/vith14/*.npy"))
FILES_CLIP_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0035-0039/vith14/*.npy"))
FILES_CLIP_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0040-0044/vith14/*.npy"))
FILES_CLIP_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0045-0049/vith14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

140 140
146 146
147 147
148 148
158 158
192 192
193 193
194 194
195 195
665 665


In [14]:
test_images = list(Path('/kaggle/input/stable-diffusion-image-to-prompts/images').glob('*.png'))
print(len(test_images))

7


In [15]:
# load CLIP vision model (Laion/ViTH14)
clip_model = torch.jit.load(CFG.clip_model_path).cuda()
clip_model = clip_model.cuda().eval().half();

In [16]:
# load transforms from pickele
with open(CFG.clip_preproc_pkl, "rb") as fp:
    saved_preprocess = pickle.load(fp)
saved_preprocess

Compose(
    Resize(size=224, interpolation=bicubic, max_size=None, antialias=None)
    CenterCrop(size=(224, 224))
    <function _convert_to_rgb at 0x7cb8c6f74170>
    ToTensor()
    Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)

In [17]:
# Inference with CLIP vision encoder
test_vision_embeddings = []
for test_image in tqdm(test_images):
    image = Image.open( test_image )
    prep = saved_preprocess(image).unsqueeze(0).to(device)
    embedding = clip_model(prep.half())
    test_vision_embeddings.append( embedding.detach().cpu().numpy() )
test_vision_embeddings = np.concatenate(test_vision_embeddings).astype(np.float16)
print(test_vision_embeddings.shape)

del clip_model, prep, image, embedding
gc.collect()

  0%|          | 0/7 [00:00<?, ?it/s]

(7, 1024)


5251

In [18]:
# Reduce the number of files as debugging.
if test_vision_embeddings.shape[0] == 7:   # 7 public test images
    FILES_OBJECTIVE_EMB = FILES_OBJECTIVE_EMB[:5]
    FILES_CLIP_EMB = FILES_CLIP_EMB[:5]
    print(f"!!! Debug mode !!!")
    print(f"len(FILES_OBJECTIVE_EMB)={len(FILES_OBJECTIVE_EMB)}")
    print(f"len(FILES_CLIP_EMB)={len(FILES_CLIP_EMB)}")

!!! Debug mode !!!
len(FILES_OBJECTIVE_EMB)=5
len(FILES_CLIP_EMB)=5


In [19]:
def predict_local_knn(
    ref_x_embeddings, test_x_embeddings, 
    n_neighbors=CFG.knn_topk,
    interval=CFG.knn_interval,
    distance_dim=CFG.knn_dim,
    coef=1.0, # a coef to prevent from overflow
):
    
    # convert to tensor
    ref_x_embeddings = torch.from_numpy(ref_x_embeddings).to('cuda')
    ref_x_embeddings /= ref_x_embeddings.norm(dim=-1, keepdim=True)
    
    n_iter = test_x_embeddings.shape[0]//interval
    if test_x_embeddings.shape[0]%interval != 0:
        n_iter += 1
        
    dist_topk_store = []
    idxs_topk_store = []
    weights_store = []
    preds = []
    delta = 0.0001
    for i in range(n_iter):
        batch_test_embeddings = torch.from_numpy(
            test_x_embeddings[i*interval:(i+1)*interval, :].copy()
        ).to('cuda')
        batch_test_embeddings /= batch_test_embeddings.norm(dim=-1, keepdim=True)
        
        # calc distance matrix
        dists = 1 - torch.mm(batch_test_embeddings, ref_x_embeddings.T) # dists.shape=[N_test, N_ref]
        del batch_test_embeddings
        gc.collect()
        
        # get topk indecies and distance
        dist_topk, idxs_topk = torch.topk(dists, n_neighbors, largest=False, dim=-1)
        dist_topk = dist_topk.to(torch.float64)        
        
        # calc weights from distance
        weights = 1/(dist_topk**distance_dim+delta)*coef
        weights[ dist_topk < 0 ] = delta
        
        dist_topk_store.append( dist_topk.to('cpu').detach().numpy().copy() )
        idxs_topk_store.append( idxs_topk.to('cpu').detach().numpy().copy() )
        weights_store.append( weights.to('cpu').detach().numpy().copy() )
                
        del dists, weights, dist_topk, idxs_topk
        torch.cuda.empty_cache()
        gc.collect()
        
    del ref_x_embeddings
    torch.cuda.empty_cache()
    gc.collect()
    return np.concatenate(dist_topk_store), np.concatenate(idxs_topk_store), np.concatenate(weights_store)

In [20]:
for i_file, file_clip_emb in enumerate(tqdm(FILES_CLIP_EMB)):
    # Local k-NN (for each CLIP embeddings file VS CLIP vision embeddings of test images)
    ref_clip_embeddings = np.load(file_clip_emb).astype(np.float16)
    with torch.no_grad():
        local_dists, local_emb_indecies, local_weights = predict_local_knn(
            ref_clip_embeddings, test_vision_embeddings,
            n_neighbors=CFG.knn_topk, interval=CFG.knn_interval, distance_dim=CFG.knn_dim,
            coef=0.001
        )
    local_files = np.zeros(local_dists.shape, dtype=np.int32) + i_file
    
    # merge local k-NN into global k-NN
    if i_file == 0:
        global_files = local_files
        global_dists = local_dists
        global_emb_indecies = local_emb_indecies
        global_weights = local_weights
    else:
        global_files = np.concatenate([global_files, local_files], axis=-1)
        global_dists = np.concatenate([global_dists, local_dists], axis=-1)
        global_emb_indecies = np.concatenate([global_emb_indecies, local_emb_indecies], axis=-1)
        global_weights = np.concatenate([global_weights, local_weights], axis=-1)

        unsorted_min_indices = np.argpartition(global_dists, CFG.knn_topk, axis=1)[:, :CFG.knn_topk]

        global_files = np.vstack( [ global_files[i, unsorted_min_indices[i,:]] for i in range(unsorted_min_indices.shape[0]) ])
        global_dists = np.vstack( [ global_dists[i, unsorted_min_indices[i,:]] for i in range(unsorted_min_indices.shape[0]) ])
        global_emb_indecies = np.vstack( [ global_emb_indecies[i, unsorted_min_indices[i,:]] for i in range(unsorted_min_indices.shape[0]) ])
        global_weights = np.vstack( [ global_weights[i, unsorted_min_indices[i,:]] for i in range(unsorted_min_indices.shape[0]) ])
    
    gc.collect()

  0%|          | 0/5 [00:00<?, ?it/s]

In [21]:
df_knn = pd.DataFrame()
df_knn["file"] = global_files.flatten()
df_knn["file"] = df_knn['file'].apply(lambda x: FILES_OBJECTIVE_EMB[x])
df_knn["dist"] = global_dists.flatten()
df_knn["emb_index"] = global_emb_indecies.flatten()
df_knn["test_index"] = np.array([ [val]*CFG.knn_topk for val in range(test_vision_embeddings.shape[0])]).flatten()
df_knn["weight"] = global_weights.flatten()

gc.collect()

18

In [22]:
# k-NN regression
test_prompt_embeddings = np.zeros( (test_vision_embeddings.shape[0], 384))
for (objective_emb_file, gdf) in tqdm(df_knn.groupby("file")):
    ref_objective_embeddings = np.load(objective_emb_file).astype(np.float16) 
    for _, r in gdf.iterrows():
        test_prompt_embeddings[int(r.test_index), :] += r.weight * ref_objective_embeddings[int(r.emb_index), :]

  0%|          | 0/5 [00:00<?, ?it/s]

In [23]:
# L2 norm
BS=1000
num = test_prompt_embeddings.shape[0] // BS
if test_prompt_embeddings.shape[0] % BS != 0:
    num+=1
for i in range(num):
    embeddings = test_prompt_embeddings[i*BS:(i+1)*BS, :]
    embeddings = embeddings / ( np.abs(embeddings).max(axis=-1, keepdims=True) + 0.0000001)
    embeddings = normalize( embeddings )
    test_prompt_embeddings[i*BS:(i+1)*BS, :] = embeddings
    
gc.collect()

122

In [24]:
pred_clip_knn = test_prompt_embeddings.flatten()

# 3. CLIP Interrogator

In [25]:
import inspect
import importlib

from blip.models import blip
from clip_interrogator import clip_interrogator

In [26]:
# replace tokenizer path to prevent downloading
blip_path = inspect.getfile(blip)

fin = open(blip_path, "rt")
data = fin.read()
data = data.replace(
    "BertTokenizer.from_pretrained('bert-base-uncased')", 
    "BertTokenizer.from_pretrained('/kaggle/input/clip-interrogator-models-x/bert-base-uncased')"
)
fin.close()

fin = open(blip_path, "wt")
fin.write(data)
fin.close()

# reload module
importlib.reload(blip)

<module 'blip.models.blip' from '/opt/conda/lib/python3.7/site-packages/blip/models/blip.py'>

In [27]:
# fix clip_interrogator bug
clip_interrogator_path = inspect.getfile(clip_interrogator.Interrogator)

fin = open(clip_interrogator_path, "rt")
data = fin.read()
data = data.replace(
    'open_clip.get_tokenizer(clip_model_name)', 
    'open_clip.get_tokenizer(config.clip_model_name.split("/", 2)[0])'
)
fin.close()

fin = open(clip_interrogator_path, "wt")
fin.write(data)
fin.close()

# reload module
importlib.reload(clip_interrogator)

<module 'clip_interrogator.clip_interrogator' from '/opt/conda/lib/python3.7/site-packages/clip_interrogator/clip_interrogator.py'>

In [28]:
import os
import sys
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt 

import numpy as np
import pandas as pd
import torch
import open_clip

sys.path.append('../input/sentence-transformers-222/sentence-transformers')
from sentence_transformers import SentenceTransformer, models

comp_path = Path('/kaggle/input/stable-diffusion-image-to-prompts/')

In [29]:
class CFG:
    device = "cuda"
    seed = 42
    embedding_length = 384
    sentence_model_path = "/kaggle/input/sentence-transformers-222/all-MiniLM-L6-v2"
    blip_model_path = "/kaggle/input/clip-interrogator-models-x/model_large_caption.pth"
    ci_clip_model_name = "ViT-H-14/laion2b_s32b_b79k"
    clip_model_name = "ViT-H-14"
    clip_model_path = "/kaggle/input/clip-interrogator-models-x/CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin"
    cache_path = "/kaggle/input/clip-interrogator-models-x"

In [30]:
df_submission = pd.read_csv(comp_path / 'sample_submission.csv', index_col='imgId_eId')
df_submission.head()

Unnamed: 0_level_0,val
imgId_eId,Unnamed: 1_level_1
20057f34d_0,0.018848
20057f34d_1,0.03019
20057f34d_2,0.072792
20057f34d_3,-0.000673
20057f34d_4,0.016774


In [31]:
images = os.listdir(comp_path / 'images')
imgIds = [i.split('.')[0] for i in images]

eIds = list(range(CFG.embedding_length))

imgId_eId = [
    '_'.join(map(str, i)) for i in zip(
        np.repeat(imgIds, CFG.embedding_length),
        np.tile(range(CFG.embedding_length), len(imgIds))
    )
]

assert sorted(imgId_eId) == sorted(df_submission.index)

In [32]:
st_model = SentenceTransformer(CFG.sentence_model_path)

In [33]:
model_config = clip_interrogator.Config(clip_model_name=CFG.ci_clip_model_name)
model_config.cache_path = CFG.cache_path

In [34]:
configs_path = os.path.join(os.path.dirname(os.path.dirname(blip_path)), 'configs')
med_config = os.path.join(configs_path, 'med_config.json')
blip_model = blip.blip_decoder(
    pretrained=CFG.blip_model_path,
    image_size=model_config.blip_image_eval_size, 
    vit=model_config.blip_model_type, 
    med_config=med_config
)
blip_model.eval()
blip_model = blip_model.to(model_config.device)
model_config.blip_model = blip_model

load checkpoint from /kaggle/input/clip-interrogator-models-x/model_large_caption.pth


In [35]:
clip_model = open_clip.create_model(CFG.clip_model_name, precision='fp16' if model_config.device == 'cuda' else 'fp32')
open_clip.load_checkpoint(clip_model, CFG.clip_model_path)
clip_model.to(model_config.device).eval()
model_config.clip_model = clip_model

In [36]:
clip_preprocess = open_clip.image_transform(
    clip_model.visual.image_size,
    is_train = False,
    mean = getattr(clip_model.visual, 'image_mean', None),
    std = getattr(clip_model.visual, 'image_std', None),
)
model_config.clip_preprocess = clip_preprocess
ci = clip_interrogator.Interrogator(model_config)

Loaded CLIP model and data in 3.24 seconds.


In [37]:
cos = torch.nn.CosineSimilarity(dim=1)

mediums_features_array = torch.stack([torch.from_numpy(t) for t in ci.mediums.embeds]).to(ci.device)
movements_features_array = torch.stack([torch.from_numpy(t) for t in ci.movements.embeds]).to(ci.device)
flavors_features_array = torch.stack([torch.from_numpy(t) for t in ci.flavors.embeds]).to(ci.device)

In [38]:
def interrogate(image: Image) -> str:
    caption = ci.generate_caption(image)
    image_features = ci.image_to_features(image)
    
    medium = [ci.mediums.labels[i] for i in cos(image_features, mediums_features_array).topk(1).indices][0]
    movement = [ci.movements.labels[i] for i in cos(image_features, movements_features_array).topk(1).indices][0]
    flaves = ", ".join([ci.flavors.labels[i] for i in cos(image_features, flavors_features_array).topk(3).indices])

    if caption.startswith(medium):
        prompt = f"{caption}, {movement},  {flaves}"
    else:
        prompt = f"{caption}, {medium}, {movement}, {flaves}"

    return clip_interrogator._truncate_to_fit(prompt, ci.tokenize)

In [39]:
prompts = []

images_path = "../input/stable-diffusion-image-to-prompts/images/"
for image_name in images:
    img = Image.open(images_path + image_name).convert("RGB")

    generated = interrogate(img)
    
    prompts.append(generated)

In [40]:
pred_clip_interrogate = st_model.encode(prompts).flatten()

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

# 4. ViT 

In [41]:
import numpy as np
import pandas as pd
from pathlib import Path
from PIL import Image
from tqdm.notebook import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import timm
from sklearn.preprocessing import normalize

In [42]:
class CFG:
    model_path1 = '/kaggle/input/k/shoheiazuma/stable-diffusion-vit-baseline-train/vit_base_patch16_224.pth'
    model_name1 = 'vit_base_patch16_224'
    input_size1 = 224 
    input_size = 384
    batch_size = 64

In [43]:
class DiffusionTestDataset(Dataset):
    def __init__(self, images, transform):
        self.images = images
        self.transform = transform
    
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx])
        image = self.transform(image)
        return image
    
def predict(
    images,
    model_path,
    model_name,
    input_size,
    batch_size
):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    transform = transforms.Compose([
        transforms.Resize(input_size),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    dataset = DiffusionTestDataset(images, transform)
    dataloader = DataLoader(
        dataset=dataset,
        shuffle=False,
        batch_size=batch_size,
        pin_memory=True,
        num_workers=2,
        drop_last=False
    )

    model = timm.create_model(
        model_name,
        pretrained=False,
        num_classes=384
    )
    state_dict = torch.load(model_path)
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()
    
    tta_preds = None
    for _ in range(2):
        preds = []
        for X in tqdm(dataloader, leave=False):
            X = X.to(device)

            with torch.no_grad():
                X_out = model(X).cpu().numpy()
                # L2 normalize -- Start
                X_out = X_out / ( np.abs(X_out).max(axis=-1, keepdims=True))  # To avoid to overflow at normalize()
                X_out = normalize( X_out )
                # L2 normalize -- End
                preds.append(X_out)
                
        if tta_preds is None:
            tta_preds = np.vstack(preds).flatten()
        else:
            tta_preds += np.vstack(preds).flatten()
    
    return tta_preds / 2

In [44]:
images = list(Path('/kaggle/input/stable-diffusion-image-to-prompts/images').glob('*.png'))
pred_vit = predict(images, CFG.model_path1, CFG.model_name1, CFG.input_size1, CFG.batch_size)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

# FINAL 

In [45]:
pred_final = 0.5 * (pred_clip_knn * 0.60 + pred_clip_interrogate * 0.15 + pred_vit * 0.25) + 0.5 * pred_img2emb

In [46]:
imgIds = [i.stem for i in test_images]
EMBEDDING_LENGTH = 384
imgId_eId = [
    '_'.join(map(str, i)) for i in zip(
        np.repeat(imgIds, EMBEDDING_LENGTH),
        np.tile(range(EMBEDDING_LENGTH), len(imgIds)))]

submission = pd.DataFrame(
    index=imgId_eId,
    data=pred_final,
    columns=['val']
).rename_axis('imgId_eId')
submission

Unnamed: 0_level_0,val
imgId_eId,Unnamed: 1_level_1
f27825b2c_0,-0.051940
f27825b2c_1,0.071797
f27825b2c_2,0.012923
f27825b2c_3,-0.018633
f27825b2c_4,-0.065897
...,...
c98f79f71_379,0.002945
c98f79f71_380,0.068411
c98f79f71_381,0.032352
c98f79f71_382,-0.033045


In [47]:
submission.to_csv('submission.csv')