In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np

# Set random seed for reproducibility
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the VAE model
class VAE(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
        super(VAE, self).__init__()
        # Encoder
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, latent_dim)  # Mean
        self.fc22 = nn.Linear(hidden_dim, latent_dim)  # Log variance
        # Decoder
        self.fc3 = nn.Linear(latent_dim + 10, hidden_dim)  # Conditioned on digit
        self.fc4 = nn.Linear(hidden_dim, input_dim)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def encode(self, x, c):
        h = self.relu(self.fc1(x))
        return self.fc21(h), self.fc22(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, c):
        h = self.relu(self.fc3(torch.cat([z, c], dim=1)))
        return self.sigmoid(self.fc4(h))

    def forward(self, x, c):
        mu, logvar = self.encode(x.view(-1, 784), c)
        z = self.reparameterize(mu, logvar)
        return self.decode(z, c), mu, logvar

# Loss function
def loss_function(recon_x, x, mu, logvar):
    BCE = nn.functional.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# Data loading
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

# Initialize model, optimizer, and hyperparameters
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
num_epochs = 10

# Training loop
model.train()
for epoch in range(num_epochs):
    train_loss = 0
    for batch_idx, (data, labels) in enumerate(train_loader):
        data = data.to(device)
        labels = torch.eye(10)[labels].to(device)  # One-hot encoding
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data, labels)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % 100 == 0:
            print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item() / len(data):.4f}')
    print(f'Epoch {epoch}, Average Loss: {train_loss / len(train_loader.dataset):.4f}')

# Save the model
torch.save(model.state_dict(), 'vae_mnist.pth')
print("Model saved as 'vae_mnist.pth'")

100%|██████████| 9.91M/9.91M [00:00<00:00, 18.2MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 501kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.52MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 11.4MB/s]


Epoch 0, Batch 0, Loss: 546.4749
Epoch 0, Batch 100, Loss: 171.6604
Epoch 0, Batch 200, Loss: 153.6828
Epoch 0, Batch 300, Loss: 137.1764
Epoch 0, Batch 400, Loss: 132.0851
Epoch 0, Average Loss: 163.0095
Epoch 1, Batch 0, Loss: 123.2530
Epoch 1, Batch 100, Loss: 124.2630
Epoch 1, Batch 200, Loss: 119.0038
Epoch 1, Batch 300, Loss: 115.4319
Epoch 1, Batch 400, Loss: 114.2737
Epoch 1, Average Loss: 120.6249
Epoch 2, Batch 0, Loss: 113.5933
Epoch 2, Batch 100, Loss: 114.9692
Epoch 2, Batch 200, Loss: 116.1663
Epoch 2, Batch 300, Loss: 113.3066
Epoch 2, Batch 400, Loss: 114.3529
Epoch 2, Average Loss: 113.5771
Epoch 3, Batch 0, Loss: 113.1621
Epoch 3, Batch 100, Loss: 112.1343
Epoch 3, Batch 200, Loss: 114.0702
Epoch 3, Batch 300, Loss: 113.9510
Epoch 3, Batch 400, Loss: 111.5906
Epoch 3, Average Loss: 110.2997
Epoch 4, Batch 0, Loss: 105.2716
Epoch 4, Batch 100, Loss: 110.3065
Epoch 4, Batch 200, Loss: 111.5732
Epoch 4, Batch 300, Loss: 108.4896
Epoch 4, Batch 400, Loss: 110.3743
Epoch 4