<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Variational_Autoencoders_(VAEs).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
import matplotlib.pyplot as plt
import os

# Define the Variational Autoencoder (VAE)
class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, z_dim):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2_mu = nn.Linear(hidden_dim, z_dim)
        self.fc2_logvar = nn.Linear(hidden_dim, z_dim)
        self.fc3 = nn.Linear(z_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, input_dim)

    def encode(self, x):
        h1 = torch.relu(self.fc1(x))
        return self.fc2_mu(h1), self.fc2_logvar(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h3 = torch.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# Define the loss function
def loss_function(recon_x, x, mu, logvar):
    BCE = nn.functional.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KL = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KL

# Hyperparameters
batch_size = 128
epochs = 10
learning_rate = 1e-3
z_dim = 20  # Dimension of latent space
hidden_dim = 400
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize the VAE model and optimizer
vae = VAE(784, hidden_dim, z_dim).to(device)
optimizer = optim.Adam(vae.parameters(), lr=learning_rate)

# Create directory for saving generated images
os.makedirs("generated_images", exist_ok=True)

# Training loop
for epoch in range(epochs):
    vae.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(dataloader):
        # Scale data back to the range [0, 1] to match the BCE loss expectation
        data = data * 0.5 + 0.5  # Inverse normalization
        data = data.to(device)

        optimizer.zero_grad()
        recon_batch, mu, logvar = vae(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    print(f"Epoch [{epoch + 1}/{epochs}], Loss: {train_loss / len(dataloader.dataset):.4f}")

    # Generate new images after each epoch
    vae.eval()
    with torch.no_grad():
        sample_z = torch.randn(64, z_dim).to(device)
        generated_images = vae.decode(sample_z).view(-1, 1, 28, 28)

        # Save and display generated images
        grid_img = torchvision.utils.make_grid(generated_images, nrow=8, normalize=True)
        save_path = f"generated_images/epoch_{epoch + 1}.png"
        torchvision.utils.save_image(generated_images, save_path, nrow=8, normalize=True)
        print(f"Generated images saved to {save_path}")

        plt.figure(figsize=(8, 8))
        plt.imshow(grid_img.permute(1, 2, 0).cpu().numpy(), cmap="gray")
        plt.axis("off")
        plt.show()