In [0]:
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms, datasets, models
from torchvision.utils import save_image
import os

In [0]:
device = torch.device('cuda' if torch.cuda.is_available else 'cpu')

In [0]:
learning_rate= 0.0002
batch_size = 100
num_epochs = 30
n_noise = 100

In [0]:
sample_dir = 'samples'

if not os.path.exists(sample_dir):
  os.makedirs(sample_dir)

In [0]:
transform = transforms.Compose([transforms.ToTensor()])

dataset = datasets.MNIST(root='../../data/', train=True, transform=transform, download=True)
data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

  0%|          | 0/9912422 [00:00<?, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../../data/MNIST/raw/train-images-idx3-ubyte.gz


 99%|█████████▉| 9854976/9912422 [00:12<00:00, 433719.73it/s]

Extracting ../../data/MNIST/raw/train-images-idx3-ubyte.gz



0it [00:00, ?it/s][A
32768it [00:00, 258501.82it/s][A
[A
0it [00:00, ?it/s][A

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../../data/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting ../../data/MNIST/raw/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../../data/MNIST/raw/t10k-images-idx3-ubyte.gz



  1%|          | 16384/1648877 [00:00<00:11, 141820.35it/s][A
  3%|▎         | 49152/1648877 [00:00<00:09, 170516.32it/s][A
  5%|▌         | 90112/1648877 [00:00<00:07, 202685.40it/s][A
  8%|▊         | 139264/1648877 [00:00<00:06, 242333.65it/s][A
 12%|█▏        | 204800/1648877 [00:00<00:04, 292403.47it/s][A
 17%|█▋        | 278528/1648877 [00:00<00:03, 350574.94it/s][A
 22%|██▏       | 360448/1648877 [00:00<00:03, 419609.59it/s][A
 28%|██▊       | 458752/1648877 [00:00<00:02, 491686.91it/s][A
 34%|███▍      | 565248/1648877 [00:01<00:01, 574438.90it/s][A
 41%|████      | 679936/1648877 [00:01<00:01, 662001.75it/s][A
 49%|████▊     | 802816/1648877 [00:01<00:01, 699692.49it/s][A
 56%|█████▌    | 925696/1648877 [00:01<00:00, 803173.70it/s][A
 62%|██████▏   | 1024000/1648877 [00:01<00:00, 823384.22it/s][A
 68%|██████▊   | 1122304/1648877 [00:01<00:00, 861894.76it/s][A
 75%|███████▍  | 1228800/1648877 [00:01<00:00, 898021.28it/s][A
 81%|████████▏ | 1343488/1648877 [00:01

Extracting ../../data/MNIST/raw/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../../data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting ../../data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Processing...
Done!


In [0]:
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()
    self.layer1 = nn.Sequential(nn.Linear(28*28, 256),
                                nn.ReLU(True),
                                nn.Linear(256, 256),
                                nn.ReLU(True),
                                nn.Linear(256, 1),
                                nn.Sigmoid())
  def forward(self, x):
    out = self.layer1(x)
    return out
  
  
class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()
    self.layer1 = nn.Sequential(nn.Linear(n_noise, 256),
                                nn.LeakyReLU(0.2, True),
                                nn.Linear(256, 256),
                                nn.LeakyReLU(0.2, True),
                                nn.Linear(256, 28*28),
                                nn.Tanh())
    
  def forward(self, x):
    out = self.layer1(x)
    return out

In [0]:
D = Discriminator().to(device)
G = Generator().to(device)

In [0]:
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=learning_rate, betas=(0.5, 0.999))
g_optimizer = torch.optim.Adam(G.parameters(), lr=learning_rate, betas=(0.5, 0.999))

def zero_grad():
  d_optimizer.zero_grad()
  g_optimizer.zero_grad()
  
def norm(images):
  return images*2 - 1
  
def denorm(images):
  images = (images+1)/2
  return images.clamp(0,1)


In [0]:
for epoch in range(num_epochs):
  for i, (images, _) in enumerate(data_loader):
    images = images.view(images.size(0), -1).to(device)
    images = norm(images)
    
    real_labels = torch.ones(images.size(0), 1).to(device)
    fake_labels = torch.zeros(images.size(0), 1).to(device)
    #===================================================
    #                 Train Discriminator
    # ==================================================
    outputs = D(images)
    d_loss_real = criterion(outputs, real_labels)
    
    z = torch.randn(images.size(0), n_noise).to(device)
    fake_image = G(z)
    outputs = D(fake_image)
    d_loss_fake = criterion(outputs, fake_labels)
    
    
    d_loss = d_loss_real + d_loss_fake
    
    zero_grad()
    d_loss.backward()
    d_optimizer.step()
    
    #===================================================
    #                 Train Generator
    # ==================================================
    z = torch.randn(images.size(0), n_noise).to(device)
    fake_image = G(z)
    outputs_fake = D(fake_image)
    g_loss = criterion(outputs_fake, real_labels)
    
    zero_grad()
    g_loss.backward()
    g_optimizer.step()
    
    
    if (i+1) % 200 == 0:
      print('Epoch [{}/{}], Step [{}/{}], d_loss [{:.4f} : {:.4f}], g_loss : {:.4f}'
            .format(epoch+1, num_epochs, i+1, len(data_loader), d_loss_real.item(), d_loss_fake.item(), g_loss.item()))
      
      
  fake_image = fake_image.view(fake_image.size(0), 1, 28, 28)
  save_image(denorm(fake_image), os.path.join(sample_dir, 'fake_images={}.png'.format(str(epoch).zfill(3))))

Epoch [1/100], Step [200/600], d_loss [0.4238 : 0.4978], g_loss : 1.0786


9920512it [00:30, 433719.73it/s]                             

Epoch [1/100], Step [400/600], d_loss [0.2877 : 0.3072], g_loss : 1.4526
Epoch [1/100], Step [600/600], d_loss [0.2537 : 0.5774], g_loss : 1.8779
Epoch [2/100], Step [200/600], d_loss [0.2351 : 0.3904], g_loss : 2.0214
Epoch [2/100], Step [400/600], d_loss [0.1665 : 0.4161], g_loss : 2.7382
Epoch [2/100], Step [600/600], d_loss [0.2484 : 0.2180], g_loss : 1.6600
Epoch [3/100], Step [200/600], d_loss [0.0729 : 0.3930], g_loss : 2.8472
Epoch [3/100], Step [400/600], d_loss [0.1772 : 0.1493], g_loss : 2.9024
Epoch [3/100], Step [600/600], d_loss [0.9166 : 0.0102], g_loss : 2.3262
Epoch [4/100], Step [200/600], d_loss [0.1069 : 0.1965], g_loss : 3.0008
Epoch [4/100], Step [400/600], d_loss [0.3030 : 0.0640], g_loss : 1.9908
Epoch [4/100], Step [600/600], d_loss [0.2475 : 0.1399], g_loss : 3.5527
Epoch [5/100], Step [200/600], d_loss [0.0767 : 0.3446], g_loss : 4.4673
Epoch [5/100], Step [400/600], d_loss [0.0419 : 0.5310], g_loss : 5.8923
Epoch [5/100], Step [600/600], d_loss [0.1108 : 0.1

KeyboardInterrupt: ignored