<a href="https://colab.research.google.com/github/151ali/lr-pytorch/blob/main/9_Simple_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

In [None]:
import matplotlib.pyplot as plt
import numpy as np

In [None]:
class Discriminator(nn.Module):
  def __init__(self, img_dim):
    super().__init__()

    self.disc = nn.Sequential(
        nn.Linear(img_dim, 128),
        nn.LeakyReLU(0.1),
        nn.Linear(128, 1),
        nn.Sigmoid()
    )

  def forward(self, x):
    return self.disc(x)

In [None]:
class Generator(nn.Module):
  def __init__(self, z_dim, img_dim):
    super().__init__()

    self.gen = nn.Sequential(
        nn.Linear(z_dim, 256),
        nn.LeakyReLU(0.1),
        nn.Linear(256, img_dim),
        nn.Tanh(),
    )
  

  def forward(self, x):
    return self.gen(x )

In [None]:
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
print(device)

cuda


In [None]:
lr    = 3e-4
z_dim = 64
image_size = 28 * 28 *1
batch_size = 16

In [None]:
disc = Discriminator(
    image_size
).to(device)

gen = Generator(
    z_dim,image_size
).to(device)


fixed_noise = torch.randn((batch_size, z_dim)).to(device)

transforms = transforms.Compose([
     transforms.ToTensor(),
     transforms.Normalize((.5),(.5))
])

In [None]:
dataset = datasets.MNIST(
    root="datasets/",
    transform=transforms,
    download=True
)

loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen  = optim.Adam(gen.parameters(), lr=lr)

criterion = nn.BCELoss()

In [None]:
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")

In [None]:
num_epochs = 300
step =0
for epoch in range(num_epochs):
  for idx, (real, _) in enumerate(loader):
    real = real.view(-1,784).to(device)
    batch_size = real.shape[0]

    # train discriminator
    noise = torch.randn(batch_size, z_dim).to(device)
    fake  = gen(noise)

    disc_real = disc(real).view(-1)
    lossD_real = criterion(disc_real, torch.ones_like(disc_real))

    disc_fake = disc(fake).view(-1)
    lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))

    lossD = (lossD_fake+lossD_real)/2
    disc.zero_grad()
    lossD.backward(retain_graph=True)
    opt_disc.step()

    # train generator
    output = disc(fake).view(-1)
    lossG  = criterion(output, torch.ones_like(output))
    gen.zero_grad()
    lossG.backward()
    opt_gen.step()

    if idx == 0:
      print(
          f"epoch {epoch}/{num_epochs} \" loss D :{lossD}, lossG :{lossG}"
      )

      with torch.no_grad():
        fake = gen(fixed_noise).reshape(-1,1,28,28)
        data = real.reshape(-1,1,28,28)
        img_grid_fake = torchvision.utils.make_grid(fake,normalize=True)
        img_grid_real = torchvision.utils.make_grid(real,normalize=True)

        writer_fake.add_image("MNIST fake images", img_grid_fake, global_step=step)
        writer_real.add_image("MNIST real images", img_grid_real, global_step=step)
        step += 1