# 6.CLIP
https://github.com/mlfoundations/open_clip

In [79]:
import os
import sys
import pickle
import json
import glob
import gc
import random
import time
import unicodedata
import traceback
import datetime
import copy

import numpy as np
import pandas as pd
import torch

from torch.utils.data import Dataset, DataLoader

from matplotlib import pyplot as plt 
from tqdm.notebook import tqdm
from pathlib import Path
from scipy.spatial import distance
from collections import defaultdict
from PIL import Image
from collections import Counter

from sklearn.model_selection import train_test_split

sys.path.append('../input/sentence-transformers-222/sentence-transformers')
from sentence_transformers import SentenceTransformer, models

os.environ["TOKENIZERS_PARALLELISM"] = "false"

def set_seed(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
import open_clip

sys.path.append('../input/sentence-transformers-222/sentence-transformers')
from sentence_transformers import SentenceTransformer, models

In [142]:
class CFG_CLASS:
    seed = 42
    text_emb_size = 384
    is_kaggle = (os.environ.get('PWD') == '/kaggle/working')
    train_files_dir: str = "img2prompt-data"
    
    clip_model = 'ViT-B-16'
    pretrained = "laion2b_s34b_b88k"
    k = 6
    model_name = f"model_{clip_model}_{pretrained}_k_{k}".replace("-", "_")
    
    # RESOURCES
    batch_size = 16
    num_workers = batch_size if not is_kaggle else 2
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    dataset_dupl_word = 2
    metadata_path = f"../input/metadata/metadata_duplwords_{dataset_dupl_word}.parquet"
    
    train_name = f"{model_name}"
    

CFG = CFG_CLASS()
CFG.train_name

'model_ViT_B_16_laion2b_s34b_b88k_k_6'

# Functions 

In [143]:
def create_submission(pred_arr, img_names, text_emb_size):
    imgIds = [i.split('.')[0] for i in img_names]

    EMBEDDING_LENGTH = text_emb_size
    eIds = list(range(EMBEDDING_LENGTH))

    imgId_eId = [
        '_'.join(map(str, i)) for i in zip(
            np.repeat(imgIds, EMBEDDING_LENGTH),
            np.tile(range(EMBEDDING_LENGTH), len(imgIds)))]
    
    submission = pd.DataFrame(
                    index=imgId_eId,
                    data=np.array(pred_arr).flatten(),
                    columns=['val']).rename_axis('imgId_eId')
    return submission

def get_sim(emb1, emb2):
    assert len(emb1.shape) <= 2 and len(emb1.shape) <= 2, "False shape"
    
    if len(emb1.shape) == 1:
        emb1 = [emb1]
    if len(emb2.shape) == 1:
        emb2 = [emb2]
    
    sim_res = 0
    for i in range(len(emb1)):
        sim_res += 1 - distance.cosine(emb1[i], emb2[i])
    return sim_res / (i + 1)

class CustomDataSet(Dataset):
    def __init__(self, data_dir, img2prompt, img_preprocess):
        self.data_dir = data_dir
        self.img_names = list(img2prompt.keys())
        self.img2prompt = img2prompt
        self.img_preprocess = img_preprocess

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img_name = self.img_names[idx]
        img_path = os.path.join(self.data_dir, img_name)
        img = Image.open(img_path)
        img_emb = self.img_preprocess(img)
        
        prompt = str(self.img2prompt[img_name])
        
        return img_name, img_emb, prompt

# Get labels

In [144]:
with open("../input/img2prompt-data/clip_prompts.pickle", "rb") as f:
    ci = pickle.load(f)
print(len(ci["mediums"]))
print(len(ci["movements"]))
print(len(ci["flavors"]))
print(len(ci["negative"]))
print(len(ci["artists"]))
    
labels = ci["mediums"] + ci["movements"] + ci["flavors"] + ci["negative"] # + ci["artists"]
len(labels)

95
200
100970
41
10530


101306

In [145]:
# metadata = pd.read_parquet("../input/metadata/metadata.parquet")
# metadata["prompt"] = metadata["prompt"].str.replace(".", ",").str.strip(" ,.") + ","
# prompts = metadata["prompt"].astype(str).tolist()

# general_prompt = ", ".join(prompts)

# counter = Counter()
# words = general_prompt.split(", ")
# for word in words:
#     counter[word] += 1
# counter.most_common()[:10]

# key_words = []
# for word, cnt in counter.most_common():
#     word = word.strip(", .")
#     if word and cnt / len(prompts) > 0.0001:
#         key_words.append(word)
# len(key_words)

# labels = list(set(labels).union(set(key_words)) - set([""]))
# len(labels)

# Clip model 

In [146]:
def get_labels_features(labels, batch_size, device, text_tokenizer, model):
    labels_features = []
    with torch.no_grad(), torch.cuda.amp.autocast():
        for i in tqdm(range(0, len(labels), batch_size), disable=True):
            labels_i = labels[i:i + batch_size]
            labels_tokens_i = text_tokenizer(labels_i).to(device)
            labels_features_i = model.encode_text(labels_tokens_i, normalize=True)
            labels_features.append(labels_features_i)
    labels_features = torch.concat(labels_features)
    return labels_features

In [147]:
model, img_train_transform, img_eval_transform = open_clip.create_model_and_transforms(
    CFG.clip_model, 
    pretrained=CFG.pretrained, 
    device=CFG.device
)
text_tokenizer = open_clip.get_tokenizer(CFG.clip_model)

labels_features = get_labels_features(
    labels=labels, 
    batch_size=CFG.batch_size, 
    device=CFG.device, 
    text_tokenizer=text_tokenizer, 
    model=model)

In [157]:
torch.save(labels_features, f"../input/{CFG.train_files_dir}/labels_features_{CFG.train_name}.torch")
with open(f"../input/{CFG.train_files_dir}/labels_{CFG.train_name}.pickle", "wb") as f:
    pickle.dump(labels, f)
with open(f"../input/{CFG.train_files_dir}/model_{CFG.train_name}.pickle", "wb") as f:
    pickle.dump(model, f)
with open(f"../input/{CFG.train_files_dir}/img_eval_transform_{CFG.train_name}.pickle", "wb") as f:
    pickle.dump(img_eval_transform, f)

# Validate 

In [131]:
st_model = SentenceTransformer('../input/sentence-transformers-222/all-MiniLM-L6-v2/')1

train_data_dir = Path("../input/")

metadata = pd.read_parquet(CFG.metadata_path).sample(frac=0.05, random_state=CFG.seed)
full_prompt = metadata[["image_name", "prompt"]].values
val_prompt_dict = {img_name: prompt for img_name, prompt in full_prompt}

val_dataset = CustomDataSet(
    data_dir=train_data_dir, 
    img2prompt=val_prompt_dict, 
    img_preprocess=img_eval_transform,
)

val_dataloader = DataLoader(val_dataset, batch_size=CFG.batch_size, shuffle=False,
                                    num_workers=CFG.num_workers)

In [134]:
stats = pd.DataFrame()
for clip_model, pretrained in open_clip.list_pretrained():
    if clip_model in ["ViT-B-16"]:
        print(clip_model, pretrained)
        
        model, img_train_transform, img_eval_transform = open_clip.create_model_and_transforms(
            clip_model, 
            pretrained=pretrained, 
            device=CFG.device
        )
        text_tokenizer = open_clip.get_tokenizer(clip_model)
        
        labels_features = get_labels_features(
            labels=labels, 
            batch_size=CFG.batch_size, 
            device=CFG.device, 
            text_tokenizer=text_tokenizer, 
            model=model)
        
        model.eval()
        for k in list(range(1, 11, 1)) + [15, 20, 30]:
            sim_sum = 0
            with torch.no_grad(), torch.cuda.amp.autocast():
                for batch_i, (img_names, img_embs, true_prompts) in enumerate(bar := tqdm(val_dataloader, disable=True)):
                    img_embs = img_embs.to(CFG.device)
                    img_features = model.encode_image(img_embs, normalize=True)
                    labels_probs = (img_features @ labels_features.T)

                    pred_prompts = []
                    for i, top_label_ind in enumerate(labels_probs.topk(k).indices):
                        pred_prompt = ", ".join([labels[i] for i in top_label_ind])
                        pred_prompts.append(pred_prompt)

                    pred_prompts_emb = st_model.encode(pred_prompts)
                    true_prompts_emb = st_model.encode(true_prompts)

                    sim_i = get_sim(pred_prompts_emb, true_prompts_emb)
                    sim_sum += sim_i

                    bar.set_description(f"{sim_sum / (batch_i + 1)}")

            sim_sum /= (batch_i + 1)
            print(k, sim_sum)
            
            curr_stats = pd.DataFrame({"clip_model": [clip_model], 
                          "pretrained": [pretrained], "k": [k], 
                          "sim": [sim_sum]})
            stats = pd.concat([stats, curr_stats], ignore_index=True)
        
        del model, labels_features
        torch.cuda.empty_cache()
        gc.collect()

ViT-B-16 openai


100%|███████████████████████████████████████| 351M/351M [00:30<00:00, 11.7MiB/s]


1 0.34292937544097357
2 0.37567607023765476
3 0.39391289046517214
4 0.40432104179281225
5 0.4096818123959584
6 0.4118226677193785
7 0.41008556852124267
8 0.4062471311125276
9 0.402223092339739
10 0.3989175229902085
15 0.3939210467398633
20 0.3956030323250394
30 0.41309099773044594
ViT-B-16 laion400m_e31
1 0.35359673752181403
2 0.38629430394666486
3 0.4019524737545568
4 0.4105097735574202
5 0.4159754098087788
6 0.4171906698684565
7 0.41677811248459995
8 0.413901172398339
9 0.41046555497087234
10 0.40726665706527926
15 0.4019971742322967
20 0.40260604316933674
30 0.42082532284108953
ViT-B-16 laion400m_e32


100%|███████████████████████████████████████| 599M/599M [06:05<00:00, 1.64MiB/s]


1 0.35375026160956863
2 0.3876241106618103
3 0.4021223114503954
4 0.4110082137943597
5 0.4166425102412417
6 0.417786881096193
7 0.41711169019862643
8 0.41405286187379603
9 0.41095706333843857
10 0.4075957465231994
15 0.4017884175614433
20 0.40307631065726285
30 0.4203051190691735
ViT-B-16 laion2b_s34b_b88k


Downloading (…)ip_pytorch_model.bin:   0%|          | 0.00/599M [00:00<?, ?B/s]

1 0.3616643027156491
2 0.3895919983828296
3 0.40344584570884046
4 0.4139548495796117
5 0.4199239736589559
6 0.4221211696777065
7 0.42108609098928457
8 0.4179800748570618
9 0.4140232437101883
10 0.41067171977828115
15 0.40465672128124264
20 0.4057926514798736
30 0.4235488012796502


In [140]:
stats.sort_values("sim", ascending=False)

Unnamed: 0,clip_model,pretrained,k,sim
51,ViT-B-16,laion2b_s34b_b88k,30,0.423549
44,ViT-B-16,laion2b_s34b_b88k,6,0.422121
45,ViT-B-16,laion2b_s34b_b88k,7,0.421086
25,ViT-B-16,laion400m_e31,30,0.420825
38,ViT-B-16,laion400m_e32,30,0.420305
43,ViT-B-16,laion2b_s34b_b88k,5,0.419924
46,ViT-B-16,laion2b_s34b_b88k,8,0.41798
31,ViT-B-16,laion400m_e32,6,0.417787
18,ViT-B-16,laion400m_e31,6,0.417191
32,ViT-B-16,laion400m_e32,7,0.417112


# Inference 

In [160]:
st_model = SentenceTransformer('../input/sentence-transformers-222/all-MiniLM-L6-v2/')

labels_features = torch.load(f"../input/{CFG.train_files_dir}/labels_features_{CFG.train_name}.torch")
with open(f"../input/{CFG.train_files_dir}/labels_{CFG.train_name}.pickle", "rb") as f:
    labels = pickle.load(f)
with open(f"../input/{CFG.train_files_dir}/model_{CFG.train_name}.pickle", "rb") as f:
    model = pickle.load(f)
with open(f"../input/{CFG.train_files_dir}/img_eval_transform_{CFG.train_name}.pickle", "rb") as f:
    img_eval_transform = pickle.load(f)

text_tokenizer = open_clip.get_tokenizer(CFG.clip_model)

In [161]:
test_data_dir = Path("../input/stable-diffusion-image-to-prompts/images/")
test_image_names = sorted(os.listdir(test_data_dir))
test_prompt_dict = {img_name: "" for img_name in test_image_names}

test_dataset = CustomDataSet(   
    data_dir=test_data_dir, 
    img2prompt=test_prompt_dict, 
    img_preprocess=img_eval_transform
)
test_dataloader = DataLoader(
    test_dataset, 
    batch_size=CFG.batch_size, 
    shuffle=False, 
    num_workers=CFG.num_workers
)

In [166]:
model.to(CFG.device)
model.eval()

pred_arr = []
clip_prompts = {}
with torch.no_grad(), torch.cuda.amp.autocast():
    for img_names, img_embs, prompts in test_dataloader:
        img_embs = img_embs.to(CFG.device)
        img_features = model.encode_image(img_embs, normalize=True)
        labels_probs = (img_features @ labels_features.T)

        pred_prompts = []
        for i, top_label_ind in enumerate(labels_probs.topk(CFG.k).indices):
            pred_prompt = ", ".join([labels[i] for i in top_label_ind])
            pred_prompts.append(pred_prompt)
            clip_prompts[img_names[i]] = pred_prompt

        pred_prompts_emb = st_model.encode(pred_prompts, show_progress_bar=False)
        pred_arr.extend(pred_prompts_emb)
        
pred_arr = np.array(pred_arr)
pred_arr.shape

(7, 384)

In [118]:
submission = create_submission(pred_arr, test_image_names, text_emb_size=CFG.text_emb_size)
submission

Unnamed: 0_level_0,val
imgId_eId,Unnamed: 1_level_1
20057f34d_0,0.006771
20057f34d_1,-0.003735
20057f34d_2,0.040284
20057f34d_3,0.107507
20057f34d_4,-0.015699
...,...
f27825b2c_379,0.149944
f27825b2c_380,0.044683
f27825b2c_381,-0.004398
f27825b2c_382,-0.081970


In [None]:
if CFG.is_kaggle:
    submission.to_csv("submission.csv")