In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np

In [None]:
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

batch_size = 128
dataset = datasets.CIFAR10(root='./data', download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim, num_classes, img_channels):
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, z_dim)
        self.gen = nn.Sequential(
            nn.ConvTranspose2d(z_dim * 2, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, img_channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        c = self.label_emb(labels)
        x = torch.cat([noise, c], 1).unsqueeze(2).unsqueeze(3)
        return self.gen(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, num_classes, img_channels):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Conv2d(img_channels, 128, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1024, 4, 2, 1, bias=False),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.classifier = nn.Sequential(
            nn.Linear(1024 * 4 * 4, 1),
            nn.Sigmoid(),
        )
        self.aux_classifier = nn.Sequential(
            nn.Linear(1024 * 4 * 4, num_classes),
            nn.Softmax(dim=1),
        )

    def forward(self, img):
        out = self.disc(img)
        out = out.view(out.shape[0], -1)
        validity = self.classifier(out)
        label = self.aux_classifier(out)
        return validity, label

In [None]:
# Hyperparameters
z_dim = 100
num_classes = 10
img_channels = 3
lr = 0.0002
beta1 = 0.5

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model
generator = Generator(z_dim, num_classes, img_channels).to(device)
discriminator = Discriminator(num_classes, img_channels).to(device)

# Optimizers
opt_gen = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
opt_disc = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))

# Loss function
adversarial_loss = nn.BCELoss()
auxiliary_loss = nn.CrossEntropyLoss()


In [None]:
num_epochs = 200
for epoch in range(num_epochs):
    for i, (imgs, labels) in enumerate(dataloader):
        batch_size = imgs.shape[0]
        imgs = imgs.to(device)
        labels = labels.to(device)
        valid = torch.ones(batch_size, 1).to(device)
        fake = torch.zeros(batch_size, 1).to(device)

        # Train Generator
        opt_gen.zero_grad()
        z = torch.randn(batch_size, z_dim).to(device)
        gen_labels = torch.randint(0, num_classes, (batch_size,)).to(device)
        gen_imgs = generator(z, gen_labels)
        validity, pred_label = discriminator(gen_imgs)
        g_loss = adversarial_loss(validity, valid) + auxiliary_loss(pred_label, gen_labels)
        g_loss.backward()
        opt_gen.step()

        # Train Discriminator
        opt_disc.zero_grad()
        validity_real, pred_label_real = discriminator(imgs)
        d_real_loss = (adversarial_loss(validity_real, valid) +
                       auxiliary_loss(pred_label_real, labels)) / 2

        validity_fake, pred_label_fake = discriminator(gen_imgs.detach())
        d_fake_loss = (adversarial_loss(validity_fake, fake) +
                       auxiliary_loss(pred_label_fake, gen_labels)) / 2

        d_loss = (d_real_loss + d_fake_loss) / 2
        d_loss.backward()
        opt_disc.step()

    print(f"Epoch [{epoch+1}/{num_epochs}] Loss D: {d_loss.item()}, loss G: {g_loss.item()}")
