In [1]:
import torch
import torch.nn as nn
from modules.vq_vae import Encoder, Decoder, VQVAE
from modules.discriminator import Discriminator
from torchinfo import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
encoder = Encoder(
    in_channels=3,
    out_channels=3,
    embed_dim=64,
    depths=[2, 2, 2, 2],
    channel_multipliers=[1, 2, 4, 8],
)
encoder = encoder.to(device)


In [3]:
summary(encoder, input_size=(1, 3, 256, 256), depth=5)

Layer (type:depth-idx)                        Output Shape              Param #
Encoder                                       [1, 3, 32, 32]            --
├─Conv2d: 1-1                                 [1, 64, 256, 256]         256
├─ModuleList: 1-2                             --                        --
│    └─ResidualBlock: 2-1                     [1, 128, 128, 128]        --
│    │    └─ModuleList: 3-1                   --                        --
│    │    │    └─ResidualLayer: 4-1           [1, 64, 256, 256]         --
│    │    │    │    └─Conv2d: 5-1             [1, 64, 256, 256]         36,928
│    │    │    │    └─GroupNorm: 5-2          [1, 64, 256, 256]         128
│    │    │    │    └─GELU: 5-3               [1, 64, 256, 256]         --
│    │    │    │    └─Conv2d: 5-4             [1, 64, 256, 256]         36,928
│    │    │    │    └─GroupNorm: 5-5          [1, 64, 256, 256]         128
│    │    │    │    └─GELU: 5-6               [1, 64, 256, 256]         --
│    │   

In [2]:
decoder = Decoder(in_channels=3, out_channels=3, embed_dim=64, depths=[2, 2, 2, 2], channel_multipliers=[8, 4, 2, 1])
decoder = decoder.to(device)

In [2]:
vq_vae = VQVAE(in_channels=3, out_channels=3, latentd_dim=4, embed_dim=48, depths=[2, 2, 2, 2], channel_multipliers=[1, 2, 4, 8])
vq_vae = vq_vae.to(device)

In [3]:
x, loss, perplexity = vq_vae.encode(torch.randn(1, 3, 256, 256).to(device))

In [2]:
discriminator = Discriminator(in_channels=3, embed_dim=32, num_layers=3, channel_multipliers=[1, 2, 4])
discriminator = discriminator.to(device)

In [3]:
print(discriminator)

Discriminator(
  (activation): ReLU()
  (layers): ModuleList(
    (0): Sequential(
      (0): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): GroupNorm(1, 64, eps=1e-05, affine=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): GroupNorm(1, 128, eps=1e-05, affine=True)
      (2): ReLU()
    )
  )
  (initial_conv): Conv2d(3, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (initial_norm): GroupNorm(1, 32, eps=1e-05, affine=True)
  (final_conv): Conv2d(128, 1, kernel_size=(1, 1), stride=(1, 1))
)


In [4]:
summary(discriminator, (3, 256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 128, 128]           1,568
         GroupNorm-2         [-1, 32, 128, 128]              64
              ReLU-3         [-1, 32, 128, 128]               0
            Conv2d-4           [-1, 64, 64, 64]          32,832
         GroupNorm-5           [-1, 64, 64, 64]             128
              ReLU-6           [-1, 64, 64, 64]               0
            Conv2d-7          [-1, 128, 32, 32]         131,200
         GroupNorm-8          [-1, 128, 32, 32]             256
              ReLU-9          [-1, 128, 32, 32]               0
           Conv2d-10            [-1, 1, 32, 32]             129
Total params: 166,177
Trainable params: 166,177
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.75
Forward/backward pass size (MB): 21.01
Params size (MB): 0.63
Estimated 