In [0]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

from tqdm import tqdm
import matplotlib.pyplot as plt
%matplotlib inline

In [0]:
device = ("cuda" if torch.cuda.is_available() else "cpu")

preprocess = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5,),(0.5,))]
)

In [0]:
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator,self).__init__()
    self.model = nn.Sequential(
        nn.Conv2d(1 ,16, kernel_size=(2,2), stride=(2,2)), #14
        nn.BatchNorm2d(16),
        nn.LeakyReLU(),
        nn.Conv2d(16, 32, kernel_size=(2,2), stride=(2,2)), #7
        nn.BatchNorm2d(32),
        nn.LeakyReLU(),
        nn.Conv2d(32, 64, kernel_size=(3,3), stride=(1,1)), #5
        nn.BatchNorm2d(64),
        nn.LeakyReLU(),
        nn.Conv2d(64, 128, kernel_size=(3,3), stride=(1,1)), #3
        nn.BatchNorm2d(128),
        nn.LeakyReLU(),
        nn.Conv2d(128, 1, kernel_size=(3,3), stride=(1,1)), #1
    )
  
  def forward(self, x):
    o = self.model(x)
    
    return o.view(-1,1)

In [0]:
class Generator(nn.Module):
  def __init__(self,nz):
    super(Generator,self).__init__()
    self.linear = nn.Linear(nz, 7*7*256)
    self.model = nn.Sequential(
        nn.ConvTranspose2d(256, 128, kernel_size=(4,4), stride=(2,2), padding=(1,1)), #14
        nn.BatchNorm2d(128),
        nn.ReLU(),
        nn.ConvTranspose2d(128, 1, kernel_size=(4,4), stride=(2,2), padding=(1,1)), #28
        nn.Tanh()
    )
    
  def forward(self, noise):
    batch_size = noise.size(0)
    z = self.linear(noise)
    z = z.view(batch_size, 256, 7, 7)
    
    return self.model(z)

In [0]:
def train_with_wasserstein(dataloader, noise_size, discriminator, generator,D_optimizer, G_optimizer,c=0.01,n_critic=5): 
  def reset_grad():
    D_optimizer.zero_grad()
    G_optimizer.zero_grad()
    
  total_d_loss = []
  for _ in range(n_critic):
    (img,target) = next(iter(dataloader)) 
    img = img.to(device)
    
    nz = torch.Tensor(np.random.rand(img.size(0),noise_size)).to(device)
    D_optimizer.zero_grad()
    d_real = discriminator(img)
    d_fake =discriminator(generator(nz))
  
    D_loss = torch.mean(d_real)-torch.mean(d_fake)
    total_d_loss.append(D_loss.item())
    D_loss = -1*D_loss
    D_loss.backward()
    D_optimizer.step()
    
    for p in discriminator.parameters():
      p.data.clamp_(-c, c)
      
    reset_grad()
    
  nz = torch.Tensor(np.random.rand(img.size(0),noise_size)).to(device)
  d_fake = discriminator(generator(nz))
  G_loss = -1*torch.mean(d_fake)
  G_loss.backward()
  G_optimizer.step()
  
  reset_grad()
  
  return np.mean(np.asarray(total_d_loss)), G_loss.item()
    

In [0]:
def train():
  noise_size = 100
  """
  model set
  """
  discriminator = Discriminator()
  discriminator = discriminator.to(device)
  discriminator.train()
  #discriminator.load_state_dict(torch.load("D_params"))
  generator = Generator(noise_size)
  generator = generator.to(device)
  generator.train()
  #generator.load_state_dict(torch.load("G_params"))
  """
  data set
  """
  batch_size = 64
  trainset = torchvision.datasets.MNIST(root="./",transform=preprocess,download=True)
  dataloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,shuffle=True)
  
  #optimizer
  D_optimizer = optim.RMSprop(discriminator.parameters(), lr=5e-5)
  G_optimizer = optim.RMSprop(generator.parameters(), lr=5e-5)
  
  
  n_iter = 10000
  for gen_iter in range(1,n_iter+1):
    if(gen_iter % 500 == 0):
      print("gen_iter: [{}]/[{}]".format(gen_iter,n_iter))
    d_loss,g_loss = train_with_wasserstein(dataloader, noise_size, discriminator, generator, D_optimizer, G_optimizer)
    
    if(gen_iter % 500 == 0):
      print("d_loss:[{}] g_loss:[{}]".format(d_loss,g_loss))
    
      torch.save(discriminator.state_dict(),"D_params")
      torch.save(generator.state_dict(),"G_params")
  

In [0]:
def eval():
  noise_size = 100
  
  """
  model set
  """
  generator = Generator(noise_size)
  generator = generator.to(device)
  generator.load_state_dict(torch.load("G_params"))
  generator.eval()
  
  nz = torch.rand(1, noise_size).to(device)
  with torch.no_grad():
    outputs = generator(nz)
    outputs = outputs[0].to("cpu").numpy()
  
  outputs = outputs.transpose(1,2,0)
  print(outputs.shape)
  outputs = outputs*0.5+0.5
  plt.imshow(outputs.reshape(28,28))
  plt.show()

In [0]:
if __name__ =="__main__":
  #train()
  eval()