In [None]:
## Flowers102 dataset, but here not only are we using Wasserstein loss, but also we train the Discriminator 5 times more than the Generator - this should in theory make for a better Discriminator and allow the Generator to train faster, but Generator gets a 2.5x bump in learning rate to not get overpowered by the Discriminator
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset
from torchvision import datasets, transforms
from torchvision.utils import save_image

# Define constants
IMG_SIZE = 64
LATENT_DIM = 100
BATCH_SIZE = 64
EPOCHS = 100
CRITIC_ITERATIONS = 5
CLIP_VALUE = 0.01

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.init_size = IMG_SIZE // 4
        self.l1 = nn.Sequential(nn.Linear(LATENT_DIM, 128 * self.init_size ** 2))
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 3, 3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 16, 3, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Linear(128 * 4 * 4, 1)
        )

    def forward(self, img):
        validity = self.model(img)
        return validity
transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])
dataloader = DataLoader(
    ConcatDataset([datasets.Flowers102(root='../../data/flowers', split='train', download=True, transform=transform),
                   datasets.Flowers102(root='../../data/flowers', split='val', download=True, transform=transform),
                   datasets.Flowers102(root='../../data/flowers', split='test', download=True, transform=transform)]),
    batch_size=BATCH_SIZE, shuffle=True
)

# Initialize models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Optimizers: Adam optimizers as suggested for WGAN
optimizer_G = optim.Adam(generator.parameters(), lr=0.0005, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Training loop
for epoch in range(EPOCHS):
    for i, (imgs, _) in enumerate(dataloader):

        # Configure input
        real_imgs = imgs.to(device)
        batch_size = real_imgs.size(0)

        # Train Discriminator
        optimizer_D.zero_grad()

        # Generate a batch of fake images
        z = torch.randn(batch_size, LATENT_DIM).to(device)
        fake_imgs = generator(z).detach()

        # Compute the discriminator loss
        loss_D = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs))
        loss_D.backward()
        optimizer_D.step()

        # Clip weights of discriminator to enforce Lipschitz constraint
        for p in discriminator.parameters():
            p.data.clamp_(-CLIP_VALUE, CLIP_VALUE)

        # Train Generator every n_critic steps
        if i % CRITIC_ITERATIONS == 0:
            optimizer_G.zero_grad()

            # Generate a batch of fake images
            gen_imgs = generator(z)

            # Compute the generator loss
            loss_G = -torch.mean(discriminator(gen_imgs))
            loss_G.backward()
            optimizer_G.step()

        # Print the progress
        print(f"[Epoch {epoch}/{EPOCHS}] [Batch {i}/{len(dataloader)}] [D loss: {loss_D.item()}] [G loss: {loss_G.item()}]")

    # Save sample images at intervals
    if epoch % 10 == 0:
        save_image(gen_imgs.data[:25], f"images/{epoch}_stl_wasserstein.png", nrow=5, normalize=True)
        # Save the model
        torch.save(generator.state_dict(), f"saved_model_wasserstein_stl_{epoch}.pth")
save_image(gen_imgs.data[:25], f"images/{epoch}_stl_wasserstein.png", nrow=5, normalize=True)
# Save the model
torch.save(generator.state_dict(), f"saved_model_wasserstein_stl_{epoch}.pth")