In [1]:
import torch
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

In [2]:
class discriminator(nn.Module):
  def __init__(self, img_channels, features_d):
    super(discriminator, self).__init__()
    self.disc = nn.Sequential(
        nn.Conv2d(in_channels=img_channels,
                  out_channels=features_d,
                  kernel_size=4,
                  stride=2,
                  padding=1),
        nn.LeakyReLU(0.2),
        self._block(features_d, features_d * 2, 4, 2, 1),
        self._block(features_d * 2, features_d * 4, 4, 2, 1),
        self._block(features_d * 4, features_d * 8, 4, 2, 1),
        nn.Conv2d(features_d * 8, 1, 4, 2, 0),
        nn.Sigmoid()
    )

  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels=in_channels,
                  out_channels=out_channels,
                  kernel_size=kernel_size,
                  stride=stride,
                  padding=padding,
                  bias=False),
        nn.BatchNorm2d(num_features=out_channels),
        nn.LeakyReLU(0.2),
    )
  def forward(self, x):
    return self.disc(x)

In [3]:
class Generator(nn.Module):
  def __init__(self, z_dim, img_channels,features_g):
    super(Generator, self).__init__()
    self.gen = nn.Sequential(
        self._block(z_dim, features_g * 16, 4, 1, 0),
        self._block(features_g * 16, features_g * 8, 4, 2, 1),
        self._block(features_g * 8, features_g * 4, 4, 2, 1),
        self._block(features_g * 4, features_g * 2, 4, 2, 1),
        nn.ConvTranspose2d(features_g * 2, img_channels, 4, 2, 1),
        nn.Tanh()
    )

  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels=in_channels,
                           out_channels=out_channels,
                           kernel_size=kernel_size,
                           stride=stride,
                           padding=padding,
                           bias=False),
        nn.BatchNorm2d(num_features=out_channels),
        nn.ReLU()
    )
  def forward(self, x):
    return self.gen(x)

In [4]:
def initialize_weights(model):
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.BatchNorm2d, nn.ConvTranspose2d)):
      nn.init.normal_(m.weight.data, 0.0, 0.02)

In [5]:
def test():
  N, in_channels, H, W = 8, 3, 64, 64
  z_dim = 100
  x = torch.randn((N, in_channels, H, W))
  disc = discriminator(in_channels, 8)
  initialize_weights(disc)
  assert disc(x).shape == (N, 1, 1, 1), f"Disc Failed, Shape: {disc(x).shape}"
  noise = torch.randn((N, z_dim, 1, 1))
  gen = Generator(z_dim, in_channels, 8)
  initialize_weights(gen)
  assert gen(noise).shape == (N, in_channels, H, W), f'Gen Failed, Shape: {gen(noise).shape}'



In [6]:
test()

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 128
EPOCHS = 5
lr = 2e-3
IMG_SIZE = 64
Z_DIM = 100
IMG_CHANNELS = 1
FEATURES_D = 64
FEATURES_G = 64

In [8]:
transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.5 for _ in range(IMG_CHANNELS)], [0.5 for _ in range(IMG_CHANNELS)])
])

In [9]:
dataset = datasets.MNIST(root='datasets/', train=True, transform=transform, download=True)
loader = DataLoader(dataset, shuffle=True, batch_size=BATCH_SIZE)

gen = Generator(Z_DIM, IMG_CHANNELS, FEATURES_G).to(device)
initialize_weights(gen)
fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)

disc = discriminator(IMG_CHANNELS, FEATURES_D).to(device)
initialize_weights(disc)

In [10]:
opt_D = optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))
opt_G = optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))

In [11]:
loss_fn = nn.BCELoss()

In [12]:
writer_real = SummaryWriter('logs/real')
writer_fake = SummaryWriter('logs/fake')
step = 0

In [13]:
gen.train()
disc.train()

discriminator(
  (disc): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (4): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (5): Conv2d(512, 1, kernel_size=(4, 4), stride=(2, 2))
    (6): Sigmoid()
  )
)

In [14]:
for epoch in range(EPOCHS):
  for batch_idx, (real, _) in enumerate(loader):

    real = real.to(device)
    noise = torch.randn((BATCH_SIZE, Z_DIM, 1, 1)).to(device)

    disc_real = disc(real).reshape(-1)
    fake = gen(noise)
    disc_fake = disc(fake).reshape(-1)

    disc.zero_grad()

    Dloss_real = loss_fn(disc_real, torch.ones_like(disc_real))
    Dloss_fake = loss_fn(disc_fake, torch.zeros_like(disc_fake))

    Dloss = (Dloss_real + Dloss_fake) / 2
    Dloss.backward(retain_graph=True)
    opt_D.step()

    gen.zero_grad()

    output = disc(fake).reshape(-1)
    Gloss = loss_fn(output, torch.ones_like(output))

    Gloss.backward()
    opt_G.step()

    if batch_idx == 0:
      print(f"Epoch: {epoch}, Dloss: {Dloss:.4f}, Gloss: {Gloss:.4f}")

      with torch.inference_mode():

        fake = gen(fixed_noise)

        image_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
        image_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)

        writer_real.add_image('Real', image_grid_real, global_step=step)
        writer_fake.add_image('Fake', image_grid_fake, global_step=step)

        step += 1


Epoch: 0, Dloss: 0.6927, Gloss: 0.8946


KeyboardInterrupt: 