In [21]:
import torch
from torch import nn

In [22]:
class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, **kwargs) -> None:
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False, **kwargs),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(inplace=True)
        )
    
    def forward(self, x):
        x = self.layers(x)
        return x

BasicConv2d(3, 32, 3, stride=2)

BasicConv2d(
  (layers): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.01, inplace=True)
  )
)

In [23]:
class BasicConvTranspose2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1, **kwargs) -> None:
        super().__init__()
        self.layers = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, output_padding, bias=False, **kwargs),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(inplace=True)
        )

    def forward(self, x):
        x = self.layers(x)
        return x
BasicConvTranspose2d(64, 32, 3, stride=2)

BasicConvTranspose2d(
  (layers): Sequential(
    (0): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.01, inplace=True)
  )
)

In [24]:
class FeatureEncoder(nn.Module):
    def __init__(self, in_channels=3, out_channels=512) -> None:
        super().__init__()
        self.layers = nn.Sequential(
            BasicConv2d(in_channels, 32),
            BasicConv2d(32, 64),
            BasicConv2d(64, 128),
            BasicConv2d(128, 256),
            BasicConv2d(256, out_channels)
        )
    
    def forward(self, x):
        x = self.layers(x)
        return x

print(FeatureEncoder())
FeatureEncoder()(torch.ones((4, 3, 64, 64), dtype=torch.float)).shape

FeatureEncoder(
  (layers): Sequential(
    (0): BasicConv2d(
      (layers): Sequential(
        (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.01, inplace=True)
      )
    )
    (1): BasicConv2d(
      (layers): Sequential(
        (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.01, inplace=True)
      )
    )
    (2): BasicConv2d(
      (layers): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.01, inplace=True)
      )
    )
    (3): BasicConv2d(
      (layers): Seq

torch.Size([4, 512, 2, 2])

In [25]:
class FeatureDecoder(nn.Module):
    def __init__(self, in_channels=512, out_channels=32) -> None:
        super().__init__()
        self.layers = nn.Sequential(
            BasicConvTranspose2d(in_channels, 256),
            BasicConvTranspose2d(256, 128),
            BasicConvTranspose2d(128, 64),
            BasicConvTranspose2d(64, 32),
            BasicConvTranspose2d(32, out_channels),
        )
    
    def forward(self, x):
        x = self.layers(x)
        return x

print(FeatureDecoder())
FeatureDecoder()(torch.ones((4, 512, 2, 2), dtype=torch.float)).shape

FeatureDecoder(
  (layers): Sequential(
    (0): BasicConvTranspose2d(
      (layers): Sequential(
        (0): ConvTranspose2d(512, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.01, inplace=True)
      )
    )
    (1): BasicConvTranspose2d(
      (layers): Sequential(
        (0): ConvTranspose2d(256, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.01, inplace=True)
      )
    )
    (2): BasicConvTranspose2d(
      (layers): Sequential(
        (0): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_runn

torch.Size([4, 32, 64, 64])

In [26]:
class LatentVectorConverter(nn.Module):
    def __init__(self, in_features, latent_dim) -> None:
        super().__init__()
        self.mu_linear = nn.Linear(in_features, latent_dim)
        self.var_linear = nn.Linear(in_features, latent_dim)
    
    def forward(self, x):
        x = torch.flatten(x, 1)
        mu = self.mu_linear(x)
        var = self.var_linear(x)
        std = torch.exp(0.5 * var)
        z = mu + std * torch.randn_like(std)
        return z, mu, std


In [27]:
class Encoder(nn.Module):
    def __init__(self, in_channels=3, out_channels=512, latent_dim=128) -> None:
        super().__init__()
        self.feature_encoder = FeatureEncoder(in_channels, out_channels)
        out_img_size = 2
        self.latent_vector_converter = LatentVectorConverter(
            out_img_size * out_img_size * out_channels, latent_dim
        )

    def forward(self, x):
        x = self.feature_encoder(x)
        x, mu, std = self.latent_vector_converter(x)
        return x, mu, std


print(Encoder())
outs = Encoder()(torch.ones((4, 3, 64, 64), dtype=torch.float))
for out in outs:
    print(out.shape)


Encoder(
  (feature_encoder): FeatureEncoder(
    (layers): Sequential(
      (0): BasicConv2d(
        (layers): Sequential(
          (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): LeakyReLU(negative_slope=0.01, inplace=True)
        )
      )
      (1): BasicConv2d(
        (layers): Sequential(
          (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): LeakyReLU(negative_slope=0.01, inplace=True)
        )
      )
      (2): BasicConv2d(
        (layers): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): LeakyReLU(negative_slope=0.01,

In [28]:
class Decoder(nn.Module):
    def __init__(self, latent_dim=128, in_img_size=2, in_channels=512, out_channels=3) -> None:
        super().__init__()
        self.in_img_size = in_img_size
        self.in_channels = in_channels
        self.z_linear = nn.Linear(latent_dim, in_img_size * in_img_size * in_channels)
        self.feature_decoder = FeatureDecoder(in_channels, 32)
        self.last_layer = nn.Sequential(
            nn.Conv2d(32, out_channels, 3, 1, 1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.z_linear(x)
        x = x.reshape((-1, self.in_channels, self.in_img_size, self.in_img_size))
        x = self.feature_decoder(x)
        x = self.last_layer(x)
        return x

print(Decoder())
Decoder()(torch.ones((16, 128), dtype=torch.float)).shape

Decoder(
  (z_linear): Linear(in_features=128, out_features=2048, bias=True)
  (feature_decoder): FeatureDecoder(
    (layers): Sequential(
      (0): BasicConvTranspose2d(
        (layers): Sequential(
          (0): ConvTranspose2d(512, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): LeakyReLU(negative_slope=0.01, inplace=True)
        )
      )
      (1): BasicConvTranspose2d(
        (layers): Sequential(
          (0): ConvTranspose2d(256, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): LeakyReLU(negative_slope=0.01, inplace=True)
        )
      )
      (2): BasicConvTranspose2d(
        (layers): Sequential(
          (0): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2

torch.Size([16, 3, 64, 64])

In [29]:
class Vae(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
    
    def forward(self, x):
        x, mu, std = self.encoder(x)
        x = self.decoder(x)
        return x, mu, std

model = Vae()
print(model)
x, mu, std = model(torch.ones((16, 3, 64, 64), dtype=torch.float))
print('  x:', x.shape)
print(' mu:', mu.shape)
print('std:', std.shape)

Vae(
  (encoder): Encoder(
    (feature_encoder): FeatureEncoder(
      (layers): Sequential(
        (0): BasicConv2d(
          (layers): Sequential(
            (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): LeakyReLU(negative_slope=0.01, inplace=True)
          )
        )
        (1): BasicConv2d(
          (layers): Sequential(
            (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
            (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): LeakyReLU(negative_slope=0.01, inplace=True)
          )
        )
        (2): BasicConv2d(
          (layers): Sequential(
            (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running

In [30]:
class KLDivergenceLoss(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.mse = nn.MSELoss()
    
    def forward(self, pred, target, mu, std):
        reconst_err = self.mse(pred, target)
        regular_err = torch.sum(std ** 2 + mu ** 2 - torch.log(std ** 2) - 1) * 0.5
        return -1 * reconst_err + regular_err
    
loss = KLDivergenceLoss()
loss(x, x, mu, std)

tensor(302.3123, grad_fn=<AddBackward0>)