# LP-GAN Implementation
This notebook implements a Lipschitz Penalty GAN (LP-GAN) using architectures inspired by the Locatello VAE model.

In [None]:
import torch
from torch.utils.data import DataLoader

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset

In [None]:
from datasets import get_dataset

Shapes3D = get_dataset("shapes3d")
shapes3d_dataset = Shapes3D(selected_factors='all', not_selected_factors_index_value=None)

num_workers_3dshapes = 4
data_loader = DataLoader(shapes3d_dataset, 
                         batch_size=64, 
                         shuffle=True, 
                         num_workers=num_workers_3dshapes, 
                         pin_memory=True,
                         persistent_workers=True
                         )

In [None]:
import torch
from torch import nn
import numpy as np
from tqdm import tqdm

# Generator (based on Locatello's Decoder)
class Generator(nn.Module):
    def __init__(self, img_size, latent_dim=10):
        super(Generator, self).__init__()
        self.img_size = img_size
        self.latent_dim = latent_dim

        kernel_size = 4
        n_chan = self.img_size[0]
        self.reshape = (64, kernel_size, kernel_size)

        self.lin1 = nn.Linear(latent_dim, 256)
        self.lin2 = nn.Linear(256, np.prod(self.reshape))

        cnn_kwargs = dict(stride=2, padding=1)
        self.convT1 = nn.ConvTranspose2d(64, 64, kernel_size, **cnn_kwargs)
        self.convT2 = nn.ConvTranspose2d(64, 32, kernel_size, **cnn_kwargs)
        self.convT3 = nn.ConvTranspose2d(32, 32, kernel_size, **cnn_kwargs)        
        self.convT4 = nn.ConvTranspose2d(32, n_chan, kernel_size, **cnn_kwargs)

    def forward(self, z):
        batch_size = z.size(0)
        x = torch.nn.functional.leaky_relu(self.lin1(z))
        x = torch.nn.functional.leaky_relu(self.lin2(x))
        x = x.view(batch_size, *self.reshape)
        x = torch.nn.functional.leaky_relu(self.convT1(x))
        x = torch.nn.functional.leaky_relu(self.convT2(x))
        x = torch.nn.functional.leaky_relu(self.convT3(x))
        return torch.sigmoid(self.convT4(x)) # Use sigmoid for image output

# Discriminator (based on Locatello's Encoder, with Lipschitz Penalty)
class Discriminator(nn.Module):
    def __init__(self, img_size):
        super(Discriminator, self).__init__()
        self.img_size = img_size

        kernel_size = 4
        n_chan = self.img_size[0]

        assert_str = "This architecture requires 64x64 inputs."
        assert self.img_size[-2] == self.img_size[-1] == 64, assert_str

        cnn_kwargs = dict(stride=2, padding=1)
        self.conv1 = nn.Conv2d(n_chan, 32, kernel_size, **cnn_kwargs)
        self.conv2 = nn.Conv2d(32, 32, kernel_size, **cnn_kwargs)
        self.conv3 = nn.Conv2d(32, 64, kernel_size, **cnn_kwargs)
        self.conv4 = nn.Conv2d(64, 64, kernel_size, **cnn_kwargs)

        self.lin = nn.Linear(int((64/(2**4))**2 * 64), 256)
        self.output_layer = nn.Linear(256, 1) # Output a single value for real/fake

    def forward(self, x):
        batch_size = x.size(0)
        x = torch.nn.functional.leaky_relu(self.conv1(x))
        x = torch.nn.functional.leaky_relu(self.conv2(x))
        x = torch.nn.functional.leaky_relu(self.conv3(x))
        x = torch.nn.functional.leaky_relu(self.conv4(x))
        x = x.view((batch_size, -1))
        x = torch.nn.functional.leaky_relu(self.lin(x))
        return self.output_layer(x)

# Lipschitz Penalty function (for LP-GAN)
def calculate_lipschitz_penalty(discriminator, real_images, fake_images, device):
    alpha = torch.rand(real_images.size(0), 1, 1, 1).to(device)
    interpolated_images = (alpha * real_images + (1 - alpha) * fake_images).requires_grad_(True)
    
    interpolated_output = discriminator(interpolated_images)
    
    gradients = torch.autograd.grad(outputs=interpolated_output, inputs=interpolated_images,
                                    grad_outputs=torch.ones_like(interpolated_output),
                                    create_graph=True, retain_graph=True)[0]
    
    gradients = gradients.view(gradients.size(0), -1)
    lipschitz_penalty = (gradients.norm(2, 1) - 1).pow(2).mean()
    return lipschitz_penalty

In [None]:
# Initialize Generator and Discriminator
img_size = (3, 64, 64)  # Assuming RGB images of size 64x64
gen = Generator(img_size, latent_dim=10).to(device)
disc = Discriminator(img_size).to(device)

# Optimizers
gen_optimizer = torch.optim.Adam(gen.parameters(), lr=0.0001, betas=(0.5, 0.9))
disc_optimizer = torch.optim.Adam(disc.parameters(), lr=0.0001, betas=(0.5, 0.9))

# Training parameters
num_epochs = 5
critic_iterations = 10 # Discriminator steps per generator step
lambda_lp = 10 # Lipschitz penalty lambda

# Training Loop
for epoch in range(num_epochs):
    progress_bar = tqdm(data_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for i, (real_images, _) in enumerate(progress_bar):
        batch_size = real_images.size(0)
        real_images = real_images.to(device)
        
        # Train Discriminator
        for _ in range(critic_iterations):
            disc_optimizer.zero_grad()
            
            # Real images
            real_output = disc(real_images)
            
            # Fake images
            z = torch.randn(batch_size, gen.latent_dim).to(device)
            fake_images = gen(z).detach()
            fake_output = disc(fake_images)
            
            # Calculate Lipschitz penalty
            lp = calculate_lipschitz_penalty(disc, real_images, fake_images, device)
            
            # Discriminator loss
            d_loss = -torch.mean(real_output) + torch.mean(fake_output) + lambda_lp * lp
            d_loss.backward()
            disc_optimizer.step()
            
        # Train Generator
        gen_optimizer.zero_grad()
        z = torch.randn(batch_size, gen.latent_dim).to(device)
        fake_images = gen(z)
        output = disc(fake_images)
        g_loss = -torch.mean(output) # Generator tries to make fake images look real
        g_loss.backward()
        gen_optimizer.step()
            
        # Update tqdm postfix with current losses
        if (i+1) % 100 == 0:
            progress_bar.set_postfix(D_Loss=d_loss.item(), G_Loss=g_loss.item())

print("Training complete!")

In [None]:
import matplotlib.pyplot as plt

# Generate some fake images for visualization
num_images_to_show = 16
z_sample = torch.randn(num_images_to_show, gen.latent_dim).to(device)
with torch.no_grad():
    generated_images = gen(z_sample).cpu()

# Plot the generated images
fig = plt.figure(figsize=(8, 8))
for i in range(num_images_to_show):
    ax = fig.add_subplot(4, 4, i + 1)
    img = generated_images[i].permute(1, 2, 0).numpy()
    ax.imshow(img)
    ax.axis('off')
plt.suptitle("Generated Images")
plt.show()