In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load MNIST
transform = transforms.ToTensor()
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)

# CVAE Components
class CVAE(nn.Module):
    def __init__(self, latent_dim=20):
        super(CVAE, self).__init__()
        self.latent_dim = latent_dim
        self.label_embedding = nn.Embedding(10, 10)

        self.encoder = nn.Sequential(
            nn.Linear(28*28 + 10, 400),
            nn.ReLU()
        )
        self.mu = nn.Linear(400, latent_dim)
        self.logvar = nn.Linear(400, latent_dim)

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim + 10, 400),
            nn.ReLU(),
            nn.Linear(400, 784),
            nn.Sigmoid()
        )

    def encode(self, x, y):
        y_embed = self.label_embedding(y)
        x = torch.cat([x.view(-1, 784), y_embed], dim=1)
        h = self.encoder(x)
        return self.mu(h), self.logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, y):
        y_embed = self.label_embedding(y)
        z = torch.cat([z, y_embed], dim=1)
        return self.decoder(z)

    def forward(self, x, y):
        mu, logvar = self.encode(x, y)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z, y)
        return recon, mu, logvar

# Loss Function
def loss_fn(recon_x, x, mu, logvar):
    BCE = nn.functional.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# Initialize and Train
model = CVAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

epochs = 10  # You can adjust
for epoch in range(epochs):
    model.train()
    train_loss = 0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        recon, mu, logvar = model(x, y)
        loss = loss_fn(recon, x, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {train_loss / len(train_loader.dataset):.4f}")

# Save model
torch.save(model.state_dict(), 'cvae_mnist.pth')

100%|██████████| 9.91M/9.91M [00:00<00:00, 16.2MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 482kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.45MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 4.17MB/s]


Epoch 1, Loss: 158.0420
Epoch 2, Loss: 117.8744
Epoch 3, Loss: 110.7497
Epoch 4, Loss: 107.5139
Epoch 5, Loss: 105.6369
Epoch 6, Loss: 104.3360
Epoch 7, Loss: 103.3819
Epoch 8, Loss: 102.7071
Epoch 9, Loss: 102.0911
Epoch 10, Loss: 101.5719
