In [2]:
pip install torch

Defaulting to user installation because normal site-packages is not writeable
Collecting torch
  Downloading torch-2.2.2-cp311-cp311-manylinux1_x86_64.whl.metadata (25 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)
  Downloading nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)
  Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)
  Downloading nvidia_cufft_cu

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import numpy as np


In [2]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# Hyperparameters
batch_size = 64
latent_dim = 100
num_classes = 10
lr = 0.0002
epochs = 50


In [4]:
# Generator network
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.model = nn.Sequential(
            nn.Linear(latent_dim + num_classes, 128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 32 * 32 * 3),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        gen_input = torch.cat((self.label_emb(labels), noise), -1)
        img = self.model(gen_input)
        img = img.view(img.size(0), 3, 32, 32)
        return img

In [5]:
# Discriminator network
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.model = nn.Sequential(
            nn.Linear(num_classes + 32 * 32 * 3, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        d_in = img.view(img.size(0), -1)
        dis_input = torch.cat((d_in, self.label_emb(labels)), -1)
        validity = self.model(dis_input)
        return validity


In [6]:
# Initialize networks
generator = Generator().to(device)
discriminator = Discriminator().to(device)


In [7]:
# Loss function
adversarial_loss = nn.BCELoss()

In [8]:
# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

In [9]:
# Dataset and DataLoader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 65643083.56it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data


In [11]:
# Training loop
for epoch in range(epochs):
    for i, (imgs, labels) in enumerate(dataloader):
        # Adversarial ground truths
        valid = torch.ones(imgs.size(0), 1, device=device)
        fake = torch.zeros(imgs.size(0), 1, device=device)

        # Configure input
        real_imgs = imgs.to(device)
        labels = labels.to(device)
        z = torch.randn(imgs.size(0), latent_dim, device=device)

        # Generate a batch of images
        gen_imgs = generator(z, labels)

        # Train Discriminator
        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(real_imgs, labels), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach(), labels), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()
        g_loss = adversarial_loss(discriminator(gen_imgs, labels), valid)
        g_loss.backward()
        optimizer_G.step()

        # Print progress
        if i % 100 == 0:
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, epochs, i, len(dataloader), d_loss.item(), g_loss.item())
            )

    # Save generated images
    if epoch % 10 == 0:
        save_image(gen_imgs.data[:25], "images/%d.png" % epoch, nrow=5, normalize=True)


[Epoch 0/50] [Batch 0/782] [D loss: 0.750015] [G loss: 1.667734]
[Epoch 0/50] [Batch 100/782] [D loss: 0.555477] [G loss: 1.561500]
[Epoch 0/50] [Batch 200/782] [D loss: 0.510027] [G loss: 1.866041]
[Epoch 0/50] [Batch 300/782] [D loss: 0.696911] [G loss: 1.489877]
[Epoch 0/50] [Batch 400/782] [D loss: 0.623170] [G loss: 1.563631]
[Epoch 0/50] [Batch 500/782] [D loss: 0.594169] [G loss: 1.497278]
[Epoch 0/50] [Batch 600/782] [D loss: 0.582114] [G loss: 1.755124]
[Epoch 0/50] [Batch 700/782] [D loss: 0.558877] [G loss: 1.682624]
[Epoch 1/50] [Batch 0/782] [D loss: 0.551541] [G loss: 1.738785]
[Epoch 1/50] [Batch 100/782] [D loss: 0.616140] [G loss: 1.585098]
[Epoch 1/50] [Batch 200/782] [D loss: 0.555593] [G loss: 1.671649]
[Epoch 1/50] [Batch 300/782] [D loss: 0.611367] [G loss: 1.342386]
[Epoch 1/50] [Batch 400/782] [D loss: 0.609334] [G loss: 1.654291]
[Epoch 1/50] [Batch 500/782] [D loss: 0.598365] [G loss: 1.259780]
[Epoch 1/50] [Batch 600/782] [D loss: 0.563802] [G loss: 1.764005]