## Import

In [2]:
import torch
from torch import nn
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn
import torchvision.datasets as dsets
import itertools
from torch.autograd import Variable

In [3]:
batchSize = 100
z_size = 100
h_size = 128

# Data Set

In [4]:
transform = transforms.Compose([          
        transforms.ToTensor()
])

cudnn.benchmark = True

train_dataset = dsets.MNIST(root='./data/', train=True, download=True, transform=transform) 
data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=100, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!


# Build Model

In [5]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        
        self.fc0 = nn.Sequential(
        
                nn.Linear(28*28,1024),
                nn.BatchNorm1d(1024),
                nn.LeakyReLU(0.1),
            
                nn.Linear(1024,512),
                nn.BatchNorm1d(512),
                nn.LeakyReLU(0.1),
            
                nn.Linear(512,256),
                nn.BatchNorm1d(256),
                nn.LeakyReLU(0.1),
            
        )
        
        self.fc1 = nn.Sequential(
                nn.Linear(256,100),
                nn.LeakyReLU(0.1)
        )
        
        self.fc2 = nn.Sequential(
                nn.Linear(256,100),
                nn.LeakyReLU(0.1)
        )
        
    def forward(self,x):
        x = x.view(batchSize,-1)
        x = self.fc0(x)
        z_mu = self.fc1(x)
        z_log_sigma = self.fc2(x)
        
        return z_mu, z_log_sigma
        
        

In [6]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()

        self.model = nn.Sequential(
        
                nn.Linear(100,256),
                nn.BatchNorm1d(256),
                nn.LeakyReLU(0.1),
            
                nn.Linear(256,512),
                nn.BatchNorm1d(512),
                nn.LeakyReLU(0.1),
            
                nn.Linear(512,1024),
                nn.BatchNorm1d(1024),
                nn.LeakyReLU(0.1),
            
                nn.Linear(1024,28*28),
                nn.Sigmoid()
        )
        
    def forward(self,x):
        x = self.model(x)
        return x
        
        

In [7]:
def sample_z(z_mu, z_log_sigma):
    epsilon = torch.FloatTensor(100*100).normal_(0,1).view((100,100))
    z_samples = z_mu + torch.mul(z_log_sigma, Variable(epsilon))
    
    return z_samples

In [8]:
encoder = Encoder()
decoder = Decoder()

In [12]:
criterion1 = nn.BCELoss()
criterion2 = nn.KLDivLoss()

In [13]:
optimizer = torch.optim.Adam(itertools.chain(encoder.parameters(),decoder.parameters()),
                             lr=1e-4,
                             betas = (0.5,0.999)
                            )


In [14]:
niter = 20

for epoch in range(21,100):
    for i, (data,_) in enumerate(tqdm_notebook(data_loader)):
        encoder.zero_grad()
        decoder.zero_grad()
   
        data_v = Variable(data)
        z_mu, z_log_sigma = encoder(data_v)
        z_samples = sample_z(z_mu,z_log_sigma)
        fake = decoder(z_samples)
        
        reconstruction_loss = criterion1(fake,data_v)
        KLD_element = z_mu.pow(2).add_(z_log_sigma.pow(2).exp()).mul_(-1).add_(1).add_(z_log_sigma.pow(2))
        KLD = torch.sum(KLD_element).mul_(-0.5)
        
        loss = reconstruction_loss + KLD
        
        loss.backward()
        
        optimizer.step()
        
        if i%100 == 0:
            print("epoch: {}, step: {}, loss: {}".format(epoch+1,i+1,loss.data[0]))
         
        
        
        
    # 결고ㅏ 이미지 저장
    Z_v = Variable(torch.FloatTensor(8*8*z_size).normal_(0,1).view(64,-1))

    samples = decoder(Z_v).data.numpy()
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(8, 8)
    gs.update(wspace=0.05, hspace=0.05)
    for j, sample in enumerate(samples):
        ax = plt.subplot(gs[j])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
    fig.savefig("test_imgs_{}_{}.png".format(epoch,i))


        
        

epoch: 22, step: 1, loss: 154.88479614257812



KeyboardInterrupt: 

In [None]:
torch.save(encoder.state_dict(),"encoder.pth")
torch.save(decoder.state_dict(),"decoder.pth")

In [61]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
% matplotlib inline

In [None]:
Z_v = Variable(torch.FloatTensor(8*8*z_size).normal_(0,1).view(64,-1))

samples = decoder(Z_v).data.numpy()
fig = plt.figure(figsize=(4, 4))
gs = gridspec.GridSpec(8, 8)
gs.update(wspace=0.05, hspace=0.05)
for j, sample in enumerate(samples):
    ax = plt.subplot(gs[j])
    plt.axis('off')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_aspect('equal')
    plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
fig.savefig("test_imgs_{}_{}.png".format(epoch,i))



#### 18th Epoch Image

<a href="https://imgur.com/Ri1HUA2"><img src="https://i.imgur.com/Ri1HUA2.png" title="source: imgur.com" /></a>