In [None]:
from diffusers import AutoPipelineForText2Image
from diffusers.models import AutoencoderKL
from torchvision import transforms
from datasets import load_dataset, Dataset, Features, Array3D, concatenate_datasets
from datasets.arrow_dataset import logging
import torch
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm, trange
from typing import Optional
from pathlib import Path
import numpy as np
from contextlib import ExitStack
from functools import partial
from trainplot.trainplot import TrainPlotPlotlyExperimental as TrainPlot
from datetime import datetime
from torch import nn, optim
from time import sleep

logging.disable_progress_bar()

In [None]:
class SimpleClassifier(nn.Module):
    def __init__(self, channel_size=1280, spatial_size=8, num_classes=1000):
        super().__init__()
        self.fc = nn.Linear(channel_size, num_classes)
        self.pool = nn.AvgPool2d(spatial_size)

    def forward(self, x):
        x = torch.flatten(self.pool(x), start_dim=1)
        x = self.fc(x)
        return x


class SimpleCNNClassifier(nn.Module):
    def __init__(self, channel_size=1280, spatial_size=8, num_classes=1000):
        super().__init__()
        self.conv = nn.Conv2d(channel_size, num_classes, kernel_size=1, padding=0)
        self.pool = nn.MaxPool2d(spatial_size)

    def forward(self, x):
        x = self.conv(x)
        x = torch.flatten(self.pool(x), start_dim=1)
        return x


def check_nans(*args, **kwargs):
    for k, v in [*enumerate(args),*kwargs.items()]:
        if torch.isnan(v).any():
            raise ValueError(f"Nans in {k}")

In [None]:
def batch_process_images(images, pipe, extract_positions: list[str], noise: float) -> dict[str, torch.Tensor]:
    pipe.unet.config.addition_embed_type = 'nothing_at_all_lol'
    if next(pipe.vae.modules()).dtype == torch.float16:
        latents = pipe.vae.encode(images.half().cuda()).latent_dist.sample()
    else:
        latents = pipe.vae.encode(images.cuda()).latent_dist.sample().half()
    check_nans(latents=latents)
    # TODO: is normal(0,1) the right noise distribution?
    if noise > 0: latents = (1-noise) * latents + noise * torch.randn_like(latents)
    # TODO: is an empty prompt the right way to do this?
    prompt_embeds, *_ = pipe.encode_prompt(prompt="", device="cuda", num_images_per_prompt=latents.shape[0], do_classifier_free_guidance=False)
    check_nans(prompt_embeds=prompt_embeds)

    # Run inference with representation extraction hooks
    representations = {}
    with ExitStack() as stack, torch.no_grad():
        for extract_position in extract_positions:
            def hook_fn(module, input, output, extract_position):
                # print(extract_position, print_shape(output))
                if isinstance(output, tuple):
                    output = output[0]  # TODO: is it good to always take the first output and ignore the rest?
                representations[extract_position] = output
            # eval is unsafe. Do not use in production.
            stack.enter_context(eval(f'pipe.unet.{extract_position}', {'__builtins__': {}, 'pipe': pipe}).register_forward_hook(partial(hook_fn, extract_position=extract_position)))
        # TODO: is this the right number of timesteps?
        # TODO: setup sdxl-turbo
        pipe.unet(latents, pipe.scheduler.config.num_train_timesteps, encoder_hidden_states=prompt_embeds, return_dict=False)

    return representations


def transform_to_tensor(ds):
    # TODO: maybe batch this
    transform_pipeline = transforms.Compose([
        transforms.CenterCrop(min(ds['image'].size)),
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
    ])
    ds['image'] = transform_pipeline(ds['image'].convert("RGB") if ds['image'].mode != "RGB" else ds['image'])
    return ds


def sd_dataset_generator(dataset, pipe, extract_positions, noise_levels, batch_size):
    dataloader = DataLoader(dataset.to_iterable_dataset().map(transform_to_tensor), batch_size=batch_size)
    for batch in dataloader:
        representations = {noise: batch_process_images(batch['image'], pipe, extract_positions, noise) for noise in noise_levels}
        for i, label in enumerate(batch['label']):
            data = {noise: {pos: representations[noise][pos][i] for pos in extract_positions} for noise in noise_levels}
            yield label, data



In [None]:
tp_loss = TrainPlot(update_period=1)
tp_accuracy = TrainPlot(update_period=1)

In [None]:
# config
generation_batch_size = 4
training_batch_size = 64
dataset_name = "cifar100"
dataset_column_rename = {'img':'image', 'fine_label': 'label'}
model_name = 'runwayml/stable-diffusion-v1-5'  # e.g. stabilityai/sdxl-turbo or runwayml/stable-diffusion-v1-5
model_classes = [SimpleClassifier, SimpleCNNClassifier]
optimizer_name = optim.Adam
extract_positions = ['down_blocks[0]','down_blocks[3]','mid_block','up_blocks[0]','up_blocks[3]']
noise_levels = [0., .5]

# load model and dataset
dataset = load_dataset(dataset_name, split='train')
dataset = dataset.rename_columns(dataset_column_rename)
# AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")  # 32 bit AutoencoderKL is required for sdxl
pipe = AutoPipelineForText2Image.from_pretrained(model_name, torch_dtype=torch.float16).to("cuda")
representation_iter = sd_dataset_generator(dataset, pipe, extract_positions, noise_levels, generation_batch_size)

# setup classification models to train
tmp = next(sd_dataset_generator(load_dataset(dataset_name, split='train').rename_columns(dataset_column_rename), pipe, extract_positions, [0.], 1))[1][0.]
representation_shapes = {k: v.shape for k, v in tmp.items() if k != 'label'}
models = {}
for model_cls in model_classes:
    for pos in extract_positions:
        for noise in noise_levels:
            models[f'{model_cls.__name__}-{pos}-{noise}'] = model_cls(
                channel_size = representation_shapes[pos][0],
                spatial_size = representation_shapes[pos][-1],
            ).to("cuda")
optimizers = {name: optimizer_name(model.parameters(), lr=1e-3) for name, model in models.items()}

# train
losses = {name: [] for name in models.keys()}
accuracies = {name: [] for name in models.keys()}
try:
    for step in trange(dataset.num_rows // training_batch_size):
        labels, representations = zip(*[next(representation_iter) for _ in range(training_batch_size)])
        for noise in noise_levels:
            for pos in extract_positions:
                x = torch.stack([r[noise][pos] for r in representations]).float()
                y = torch.tensor(labels).cuda()
                check_nans(x=x, y=y)
                for cls_name in model_classes:
                    name = f'{cls_name.__name__}-{pos}-{noise}'
                    model = models[name]
                    check_nans(*model.parameters())
                    optimizers[name].zero_grad()
                    model.train()
                    y_hat = model(x)
                    loss = nn.CrossEntropyLoss()(y_hat, y)
                    check_nans(y_hat=y_hat, loss=loss)
                    loss.backward()
                    optimizers[name].step()
                    check_nans(*model.parameters())
                    losses[name].append(loss.item())
                    accuracies[name].append((y_hat.argmax(dim=1) == y).float().mean().item())
                    tp_loss(step=step, **{name: np.mean(losses[name][-100:])})
                    tp_accuracy(step=step, **{name: np.mean(accuracies[name][-100:])})
except KeyboardInterrupt:
    print('Got Keyboard Interrupt')
finally:
    if step > 1:
        print('Saving models...')
        folder = Path(f'../classifier-models/onfly-{datetime.now().strftime("%Y-%m-%d-%H-%M-%S")}')
        print(f'Saving models to `{folder}`')
        folder.mkdir(parents=True, exist_ok=True)
        for name, model in models.items():
            torch.save(model.state_dict(), folder / f'{name}.pth')