In [1]:
from pathlib import PosixPath

from PIL import Image
import numpy as np
from transformers import AutoProcessor, Blip2ForConditionalGeneration
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader

## 1 - Load inference model

In [2]:
model_dir = '../data/models/fine_tuned/'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Blip2ForConditionalGeneration.from_pretrained(
    model_dir, 
    max_memory={"cpu": "1GIB"}, 
    offload_state_dict=True, 
    # dtorch_device=device,
    torch_dtype=torch.float16).to(device)
processor = AutoProcessor.from_pretrained(model_dir)

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [3]:
def inference(image, pre_prompt=None, dtype=torch.float16):
    inputs = processor(image, text=pre_prompt, return_tensors="pt").to(device, dtype)

    generated_ids = model.generate(**inputs, max_new_tokens=128, num_beams=3)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()

    return generated_text

In [4]:
def batch_inference(images, pre_prompt=None, dtype=torch.float16):
    inputs = processor(images, text=pre_prompt, return_tensors="pt").to(device, dtype)

    generated_ids = model.generate(**inputs, max_new_tokens=128, num_beams=3)
    generated_text = [item.strip() for item in processor.batch_decode(generated_ids, skip_special_tokens=True)]

    return generated_text

## 2 - Setup inference images

In [5]:
data_dir = PosixPath('../data/images')
images = data_dir.glob('**/*.png')

In [6]:
class ImageDataset(Dataset):
    def __init__(self, path: str):
        self.path = PosixPath(path)
        self.images = list(self.path.glob('**/*.png'))

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx: int):
        image_path = self.images[idx]
        return Image.open(image_path), image_path.name.split('.')[0]
    
def collate_fn(data):
    batch = {'images': [], 'ids': []}
    for row in data:
        batch['images'].append(row[0])
        batch['ids'].append(row[1])
    
    return batch

In [7]:
image_dataset = ImageDataset('../data/images')
image_loader = DataLoader(image_dataset, shuffle=False, batch_size=4, pin_memory=True, collate_fn=collate_fn)

## 3 - Run image2prompt inference on all data

In [8]:
model.eval()
row_idx = -1
prompt_df = pd.DataFrame(columns=['imgId', 'prompt'])
with torch.no_grad():
    for batch in image_loader:
        # batch = next(iter(image_loader))
        prompts = batch_inference(batch['images'])
        for idx, prompt in enumerate(prompts):
            row_idx += 1
            img_id = batch['ids'][idx]
            prompt_df.loc[row_idx] = [img_id, prompt]



In [9]:
prompt_df

Unnamed: 0,imgId,prompt
0,20057f34d,"a black hole in the center of the earth, by mi..."
1,227ef0887,"wooden sculpture, intricate detail, 8 k"
2,92e911621,a dinosaur eating a cheese pizza
3,a4e1c55a9,"a drawing of a robot, in a style of a cartoon"
4,c98f79f71,"portrait of a man in a dinosaur costume, by gr..."
5,d8edf2e40,an astronaut in a space suit standing in front...
6,f27825b2c,a donut in a donut shop


## 4 - Load sentence embedding model

In [10]:
from sentence_transformers import SentenceTransformer, models

In [11]:
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')

In [12]:
embedding_size = 384

def embed_prompts(prompts_df: pd.DataFrame) -> np.ndarray:
    return embedding_model.encode(prompts_df['prompt'])

def build_submission(prompts_df: pd.DataFrame) -> pd.DataFrame:
    prompt_embedding = embed_prompts(prompts_df)
    output_df = pd.DataFrame(columns=['imgId_eId', 'val'])
    
    for prompt_idx, _ in enumerate(prompts_df.prompt):
        imgId = prompts_df.iloc[prompt_idx].imgId
        embed_vec = prompt_embedding[prompt_idx]

        for embed_idx, val in enumerate(embed_vec):
            row_idx = (embedding_size * prompt_idx) + embed_idx
            output_df.loc[row_idx] = (f'{imgId}_{embed_idx}', val)
    
    return output_df

In [13]:
output_submission_df = build_submission(prompt_df)

In [14]:
output_submission_df

Unnamed: 0,imgId_eId,val
0,20057f34d_0,-0.023036
1,20057f34d_1,0.021516
2,20057f34d_2,0.010190
3,20057f34d_3,0.064952
4,20057f34d_4,0.015292
...,...,...
2683,f27825b2c_379,0.001352
2684,f27825b2c_380,-0.061341
2685,f27825b2c_381,-0.027693
2686,f27825b2c_382,0.035682


In [15]:
# output_submission_df.to_csv('submission.csv', index=False)