The purpose of this notebook is play around with the architecture of the diffusion models used by the ControlNet-XS authors: StableDiffusion 2.1 and StableDiffusionXL

___

In [6]:
from util import public_attrs

___

In [7]:
from diffusers import StableDiffusionPipeline
import torch

In [8]:
pipe21 = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float32)

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

In [9]:
from diffusers import StableDiffusionXLPipeline

In [10]:
pipexl = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float32)

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [11]:
pipe15 = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float32)

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["bos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["eos_token_id"]` will be overriden.


___

In [12]:
unet21, unetxl, unet15 = pipe21.unet, pipexl.unet, pipe15.unet
vae21, vaexl, vae15 = pipe21.vae, pipexl.vae, pipe15.vae

In [13]:
x = torch.rand(1,3,512,512)

In [9]:
latents21 = vae21(x).sample
latentsxl = vaexl(x).sample
latents15 = vae15(x).sample

In [10]:
latents21.shape, latentsxl.shape, latents15.shape

(torch.Size([1, 3, 512, 512]),
 torch.Size([1, 3, 512, 512]),
 torch.Size([1, 3, 512, 512]))

**Q:** This is wrong. A vae tranforms an image `3x512x512` to a latent `4x64x64`. What am I missing?

In [11]:
public_attrs(vae15, contains=['enc', 'cal', 'for'])

['call_super_init',
 'disable_xformers_memory_efficient_attention',
 'enable_xformers_memory_efficient_attention',
 'encode',
 'encoder',
 'forward',
 'ignore_for_config',
 'register_forward_hook',
 'register_forward_pre_hook',
 'set_use_memory_efficient_attention_xformers',
 'tiled_encode']

In [12]:
vae15.encode??

[0;31mSignature:[0m [0mvae15[0m[0;34m.[0m[0mencode[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m <no docstring>
[0;31mSource:[0m   
    [0;32mdef[0m [0mwrapper[0m[0;34m([0m[0mself[0m[0;34m,[0m [0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0;32mif[0m [0mhasattr[0m[0;34m([0m[0mself[0m[0;34m,[0m [0;34m"_hf_hook"[0m[0;34m)[0m [0;32mand[0m [0mhasattr[0m[0;34m([0m[0mself[0m[0;34m.[0m[0m_hf_hook[0m[0;34m,[0m [0;34m"pre_forward"[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m            [0mself[0m[0;34m.[0m[0m_hf_hook[0m[0;34m.[0m[0mpre_forward[0m[0;34m([0m[0mself[0m[0;34m)[0m[0;34m[0m
[0;34m[0m        [0;32mreturn[0m [0mmethod[0m[0;34m([0m[0mself[0m[0;34m,[0m [0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0

In [13]:
bla = vae15.forward(x)

In [14]:
bla.sample.shape

torch.Size([1, 3, 512, 512])

What the fuck?!

___

Let's go the opposite way & decode the latents

This line

    im_rec_21 =   vae21.decode(latents21).sample  # reconstructed image

produces an error
    
    --> 459 return F.conv2d(input, weight, bias, self.stride,
        460                 self.padding, self.dilation, self.groups)
    
    RuntimeError: Given groups=1, weight of size [4, 4, 1, 1], expected input[1, 3, 512, 512] to have 4 channels, but got 3 channels instead

This means `vae21.decode` can't handle input of shape `bla, 3, 512, 512`. This is as expected, good!

Let's now decode an input of shape `blub, 4, 64, 64`

In [15]:
rand_im = torch.rand(1,3,512,512)
rand_lat = torch.rand(1,4,64,64)

In [16]:
vae21.decode(rand_lat).sample.shape

torch.Size([1, 3, 512, 512])

Works, ie turns a shape `(1,4,64,64)` into `(1,3,512,512)`, good. But why does it's opposite function then not turn `(1,3,512,512)` into `(1,4,64,64)`, but intead remains at `(1,3,512,512)`?

In [17]:
vae21.encode(rand_im).latent_dist.sample().shape

torch.Size([1, 4, 64, 64])

It doesn't! Aha!

In [18]:
vae21(x).sample.shape

torch.Size([1, 3, 512, 512])

Aha again! `__call__` is **not the same** as `encode`. What I want is `encode`, while `__call__` seems to both `encode` and then `decode`.
This makes sense actually, as the vae's job is to reproduce a given image.

___

In [19]:
for n, vae in zip(('vae21','vaexl','vae15'), (vae21,vaexl,vae15)):
    print(n, vae.encode(rand_im).latent_dist.sample().shape)

vae21 torch.Size([1, 4, 64, 64])
vaexl torch.Size([1, 4, 64, 64])
vae15 torch.Size([1, 4, 64, 64])


---

In [14]:
from util import simple_describe

### Description of SDXL

In [16]:
simple_describe(unetxl.conv_in)

 Conv2d (4, 320)


In [15]:
simple_describe(unetxl.down_blocks)

 ModuleList 
	 DownBlock2D 
		 ResnetBlock2D (320, 320)
		 ResnetBlock2D (320, 320)
		 Downsample2D (320, 320)
	 CrossAttnDownBlock2D 
		 ResnetBlock2D (320, 640)
		 Transformer2DModel (640, 640)
		 ResnetBlock2D (640, 640)
		 Transformer2DModel (640, 640)
		 Downsample2D (640, 640)
	 CrossAttnDownBlock2D 
		 ResnetBlock2D (640, 1280)
		 Transformer2DModel (1280, 1280)
		 ResnetBlock2D (1280, 1280)
		 Transformer2DModel (1280, 1280)


### Description of SD21

In [17]:
simple_describe(unet21.conv_in)

 Conv2d (4, 320)


In [18]:
simple_describe(unet21.down_blocks)

 ModuleList 
	 CrossAttnDownBlock2D 
		 ResnetBlock2D (320, 320)
		 Transformer2DModel (320, 320)
		 ResnetBlock2D (320, 320)
		 Transformer2DModel (320, 320)
		 Downsample2D (320, 320)
	 CrossAttnDownBlock2D 
		 ResnetBlock2D (320, 640)
		 Transformer2DModel (640, 640)
		 ResnetBlock2D (640, 640)
		 Transformer2DModel (640, 640)
		 Downsample2D (640, 640)
	 CrossAttnDownBlock2D 
		 ResnetBlock2D (640, 1280)
		 Transformer2DModel (1280, 1280)
		 ResnetBlock2D (1280, 1280)
		 Transformer2DModel (1280, 1280)
		 Downsample2D (1280, 1280)
	 DownBlock2D 
		 ResnetBlock2D (1280, 1280)
		 ResnetBlock2D (1280, 1280)


### Description of SD15

In [19]:
simple_describe(unet15.conv_in)

 Conv2d (4, 320)


In [20]:
simple_describe(unet15.down_blocks)

 ModuleList 
	 CrossAttnDownBlock2D 
		 ResnetBlock2D (320, 320)
		 Transformer2DModel (320, 320)
		 ResnetBlock2D (320, 320)
		 Transformer2DModel (320, 320)
		 Downsample2D (320, 320)
	 CrossAttnDownBlock2D 
		 ResnetBlock2D (320, 640)
		 Transformer2DModel (640, 640)
		 ResnetBlock2D (640, 640)
		 Transformer2DModel (640, 640)
		 Downsample2D (640, 640)
	 CrossAttnDownBlock2D 
		 ResnetBlock2D (640, 1280)
		 Transformer2DModel (1280, 1280)
		 ResnetBlock2D (1280, 1280)
		 Transformer2DModel (1280, 1280)
		 Downsample2D (1280, 1280)
	 DownBlock2D 
		 ResnetBlock2D (1280, 1280)
		 ResnetBlock2D (1280, 1280)
