In [2]:
import torch
from torch.nn import init
from torch.autograd import Variable
import torchvision
from torchvision.utils import save_image
import matplotlib.pyplot as plt 
import numpy as np
import os
import time
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

In [3]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Sequential(
            nn.Linear(100,128),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.2),
            nn.Linear(128,256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),
            nn.Linear(256,512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2),\
            nn.Linear(512,1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024,28*28),
            nn.Tanh()
            
        )
    def forward(self, x):
        x= self.linear(x)
        x = x.view(-1, 1,28,28)
        return x

In [4]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()      
        self.linear = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1)
        )
    def forward(self, img):
        img_flat = img.view(-1,28*28)
        validity = self.linear(img_flat)
        return validity

In [10]:
# build network
G = Generator()
D = Discriminator()
G.cuda()
D.cuda()

G_opt = torch.optim.RMSprop(G.parameters(), lr=0.00005)
D_opt = torch.optim.RMSprop(D.parameters(), lr=0.00005)

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])
realimages = torchvision.datasets.MNIST(root='./data', train=True,download=True, transform=transform)
train_loader= torch.utils.data.DataLoader(realimages, batch_size=100,shuffle=True, num_workers=2)

In [11]:
fixed_z = torch.Tensor(100, 100).uniform_(-1, 1)
fixed_z = Variable(fixed_z.cuda())
for e in range(200):

    for n, (x_,_) in enumerate(train_loader):
        
        x_ = Variable(x_.cuda())
        #print(x_[0].shape)

        # run real input on Discriminator
        D_result_real = D(x_)

        # run Generator input on Discriminator
        z1_ = torch.Tensor(100 ,100).uniform_(-1, 1)
        z1_ = Variable(z1_.cuda())
        x_fake = G(z1_)
        D_result_fake = D(x_fake)
        
        D_loss = -(torch.mean(D_result_real) - torch.mean(D_result_fake))

        # optimize Discriminator
        D.zero_grad()
        D_loss.backward()
        D_opt.step() 
        
        #Genrator
        
        z2_ = torch.Tensor(100, 100).uniform_(-1, 1)
        z2_ = Variable(z2_.cuda())
        G_result = G(z2_)       
        output_fake = D(G_result)
        G_loss = -torch.mean(output_fake)

        G.zero_grad()
        G_loss.backward()
        G_opt.step()
        for p in D.parameters():
            p.data.clamp_(-0.01, 0.01)
 
    print("Epoch {}/{}...".format(e+1, 100),"Discriminator Loss: {:.4f}...".format(D_loss.data[0]),"Generator Loss: {:.4f}".format(G_loss.data[0])) 
    test_images = G(fixed_z)
    #print(test_images[0].shape)
    save_image(test_images.data,'./samples/WGAN/output.png'.format(e),nrow=10)

Epoch 1/100... Discriminator Loss: -0.0132... Generator Loss: -4.8579




Epoch 2/100... Discriminator Loss: -0.0103... Generator Loss: -2.9883
Epoch 3/100... Discriminator Loss: -0.0070... Generator Loss: -1.8536
Epoch 4/100... Discriminator Loss: -0.0122... Generator Loss: -0.7711
Epoch 5/100... Discriminator Loss: -0.0101... Generator Loss: -0.5260
Epoch 6/100... Discriminator Loss: -0.0140... Generator Loss: -0.3988
Epoch 7/100... Discriminator Loss: -0.0167... Generator Loss: -0.3485
Epoch 8/100... Discriminator Loss: -0.0190... Generator Loss: -0.1979
Epoch 9/100... Discriminator Loss: -0.0272... Generator Loss: -0.2762
Epoch 10/100... Discriminator Loss: -0.0191... Generator Loss: -0.2125
Epoch 11/100... Discriminator Loss: -0.0191... Generator Loss: -0.2609
Epoch 12/100... Discriminator Loss: -0.0182... Generator Loss: -0.2196
Epoch 13/100... Discriminator Loss: -0.0203... Generator Loss: -0.1684
Epoch 14/100... Discriminator Loss: -0.0155... Generator Loss: -0.1857
Epoch 15/100... Discriminator Loss: -0.0192... Generator Loss: -0.1846
Epoch 16/100..

Epoch 118/100... Discriminator Loss: -0.0006... Generator Loss: -0.0575
Epoch 119/100... Discriminator Loss: -0.0024... Generator Loss: -0.0088
Epoch 120/100... Discriminator Loss: -0.0013... Generator Loss: -0.0500
Epoch 121/100... Discriminator Loss: 0.0010... Generator Loss: 0.0249
Epoch 122/100... Discriminator Loss: -0.0007... Generator Loss: -0.0181
Epoch 123/100... Discriminator Loss: -0.0007... Generator Loss: -0.0087
Epoch 124/100... Discriminator Loss: -0.0008... Generator Loss: -0.0084
Epoch 125/100... Discriminator Loss: -0.0012... Generator Loss: -0.0018
Epoch 126/100... Discriminator Loss: -0.0008... Generator Loss: 0.0090
Epoch 127/100... Discriminator Loss: 0.0032... Generator Loss: -0.0148
Epoch 128/100... Discriminator Loss: -0.0035... Generator Loss: -0.1170
Epoch 129/100... Discriminator Loss: -0.0025... Generator Loss: -0.0229
Epoch 130/100... Discriminator Loss: 0.0024... Generator Loss: -0.0864
Epoch 131/100... Discriminator Loss: -0.0008... Generator Loss: -0.05