In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
import numpy as np
import matplotlib.pyplot as plt
import os
from datetime import datetime

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=0.5,std=0.5)])
                        

In [3]:
train_dataset = torchvision.datasets.MNIST(
    root = '.',
    train = True,
    transform = transform,
    download = True
)

In [4]:
test_dataset = torchvision.datasets.MNIST(
    root = '.',
    train = False,
    transform = transform,
    download = True
)

In [5]:
len(train_dataset)

60000

In [6]:
batch_size = 128
data_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size = batch_size,
    shuffle = True)

In [7]:
#Discriminator 
D = nn.Sequential(
    nn.Linear(784, 512),
    nn.LeakyReLU(0.2),
    nn.Linear(512, 256),
    nn.LeakyReLU(0.2),
    nn.Linear(256, 1),
    #nn.Sigmoid()
)

In [9]:
#Generator
latent_dim = 100
G = nn.Sequential(
    nn.Linear(latent_dim, 256),
    nn.BatchNorm1d(256, momentum=0.7),
    nn.Linear(256,512),
    nn.LeakyReLU(0.2),
    nn.BatchNorm1d(512, momentum=0.7),
    nn.Linear(512, 1024),
    nn.LeakyReLU(0.2),
    nn.BatchNorm1d(1024, momentum=0.7),
    nn.Linear(1024, 784),
    nn.Tanh()
)

In [10]:
#set device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
D = D.to(device)
G = G.to(device)

In [11]:
#loss and optimizer
criterion = nn.BCEWithLogitsLoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002, betas = (0.5, 0.999))
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002, betas = (0.5, 0.999))

In [12]:
#Scale image back to(0,1)
def scale_image(img):
    out = (img + 1) / 2
    return out

In [14]:
#create a folder to store generated images
if not os.path.exists('gan_images'):
    os.makedirs('gan_images')

In [15]:
#Training loop
#labels to use in the loop
ones_ = torch.ones(batch_size, 1).to(device)
zeros_ = torch.zeros(batch_size, 1).to(device)

In [None]:
#save losses
d_losses = []
g_losses = []

for epochs in range(200):
    for inputs, _ in data_loader:
        #dont need targets
        #reshape and data to cpu
        n = inputs.size(0)
        inputs = inputs.reshape(n, 784).to(device)

        #set ones and zeros to correct size
        ones = ones_[:n]
        zeros = zeros_[:n]

        #TRAIN DISCRIMINATOR
        #real images
        real_outputs = D(inputs)
        d_loss_real = criterion(real_outputs, ones)
        #Fake images
        noise = torch.randn(n, latent_dim).to(device)
        fake_image = G(noise)
        fake_outouts = D(fake_image)
        d_loss_fake = criterion(fake_outouts, zeros)
        #gradient descent loop
        d_loss = 0.5 * (d_loss_real + d_loss_fake)
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        #TRAIN GENERATOR
        for _ in range(2):
            #fake image
            noise = torch.randn(n, latent_dim).to(device)
            fake_image = G(noise)
            fake_outputs = D(fake_image)
            #reverse the goals 
            g_loss = criterion(fake_outputs, ones)

        #gradient descent loop
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

    #save losses 
    d_losses.append(d_loss.item())
    g_losses.append(g_loss.item())
    print(f'Epochs {epochs}, d_loss: {d_loss.item()}, g_loss: {g_loss.item()}')

    fake_images = fake_image.reshape(-1,1,28,28)
    save_image(scale_image(fake_image), f'gan_images/{epochs+1}.png')

In [None]:
#plot d_loss and g_loss
plt.plot(g_losses, label='g_losses')
plt.plot(d_losses, label='d_losses')
plt.legend()

In [None]:
from skimage.io import imread
a = imread('gan_images/1.png')
plt.imshow(a)

In [None]:
a = imread('gan_images/150.png')
plt.imshow(a)