### Preparation

In [None]:
import os
import time
import json
import copy
import argparse
import itertools
from os.path import join, exists, splitext, basename
from imp import reload
from glob import glob

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from numpy.random import RandomState
import pandas as pd
from sklearn.model_selection import train_test_split
from tqdm import tqdm

import torch
from torch.nn import Parameter
import torch.nn as nn
import torch.utils.data
import torch.optim as optim
from torch.nn import functional as F
import torchvision
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
from torch.utils.data import DataLoader

from IPython.display import SVG
from IPython.display import display

import seaborn as sns

import Tars 
from Tars.distributions import RealNVP, Normal, Bernoulli
from Tars.models import ML
from Tars.utils import get_dict_values
from Tars.distributions.divergences import KullbackLeibler
from Tars.models import VAE

from utils_tars import *
from model_tars_256 import *
from imp import reload

from solver import Solver
from data_loader_sparse import get_loader

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
# attr読み込み
f = open('./data/list_attr_celeba.txt')
lines2 = f.readlines() # 1行毎にファイル終端まで全て読む(改行文字も含まれる)
attr = lines2[1]
f.close()
attr = attr.split(" ")[:-1]

In [None]:
np.array(attr)[[2, 4, 15,18, 20, 21,24, 26, 31, 39]]

In [None]:
celebA_loader = get_loader("./data/CelebA_nocrop/images", './data/list_attr_celeba.txt', np.array(attr)[[2, 4, 15,18,20, 21,24, 26,31, 39]])
celebA_loader_test = get_loader("./data/CelebA_nocrop/images", './data/list_attr_celeba.txt', np.array(attr)[[2, 4, 15,18,20, 21,24, 26,31, 39]], mode="test")

### model

In [None]:
class Encoder(Normal):
    def __init__(self, z_dim=63, domain_num=10):
        super(Encoder, self).__init__(cond_var=["x", "y"], var=["z"])

        self.z_dim = z_dim

        # encode
        self.conv_e = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),  # 128 ⇒ 64
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # 64 ⇒ 32
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),  # 32 ⇒ 16
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
        )
        self.fc1 = nn.Sequential(
            nn.Linear(256 * 16 * 16,  40),
        )      
        self.fc2 = nn.Sequential(
            nn.Linear(256 * 16 * 16,  domain_num),
        )        
        
        self.fc = nn.Sequential(
            nn.Linear(40+domain_num, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 2*self.z_dim),
        )

    def forward(self, x, y):
        x = self.conv_e(x)
        x = x.view(-1, 256 * 16 * 16)
        x1 = self.fc1(x)
        x2 = self.fc2(x)
        x = torch.cat([x1, x2*y], dim=1)
        x = self.fc(x)
        mu = x[:, :self.z_dim]
        scale = F.softplus(x[:, self.z_dim:])
        return {"loc": mu, "scale": scale}


class Decoder(Bernoulli):
    def __init__(self, z_dim=63, domain_num=10):
        super(Decoder, self).__init__(cond_var=["z", "y"], var=["x"])
        
        self.z_dim = z_dim 

        # decode
        self.fc1 = nn.Sequential(
            nn.Linear(self.z_dim, 40),
        )
        self.fc2 = nn.Sequential(
            nn.Linear(self.z_dim, domain_num),
        )
        
        self.fc_d = nn.Sequential(
            nn.Linear(40+domain_num, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 256 * 16 * 16),
            nn.LeakyReLU(0.2)
        )
        self.conv_d = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )
        
    def forward(self, z, y):
        z1 = self.fc1(z)
        z2 = self.fc2(z)
        z = torch.cat([z1, z2*y], dim=1)
        h = self.fc_d(z)
        h = h.view(-1, 256, 16, 16)
        return {"probs": self.conv_d(h)}
    
    


## CVAE 学習

In [None]:
def encoder_plot(data_loader, E, D):
    E.eval()
    D.eval()
    
    images, labels, _ = iter(data_loader).next()
    images = images.to(device)
    labels = labels.to(device)
    z = E.sample_mean({"x": images, "y": labels})
    samples = D.sample_mean({"z": z, "y": labels})
    samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1).squeeze()
    
    print("↓generate")
    plt.figure(figsize=(10, 8))
    for i in range(10):
        plt.subplot(4, 5, i+1)
        plt.xticks([])
        plt.yticks([])
        plt.subplots_adjust(wspace=0., hspace=0.)
        plt.imshow(samples[i])
    images = images.cpu().data.numpy().transpose(0, 2, 3, 1).squeeze()
    plt.figure(figsize=(10, 8))
    for i in range(10):
        plt.subplot(4, 5, i+11)
        plt.xticks([])
        plt.yticks([])
        plt.subplots_adjust(wspace=0., hspace=0.)
        plt.imshow(images[i])
    plt.savefig("./logs/{}.png".format(epoch+1))
    plt.show()
    print("↑true")
    
def manipulate_attribute(data_loader, E, D, domain_num):
    scale = 10
    E.eval()
    D.eval()
    images, labels, _ = iter(data_loader).next()
    images = images[1:2].to(device)
    labels = labels[1:2].to(device)
    z = E.sample_mean({"x": images, "y": labels})
    L = labels.repeat(domain_num, 1)
    nd = 1 - L.diag()
    for k in range(domain_num):
        L[k, k] = nd[k] * scale
    samples = D.sample_mean({"z": z.repeat(domain_num, 1), "y": L})
    samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1).squeeze()
    plt.figure(figsize=(10, int(domain_num/10)))
    for i in range(domain_num):
        plt.subplot(int(domain_num/10), 10, i+1)
        plt.xticks([])
        plt.yticks([])
        plt.subplots_adjust(wspace=0., hspace=0.)
        plt.imshow(samples[i])
    plt.savefig("./logs/{}.png".format(epoch+1))
    plt.show()

In [None]:
# E.load_state_dict(torch.load('../ICLR_DVAE_logs/E_{}.pkl'.format(experiment_name)))
# D.load_state_dict(torch.load('../ICLR_DVAE_logs/D_{}.pkl'.format(experiment_name)))

In [None]:
def train():
    E.train()
    D.train()
    
    train_reconst_loss = 0
    train_rate_loss = 0
    for batch_idx, (x, y, _) in tqdm(enumerate(celebA_loader)):
        x = x.to(device)
        y = y.to(device)

        # 再構成
        E_optimizer.zero_grad()
        D_optimizer.zero_grad()

        recon_loss = elbo({"x": x, "y": y}, E, D_)
        recon_loss.backward()

        E_optimizer.step()
        D_optimizer.step()

        train_reconst_loss += recon_loss
        
        if (batch_idx+1) % 10000 == 0:
            print(batch_idx+1)
            encoder_plot(celebA_loader_test, E, D)
            manipulate_attribute(celebA_loader_test, E, D, domain_num)
            print("Epoch-SCUT: {}, train_recon_loss: {}".format((epoch + 1), train_reconst_loss))


In [None]:
domain_num = 10
for epoch_ in range(5):
    lr = np.random.uniform(1e-4, 1e-3)
    z_dim = np.random.randint(50, 300)

    log_dir = "../ICLR_DVAE_logs/"
    experiment_name = 'CVAE_celebA_z{}_lr{:.5f}_attr10'.format(
        z_dim, lr)

    # prior model p(z)
    loc = torch.tensor(0.).to(device)
    scale = torch.tensor(1.).to(device)
    prior = Normal(loc=loc, scale=scale, var=["z"], dim=z_dim)

    E = Encoder(z_dim=z_dim).to(device)
    D = Decoder(z_dim=z_dim).to(device)
    D_ = D*prior

    E_optimizer = optim.Adam(E.parameters(), lr=lr)
    D_optimizer = optim.Adam(D.parameters(), lr=lr)

    for epoch in range(3):
        train()

        torch.save(E.state_dict(), join(log_dir, 'E_{}.pkl'.format(experiment_name)))
        torch.save(D.state_dict(), join(log_dir, 'D_{}.pkl'.format(experiment_name)))