In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
import torch
from diffusers import StableDiffusionPipeline

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# This function loads several neural networks to use together.
# The individual networks and preprocessors loaded are listed in this file:
# https://huggingface.co/CompVis/stable-diffusion-v1-4/blob/main/model_index.json
pipe = StableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    revision="fp16",
    torch_dtype=torch.float16, 
    use_auth_token=True
).to(device)


In [None]:
list(vars(pipe).keys())

In [None]:
for n, a in vars(pipe).items():
    if isinstance(a, torch.nn.Module):
        print(f'{type(a).__name__} "{n}" has:')
        print(f'{sum(p.numel() for p in a.parameters())} parameters in '
              f'{len(list(a.parameters()))} tensors in '
              f'{len(list(p.modules() for p in a.modules()))} modules')

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())
    
print(count_parameters(pipe.text_encoder))
print(count_parameters(pipe.unet))
print(count_parameters(pipe.vae))

In [None]:
from baukit import show, renormalize, pbar
from torch import autocast
import numpy

prompt = "Photo of Chewbacca and Angela Merkel solving a Rubik's cube on Boston Common"
seed = 2

# Stable Diffusion inference devised by Robin Rombach et al. (CVPR 2022, https://arxiv.org/abs/2112.10752)
# Derived from the Huggingface Stable Diffusion pipeline by Suraj Patil and others
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L16-L171
with autocast(device), torch.no_grad():
    text_tokens = pipe.tokenizer(["", prompt], padding="max_length", return_tensors="pt")['input_ids']
    text_vectors = pipe.text_encoder(text_tokens.to(device))[0]
    image_vectors = torch.from_numpy(numpy.random.RandomState(seed).randn(1, 4, 64, 64)).float().to(device)
    
    # The scheduler uses a linear multistep (PLMS) method proposed by Katherine Crowson
    # https://github.com/crowsonkb/k-diffusion
    scheduler = pipe.scheduler
    scheduler.set_timesteps(33)
    latent_scale = 0.18215
    guidance_strength = 5.0
    intermediates = []
    for t in pbar(scheduler.timesteps):
        intermediates.extend(renormalize.as_image(pipe.vae.decode(image_vectors / latent_scale)))
        # Pass two copies into the network, one to process with "" and the other with prompt.
        image_vector_input = torch.cat([image_vectors] * 2)
        # pipe.unet is a neural network inputs image_vector_inputs and text_vectors and outputs some updates
        update = pipe.unet(image_vector_input, t, text_vectors)["sample"]
        # Classifier-free guidance: see Jonathan Ho and Tim Salimans
        # (Neurips 2021 Workshop, https://arxiv.org/abs/2207.12598)
        strong_guidance = update[0] + guidance_strength * (update[1] - update[0])
        image_vectors = scheduler.step(strong_guidance, t, image_vectors)["prev_sample"]

    # pipe.vae is a neural network
    rgb_vectors = pipe.vae.decode(image_vectors / latent_scale)
    intermediates.extend(renormalize.as_image(rgb_vectors))
    show(show.WRAP, [[show.style(width=144), im] for im in intermediates])
    print('Text tokens are', tokens.shape, tokens.dtype)
    print('Text vectors are', text_vectors.shape, text_vectors.dtype)
    print('Image vectors are', image_vectors.shape, image_vectors.dtype)
    print('RGB vectors are', rgb_vectors.shape, rgb_vectors.dtype)
    

In [None]:
import os
from baukit import ImageFolderSet
from torchvision.transforms import Compose, Normalize, ToTensor
from torchvision.datasets.utils import download_and_extract_archive

if not os.path.isdir('coco_humans'):
    download_and_extract_archive('https://cs7150.baulab.info/2022-Fall/data/coco_humans.zip', 'coco_humans')

images = ImageFolderSet('coco_humans', transform=Compose([
    ToTensor(),
    Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]))
for image in images:
    safety_cheker_input = pipe.feature_extractor(image, return_tensors="pt").to(self.device)
    _, has_nsfw_concept = pipe.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
    if has_nsfw_concept:
        show(image)
