In [15]:
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image

torch.cuda.set_device(0) # set pytorch running on GPU0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#parameters
sizeImage = 784
Z_size = 64
sizeHidden = 256
batch_size = 100
epochsNum = 100 #number of epochs

sampleLocation = 'samplesOfGAN'
if not os.path.exists(sampleLocation):
    os.makedirs(sampleLocation)

In [None]:
# Download MNIST dataset
trans = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5])]) # Image processing
mnist = torchvision.datasets.MNIST(root='dataMNIST',train=True,transform=trans,download=True)#If already downloaded, download=False
data_loader = torch.utils.data.DataLoader(dataset=mnist,batch_size=batch_size, shuffle=True)

In [16]:
# Construct Discriminator
Discriminator = nn.Sequential(
    nn.Linear(sizeImage, sizeHidden),
    nn.LeakyReLU(0.2),
    nn.Linear(sizeHidden, sizeHidden),
    nn.LeakyReLU(0.2),
    nn.Linear(sizeHidden, 1),
    nn.Sigmoid())
Discriminator = Discriminator.to(device)

# Construct Generator
Generator = nn.Sequential(
    nn.Linear(Z_size, sizeHidden),
    nn.ReLU(),
    nn.Linear(sizeHidden, sizeHidden),
    nn.ReLU(),
    nn.Linear(sizeHidden, sizeImage),
    nn.Tanh())
Generator = Generator.to(device)

In [17]:
# Define criterion and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(Discriminator.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(Generator.parameters(), lr=0.0002)

#
def ImageNormalize(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

In [18]:
for epoch in range(epochsNum):
    for i, (images, _) in enumerate(data_loader):
        images = images.reshape(batch_size, -1).to(device)
        
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        ######### Discriminator training#########

        # Loss function
        outputs = Discriminator(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs
        z = torch.randn(batch_size, Z_size).to(device)
        fake_images = Generator(z)
        outputs = Discriminator(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs        
        d_loss = d_loss_real + d_loss_fake #Total loss
        
        #Backward       
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        ######### Generator training#########

        # Loss function
        z = torch.randn(batch_size, Z_size).to(device)
        fake_images = Generator(z)
        outputs = Discriminator(fake_images)
        g_loss = criterion(outputs, real_labels)
        
        #Backward 
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()
        
        if (i+1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, real_score: {:.2f}, fake_score: {:.2f}' 
                  .format(epoch, epochsNum, i+1, len(data_loader), d_loss.item(), g_loss.item(), 
                          real_score.mean().item(), fake_score.mean().item()))
    
    # Save real_image
    if (epoch+1) == 1:
        images = images.reshape(images.size(0), 1, 28, 28)
        save_image(ImageNormalize(images), os.path.join(sampleLocation, 'real_images.png'))
    
    # Save fake_image
    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    save_image(ImageNormalize(fake_images), os.path.join(sampleLocation, 'fake_images-{}.png'.format(epoch+1)))

# save model
torch.save(Generator.state_dict(), 'Generator.ckpt')
torch.save(Discriminator.state_dict(), 'Discriminator.ckpt')

Epoch [0/2], Step [100/600], d_loss: 0.2846, g_loss: 2.5114, real_score: 0.94, fake_score: 0.20
Epoch [0/2], Step [200/600], d_loss: 0.0556, g_loss: 3.9674, real_score: 0.99, fake_score: 0.05
Epoch [0/2], Step [300/600], d_loss: 0.5458, g_loss: 2.5319, real_score: 0.80, fake_score: 0.24
Epoch [0/2], Step [400/600], d_loss: 0.0310, g_loss: 5.7486, real_score: 1.00, fake_score: 0.03
Epoch [0/2], Step [500/600], d_loss: 0.0188, g_loss: 6.3908, real_score: 0.99, fake_score: 0.01
Epoch [0/2], Step [600/600], d_loss: 0.0845, g_loss: 4.3738, real_score: 0.97, fake_score: 0.05
Epoch [1/2], Step [100/600], d_loss: 0.0597, g_loss: 4.8717, real_score: 0.98, fake_score: 0.03
Epoch [1/2], Step [200/600], d_loss: 0.0564, g_loss: 4.9853, real_score: 0.98, fake_score: 0.03
Epoch [1/2], Step [300/600], d_loss: 0.1076, g_loss: 4.2582, real_score: 0.97, fake_score: 0.07
Epoch [1/2], Step [400/600], d_loss: 0.2820, g_loss: 5.6203, real_score: 0.91, fake_score: 0.11
Epoch [1/2], Step [500/600], d_loss: 0.1