<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Variational_Autoencoders_(VAEs).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define the VAE class
class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),  # Encoder layer
            nn.ReLU()  # ReLU activation
        )
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)  # Linear layer for mean
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)  # Linear layer for log-variance
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),  # Decoder layer
            nn.ReLU(),  # ReLU activation
            nn.Linear(hidden_dim, input_dim),  # Output layer
            nn.Sigmoid()  # Sigmoid activation
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        z = self.reparameterize(mu, logvar)
        return self.decoder(z), mu, logvar

# Define the loss function for VAE
def loss_function(recon_x, x, mu, logvar):
    BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# Define the train_vae function
def train_vae(vae, dataloader, epochs):
    optimizer = optim.Adam(vae.parameters(), lr=0.001)  # Adam optimizer

    for epoch in range(epochs):
        for data, _ in dataloader:
            data = data.view(data.size(0), -1)  # Flatten the images
            data = (data + 1) / 2  # Rescale to [0, 1]
            optimizer.zero_grad()
            recon_data, mu, logvar = vae(data)  # Forward pass
            loss = loss_function(recon_data, data, mu, logvar)  # Compute the loss
            loss.backward()  # Backpropagate the loss
            optimizer.step()  # Update the model parameters

        print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}')

# Example usage
vae = VAE(input_dim=784, hidden_dim=128, latent_dim=20)

# Define the transformation for the dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load the MNIST dataset
dataset = datasets.MNIST(root='mnist_data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# Train the VAE
train_vae(vae, dataloader, epochs=20)