In [None]:
import os
from os.path import join, exists, splitext, basename
from imp import reload
from glob import glob

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
from tqdm import tqdm

import torch
from torch.nn import Parameter
import torch.nn as nn
import torch.utils.data
import torch.optim as optim
from torch.nn import functional as F
import torchvision
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
from torch.utils.data import DataLoader

from pixyz.distributions import Normal, Bernoulli, Categorical
from pixyz.losses import KullbackLeibler, CrossEntropy
from pixyz.models import Model, VAE


from models import *
from utils import *

from mnist_A_data_loader import get_mnist_A_loader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

unique_label = []
for i in range(2):
    for j in range(3):
        for k in range(4):
            for l in range(10):
                unique_label.append([i, j, k, l])
unique_label = np.array(unique_label)

train_label_attr, test_label_attr = train_test_split(unique_label, random_state=42, test_size=40)
valid_label_attr, test_label_attr = train_test_split(test_label_attr, random_state=42, test_size=20)

train_loader = get_mnist_A_loader("../data/MNIST_A/train_X/",
                                  "../data/MNIST_A/train_y.npy", train_label_attr)
test_loader = get_mnist_A_loader("../data/MNIST_A/test_X/", 
                                 "../data/MNIST_A/test_y.npy", test_label_attr)
valid_loader = get_mnist_A_loader("../data/MNIST_A/valid_X/", 
                                  "../data/MNIST_A/valid_y.npy", valid_label_attr)

log_dir = "./logs"
classifier = MNIST_A_Classifier().to(device)
classifier.load_state_dict(torch.load(join(log_dir, 'MNIST_A_classifier.pkl')))
classifier.eval()

In [None]:
class Encoder_XY(Normal):
    def __init__(self, z_dim=64, y_dim=1+3+4+10):
        super(Encoder_XY, self).__init__(cond_var=["x", "y1", "y2", "y3", "y4"], var=["z"])

        self.z_dim = z_dim

        # encode
        self.conv_e = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),  # 64 ⇒ 32
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),  # 32 ⇒ 16
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # 16 ⇒ 8
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
        )
        self.fc1 = nn.Sequential(
            nn.Linear(128 * 8 * 8,  40),
        )      
        self.fc2 = nn.Sequential(
            nn.Linear(128 * 8 * 8,  y_dim),
        )        
        
        self.fc = nn.Sequential(
            nn.Linear(40+y_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 2*self.z_dim),
        )

    def forward(self, x, y1, y2, y3, y4):
        x = self.conv_e(x)
        x = x.view(-1, 128 * 8 * 8)
        x1 = self.fc1(x)
        x2 = self.fc2(x)
        x = torch.cat([x1, x2*torch.cat([y1[:, None], y2, y3, y4], dim=1)], dim=1)
        x = self.fc(x)
        mu = x[:, :self.z_dim]
        scale = F.softplus(x[:, self.z_dim:])
        return {"loc": mu, "scale": scale}
    

class Encoder_X(Normal):
    def __init__(self, z_dim=64):
        super(Encoder_X, self).__init__(cond_var=["x"], var=["z"])

        self.z_dim = z_dim

        # encode
        self.conv_e = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),  # 64 ⇒ 32
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),  # 32 ⇒ 16
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # 16 ⇒ 8
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
        )

        self.fc = nn.Sequential(
            nn.Linear(128 * 8 * 8, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 2*self.z_dim),
        )

    def forward(self, x):
        x = self.conv_e(x)
        x = x.view(-1, 128 * 8 * 8)
        x = self.fc(x)
        mu = x[:, :self.z_dim]
        scale = F.softplus(x[:, self.z_dim:])
        return {"loc": mu, "scale": scale}
    

class Encoder_Y(Normal):
    def __init__(self, z_dim=64, y_dim=1+3+4+10):
        super(Encoder_Y, self).__init__(cond_var=["y1", "y2", "y3", "y4"], var=["z"])

        self.z_dim = z_dim

        self.fc = nn.Sequential(
            nn.Linear(y_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 2*self.z_dim),
        )

    def forward(self, y1, y2, y3, y4):
        y = torch.cat([y1[:, None], y2, y3, y4], dim=1)
        y = self.fc(y)
        mu = y[:, :self.z_dim]
        scale = F.softplus(y[:, self.z_dim:])
        return {"loc": mu, "scale": scale}
    

class Decoder_X(Bernoulli):
    def __init__(self, z_dim=64):
        super(Decoder_X, self).__init__(cond_var=["z"], var=["x"])

        self.z_dim = z_dim

        # decode
        self.fc_d = nn.Sequential(
            nn.Linear(self.z_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 128 * 8 * 8),
            nn.LeakyReLU(0.2)
        )
        self.conv_d = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def forward(self, z):
        h = self.fc_d(z)
        h = h.view(-1, 128, 8, 8)
        return {"probs": self.conv_d(h)}


class Decoder_Y1(Bernoulli):
    def __init__(self, z_dim=64, y_dim=1):
        super(Decoder_Y1, self).__init__(cond_var=["z"], var=["y1"])
        self.y_dim = y_dim
        self.z_dim = z_dim

        # decode
        self.fc_d = nn.Sequential(
            nn.Linear(self.z_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, y_dim),
        )

    def forward(self, z):
        y = self.fc_d(z)
        return {"probs":  torch.sigmoid(y)}

    
class Decoder_Y2(Categorical):
    def __init__(self, z_dim=64, y_dim=3):
        super(Decoder_Y2, self).__init__(cond_var=["z"], var=["y2"])
        self.y_dim = y_dim
        self.z_dim = z_dim

        # decode
        self.fc_d = nn.Sequential(
            nn.Linear(self.z_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, y_dim),
        )

    def forward(self, z):
        y = self.fc_d(z)
        return {"probs":  torch.softmax(y, dim=1)}
    

class Decoder_Y3(Categorical):
    def __init__(self, z_dim=64, y_dim=4):
        super(Decoder_Y3, self).__init__(cond_var=["z"], var=["y3"])
        self.y_dim = y_dim
        self.z_dim = z_dim

        # decode
        self.fc_d = nn.Sequential(
            nn.Linear(self.z_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, y_dim),
        )

    def forward(self, z):
        y = self.fc_d(z)
        return {"probs":  torch.softmax(y, dim=1)}

    
class Decoder_Y4(Categorical):
    def __init__(self, z_dim=64, y_dim=10):
        super(Decoder_Y4, self).__init__(cond_var=["z"], var=["y4"])
        self.y_dim = y_dim
        self.z_dim = z_dim

        # decode
        self.fc_d = nn.Sequential(
            nn.Linear(self.z_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, y_dim),
        )

    def forward(self, z):
        y = self.fc_d(z)
        return {"probs":  torch.softmax(y, dim=1)}

In [None]:
z_dim = 64

loc = torch.tensor(0.).to(device)
scale = torch.tensor(1.).to(device)
prior = Normal(loc=loc, scale=scale, var=["z"], dim=z_dim, name="p_prior")

p_x = Decoder_X().to(device)
p_y1 = Decoder_Y1().to(device)
p_y2 = Decoder_Y2().to(device)
p_y3 = Decoder_Y3().to(device)
p_y4 = Decoder_Y4().to(device)

q = Encoder_XY().to(device)
q_x = Encoder_X().to(device)
q_y = Encoder_Y().to(device)

p = p_x * p_y1 * p_y2 * p_y3 * p_y4
print(p)


kl = KullbackLeibler(q, prior)
kl_x = KullbackLeibler(q, q_x)
kl_y = KullbackLeibler(q, q_y)

regularizer = kl + kl_x + kl_y
print(regularizer)

model = VAE(q, p, other_distributions=[q_x, q_y],
            regularizer=regularizer, optimizer=optim.Adam, optimizer_params={"lr":1e-3})
print(model)

experiment_name =  'JMVAE_expert_MNIST_A_z_dim{}'.format(z_dim)

# train

In [None]:
for i in range(10):
    for batch_idx, (x, y) in tqdm(enumerate(train_loader)):
        x = x.to(device)
        y = y.to(device)
        loss = model.train({"x": x, "y1": y[:, 0], "y2": y[:, 1:4], "y3": y[:, 4:8], "y4": y[:, 8:]})
 
        if (batch_idx+1)%1000==0:
            print("train")
            encoder_plot(train_loader, q_x, p_x, conditional=False)
            print("valid")
            encoder_plot(valid_loader, q_x, p_x, conditional=False)


In [None]:
torch.save(p_x.state_dict(), join(log_dir, 'px_{}.pkl'.format(experiment_name)))
torch.save(p_y1.state_dict(), join(log_dir, 'py1_{}.pkl'.format(experiment_name)))
torch.save(p_y2.state_dict(), join(log_dir, 'py2_{}.pkl'.format(experiment_name)))
torch.save(p_y3.state_dict(), join(log_dir, 'py3_{}.pkl'.format(experiment_name)))
torch.save(p_y4.state_dict(), join(log_dir, 'py4_{}.pkl'.format(experiment_name)))
torch.save(q_x.state_dict(), join(log_dir, 'qx_{}.pkl'.format(experiment_name)))
torch.save(q_y.state_dict(), join(log_dir, 'qy_{}.pkl'.format(experiment_name)))
torch.save(q.state_dict(), join(log_dir, 'q_{}.pkl'.format(experiment_name)))

# 検証

- train dataにある属性なら生成できるのか
- test dataにある属性は生成できないのか

In [None]:
p_x.load_state_dict(torch.load(join(log_dir, 'px_{}.pkl'.format(experiment_name))))
p_y.load_state_dict(torch.load(join(log_dir, 'py_{}.pkl'.format(experiment_name))))
q_x.load_state_dict(torch.load(join(log_dir, 'qx_{}.pkl'.format(experiment_name))))
q_y.load_state_dict(torch.load(join(log_dir, 'qy_{}.pkl'.format(experiment_name))))
q.load_state_dict(torch.load(join(log_dir, 'q_{}.pkl'.format(experiment_name))))

## testデータ

In [None]:
print(q_y)

In [None]:
p_x.eval()
q_y.eval()

sample_size = 5
attr_size = 10
label = []
for j in range(attr_size):
        for i in range(sample_size):
            label.append(test_label_attr[j])
label = np.array(label)
onehot_label = torch.FloatTensor(label2onehot(np.array(label)))
onehot_label = onehot_label.to(device)

z = q_y.sample({"y1": onehot_label[:, 0], "y2": onehot_label[:, 1:4], "y3": onehot_label[:, 4:8], "y4": onehot_label[:, 8:]})["z"]
samples = p_x.sample_mean({"z": z})
p_y1, p_y2, p_y3, p_y4 = classifier(samples)
pred = onehot2label(torch.cat([p_y1[:, None], p_y2, p_y3, p_y4], 1))
samples_ = samples.cpu().data.numpy().transpose(0, 2, 3, 1).squeeze()
plt.figure(figsize=(sample_size*3, attr_size*3))
for i in range(sample_size*attr_size):
    plt.subplot(attr_size, sample_size, i+1)
    plt.xticks([])
    plt.yticks([])
    classifier(samples)
    plt.title("T: {}, P: {}".format(label[i], pred[i]))
    plt.imshow(samples_[i], plt.cm.gray)
plt.show()

In [None]:
# accuracy
accuracy_data = (pred == label).sum(1) == 4
accuracy_data.sum() / len(accuracy_data)

## 訓練データ

In [None]:
p_x.eval()
q_y.eval()

sample_size = 5
attr_size = 10
label = []
for j in range(attr_size):
        for i in range(sample_size):
            label.append(train_label_attr[j])
label = np.array(label)
onehot_label = torch.FloatTensor(label2onehot(np.array(label)))
onehot_label = onehot_label.to(device)

z = q_y.sample({"y1": onehot_label[:, 0], "y2": onehot_label[:, 1:4], "y3": onehot_label[:, 4:8], "y4": onehot_label[:, 8:]})["z"]
samples = p_x.sample_mean({"z": z})
p_y1, p_y2, p_y3, p_y4 = classifier(samples)
pred = onehot2label(torch.cat([p_y1[:, None], p_y2, p_y3, p_y4], 1))
samples_ = samples.cpu().data.numpy().transpose(0, 2, 3, 1).squeeze()
plt.figure(figsize=(sample_size*3, attr_size*3))
for i in range(sample_size*attr_size):
    plt.subplot(attr_size, sample_size, i+1)
    plt.xticks([])
    plt.yticks([])
    classifier(samples)
    plt.title("T: {}, P: {}".format(label[i], pred[i]))
    plt.imshow(samples_[i], plt.cm.gray)
plt.show()

In [None]:
# accuracy
accuracy_data = (pred == label).sum(1) == 4
accuracy_data.sum() / len(accuracy_data)