In [1]:
# prerequisites
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
from torch.utils.data import Subset

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
print(device)

cpu


In [2]:
bs = 1000 #paper 64, blog 256 -- ideal batch size ranges from 32 to 128

# MNIST Dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) #mean 0.5, and std dev 0.5

train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transform, download=False)

print(len(train_dataset), len(test_dataset))
# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle = True)
print(len(train_loader), len(test_loader))

60000 10000
60 10


In [3]:
class Generator(nn.Module):
    def __init__(self, g_input_dim, g_output_dim):
        super(Generator, self).__init__()       
        self.fc1 = nn.Linear(g_input_dim, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2)
        self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim)
    
    # forward method
    def forward(self, x): 
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return torch.tanh(self.fc4(x))
    
class Discriminator(nn.Module):
    def __init__(self, d_input_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
        self.fc4 = nn.Linear(self.fc3.out_features, 1)
    
    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3) #if we are overfitting
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        return torch.sigmoid(self.fc4(x))

In [5]:
print(train_dataset.data.size(0),train_dataset.data.size(1),train_dataset.data.size(2))

60000 28 28


z_dim = 1000
mnist_dim = train_dataset.data.size(1) * train_dataset.data.size(2)

G = Generator(g_input_dim = z_dim, g_output_dim = mnist_dim).to(device)
D = Discriminator(mnist_dim).to(device)

In [7]:
G

Generator(
  (fc1): Linear(in_features=1000, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=1024, bias=True)
  (fc4): Linear(in_features=1024, out_features=784, bias=True)
)

In [8]:
D

Discriminator(
  (fc1): Linear(in_features=784, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=256, bias=True)
  (fc4): Linear(in_features=256, out_features=1, bias=True)
)

In [10]:
# loss
criterion = nn.BCELoss() 

# optimizer
lr = 0.0002 
G_optimizer = optim.Adam(G.parameters(), lr = lr)
D_optimizer = optim.Adam(D.parameters(), lr = lr)

In [11]:
def G_train(x):
    #=======================Train the generator=======================#
    G.zero_grad()

    z = Variable(torch.randn(bs, z_dim).to(device))
    y = Variable(torch.ones(bs, 1).to(device))

    G_output = G(z)
    D_output = D(G_output)
    G_loss = criterion(D_output, y)

    # gradient backprop & optimize ONLY G's parameters
    G_loss.backward()
    G_optimizer.step()
        
    return G_loss.data.item()

In [12]:
def D_train(x):
    #=======================Train the discriminator=======================#
    D.zero_grad()

    # train discriminator on real
    x_real, y_real = x.view(-1, mnist_dim), torch.ones(bs, 1)
    x_real, y_real = Variable(x_real.to(device)), Variable(y_real.to(device))

    D_output = D(x_real)
    D_real_loss = criterion(D_output, y_real)
    D_real_score = D_output

    # train discriminator on fake
    z = Variable(torch.randn(bs, z_dim).to(device))
    x_fake, y_fake = G(z), Variable(torch.zeros(bs, 1).to(device))

    D_output = D(x_fake)
    D_fake_loss = criterion(D_output, y_fake)
    D_fake_score = D_output

    # gradient backprop & optimize ONLY D's parameters
    D_loss = D_real_loss + D_fake_loss
    D_loss.backward()
    D_optimizer.step()
        
    return  D_loss.data.item()

In [None]:
n_epoch = 100
st_losses_g = [] #store losses for plotting
st_losses_d = [] #store losses for plotting
for epoch in range(1, n_epoch+1):           
    D_losses, G_losses = [], []
    for batch_idx, (x, _) in enumerate(train_loader):
        D_losses.append(D_train(x))
        G_losses.append(G_train(x))
    st_losses_g.append(torch.mean(torch.FloatTensor(G_losses))) #add this to other one
    st_losses_d.append(torch.mean(torch.FloatTensor(D_losses))) #add this to other one
    
    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
            (epoch), n_epoch, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))

In [None]:
with torch.no_grad():
    test_z = Variable(torch.randn(bs, z_dim).to(device))
    generated = G(test_z)

    save_image(generated.view(1000, 1, 28, 28), './samples/GAN_100epoch' + '.png')

In [15]:
n_epoch = 100
for epoch in range(1, n_epoch+1):           
    D_losses, G_losses = [], []
    for batch_idx, (x, _) in enumerate(train_loader):
        D_losses.append(D_train(x))
        G_losses.append(G_train(x))

    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
            (100+epoch), 100+n_epoch, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))

[101/200]: loss_d: 0.430, loss_g: 3.236
[102/200]: loss_d: 0.425, loss_g: 3.248
[103/200]: loss_d: 0.376, loss_g: 3.411
[104/200]: loss_d: 0.467, loss_g: 3.298
[105/200]: loss_d: 0.462, loss_g: 3.157
[106/200]: loss_d: 0.422, loss_g: 3.121
[107/200]: loss_d: 0.480, loss_g: 3.101
[108/200]: loss_d: 0.465, loss_g: 3.239
[109/200]: loss_d: 0.445, loss_g: 3.164
[110/200]: loss_d: 0.504, loss_g: 2.892
[111/200]: loss_d: 0.508, loss_g: 2.759
[112/200]: loss_d: 0.468, loss_g: 2.893
[113/200]: loss_d: 0.497, loss_g: 2.903
[114/200]: loss_d: 0.474, loss_g: 2.985
[115/200]: loss_d: 0.505, loss_g: 2.893
[116/200]: loss_d: 0.501, loss_g: 2.944
[117/200]: loss_d: 0.535, loss_g: 2.920
[118/200]: loss_d: 0.520, loss_g: 2.845
[119/200]: loss_d: 0.534, loss_g: 2.767
[120/200]: loss_d: 0.533, loss_g: 2.670
[121/200]: loss_d: 0.519, loss_g: 2.663
[122/200]: loss_d: 0.506, loss_g: 2.766
[123/200]: loss_d: 0.500, loss_g: 2.686
[124/200]: loss_d: 0.509, loss_g: 2.693
[125/200]: loss_d: 0.514, loss_g: 2.714


In [17]:
with torch.no_grad():
    test_z = Variable(torch.randn(bs, z_dim).to(device))
    generated = G(test_z)

    save_image(generated.view(1000, 1, 28, 28), './samples/GAN_200epoch' + '.png')

In [None]:
# plot and save the generator and discriminator loss
import matplotlib
import matplotlib.pyplot as plt
plt.figure()
plt.plot(st_losses_g, label='Generator loss')
plt.plot(st_losses_d, label='Discriminator Loss')
plt.legend()
plt.savefig('./samples/GAN_loss.png')