# Conditional variational autoencoder (using the VAE class)

# Setup

In [1]:
!python --version


Python 3.6.7


In [2]:
%%bash
git clone https://github.com/masa-su/pixyz.git

fatal: destination path 'pixyz' already exists and is not an empty directory.


In [3]:
# http://pytorch.org/
from os.path import exists
from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())
cuda_output = !ldconfig -p|grep cudart.so|sed -e 's/.*\.\([0-9]*\)\.\([0-9]*\)$/cu\1\2/'
accelerator = cuda_output[0] if exists('/dev/nvidia0') else 'cpu'

!pip install -q http://download.pytorch.org/whl/{accelerator}/torch-0.4.1-{platform}-linux_x86_64.whl torchvision
import torch
!pip install tensorboardX
!pip install -e pixyz --process-dependency-links
torch.cuda.is_available()

Obtaining file:///content/pixyz
[33mDEPRECATION: Dependency Links processing has been deprecated and will be removed in a future release. A possible replacement is PEP 508 URL dependencies. You can find discussion regarding this at https://github.com/pypa/pip/issues/4187.[0m
Installing collected packages: pixyz
  Found existing installation: pixyz 0.0.2
    Can't uninstall 'pixyz'. No files were found to uninstall.
  Running setup.py develop for pixyz
Successfully installed pixyz


True

# CVAE

In [0]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tensorboardX import SummaryWriter

from tqdm import tqdm

batch_size = 128
epochs = 10
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [2]:
root = '../data'
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Lambda(lambd=lambda x: x.view(-1))])
kwargs = {'batch_size': batch_size, 'num_workers': 1, 'pin_memory': True}

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=True, transform=transform, download=True),
    shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=False, transform=transform),
    shuffle=False, **kwargs)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!


In [0]:
from pixyz.distributions import Normal, Bernoulli
from pixyz.losses import KullbackLeibler
from pixyz.models import VAE

In [0]:
x_dim = 784
y_dim = 10
z_dim = 64


# inference model q(z|x,y)
class Inference(Normal):
    def __init__(self):
        super(Inference, self).__init__(cond_var=["x","y"], var=["z"], name="q")

        self.fc1 = nn.Linear(x_dim+y_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc31 = nn.Linear(512, z_dim)
        self.fc32 = nn.Linear(512, z_dim)

    def forward(self, x, y):
        h = F.relu(self.fc1(torch.cat([x, y], 1)))
        h = F.relu(self.fc2(h))        
        return {"loc": self.fc31(h), "scale": F.softplus(self.fc32(h))}

    
# generative model p(x|z,y)    
class Generator(Bernoulli):
    def __init__(self):
        super(Generator, self).__init__(cond_var=["z"], var=["x"], name="p")

        self.fc1 = nn.Linear(z_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, x_dim)

    def forward(self, z):
#         print(z)
        h = F.relu(self.fc1(z))
        h = F.relu(self.fc2(h))
        return {"probs": torch.sigmoid(self.fc3(h))}
# class Generator(Bernoulli):
#     def __init__(self):
#         super(Generator, self).__init__(cond_var=["z","y"], var=["x"], name="p")

#         self.fc1 = nn.Linear(z_dim+y_dim, 512)
#         self.fc2 = nn.Linear(512, 512)
#         self.fc3 = nn.Linear(512, x_dim)

#     def forward(self, z, y):
#         print(z)
#         print(y)
#         h = F.relu(self.fc1(torch.cat([z, y], 1)))
#         h = F.relu(self.fc2(h))
#         return {"probs": torch.sigmoid(self.fc3(h))}

    
# prior model p(z)
loc = torch.tensor(0.).to(device)
scale = torch.tensor(1.).to(device)
prior = Normal(loc=loc, scale=scale, var=["z"], dim=z_dim, name="p_prior")

In [5]:
p = Generator()
q = Inference()

p.to(device)
q.to(device)

print(p)
print(q)

Distribution:
  p(x|z) (Bernoulli)
Network architecture:
  Generator(
    (fc1): Linear(in_features=64, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=512, bias=True)
    (fc3): Linear(in_features=512, out_features=784, bias=True)
  )
Distribution:
  q(z|x,y) (Normal)
Network architecture:
  Inference(
    (fc1): Linear(in_features=794, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=512, bias=True)
    (fc31): Linear(in_features=512, out_features=64, bias=True)
    (fc32): Linear(in_features=512, out_features=64, bias=True)
  )


In [6]:
kl = KullbackLeibler(q, prior)
print(kl)

KL[q(z|x,y)||p_prior(z)]


In [7]:
model = VAE(q, p, regularizer=kl, optimizer=optim.Adam, optimizer_params={"lr":1e-3})
print(model)

Distributions (for training): 
  q(z|x,y), p(x|z) 
Loss function: 
  mean(-E_q(z|x,y)[log p(x|z)] + KL[q(z|x,y)||p_prior(z)]) 
Optimizer: 
  Adam (
  Parameter Group 0
      amsgrad: False
      betas: (0.9, 0.999)
      eps: 1e-08
      lr: 0.001
      weight_decay: 0
  )


In [0]:
def train(epoch):
    train_loss = 0
    for x, y in tqdm(train_loader):
        x = x.to(device)
        y = torch.eye(10)[y].to(device)
#         print(x.size())
#         print(y.size())
        loss = model.train({"x": x, "y": y})
        train_loss += loss
    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    return train_loss

In [10]:
len(train_loader.dataset)

60000

In [0]:
def test(epoch):
    test_loss = 0
    for x, y in test_loader:
        x = x.to(device)
        y = torch.eye(10)[y].to(device)
        loss = model.test({"x": x, "y": y})
        test_loss += loss

    test_loss = test_loss * test_loader.batch_size / len(test_loader.dataset)
    print('Test loss: {:.4f}'.format(test_loss))
    return test_loss

In [0]:
def plot_reconstrunction(x, y):
    with torch.no_grad():
        z = q.sample({"x": x, "y": y}, return_all=False)
        z.update({"y": y})
        recon_batch = p.sample_mean(z).view(-1, 1, 28, 28)
    
        recon = torch.cat([x.view(-1, 1, 28, 28), recon_batch]).cpu()
        return recon
    
def plot_image_from_latent(z, y):
    with torch.no_grad():
        sample = p.sample_mean({"z": z, "y": y}).view(-1, 1, 28, 28).cpu()
        return sample
    
def plot_reconstrunction_changing_y(x, y):
    y_change = torch.eye(10)[range(7)].to(device)
    batch_dummy = torch.ones(x.size(0))[:, None].to(device)    
    recon_all = []
    
    with torch.no_grad():
        for _y in y_change:
            z = q.sample({"x": x, "y": y}, return_all=False)
            z.update({"y": batch_dummy * _y[None,:]})
            recon_batch = p.sample_mean(z).view(-1, 1, 28, 28)
            recon_all.append(recon_batch)
    
        recon_changing_y = torch.cat(recon_all)
        recon_changing_y = torch.cat([x.view(-1, 1, 28, 28), recon_changing_y]).cpu()
        return recon_changing_y

In [13]:
writer = SummaryWriter()

plot_number = 1

z_sample = 0.5 * torch.randn(64, z_dim).to(device)
y_sample = torch.eye(10)[[plot_number]*64].to(device)

_x, _y = iter(test_loader).next()
_x = _x.to(device)
_y = torch.eye(10)[_y].to(device)

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss = test(epoch)
    
#     recon = plot_reconstrunction(_x[:8], _y[:8])
#     sample = plot_image_from_latent(z_sample, y_sample)
#     recon_changing_y = plot_reconstrunction_changing_y(_x[:8], _y[:8])

    writer.add_scalar('train_loss', train_loss.item(), epoch)
    writer.add_scalar('test_loss', test_loss.item(), epoch)      
    
#     writer.add_image('Image_from_latent', sample, epoch)
#     writer.add_image('Image_reconstrunction', recon, epoch)
#     writer.add_image('Image_reconstrunction_change_y', recon_changing_y, epoch)
    
writer.close()

100%|██████████| 469/469 [00:09<00:00, 48.68it/s]

Epoch: 1 Train loss: 178.6360



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 141.5825


100%|██████████| 469/469 [00:09<00:00, 48.71it/s]

Epoch: 2 Train loss: 129.6660



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 122.3040


100%|██████████| 469/469 [00:09<00:00, 47.79it/s]

Epoch: 3 Train loss: 117.7766



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 115.0121


100%|██████████| 469/469 [00:09<00:00, 47.95it/s]

Epoch: 4 Train loss: 112.7396



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 111.4485


100%|██████████| 469/469 [00:09<00:00, 47.82it/s]

Epoch: 5 Train loss: 109.9924



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 109.1090


100%|██████████| 469/469 [00:09<00:00, 48.01it/s]

Epoch: 6 Train loss: 108.1604



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 107.8838


100%|██████████| 469/469 [00:10<00:00, 46.90it/s]

Epoch: 7 Train loss: 106.7672



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 106.9981


100%|██████████| 469/469 [00:09<00:00, 47.74it/s]

Epoch: 8 Train loss: 105.6078



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 106.2075


100%|██████████| 469/469 [00:15<00:00, 31.21it/s]


Epoch: 9 Train loss: 104.6505


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 104.8227


100%|██████████| 469/469 [00:17<00:00, 26.83it/s]

Epoch: 10 Train loss: 103.6828





Test loss: 104.5401
