In [None]:
!pip -q install torch_snippets

In [None]:
from torch_snippets import *
from torchvision.utils import make_grid
import torch.nn as nn
import torch
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.optim import Adam

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,),(0.5,))
])

dataloader = DataLoader(MNIST('/content/',download = True, train = True, transform = transforms), shuffle = True, drop_last = True, batch_size = 128)

In [None]:
class Discriminator(nn.Module):
  def __init__(self):
    super().__init__()
    self.model = nn.Sequential(
        nn.Linear(784, 1024),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(1024, 512),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(512,256),
        nn.ReLU(),
        nn.Dropout(),
        nn.Linear(256, 1),
        nn.Sigmoid()
    )

  def forward(self, x):
    return self.model(x)

class Generator(nn.module):
  def __init__(self):
    super().__init__()
    self.model = nn.Sequential(
        nn.Linear(100, 256),
        nn.ReLU(),
        nn.Linear(256, 512),
        nn.ReLU(),
        nn.Linear(512, 1024),
        nn.ReLU(),
        nn.Linear(1024, 784),
        nn.Tanh()
    )
  
  def forward(self, x):
    return self.model(x)

def noise(size):
  n = torch.randn(size, 100)
  return n.to(device)

def discriminator_train_step(real_data, fake_data):
  d_optimizer.zero_grad()
  prediction_real = discriminator(real_data)
  error_real = loss(prediction_real, torch.ones(len(real_data), 1).to(device))
  error_real.backward()
  prediction_fake = discriminator(fake_data)
  error_fake = loss(prediction_fake, torch.zeros(len(fake_data), 1).to(device))
  error_fake.backward()
  d_optimizer.step()
  return error_fake + error_real

def generator_train_step(fake_data):
  g_optimizer.zero_grad()
  prediction = discriminator(fake_data)
  error = loss(prediction, torch.ones(len(real_data), 1).to(device))
  error.backward()
  g_optimizer.step()
  return error

discriminator = Discriminator().to(device)
generator = Generator().to(device)
d_optimizer = Adam(discriminator.parameters(), lr = 0.0002)
g_optimizer = Adam(generator.parameters(), lr = 0.0002)
loss = nn.BCELoss()
num_epochs = 200
log = Report(num_epochs)

for epoch in range(num_epochs):
  N = len(dataloader)
  for i, (image, _) in enumerate(tqdm(dataloader)):
    real_data = image.view(len(image),-1).to(device)
    fake_data = generator(noise(len(real_data))).to(device)
    fake_data = fake_data.detach()

    d_loss = discriminator(real_data, fake_data)
    fake_data = generator(noise(len(real_data))).to(device)
    g_loss = generator_train_step(fake_data)
    log.record(epoch + (1+i)/N, d_loss = d_loss.item(), g_loss = g_loss.item(), end = '/r')
  log.report_avgs(epoch+1)
log.plot_epochs(['d_loss','g_loss'])

In [None]:
z = torch.randn(64,100).to(device)
sample_images = generator(z).data.cpu().view(64, 1, 28 , 28)
grid = make_grid(sample_images, nrow = 8, normalize=True)
show(grid.cpu().detach().permute(1,2,0), sz = 5)