In [27]:
# prerequisites
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image


In [28]:
# Alternative: Load raw MNIST data directly
import struct
import numpy as np
import torch

def load_mnist_images(filename):
    """Load MNIST images from raw binary file"""
    with open(filename, 'rb') as f:
        magic, num, rows, cols = struct.unpack('>IIII', f.read(16))
        images = np.frombuffer(f.read(), dtype=np.uint8)
        images = images.reshape(num, rows, cols)
    return images

def load_mnist_labels(filename):
    """Load MNIST labels from raw binary file"""
    with open(filename, 'rb') as f:
        magic, num = struct.unpack('>II', f.read(8))
        labels = np.frombuffer(f.read(), dtype=np.uint8)
    return labels

# Load raw data
train_images = load_mnist_images('./mnist_data/raw/train-images-idx3-ubyte')
train_labels = load_mnist_labels('./mnist_data/raw/train-labels-idx1-ubyte')
test_images = load_mnist_images('./mnist_data/raw/t10k-images-idx3-ubyte')
test_labels = load_mnist_labels('./mnist_data/raw/t10k-labels-idx1-ubyte')

print(f"Train images shape: {train_images.shape}")
print(f"Train labels shape: {train_labels.shape}")
print(f"Test images shape: {test_images.shape}")
print(f"Test labels shape: {test_labels.shape}")

# Convert to PyTorch tensors and normalize
train_images_tensor = torch.FloatTensor(train_images) / 255.0
test_images_tensor = torch.FloatTensor(test_images) / 255.0
train_labels_tensor = torch.LongTensor(train_labels)
test_labels_tensor = torch.LongTensor(test_labels)

print(f"Train images tensor shape: {train_images_tensor.shape}")
print(f"Train images tensor range: [{train_images_tensor.min():.3f}, {train_images_tensor.max():.3f}]")


Train images shape: (60000, 28, 28)
Train labels shape: (60000,)
Test images shape: (10000, 28, 28)
Test labels shape: (10000,)
Train images tensor shape: torch.Size([60000, 28, 28])
Train images tensor range: [0.000, 1.000]


In [29]:
# Convert labels to one-hot encoding
import torch.nn.functional as F

# Method 1: Using torch.nn.functional.one_hot (PyTorch 1.6+)
train_labels_onehot = F.one_hot(train_labels_tensor, num_classes=10).float()
test_labels_onehot = F.one_hot(test_labels_tensor, num_classes=10).float()

print(f"Original train labels shape: {train_labels_tensor.shape}")
print(f"One-hot train labels shape: {train_labels_onehot.shape}")
print(f"Original train labels (first 5): {train_labels_tensor[:5]}")
print(f"One-hot train labels (first 5):")
print(train_labels_onehot[:5])

# Method 2: Manual one-hot encoding
def to_one_hot(labels, num_classes=10):
    """Convert labels to one-hot encoding manually"""
    one_hot = torch.zeros(labels.size(0), num_classes)
    one_hot.scatter_(1, labels.unsqueeze(1), 1)
    return one_hot

# Alternative manual method
train_labels_onehot_manual = to_one_hot(train_labels_tensor)
test_labels_onehot_manual = to_one_hot(test_labels_tensor)

print(f"\nManual one-hot train labels shape: {train_labels_onehot_manual.shape}")
print(f"Manual one-hot train labels (first 5):")
print(train_labels_onehot_manual[:5])

# Verify both methods give same result
print(f"\nMethods are equivalent: {torch.equal(train_labels_onehot, train_labels_onehot_manual)}")


Original train labels shape: torch.Size([60000])
One-hot train labels shape: torch.Size([60000, 10])
Original train labels (first 5): tensor([5, 0, 4, 1, 9])
One-hot train labels (first 5):
tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])

Manual one-hot train labels shape: torch.Size([60000, 10])
Manual one-hot train labels (first 5):
tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])

Methods are equivalent: True


In [30]:
# Concatenate one-hot labels with image data for conditional generation
import torch

# Reshape images to flat vectors (784 dimensions for 28x28 images)
train_images_flat = train_images_tensor.view(-1, 784)  # (60000, 784)
test_images_flat = test_images_tensor.view(-1, 784)    # (10000, 784)

# Concatenate images with one-hot labels
train_conditional_data = torch.cat([train_images_flat, train_labels_onehot], dim=1)  # (60000, 794)
test_conditional_data = torch.cat([test_images_flat, test_labels_onehot], dim=1)     # (10000, 794)

print(f"Original train images shape: {train_images_flat.shape}")
print(f"One-hot train labels shape: {train_labels_onehot.shape}")
print(f"Concatenated train data shape: {train_conditional_data.shape}")
print(f"Concatenated test data shape: {test_conditional_data.shape}")

# Verify concatenation
print(f"\nFirst sample - Image part (first 5 pixels): {train_conditional_data[0, :5]}")
print(f"First sample - Label part (last 10 values): {train_conditional_data[0, -10:]}")
print(f"Original label for first sample: {train_labels_tensor[0]}")

# Create DataLoaders for conditional data
from torch.utils.data import TensorDataset, DataLoader

# Create datasets
train_dataset_conditional = TensorDataset(train_conditional_data, train_labels_tensor)
test_dataset_conditional = TensorDataset(test_conditional_data, test_labels_tensor)

# Create data loaders
train_loader_conditional = DataLoader(train_dataset_conditional, batch_size=100, shuffle=True)
test_loader_conditional = DataLoader(test_dataset_conditional, batch_size=100, shuffle=False)

print(f"\nTrain loader batches: {len(train_loader_conditional)}")
print(f"Test loader batches: {len(test_loader_conditional)}")

# Example: Get a batch and verify structure
for batch_data, batch_labels in train_loader_conditional:
    print(f"\nBatch data shape: {batch_data.shape}")
    print(f"Batch labels shape: {batch_labels.shape}")
    print(f"First sample in batch - Image part: {batch_data[0, :5]}")
    print(f"First sample in batch - Label part: {batch_data[0, -10:]}")
    print(f"First sample original label: {batch_labels[0]}")
    break


Original train images shape: torch.Size([60000, 784])
One-hot train labels shape: torch.Size([60000, 10])
Concatenated train data shape: torch.Size([60000, 794])
Concatenated test data shape: torch.Size([10000, 794])

First sample - Image part (first 5 pixels): tensor([0., 0., 0., 0., 0.])
First sample - Label part (last 10 values): tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 0.])
Original label for first sample: 5

Train loader batches: 600
Test loader batches: 100

Batch data shape: torch.Size([100, 794])
Batch labels shape: torch.Size([100])
First sample in batch - Image part: tensor([0., 0., 0., 0., 0.])
First sample in batch - Label part: tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 0.])
First sample original label: 5


In [31]:
class ConditionalVAE(nn.Module):
    def __init__(self, x_dim, label_dim, h_dim1, h_dim2, z_dim):
        super(ConditionalVAE, self).__init__()
        
        # encoder part - takes concatenated image + label
        self.fc1 = nn.Linear(x_dim + label_dim, h_dim1)  # 784 + 10 = 794
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.fc31 = nn.Linear(h_dim2, z_dim)
        self.fc32 = nn.Linear(h_dim2, z_dim)
        
        # decoder part - takes z + label for conditional generation
        self.fc4 = nn.Linear(z_dim + label_dim, h_dim2)  # z + label
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)  # output only image (784)
        
    def encoder(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return self.fc31(h), self.fc32(h) # mu, log_var
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample
        
    def decoder(self, z, labels):
        # Concatenate z with labels for conditional generation
        z_with_labels = torch.cat([z, labels], dim=1)
        h = F.relu(self.fc4(z_with_labels))
        h = F.relu(self.fc5(h))
        return torch.sigmoid(self.fc6(h))  # output only image part
    
    def forward(self, x):
        mu, log_var = self.encoder(x)  # x is already 794-dimensional
        z = self.sampling(mu, log_var)
        
        # Separate image and label parts
        images = x[:, :784]  # first 784 dimensions
        labels = x[:, 784:]  # last 10 dimensions
        
        recon_images = self.decoder(z, labels)
        return recon_images, mu, log_var

# build conditional model
vae = ConditionalVAE(x_dim=784, label_dim=10, h_dim1=512, h_dim2=256, z_dim=2)
if torch.cuda.is_available():
    vae.cuda()

In [32]:
vae

ConditionalVAE(
  (fc1): Linear(in_features=794, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=256, bias=True)
  (fc31): Linear(in_features=256, out_features=2, bias=True)
  (fc32): Linear(in_features=256, out_features=2, bias=True)
  (fc4): Linear(in_features=12, out_features=256, bias=True)
  (fc5): Linear(in_features=256, out_features=512, bias=True)
  (fc6): Linear(in_features=512, out_features=784, bias=True)
)

In [33]:
optimizer = optim.Adam(vae.parameters())
# return reconstruction error + KL divergence losses
def loss_function(recon_images, x, mu, log_var):
    # Extract only the image part (first 784 dimensions) for reconstruction loss
    target_images = x[:, :784]  # first 784 dimensions are images
    BCE = F.binary_cross_entropy(recon_images, target_images, reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + KLD

In [34]:
def train(epoch):
    vae.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader_conditional):
        data = data.cuda()
        optimizer.zero_grad()
        
        recon_batch, mu, log_var = vae(data)
        loss = loss_function(recon_batch, data, mu, log_var)
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader_conditional.dataset),
                100. * batch_idx / len(train_loader_conditional), loss.item() / len(data)))
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader_conditional.dataset)))

In [35]:
def test():
    vae.eval()
    test_loss= 0
    with torch.no_grad():
        for data, _ in test_loader_conditional:
            data = data.cuda()
            recon, mu, log_var = vae(data)
            
            # sum up batch loss
            test_loss += loss_function(recon, data, mu, log_var).item()
        
    test_loss /= len(test_loader_conditional.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [36]:
for epoch in range(1, 51):
    train(epoch)
    test()

====> Epoch: 1 Average loss: 160.9827
====> Test set loss: 141.0363
====> Epoch: 2 Average loss: 138.3065
====> Test set loss: 136.4774
====> Epoch: 3 Average loss: 135.0846
====> Test set loss: 134.3047
====> Epoch: 4 Average loss: 133.3477
====> Test set loss: 133.0705
====> Epoch: 5 Average loss: 132.2238
====> Test set loss: 132.1766
====> Epoch: 6 Average loss: 131.4938
====> Test set loss: 132.0188
====> Epoch: 7 Average loss: 130.9228
====> Test set loss: 131.5359
====> Epoch: 8 Average loss: 130.4381
====> Test set loss: 130.9532
====> Epoch: 9 Average loss: 130.0297
====> Test set loss: 130.4505
====> Epoch: 10 Average loss: 129.6611
====> Test set loss: 130.4030
====> Epoch: 11 Average loss: 129.3520
====> Test set loss: 130.0123
====> Epoch: 12 Average loss: 129.0826
====> Test set loss: 129.8942
====> Epoch: 13 Average loss: 128.7928
====> Test set loss: 129.7332
====> Epoch: 14 Average loss: 128.5410
====> Test set loss: 129.4924
====> Epoch: 15 Average loss: 128.2835
====

In [37]:
# Conditional generation function
def generate_conditional_samples(digit, num_samples=64):
    """Generate samples of a specific digit"""
    with torch.no_grad():
        # Create random latent vectors
        z = torch.randn(num_samples, 2).cuda()
        
        # Create one-hot encoding for the desired digit
        target_labels = torch.zeros(num_samples, 10).cuda()
        target_labels[:, digit] = 1.0
        
        # Generate samples using the decoder
        generated_images = vae.decoder(z, target_labels)
        
        return generated_images

# Generate samples for different digits
for digit in [0, 1, 2, 3, 4]:
    samples = generate_conditional_samples(digit, num_samples=16)
    save_image(samples.view(16, 1, 28, 28), f'./samples/digit_{digit}.png')
    print(f"Generated samples for digit {digit}")

# # Also generate some random samples (without specific digit)
# with torch.no_grad():
#     z = torch.randn(64, 2).cuda()
#     random_labels = torch.zeros(64, 10).cuda()
#     random_labels[:, 5] = 1.0  # Generate digit 5
#     sample = vae.decoder(z, random_labels)
#     save_image(sample.view(64, 1, 28, 28), './samples/random_samples.png')

Generated samples for digit 0
Generated samples for digit 1
Generated samples for digit 2
Generated samples for digit 3
Generated samples for digit 4
