In [21]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
import torchinfo

In [26]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(64, momentum=0.9),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(128, momentum=0.9),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(256, momentum=0.9),
            nn.LeakyReLU(0.2),
        )

        self.fc = nn.Sequential(
            nn.Linear(256 * 8 * 8, 2048),
            nn.BatchNorm1d(2048, momentum=0.9),
            nn.LeakyReLU(0.2),
        )

        self.mu_layer = nn.Linear(2048, 128)
        self.var_layer = nn.Linear(2048, 128)

    def forward(self, imgs):
        out = self.conv_layers(imgs)
        out = nn.Flatten()(out)
        out = self.fc(out)
        mu = self.mu_layer(out)
        logvar = self.var_layer(out)
        return mu, logvar

In [27]:
# TEST ENCODER
encoder = Encoder()
print(encoder)
imgs = torch.randn(64, 3, 64, 64)
mu, logvar = encoder(imgs)
print(mu.shape, logvar.shape)

Encoder(
  (conv_layers): Sequential(
    (0): Conv2d(3, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2)
    (3): Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.2)
    (6): Conv2d(128, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (7): BatchNorm2d(256, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
    (8): LeakyReLU(negative_slope=0.2)
  )
  (fc): Sequential(
    (0): Linear(in_features=16384, out_features=2048, bias=True)
    (1): BatchNorm1d(2048, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2)
  )
  (mu_layer): Linear(in_features=2048, out_features=128, bias=True)
  (var_layer): Linear(in_features=2048, out_features=128, 

In [28]:
torchinfo.summary(encoder, input_size=(64, 3, 64, 64))

Layer (type:depth-idx)                   Output Shape              Param #
Encoder                                  --                        --
├─Sequential: 1-1                        [64, 256, 8, 8]           --
│    └─Conv2d: 2-1                       [64, 64, 32, 32]          4,864
│    └─BatchNorm2d: 2-2                  [64, 64, 32, 32]          128
│    └─LeakyReLU: 2-3                    [64, 64, 32, 32]          --
│    └─Conv2d: 2-4                       [64, 128, 16, 16]         204,928
│    └─BatchNorm2d: 2-5                  [64, 128, 16, 16]         256
│    └─LeakyReLU: 2-6                    [64, 128, 16, 16]         --
│    └─Conv2d: 2-7                       [64, 256, 8, 8]           819,456
│    └─BatchNorm2d: 2-8                  [64, 256, 8, 8]           512
│    └─LeakyReLU: 2-9                    [64, 256, 8, 8]           --
├─Sequential: 1-2                        [64, 2048]                --
│    └─Linear: 2-10                      [64, 2048]                33

In [29]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(128, 256 * 8 * 8, bias=False),
            nn.BatchNorm1d(256 * 8 * 8, momentum=0.9),
            nn.LeakyReLU(0.2),
        )

        self.deconv_layers = nn.Sequential(
            nn.ConvTranspose2d(256, 256, kernel_size=5, stride=2, padding=2, output_padding=1, bias=False),
            nn.BatchNorm2d(256, momentum=0.9),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2, padding=2, output_padding=1, bias=False),
            nn.BatchNorm2d(128, momentum=0.9),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1, bias=False),
            nn.BatchNorm2d(64, momentum=0.9),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(64, 3, kernel_size=5, stride=2, padding=2, output_padding=1, bias=False),
            nn.Tanh(),            
        )

    def forward(self, z):
        out = self.fc(z)
        out = out.view(-1, 256, 8, 8)
        recon_imgs = self.deconv_layers(out)
        return recon_imgs

In [30]:
decoder = Decoder()
print(decoder)
sigma = torch.exp(0.5 * logvar)
epsilon = torch.randn_like(sigma)
z = mu + sigma * epsilon
print(z.shape)
recon_imgs = decoder(z)
print(recon_imgs.shape)

Decoder(
  (fc): Sequential(
    (0): Linear(in_features=128, out_features=16384, bias=False)
    (1): BatchNorm1d(16384, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2)
  )
  (deconv_layers): Sequential(
    (0): ConvTranspose2d(256, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1), bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2)
    (3): ConvTranspose2d(256, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1), bias=False)
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.2)
    (6): ConvTranspose2d(128, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1), bias=False)
    (7): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
    (8): LeakyReLU(negative_slo

In [41]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, 128, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(128, momentum=0.9),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(256, momentum=0.9),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 256, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(256, momentum=0.9),
            nn.LeakyReLU(0.2),
        )

        self.fc = nn.Sequential(
            nn.Linear(256 * 8 * 8, 512),
            nn.BatchNorm1d(512, momentum=0.9),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1),
            nn.Sigmoid(),
        )

    def forward(self, imgs):
        out = self.conv_layers(imgs)
        out = nn.Flatten()(out)
        bottleneck = out
        out = self.fc(out)
        return out, bottleneck

In [42]:
discriminator = Discriminator()
print(discriminator)
imgs = torch.randn(64, 3, 64, 64)
out, bottleneck = discriminator(imgs)
print(out.shape, bottleneck.shape)

Discriminator(
  (conv_layers): Sequential(
    (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Conv2d(32, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2)
    (5): Conv2d(128, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2)
    (8): Conv2d(256, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (9): BatchNorm2d(256, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2)
  )
  (fc): Sequential(
    (0): Linear(in_features=16384, out_features=512, bias=True)
    (1): BatchNorm1d(512, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2)
    (3):

In [43]:
class VAE_GAN(nn.Module):
    def __init__(self):
        super(VAE_GAN, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.discriminator = Discriminator()
        
    def forward(self, imgs):
        batch_size = imgs.shape[0]
        mu, logvar = self.encoder(imgs)
        sigma = torch.exp(0.5 * logvar)
        epsilon = torch.randn_like(sigma)
        z = mu + sigma * epsilon
        recon_imgs = self.decoder(z)
        return mu, logvar, recon_imgs

In [44]:
imgs = torch.randn(64, 3, 64, 64)
vae_gan = VAE_GAN()
mu, logvar, recon_imgs = vae_gan(imgs)
print(mu.shape, logvar.shape, recon_imgs.shape)

torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 3, 128, 128])
