In [1]:
import os
import sys
import numpy as np
import pandas as pd
from pathlib import Path

import torch
from torch import nn
import torchvision
from torchvision.models import resnet34
from torchvision import transforms, datasets

from PIL import Image

sys.path.append('../../data/sentence-transformers-222/sentence-transformers')
from sentence_transformers import SentenceTransformer, models

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

comp_path = Path('../../data/stable-diffusion-image-to-prompts/')

sample_submission = pd.read_csv(comp_path / 'sample_submission.csv', index_col='imgId_eId')
sample_submission.head()

images = os.listdir(comp_path / 'images')
imgIds = [i.split('.')[0] for i in images]

EMBEDDING_LENGTH = 384
TOTAL_IMAGES = len(imgIds)
eIds = list(range(EMBEDDING_LENGTH))

imgId_eId = [
    '_'.join(map(str, i)) for i in zip(
        np.repeat(imgIds, EMBEDDING_LENGTH),
        np.tile(range(EMBEDDING_LENGTH), len(imgIds)))]

assert sorted(imgId_eId) == sorted(sample_submission.index)

In [2]:
st_model = SentenceTransformer('../../data/sentence-transformers-222/all-MiniLM-L6-v2').to(DEVICE)

In [3]:
class DiffusionImageDataset(torch.utils.data.Dataset):
    def __init__(self, data_path, prompts=None, transform=None):
        self.image_paths = [
            str(name) for name in data_path.iterdir()
        ]
        self.prompts = prompts
        self.transform = transform
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        x = Image.open(self.image_paths[idx])
        if self.transform:
            x = self.transform(x)
        
        if self.prompts:
            return x, self.prompts[idx]
        else:
            return x

In [4]:
BATCH_SIZE = 64
data_path = comp_path / 'images'

test_transform = transforms.Compose([transforms.ToTensor()])
test_dataset = DiffusionImageDataset(data_path, transform=test_transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=1,
                                        shuffle=False)

In [5]:
from transformers import AutoProcessor, AutoModelForCausalLM

processor = AutoProcessor.from_pretrained("../../data/image-caption-models/git-large-r")
model = AutoModelForCausalLM.from_pretrained("../../data/image-caption-models/git-large-r").to("cuda")

In [6]:
def get_prompts(model, processor, data_iter):
    prompts = []
    with torch.no_grad():
        for images in data_iter:
            pixel_values = processor(images=images, return_tensors="pt").pixel_values.to("cuda")
            generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
            preds = processor.batch_decode(generated_ids, skip_special_tokens=True)
            
            prompts.extend([pred.rstrip() for pred in preds])
    
    return prompts

In [7]:
prompts = get_prompts(model, processor, test_loader)
prompt_embeddings = st_model.encode(prompts).flatten()
submission = pd.DataFrame(
                index=imgId_eId,
                data=prompt_embeddings,
                columns=['val']).rename_axis('imgId_eId')

submission.head()

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Unnamed: 0_level_0,val
imgId_eId,Unnamed: 1_level_1
f27825b2c_0,-0.007689
f27825b2c_1,0.101176
f27825b2c_2,-0.031015
f27825b2c_3,-0.044353
f27825b2c_4,-0.032887


In [9]:
submission.to_csv('submission.csv')