In [46]:
import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os
from torchvision.utils import save_image

# Hyperparameters
latent_dim = 100
lr = 0.0002
num_classes = 10
batch_size = 128
img_size = 28
num_epochs = 100
n_images = 128

## Download Dataset

In [47]:
# Device configuration
device = torch.device('cuda')
os.makedirs('./data/mnist', exist_ok=True)
# MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)


## Models

In [48]:

class Generator(nn.Module):
    def __init__(self, img_channels=1):
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)

        self.init_size = img_size // 4  # Initial size before upsampling
        self.l1 = nn.Sequential(nn.Linear(latent_dim + num_classes, 128 * self.init_size ** 2))

        self.model = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.ConvTranspose2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Upsample(scale_factor=2),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),

            nn.ConvTranspose2d(64, img_channels, kernel_size=3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        gen_input = torch.cat((self.label_emb(labels), noise), -1)
        out = self.l1(gen_input)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.model(out)
        return img

class Discriminator(nn.Module):
    def __init__(self, img_channels=1):
        super(Discriminator, self).__init__()
        self.label_embedding = nn.Embedding(num_classes, img_size * img_size)

        self.model = nn.Sequential(
            nn.Conv2d(img_channels + 1, 64, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),    

            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.adv_layer = nn.Sequential(
            nn.Linear(512 * 2 * 2, 1),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        label_embeddings = self.label_embedding(labels).view(labels.size(0), 1, img_size, img_size)
        d_in = torch.cat((img, label_embeddings), 1)
        out = self.model(d_in)
        out = out.view(out.size(0), -1) 
        validity = self.adv_layer(out)
        return validity

In [49]:
def generate_images(generator, epoch, n_images, latent_dim, digit=6):
    z = torch.randn(n_images, latent_dim, device=device)
    labels = torch.full((n_images,), digit, dtype=torch.long, device=device)
    gen_imgs = generator(z, labels)
    gen_imgs = gen_imgs * 0.5 + 0.5
    save_image(gen_imgs.data, f'./data/mnist/cgan/fake_image_{epoch+1:03d}_digit_{digit}.png', nrow=5, normalize=True)

## Training Stage

In [50]:
generator = Generator(img_channels=1).to(device)
discriminator = Discriminator(img_channels=1).to(device)

adversarial_loss = nn.BCELoss()
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

for epoch in range(num_epochs):
    for i, (imgs, labels) in enumerate(train_loader):
        batch_size = imgs.size(0)
        real_imgs = imgs.to(device)
        labels = labels.to(device)
        valid = torch.ones(batch_size, 1, device=device, dtype=torch.float32)
        fake = torch.zeros(batch_size, 1, device=device, dtype=torch.float32)
        optimizer_G.zero_grad()
        z = torch.randn(batch_size, latent_dim, device=device)
        gen_labels = torch.randint(0, num_classes, (batch_size,), device=device)
        gen_imgs = generator(z, gen_labels)
        validity = discriminator(gen_imgs, gen_labels)
        g_loss = adversarial_loss(validity, valid)
        g_loss.backward()
        optimizer_G.step()
        optimizer_D.zero_grad()
        validity_real = discriminator(real_imgs, labels)
        d_real_loss = adversarial_loss(validity_real, valid)
        validity_fake = discriminator(gen_imgs.detach(), gen_labels)
        d_fake_loss = adversarial_loss(validity_fake, fake)
        d_loss = (d_real_loss + d_fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()
        if i % 100 == 0:
            print(f"[Epoch {epoch}/{num_epochs}] [Batch {i}/{len(train_loader)}] "
                  f"[D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")
    generate_images(generator, epoch, n_images, latent_dim, digit=2)

print("Training finished.")

[Epoch 0/100] [Batch 0/468] [D loss: 0.6942203044891357] [G loss: 0.6911160945892334]
[Epoch 0/100] [Batch 100/468] [D loss: 0.5465178489685059] [G loss: 1.1285897493362427]
[Epoch 0/100] [Batch 200/468] [D loss: 0.690596878528595] [G loss: 0.6928708553314209]
[Epoch 0/100] [Batch 300/468] [D loss: 0.6494032144546509] [G loss: 0.8030210733413696]
[Epoch 0/100] [Batch 400/468] [D loss: 0.6812285780906677] [G loss: 0.7501981258392334]
[Epoch 1/100] [Batch 0/468] [D loss: 0.6362183690071106] [G loss: 0.8183027505874634]
[Epoch 1/100] [Batch 100/468] [D loss: 0.6693100929260254] [G loss: 0.8404824733734131]
[Epoch 1/100] [Batch 200/468] [D loss: 0.6495645046234131] [G loss: 0.8573331832885742]
[Epoch 1/100] [Batch 300/468] [D loss: 0.5559887886047363] [G loss: 0.9519342184066772]
[Epoch 1/100] [Batch 400/468] [D loss: 0.5796309113502502] [G loss: 0.9852864742279053]
[Epoch 2/100] [Batch 0/468] [D loss: 0.540535569190979] [G loss: 0.9175381660461426]
[Epoch 2/100] [Batch 100/468] [D loss: 0

KeyboardInterrupt: 