In [1]:
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'd:/samples'

if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

In [5]:
transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5),std=(0.5, 0.5, 0.5))])

In [6]:
mnist = torchvision.datasets.MNIST(root='d:/MNIST/',
                                   train=True,
                                   transform=transform,
                                   download=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!


In [7]:
data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                          batch_size=batch_size, 
                                          shuffle=True)

In [11]:
D = nn.Sequential(
    nn.Linear(image_size,hidden_size),
    nn.LeakyReLU(.2),
    nn.Linear(hidden_size,hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, 1),
    nn.Sigmoid()).to(device)

G = nn.Sequential(
    nn.Linear(latent_size,hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size,hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size,image_size),
    nn.Tanh()).to(device)

In [12]:
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)

In [13]:
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

In [None]:
total_step = len(data_loader)
for epoch in range(num_epochs):
    for idx,(images,_) in enumerate(data_loader):
        images = images.reshape(batch_size,-1).to(device)
        
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        
        #------------DISCRIMINATOR-------------------------
        outputs = D(images)
        real_loss = criterion(outputs,real_labels)
        
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        fake_loss = criterion(outputs,fake_labels)
        
        d_loss = real_loss + fake_loss
        reset_grad()
        d_loss.backward()
        d_optimizer.step()
        
        #-------------Generator---------------------------
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        
        g_loss = criterion(outputs,real_labels)
        g_loss.backward()
        g_optimizer.step()
        
        if (idx+1) % 100 == 0:
            print(idx+1, total_step, d_loss.item(), g_loss.item())
            
    if epoch%5==0:
        fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
        save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))

100 600 0.012241052463650703 5.728859901428223
200 600 0.01560187991708517 5.394524574279785
300 600 0.01006968505680561 6.615293979644775
400 600 0.007493121549487114 6.205960750579834
500 600 0.015997696667909622 5.115780830383301
600 600 0.019362976774573326 6.4097208976745605
100 600 0.02448098734021187 6.327150821685791
200 600 0.028336849063634872 5.604171276092529
300 600 0.082874596118927 7.147561550140381
400 600 0.0808338075876236 4.517849922180176
500 600 0.34115099906921387 6.846116542816162
600 600 0.0476398840546608 6.767890453338623
100 600 0.17276516556739807 3.8496320247650146
200 600 0.2278597503900528 4.556034088134766
300 600 0.25600844621658325 4.8330559730529785
400 600 0.44862455129623413 5.88062858581543
500 600 0.6797982454299927 3.5120224952697754
600 600 0.7296677231788635 3.0891616344451904
100 600 0.3494322896003723 3.349128484725952
200 600 0.5206717252731323 5.005690097808838
300 600 0.6268661618232727 4.6981682777404785
400 600 0.21782223880290985 4.5921

In [17]:
images.size()

torch.Size([100])