In [73]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd.variable import Variable
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

mnist_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset=mnist_dataset, batch_size=100, shuffle=True)

# Discriminator network
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(1, 32, 5, stride=2, padding=2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),
            nn.Conv2d(32, 64, 5, stride=2, padding=2),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),
            nn.Flatten(),
            nn.Linear(7*7*64, 1)
        )

    def forward(self, x):
        return self.main(x)

# Generator network
class Generator(nn.Module):
    def __init__(self, num_classes):
        super(Generator, self).__init__()
        self.num_classes = num_classes
        self.main = nn.Sequential(
            nn.Linear(100 + num_classes, 256),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(256),
            nn.Linear(256, 7*7*64),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(7*7*64),
            nn.Unflatten(1, (64, 7, 7)),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(32),
            nn.ConvTranspose2d(32, 1, 4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, z, labels):
        # Expand labels to one-hot encoding
        labels_onehot = F.one_hot(labels.to(torch.int64), num_classes=self.num_classes).float()
        # Concatenate noise and class labels
        z = torch.cat((z, labels_onehot), dim=1)
        return self.main(z)

# Initialize networks
num_classes = 10
generator = Generator(num_classes).to(device)
discriminator = Discriminator().to(device)

# Loss functions
criterion = nn.BCEWithLogitsLoss()
criterion_class = nn.CrossEntropyLoss()

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Training loop
num_epochs = 50

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(dataloader):
        batch_size = images.size(0)
        images = images.to(device)
        labels = labels.to(device)

        # Adversarial ground truths
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # Train Discriminator
        optimizer_D.zero_grad()
        real_pred = discriminator(images)
        d_real_loss = criterion(real_pred, real_labels)

        z = torch.randn(batch_size, 100).to(device)
        gen_labels = torch.randint(0, num_classes, (batch_size,)).to(device)
        gen_images = generator(z, gen_labels)
        fake_pred = discriminator(gen_images.detach())
        d_fake_loss = criterion(fake_pred, fake_labels)

        d_loss = 0.5 * (d_real_loss + d_fake_loss)
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()
        z = torch.randn(batch_size, 100).to(device)
        gen_labels = torch.randint(0, num_classes, (batch_size,)).to(device)
        gen_images = generator(z, gen_labels)
        validity = discriminator(gen_images)
        gen_labels_onehot = F.one_hot(gen_labels, num_classes=num_classes).float()
        g_loss = criterion(validity, real_labels) + 0.5 * criterion_class(gen_labels_onehot, labels)
        g_loss.backward()
        optimizer_G.step()

        if (i+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], '
                  f'D_loss: {d_loss.item():.4f}, G_loss: {g_loss.item():.4f}')

# Save the model checkpoints
torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')


Epoch [1/50], Step [100/600], D_loss: 0.0199, G_loss: 5.5069
Epoch [1/50], Step [200/600], D_loss: 0.0212, G_loss: 5.3127


KeyboardInterrupt: 