In [1]:
import torch 
import sys
sys.path.insert(0, 'src/')

from src.model.vae import VAE_Encoder, VAE_Decoder
from src.model.config import StableDiffusionConfig
from src.model.unet import UNet
from src.model.clip import CLIPEncoder

from src.model.diffusion import StableDiffusion

from torchinfo import summary

config = StableDiffusionConfig()

Stable Diffusion

In [2]:
img = torch.randn((1, config.img_channels, config.img_size, config.img_size))
tokens = torch.randint(low=0, high=config.vocab_size, size=(1, config.clip_seq_len))
time = torch.randn((1, config.unet_time_emb_dim))

In [3]:
model = StableDiffusion(config)

In [4]:
img_size = (1, config.img_channels, config.img_size, config.img_size)

In [5]:
model(img, tokens, time)

torch.Size([1, 1280, 4, 4])
torch.Size([1, 1280, 4, 4])


RuntimeError: Expected weight to be a vector of size equal to the number of channels in input, but got weight of shape [2560] and input of shape [1, 1280, 4, 4]

In [None]:
summary(model, [img_size, (1, config.clip_seq_len)])

Layer (type:depth-idx)                                  Output Shape              Param #
StableDiffusion                                         [32, 4, 32, 32]           882,807,447
├─VAE_Encoder: 1-1                                      [32, 4, 32, 32]           --
│    └─Conv2d: 2-1                                      [1, 128, 256, 256]        3,456
│    └─ModuleList: 2-2                                  --                        --
│    │    └─VAE_Block: 3-1                              [1, 128, 128, 128]        442,880
│    │    └─VAE_Block: 3-2                              [1, 256, 64, 64]          1,508,352
│    │    └─VAE_Block: 3-3                              [1, 512, 32, 32]          6,031,360
│    └─Sequential: 2-3                                  [1, 512, 32, 32]          --
│    │    └─VAE_Block: 3-4                              [1, 512, 32, 32]          4,720,640
│    │    └─VAE_Block: 3-5                              [1, 512, 32, 32]          4,720,640
│    │    └─VAE

VAE

In [8]:
enc = VAE_Encoder(config.img_channels, config.vae_features_dims, config.vae_num_groups, config.vae_num_heads, config.vae_dropout, config.vae_latent_dim)

print(sum(p.numel() for p in enc.parameters()))

23237584


UNET

In [4]:
model = UNet(config.vae_latent_dim, config.unet_features_dims, config.unet_attn_num_heads, config.unet_attn_dim, config.unet_time_emb_dim, config.unet_time_emb_dim_scale_factor)

In [5]:
sum(p.numel() for p in model.parameters())

852620804

In [17]:
[img.shape, noise.shape]

[torch.Size([1, 3, 512, 512]), torch.Size([1, 4, 64, 64])]

In [9]:
summary(enc, [(1, 3, 512, 512), (1, 4, 64, 64)])

Layer (type:depth-idx)                             Output Shape              Param #
VAE_Encoder                                        [1, 4, 64, 64]            --
├─Conv2d: 1-1                                      [1, 128, 512, 512]        3,456
├─ModuleList: 1-2                                  --                        --
│    └─VAE_Block: 2-1                              [1, 128, 256, 256]        --
│    │    └─PrenormResidualConnection: 3-1         [1, 128, 512, 512]        147,712
│    │    └─PrenormResidualConnection: 3-2         [1, 128, 512, 512]        147,712
│    │    └─Conv2d: 3-3                            [1, 128, 256, 256]        147,456
│    └─VAE_Block: 2-2                              [1, 256, 128, 128]        --
│    │    └─PrenormResidualConnection: 3-4         [1, 256, 256, 256]        328,192
│    │    └─PrenormResidualConnection: 3-5         [1, 256, 256, 256]        590,336
│    │    └─Conv2d: 3-6                            [1, 256, 128, 128]        589,824
│ 

In [None]:
summary()

In [6]:
time = torch.randn((1, config.unet_time_emb_dim))
model(out, time)

TypeError: SwitchSequential.forward() missing 1 required positional argument: 'time'

In [2]:

tokens.shape

torch.Size([1, 77])

In [3]:
clip = CLIPEncoder(config.vocab_size, config.clip_emb_dim, config.clip_seq_len,
                   config.clip_attn_num_heads, config.clip_emb_dim_scale_factor, 
                   config.clip_num_layers, config.clip_dropout).to('cuda')

In [4]:
sum(p.numel() for p in clip.parameters())

123060480

In [1]:
import torch

torch.save({
    'epoch' : 1
}, "src/model/weights/stable_diffusion_7.pth")

In [None]:
UNet_AttentionBlock

In [6]:
text_emb = clip(tokens)

In [7]:
text_emb.shape

torch.Size([1, 77, 768])

In [4]:
from torch import nn

In [5]:
out.shape

torch.Size([1, 4, 64, 64])

In [6]:
dec = VAE_Decoder(config.vae_latent_dim, config.vae_features_dims, config.vae_num_groups, config.vae_dropout, config.vae_num_heads, config.img_channels)

In [7]:
print(sum(p.numel() for p in dec.parameters()))

32367379


In [8]:
dec = dec.to('cuda')

In [9]:
dec(out).shape

torch.Size([1, 4, 64, 64])
torch.Size([1, 512, 64, 64])
torch.Size([1, 512, 64, 64])
torch.Size([1, 512, 128, 128])
torch.Size([1, 256, 256, 256])
torch.Size([1, 128, 512, 512])


torch.Size([1, 3, 512, 512])