In [1]:
# Install dependencies
#!pip install torch torchvision matplotlib

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [2]:
# Define Generator
class Generator(nn.Module):
    def __init__(self, noise_dim=100, num_classes=10):
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.model = nn.Sequential(
            nn.Linear(noise_dim + num_classes, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 28*28),
            nn.Tanh()
        )
        
    def forward(self, noise, labels):
        labels = self.label_emb(labels)
        input = torch.cat((noise, labels), -1)
        return self.model(input).view(-1, 1, 28, 28)


In [3]:
# Define Discriminator
class Discriminator(nn.Module):
    def __init__(self, num_classes=10):
        super(Discriminator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.model = nn.Sequential(
            nn.Linear(28*28 + num_classes, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        
    def forward(self, img, labels):
        img_flat = img.view(img.size(0), -1)
        labels = self.label_emb(labels)
        input = torch.cat((img_flat, labels), -1)
        return self.model(input)

In [4]:
# Hyperparameters
batch_size = 64
lr = 0.0002
epochs = 50
noise_dim = 100
num_classes = 10

# Data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Models
generator = Generator(noise_dim, num_classes).to(device)
discriminator = Discriminator(num_classes).to(device)

# Optimizers
optim_G = optim.Adam(generator.parameters(), lr=lr)
optim_D = optim.Adam(discriminator.parameters(), lr=lr)

# Loss
adversarial_loss = nn.BCELoss()

for epoch in range(epochs):
    for imgs, labels in loader:
        imgs, labels = imgs.to(device), labels.to(device)

        # Train Discriminator
        optim_D.zero_grad()
        real_validity = discriminator(imgs, labels)
        real_loss = adversarial_loss(real_validity, torch.ones_like(real_validity))
        
        noise = torch.randn(imgs.size(0), noise_dim, device=device)
        gen_labels = torch.randint(0, num_classes, (imgs.size(0),), device=device)
        fake_imgs = generator(noise, gen_labels)
        
        fake_validity = discriminator(fake_imgs.detach(), gen_labels)
        fake_loss = adversarial_loss(fake_validity, torch.zeros_like(fake_validity))
        
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optim_D.step()
        
        # Train Generator
        optim_G.zero_grad()
        gen_validity = discriminator(fake_imgs, gen_labels)
        g_loss = adversarial_loss(gen_validity, torch.ones_like(gen_validity))
        g_loss.backward()
        optim_G.step()

    print(f"Epoch [{epoch+1}/{epochs}] | D Loss: {d_loss:.4f} | G Loss: {g_loss:.4f}")

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100.0%


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100.0%


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100.0%


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100.0%


Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw

Epoch [1/50] | D Loss: 0.0841 | G Loss: 4.7763
Epoch [2/50] | D Loss: 0.0689 | G Loss: 5.2225
Epoch [3/50] | D Loss: 0.0220 | G Loss: 7.4886
Epoch [4/50] | D Loss: 0.0189 | G Loss: 5.4285
Epoch [5/50] | D Loss: 0.0369 | G Loss: 6.1684
Epoch [6/50] | D Loss: 0.1342 | G Loss: 4.8662
Epoch [7/50] | D Loss: 0.0948 | G Loss: 6.7557
Epoch [8/50] | D Loss: 0.0304 | G Loss: 6.3684
Epoch [9/50] | D Loss: 0.1932 | G Loss: 5.0937
Epoch [10/50] | D Loss: 0.1661 | G Loss: 5.6210
Epoch [11/50] | D Loss: 0.1838 | G Loss: 3.4968
Epoch [12/50] | D Loss: 0.3494 | G Loss: 3.5595
Epoch [13/50] | D Loss: 0.0957 | G Loss: 3.7373
Epoch [14/50] | D Loss: 0.1876 | G Loss: 3.4896
Epoch [15/50] | D Loss: 0.1953 | G Loss: 2.4167
Epoch [16/50] | D Loss: 0.1921 | G Loss: 4.1897
Epoch [17/50] | D Loss: 0.0925 | G Loss: 4.1587
Epoch [18/50] | D Loss: 0.3554 | G Loss: 3.3412
Epoch [19/50] | D Loss: 0.1441 | G Loss: 3.8558
Epoch [20/50] | D Loss

In [5]:
torch.save(generator.state_dict(), "mnist_cgan_generator.pth")
print("Saved Conditional MNIST Generator!")

Saved Conditional MNIST Generator!
