<a href="https://colab.research.google.com/github/Pmilivojevic/PyTorch/blob/main/GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
class Discriminator(nn.Module):
  def __init__(self, im_dim):
    super().__init__()

    self.fc1 = nn.Linear(im_dim, 128)
    self.non_lin1 = nn.LeakyReLU(0.1)
    self.fc2 = nn.Linear(128, 1)
    self.non_lin2 = nn.Sigmoid()

  def forward(self, x):
    x = self.non_lin1(self.fc1(x))

    return self.non_lin2(self.fc2(x))

In [None]:
class Generator(nn.Module):
  def __init__(self, z_dim, im_dim):
    super().__init__()

    self.fc1 = nn.Linear(z_dim, 256)
    self.non_lin1 = nn.LeakyReLU(0.1)
    self.fc2 = nn.Linear(256, im_dim)
    self.non_lin2 = nn.Tanh()

  def forward(self, x):
    x = self.non_lin1(self.fc1(x))

    return self.non_lin2(self.fc2(x))

In [None]:
lr = 3e-4
z_dim = 64
im_dim = 784
batch_size = 32
num_epochs =  50

disc = Discriminator(im_dim).to(device)
gen = Generator(z_dim, im_dim).to(device)
fixed_noise = torch.randn(batch_size, z_dim).to(device)
transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ]
)

train_dataset = datasets.MNIST(
    root='/content/drive/MyDrive/Colab Notebooks/Dataset',
    transform=transforms,
    download=True
)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /content/drive/MyDrive/Colab Notebooks/Dataset/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 148212607.02it/s]

Extracting /content/drive/MyDrive/Colab Notebooks/Dataset/MNIST/raw/train-images-idx3-ubyte.gz to /content/drive/MyDrive/Colab Notebooks/Dataset/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /content/drive/MyDrive/Colab Notebooks/Dataset/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 93324879.68it/s]


Extracting /content/drive/MyDrive/Colab Notebooks/Dataset/MNIST/raw/train-labels-idx1-ubyte.gz to /content/drive/MyDrive/Colab Notebooks/Dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /content/drive/MyDrive/Colab Notebooks/Dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 33411233.21it/s]

Extracting /content/drive/MyDrive/Colab Notebooks/Dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to /content/drive/MyDrive/Colab Notebooks/Dataset/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /content/drive/MyDrive/Colab Notebooks/Dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 19741480.59it/s]


Extracting /content/drive/MyDrive/Colab Notebooks/Dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to /content/drive/MyDrive/Colab Notebooks/Dataset/MNIST/raw



In [None]:
for epoch in range(num_epochs):
  for batch_ind, (real, _) in enumerate(train_loader):
    real = real.view(-1, 784).to(device)
    batch_size = real.shape[0]

    noise = torch.randn(batch_size, z_dim).to(device)
    fake = gen(noise)
    disc_real = disc(real).view(-1)
    lossDreal = criterion(disc_real, torch.ones_like(disc_real))
    disc_fake = disc(fake).view(-1)
    lossDfake = criterion(disc_fake, torch.zeros_like(disc_fake))
    lossD = (lossDreal + lossDfake)/2
    disc.zero_grad()
    lossD.backward(retain_graph=True)
    opt_disc.step()

    output = disc(fake).view(-1)
    lossG = criterion(output, torch.ones_like(output))
    gen.zero_grad()
    lossG.backward()
    opt_gen.step()