<a href="https://colab.research.google.com/github/MehrdadDastouri/MNIST-VAE/blob/main/MNIST-VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os
import matplotlib.pyplot as plt
from tqdm import tqdm

class VAEConfig:
    latent_dim = 20
    input_dim = 784  # 28x28
    hidden_dims = [400]
    epochs = 30
    batch_size = 128
    lr = 0.001
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    sample_dir = "vae_samples"
    model_dir = "vae_models"

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(VAEConfig.input_dim, VAEConfig.hidden_dims[0]),
            nn.ReLU(),
            nn.Linear(VAEConfig.hidden_dims[0], VAEConfig.hidden_dims[0]//2),
            nn.ReLU()
        )

        # Latent space parameters
        self.fc_mu = nn.Linear(VAEConfig.hidden_dims[0]//2, VAEConfig.latent_dim)
        self.fc_logvar = nn.Linear(VAEConfig.hidden_dims[0]//2, VAEConfig.latent_dim)

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(VAEConfig.latent_dim, VAEConfig.hidden_dims[0]//2),
            nn.ReLU(),
            nn.Linear(VAEConfig.hidden_dims[0]//2, VAEConfig.hidden_dims[0]),
            nn.ReLU(),
            nn.Linear(VAEConfig.hidden_dims[0], VAEConfig.input_dim),
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, VAEConfig.input_dim))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

class VAETrainer:
    def __init__(self):
        self.config = VAEConfig()
        self._init_dirs()

        # Initialize model and optimizer
        self.model = VAE().to(self.config.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.config.lr)

        # Dataset and dataloader
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        self.train_loader = self._get_dataloader()

    def _init_dirs(self):
        os.makedirs(self.config.sample_dir, exist_ok=True)
        os.makedirs(self.config.model_dir, exist_ok=True)

    def _get_dataloader(self):
        train_set = torchvision.datasets.MNIST(
            root='./data', train=True, download=True, transform=self.transform)
        return DataLoader(
            train_set, batch_size=self.config.batch_size, shuffle=True, num_workers=2)

    def loss_function(self, recon_x, x, mu, logvar):
        BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return BCE + KLD

    def train_epoch(self, epoch):
        self.model.train()
        train_loss = 0
        for data, _ in tqdm(self.train_loader, desc=f'Epoch {epoch+1}/{self.config.epochs}'):
            data = data.to(self.config.device)

            self.optimizer.zero_grad()
            recon_batch, mu, logvar = self.model(data)
            loss = self.loss_function(recon_batch, data, mu, logvar)
            loss.backward()
            train_loss += loss.item()
            self.optimizer.step()

        return train_loss / len(self.train_loader.dataset)

    def generate_samples(self, epoch):
        with torch.no_grad():
            z = torch.randn(64, self.config.latent_dim).to(self.config.device)
            sample = self.model.decode(z).cpu()
            save_image(sample.view(64, 1, 28, 28),
                       os.path.join(self.config.sample_dir, f'sample_epoch_{epoch+1}.png'))

    def save_model(self, epoch):
        torch.save(self.model.state_dict(),
                 os.path.join(self.config.model_dir, f'vae_epoch_{epoch+1}.pth'))

    def train(self):
        for epoch in range(self.config.epochs):
            loss = self.train_epoch(epoch)
            print(f'Epoch {epoch+1}, Loss: {loss:.4f}')
            self.generate_samples(epoch)
            if (epoch+1) % 10 == 0:
                self.save_model(epoch)

if __name__ == "__main__":
    trainer = VAETrainer()
    trainer.train()