In [21]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

In [3]:
class Discriminator(nn.Module):
  def __init__(self, in_features):
    super().__init__()
    self.disc = nn.Sequential(
        nn.Linear(in_features=in_features,
                  out_features=256),
        nn.LeakyReLU(0.2),
        nn.Linear(in_features=256,
                  out_features=1),
        nn.Sigmoid()
    )
  def forward(self, x):
    return self.disc(x)

In [14]:
class Generator(nn.Module):
  def __init__(self, z_dim, img_dim):
    super().__init__()
    self.gen = nn.Sequential(
        nn.Linear(in_features=z_dim,
                  out_features=256),
        nn.LeakyReLU(0.2),
        nn.Linear(in_features=256,
                  out_features=img_dim),
        nn.Tanh()
    )

  def forward(self, x):
    return self.gen(x)

In [15]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 32
img_dim = 28 * 28
epochs = 50
z_dim = 64
lr = 2e-3


In [16]:
loss_fn = nn.BCELoss().to(device)

disc = Discriminator(in_features=784).to(device)
d_optim = optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))

gen = Generator(z_dim=z_dim, img_dim=img_dim).to(device)
g_optim = optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))

In [30]:
fixed_noise = torch.randn(batch_size, z_dim).to(device)

In [31]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [32]:
data = datasets.MNIST(root='data/', transform=transform, download=True)
loader = DataLoader(dataset=data, batch_size=32, shuffle=True)

In [33]:
writer_fake = SummaryWriter('runs/GANS_MNIST/fake')
writer_real = SummaryWriter('runs/GANS_MNIST/real')
step = 0

In [35]:
for epoch in range(epochs):
  for batch_idx, (real_imgs, _) in enumerate(loader):

    real = real_imgs.view(-1, 784).to(device)

    batch_size = real.shape[0]

    noise = torch.randn(batch_size, z_dim).to(device)
    fake = gen(noise)

    disc.train()

    d_optim.zero_grad()

    disc_real = disc(real).view(-1)
    Dloss_real = loss_fn(disc_real, torch.ones_like(disc_real))

    disc_fake = disc(fake).view(-1)
    Dloss_fake = loss_fn(disc_fake, torch.zeros_like(disc_fake))

    Dloss = (Dloss_real + Dloss_fake) / 2

    Dloss.backward(retain_graph=True)
    d_optim.step()

    gen.train()

    g_optim.zero_grad()

    output = disc(fake).view(-1)
    Gloss = loss_fn(output, torch.ones_like(output))

    Gloss.backward()
    g_optim.step()

    if batch_idx == 0:
      print(f'Epoch: {epoch}/{epochs}, Disc Loss: {Dloss}, Gen Loss: {Gloss}')

      with torch.inference_mode():
        fake = gen(fixed_noise).view(-1, 1, 28, 28)
        data = real.view(-1, 1, 28, 28)

        img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
        img_grid_real = torchvision.utils.make_grid(data, normalize=True)

        writer_fake.add_image('Mnist Fake Images', img_grid_fake, global_step=step)
        writer_real.add_image('Mnist Real Images', img_grid_real, global_step=step)

        step += 1


Epoch: 0/50, Disc Loss: 0.7623282670974731, Gen Loss: 0.4016761779785156
Epoch: 1/50, Disc Loss: 0.6630624532699585, Gen Loss: 0.8851467370986938
Epoch: 2/50, Disc Loss: 0.6129927635192871, Gen Loss: 0.6792729496955872
Epoch: 3/50, Disc Loss: 0.5401791334152222, Gen Loss: 1.0731406211853027
Epoch: 4/50, Disc Loss: 0.6210968494415283, Gen Loss: 1.8617959022521973
Epoch: 5/50, Disc Loss: 0.6433295011520386, Gen Loss: 0.9691511392593384
Epoch: 6/50, Disc Loss: 0.6595112681388855, Gen Loss: 0.9242030382156372
Epoch: 7/50, Disc Loss: 0.6463932394981384, Gen Loss: 1.5764861106872559
Epoch: 8/50, Disc Loss: 0.7339122295379639, Gen Loss: 1.8166894912719727
Epoch: 9/50, Disc Loss: 0.5658252239227295, Gen Loss: 1.343628168106079
Epoch: 10/50, Disc Loss: 0.567017674446106, Gen Loss: 1.1390042304992676
Epoch: 11/50, Disc Loss: 0.6110943555831909, Gen Loss: 0.6938178539276123
Epoch: 12/50, Disc Loss: 0.5839024782180786, Gen Loss: 1.2852401733398438
Epoch: 13/50, Disc Loss: 0.6414852142333984, Gen L

In [38]:
!tensorboard --logdir runs/GANS_MNIST/

^C
