# Variational Autoencoder

In [42]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResBlock(nn.Module):
    def __init__(self, channels, num_groups):
        super(ResBlock, self).__init__()
        self.gn1 = nn.GroupNorm(num_groups, channels)
        self.gn2 = nn.GroupNorm(num_groups, channels)
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

    def forward(self, x):
        residual = x
        out = self.gn1(x)
        out = F.silu(out)
        out = self.conv1(out)
        out = self.gn2(out)
        out = F.silu(out)
        out = self.conv2(out)
        out += residual
        return out

class CVAE(nn.Module):
    """Convolutional Variational Autoencoder with ResNet connections."""
    def __init__(self, base_channels, num_groups=8, dropout_prob=0.5):
        super(CVAE, self).__init__()
        self.latent_dim = 128
        self.dropout_prob = dropout_prob

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, base_channels, kernel_size=3, padding=1),
            self._conv_block(base_channels, base_channels, stride=2, num_groups=num_groups),  # (256, 256) -> (128, 128)
            ResBlock(base_channels, num_groups),
            self._conv_block(base_channels, base_channels * 2, stride=2, num_groups=num_groups),  # (128, 128) -> (64, 64)
            ResBlock(base_channels * 2, num_groups),
            self._conv_block(base_channels * 2, base_channels * 4, stride=2, num_groups=num_groups),  # (64, 64) -> (32, 32)
            ResBlock(base_channels * 4, num_groups),
            self._conv_block(base_channels * 4, base_channels * 8, stride=2, num_groups=num_groups),  # (32, 32) -> (16, 16)
            ResBlock(base_channels * 8, num_groups),
            nn.Flatten(),
        )

        self.flattened_dims = base_channels * 8 * 16 * 16
        self.fc_mean = nn.Linear(self.flattened_dims, self.latent_dim)
        self.fc_logvar = nn.Linear(self.flattened_dims, self.latent_dim)

        self.decoder_input = nn.Linear(self.latent_dim, self.flattened_dims)

        self.decoder = nn.Sequential(
            nn.Unflatten(1, (base_channels * 8, 16, 16)),
            ResBlock(base_channels * 8, num_groups),
            self._conv_transpose_block(base_channels * 8, base_channels * 4, num_groups=num_groups),  # (16, 16) -> (32, 32)
            ResBlock(base_channels * 4, num_groups),
            self._conv_transpose_block(base_channels * 4, base_channels * 2, num_groups=num_groups),  # (32, 32) -> (64, 64)
            ResBlock(base_channels * 2, num_groups),
            self._conv_transpose_block(base_channels * 2, base_channels, num_groups=num_groups),  # (64, 64) -> (128, 128)
            ResBlock(base_channels, num_groups),
            self._conv_transpose_block(base_channels, base_channels // 2, num_groups=num_groups),  # (128, 128) -> (256, 256)
            nn.GroupNorm(num_groups, base_channels // 2),
            nn.SiLU(),
            nn.Conv2d(base_channels // 2, 3, kernel_size=3, padding=1),
            nn.Sigmoid()
        )

    def _conv_block(self, in_channels, out_channels, stride=1, num_groups=8):
        return nn.Sequential(
            nn.GroupNorm(num_groups, in_channels),
            nn.SiLU(),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1),
            nn.SiLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.Dropout(self.dropout_prob) 
            
        )

    def _conv_transpose_block(self, in_channels, out_channels, num_groups=8):
        return nn.Sequential(
            nn.GroupNorm(num_groups, in_channels),
            nn.SiLU(),
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.SiLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.Dropout(self.dropout_prob)
        )

    def encode(self, x):
        x = self.encoder(x)
        return self.fc_mean(x), self.fc_logvar(x)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        z = self.decoder_input(z)
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

    @staticmethod
    def recommend_num_groups(base_channels):
        """Recommend the number of groups for GroupNorm based on base channels."""
        if base_channels < 16:
            return 4
        elif base_channels < 64:
            return 8
        elif base_channels < 256:
            return 16
        else:
            return 32


base_channels = 32  
num_groups = CVAE.recommend_num_groups(base_channels)
dropout_prob = 0.1 
img_size = 256

model = CVAE(base_channels, num_groups=num_groups, dropout_prob=dropout_prob)
input_image = torch.randn(1, 3, img_size, img_size)
output_image, mu, logvar = model(input_image)
print(output_image.shape) 

torch.Size([1, 3, 256, 256])


# Dataset
- **HuggingFaceDataset**: Ryan-sjtu/celebahq-caption

In [None]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import CIFAR10

class CVAELightning(pl.LightningModule):
    def __init__(self, img_size=32, latent_dim=128, lr=1e-3):
        super().__init__()
        self.save_hyperparameters()
        self.model = CVAE(img_size)
        self.lr = lr

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, _ = batch
        recon_x, mu, logvar = self(x)
        loss = self.loss_function(recon_x, x, mu, logvar)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, _ = batch
        recon_x, mu, logvar = self(x)
        loss = self.loss_function(recon_x, x, mu, logvar)
        self.log('val_loss', loss)

    def test_step(self, batch, batch_idx):
        x, _ = batch
        recon_x, mu, logvar = self(x)
        loss = self.loss_function(recon_x, x, mu, logvar)
        self.log('test_loss', loss)

    def loss_function(self, recon_x, x, mu, logvar):
        BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return BCE + KLD

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=5, min_lr=1e-6)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss",
            },
        }

    def prepare_data(self):
        # Download the CIFAR10 dataset
        CIFAR10(root='./data', train=True, download=True)
        CIFAR10(root='./data', train=False, download=True)

    def setup(self, stage=None):
        # Transform
        transform = transforms.Compose([
            transforms.ToTensor(),
        ])

        # Load and split datasets
        if stage == 'fit' or stage is None:
            cifar_full = CIFAR10(root='./data', train=True, transform=transform)
            self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])

        if stage == 'test' or stage is None:
            self.cifar_test = CIFAR10(root='./data', train=False, transform=transform)

    def train_dataloader(self):
        return DataLoader(self.cifar_train, batch_size=64, shuffle=True, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.cifar_val, batch_size=64, num_workers=4)

    def test_dataloader(self):
        return DataLoader(self.cifar_test, batch_size=64, num_workers=4)

# Example usage:
model = CVAELightning(img_size=32, latent_dim=128)
trainer = pl.Trainer(max_epochs=10, gpus=1 if torch.cuda.is_available() else 0)
trainer.fit(model)