In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# Hyperparameters
image_size = 32
batch_size = 128
latent_dim = 100
num_epochs = 20
learning_rate = 0.0002
beta1 = 0.5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data loading and preprocessing
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)  # normalize to [-1,1]
])

dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Generator model
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            # input is latent vector Z
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),

            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            nn.ConvTranspose2d(128, 3, 4, 2, 1, bias=False),
            nn.Tanh()  # output in range [-1,1]
        )

    def forward(self, x):
        return self.net(x)

# Discriminator model
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 128, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()  # output probability
        )

    def forward(self, x):
        return self.net(x).view(-1)

# Initialize models
G = Generator().to(device)
D = Discriminator().to(device)

# Loss and optimizers
criterion = nn.BCELoss()
optimizerD = optim.Adam(D.parameters(), lr=learning_rate, betas=(beta1, 0.999))
optimizerG = optim.Adam(G.parameters(), lr=learning_rate, betas=(beta1, 0.999))

# Fixed noise for monitoring progress
fixed_noise = torch.randn(64, latent_dim, 1, 1, device=device)

def denorm(x):
    return x * 0.5 + 0.5

# Training loop
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(dataloader):
        real_images = images.to(device)
        batch_size_curr = real_images.size(0)

        # Real and fake labels
        real_labels = torch.ones(batch_size_curr, device=device)
        fake_labels = torch.zeros(batch_size_curr, device=device)

        # Train discriminator with real images
        outputs = D(real_images)
        d_loss_real = criterion(outputs, real_labels)
        D.zero_grad()
        d_loss_real.backward()

        # Train discriminator with fake images
        noise = torch.randn(batch_size_curr, latent_dim, 1, 1, device=device)
        fake_images = G(noise)
        outputs = D(fake_images.detach())
        d_loss_fake = criterion(outputs, fake_labels)
        d_loss_fake.backward()
        optimizerD.step()

        # Train generator
        G.zero_grad()
        outputs = D(fake_images)
        g_loss = criterion(outputs, real_labels)  # wants discriminator to mistake fakes for real
        g_loss.backward()
        optimizerG.step()

        if (i+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], '
                  f'D_loss: {(d_loss_real + d_loss_fake).item():.4f}, G_loss: {g_loss.item():.4f}')

    # Save images for monitoring
    with torch.no_grad():
        fake_images = G(fixed_noise).detach().cpu()
    plt.figure(figsize=(8,8))
    plt.axis('off')
    plt.title(f'Generated Images at Epoch {epoch+1}')
    plt.imshow(np.transpose(denorm(fake_images), (0,2,3,1)).reshape(8,8,image_size,image_size,3).swapaxes(1,2).reshape(8*image_size,8*image_size,3))
    plt.show()


Output hidden; open in https://colab.research.google.com to view.