In [1]:
import torch
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
import pandas as pd
from PIL import Image
from tqdm.cli import tqdm
from typing import Optional, Dict
import os, os.path as osp

In [2]:
data_dir = '../data'
filename = 'low_cosine_text.csv'
out_directory = 'generated_images'
overwrite_images = False

In [3]:
file_path = osp.join(data_dir, filename)
out_directory = osp.join(data_dir, out_directory)

In [4]:
df = pd.read_csv(file_path)

In [5]:
def get_input(prompt: str, batch_size: Optional[int]=1) -> Dict:
    generator = [torch.Generator('cuda').manual_seed(i) for i in range(batch_size)]
    prompts = batch_size * [prompt]
    num_inference_steps = 50
    height = 768 
    width = 768 

    return {'prompt': prompts, 'generator': generator, 'num_inference_steps': num_inference_steps, 'height': height, 'width': width}

In [6]:
ckpt = 'stabilityai/stable-diffusion-2'
pipe = StableDiffusionPipeline.from_pretrained(ckpt, torch_dtype=torch.float16)
pipe = pipe.to('cuda')
pipe.enable_attention_slicing()
# pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

Fetching 16 files:   0%|          | 0/16 [00:00<?, ?it/s]

ftfy or spacy is not installed using BERT BasicTokenizer instead of ftfy.


In [7]:
df.iloc[0]

imageid          00049706-717d-437b-b1b0-16cc39dab0dc
score                                        0.250753
prompt     vector mouse cursor in the style of diablo
Name: 0, dtype: object

In [8]:
if not osp.exists(out_directory):
    os.mkdir(out_directory)

In [10]:
# Iterate over all image ids
for i in tqdm(range(len(df))): # Not vectorising on purpose
    row = df.iloc[i]

    id = row.imageid
    prompt = row.prompt

    # Create a folder named after the ID
    generated_dir = osp.join(data_dir, out_directory, str(i))
    if osp.exists(generated_dir):
        if overwrite_images:
            os.rmdir(generated_dir)
            os.mkdir(generated_dir)
        elif len(os.listdir(generated_dir)) > 0:
            # Skip to next prompt
            continue
    else:
        os.mkdir(generated_dir)

    # Generate images
    inputs = get_input(prompt, batch_size=4)
    images = pipe(**inputs)

    for idx, image in enumerate(images.images):
        image.save(osp.join(generated_dir, f'image_{idx}.png'))

  0%|          | 0/2661 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 1/2661 [00:32<24:15:36, 32.83s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 2/2661 [01:05<23:59:46, 32.49s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 3/2661 [01:37<23:49:42, 32.27s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 4/2661 [02:09<23:45:09, 32.18s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 5/2661 [02:41<23:45:31, 32.20s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 5/2661 [02:51<25:16:03, 34.25s/it]


KeyboardInterrupt: 