In [1]:
import torch
from models import select

In [2]:
img_size = (3,64,64)
batch_size = 2
dummy_data  = torch.randn(batch_size, *img_size)

print("Dummy data shape:", dummy_data.shape)

Dummy data shape: torch.Size([2, 3, 64, 64])


# Locatello

In [3]:
model = select('vae_locatello', img_size=img_size)
model

Model(
  (encoder): Encoder(
    (conv1): Conv2d(3, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (conv2): Conv2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (conv3): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (conv4): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (lin): Linear(in_features=1024, out_features=256, bias=True)
    (dist_statistics): Linear(in_features=256, out_features=20, bias=True)
  )
  (decoder): Decoder(
    (lin1): Linear(in_features=10, out_features=256, bias=True)
    (lin2): Linear(in_features=256, out_features=1024, bias=True)
    (convT1): ConvTranspose2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (convT2): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (convT3): ConvTranspose2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (convT4): ConvTranspose2d(32, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  )
)

In [4]:
output = model(dummy_data)

for name, tensor in output.items():
    print(f"{name}: {tensor.shape}")

reconstructions: torch.Size([2, 3, 64, 64])
stats_qzx: torch.Size([2, 10, 2])
samples_qzx: torch.Size([2, 10])


# Burgess

In [5]:
model = select('vae_burgess', img_size=img_size)
model

Model(
  (encoder): Encoder(
    (conv1): Conv2d(3, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (conv2): Conv2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (conv3): Conv2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (conv_64): Conv2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (lin1): Linear(in_features=512, out_features=256, bias=True)
    (lin2): Linear(in_features=256, out_features=256, bias=True)
    (dist_statistics): Linear(in_features=256, out_features=20, bias=True)
  )
  (decoder): Decoder(
    (lin1): Linear(in_features=10, out_features=256, bias=True)
    (lin2): Linear(in_features=256, out_features=256, bias=True)
    (lin3): Linear(in_features=256, out_features=512, bias=True)
    (convT_64): ConvTranspose2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (convT1): ConvTranspose2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (convT2): ConvTranspose2d(32, 32, kernel_size

In [6]:
output = model(dummy_data)

for name, tensor in output.items():
    print(f"{name}: {tensor.shape}")

reconstructions: torch.Size([2, 3, 64, 64])
stats_qzx: torch.Size([2, 10, 2])
samples_qzx: torch.Size([2, 10])


# Chen MLP

In [8]:
model = select('vae_chen_mlp', latent_dim = 10, img_size=img_size)
model

Model(
  (encoder): Encoder(
    (fc1): Linear(in_features=12288, out_features=1200, bias=True)
    (fc2): Linear(in_features=1200, out_features=1200, bias=True)
    (act): ReLU(inplace=True)
    (dist_statistics): Linear(in_features=1200, out_features=20, bias=True)
  )
  (decoder): Decoder(
    (net): Sequential(
      (0): Linear(in_features=10, out_features=1200, bias=True)
      (1): Tanh()
      (2): Linear(in_features=1200, out_features=1200, bias=True)
      (3): Tanh()
      (4): Linear(in_features=1200, out_features=1200, bias=True)
      (5): Tanh()
      (6): Linear(in_features=1200, out_features=12288, bias=True)
    )
  )
)

In [10]:
output = model(dummy_data)

for name, tensor in output.items():
    print(f"{name}: {tensor.shape}")

reconstructions: torch.Size([2, 3, 64, 64])
stats_qzx: torch.Size([2, 10, 2])
samples_qzx: torch.Size([2, 10])


# Locatello SBD

In [11]:
model = select('vae_locatello_sbd', img_size=img_size)
model

Model(
  (encoder): Encoder(
    (conv1): Conv2d(3, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (conv2): Conv2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (conv3): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (conv4): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (lin): Linear(in_features=1024, out_features=256, bias=True)
    (dist_statistics): Linear(in_features=256, out_features=20, bias=True)
  )
  (decoder): Decoder(
    (conv1): Conv2d(12, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (conv2): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (conv3): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (conv4): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (conv5): Conv2d(64, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  )
)

In [12]:
output = model(dummy_data)

for name, tensor in output.items():
    print(f"{name}: {tensor.shape}")

reconstructions: torch.Size([2, 3, 64, 64])
stats_qzx: torch.Size([2, 10, 2])
samples_qzx: torch.Size([2, 10])


# Montero Large

In [13]:
model = select('vae_montero_large', img_size=img_size)
model

Model(
  (encoder): Encoder(
    (conv1): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (conv2): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (conv3): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (conv4): Conv2d(128, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (conv5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (lin): Linear(in_features=1024, out_features=256, bias=True)
    (dist_statistics): Linear(in_features=256, out_features=20, bias=True)
  )
  (decoder): Decoder(
    (lin1): Linear(in_features=10, out_features=256, bias=True)
    (lin2): Linear(in_features=256, out_features=1024, bias=True)
    (convT1): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (convT2): ConvTranspose2d(128, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (convT3): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    

In [14]:
output = model(dummy_data)

for name, tensor in output.items():
    print(f"{name}: {tensor.shape}")

reconstructions: torch.Size([2, 3, 64, 64])
stats_qzx: torch.Size([2, 10, 2])
samples_qzx: torch.Size([2, 10])


# Montero Small

In [15]:
model = select('vae_montero_small', img_size=img_size)
model

Model(
  (encoder): Encoder(
    (conv1): Conv2d(3, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (conv2): Conv2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (conv3): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (conv4): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (conv5): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (lin): Linear(in_features=512, out_features=256, bias=True)
    (dist_statistics): Linear(in_features=256, out_features=20, bias=True)
  )
  (decoder): Decoder(
    (lin1): Linear(in_features=10, out_features=256, bias=True)
    (lin2): Linear(in_features=256, out_features=512, bias=True)
    (convT1): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (convT2): ConvTranspose2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (convT3): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (convT4): 

In [16]:
output = model(dummy_data)

for name, tensor in output.items():
    print(f"{name}: {tensor.shape}")

reconstructions: torch.Size([2, 3, 64, 64])
stats_qzx: torch.Size([2, 10, 2])
samples_qzx: torch.Size([2, 10])
