In [1]:
%matplotlib inline
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable
import matplotlib.pyplot as plt

In [2]:
batch_size = 64
learning_rate = 0.001
num_epoch = 10
hidden_size = 2
hidden_size_in = 256

# 데이터 로드

In [3]:
mnist_train = dset.MNIST("./", train=True, transform=transforms.ToTensor(), target_transform=None, download=True)
mnist_test = dset.MNIST("./", train=False, transform=transforms.ToTensor(), target_transform=None, download=True)

In [4]:
train_loader = torch.utils.data.DataLoader(mnist_train,batch_size=batch_size, shuffle=True,num_workers=2,drop_last=True)
test_loader = torch.utils.data.DataLoader(mnist_test,batch_size=batch_size, shuffle=False,num_workers=2,drop_last=True)

# 모델 정의

In [5]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder,self).__init__()
        self.fc1 = nn.Linear(794, hidden_size_in)
        self.fc1_1 = nn.Linear(hidden_size_in, hidden_size)
        self.fc1_2 = nn.Linear(hidden_size_in, hidden_size)
        self.relu = nn.ReLU()
                        
    def encode(self,x):
        
        x = self.relu(self.fc1(x))
        mu = self.fc1_1(x)
        log_var = self.fc1_2(x)
                
        return mu,log_var
    
    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        
        eps = torch.FloatTensor(std.size()).normal_()
        eps = Variable(eps).cuda()
        
        return eps.mul(std).add_(mu)
    
    def forward(self,x):
        mu, logvar = self.encode(x)
        reparam = self.reparametrize(mu,logvar)
        
        return mu,logvar,reparam
        
encoder = Encoder().cuda()

In [6]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder,self).__init__()
        self.fc2 = nn.Linear(hidden_size+10, 256)
        self.relu = nn.ReLU()
        self.fc3 = nn.Linear(hidden_size_in, 784)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self,x):
        out = self.fc2(x)
        out = self.relu(out)
        out = self.fc3(out)
        out = self.sigmoid(out)
        out = out.view(batch_size,28,28,1)
        
        return out
        
decoder = Decoder().cuda()

# Q는 인코딩

# P는 디코딩

# one_hot은 label을 one-hot으로

In [7]:
def Q(x,c):
    x = x.view(batch_size,-1)
    
    inputs = torch.cat([x,c],1)
    #print(inputs)
    mu,log_var,reparam = encoder(inputs)
    return mu,log_var,reparam

In [8]:
def P(z,c):
    inputs=torch.cat([z,c],1)
    sample_x = decoder(inputs)
    return sample_x

In [9]:
def one_hot(label,size):
    c = torch.zeros(batch_size,size).cuda()
    if type(label) == int:
        c = torch.zeros(1,size).cuda()
        c[0][label] = 1
        return c
    for i, v in enumerate(label) :
        c[i][v] = 1
    return c

# loss 정의

In [10]:
def loss_function(recon_x, x, mean, log_var):
    BCE = torch.nn.functional.binary_cross_entropy(recon_x, x, size_average=False)

    KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())

    return BCE , KLD

In [12]:
parameters = list(encoder.parameters())+ list(decoder.parameters())
optimizer = torch.optim.Adam(parameters, lr=learning_rate)

# 학습

## losses, bces, klds는 loss확인용

In [15]:
try:
    encoder, decoder = torch.load('./model/conditional_variational_autoencoder_.pkl')
    print("\n--------model restored--------\n")
except:
    print("\n--------model not restored--------\n")
    pass
losses = []
bces = []
klds = []
for i in range(num_epoch):
    for j,[image,label] in enumerate(train_loader):
        c = one_hot(label,size=10)
        optimizer.zero_grad()
        
        c = Variable(c).cuda()
        image = Variable(image).cuda()

        mu,log_var,reparam = Q(image,c)
        #print(reparam.size())
        output = P(reparam,c)
        #print(output.size())
        BCE , KLD = loss_function(output, image, mu, log_var)
        loss = BCE + KLD
        loss.backward()
        optimizer.step()

        if j % 100 == 0:
            #torch.save([encoder,decoder],'./model/conditional_variational_autoencoder1.pkl')
            losses.append(float(loss.cpu().data)/batch_size)
            bces.append(float(BCE.cpu().data)/batch_size)
            klds.append(float(KLD.cpu().data)/batch_size)
            np.savez(L=losses,file='loss.npz')
            np.savez(L=bces,file='bces.npz')
            np.savez(L=klds,file='klds.npz') 
            print('epoch: '+str(i)+' loss: '+str(float(loss.cpu().data)))


--------model not restored--------



  "Please ensure they have the same size.".format(target.size(), input.size()))


epoch: 0 loss: 33778.47265625
epoch: 0 loss: 12839.0849609375
epoch: 0 loss: 11305.08984375


KeyboardInterrupt: 

# 결과확인

## z는 random, c는 0~9까지

In [None]:
batch_size = 1
cv = 9 #test_loader.dataset.test_labels[10]
fix=torch.randn([1, hidden_size])

In [None]:
batch_size = 1
for i in range(10):
    c = Variable(one_hot(i,size=10)).type(torch.FloatTensor).cuda()
    z = Variable(fix, volatile=False).cuda()    
    output = P(z,c)
    out_img = torch.squeeze(output.cpu().data)
    print(i)
    plt.imshow(out_img.numpy(),cmap='gray')
    plt.show()