# Representation Generator for the h-space similarity explorer

In [None]:
from diffusers import AutoPipelineForText2Image, AutoencoderKL
import torch
import numpy as np
from pathlib import Path
import json
from PIL import Image
from tqdm.notebook import tqdm

## Config

In [None]:
models = [
    dict(
        short='SD-1.5',
        name='runwayml/stable-diffusion-v1-5',
        steps=50,
        guidance_scale=7.5,
        vae='default',
    ),
    dict(
        short='SD-Turbo',
        name='stabilityai/sd-turbo',
        steps=2,
        guidance_scale=0.0,
        vae='default',
    ),
    dict(
        short='SDXL-Turbo',
        name='stabilityai/sdxl-turbo',
        steps=4,
        guidance_scale=0.0,
        vae='stabilityai/sdxl-vae',
    ),
]

prompts = {
    "Cat": "A photo of a cat.",
    "Dog": "A photograph of a husky, dog, looking friendly and cute.",
    "Polarbear": "A photo of a polar bear.",
    "ConstructionWorker": "A photo of a hard working construction worker.",
    "Woman": "A photo of a beautiful, slightly smiling woman in the city.",
    "OldMan": "A portrait of an old man with a long beard and a hat.",
    "FuturisticCityscape": "A futuristic cityscape at sunset, with flying cars and towering skyscrapers, in the style of cyberpunk.",
    "MountainLandscape": "A serene mountain landscape with a crystal-clear lake in the foreground, reflecting the snow-capped peaks under a bright blue sky.",
    "SpaceAstronaut": "A high-res photo of an astronaut floating in the vastness of space, with a colorful nebula and distant galaxies in the background.",
    "MajesticLion": "A close-up portrait of a majestic lion, with detailed fur and piercing eyes, set against the backdrop of the African savannah at dusk.",
    "MagicalForest": "A magical forest filled with glowing plants, mythical creatures, and a pathway leading to an enchanted castle.",
    "JapaneseGarden": "A traditional Japanese garden in spring, complete with cherry blossoms, a koi pond, and a wooden bridge.",
}

seed = 0

## Generate Representations

In [None]:
def get_reprs(pipe, prompt, steps, guidance_scale, vae=None):
    '''Get representations and intermediate images from a model.'''
    if vae is None: vae = pipe.vae
    reprs = []
    imgs = []
    def get_repr(module, input, output):
        reprs.append(output[0].cpu().numpy())
    def latents_callback(i, t, latents):
        latents = 1 / vae.config.scaling_factor * latents.to(dtype=vae.dtype)
        image = vae.decode(latents).sample[0]
        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.cpu().permute(1, 2, 0).numpy()
        imgs.extend(pipe.numpy_to_pil(image))

    with pipe.unet.mid_block.register_forward_hook(get_repr):
        result = pipe(prompt, num_inference_steps = steps, guidance_scale=guidance_scale, callback=latents_callback, callback_steps=1, generator=torch.Generator("cuda").manual_seed(seed))
    return reprs, imgs

# setup output directory
base_path = Path('representations')
base_path.mkdir(exist_ok=True)
with open(base_path / '.gitignore', 'w') as f:
    f.write('*')

# run the models
for model_dict in models:
    model_name = model_dict['name']
    model_name_short = model_dict['short']
    vae_name = model_dict['vae']

    # load model
    pipe = AutoPipelineForText2Image.from_pretrained(model_name, torch_dtype=torch.float16).to('cuda')
    pipe.set_progress_bar_config(disable=True)  # disable progress bar
    vae = AutoencoderKL.from_pretrained(vae_name).to('cuda') if vae_name != 'default' else None

    # note model h-space dimensions
    hspace_shape = get_reprs(pipe, '', 1, 0, vae)[0][0].shape
    model_dict['hspace_channels'] = hspace_shape[0]
    model_dict['hspace_spatial'] = hspace_shape[1]

    # go through prompts
    for i, (prompt_name, prompt) in enumerate(tqdm(prompts.items(), desc=f'Running {model_name_short}')):

        # setup save path
        save_path = base_path / model_name_short / prompt_name
        save_path.mkdir(exist_ok=True, parents=True)

        # run the model
        reprs, imgs = get_reprs(pipe, prompt, model_dict['steps'], model_dict['guidance_scale'], vae)

        # save representations
        with open(save_path / 'repr.bin', 'wb') as f:
            f.write(np.array(np.stack(reprs), dtype=np.float32).tobytes())

        # save result
        for j, img in enumerate(imgs, 1):
            img.save(save_path / f'{j}.png')

        # save config
        git_hash = !git rev-parse main
        with open(save_path / 'config.json', 'w') as f:
            f.write(json.dumps({**model_dict, 'prompt_name': prompt_name, 'prompt': prompt, 'git_hash': git_hash[0], 'seed': seed}))

with open(base_path/'prompts.json', 'w') as f:
    f.write(json.dumps(prompts))
with open(base_path/'models.json', 'w') as f:
    f.write(json.dumps(models))