In [None]:
from diffusers import AutoPipelineForText2Image
import torch
import numpy as np
from pathlib import Path
import json
from PIL import Image

In [None]:
model = 'runwayml/stable-diffusion-v1-5'
steps = 50

In [None]:
pipe = AutoPipelineForText2Image.from_pretrained(model, torch_dtype=torch.float16).to('cuda')

In [None]:
def get_reprs(prompt):
    reprs = []
    imgs = []
    def get_repr(module, input, output):
        reprs.append(output[0].cpu().numpy())
    def latents_callback(i, t, latents):
        latents = 1 / 0.18215 * latents
        image = pipe.vae.decode(latents).sample[0]
        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.cpu().permute(1, 2, 0).numpy()
        imgs.extend(pipe.numpy_to_pil(image))

    with pipe.unet.mid_block.register_forward_hook(get_repr):
        pipe(prompt, callback=latents_callback, callback_steps=1)
    return reprs, imgs

prompts = {
    "Dog": "A photo of a dog.",
    "Cat": "A photo of a cat.",
    "Polarbear": "A photo of a polar bear.",
    "FuturisticCityscape": "A futuristic cityscape at sunset, with flying cars and towering skyscrapers, in the style of cyberpunk.",
    "MountainLandscape": "A serene mountain landscape with a crystal-clear lake in the foreground, reflecting the snow-capped peaks under a bright blue sky.",
    "SpaceAstronaut": "An astronaut floating in the vastness of space, with a colorful nebula and distant galaxies in the background.",
    "MedievalMarketplace": "A medieval marketplace bustling with people, stalls filled with fruits, vegetables, and handmade goods, with a castle in the distance.",
    "MajesticLion": "A close-up portrait of a majestic lion, with detailed fur and piercing eyes, set against the backdrop of the African savannah at dusk.",
    "AbstractCubism": "An abstract painting featuring swirling colors and geometric shapes, evoking the style of cubism.",
    "VintageStreet": "A vintage 1950s street scene, with classic cars, neon signs, and pedestrians dressed in period attire.",
    "MagicalForest": "A magical forest filled with glowing plants, mythical creatures, and a pathway leading to an enchanted castle.",
    "SoccerGoal": "A dynamic sports scene capturing the moment a soccer player scores a goal, with the crowd cheering in the background.",
    "JapaneseGarden": "A traditional Japanese garden in spring, complete with cherry blossoms, a koi pond, and a wooden bridge.",
}
base_path = Path('representations')
base_path.mkdir(exist_ok=True)
outputs = []
for i, (name, prompt) in enumerate(prompts.items()):
    print(f'Prompt: {prompt}')
    save_path = base_path / name
    save_path.mkdir(exist_ok=True)
    reprs, imgs = get_reprs(prompt)
    # saving representations
    with open(save_path / 'repr.bin', 'wb') as f:
        f.write(np.array(np.stack(reprs), dtype=np.float32).tobytes())
    # saving result
    for j, img in enumerate(imgs, 1):
        img.save(save_path / f'{j}.png')
    # save config
    with open(save_path / 'config.json', 'w') as f:
        f.write(json.dumps({'model': model, 'steps': steps, 'prompt': prompt}))
    outputs.append(save_path)
with open(base_path / '.gitignore', 'w') as f:
    f.write('*')
with open(base_path/'outputs.txt', 'w') as f:
    f.write('\n'.join(str(x) for x in outputs))