In [1]:
import torch
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()

        # in the paper they don't use the batchnorm at the earlier layers
        self.net = nn.Sequential(
            # Input: N x channel_img x 64 x 64
            nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1), # 32 x 32
            nn.LeakyReLU(0.2),
            self._block(features_d, features_d*2, kernel_size=4, stride=2, padding=1), # 16 x 16
            self._block(features_d*2, features_d*4, kernel_size=4, stride=2, padding=1), # 8 x 8
            self._block(features_d*4, features_d*8, kernel_size=4, stride=2, padding=1), # 4 x 4
            nn.Conv2d(features_d*8, out_channels=1, kernel_size=4, stride=2, padding=0), # the output is only 1 channel, N x 1 x 1 x 1
            nn.Sigmoid() # N x 1 x 1 x 1
        )


    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False # bias is set to False because we want to use BatchNorm
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )
    
    def forward(self, x):
        return self.net(x)
    
class Generator(nn.Module):
    def __init__(self, z_dim, channel_img, feature_g):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            # Input: N x z_dim x 1 x 1
            self._block(z_dim, feature_g*16, kernel_size=4, stride=1, padding=0), # N x feature_g*16 x 4 x 4
            self._block(feature_g*16, feature_g*8, kernel_size=4, stride=2, padding=1), # 8 x 8
            self._block(feature_g*8, feature_g*4, kernel_size=4, stride=2, padding=1), # 16 x 16
            self._block(feature_g*4, feature_g*2, kernel_size=4, stride=2, padding=1), # 32 x 32
            nn.ConvTranspose2d( # for input to the discriminator, 64 x 64
                feature_g*2,
                channel_img,
                kernel_size=4,
                stride=2,
                padding=1
            ),
            nn.Tanh() # good for image generation output ranges: [-1, 1]
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(), # why ReLU? because based on paper
        )
    
    def forward(self, x):
        return self.net(x)
    
def initialize_weights(models):
    for m in models.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

        if isinstance(m, nn.ConvTranspose2d):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

        if isinstance(m, nn.BatchNorm2d):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

def test():
    N, in_channels, H, W = 3, 3, 64, 64
    z_dim = 100
    X = torch.randn((N, in_channels, H, W))
    z = torch.randn((N, z_dim, 1, 1))

    disc = Discriminator(channels_img=in_channels, features_d=8)
    initialize_weights(disc)
    assert disc(X).shape == (N, 1, 1, 1)

    gen = Generator(z_dim=z_dim, channel_img=in_channels, feature_g=8)
    initialize_weights(gen)
    assert gen(z).shape == (N, in_channels, H, W)

    print("Success")

In [2]:
test()

Success


In [3]:
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

In [4]:
# Hyperparamters, etc.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 2e-4
BATCH_SIZE = 128
IMAGE_SIZE = 64
CHANNELS_IMG = 1
Z_DIM = 100
NUM_EPOCHS = 1
FEATURES_DISC = 64
FEATURES_GEN = 64

transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.5 for _ in range(CHANNELS_IMG)], 
        [0.5 for _ in range(CHANNELS_IMG)]
    ),
])

dataset = datasets.MNIST(root='dataset/', train=True, transform=transform, download=True)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# Model Initialization
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
initialize_weights(gen)
initialize_weights(disc)

opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

criterion = nn.BCELoss()

fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

gen.train()
disc.train()

for epoch in range(NUM_EPOCHS):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.to(device)

        # generate noise
        noise = torch.randn((BATCH_SIZE, Z_DIM, 1, 1)).to(device)
        fake = gen(noise)

        ### Train Discriminator max log(D(x)) + log(1 - D(G(z)))
        disc_real = disc(real).reshape(-1) # N
        loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).reshape(-1)
        loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        loss_disc = (loss_disc_real + loss_disc_fake) * 0.5

        disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()

        ### Train Generator min log(1 - D(G(z))) <---> max log(D(G(z)))
        fake = gen(noise)
        output = disc(fake).reshape(-1)
        loss_gen = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} Loss D: {loss_disc:.4f}, Loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise)
                
                img_grid_real = torchvision.utils.make_grid(
                    real[:32], normalize=True
                )

                img_grid_fake = torchvision.utils.make_grid(
                    fake[:32], normalize=True
                )

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_real.add_image("Fake", img_grid_fake, global_step=step)

        step += 1

Epoch [0/1] Batch 0/469 Loss D: 0.6881, Loss G: 0.7881
Epoch [0/1] Batch 100/469 Loss D: 0.0145, Loss G: 4.1124
Epoch [0/1] Batch 200/469 Loss D: 0.0991, Loss G: 2.8517
Epoch [0/1] Batch 300/469 Loss D: 0.5433, Loss G: 0.8265
Epoch [0/1] Batch 400/469 Loss D: 0.5789, Loss G: 0.9421
