In [4]:
# torch_cvae_train.py

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt


## Checking your data folder
Let's verify the contents of your `./data` directory to ensure the MNIST files are in the correct place.

In [8]:
import os
print('Contents of ./data:')
print(os.listdir('./data'))
if os.path.exists('./data/MNIST/raw'):
    print('Contents of ./data/MNIST/raw:')
    print(os.listdir('./data/MNIST/raw'))
else:
    print('No MNIST/raw directory found in ./data')

Contents of ./data:
['MNIST']
Contents of ./data/MNIST/raw:
['train-images.idx3-ubyte', 'train-labels.idx1-ubyte']


In [10]:
# 1. Dataset and Dataloader
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

100%|██████████| 9.91M/9.91M [01:40<00:00, 99.0kB/s]

100%|██████████| 28.9k/28.9k [00:00<00:00, 165kB/s]

100%|██████████| 1.65M/1.65M [00:01<00:00, 1.59MB/s]

100%|██████████| 4.54k/4.54k [00:00<00:00, 9.48MB/s]



In [None]:



# 2. Conditional VAE
class CVAE(nn.Module):
    def __init__(self, latent_dim=20):
        super(CVAE, self).__init__()
        self.latent_dim = latent_dim
        self.fc1 = nn.Linear(784 + 10, 400)
        self.fc21 = nn.Linear(400, latent_dim)
        self.fc22 = nn.Linear(400, latent_dim)
        self.fc3 = nn.Linear(latent_dim + 10, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x, y):
        h1 = F.relu(self.fc1(torch.cat([x, y], dim=1)))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, y):
        h3 = F.relu(self.fc3(torch.cat([z, y], dim=1)))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x, y):
        mu, logvar = self.encode(x, y)
        z = self.reparameterize(mu, logvar)
        return self.decode(z, y), mu, logvar



In [None]:
# 3. Loss function
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD



In [None]:
# 4. Training
model = CVAE()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

for epoch in range(10):  # adjust epochs based on results
    model.train()
    train_loss = 0
    for batch_idx, (data, labels) in enumerate(train_loader):
        data = data.view(-1, 784).to(device)
        labels = F.one_hot(labels, 10).float().to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data, labels)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    print(f'Epoch {epoch}, Loss: {train_loss / len(train_loader.dataset):.4f}')

torch.save(model.state_dict(), 'cvae_mnist.pth')
