<a href="https://colab.research.google.com/github/ZS4MLDL/learn_pytorch/blob/main/04_Gans_From_Scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [17]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import random
import os
import time

In [18]:
os.makedirs('./images/gan', exist_ok=True)

In [19]:
BATCH_SIZE = 64
N_EPOCHS = 25
IMAGE_SIZE = 28 * 28
LATENT_DIM = 100
PRINT_EVERY = 5
N_SHOW = 5

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [20]:
transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
])

In [21]:
data = datasets.MNIST('.data', train=True, download=True,transform=transforms)

In [22]:
data

Dataset MNIST
    Number of datapoints: 60000
    Root location: .data
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=(0.5,), std=(0.5,))
           )

In [23]:
iterator = DataLoader(data,BATCH_SIZE,shuffle=True,drop_last=True)

In [24]:
class Generator(nn.Module):
  def __init__(self, latent_dim, image_size):
    super().__init__()
    self.main = nn.Sequential(
        nn.Linear(latent_dim, 256),
        nn.LeakyReLU(0.2),

        nn.Linear(256, 512),
        nn.LeakyReLU(0.2),

        nn.Linear(512, 1024),
        nn.LeakyReLU(0.2),

        nn.Linear(1024, image_size),
        nn.Tanh()

        )

  def forward(self, x):
      return self.main(x)


In [25]:
class Discriminator(nn.Module):
  def __init__(self,image_size):
    super().__init__()
    self.main = nn.Sequential(
        nn.Linear(image_size, 1024),
        nn.LeakyReLU(0.2),
        nn.Dropout(0.3),

        nn.Linear(1024, 512),
        nn.LeakyReLU(0.2),
        nn.Dropout(0.3),

        nn.Linear(512, 256),
        nn.LeakyReLU(0.2),
        nn.Dropout(0.3),

        nn.Linear(256, 1),
        nn.Sigmoid()
    )

  def forward(self, x):
      return self.main(x).squeeze(1)

In [26]:
G = Generator(LATENT_DIM,IMAGE_SIZE).to(device)

In [27]:
D = Discriminator(IMAGE_SIZE).to(device)

In [28]:
criterion = nn.BCELoss()

In [29]:
G_optimizer = optim.Adam(G.parameters(), lr=0.0002,betas = (0.5,0.999))
D_optimizer = optim.Adam(G.parameters(), lr=0.0002,betas = (0.5,0.999))

In [30]:
real_labels = torch.ones(BATCH_SIZE).to(device)
fake_labels = torch.zeros(BATCH_SIZE).to(device)

In [31]:
def epoch_time(start_time, end_time):
  elapsed_time = end_time - start_time
  elapsed_mins = int(elapsed_time)/60
  elapsed_secs = int(elapsed_time - (elapsed_mins * 60))

  return elapsed_mins, elapsed_secs

In [32]:
for epoch in range(1, N_EPOCHS+1):
  START_TIME = time.time()
  for _ in range(len(iterator)):
    #train Discriminator
     x, _ = next(iter(iterator))  #images[64] [batch, channel, height, width] [64,1, 28, 28]
     x = x.to(device)
     x = x.view(-1, IMAGE_SIZE)
     z = torch.randn(x.shape[0], LATENT_DIM).to(device) #latent dim [64, 100]

     with torch.no_grad():
        generated_images = G(z)

     pred_real = D(x)
     D_error_real = criterion(pred_real, real_labels)
     pred_fake = D(generated_images.detach())
     D_error_fake = criterion(pred_fake ,fake_labels)

     D_error = D_error_real + D_error_fake

     D_error.backward()
     D_optimizer.step()

     #train Generator
     G.zero_grad()
     x, _ = next(iter(iterator))
     x = x.to(device)
     x  = x.view(-1, IMAGE_SIZE)
     z = torch.randn(x.shape[0], LATENT_DIM).to(device)
     generated_images = G(z)
     pred_fake = D(generated_images)
     G_error = criterion(pred_fake, real_labels)
     G_error.backward()
     G_optimizer.step()

END_TIME = time.time()

epoch_mins, epoch_secs = epoch_time(START_TIME, END_TIME)

if epoch % PRINT_EVERY == 0:

  z = torch.randn(N_SHOW * N_SHOW, LATENT_DIM).to(device)

  with torch.no_grad():
    generated_images = G(z)

    generated_images = generated_images.view(-1, 1, 28, 28)

    print(f"| Epoch: {epoch:03} | D_error: {D_error.item():.03f} G_error: {G_error.item():.03f} | Time: {epoch_mins}m {epoch_secs}s")

    torchvision.utils.save_image(generated_images, f"images/gan/epoch{epoch:03}.png", nrow=N_SHOW, normalize=True)

    img = plt.imread(f"images/gan/epoch{epoch:03}.png")
    plt.imshow(img)
    plt.show()


KeyboardInterrupt: 