In [1]:
# import libraries
import torch
import torch.nn as nn


In [2]:
# define CVAE model
class Encoder(nn.Module):
    def __init__(self, input_size, cond_size, latent_size):
        super(Encoder, self).__init__()
        self.mu_network = nn.Sequential(
            nn.Linear(input_size+cond_size, 512), nn.PReLU(), nn.Dropout(),
            nn.Linear(512, 256), nn.PReLU(), nn.Dropout(),
            nn.Linear(256, 32), nn.PReLU(),
            nn.Linear(32, latent_size)
        )
        # we set the covariance matrix to be diag([sigma_1,...,sigma_k])
        self.log_sigma_pow2_network = nn.Sequential(
            nn.Linear(input_size+cond_size, 512), nn.PReLU(), nn.Dropout(),
            nn.Linear(512, 256), nn.PReLU(), nn.Dropout(),
            nn.Linear(256, 32), nn.PReLU(),
            nn.Linear(32, latent_size)
        )
        self.output_size = latent_size
        self.latent_size = latent_size
        self.input_size = input_size
        self.cond_size = cond_size
    def forward(self, x, y):
        # input tensor shape: BxK1, BxK2
        input = torch.cat([x,y], dim=1)
        mu = self.mu_network(input)
        log_sigma_pow2 = self.log_sigma_pow2_network(input)
        return mu, log_sigma_pow2

    def sample(self, mu, log_sigma_pow2, L):
        # given the computed mu, and sigma, obtain L samples by reparameterization
        # draw standard normal distribution
        # input: Bxk
        # return: LxBxk
        eps = torch.randn((L,len(mu),self.latent_size))
        if log_sigma_pow2.is_cuda:
            eps = eps.cuda()
        eps = eps * torch.exp(log_sigma_pow2/2)
        eps = eps + mu
        return eps

    def kl_divergence(self, mu, log_sigma_pow2):
        # given mu and log(sigma^2), obtain the KL divergence relative to N(0,I)
        # using formula from https://stats.stackexchange.com/questions/318748/deriving-the-kl-divergence-loss-for-vaes
        # input: BxK
        # output: B
        res = 1.0 / 2 * (-torch.sum(log_sigma_pow2, dim=1)-self.output_size+\
                         torch.sum(torch.exp(log_sigma_pow2), dim=1)+torch.sum(mu*mu, dim=1))
        return res


class Decoder(nn.Module):
    def __init__(self, latent_size, cond_size, output_size, sigma=0.1):
        super(Decoder, self).__init__()
        self.mu_network = nn.Sequential(
            nn.Linear(latent_size+cond_size, 512), nn.PReLU(), nn.Dropout(),
            nn.Linear(512, 256), nn.PReLU(), nn.Dropout(),
            nn.Linear(256, 32), nn.PReLU(),
            nn.Linear(32, output_size)
        )
        self.latent_size = latent_size
        self.cond_size = cond_size
        self.sigma = sigma
        self.sigma_2 = sigma*sigma

    def forward(self, z, y):
        input = torch.cat([z,y],dim=1)
        mu = self.mu_network(input)
        return mu

    def sample(self, mu):
        # use the computed mu to generate sample
        # input: BxN
        eps = torch.randn(mu.size())
        if mu.is_cuda:
            eps = eps.cuda()
        eps = eps * self.sigma + mu
        return eps

    def generation_loss(self, x, mu):
        # input:
        # - mu: LxBxN
        # - x: BxN
        # output: BxN
        # formula: 1/L * sum_z log(N(x; mu, sigma^2I))
        #       => -1/L \sum_z 1/2*(x-mu)^T(x-mu)/(sigma^2)
        res = - 1.0/2*(x-mu)*(x-mu)/self.sigma_2
        res = torch.sum(res, dim=2) # sum up (x-mu)^2
        # calculate the mean w.r.t. first dimension (L)
        res = torch.mean(res, dim=0)
        return res

class CVAE(nn.Module):
    def __init__(self, input_size, latent_size, cond_size):
        super(CVAE, self).__init__()
        self.encoder = Encoder(input_size, cond_size, latent_size)
        self.decoder = Decoder(latent_size, cond_size, input_size, sigma=0.1)
        self.input_size = input_size
        self.latent_size = latent_size
        self.cond_size = cond_size

    def train_forward(self, x, y, L=10):
        # get necessary signals from the input
        z_mu, z_log_sigma_pow2 = self.encoder(x, y)
        # generate samples of z using the mean and variance
        z = self.encoder.sample(z_mu, z_log_sigma_pow2, L)
        y_extended = y.repeat(len(z),1).view(-1,self.cond_size)
        z = z.view(-1,self.latent_size)  # B and L together first
        if x.is_cuda:
            z.cuda()
        # copy y so we have shape: LxBxc
        x_mu = self.decoder(z, y_extended).view(L,-1,self.input_size)
        return z_mu,z_log_sigma_pow2, z, x_mu

    def gen_forward(self, y):
        # randomly sample a latent z
        z = torch.randn(len(y),self.latent_size)
        if y.is_cuda:
            z = z.cuda()
        x_mean = self.decoder(z, y)
        x = self.decoder.sample(x_mean)
        return x


In [3]:
# unit testing for encoder
# test sample

# def sample(self, mu, log_sigma_pow2, L):
#     # given the computed mu, and sigma, obtain L samples by reparameterization
#     # draw standard normal distribution
#     # input: Bxk
#     # return: LxBxk
input_size = 20
cond_size = 2
latent_size = 5
encoder = Encoder(input_size, cond_size, latent_size)
mu = torch.ones((10,5))
sigma = torch.abs(torch.randn(10,5))
log_sigma_pow2 = torch.log(sigma*sigma)
L = 1000
z = encoder.sample(mu, log_sigma_pow2, L)


# test KL divergence




In [4]:
#print(z)
unit_testing = False
if unit_testing:
    print('estimated mean: ')
    print(torch.mean(z, dim=0))
    print('estimated std: ')
    print(torch.std(z, dim=0))
    print('true mean:')
    print(mu)
    print('true std:')
    print(sigma)

In [9]:
# training
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import transforms
from collections import defaultdict
import matplotlib.pyplot as plt
def idx2onehot(idx, n):

    assert torch.max(idx).item() < n
    if idx.dim() == 1:
        idx = idx.unsqueeze(1)

    onehot = torch.zeros(idx.size(0), n)
    onehot.scatter_(1, idx, 1)

    return onehot

def main():
    batch_size = 32
    input_size = 28*28
    latent_size = 16
    cond_size = 10
    learning_rate = 0.001
    num_epoch = 20
    print_every = 100

    device = torch.device('cuda:0')

    dataset = MNIST(
            root='data', train=True, transform=transforms.ToTensor(),
            download=True)
    data_loader = DataLoader(
        dataset=dataset, batch_size=batch_size, shuffle=True)
    
    cvae = CVAE(input_size, latent_size, cond_size)
    cvae.cuda()
    optimizer = torch.optim.Adam(cvae.parameters(), learning_rate)

    logs = defaultdict(list)
    print('start training...')
    for epoch in range(num_epoch):
        for iter, (x, y) in enumerate(data_loader):
            y = idx2onehot(y, n=10)
            x = x.view(-1,28*28)
            x = x.cuda()
            y = y.cuda()
            #x, y = x.to(device), y.to(device)
            z_mu,z_log_sigma_pow2, z, x_mu = cvae.train_forward(x, y, L=10)
            kl_divergence = cvae.encoder.kl_divergence(z_mu, z_log_sigma_pow2)
            generation_loss = cvae.decoder.generation_loss(x, x_mu)
            loss_i = -generation_loss + kl_divergence
            loss_i = torch.mean(loss_i)

            optimizer.zero_grad()
            loss_i.backward()
            optimizer.step()

            logs['loss'].append(loss_i.item())
            if iter % print_every == 0:
                print('epoch: %d, batch: %d, loss: %f' % (epoch, iter, loss_i.item()))
                # save the reconstructed inference

                y = torch.arange(0, 10).long().unsqueeze(1)
                y_onehot = idx2onehot(y, n=10)
                y_onehot = y_onehot.cuda()
                x = cvae.gen_forward(y_onehot)

                plt.figure()
                plt.figure(figsize=(5, 10))
                for p in range(10):
                    plt.subplot(5, 2, p+1)
                    plt.text(
                        0, 0, "y={:d}".format(y[p].item()), color='black',
                        backgroundcolor='white', fontsize=8)
                    plt.imshow(x[p].cpu().view(28, 28).data.numpy())
                    plt.axis('off')
                fig_root = 'plots'
                os.makedirs(os.path.join(fig_root), exist_ok=True)
                plt.savefig(
                    os.path.join(fig_root,
                                 "epoch{:d}batch{:d}.png".format(epoch, iter)),
                    dpi=300)
                plt.clf()
                plt.close('all')


In [10]:
main()

start training...
epoch: 0, batch: 0, loss: 4777.126953


TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.