In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image

# Define constants
IMG_SIZE = 64
LATENT_DIM = 100
BATCH_SIZE = 64
EPOCHS_CIFAR10 = 100
EPOCHS_STL10 = 100
LEARNING_RATE = 0.0002
CRITIC_ITERATIONS = 5
LAMBDA_GP = 10

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.init_size = IMG_SIZE // 4
        self.l1 = nn.Sequential(nn.Linear(LATENT_DIM, 128 * self.init_size ** 2))
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, 3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Linear(512 * 4 * 4, 1)
        )

    def forward(self, img):
        validity = self.model(img)
        return validity
transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

trainset_cifar10 = datasets.CIFAR10(root='../../data/cifar10', train=True, download=True, transform=transform)
trainloader_cifar10 = DataLoader(trainset_cifar10, batch_size=BATCH_SIZE, shuffle=True)

trainset_stl10 = datasets.STL10(root='../../data/stl10', split='train+unlabeled', download=True, transform=transform)
trainloader_stl10 = DataLoader(trainset_stl10, batch_size=BATCH_SIZE, shuffle=True)
def compute_gradient_penalty(D, real_samples, fake_samples):
    """Calculates the gradient penalty loss for WGAN GP"""
    alpha = torch.randn((real_samples.size(0), 1, 1, 1), device=device)
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = torch.ones(real_samples.size(0), 1, requires_grad=False, device=device)
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

def train_wgan_gp(generator, discriminator, dataloader, epochs, save_interval, phase):
    optimizer_G = optim.Adam(generator.parameters(), lr=0.0005, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

    for epoch in range(epochs):
        for i, (imgs, _) in enumerate(dataloader):
            real_imgs = imgs.to(device)
            batch_size = real_imgs.size(0)

            # ---------------------
            #  Train Discriminator
            # ---------------------
            for _ in range(CRITIC_ITERATIONS):
                optimizer_D.zero_grad()

                z = torch.randn(batch_size, LATENT_DIM).to(device)
                fake_imgs = generator(z)

                real_validity = discriminator(real_imgs)
                fake_validity = discriminator(fake_imgs.detach())
                gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data)
                d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + LAMBDA_GP * gradient_penalty

                d_loss.backward()
                optimizer_D.step()

            # -----------------
            #  Train Generator
            # -----------------
            optimizer_G.zero_grad()

            gen_imgs = generator(z)
            fake_validity = discriminator(gen_imgs)
            g_loss = -torch.mean(fake_validity)

            g_loss.backward()
            optimizer_G.step()

            if i % 100 == 0:
                print(f"[Epoch {epoch+1}/{epochs}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")

        if epoch % save_interval == 0:
            save_image(gen_imgs.data[:25], f"images/{phase}_{epoch}_2.png", nrow=5, normalize=True)
            torch.save(generator.state_dict(), f"saved_model/generator2_{phase}_{epoch}.pt")
# Instantiate models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Training on CIFAR-10
train_wgan_gp(generator, discriminator, trainloader_cifar10, EPOCHS_CIFAR10, 5, 'cifar10')

# Fine-tuning on STL-10
train_wgan_gp(generator, discriminator, trainloader_stl10, EPOCHS_STL10, 5, 'stl10')


Files already downloaded and verified
Files already downloaded and verified
[Epoch 1/100] [Batch 0/782] [D loss: -3.5688352584838867] [G loss: 4.271856307983398]
[Epoch 1/100] [Batch 100/782] [D loss: -575.4093627929688] [G loss: 282.6064453125]
[Epoch 1/100] [Batch 200/782] [D loss: 19.078365325927734] [G loss: -456.93023681640625]
[Epoch 1/100] [Batch 300/782] [D loss: -10.034337997436523] [G loss: -462.74908447265625]
[Epoch 1/100] [Batch 400/782] [D loss: -16.070066452026367] [G loss: -433.5330810546875]
[Epoch 1/100] [Batch 500/782] [D loss: -101.57219696044922] [G loss: -313.2353210449219]
[Epoch 1/100] [Batch 600/782] [D loss: -620.7971801757812] [G loss: 137.49774169921875]
[Epoch 1/100] [Batch 700/782] [D loss: -23.38006591796875] [G loss: -465.18389892578125]
[Epoch 2/100] [Batch 0/782] [D loss: -105.55667114257812] [G loss: -395.8480224609375]
[Epoch 2/100] [Batch 100/782] [D loss: -7.579474449157715] [G loss: -439.79876708984375]
[Epoch 2/100] [Batch 200/782] [D loss: -14.5

KeyboardInterrupt: 