In [1]:
import torch
from torch import nn, Tensor
import torch.optim as optim
from denoising_diffusion_pytorch import Unet

from utils import get_data_generator, get_data_tensor
from config import load_config

In [2]:
class UNetGenerator(nn.Module):
    def __init__(self, lat_dim: int = 32):
        super().__init__()
        self.latent_dim = lat_dim
        self.project_z = nn.Linear(lat_dim, 32 * 32)  # Project z into 64x64 feature map
        self.unet = Unet(
            dim=64,
            dim_mults=(1, 2, 4, 4),
            channels=1,
            # flash_attn=True,
        )

    def forward(self, z: Tensor) -> Tensor:
        x = self.project_z(z).view(z.shape[0], 1, 32, 32)  # Reshape into (B, C, H, W)
        t = torch.zeros(len(z), device=z.device)
        return self.unet(x, t)  # U-Net outputs an image


class Discriminator(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),  # (B, 64, 16, 16)
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # (B, 128, 8, 8)
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),  # (B, 256, 4, 4)
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 1, kernel_size=4, stride=1),  # (B, 1, 1, 1)
            nn.Sigmoid()
        )

    def forward(self, img: Tensor) -> Tensor:
        return self.model(img).view(-1, 1)  # Output a single probability per image


class GAN(nn.Module):
    def __init__(self, lat_dim: int = 32):
        super().__init__()
        self.generator = UNetGenerator(lat_dim)
        self.discriminator = Discriminator()

    def generate(self, z: Tensor) -> Tensor:
        return self.generator(z)

    def discriminate(self, img: Tensor) -> Tensor:
        return self.discriminator(img)


In [None]:
#data
config = load_config()
data = get_data_tensor(config)
data_generator = get_data_generator(data)

# Hyperparameters
latent_dim = 32
lr = 2e-4
batch_size = 64
total_iters = 10000

# Initialize models
device = "cuda" if torch.cuda.is_available() else "cpu"
gan = GAN(latent_dim).to(device)
generator, discriminator = gan.generator, gan.discriminator

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

# Loss function
criterion = nn.BCELoss()

# Training loop
for it in range(total_iters):
    real_imgs = next(data_generator).to(device)
    batch_size = real_imgs.size(0)

    # Labels
    real_labels = torch.ones(batch_size, 1)
    fake_labels = torch.zeros(batch_size, 1)

    # Train Discriminator
    optimizer_D.zero_grad()

    real_loss = criterion(discriminator(real_imgs), real_labels)
    noise = torch.randn(batch_size, latent_dim)  # Sample noise
    fake_imgs = generator(noise)
    fake_loss = criterion(discriminator(fake_imgs.detach()), fake_labels)

    loss_D = real_loss + fake_loss
    loss_D.backward()
    optimizer_D.step()

    # Train Generator
    optimizer_G.zero_grad()
    loss_G = criterion(discriminator(fake_imgs), real_labels)  # Want fake images to be classified as real
    loss_G.backward()
    optimizer_G.step()

    print(f"Iter {it}: D Loss = {loss_D.item():.4f}, G Loss = {loss_G.item():.4f}")

Iter 0: D Loss = 1.5383, G Loss = 1.1410
