In [None]:
import torch
import torch.nn as nn
from modules.vq_vae import Encoder, Decoder, VQVAE
from modules.discriminator import Discriminator
from torchinfo import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
encoder = Encoder(
    in_channels=3,
    out_channels=3,
    embed_dim=64,
    depths=[2, 2, 2, 2],
    channel_multipliers=[1, 2, 4, 8],
    activations=[nn.GELU, nn.GELU, nn.GELU],
)
encoder = encoder.to(device)


In [None]:
summary(encoder, input_size=(1, 3, 256, 256), depth=5)

In [None]:
decoder = Decoder(in_channels=3, out_channels=3, embed_dim=64, depths=[2, 2, 2, 2], channel_multipliers=[8, 4, 2, 1], activations=[nn.SELU, nn.SELU, nn.SELU])
decoder = decoder.to(device)

In [None]:
summary(decoder, input_size=(1, 3, 256, 256), depth=5)

In [None]:
vq_vae = VQVAE(in_channels=3, out_channels=3, latentd_dim=4, embed_dim=48, depths=[2, 2, 2, 2], channel_multipliers=[1, 2, 4, 8])
vq_vae = vq_vae.to(device)

In [None]:
x, loss, perplexity = vq_vae.encode(torch.randn(1, 3, 256, 256).to(device))

In [None]:
discriminator = Discriminator(in_channels=3, embed_dim=32, num_layers=3, channel_multipliers=[1, 2, 4])
discriminator = discriminator.to(device)

In [None]:
print(discriminator)

In [None]:
summary(discriminator, (3, 256, 256))