In [10]:
!nvidia-smi -L

GPU 0: Tesla K80 (UUID: GPU-b02f1f96-8a18-9b22-664f-858d43192d42)


In [None]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

torch.manual_seed(0)

In [12]:
def show_tensor_images(fake, real, num_images=25, size=(1, 28, 28)):

    fake = (fake+1)/2
    real = (real+1)/2

    fake_unflat = fake.detach().cpu().view(-1, *size)
    real_unflat = real.detach().cpu().view(-1, *size)

    fake_grid = make_grid(fake_unflat[:num_images], nrow=5)
    real_grid = make_grid(real_unflat[:num_images], nrow=5)

    plt.figure(figsize=(10,3))
    plt.subplot(1, 2, 1)
    plt.imshow(fake_grid.permute(1, 2, 0).squeeze())

    plt.subplot(1, 2, 2)
    plt.imshow(real_grid.permute(1, 2, 0).squeeze())
    plt.show()

## Generator

In [13]:
def get_generator_block(input_dim, output_dim):
  return nn.Sequential(
      nn.Linear(input_dim, output_dim),
      nn.BatchNorm1d(output_dim),
      nn.ReLU(inplace = True)
  )

In [14]:
class Generator(nn.Module):
    # z_dim: the dimension of the noise vector, a scalar
    # im_dim: the dimension of the images used, a scalar
    # hidden_dim: the inner dimension, a scalar
        
    def __init__(self, z_dim = 10, im_dim = 784, hidden_dim = 128):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            get_generator_block(z_dim, hidden_dim),
            get_generator_block(hidden_dim, hidden_dim * 2),
            get_generator_block(hidden_dim * 2, hidden_dim * 4),
            get_generator_block(hidden_dim * 4, hidden_dim * 8),
            nn.Linear(hidden_dim * 8, im_dim),
            nn.Sigmoid()
        )
    def forward(self, noise):
        # noise: a noise tensor with dimensions (n_samples, z_dim)
        return self.gen(noise)
    
    def get_gen(self):
        return self.gen

In [15]:
def get_noise(n_samples, z_dim, device='cuda'):
  
    return torch.randn(n_samples, z_dim, device=device)

## Discriminator

In [16]:
def get_discriminator_block(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.LeakyReLU(0.2, inplace=True))

In [17]:
class Discriminator(nn.Module):
    # im_dim: the dimension of the images, fitted for the dataset used, a scalar
    # hidden_dim: the inner dimension, a scalar

    def __init__(self, im_dim=784, hidden_dim=128):
        super(Discriminator, self).__init__()

        self.disc = nn.Sequential(
            get_discriminator_block(im_dim, hidden_dim * 4),
            get_discriminator_block(hidden_dim * 4, hidden_dim * 2),
            get_discriminator_block(hidden_dim * 2, hidden_dim),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, image):
        # Function for completing a forward pass of the discriminator: Given an image tensor,returns a 1-dimension tensor representing fake/real.
        return self.disc(image)
    
    def get_disc(self):

        return self.disc

In [None]:
criterion = nn.BCEWithLogitsLoss()
n_epochs = 100
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.00001
device = 'cuda'

dataloader =DataLoader(MNIST('/files/', train=True, download=True,
                       transform = transforms.ToTensor()),
                       batch_size=batch_size, shuffle=True)

In [19]:
gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator().to(device) 
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)

In [20]:
def get_disc_loss(gen, disc, criterion, real, num_images, z_dim, device):

    fake_noise = get_noise(num_images, z_dim, device=device)
    fake = gen(fake_noise)
    disc_fake_pred = disc(fake.detach())
    disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
    disc_real_pred = disc(real)
    disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
    disc_loss = (disc_fake_loss + disc_real_loss) / 2
    return disc_loss

In [21]:
def get_gen_loss(gen, disc, criterion, num_images, z_dim, device):

    fake_noise = get_noise(num_images, z_dim, device=device)
    fake = gen(fake_noise)
    disc_fake_pred = disc(fake)
    gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))

    return gen_loss

In [None]:
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
test_generator = True
gen_loss = False
for epoch in range(n_epochs):

    for real, _ in tqdm(dataloader):
        cur_batch_size = len(real)

        # Flatten the batch of real images from the dataset
        real = real.view(cur_batch_size, -1).to(device)

        # Zero out the gradients before backpropagation
        disc_opt.zero_grad()

        # Calculate discriminator loss
        disc_loss = get_disc_loss(gen, disc, criterion, real, cur_batch_size, z_dim, device)

        # Update gradients
        disc_loss.backward(retain_graph=True)

        # Update optimizer
        disc_opt.step()

        # Zero out the gradients before backpropagation
        gen_opt.zero_grad()

        # Calculate generator loss
        gen_loss = get_gen_loss(gen, disc, criterion, cur_batch_size, z_dim, device)

        # Update Gradients
        gen_loss.backward()

        # Update optimizer
        gen_opt.step()

        # Average discriminator loss
        mean_discriminator_loss += disc_loss.item() / display_step

        # Average generator loss
        mean_generator_loss += gen_loss.item() / display_step

        if cur_step % display_step == 0 and cur_step > 0:
            print(f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")
            fake_noise = get_noise(cur_batch_size, z_dim, device=device)
            fake = gen(fake_noise)
            show_tensor_images(fake, real)
            mean_generator_loss = 0
            mean_discriminator_loss = 0
        cur_step += 1