In [None]:
from diffusers import AutoPipelineForText2Image, AutoencoderKL
from torchvision import transforms
from datasets import load_dataset, load_from_disk, Dataset, Features, Array3D, concatenate_datasets, DatasetDict
from datasets.arrow_dataset import logging
import torch
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from typing import Optional
from pathlib import Path
import numpy as np
from contextlib import ExitStack
from functools import partial
import tempfile
import itertools

logging.disable_progress_bar()

In [None]:
def print_shape(x):
    if isinstance(x, tuple):
        return '(' + ', '.join(print_shape(y) for y in x) + ')'
    else:
        return str(x.shape)


def batch_process_images(images, pipe, extract_positions: list[str], noise: float) -> dict[str, np.ndarray]:
    pipe.unet.config.addition_embed_type = 'nothing_at_all_lol'
    if next(pipe.vae.modules()).dtype == torch.float16:
        latents = pipe.vae.encode(images.half().cuda()).latent_dist.sample()
    else:
        latents = pipe.vae.encode(images.cuda()).latent_dist.sample().half()
    # TODO: is normal(0,1) the right noise distribution?
    if noise > 0: latents = (1-noise) * latents + noise * torch.randn_like(latents)
    # TODO: is an empty prompt the right way to do this?
    prompt_embeds, *_ = pipe.encode_prompt(prompt="", device="cuda", num_images_per_prompt=latents.shape[0], do_classifier_free_guidance=False)

    # Run inference with representation extraction hooks
    representations = {}
    with ExitStack() as stack, torch.no_grad():
        for extract_position in extract_positions:
            def hook_fn(module, input, output, extract_position):
                # print(extract_position, print_shape(output))
                if isinstance(output, tuple):
                    output = output[0].cpu().numpy()  # TODO: is it good to always take the first output and ignore the rest?
                representations[extract_position] = output
            # eval is unsafe. Do not use in production.
            stack.enter_context(eval(f'pipe.unet.{extract_position}', {'__builtins__': {}, 'pipe': pipe}).register_forward_hook(partial(hook_fn, extract_position=extract_position)))
        # TODO: is this the right number of timesteps?
        # TODO: setup sdxl-turbo
        pipe.unet(latents, pipe.scheduler.config.num_train_timesteps, encoder_hidden_states=prompt_embeds, return_dict=False)

    return representations




def transform_to_tensor(ds):
    # TODO: maybe batch this
    transform_pipeline = transforms.Compose([
        transforms.CenterCrop(min(ds['image'].size)),
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
    ])
    ds['image'] = transform_pipeline(ds['image'].convert("RGB") if ds['image'].mode != "RGB" else ds['image'])
    return ds


def sd_dataset_generator(dataloader, pipe, extract_positions, noise, max_samples=None):
    status = tqdm(total=max_samples, desc="Generating representations")
    for batch in dataloader:
        status.update(len(batch['image']))
        representations = batch_process_images(batch['image'], pipe, extract_positions, noise)
        for i, label in enumerate(batch['label']):
            yield {"label": label} | {pos: x[i] for pos, x in representations.items()}


def create_sd_dataset(
        output_path: str = 'sd_representations_dataset',
        model_name: str = 'runwayml/stable-diffusion-v1-5',
        dataset_name: str = 'imagenet-1k',
        extract_positions: list[str] = ['mid_block'],
        batch_size: int = 8,
        max_samples_per_split: dict[str, int] = {},
        noise: float = 0.0,
        dataset_column_rename: dict[str, str] = {},
        vae = None,
    ) -> Dataset:
    '''Create a dataset of representations from a Stable Diffusion model.

    Notes:
        This code is unsafe, do not use in production.

    Args:
        output_path: Path to save the dataset to.
        model_name: Name of the Stable Diffusion model to use.
        dataset_name: Name of the dataset to use.
        split: Split of the dataset to use.
        extract_positions: List of positions in the U-Net to extract representations from. If the output is a tuple, only the first element is used.
        batch_size: Batch size to use.
        max_samples: Maximum number of samples to use. If None, use the entire dataset.
        noise: Noise to add to the latent space. Must be between 0 and 1.
    '''
    # check arguments
    if Path(output_path).exists():
        raise ValueError('Output path already exists')
    if not 0 <= noise <= 1:
        raise ValueError('Noise must be between 0 and 1')

    # load model
    print(f'loading model `{model_name}`')
    if 'sdxl' in model_name:
        # sdxl needs 32bit vae
        vae = AutoencoderKL.from_pretrained('stabilityai/sdxl-vae')
        pipe = AutoPipelineForText2Image.from_pretrained(model_name, torch_dtype=torch.float16, vae=vae).to("cuda")
    else:
        pipe = AutoPipelineForText2Image.from_pretrained(model_name, torch_dtype=torch.float16).to("cuda")

    # create new dataset
    print(f'Setting up new dataset with featues `{extract_positions+["label"]}`')
    tmp_dataset_dict = load_dataset(dataset_name)
    tmp_dataset = list(tmp_dataset_dict.values())[0].rename_columns(dataset_column_rename)
    tmp = batch_process_images(next(iter(DataLoader(tmp_dataset.to_iterable_dataset().take(1).map(transform_to_tensor), batch_size=1)))['image'], pipe, extract_positions, noise)
    total_max_samples = sum(max_samples_per_split.get(split, len(tmp_dataset_dict[split])) for split in tmp_dataset_dict.keys())
    print(f'The new dataset will be roughly {sum(x.nbytes for x in tmp.values()) * total_max_samples / 1e9:.2f} GB')
    features = Features({
        **{pos: Array3D(shape=x.shape[1:], dtype="float16") for pos, x in tmp.items() if pos != "label"},
        "label": tmp_dataset.features["label"],
    })

    # load dataset
    with tempfile.TemporaryDirectory() as tmp_ouput_path:
        # go through the all splits
        tmp_paths = {}
        for split in tmp_dataset_dict.keys():
            print(f'loading source dataset `{dataset_name}` ({split} split)')
            original_dataset = load_dataset(dataset_name, split=split)
            max_samples = max_samples_per_split.get(split, len(original_dataset))
            dataset = original_dataset.to_iterable_dataset().take(max_samples).rename_columns(dataset_column_rename).map(transform_to_tensor)
            dataloader = DataLoader(dataset, batch_size=batch_size)

            print(f'Generating new dataset for split `{split}`')
            tmp_paths[split] = []
            for i, batch in enumerate(tqdm(dataloader, total = max_samples//batch_size, desc = "Generating representations")):
                tmp = batch_process_images(batch['image'], pipe, extract_positions, noise)
                tmp_dataset = Dataset.from_dict({'label': batch['label']} | tmp, features=features)
                path = Path(tmp_ouput_path) / f'{split}-{i:05d}'
                tmp_dataset.save_to_disk(path)
                tmp_paths[split].append(path)
            print('Concatenating generated datasets')
        
        new_dataset = DatasetDict({
            split: concatenate_datasets([load_from_disk(x) for x in tmp_paths[split]])
            for split in tmp_dataset_dict.keys()
        })

        # save new dataset
        print(f'saving new dataset to `{output_path}`')
        new_dataset.save_to_disk(output_path)

    return new_dataset


In [None]:
# sd_dataset = create_sd_dataset(
#     output_path='../repr_dataset_test123',
#     model_name='runwayml/stable-diffusion-v1-5',
#     dataset_name='zh-plus/tiny-imagenet',
#     extract_positions=['mid_block'],
#     batch_size=4,
#     noise=0.,
#     # dataset_column_rename={'img':'image', 'fine_label': 'label'},  # for cifar100
# )
# sd_dataset

In [None]:
model_names = ['runwayml/stable-diffusion-v1-5','stabilityai/sd-turbo']#,'stabilityai/sdxl-turbo']
noise_levels = [0.0]#, 0.01, 0.1,0.2,0.5,0.8]
dataset_names_and_column_renames = [
    ('mnist', {}),
    ('cifar10', {'img':'image'}),
    ('cifar100', {'img':'image', 'fine_label': 'label'}),
    ('zh-plus/tiny-imagenet', {}),
]
count = np.prod([len(x) for x in [model_names, noise_levels, dataset_names_and_column_renames]])
for model_name, noise_level, (dataset_name, column_rename) in tqdm(itertools.product(model_names, noise_levels, dataset_names_and_column_renames), total=count, desc='Generating datasets'):
    print('#'*80)
    print(f'Creating dataset for model `{model_name}`, dataset `{dataset_name}`, noise `{noise_level}`')
    print('#'*80)
    sd_dataset = create_sd_dataset(
        output_path='../datasets-tmp/'+f'{model_name}-{dataset_name}-{noise_level}'.replace('/', '-'),
        model_name=model_name,
        dataset_name=dataset_name,
        extract_positions=['mid_block'],
        batch_size=4,
        max_samples_per_split={'train': 60, 'test': 10, 'valid': 10},
        noise=noise_level,
        dataset_column_rename=column_rename,
    )