In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        
        # Encoder
        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20)  
        self.fc22 = nn.Linear(400, 20) 

        # Decoder
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = torch.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mean, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mean + eps * std

    def decode(self, z):
        h3 = torch.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mean, log_var = self.encode(x.view(-1, 784))
        z = self.reparameterize(mean, log_var)
        return self.decode(z), mean, log_var

In [None]:
def train(model, dataloader, optimizer):
    model.train()
    train_loss = 0
    for data, _ in dataloader:
        data = data.to(device)
        optimizer.zero_grad()
        reconstruction, mean, log_var = model(data)
        # Calculate loss
        recon_loss = nn.functional.binary_cross_entropy(reconstruction, data.view(-1, 784), reduction='sum')
        kld_loss = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
        loss = recon_loss + kld_loss
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    return train_loss / len(dataloader.dataset)

In [None]:
# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
mnist = datasets.MNIST('./data', train=True, download=True, transform=transform)
dataloader = DataLoader(mnist, batch_size=128, shuffle=True)

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the VAE model
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    train_loss = train(model, dataloader, optimizer)
    print(f'Epoch {epoch+1}, Loss: {train_loss}')