# Overview

This method uses k-NN regression with distances of CLIP embeddings.  
***Note: This method does not use generated images, only prompts.***
![](https://www.googleapis.com/download/storage/v1/b/kaggle-forum-message-attachments/o/inbox%2F8163878%2Fbf35d4beb7867bc163f28abb647128ae%2FSDIP_method.PNG?generation=1682513923845611&alt=media)

# Library

In [1]:
wheels_path = "/kaggle/input/clip-interrogator-wheels-x"
clip_interrogator_whl_path = f"{wheels_path}/clip_interrogator-0.4.3-py3-none-any.whl"
!pip install --no-index --find-links $wheels_path $clip_interrogator_whl_path -q

[0m

In [2]:
import os, glob, gc
import random
import numpy as np
import pandas as pd
import pickle
from pathlib import Path
from PIL import Image
from tqdm.notebook import tqdm
from sklearn.preprocessing import normalize
import torch
from torch import nn
from torchvision import transforms
import open_clip
import warnings
warnings.filterwarnings('ignore')
from matplotlib import pyplot as plt

# Config

In [3]:
class CFG:
    clip_model_path = "/kaggle/input/laion-vit-h-14-model/ViT-H-14_laion2b_s32b_b79k.pt"
    clip_preproc_pkl = "/kaggle/input/laion-vit-h-14-model/preprocess.pkl"
    input_size = 224
    batch_size = 64
    seed = 42
    knn_topk = 100
    knn_interval = 1000
    knn_dim = 12

In [4]:
def seed_everything(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    if torch.cuda.is_available(): 
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True

seed_everything(CFG.seed)

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#### <a id="top"></a>
# <div style="box-shadow: rgb(60, 121, 245) 0px 0px 0px 3px inset, rgb(255, 255, 255) 10px -10px 0px -3px, rgb(31, 193, 27) 10px -10px, rgb(255, 255, 255) 20px -20px 0px -3px, rgb(255, 217, 19) 20px -20px, rgb(255, 255, 255) 30px -30px 0px -3px, rgb(255, 156, 85) 30px -30px, rgb(255, 255, 255) 40px -40px 0px -3px, rgb(255, 85, 85) 40px -40px; padding:20px; margin-right: 40px; font-size:30px; font-family: consolas; text-align:center; display:fill; border-radius:15px; color:rgb(60, 121, 245);"><b> ref_embeddings </b></div>


In [6]:
FILES_OBJECTIVE_EMB = []
FILES_CLIP_EMB = []

In [7]:
### DiffusionDB-14M (https://huggingface.co/datasets/poloclub/diffusiondb)
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-diffusiondb-14m-part1/all_minilm_l6_v2/*.npy"))
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-diffusiondb-14m-part2/all_minilm_l6_v2/*.npy"))
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-diffusiondb-14m-part3/all_minilm_l6_v2/*.npy"))

# CLIP text tembeddings
FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-diffusiondb-14m-part1/vith14/*.npy"))
FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-diffusiondb-14m-part2/vith14/*.npy"))
FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-diffusiondb-14m-part3/vith14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

140 140


In [8]:
# MSCOCO 2017(train data) (https://cocodataset.org/#download)
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-mscoco/all_minilm_l6_v2/*.npy"))

# CLIP text tembeddings
FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-mscoco/vith14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

146 146


In [9]:
# dataset80k (https://www.kaggle.com/competitions/stable-diffusion-image-to-prompts/discussion/390674)
# Objective embeddings
FILES_OBJECTIVE_EMB.append("/kaggle/input/pub-embeddings-dataset80k/all_minilm_l6_v2/prompt_embeddings_allminilm_001.npy")

# CLIP text tembeddings
FILES_CLIP_EMB.append("/kaggle/input/pub-embeddings-dataset80k/vith14/prompt_embedding_vith14_001.npy")
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

147 147


In [10]:
# Dataset30k (https://www.kaggle.com/competitions/stable-diffusion-image-to-prompts/discussion/391500)
# Objective embeddings
FILES_OBJECTIVE_EMB.append("/kaggle/input/pub-embeddings-dataset30k/all_minilm_l6_v2/prompt_embeddings_allminilm_001.npy")

# CLIP text tembeddings
FILES_CLIP_EMB.append("/kaggle/input/pub-embeddings-dataset30k/vith14/prompt_embedding_vith14_001.npy")
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

148 148


In [11]:
# Dataset900k (https://www.kaggle.com/competitions/stable-diffusion-image-to-prompts/discussion/399699)
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-dataset900k/all_minilm_l6_v2/*.npy"))

# CLIP text tembeddings
FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-dataset900k/vith14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

158 158


In [12]:
# Conceptual Captions
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-conceptual-captions/all_minilm_l6_v2/*.npy"))

# CLIP text tembeddings
FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-conceptual-captions/vith14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

192 192


In [13]:
# SD2GPT2 (https://www.kaggle.com/datasets/xiaozhouwang/sd2gpt2)
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-sd2gpt2/all_minilm_l6_v2/*.npy"))

# CLIP text tembeddings
FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-sd2gpt2/vith14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

193 193


In [14]:
# SD2Hardcode (https://www.kaggle.com/datasets/xiaozhouwang/sd2hardcode)
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-sd2hardcode/all_minilm_l6_v2/*.npy"))

# CLIP text tembeddings
FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-sd2hardcode/vith14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

194 194


In [15]:
# ChatGPT (https://www.kaggle.com/competitions/stable-diffusion-image-to-prompts/discussion/402146)
# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-chatgpt/all_minilm_l6_v2/*.npy"))

# CLIP text tembeddings
FILES_CLIP_EMB += sorted(glob.glob("/kaggle/input/pub-embeddings-chatgpt/vith14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))


195 195


In [16]:
# Laion2B-en (https://huggingface.co/datasets/laion/laion2B-en)
# (part0000-part0050 of 2000)

# Objective embeddings
FILES_OBJECTIVE_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0000-0004/all_minilm_l6_v2/*.npy"))
FILES_OBJECTIVE_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0005-0009/all_minilm_l6_v2/*.npy"))
FILES_OBJECTIVE_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0010-0014/all_minilm_l6_v2/*.npy"))
FILES_OBJECTIVE_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0015-0019/all_minilm_l6_v2/*.npy"))
FILES_OBJECTIVE_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0020-0024/all_minilm_l6_v2/*.npy"))
FILES_OBJECTIVE_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0025-0029/all_minilm_l6_v2/*.npy"))
FILES_OBJECTIVE_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0030-0034/all_minilm_l6_v2/*.npy"))
FILES_OBJECTIVE_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0035-0039/all_minilm_l6_v2/*.npy"))
FILES_OBJECTIVE_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0040-0044/all_minilm_l6_v2/*.npy"))
FILES_OBJECTIVE_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0045-0049/all_minilm_l6_v2/*.npy"))

# CLIP text tembeddings
FILES_CLIP_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0000-0004/vith14/*.npy"))
FILES_CLIP_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0005-0009/vith14/*.npy"))
FILES_CLIP_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0010-0014/vith14/*.npy"))
FILES_CLIP_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0015-0019/vith14/*.npy"))
FILES_CLIP_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0020-0024/vith14/*.npy"))
FILES_CLIP_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0025-0029/vith14/*.npy"))
FILES_CLIP_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0030-0034/vith14/*.npy"))
FILES_CLIP_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0035-0039/vith14/*.npy"))
FILES_CLIP_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0040-0044/vith14/*.npy"))
FILES_CLIP_EMB += sorted(glob.glob(f"/kaggle/input/pub-embeddings-laion2b-part0045-0049/vith14/*.npy"))
print(len(FILES_OBJECTIVE_EMB), len(FILES_CLIP_EMB))

665 665


#### <a id="top"></a>
# <div style="box-shadow: rgb(60, 121, 245) 0px 0px 0px 3px inset, rgb(255, 255, 255) 10px -10px 0px -3px, rgb(31, 193, 27) 10px -10px, rgb(255, 255, 255) 20px -20px 0px -3px, rgb(255, 217, 19) 20px -20px, rgb(255, 255, 255) 30px -30px 0px -3px, rgb(255, 156, 85) 30px -30px, rgb(255, 255, 255) 40px -40px 0px -3px, rgb(255, 85, 85) 40px -40px; padding:20px; margin-right: 40px; font-size:30px; font-family: consolas; text-align:center; display:fill; border-radius:15px; color:rgb(60, 121, 245);"><b> Generate CLIP vision embeddings </b></div>

In [17]:
test_images = list(Path('/kaggle/input/stable-diffusion-image-to-prompts/images').glob('*.png'))
print(len(test_images))

7


In [18]:
# load CLIP vision model (Laion/ViTH14)
clip_model = torch.jit.load(CFG.clip_model_path).cuda()
clip_model = clip_model.cuda().eval().half();

In [19]:
# load transforms from pickele
with open(CFG.clip_preproc_pkl, "rb") as fp:
    saved_preprocess = pickle.load(fp)
saved_preprocess

Compose(
    Resize(size=224, interpolation=bicubic, max_size=None, antialias=None)
    CenterCrop(size=(224, 224))
    <function _convert_to_rgb at 0x7cdce1f90290>
    ToTensor()
    Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)

In [20]:
# Inference with CLIP vision encoder
test_vision_embeddings = []
for test_image in tqdm(test_images):
    image = Image.open( test_image )
    prep = saved_preprocess(image).unsqueeze(0).to(device)
    embedding = clip_model(prep.half())
    test_vision_embeddings.append( embedding.detach().cpu().numpy() )
test_vision_embeddings = np.concatenate(test_vision_embeddings).astype(np.float16)
print(test_vision_embeddings.shape)

  0%|          | 0/7 [00:00<?, ?it/s]

(7, 1024)


In [21]:
del clip_model, prep, image, embedding
gc.collect()

5257

In [22]:
# Reduce the number of files as debugging.
if test_vision_embeddings.shape[0] == 7:   # 7 public test images
    FILES_OBJECTIVE_EMB = FILES_OBJECTIVE_EMB[:5]
    FILES_CLIP_EMB = FILES_CLIP_EMB[:5]
    print(f"!!! Debug mode !!!")
    print(f"len(FILES_OBJECTIVE_EMB)={len(FILES_OBJECTIVE_EMB)}")
    print(f"len(FILES_CLIP_EMB)={len(FILES_CLIP_EMB)}")

!!! Debug mode !!!
len(FILES_OBJECTIVE_EMB)=5
len(FILES_CLIP_EMB)=5


#### <a id="top"></a>
# <div style="box-shadow: rgb(60, 121, 245) 0px 0px 0px 3px inset, rgb(255, 255, 255) 10px -10px 0px -3px, rgb(31, 193, 27) 10px -10px, rgb(255, 255, 255) 20px -20px 0px -3px, rgb(255, 217, 19) 20px -20px, rgb(255, 255, 255) 30px -30px 0px -3px, rgb(255, 156, 85) 30px -30px, rgb(255, 255, 255) 40px -40px 0px -3px, rgb(255, 85, 85) 40px -40px; padding:20px; margin-right: 40px; font-size:30px; font-family: consolas; text-align:center; display:fill; border-radius:15px; color:rgb(60, 121, 245);"><b> k-NN Regression (CUDA) </b></div>


In [23]:
def predict_local_knn(
    ref_x_embeddings, test_x_embeddings, 
    n_neighbors=CFG.knn_topk,
    interval=CFG.knn_interval,
    distance_dim=CFG.knn_dim,
    coef=1.0, # a coef to prevent from overflow
):
    
    # convert to tensor
    ref_x_embeddings = torch.from_numpy(ref_x_embeddings).to('cuda')
    ref_x_embeddings /= ref_x_embeddings.norm(dim=-1, keepdim=True)
    
    n_iter = test_x_embeddings.shape[0]//interval
    if test_x_embeddings.shape[0]%interval != 0:
        n_iter += 1
        
    dist_topk_store = []
    idxs_topk_store = []
    weights_store = []
    preds = []
    delta = 0.0001
    for i in range(n_iter):
        batch_test_embeddings = torch.from_numpy(
            test_x_embeddings[i*interval:(i+1)*interval, :].copy()
        ).to('cuda')
        batch_test_embeddings /= batch_test_embeddings.norm(dim=-1, keepdim=True)
        
        # calc distance matrix
        dists = 1 - torch.mm(batch_test_embeddings, ref_x_embeddings.T) # dists.shape=[N_test, N_ref]
        del batch_test_embeddings
        gc.collect()
        
        # get topk indecies and distance
        dist_topk, idxs_topk = torch.topk(dists, n_neighbors, largest=False, dim=-1)
        dist_topk = dist_topk.to(torch.float64)        
        
        # calc weights from distance
        weights = 1/(dist_topk**distance_dim+delta)*coef
        weights[ dist_topk < 0 ] = delta
        
        dist_topk_store.append( dist_topk.to('cpu').detach().numpy().copy() )
        idxs_topk_store.append( idxs_topk.to('cpu').detach().numpy().copy() )
        weights_store.append( weights.to('cpu').detach().numpy().copy() )
                
        del dists, weights, dist_topk, idxs_topk
        torch.cuda.empty_cache()
        gc.collect()
        
    del ref_x_embeddings
    torch.cuda.empty_cache()
    gc.collect()
    return np.concatenate(dist_topk_store), np.concatenate(idxs_topk_store), np.concatenate(weights_store)

In [24]:
for i_file, file_clip_emb in enumerate(tqdm(FILES_CLIP_EMB)):
    # Local k-NN (for each CLIP embeddings file VS CLIP vision embeddings of test images)
    ref_clip_embeddings = np.load(file_clip_emb).astype(np.float16)
    with torch.no_grad():
        local_dists, local_emb_indecies, local_weights = predict_local_knn(
            ref_clip_embeddings, test_vision_embeddings,
            n_neighbors=CFG.knn_topk, interval=CFG.knn_interval, distance_dim=CFG.knn_dim,
            coef=0.001
        )
    local_files = np.zeros(local_dists.shape, dtype=np.int32) + i_file
    
    # merge local k-NN into global k-NN
    if i_file == 0:
        global_files = local_files
        global_dists = local_dists
        global_emb_indecies = local_emb_indecies
        global_weights = local_weights
    else:
        global_files = np.concatenate([global_files, local_files], axis=-1)
        global_dists = np.concatenate([global_dists, local_dists], axis=-1)
        global_emb_indecies = np.concatenate([global_emb_indecies, local_emb_indecies], axis=-1)
        global_weights = np.concatenate([global_weights, local_weights], axis=-1)

        unsorted_min_indices = np.argpartition(global_dists, CFG.knn_topk, axis=1)[:, :CFG.knn_topk]

        global_files = np.vstack( [ global_files[i, unsorted_min_indices[i,:]] for i in range(unsorted_min_indices.shape[0]) ])
        global_dists = np.vstack( [ global_dists[i, unsorted_min_indices[i,:]] for i in range(unsorted_min_indices.shape[0]) ])
        global_emb_indecies = np.vstack( [ global_emb_indecies[i, unsorted_min_indices[i,:]] for i in range(unsorted_min_indices.shape[0]) ])
        global_weights = np.vstack( [ global_weights[i, unsorted_min_indices[i,:]] for i in range(unsorted_min_indices.shape[0]) ])
    
    gc.collect()

  0%|          | 0/5 [00:00<?, ?it/s]

In [25]:
df_knn = pd.DataFrame()
df_knn["file"] = global_files.flatten()
df_knn["file"] = df_knn['file'].apply(lambda x: FILES_OBJECTIVE_EMB[x])
df_knn["dist"] = global_dists.flatten()
df_knn["emb_index"] = global_emb_indecies.flatten()
df_knn["test_index"] = np.array([ [val]*CFG.knn_topk for val in range(test_vision_embeddings.shape[0])]).flatten()
df_knn["weight"] = global_weights.flatten()
df_knn

Unnamed: 0,file,dist,emb_index,test_index,weight
0,/kaggle/input/pub-embeddings-diffusiondb-14m-p...,0.716797,35943,0,0.054062
1,/kaggle/input/pub-embeddings-diffusiondb-14m-p...,0.726562,60517,0,0.045996
2,/kaggle/input/pub-embeddings-diffusiondb-14m-p...,0.709961,98421,0,0.060610
3,/kaggle/input/pub-embeddings-diffusiondb-14m-p...,0.709961,98420,0,0.060610
4,/kaggle/input/pub-embeddings-diffusiondb-14m-p...,0.709961,98419,0,0.060610
...,...,...,...,...,...
695,/kaggle/input/pub-embeddings-diffusiondb-14m-p...,0.644531,22745,6,0.190855
696,/kaggle/input/pub-embeddings-diffusiondb-14m-p...,0.670898,23578,6,0.118830
697,/kaggle/input/pub-embeddings-diffusiondb-14m-p...,0.670898,23582,6,0.118830
698,/kaggle/input/pub-embeddings-diffusiondb-14m-p...,0.670898,23575,6,0.118830


In [26]:
gc.collect()

39

In [27]:
# k-NN regression
test_prompt_embeddings = np.zeros( (test_vision_embeddings.shape[0], 384))
for (objective_emb_file, gdf) in tqdm(df_knn.groupby("file")):
    ref_objective_embeddings = np.load(objective_emb_file).astype(np.float16) 
    for _, r in gdf.iterrows():
        test_prompt_embeddings[int(r.test_index), :] += r.weight * ref_objective_embeddings[int(r.emb_index), :]

  0%|          | 0/5 [00:00<?, ?it/s]

In [28]:
# L2 norm
BS=1000
num = test_prompt_embeddings.shape[0] // BS
if test_prompt_embeddings.shape[0] % BS != 0:
    num+=1
for i in range(num):
    embeddings = test_prompt_embeddings[i*BS:(i+1)*BS, :]
    embeddings = embeddings / ( np.abs(embeddings).max(axis=-1, keepdims=True) + 0.0000001)
    embeddings = normalize( embeddings )
    test_prompt_embeddings[i*BS:(i+1)*BS, :] = embeddings
    
gc.collect()

122

In [29]:
test_prompt_embeddings = test_prompt_embeddings.flatten()

In [30]:
wheels_path = "/kaggle/input/clip-interrogator-wheels-x"
clip_interrogator_whl_path = f"{wheels_path}/clip_interrogator-0.4.3-py3-none-any.whl"

In [31]:
!pip install --no-index --find-links $wheels_path $clip_interrogator_whl_path -q

[0m

In [32]:
import inspect
import importlib

from blip.models import blip
from clip_interrogator import clip_interrogator

In [33]:
# replace tokenizer path to prevent downloading
blip_path = inspect.getfile(blip)

fin = open(blip_path, "rt")
data = fin.read()
data = data.replace(
    "BertTokenizer.from_pretrained('bert-base-uncased')", 
    "BertTokenizer.from_pretrained('/kaggle/input/clip-interrogator-models-x/bert-base-uncased')"
)
fin.close()

fin = open(blip_path, "wt")
fin.write(data)
fin.close()

# reload module
importlib.reload(blip)

<module 'blip.models.blip' from '/opt/conda/lib/python3.7/site-packages/blip/models/blip.py'>

In [34]:
# fix clip_interrogator bug
clip_interrogator_path = inspect.getfile(clip_interrogator.Interrogator)

fin = open(clip_interrogator_path, "rt")
data = fin.read()
data = data.replace(
    'open_clip.get_tokenizer(clip_model_name)', 
    'open_clip.get_tokenizer(config.clip_model_name.split("/", 2)[0])'
)
fin.close()

fin = open(clip_interrogator_path, "wt")
fin.write(data)
fin.close()

# reload module
importlib.reload(clip_interrogator)

<module 'clip_interrogator.clip_interrogator' from '/opt/conda/lib/python3.7/site-packages/clip_interrogator/clip_interrogator.py'>

In [35]:
import os
import sys
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt 

import numpy as np
import pandas as pd
import torch
import open_clip

sys.path.append('../input/sentence-transformers-222/sentence-transformers')
from sentence_transformers import SentenceTransformer, models

comp_path = Path('/kaggle/input/stable-diffusion-image-to-prompts/')

In [36]:
class CFG:
    device = "cuda"
    seed = 42
    embedding_length = 384
    sentence_model_path = "/kaggle/input/sentence-transformers-222/all-MiniLM-L6-v2"
    blip_model_path = "/kaggle/input/clip-interrogator-models-x/model_large_caption.pth"
    ci_clip_model_name = "ViT-H-14/laion2b_s32b_b79k"
    clip_model_name = "ViT-H-14"
    clip_model_path = "/kaggle/input/clip-interrogator-models-x/CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin"
    cache_path = "/kaggle/input/clip-interrogator-models-x"

In [37]:
df_submission = pd.read_csv(comp_path / 'sample_submission.csv', index_col='imgId_eId')
df_submission.head()

Unnamed: 0_level_0,val
imgId_eId,Unnamed: 1_level_1
20057f34d_0,0.018848
20057f34d_1,0.03019
20057f34d_2,0.072792
20057f34d_3,-0.000673
20057f34d_4,0.016774


In [38]:
images = os.listdir(comp_path / 'images')
imgIds = [i.split('.')[0] for i in images]

eIds = list(range(CFG.embedding_length))

imgId_eId = [
    '_'.join(map(str, i)) for i in zip(
        np.repeat(imgIds, CFG.embedding_length),
        np.tile(range(CFG.embedding_length), len(imgIds))
    )
]

assert sorted(imgId_eId) == sorted(df_submission.index)

In [39]:
st_model = SentenceTransformer(CFG.sentence_model_path)

In [40]:
model_config = clip_interrogator.Config(clip_model_name=CFG.ci_clip_model_name)
model_config.cache_path = CFG.cache_path

In [41]:
configs_path = os.path.join(os.path.dirname(os.path.dirname(blip_path)), 'configs')
med_config = os.path.join(configs_path, 'med_config.json')
blip_model = blip.blip_decoder(
    pretrained=CFG.blip_model_path,
    image_size=model_config.blip_image_eval_size, 
    vit=model_config.blip_model_type, 
    med_config=med_config
)
blip_model.eval()
blip_model = blip_model.to(model_config.device)
model_config.blip_model = blip_model

load checkpoint from /kaggle/input/clip-interrogator-models-x/model_large_caption.pth


In [42]:
clip_model = open_clip.create_model(CFG.clip_model_name, precision='fp16' if model_config.device == 'cuda' else 'fp32')
open_clip.load_checkpoint(clip_model, CFG.clip_model_path)
clip_model.to(model_config.device).eval()
model_config.clip_model = clip_model

In [43]:
clip_preprocess = open_clip.image_transform(
    clip_model.visual.image_size,
    is_train = False,
    mean = getattr(clip_model.visual, 'image_mean', None),
    std = getattr(clip_model.visual, 'image_std', None),
)
model_config.clip_preprocess = clip_preprocess

In [44]:
ci = clip_interrogator.Interrogator(model_config)

Loaded CLIP model and data in 2.26 seconds.


In [45]:
cos = torch.nn.CosineSimilarity(dim=1)

mediums_features_array = torch.stack([torch.from_numpy(t) for t in ci.mediums.embeds]).to(ci.device)
movements_features_array = torch.stack([torch.from_numpy(t) for t in ci.movements.embeds]).to(ci.device)
flavors_features_array = torch.stack([torch.from_numpy(t) for t in ci.flavors.embeds]).to(ci.device)
#trending_features_array = torch.stack([torch.from_numpy(t) for t in ci.trending.embeds]).to(ci.device)

In [46]:
def interrogate(image: Image) -> str:
    caption = ci.generate_caption(image)
    image_features = ci.image_to_features(image)
    
    #medium = [ci.mediums.labels[i] for i in cos(image_features, mediums_features_array).topk(1).indices][0]
    #movement = [ci.movements.labels[i] for i in cos(image_features, movements_features_array).topk(1).indices][0]
    #trending = [ci.movements.labels[i] for i in cos(image_features, trending_features_array).topk(1).indices][0]
    #flaves = ", ".join([ci.flavors.labels[i] for i in cos(image_features, flavors_features_array).topk(3).indices])
    
    medium = [ci.mediums.labels[i] for i in cos(image_features, mediums_features_array).topk(1).indices][0]
    movement = [ci.movements.labels[i] for i in cos(image_features, movements_features_array).topk(1).indices][0]
    flaves = ", ".join([ci.flavors.labels[i] for i in cos(image_features, flavors_features_array).topk(3).indices])

    if caption.startswith(medium):
        prompt = f"{caption}, {movement},  {flaves}"
    else:
        prompt = f"{caption}, {medium}, {movement}, {flaves}"

    return clip_interrogator._truncate_to_fit(prompt, ci.tokenize)

In [47]:
import re
from gensim.parsing.preprocessing import remove_stopwords
prompts = []


images_path = "../input/stable-diffusion-image-to-prompts/images/"
for image_name in images:
    #print(image_name)
    img = Image.open(images_path + image_name).convert("RGB")

    generated = interrogate(img)
    #print(generated)
    generated = re.sub(r'[^A-Za-z ]+', '', generated)
    generated = ' '.join(dict.fromkeys(generated.split()))
    prompts.append(generated)
#print(".....")
#print(prompts)

In [48]:
def add_text_limiters(text: str) -> str:
    return " ".join([
        word + "\n" if i % 15 == 0 else word 
        for i, word in enumerate(text.split(" "), start=1)
    ])

def plot_image(image: np.ndarray, original_prompt: str, generated_prompt: str) -> None:
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    plt.annotate(
        "Original prompt:\n" + add_text_limiters(original_prompt) + "\n\nGenerated prompt:\n" + add_text_limiters(generated_prompt), 
        xy=(1.05, 0.5), xycoords='axes fraction', ha='left', va='center', 
        fontsize=16, rotation=0, color="#104a6e"
    )

In [49]:
prompt_embeddings = st_model.encode(prompts).flatten()

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

In [50]:
import numpy as np
import pandas as pd
from pathlib import Path
from PIL import Image
from tqdm.notebook import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import timm
from sklearn.preprocessing import normalize

In [51]:
class CFG:
    model_path = '/kaggle/input/swin-large-finetune-stablediffusion-textimage-pair/vit_large_patch16_384_1_64_0.0001_0.6564.pth/vit_large_patch16_384_1_64_0.0001_0.6564.pth'
    model_name = 'vit_large_patch16_384'
    input_size = 384
    batch_size = 64

In [52]:
class DiffusionTestDataset(Dataset):
    def __init__(self, images, transform):
        self.images = images
        self.transform = transform
    
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx])
        image = self.transform(image)
        return image

In [53]:
def predict(
    images,
    model_path,
    model_name,
    input_size,
    batch_size
):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    transform = transforms.Compose([
        transforms.Resize(input_size),
        transforms.RandomHorizontalFlip(p=0.5),
        #transforms.RandomRotation(degrees=10),

        #transforms.RandomVerticalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    dataset = DiffusionTestDataset(images, transform)
    dataloader = DataLoader(
        dataset=dataset,
        shuffle=False,
        batch_size=batch_size,
        pin_memory=True,
        num_workers=2,
        drop_last=False
    )

    model = timm.create_model(
        model_name,
        pretrained=False,
        num_classes=384
    )
    state_dict = torch.load(model_path)
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()
    
    tta_preds = None
    for _ in range(2):
        preds = []
        for X in tqdm(dataloader, leave=False):
            X = X.to(device)

            with torch.no_grad():
                X_out = model(X).cpu().numpy()
                # L2 normalize -- Start
                X_out = X_out / ( np.abs(X_out).max(axis=-1, keepdims=True) + 0.0000001)  # To avoid to overflow at normalize()
                X_out = normalize( X_out )
                # L2 normalize -- End
                preds.append(X_out)
                
        if tta_preds is None:
            tta_preds = np.vstack(preds).flatten()
        else:
            tta_preds += np.vstack(preds).flatten()
    
    return tta_preds / 2

In [54]:
images = list(Path('/kaggle/input/stable-diffusion-image-to-prompts/images').glob('*.png'))
embeddings2 = predict(images, CFG.model_path, CFG.model_name, CFG.input_size, CFG.batch_size)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

In [55]:
class CFG:
    model_path = '/kaggle/input/swin-large-finetune-stablediffusion-textimage-pair/swin_large_patch4_window7_224_10_epochs.pth'
    model_name = 'swin_large_patch4_window7_224'
    input_size = 224
    batch_size = 64
embeddings3 = predict(images, CFG.model_path, CFG.model_name, CFG.input_size, CFG.batch_size)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

In [56]:
class CFG:
    model_path = '/kaggle/input/vit-large/vit_large_patch16_224.pth'
    model_name = 'vit_large_patch16_224'
    input_size = 224
    batch_size = 64
#embeddings4 = predict(images, CFG.model_path, CFG.model_name, CFG.input_size, CFG.batch_size)

In [57]:
class CFG:
    model_path = '/kaggle/input/swin-large-finetune-stablediffusion-textimage-pair/swin_large_patch4_window7_224_400k_11_epoch.pth'
    model_name = 'swin_large_patch4_window7_224'
    input_size = 224
    batch_size = 64
#embeddings4 = predict(images, CFG.model_path, CFG.model_name, CFG.input_size, CFG.batch_size)

# prompt generator - OFA large

In [58]:
!pip install /kaggle/input/clip-interragtor2/promptcap-1.0.3-py3-none-any.whl

Processing /kaggle/input/clip-interragtor2/promptcap-1.0.3-py3-none-any.whl
Installing collected packages: promptcap
Successfully installed promptcap-1.0.3
[0m

In [59]:

import torch
from promptcap import PromptCap

model = PromptCap("/kaggle/input/stable-diffusion-data/OFA-large-caption/")  # also support OFA checkpoints. e.g. "OFA-Sys/ofa-large"

if torch.cuda.is_available():
    model.cuda()
    #model.cpu()
prompt = "what does the image describe?"
#image = "/kaggle/input/stable-diffusion-image-to-prompts/images/20057f34d.png"

#print(model.caption(prompt, image))


/kaggle/input/stable-diffusion-data/OFA-large-caption/
<super: <class 'OFATokenizer'>, <OFATokenizer object>>


In [60]:
prompts = []

for image_name in images:
    generated = model.caption(prompt, image_name)
    generated = re.sub(r'[^A-Za-z ]+', '', generated)
    generated = ' '.join(dict.fromkeys(generated.split()))
    prompts.append(generated)
ofa_embeddings = st_model.encode(prompts).flatten()

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

In [61]:
# knn_clip + clip + vit large 
            # 0.58398: 0.45 0.15 0.4
            # 0.58331: 0.5 0.1 0.4
            # 0.58349: 0.4 0.2 0.4 
            # 0.58276: 0.45 0.1 0.45
# knn_clip + clip + vit large + swin
            # 0.58720: 0.45 0.15 0.3 0.1 
            # 0.58531: 0.4 0.1 0.4 0.1
            # 0.58641: 0.5 0.1 0.3 0.1
            # 0.58627: 0.45 0.1 0.35 0.1
            #          0.45 0.2 0.25 0.1
            #           0.4 0.2 0.3 0.1 
# knn_clip + clip + vit large + swin + vit large 224
            # 0.58639: 0.4 0.2 0.3 0.05 0.05
            # 0.58680: 0.45 0.15 0.3 0.05 0.05
            # 0.58758: 0.4 0.15 0.3 0.1 0.05
            # 
# knn_clip + clip + vit large + ofa_embeddings 
            # 0.4 0.2 0.3 0.1 
            # 0.4 0.15 0.3 0.15
            # 0.45 0.15 0.3 0.1
            
# knn_clip + clip + vit large + ofa_embeddings + swin
            # 0.4 0.15 0.3 0.05 0.1
            # best: 0.4 0.15 0.3 0.1 0.05
            # 0.35 0.15 0.3 0.1 0.1
            # 0.35 0.2 0.3 0.1 0.05
            # 0.5 0.15 0.4 0.1 0.05
            # 0.5 0.2 0.4 0.1 0.05
            # 0.5 0.3 0.5 0.1 0.05
            # 0.45 0.15 0.35 0.1 0.05
            # 0.375 0.15 0.325 0.1 0.05
            # 0.39 0.15 0.31 0.1 0.05
            # 0.38 0.15 0.32 0.1 0.05
            # 0.41 0.15 0.29 0.1 0.05

#test_prompt_embeddings = test_prompt_embeddings*0.4+prompt_embeddings*0.15+embeddings2*0.3+embeddings3*0.1+embeddings4*0.05
test_prompt_embeddings = test_prompt_embeddings*0.41+prompt_embeddings*0.15+embeddings2*0.29+ ofa_embeddings*0.1 + embeddings3 * 0.05
#test_prompt_embeddings = test_prompt_embeddings*0.3+prompt_embeddings*0.15+embeddings2*0.25+ ofa_embeddings*0.1 + embeddings3 * 0.1
#test_prompt_embeddings = test_prompt_embeddings*0.4+prompt_embeddings*0.15+embeddings2*0.3+ ofa_embeddings*0.1 + gpt_embeddings * 0.05

In [62]:
imgIds = [i.stem for i in test_images]
EMBEDDING_LENGTH = 384
imgId_eId = [
    '_'.join(map(str, i)) for i in zip(
        np.repeat(imgIds, EMBEDDING_LENGTH),
        np.tile(range(EMBEDDING_LENGTH), len(imgIds)))]

submission = pd.DataFrame(
    index=imgId_eId,
    data=test_prompt_embeddings,
    columns=['val']
).rename_axis('imgId_eId')
submission.to_csv('submission.csv')

In [63]:
import pandas as pd
df = pd.read_csv('/kaggle/input/swin-large-finetune-stablediffusion-textimage-pair/submission_0.59016.csv')
import numpy as np
from numpy.linalg import norm
 
A = submission['val']
B = df['val']
 
# compute cosine similarity
cosine = np.dot(A,B)/(norm(A)*norm(B))
print("Cosine Similarity:", cosine)

Cosine Similarity: 0.9999709642896993
