# VAE encoder

> a VAE module for percpetion.

In [None]:
#| default_exp models.vae

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *

In [None]:
#| export
"""
Variational encoder model, used as a visual model
for our model of the world.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    """VAE encoder for 32×32 RGB images"""
    def __init__(self, img_channels, latent_size):
        super().__init__()

        self.conv1 = nn.Conv2d(img_channels,  32, 4, stride=2, padding=1)   # 32→16
        self.conv2 = nn.Conv2d(32,  64, 4, stride=2, padding=1)             # 16→8
        self.conv3 = nn.Conv2d(64, 128, 4, stride=2, padding=1)             # 8→4
        self.conv4 = nn.Conv2d(128, 256, 4, stride=2, padding=1)            # 4→2

        self.fc_mu        = nn.Linear(256 * 2 * 2, latent_size)
        self.fc_logsigma  = nn.Linear(256 * 2 * 2, latent_size)

    def forward(self, x):
        x = F.relu(self.conv1(x))  
        x = F.relu(self.conv2(x))  
        x = F.relu(self.conv3(x))  
        x = F.relu(self.conv4(x))  
        print(x.shape)
        x = x.view(x.size(0), -1)  # flatten

        mu = self.fc_mu(x)
        logsigma = self.fc_logsigma(x)
        return mu, logsigma





In [None]:
#| hide
enc = Encoder(3, 32)
x = torch.randn(16, 3, 32, 32)
mu, logsigma = enc(x)
mu.shape, logsigma.shape

torch.Size([16, 256, 2, 2])


(torch.Size([16, 32]), torch.Size([16, 32]))

In [None]:
ten = torch.randn(4, 32)
ten.shape

torch.Size([4, 32])

In [None]:
ten

tensor([[-2.5277e-02, -8.6355e-01, -2.4682e+00, -1.8421e+00,  1.3121e-01,
          2.3854e+00, -8.8746e-01, -4.6337e-01,  4.9301e-01,  6.7903e-01,
          5.4017e-01, -1.8188e+00,  9.8022e-01,  1.0151e+00, -2.5583e-01,
         -3.9285e-01, -8.1354e-01,  1.0092e+00,  2.0673e-01,  6.0336e-01,
          5.1131e-01,  5.8135e-01,  1.8519e+00,  1.4168e+00, -4.1078e-01,
         -2.5738e-01, -1.1851e+00, -3.6064e-01,  1.1961e-01, -2.6908e-01,
          9.0220e-02,  5.6336e-01],
        [ 1.5776e-03, -4.2801e-01, -2.1140e-01,  4.2696e-01, -5.6730e-01,
          1.4228e+00,  1.4029e+00, -2.1389e+00, -1.2331e+00, -1.1081e-01,
         -3.6752e-01,  2.0156e+00, -2.2204e+00, -2.0780e-01,  1.0776e+00,
         -7.0024e-01, -3.1071e-01,  1.2378e-01,  1.6334e-01, -3.1914e-01,
         -1.7702e-01, -3.9492e-02,  1.3265e-01, -1.0066e-01,  6.1435e-01,
         -2.1366e-01, -4.9868e-01,  2.3917e-01, -1.1005e+00,  5.3529e-01,
          1.3505e+00, -3.2555e-01],
        [ 1.3088e+00, -8.0877e-01, -4.48

In [None]:
torch.softmax(ten, dim=-1)

tensor([[0.0179, 0.0077, 0.0016, 0.0029, 0.0210, 0.1996, 0.0076, 0.0116, 0.0301,
         0.0362, 0.0315, 0.0030, 0.0490, 0.0507, 0.0142, 0.0124, 0.0081, 0.0504,
         0.0226, 0.0336, 0.0306, 0.0329, 0.1171, 0.0758, 0.0122, 0.0142, 0.0056,
         0.0128, 0.0207, 0.0140, 0.0201, 0.0323],
        [0.0221, 0.0144, 0.0178, 0.0338, 0.0125, 0.0914, 0.0896, 0.0026, 0.0064,
         0.0197, 0.0153, 0.1653, 0.0024, 0.0179, 0.0647, 0.0109, 0.0161, 0.0249,
         0.0259, 0.0160, 0.0185, 0.0212, 0.0251, 0.0199, 0.0407, 0.0178, 0.0134,
         0.0280, 0.0073, 0.0376, 0.0850, 0.0159],
        [0.0550, 0.0066, 0.0095, 0.0215, 0.0033, 0.0160, 0.0261, 0.0070, 0.0125,
         0.0153, 0.0165, 0.0829, 0.0116, 0.0067, 0.0313, 0.0107, 0.0132, 0.2676,
         0.0068, 0.0270, 0.1244, 0.0211, 0.0017, 0.0152, 0.0077, 0.0123, 0.0073,
         0.0145, 0.0619, 0.0046, 0.0245, 0.0578],
        [0.0120, 0.0029, 0.0059, 0.0421, 0.0404, 0.0282, 0.0078, 0.0217, 0.0025,
         0.0055, 0.0068, 0.0981, 0.0158,

In [None]:
torch.softmax(ten, dim=-1).shape

In [None]:
# sample from the softmax distribution
torch.argmax(torch.softmax(ten, dim=-1), dim=-1)

tensor([ 5, 11, 17, 26])

In [None]:
#| export
class Decoder(nn.Module):
    """VAE decoder for 32×32 RGB images"""
    def __init__(self, img_channels, latent_size):
        super().__init__()

        self.fc1 = nn.Linear(latent_size, 256 * 2 * 2)

        self.deconv1 = nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1)  # 2→4
        self.deconv2 = nn.ConvTranspose2d(128,  64, 4, stride=2, padding=1)  # 4→8
        self.deconv3 = nn.ConvTranspose2d(64,   32, 4, stride=2, padding=1)  # 8→16
        self.deconv4 = nn.ConvTranspose2d(32, img_channels, 4, stride=2, padding=1) # 16→32

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = x.view(-1, 256, 2, 2)

        x = F.relu(self.deconv1(x))
        x = F.relu(self.deconv2(x))
        x = F.relu(self.deconv3(x))

        reconstruction = torch.sigmoid(self.deconv4(x))
        return reconstruction



In [None]:
#| hide
dec = Decoder(3, 32)
z = torch.randn(16, 32)
recon = dec(z)
recon.shape

torch.Size([16, 3, 32, 32])

In [None]:
#| export
class VAE(nn.Module):
    """ Variational Autoencoder """
    def __init__(self, img_channels, latent_size):
        super(VAE, self).__init__()
        self.encoder = Encoder(img_channels, latent_size)
        self.decoder = Decoder(img_channels, latent_size)

    def forward(self, x): # pylint: disable=arguments-differ
        mu, logsigma = self.encoder(x)
        sigma = logsigma.exp()
        eps = torch.randn_like(sigma)
        z = eps.mul(sigma).add_(mu)

        recon_x = self.decoder(z)
        return recon_x, mu, logsigma

In [None]:
#| hide
model = VAE(3, 32)
x = torch.randn(16, 3, 32, 32)
mu, logsigma = model.encoder(x)
mu.shape


torch.Size([16, 32])

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()