<a href="https://colab.research.google.com/github/SonnetSaif/vanilla-GAN-from-scratch_PyTorch/blob/main/vanilla_GAN_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [10]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

In [8]:
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4   # 3e-4 is best for adam
z_dim = 64
img_dim = 18 *28 * 1
batch_size = 32
num_epochs = 50

In [4]:
class Generator(nn.Module):
  def __init__(self, z_dim, img_dim):
    super().__init__()
    self.gen = nn.Sequential(
      nn.Linear(z_dim, 256),
      nn.LeakyReLU(0.1),
      nn.Linear(256, img_dim),
      nn.Tanh()
    )

  def forward(self, x):
    return self.gen(x)

In [5]:
class Discriminator(nn.Module):
  def __init__(self, img_dim):
    super().__init__()
    self.disc = nn.Sequential(
      nn.Linear(z_dim, 256),
      nn.LeakyReLU(0.1),
      nn.Linear(256, img_dim),
      nn.Sigmoid()
    )

  def forward(self, x):
    return self.disc(x)

In [13]:
gen = Generator(z_dim, img_dim).to(device)
disc = Discriminator(img_dim).to(device)
transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, ), (0.5, ))
])
dataset = datasets.MNIST(root='dataset/', transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
optim_gen = optim.Adam(gen.parameters(), lr=lr)
optim_disc = optim.Adam(disc.parameters(), lr=lr)
criterion = nn.BCELoss()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to dataset/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 65800178.59it/s]


Extracting dataset/MNIST/raw/train-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to dataset/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 30799820.45it/s]

Extracting dataset/MNIST/raw/train-labels-idx1-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 23575964.81it/s]

Extracting dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 5280080.04it/s]


Extracting dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw



In [12]:
# Tensorboard
summaryWriter_fake = SummaryWriter(f"GAN-MNIST/fake")
summaryWriter_real = SummaryWriter(f"GAN-MNIST/real")
step = 0