In [2]:
# Install and Import
!pip install torchvision
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt




In [3]:
# Load MNIST
transform = transforms.ToTensor()
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)


100%|█████████████████████████████████████████████████████████████████████████████| 9.91M/9.91M [00:03<00:00, 2.91MB/s]
100%|█████████████████████████████████████████████████████████████████████████████| 28.9k/28.9k [00:00<00:00, 83.0kB/s]
100%|██████████████████████████████████████████████████████████████████████████████| 1.65M/1.65M [00:05<00:00, 322kB/s]
100%|██████████████████████████████████████████████████████████████████████████████| 4.54k/4.54k [00:00<00:00, 746kB/s]


In [4]:
# One-hot encoding for labels
def one_hot(labels, num_classes=10):
    return F.one_hot(labels, num_classes=num_classes).float()

In [5]:
# Define CVAE Model
class CVAE(nn.Module):
    def __init__(self):
        super(CVAE, self).__init__()
        self.fc1 = nn.Linear(784 + 10, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20 + 10, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x, y):
        h = F.relu(self.fc1(torch.cat([x, y], dim=1)))
        return self.fc21(h), self.fc22(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, y):
        h = F.relu(self.fc3(torch.cat([z, y], dim=1)))
        return torch.sigmoid(self.fc4(h))

    def forward(self, x, y):
        mu, logvar = self.encode(x, y)
        z = self.reparameterize(mu, logvar)
        return self.decode(z, y), mu, logvar

In [6]:
# Define Loss Function
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD


In [7]:
# Train
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CVAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

epochs = 10
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for x, y in train_loader:
        x = x.view(-1, 784).to(device)
        y_oh = one_hot(y).to(device)

        recon_batch, mu, logvar = model(x, y_oh)
        loss = loss_function(recon_batch, x, mu, logvar)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss:.2f}")

Epoch 1, Loss: 9773375.67
Epoch 2, Loss: 7169785.74
Epoch 3, Loss: 6745463.47
Epoch 4, Loss: 6549287.74
Epoch 5, Loss: 6422558.67
Epoch 6, Loss: 6334965.83
Epoch 7, Loss: 6272662.45
Epoch 8, Loss: 6225541.69
Epoch 9, Loss: 6185143.23
Epoch 10, Loss: 6150972.87


In [8]:
# STEP 6: Save Model
torch.save(model.state_dict(), "cvae_model.pt")