In [0]:
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms, datasets, models
from torchvision.utils import save_image
import os

In [0]:
device = torch.device('cuda' if torch.cuda.is_available else 'cpu')

In [0]:
learning_rate= 2e-4
batch_size = 100
num_epochs = 30

In [0]:
sample_dir = 'samples'

if not os.path.exists(sample_dir):
  os.makedirs(sample_dir)

In [0]:
transform = transforms.Compose([transforms.ToTensor()])

dataset = datasets.MNIST(root='../../data/', train=True, transform=transform, download=True)
data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

In [0]:
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()
    self.conv = nn.Sequential(nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=1, bias=False), 
                                nn.BatchNorm2d(64), # (n, 64, 14, 14)
                                nn.ReLU(True),
                                nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False),
                                nn.BatchNorm2d(128), # (n, 128, 7, 7)
                                nn.ReLU(True),
                                nn.Conv2d(128, 256, kernel_size =3, stride=2, padding=1, bias=False),
                                nn.BatchNorm2d(256), # (n, 256, 4, 4)
                                nn.ReLU(True),
                                nn.AvgPool2d(4))
    
    self.fc = nn.Sequential(nn.Linear(256,1),
                            nn.Sigmoid())
    
  def forward(self, x):
    out = self.conv(x)
    out = out.view(out.size(0), -1)
    out = self.fc(out)
    
    return out
  
  
class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()
    
    self.layer1 = nn.Sequential(nn.Linear(64, 256*4*4))
    
    self.convT = nn.Sequential(nn.ConvTranspose2d(256, 128, 4, 1, 0), # (n, 128, 7, 7)
                              nn.BatchNorm2d(128),
                              nn.LeakyReLU(0.2, True),
                              nn.ConvTranspose2d(128, 64, 4, 2, 1),  # (n, 64, 14, 14)
                              nn.BatchNorm2d(64),
                              nn.LeakyReLU(0.2, True),
                              nn.ConvTranspose2d(64, 1, 4, 2, 1),    # (n, 1, 28, 28)
                              nn.Tanh())
    
  def forward(self, x):
    out = self.layer1(x)
    
    out = out.view(out.size(0), 256, 4, 4)
    out = self.convT(out)
    return out

In [0]:
D = Discriminator().to(device)
G = Generator().to(device)

In [0]:
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=learning_rate)
g_optimizer = torch.optim.Adam(G.parameters(), lr=learning_rate)

def zero_grad():
  d_optimizer.zero_grad()
  g_optimizer.zero_grad()
  
def norm(images):
  return images*2 - 1
  
def denorm(images):
  images = (images+1) / 2
  return images.clamp(0,1)


In [0]:
for epoch in range(num_epochs):
  for i, (images, _) in enumerate(data_loader):
    images = images.to(device)
    images = norm(images)
    
    real_labels = torch.ones(images.size(0), 1).to(device)
    fake_labels = torch.zeros(images.size(0), 1).to(device)
    #===================================================
    #                 Train Discriminator
    # ==================================================
    outputs = D(images)
    d_loss_real = criterion(outputs, real_labels)
    
    z = torch.randn(images.size(0), 64).to(device)
    fake_image = G(z)
    outputs = D(fake_image)
    d_loss_fake = criterion(outputs, fake_labels)
    
    
    d_loss = d_loss_real + d_loss_fake
    
    zero_grad()
    d_loss.backward()
    d_optimizer.step()
    
    #===================================================
    #                 Train Generator
    # ==================================================
    z = torch.randn(images.size(0), 64).to(device)
    fake_image = G(z)
    outputs_fake = D(fake_image)
    g_loss = criterion(outputs_fake, real_labels)
    
    zero_grad()
    g_loss.backward()
    g_optimizer.step()
    
    
    if (i+1) % 200 == 0:
      print('Epoch [{}/{}], Step [{}/{}], d_loss [{:.4f} : {:.4f}], g_loss : {:.4f}'
            .format(epoch, num_epochs, i+1, len(data_loader), d_loss_real.item(), d_loss_fake.item(), g_loss.item()))

    
  fake_image = fake_image.view(fake_image.size(0), 1, 28, 28)
  save_image(denorm(fake_image), os.path.join(sample_dir, 'fake_images={}.png'.format(str(epoch).zfill(3))))

Epoch [0/1000], Step [200/600], d_loss [0.0395 : 0.0519], g_loss : 3.3601
Epoch [0/1000], Step [400/600], d_loss [0.0277 : 0.0739], g_loss : 3.3675
Epoch [0/1000], Step [600/600], d_loss [0.0212 : 0.0197], g_loss : 4.5442
Epoch [1/1000], Step [200/600], d_loss [0.0642 : 0.0803], g_loss : 2.9044
Epoch [1/1000], Step [400/600], d_loss [0.0309 : 0.0326], g_loss : 4.3790
Epoch [1/1000], Step [600/600], d_loss [0.0332 : 0.0674], g_loss : 2.9637
Epoch [2/1000], Step [200/600], d_loss [0.0271 : 0.0775], g_loss : 3.4714
Epoch [2/1000], Step [400/600], d_loss [0.0333 : 0.0618], g_loss : 3.3065
Epoch [2/1000], Step [600/600], d_loss [0.0266 : 0.0414], g_loss : 2.9159
Epoch [3/1000], Step [200/600], d_loss [0.0647 : 0.0069], g_loss : 5.2051
Epoch [3/1000], Step [400/600], d_loss [0.0616 : 0.0547], g_loss : 3.3958
Epoch [3/1000], Step [600/600], d_loss [0.0413 : 0.0493], g_loss : 3.8480
Epoch [4/1000], Step [200/600], d_loss [0.4211 : 0.0882], g_loss : 2.1586
Epoch [4/1000], Step [400/600], d_loss

KeyboardInterrupt: ignored