In [30]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter

In [29]:
class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d): 
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            # input: N x channels_img x 64 x 64 
            nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1), # 32x32
            nn.LeakyReLU(0.2),
            self.block(features_d, features_d*2, 4, 2, 1), # 16x16
            self.block(features_d*2, features_d*4, 4, 2, 1), # 8x8
            self.block(features_d*4, features_d*8, 4, 2, 1), # 4x4
            nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0), # 1x1
            nn.Sigmoid(),
        )

    def block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels, 
                out_channels, 
                kernel_size, 
                stride, 
                padding,
                bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )
    
    def forward(self, x):
        return self.disc(x)


In [20]:
class Generator(nn.Module):
    def __init__(self, z_dim, channels_img, features_g):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            # input: N x z_dim x 1 x 1
            self.block(z_dim, features_g*16, 4, 1, 0), # N x f_g*16 x 4 x 4
            self.block(features_g*16, features_g*8, 4, 2, 1), # 8x8
            self.block(features_g*8, features_g*4, 4, 2, 1), # 16x16
            self.block(features_g*4, features_g*2, 4, 2, 1), # 32x32
            nn.ConvTranspose2d(features_g*2, channels_img, kernel_size=4, stride=2, padding=1), # 64x64
            nn.Tanh(),
        )
            
    def block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels, 
                out_channels, 
                kernel_size, 
                stride, 
                padding,
                bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
    
    def forward(self, x):
        return self.net(x)

In [21]:
def initalize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [27]:
def test():
    N, in_channels, H, W = 8, 3, 64, 64
    z_dim = 100
    x = torch.randn((N, in_channels, H, W))

    disc = Discriminator(in_channels, features_d=8)
    initalize_weights(disc)
    assert disc(x).shape == (N, 1, 1, 1), "Discriminator test failed"

    gen = Generator(z_dim, in_channels, features_g=8)
    initalize_weights(gen)
    z = torch.randn((N, z_dim, 1, 1))
    assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"

    print("All tests passed")

test()

All tests passed


In [37]:
# Hyperparameters
lr = 0.0002
batch_size = 128
image_size = 64
img_channels = 1  # For grayscale images, use 1; for RGB, use 3.
noise_dim = 100
feature_gen = 64
feature_disc = 64
num_epochs = 5

# Initialize models
device = "cuda" if torch.cuda.is_available() else "cpu"
gen = Generator(noise_dim, img_channels, feature_gen).to(device)
disc = Discriminator(img_channels, feature_disc).to(device)
initalize_weights(gen)
initalize_weights(disc)

# Optimizers
optimizer_gen = optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_disc = optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))

# Loss
criterion = nn.BCELoss()


In [33]:
# Data loading and preprocessing
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5 for _ in range(img_channels)], [0.5 for _ in range(img_channels)])  # Normalize to [-1, 1]
])

dataset = datasets.MNIST(root="data/", train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [38]:
fixed_noise = torch.randn((32, noise_dim, 1, 1)).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

gen.train()
disc.train()

Discriminator(
  (disc): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (4): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (5): Conv2d(512, 1, kernel_size=(4, 4), stride=(2, 2))
    (6): Sigmoid()
  )
)

In [43]:
from torchvision import datasets, transforms
import torchvision
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(dataloader):
        real = real.to(device)
        noise = torch.randn((batch_size, noise_dim, 1, 1)).to(device)
        fake = gen(noise)
        ## Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        # disc_real = disc(real)
        # print(disc_real.shape)
        disc_real = disc(real).reshape(-1)
        # print(disc_real.shape)
        loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
        # disc_fake = disc(gen(noise))
        # print(disc_fake.shape)
        disc_fake = disc(gen(noise)).reshape(-1)
        # print(disc_fake.shape)
        loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        loss_disc = (loss_disc_real + loss_disc_fake) / 2
        disc.zero_grad()
        loss_disc.backward(retain_graph=True)
        optimizer_disc.step()

        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        output = disc(fake).reshape(-1)
        loss_gen = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        loss_gen.backward()
        optimizer_gen.step()

        if batch_idx % 100 == 0:
            # print(
            #     f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(dataloader)} \
            #       Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
            # )

            with torch.no_grad():
                fake = gen(fixed_noise)
                img_grid_real = torchvision.utils.make_grid(
                    real[:32], 
                    normalize=True
                )
                img_grid_fake = torchvision.utils.make_grid(
                    fake[:32], 
                    normalize=True
                )

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)

            step += 1


In [44]:
# Save model
torch.save(gen.state_dict(), "gen.pth")
torch.save(disc.state_dict(), "disc.pth")