In [1]:
!pip install --no-index --no-deps /kaggle/input/lavis-pretrained/salesforce-lavis/transformers* 
!pip install --no-index --no-deps /kaggle/input/lavis-pretrained/salesforce-lavis/hugging*
!pip install --no-index --find-links /kaggle/input/skt-clip-interrogator/skt-clip-interrogator/open-clip-torch-2.16.0/ open-clip-torch==2.16.0
!pip install --no-index --find-links /kaggle/input/skt-clip-interrogator/skt-clip-interrogator/safetensors-0.3.0/ safetensors==0.3.0

Processing /kaggle/input/lavis-pretrained/salesforce-lavis/transformers-4.26.1-py3-none-any.whl
transformers is already installed with the same version as the provided wheel. Use --force-reinstall to force an installation of the wheel.
[0mProcessing /kaggle/input/lavis-pretrained/salesforce-lavis/huggingface_hub-0.12.0-py3-none-any.whl
Installing collected packages: huggingface-hub
  Attempting uninstall: huggingface-hub
    Found existing installation: huggingface-hub 0.12.1
    Uninstalling huggingface-hub-0.12.1:
      Successfully uninstalled huggingface-hub-0.12.1
Successfully installed huggingface-hub-0.12.0
[0mLooking in links: /kaggle/input/skt-clip-interrogator/skt-clip-interrogator/open-clip-torch-2.16.0/
Processing /kaggle/input/skt-clip-interrogator/skt-clip-interrogator/open-clip-torch-2.16.0/open_clip_torch-2.16.0-py3-none-any.whl
Processing /kaggle/input/skt-clip-interrogator/skt-clip-interrogator/open-clip-torch-2.16.0/ftfy-6.1.1-py3-none-any.whl
Installin

In [2]:
from glob import glob
import os
from tqdm import tqdm
import time
from PIL import Image
import sys
import numpy as np
import pandas as pd
from pathlib import Path
sys.path.append('/kaggle/input/sentence-transformers-222/sentence-transformers')
from sentence_transformers import SentenceTransformer, models

import torch
from torch.utils.data import Dataset, DataLoader

from transformers import AutoProcessor, BlipForConditionalGeneration
import open_clip
from safetensors.numpy import load_file

In [3]:
class CLIP_Dataset(Dataset):
    def __init__(self, image_paths, captions, preprocess):
        self.image_paths = image_paths
        self.captions = captions
        self.preprocess = preprocess

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        processed_image = self.preprocess(image)
        caption = self.captions[idx]
        return processed_image, caption


def setup_blip_model():
    # setup blip model
    blip_processor = AutoProcessor.from_pretrained("/kaggle/input/blip-pretrained-model/blip-image-captioning-large")
    blip_model = BlipForConditionalGeneration.from_pretrained("/kaggle/input/blip-pretrained-model/blip-image-captioning-large")
    
    return blip_processor, blip_model


def set_up_clip_model(device):
    
    # setup clip model
    clip_model = open_clip.create_model('ViT-H-14', precision='fp16' if device == 'cuda' else 'fp32')
    open_clip.load_checkpoint(clip_model, "/kaggle/input/skt-clip-interrogator/skt-clip-interrogator/models/CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin")
    clip_tokenizer = open_clip.get_tokenizer('ViT-H-14')
    clip_preprocess = open_clip.image_transform(
        clip_model.visual.image_size,
        is_train = False,
        mean = getattr(clip_model.visual, 'image_mean', None),
        std = getattr(clip_model.visual, 'image_std', None),
    )
    
    return clip_model, clip_tokenizer, clip_preprocess

In [4]:
def load_labels(file_path):

    with open(file_path, 'r', encoding='utf-8', errors='replace') as f:
        labels = [line.strip() for line in f.readlines()]
        
    return labels


def load_labels_and_features(label_dir, device):

    mediums_labels = load_labels(f"{label_dir}mediums.txt")
    movements_labels = load_labels(f"{label_dir}movements.txt")
    flavors_labels = load_labels(f"{label_dir}flavors.txt")
        
    mediums_embeds = load_file(f"{label_dir}mediums.safetensors")['embeds']
    movements_embeds = load_file(f"{label_dir}movements.safetensors")['embeds']
    flavors_embeds = load_file(f"{label_dir}flavors.safetensors")['embeds']

    mediums_features_array = torch.stack([torch.from_numpy(t) for t in mediums_embeds]).to(device)
    movements_features_array = torch.stack([torch.from_numpy(t) for t in movements_embeds]).to(device)
    flavors_features_array = torch.stack([torch.from_numpy(t) for t in flavors_embeds]).to(device)
    
    return {
        "labels": {"medium": mediums_labels, "movements": movements_labels, "flavors": flavors_labels}, 
        "features": {"medium": mediums_features_array, "movements": movements_features_array, "flavors": flavors_features_array}
    }


def prompt_at_max_len(text, tokenize):
    tokens = tokenize([text])
    return tokens[0][-1] != 0

def truncate_to_fit(text, tokenize):
    parts = text.split(', ')
    new_text = parts[0]
    for part in parts[1:]:
        if prompt_at_max_len(new_text + part, tokenize):
            break
        new_text += ', ' + part
    return new_text


def interrogate(image_features, caption, clip_tokenizer, label_dir, device):

    labels_and_features = load_labels_and_features(label_dir, device) 
    mediums_labels = labels_and_features["labels"]["medium"]
    movements_labels = labels_and_features["labels"]["movements"]
    flavors_labels = labels_and_features["labels"]["flavors"]

    mediums_features_array = labels_and_features["features"]["medium"] 
    movements_features_array = labels_and_features["features"]["movements"] 
    flavors_features_array = labels_and_features["features"]["flavors"] 

    cos = torch.nn.CosineSimilarity(dim=1)
    
    medium = [mediums_labels[i] for i in cos(image_features, mediums_features_array).topk(1).indices][0]
    movement = [movements_labels[i] for i in cos(image_features, movements_features_array).topk(1).indices][0]
    flaves = ", ".join([flavors_labels[i] for i in cos(image_features, flavors_features_array).topk(3).indices])

    if caption.startswith(medium):
        prompt = f"{caption}, {movement}, {flaves}"
    else:
        prompt = f"{caption}, {medium}, {movement}, {flaves}"

    return truncate_to_fit(prompt, clip_tokenizer)


def get_prompt_embeddings(comp_path):
    prompts = pd.read_csv(comp_path / 'prompts.csv', index_col='imgId')
    prompts.head(7)

    sample_submission = pd.read_csv(comp_path / 'sample_submission.csv', index_col='imgId_eId')
    sample_submission.head()
    
    st_model = SentenceTransformer('/kaggle/input/sentence-transformers-222/all-MiniLM-L6-v2')
    prompt_embeddings = st_model.encode(prompts['prompt']).flatten()
    
    assert np.all(np.isclose(sample_submission['val'].values, prompt_embeddings, atol=1e-07))

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

blip_processor, blip_model = setup_blip_model()
clip_model, clip_tokenizer, clip_preprocess = set_up_clip_model(device)
# 要搞清楚这个label是怎么来的
label_dir = "/kaggle/input/skt-clip-interrogator/skt-clip-interrogator/labels/CLIP-ViT-H-14-laion2B-s32B-b79K/"
images_root = '/kaggle/input/stable-diffusion-image-to-prompts/images'
image_ids = [i.split('.')[0] for i in os.listdir(images_root)]

image_paths = glob(f'{images_root}/*')

start_time = time.time()

# run blip for image caption
blip_model.to(device)

blip_data_loader = DataLoader(image_paths, batch_size=128, shuffle=False)
gen_kwargs = {"max_length": 20, "min_length": 5}
blip_captions = []

for batch in tqdm(blip_data_loader):

    images = []
    for image_path in batch:
        i_image = Image.open(image_path).convert("RGB")
        images.append(i_image)

    pixel_values = blip_processor(images=images, return_tensors="pt").pixel_values.to(device)
    out = blip_model.generate(pixel_values=pixel_values, **gen_kwargs)
    captions = blip_processor.batch_decode(out, skip_special_tokens=True)

    blip_captions.extend(captions)

blip_model.to("cpu")

# run clip for image features
clip_model.to(device) 

clip_dataset = CLIP_Dataset(image_paths, blip_captions, clip_preprocess) 
clip_data_loader = DataLoader(clip_dataset, batch_size=128, shuffle=False)

pred_prompts = []

for batch in tqdm(clip_data_loader):

    processed_images, captions = batch
    processed_images = processed_images.to(device)

    with torch.no_grad(), torch.cuda.amp.autocast():
        image_features = clip_model.encode_image(processed_images)

    for image_feature, caption in zip(image_features, captions):
        prompt = interrogate(image_feature, caption, clip_tokenizer, label_dir, device)
        pred_prompts.append(prompt)

print("--- %s seconds ---" % (time.time() - start_time))
print(pred_prompts)

100%|██████████| 1/1 [00:10<00:00, 10.60s/it]
100%|██████████| 1/1 [00:10<00:00, 10.11s/it]

--- 32.01980996131897 seconds ---
['arafed image of a man standing in front of a counter top, a digital rendering, conceptual art, the mighty donut, donut, at the counter', 'cartoon dinosaur - like dinosaur - like dinosaur with a cheeseburger, an illustration of, sumatraism, mmmmm, buttercup eating pizza, pastry lizard', 'a drawing of a robot robot with a robot on it, a screenprint, art brut, ((robot)), robot cat, robot design', 'a close up of a circular shaped object with a circular object in the middle, concept art, conceptual art, crater, studying a hell open rift portal, abstract holescape', 'a man in a white astronaut suit standing in front of a tree, a portrait, space art, american astronaut in the forest, astronaut walking, american astronaut', 'a close up of a circular wooden sculpture of a rose, a woodcut, op art, whorl, swirling around, wood art', 'painting of a man with a dragonfly on his head and a lizard on his head, a surrealist painting, magic realism, magic realism pain




In [6]:
comp_path = Path("/kaggle/input/stable-diffusion-image-to-prompts")
prompts = pd.read_csv(comp_path / 'prompts.csv', index_col='imgId')
prompts.head(7)

# 7 * 384(feature_dim)这么多行
sample_submission = pd.read_csv(comp_path / 'sample_submission.csv', index_col='imgId_eId')
sample_submission.head()

st_model = SentenceTransformer('/kaggle/input/sentence-transformers-222/all-MiniLM-L6-v2')
prompt_embeddings = st_model.encode(prompts['prompt']).flatten()

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

In [7]:
images = os.listdir(comp_path / 'images')
imgIds = [i.split('.')[0] for i in images]

EMBEDDING_LENGTH = 384
eIds = list(range(EMBEDDING_LENGTH))

imgId_eId = [
    '_'.join(map(str, i)) for i in zip(
        np.repeat(imgIds, EMBEDDING_LENGTH),
        np.tile(range(EMBEDDING_LENGTH), len(imgIds)))]

assert sorted(imgId_eId) == sorted(sample_submission.index)

In [8]:
prompts = ['All work and no Kaggle makes Jack a dull boy'] * len(images)
prompt_embeddings = st_model.encode(prompts).flatten()

submission = pd.DataFrame(
                index=imgId_eId,
                data=prompt_embeddings,
                columns=['val']).rename_axis('imgId_eId')

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

In [9]:
submission.to_csv('submission.csv')