In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor
from torchvision.utils import save_image

# Define the VAE class
class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim
        # Define the encoder network
        self.encoder = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, latent_dim * 2)
        )
        # Define the decoder network
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Sigmoid()
        )

    def encode(self, x):
        # Perform encoding
        hidden = self.encoder(x)
        mu, logvar = hidden[:, :self.latent_dim], hidden[:, self.latent_dim:]
        return mu, logvar

    def reparameterize(self, mu, logvar):
        # Reparameterization trick for sampling from the latent space
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z

    def decode(self, z):
        # Perform decoding
        return self.decoder(z)

    def forward(self, x):
        # Perform encoding and decoding
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# Define the loss function for VAE
def vae_loss(recon_x, x, mu, logvar):
    # Reconstruction loss
    recon_loss = nn.BCELoss(reduction='sum')(recon_x, x.view(-1, 784))
    # KL divergence loss
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_loss

# Set the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the FashionMNIST dataset
train_dataset = FashionMNIST(root='./data', train=True, transform=ToTensor(), download=True)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

# Initialize the VAE model
latent_dim = 20
vae = VAE(latent_dim).to(device)

# Define the optimizer
optimizer = optim.Adam(vae.parameters(), lr=1e-3)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        recon_data, mu, logvar = vae(data)
        loss = vae_loss(recon_data, data, mu, logvar)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

# Save the trained model
torch.save(vae.state_dict(), 'vae_model.pth')
