# Variational Autoencoder (VAE) Refinement - Session Notes

## 1. Initial Code, Problems, and First Refinements
- Started with a VAE implementation in PyTorch.
- Identified issues:
    - Hardcoded flattening/reshaping sizes in encoder/decoder.
    - No flexibility for loss type or activation function.
    - Some aspects of documentation and maintainability could improve.

### First Refinements Highlights
- Used utility methods for encoder/decoder construction.
- Consistently type-annotated and documented all methods.
- Provided flexibility in model design and loss computation.


## 2. Further Refinement Points
- Dynamically calculate and use encoder output shape in flatten/reshape.
- Parameterize activation function and loss type.
- Avoid overuse of `nn.Sequential`.
- Provide comprehensive input shape/type checking.
- Improve testing practices for model flexibility and error handling.

In [None]:

import torch
from torch import nn
from torch.nn import functional as F
from typing import List, Tuple, Callable, Any, Dict, Optional
from dataclasses import dataclass

@dataclass
class VAEOutput:
    loss: torch.Tensor
    recon_loss: torch.Tensor
    kld: torch.Tensor

class VAE(nn.Module):
    def __init__(
        self,
        x_dim: int,
        input_shape: Tuple[int, int, int],   # (C, H, W)
        hidden_dims: Optional[List[int]] = None,
        latent_dim: int = 16,
        activation: Optional[Callable[[], nn.Module]] = None,
        recon_loss_type: str = "mse",
    ):
        super().__init__()
        assert isinstance(input_shape, (tuple, list)) and len(input_shape) == 3,             "input_shape must be a tuple (channels, H, W)"
        self.x_dim = x_dim
        self.recon_loss_type = recon_loss_type.lower()
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]
        self.hidden_dims = hidden_dims

        if activation is None:
            self.act = nn.LeakyReLU(0.2)
        else:
            self.act = activation()

        # Encoder
        modules = []
        in_channels = x_dim
        for h_dim in hidden_dims:
            modules.append(nn.Conv2d(in_channels, h_dim, kernel_size=3, stride=2, padding=1))
            modules.append(nn.BatchNorm2d(h_dim))
            modules.append(self.act)
            in_channels = h_dim
        self.encoder = nn.Sequential(*modules)

        # Compute encoded output shape
        with torch.no_grad():
            dummy = torch.zeros(1, *input_shape)
            enc_out = self.encoder(dummy)
            self.enc_out_shape = enc_out.shape[1:]  # (C, H, W)
            self.enc_flat_dim = enc_out.numel() // enc_out.shape[0]

        self.fc_mu = nn.Linear(self.enc_flat_dim, latent_dim)
        self.fc_var = nn.Linear(self.enc_flat_dim, latent_dim)

        # Decoder
        self.decoder_input = nn.Linear(latent_dim, self.enc_flat_dim)
        decoder_modules = []
        reversed_hidden_dims = list(reversed(hidden_dims))
        in_channels = self.enc_out_shape[0]
        for i in range(len(reversed_hidden_dims) - 1):
            decoder_modules.append(
                nn.ConvTranspose2d(
                    in_channels, reversed_hidden_dims[i + 1],
                    kernel_size=3, stride=2, padding=1, output_padding=1
                )
            )
            decoder_modules.append(nn.BatchNorm2d(reversed_hidden_dims[i + 1]))
            decoder_modules.append(self.act)
            in_channels = reversed_hidden_dims[i + 1]

        self.decoder = nn.Sequential(*decoder_modules)
        self.final_layer = nn.Sequential(
            nn.ConvTranspose2d(
                in_channels, x_dim, kernel_size=3, stride=2,
                padding=1, output_padding=1
            ),
            nn.Tanh()
        )

    def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        result = self.encoder(x)
        result = torch.flatten(result, start_dim=1)
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)
        return mu, log_var

    def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z: torch.Tensor) -> torch.Tensor:
        result = self.decoder_input(z)
        result = result.view(-1, *self.enc_out_shape)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        recon_x = self.decode(z)
        return recon_x, x, mu, log_var

    def loss_function(
        self, recons: torch.Tensor, input: torch.Tensor,
        mu: torch.Tensor, log_var: torch.Tensor, kld_weight: float = 1.0
    ) -> VAEOutput:
        if self.recon_loss_type == 'bce':
            recon_loss = F.binary_cross_entropy(recons, input, reduction='mean')
        else:
            recon_loss = F.mse_loss(recons, input)

        kld = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1).mean()
        loss = recon_loss + kld_weight * kld
        return VAEOutput(loss=loss, recon_loss=recon_loss.detach(), kld=kld.detach())

    def sample(self, batch_size: int, device: torch.device) -> torch.Tensor:
        z = torch.randn(batch_size, self.fc_mu.out_features).to(device)
        samples = self.decode(z)
        return samples

    def generate(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward(x)[0]
    


## 3. Example Test Suite for the Refined VAE

- Shows how to use `unittest` to check initialization, forward pass, loss, and shape adaptiveness.
- Demonstrates best practices for flexible PyTorch model testing.


In [None]:

import torch
import unittest

class TestVAE(unittest.TestCase):
    def setUp(self):
        self.input_channels = 3
        self.input_size = 64
        self.latent_dim = 10
        self.model = VAE(
            x_dim=self.input_channels,
            input_shape=(self.input_channels, self.input_size, self.input_size),
            hidden_dims=[32, 64, 128],
            latent_dim=self.latent_dim,
        )

    def test_vae_initialization(self):
        self.assertIsInstance(self.model.encoder, torch.nn.Module)
        self.assertIsInstance(self.model.decoder, torch.nn.Module)

    def test_vae_forward(self):
        batch = 8
        H = W = self.input_size
        x = torch.randn(batch, self.input_channels, H, W)
        recon_x, inp_x, mu, log_var = self.model(x)
        self.assertEqual(recon_x.shape, (batch, self.input_channels, H, W))
        self.assertEqual(inp_x.shape, (batch, self.input_channels, H, W))
        self.assertEqual(mu.shape, (batch, self.latent_dim))
        self.assertEqual(log_var.shape, (batch, self.latent_dim))

    def test_vae_loss(self):
        batch = 4
        x = torch.randn(batch, self.input_channels, self.input_size, self.input_size)
        recon_x, inp_x, mu, log_var = self.model(x)
        loss_out = self.model.loss_function(recon_x, inp_x, mu, log_var, kld_weight=0.01)
        self.assertTrue(hasattr(loss_out, 'loss'))
        self.assertTrue(loss_out.loss.requires_grad)
        self.assertGreaterEqual(loss_out.loss.item(), 0)

    def test_enc_dec_shapes(self):
        # Test for different input shape (32x32)
        model32 = VAE(
            x_dim=3, input_shape=(3, 32, 32),
            hidden_dims=[32, 64], latent_dim=5)
        x = torch.randn(2, 3, 32, 32)
        out, _, mu, log_var = model32(x)
        self.assertEqual(out.shape, (2, 3, 32, 32))
        self.assertEqual(mu.shape, (2, 5))
        self.assertEqual(log_var.shape, (2, 5))

# If running interactively, remove the following block or use: unittest.main(argv=[''], exit=False)
# if __name__ == '__main__':
#     unittest.main()



## 4. Summary

- The refined VAE is modular, robust, and model-agnostic, supporting various channel, height, and width sizes.
- Testing demonstrates flexibility for various inputs, shapes, and code safety.
- Always provide `input_shape` as a tuple (C, H, W).

### Further Reading

- [VAE Theory - Kingma & Welling 2014](https://arxiv.org/abs/1312.6114)
- [PyTorch VAE Tutorial](https://pytorch.org/tutorials/beginner/torchvision_tutorial.html)

---

**End of Session Export**
