In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import os
from os.path import join, exists
import numpy as np
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.utils.data
import torch.optim as optim
from torch.nn import functional as F
from torchvision import datasets, transforms

from pixyz.distributions import Normal, Bernoulli, Deterministic
from pixyz.losses import KullbackLeibler, CrossEntropy, AdversarialKullbackLeibler, Parameter
from pixyz.models import Model

from models import *
from utils import *

batch_size = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load dataset
dataset_zip = np.load('../data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz')
imgs = dataset_zip['imgs']
imgs = imgs[:,None,:,:].astype("float32")

train_imgs, test_imgs = train_test_split(imgs, random_state=42, test_size=1000)

transform = transforms.Compose([
    transforms.ToTensor(),
])
train_loader = torch.utils.data.DataLoader(train_imgs, batch_size=128, shuffle=True)  
test_loader = torch.utils.data.DataLoader(test_imgs, batch_size=128, shuffle=False)  

# training

In [None]:
for gamma in [100, 500]:
    z_dim=10
    epoch_num = 15
    N = len(train_loader)

    # prior model p(z)
    loc = torch.tensor(0.).to(device)
    scale = torch.tensor(1.).to(device)
    prior = Normal(loc=loc, scale=scale, var=["z"], dim=z_dim, name="p_prior")

    E = Encoder(z_dim=z_dim).to(device) # q(z|x)
    D = Decoder(z_dim=z_dim).to(device) # p(x|z)

    reconst = CrossEntropy(E, D)
    kl = KullbackLeibler(E, prior)
    C = Parameter("C")
    loss_cls = reconst.mean() + gamma*(kl.mean()-C).abs()
    model = Model(loss_cls, distributions=[E, D], optimizer=optim.Adam, optimizer_params={"lr":5e-4})

    loss_list = []
    for epoch in range(epoch_num):
        for batch_idx, x in tqdm(enumerate(train_loader)):
            x = x.to(device)
            C_ = 25*(batch_idx+epoch*N)/(epoch_num*N)
            loss = model.train({"x": x, "C": C_})
            loss_list.append(loss.detach())
        plt.plot(loss_list)
        plt.show()
        encoder_plot(test_loader, E, D)
        traverse_plot(test_loader, E, D, 1)


    log_dir = "./logs/"
    experiment_name = "betavae_C_dsprites_z_dim{}_gamma{}".format(z_dim, gamma)
    torch.save(E.state_dict(), join(log_dir, 'E_{}.pkl'.format(experiment_name)))
    torch.save(D.state_dict(), join(log_dir, 'D_{}.pkl'.format(experiment_name)))

2576it [01:06, 38.83it/s]

# visualize

In [None]:
z_dim=10
gamma = 100

log_dir = "./logs/"
experiment_name = "betavae_C_dsprites_z_dim{}_gamma{}".format(z_dim, beta)
E = Encoder(z_dim=z_dim).to(device) # q(z|x)
D = Decoder(z_dim=z_dim).to(device) # p(x|z)
E.load_state_dict(torch.load(join(log_dir, 'E_{}.pkl'.format(experiment_name))))
D.load_state_dict(torch.load(join(log_dir, 'D_{}.pkl'.format(experiment_name))))

In [None]:
traverse_plot(test_loader, E, D, 1, scale=10)

gifをつくる  
0: 四角  
1: 楕円  
2: ハート  



In [None]:
make_gif(0, test_loader, E, D, experiment_name, m=30, scale=3, z_dim=10)
make_gif(1, test_loader, E, D, experiment_name, m=30, scale=3, z_dim=10)
make_gif(2, test_loader, E, D, experiment_name, m=30, scale=3, z_dim=10)

In [None]:
from imp import reload
import models
import utils
reload(models)
reload(utils)
from models import *
from utils import *