In [None]:
import torch
from torch.autograd import Variable

In [None]:
def to_var(x):
  if torch.cuda.is_available():
    x=x.cuda()
  return Variable(x)

In [None]:
import torchvision.datasets as dsets
import torchvision.transforms as transforms

In [None]:
transform = transforms.Compose([transforms.ToTensor(),
                transforms.Normalize(mean=0.5,
                          std=0.5)])

In [None]:
train_dataset = dsets.MNIST(root='./data',
              train=True,
              transform=transform,
              download=True)

test_dataset = dsets.MNIST(root='./data',
              train=False,
              transform=transform)

In [None]:
import torch.utils.data as Data

In [None]:
batch_size = 100

In [None]:
data_loader = Data.DataLoader(dataset=train_dataset,
                batch_size=batch_size,
                shuffle=True)


In [None]:
import torch.nn as nn

In [None]:
D = nn.Sequential(
    nn.Linear(28*28, 256),
    nn.LeakyReLU(0.2),
    nn.Linear(256, 256),
    nn.LeakyReLU(0.2),
    nn.Linear(256, 1),
    nn.Sigmoid()
)

In [None]:
G = nn.Sequential(
    nn.Linear(64, 256),
    nn.LeakyReLU(0.2),
    nn.Linear(256, 256),
    nn.LeakyReLU(0.2),
    nn.Linear(256, 28*28),
    nn.Tanh()
)

In [None]:
if torch.cuda.is_available():
  D.cuda()
  G.cuda()


In [None]:
loss_fn = nn.BCELoss()
d_opt = torch.optim.Adam(D.parameters(), lr=0.0003)
g_opt = torch.optim.Adam(G.parameters(), lr=0.0003)

In [None]:
from torchvision.utils import save_image

In [None]:
def denorm(x):
  out = (x+1)/2
  out.clamp_(0,1)
  return out

In [None]:
for epoch in range(200):
  for i, (images, _) in enumerate(data_loader):
    batch_size = images.size(0)
    images = to_var(images.view(batch_size, -1))

    real_labels = to_var(torch.ones(batch_size, 1))
    fake_labels = to_var(torch.zeros(batch_size, 1))

    outputs = D(images)
    d_loss_real = loss_fn(outputs, real_labels)
    real_score = outputs

    z = to_var(torch.randn(batch_size, 64))
    fake_images = G(z)
    outputs = D(fake_images)
    d_loss_fake = loss_fn(outputs, fake_labels)
    fake_score = outputs

    d_loss = d_loss_real + d_loss_fake
    D.zero_grad()
    d_loss.backward()
    d_opt.step()

    z = to_var(torch.randn(batch_size, 64))
    fake_images = G(z)
    outputs = D(fake_images)

    g_loss = loss_fn(outputs, real_labels)
    D.zero_grad()
    G.zero_grad()
    g_loss.backward()
    g_opt.step()

    if (i+30)%300 ==0 :
      print("Epoch %d, batch %d, d_loss: %.4f, g_loss: %.4f,"
      "D(x): %.2f, D(G(z)): %.2f"
      %(epoch, i+1, d_loss.data, g_loss.data,
        real_score.data.mean(), fake_score.data.mean()))

  if (epoch ==0):
    images = images.view(batch_size, 1, 28, 28)
    save_image(denorm(images), "./data/real_images.png")
  fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
  save_image(denorm(fake_images), "./data/fake_images-%d.png"%(epoch+1))




Epoch 0, batch 271, d_loss: 0.9523, g_loss: 3.3826,D(x): 0.71, D(G(z)): 0.30
Epoch 0, batch 571, d_loss: 0.1914, g_loss: 3.3210,D(x): 0.97, D(G(z)): 0.15
Epoch 1, batch 271, d_loss: 1.6304, g_loss: 3.2802,D(x): 0.85, D(G(z)): 0.49
Epoch 1, batch 571, d_loss: 0.7206, g_loss: 6.0111,D(x): 0.77, D(G(z)): 0.11
Epoch 2, batch 271, d_loss: 0.3023, g_loss: 4.5459,D(x): 0.88, D(G(z)): 0.09
Epoch 2, batch 571, d_loss: 1.6474, g_loss: 1.0663,D(x): 0.56, D(G(z)): 0.50
Epoch 3, batch 271, d_loss: 1.9637, g_loss: 0.6577,D(x): 0.46, D(G(z)): 0.55
Epoch 3, batch 571, d_loss: 0.7762, g_loss: 1.6382,D(x): 0.74, D(G(z)): 0.28
Epoch 4, batch 271, d_loss: 0.3329, g_loss: 3.4993,D(x): 0.87, D(G(z)): 0.13
Epoch 4, batch 571, d_loss: 0.8928, g_loss: 2.8431,D(x): 0.77, D(G(z)): 0.27
Epoch 5, batch 271, d_loss: 0.5828, g_loss: 2.8135,D(x): 0.84, D(G(z)): 0.23
Epoch 5, batch 571, d_loss: 0.4490, g_loss: 2.1374,D(x): 0.87, D(G(z)): 0.22
Epoch 6, batch 271, d_loss: 0.7843, g_loss: 2.5340,D(x): 0.77, D(G(z)): 0.31

In [112]:
class DisCriminatoR(nn.Module):
  def __init__(self):
    super(DisCriminatoR, self).__init__()
    self.conv1 = nn.Conv2d(
        in_channels=1,
        out_channels=64,
        kernel_size=5,
        stride=2,
        padding=2,
        bias=True
    )
    self.leaky_relu = nn.LeakyReLU()
    self.dropout_2d = nn.Dropout2d(0.3)
    self.conv2 = nn.Conv2d(
        in_channels=64,
        out_channels=128,
        kernel_size=5,
        stride=2,
        padding=2,
        bias=True
    )
    self.linear1 = nn.Linear(128*7*7, 1, bias=True)
    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
    out = self.conv1(x)
    out = self.leaky_relu(out)
    out = self.dropout_2d(out)
    out = self.conv2(out)
    out = self.leaky_relu(out)
    out = self.dropout_2d(out)
    out = out.view(-1, 128*7*7)
    out = self.linear1(out)
    out = self.sigmoid(out)
    return out

In [113]:
class Generator(nn.Module):
  def __init__(self, latent_dim = 100, batchnorm=True):
    super(Generator, self).__init__()
    self.latent_dim = latent_dim
    self.batchnorm = batchnorm

    self.linear1 = nn.Linear(latent_dim, 7*7*256, bias=False)
    self.bn1d1 = nn.BatchNorm1d(7*7*256) if batchnorm else None
    self.leaky_relu = nn.LeakyReLU()

    self.conv1 = nn.Conv2d(
        in_channels=256,
        out_channels=128,
        kernel_size=5,
        stride=1,
        padding=2,
        bias=False
    )

    self.bn2d1 = nn.BatchNorm2d(128) if batchnorm else None
    self.conv2 = nn.ConvTranspose2d(
        in_channels=128,
        out_channels=64,
        kernel_size=4,
        stride=2,
        padding=1,
        bias=False
    )

    self.bn2d2 = nn.BatchNorm2d(64) if batchnorm else None
    self.conv3 = nn.ConvTranspose2d(
        in_channels=64,
        out_channels=1,
        kernel_size=4,
        stride=2,
        padding=1,
        bias=False
    )

    self.tanh = nn.Tanh()

  def forward(self, x):
    out = self.linear1(x)
    if self.batchnorm:
      out = self.bn1d1(out)
    out = self.leaky_relu(out)
    out = out.view(-1, 256, 7, 7)
    out = self.conv1(out)

    if self.batchnorm:
      out = self.bn2d1(out)
    out = self.leaky_relu(out)
    out = self.conv2(out)

    if self.batchnorm:
      out = self.bn2d2(out)
    out = self.leaky_relu(out)
    out = self.conv3(out)

    out = self.tanh(out)
    return out

In [114]:
DCG = Generator()
DCD = DisCriminatoR()

In [115]:
if torch.cuda.is_available:
  DCG.cuda()
  DCD.cuda()

In [116]:
dcd_opt = torch.optim.Adam(DCD.parameters(), lr=0.0002, betas=(0.5,0.999))
dcg_opt = torch.optim.Adam(DCG.parameters(), lr=0.0002, betas=(0.5,0.999))

In [117]:
import os

for epoch in range(40):
  for i, (images, _) in enumerate(data_loader):
    batch_size = images.size(0)
    images = to_var(images)

    real_labels = to_var(torch.ones(batch_size, 1))
    fake_labels = to_var(torch.zeros(batch_size, 1))

    # ===================== Train DCD =====================
    outputs = DCD(images)
    d_loss_real = loss_fn(outputs, real_labels)
    real_score = outputs

    z = to_var(torch.randn(batch_size, 100))
    fake_images = DCG(z) # Generate fake images
    outputs = DCD(fake_images) # Discriminator's output on fake images
    d_loss_fake = loss_fn(outputs, fake_labels)
    fake_score = outputs

    d_loss = d_loss_real + d_loss_fake
    DCD.zero_grad()
    d_loss.backward()
    dcd_opt.step()

    # ===================== Train DCG =====================
    z = to_var(torch.randn(batch_size, 100))
    fake_images = DCG(z)
    outputs = DCD(fake_images) # Discriminator's output on newly generated fake images

    g_loss = loss_fn(outputs, real_labels)
    DCD.zero_grad() # Clear gradients for DCD (even though we're training DCG, good practice)
    DCG.zero_grad()
    g_loss.backward()
    dcg_opt.step()

    if (i+30)%300 ==0 :
      print("Epoch %d, batch %d, d_loss: %.4f, g_loss: %.4f,"
      "D(x): %.2f, D(G(z)): %.2f"
      %(epoch, i+1, d_loss.data, g_loss.data,
        real_score.data.mean(), fake_score.data.mean()))

  # Create directory if not exists
  if not os.path.exists('./data1'):
    os.makedirs('./data1')

  if (epoch ==0):
    images = images.view(batch_size, 1, 28, 28)
    save_image(denorm(images), "./data1/real_images.png")
  fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
  save_image(denorm(fake_images), "./data1/fake_images-%d.png"%(epoch+1))

Epoch 0, batch 271, d_loss: 1.2430, g_loss: 0.8829,D(x): 0.54, D(G(z)): 0.45
Epoch 0, batch 571, d_loss: 1.3117, g_loss: 0.7868,D(x): 0.52, D(G(z)): 0.46
Epoch 1, batch 271, d_loss: 1.3092, g_loss: 0.7959,D(x): 0.54, D(G(z)): 0.49
Epoch 1, batch 571, d_loss: 1.3124, g_loss: 0.7757,D(x): 0.54, D(G(z)): 0.49
Epoch 2, batch 271, d_loss: 1.3311, g_loss: 0.7434,D(x): 0.51, D(G(z)): 0.47
Epoch 2, batch 571, d_loss: 1.2798, g_loss: 0.8129,D(x): 0.52, D(G(z)): 0.44
Epoch 3, batch 271, d_loss: 1.2858, g_loss: 0.8180,D(x): 0.54, D(G(z)): 0.47
Epoch 3, batch 571, d_loss: 1.3021, g_loss: 0.8229,D(x): 0.54, D(G(z)): 0.48
Epoch 4, batch 271, d_loss: 1.2674, g_loss: 0.8622,D(x): 0.55, D(G(z)): 0.47
Epoch 4, batch 571, d_loss: 1.2988, g_loss: 0.8576,D(x): 0.53, D(G(z)): 0.46
Epoch 5, batch 271, d_loss: 1.3347, g_loss: 0.8604,D(x): 0.55, D(G(z)): 0.50
Epoch 5, batch 571, d_loss: 1.2110, g_loss: 0.8440,D(x): 0.55, D(G(z)): 0.44
Epoch 6, batch 271, d_loss: 1.2382, g_loss: 0.9251,D(x): 0.55, D(G(z)): 0.44