<a href="https://colab.research.google.com/github/SonnetSaif/vanilla-GAN-from-scratch_PyTorch/blob/main/vanilla_GAN_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4   # 3e-4 is best for adam
z_dim = 64
img_dim = 18 *28 * 1
batch_size = 32
num_epochs = 50

In [None]:
class Generator(nn.Module):
  def __init__(self, z_dim, img_dim):
    super().__init__()
    self.gen = nn.Sequential(
      nn.Linear(z_dim, 256),
      nn.LeakyReLU(0.1),
      nn.Linear(256, img_dim),
      nn.Tanh()
    )

  def forward(self, x):
    return self.gen(x)

In [None]:
class Discriminator(nn.Module):
  def __init__(self, img_dim):
    super().__init__()
    self.disc = nn.Sequential(
      nn.Linear(z_dim, 256),
      nn.LeakyReLU(0.1),
      nn.Linear(256, img_dim),
      nn.Sigmoid()
    )

  def forward(self, x):
    return self.disc(x)

In [None]:
gen = Generator(z_dim, img_dim).to(device)
disc = Discriminator(img_dim).to(device)
transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, ), (0.5, ))
])
dataset = datasets.MNIST(root='dataset/', transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
optim_gen = optim.Adam(gen.parameters(), lr=lr)
optim_disc = optim.Adam(disc.parameters(), lr=lr)
criterion = nn.BCELoss()

In [None]:
# Tensorboard
summaryWriter_fake = SummaryWriter(f"GAN-MNIST/fake")
summaryWriter_real = SummaryWriter(f"GAN-MNIST/real")
step = 0

In [None]:
for epoch in range(num_epochs):
  for batch_idx, (real, _) in enumerate(loader):
    real_img = real.view(-1, 784).to(device)
    batch_size = real.shape[0]

    noise = torch.randn(batch_size, z_dim).to(device)
    fake_img = gen(noise)
    disc_real = disc(real_img).view(-1)
    loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
    disc_fake = disc(fake_img).view(-1)
    loss_disc_fake = criterion(disc_fake, torch.ones_like(disc_fake))
    loss_disc_total = (loss_disc_real + loss_disc_fake) / 2
    disc.zero_grad()
    loss_disc_total.backward()
    optim_disc.step()