In [33]:
from diffusers import AutoencoderKL
import torch
import torch.nn.functional as F

### 1. Figure out encoder size

In [66]:
encoder = AutoencoderKL(
        in_channels = 9,
        out_channels = 9,
        down_block_types = tuple(["DownEncoderBlock2D"]*8),
        up_block_types = tuple(["UpDecoderBlock2D"]*8),
        block_out_channels = (9, 9, 16, 32, 64, 128, 256, 512),
        norm_num_groups=1,
        latent_channels=512
)

In [67]:
out = encoder.encode(torch.randn((1, 9, 432, 432))).latent_dist.mode()
out.shape

torch.Size([1, 512, 3, 3])

In [68]:
F.avg_pool2d(out, kernel_size=(3, 3)).shape

torch.Size([1, 512, 1, 1])

In [70]:
encoder.encoder

Encoder(
  (conv_in): Conv2d(9, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (down_blocks): ModuleList(
    (0-1): 2 x DownEncoderBlock2D(
      (resnets): ModuleList(
        (0): ResnetBlock2D(
          (norm1): GroupNorm(1, 9, eps=1e-06, affine=True)
          (conv1): LoRACompatibleConv(9, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): GroupNorm(1, 9, eps=1e-06, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): LoRACompatibleConv(9, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
        )
      )
      (downsamplers): ModuleList(
        (0): Downsample2D(
          (conv): LoRACompatibleConv(9, 9, kernel_size=(3, 3), stride=(2, 2))
        )
      )
    )
    (2): DownEncoderBlock2D(
      (resnets): ModuleList(
        (0): ResnetBlock2D(
          (norm1): GroupNorm(1, 9, eps=1e-06, affine=True)
          (conv1): LoRACompatibleConv(9, 16, kernel_size=(3, 3), stride=(1,

In [71]:
sum(p.numel() for p in encoder.encoder.parameters() if p.requires_grad)

20905384

### 2. Figure out decoder size

In [81]:
decoder = AutoencoderKL(
    in_channels = 1,
    out_channels = 1,
    down_block_types = ("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"),
    up_block_types = ("UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"),
    block_out_channels = (8, 16, 32, 64),
    norm_num_groups=1,
    latent_channels=1  # maybe try changing this later to allow deeper representation
)

In [82]:
out = decoder.decode(torch.randn((1, 1, 54, 54))).sample
out.shape

torch.Size([1, 1, 432, 432])

In [83]:
sum(p.numel() for p in decoder.decoder.parameters() if p.requires_grad)

426449