# Representation Generator for the h-space similarity explorer

In [None]:
import numpy as np
from pathlib import Path
import json
from sdwrapper import SD
from typing import Any
from tqdm.notebook import tqdm

## Config

In [None]:
models: list[dict[str,Any]] = [
    dict(
        short='SDXL-Turbo',
        extract_positions = None,  # use all
    ),
    dict(
        short='SDXL-Lightning',
        extract_positions = ['down_blocks[0]', 'down_blocks[1]', 'down_blocks[2]', 'mid_block', 'conv_out'],
    ),
    dict(
        short='SD-Turbo',
        extract_positions = None,  # use all
    ),
    dict(
        short='SD-1.5',
        extract_positions = ['down_blocks[2]', 'down_blocks[3]', 'mid_block', 'up_blocks[0]', 'conv_out'],
    ),
]

prompts = {
    "Cat": "A photo of a cat.",
    "Dog": "A photograph of a husky, dog, looking friendly and cute.",
    "Woman": "A photo of a beautiful, slightly smiling woman in the city.",
    "OldMan": "A portrait of an old man with a long beard and a hat.",
    "ConstructionWorker": "A photo of a hard working construction worker.",
    "FuturisticCityscape": "A futuristic cityscape at sunset, with flying cars and towering skyscrapers, in the style of cyberpunk.",
    "MountainLandscape": "A serene mountain landscape with a crystal-clear lake in the foreground, reflecting the snow-capped peaks under a bright blue sky.",
    "SpaceAstronaut": "A high-res photo of an astronaut floating in the vastness of space, with a colorful nebula and distant galaxies in the background.",
    "MagicalForest": "A magical forest filled with glowing plants, mythical creatures, and a pathway leading to an enchanted castle.",
    "JapaneseGarden": "A traditional Japanese garden in spring, complete with cherry blossoms, a koi pond, and a wooden bridge.",
}

seed = 0

## Generate Representations

In [None]:
# setup output directory
base_path = Path('representations')
base_path.mkdir(exist_ok=True)
with open(base_path / '.gitignore', 'w') as f:
    f.write('*')

# run the models
for model_dict in models:
    name_short = model_dict['short']
    sd = SD(name_short, disable_progress_bar=True)
    extract_positions = model_dict['extract_positions'] or sd.available_extract_positions
    del model_dict['extract_positions']  # not needed anymore

    def get_reprs(prompt):
        result = sd(prompt, seed=seed, extract_positions=sd.available_extract_positions)
        representations = {}
        for pos, reprs in result.representations.items():
            representations[pos] = []
            for repr in reprs:
                while isinstance(repr, tuple) or len(repr.shape) > 3:
                    # ignore classifier-free model, batch-dimension, ...
                    repr = repr[0]
                representations[pos].append(repr.cpu().permute(1, 2, 0).numpy())
        return representations, result.images

    # fill model dict
    model_dict |= {x: sd.config[x] for x in ['name', 'steps', 'guidance_scale']}

    # note model h-space dimensions
    representations, _ = get_reprs('')
    model_dict['representations'] = {}
    for pos, reprs in representations.items():
        model_dict['representations'][pos] = {
            'channels': reprs[0].shape[-1],
            'spatial': reprs[0].shape[-2],
            'available': pos in extract_positions,
        }

    # go through prompts
    for i, (prompt_name, prompt) in enumerate(tqdm(prompts.items(), desc=f'Running {name_short}')):

        # setup save path
        save_path = base_path / name_short / prompt_name
        save_path.mkdir(exist_ok=True, parents=True)

        # run the model
        representations, images = get_reprs(prompt)

        # save representations
        for pos in extract_positions:
            for j, repr in enumerate(representations[pos], 0):
                with open(save_path / f'repr-{pos}-{j}.bin', 'wb') as f:
                    f.write(np.array(repr, dtype=np.float16).tobytes())

        # save intermediate images
        for j, img in enumerate(images, 1):
            img.save(save_path / f'{j}.jpg')

        # save config
        git_hash = !git rev-parse main
        with open(save_path / 'config.json', 'w') as f:
            f.write(json.dumps({**model_dict, 'prompt_name': prompt_name, 'prompt': prompt, 'git_hash': git_hash[0], 'seed': seed}))

# save global config files
with open(base_path/'prompts.json', 'w') as f:
    f.write(json.dumps(prompts))
with open(base_path/'models.json', 'w') as f:
    f.write(json.dumps(models))