In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

class CVAE(nn.Module):
    def __init__(self, input_dim=784, num_classes=10, latent_dim=20):
        super(CVAE, self).__init__()
        
        # --- ENCODER ---
        # Input = Image (784) + Label (10)
        self.fc1 = nn.Linear(input_dim + num_classes, 400)
        self.fc_mu = nn.Linear(400, latent_dim)
        self.fc_logvar = nn.Linear(400, latent_dim)
        
        # --- DECODER ---
        # Input = Latent (20) + Label (10)
        self.fc3 = nn.Linear(latent_dim + num_classes, 400)
        self.fc4 = nn.Linear(400, input_dim)

    def encode(self, x, c):
        # Concatenate image (x) and label (c)
        inputs = torch.cat([x, c], dim=1)
        h1 = F.relu(self.fc1(inputs))
        return self.fc_mu(h1), self.fc_logvar(h1)

    def reparameterize(self, mu, logvar):
        # Standard VAE reparameterization trick
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, c):
        # Concatenate latent vector (z) and label (c)
        inputs = torch.cat([z, c], dim=1)
        h3 = F.relu(self.fc3(inputs))
        return torch.sigmoid(self.fc4(h3)) # Sigmoid to output pixel values 0-1

    def forward(self, x, c):
        mu, logvar = self.encode(x.view(-1, 784), c)
        z = self.reparameterize(mu, logvar)
        return self.decode(z, c), mu, logvar

In [2]:
def loss_function(recon_x, x, mu, logvar):
    # Binary Cross Entropy for reconstruction
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    
    # KL Divergence
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD

In [3]:
import torch.optim as optim
from torch.utils.data import DataLoader

# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128
train_loader = DataLoader(
    datasets.MNIST('./data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True)

model = CVAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

def one_hot(labels, num_classes):
    # Turns integer labels [3, 7] into one-hot vectors
    return torch.eye(num_classes)[labels].to(device)

# Training
model.train()
for epoch in range(10): # Run for 10 epochs
    train_loss = 0
    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)
        
        # Create one-hot labels
        c = one_hot(labels, 10) 
        
        optimizer.zero_grad()
        
        # Forward pass
        recon_batch, mu, logvar = model(data, c)
        
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
    print(f'Epoch: {epoch} Average loss: {train_loss / len(train_loader.dataset):.4f}')

    Found GPU0 NVIDIA GeForce RTX 5080 which is of cuda capability 12.0.
    Minimum and Maximum cuda capability supported by this version of PyTorch is
    (5.0) - (9.0)
    
  queued_call()
    Please install PyTorch with a following CUDA
    configurations:  12.8 13.0 following instructions at
    https://pytorch.org/get-started/locally/
    
  queued_call()
NVIDIA GeForce RTX 5080 with CUDA capability sm_120 is not compatible with the current PyTorch installation.
The current PyTorch install supports CUDA capabilities sm_50 sm_60 sm_61 sm_70 sm_75 sm_80 sm_86 sm_90.
If you want to use the NVIDIA GeForce RTX 5080 GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/

  queued_call()


RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)