In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.nn.utils import spectral_norm
# Hyperparameters
batch_size = 64
image_size = 64
nc = 3  # Number of channels in the training images. For STL-10, it's 3 (RGB)
nz = 100  # Size of z latent vector (i.e. size of generator input)
ngf = 64  # Size of feature maps in generator
ndf = 64  # Size of feature maps in discriminator
num_epochs = 5
lr = 0.0002
beta1 = 0.5
beta2 = 0.999
lambda_gp = 10
n_generator = 5  # Number of generator updates per discriminator update

# Dataset preparation
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

dataset = datasets.STL10(root='../../data/stl10', split='train+unlabeled', download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)

# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input).view(-1, 1).squeeze(1)

# Initialize models
netG = Generator().to(device)
netD = Discriminator().to(device)

# Initialize weights
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

netG.apply(weights_init)
netD.apply(weights_init)

def compute_0_gp(D, real_samples, fake_samples, lambda_gp):
    alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=real_samples.device)
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = torch.ones(d_interpolates.size(), device=real_samples.device)
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    gradient_penalty = ((gradients.norm(2, dim=1) ** 2).mean()) * lambda_gp
    return gradient_penalty

# Optimizers
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, beta2))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, beta2))


# Training Loop
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):
        # Update Discriminator
        netD.zero_grad()
        real_images = data[0].to(device)
        batch_size = real_images.size(0)
        noise = torch.randn(batch_size, nz, 1, 1, device=device)
        fake_images = netG(noise)

        real_output = netD(real_images)
        fake_output = netD(fake_images.detach())
        d_loss_real = torch.mean(real_output)
        d_loss_fake = torch.mean(fake_output)
        gradient_penalty = compute_0_gp(netD, real_images, fake_images, lambda_gp)

        d_loss = d_loss_fake - d_loss_real + gradient_penalty
        d_loss.backward()
        optimizerD.step()

        # Update Generator n_generator times for each discriminator update
        for _ in range(n_generator):
            netG.zero_grad()
            noise = torch.randn(batch_size, nz, 1, 1, device=device)
            fake_images = netG(noise)
            output = netD(fake_images)
            g_loss = -torch.mean(output)
            g_loss.backward()
            optimizerG.step()

        # Print losses occasionally
        if i % 100 == 0:
            print(f'[{epoch}/{num_epochs}][{i}/{len(dataloader)}] '
                  f'Loss_D: {d_loss.item():.4f} Loss_G: {g_loss.item():.4f}')

        # Save fake images occasionally
        if i % 1000 == 0:
            vutils.save_image(fake_images.data[:64], f'images/fake_samples_epoch_{epoch:03d}_{i:04d}.png', normalize=True)
            torch.save(netG.state_dict(), f'saved_model/generator_{epoch}.pth')

    # Save the final model
torch.save(netG.state_dict(), 'generator.pth')
torch.save(netD.state_dict(), 'discriminator.pth')

import matplotlib.pyplot as plt

# Load the models
netG.load_state_dict(torch.load('generator.pth'))
netG.eval()

# Generate samples
noise = torch.randn(64, nz, 1, 1, device=device)
fake_images = netG(noise)

# Display the generated images
grid = vutils.make_grid(fake_images, padding=2, normalize=True)
plt.imshow(grid.cpu().numpy().transpose((1, 2, 0)))
plt.show()

Files already downloaded and verified
[0/5][0/1641] Loss_D: 0.0977 Loss_G: -0.4770
[0/5][100/1641] Loss_D: 0.0000 Loss_G: -1.0000
[0/5][200/1641] Loss_D: 0.0000 Loss_G: -1.0000


KeyboardInterrupt: 