In [None]:
import os
import time
import json
import argparse
from os.path import join, exists, splitext, basename
from imp import reload
from glob import glob
import shutil
from itertools import cycle
from imp import reload

import numpy as np

import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

import matplotlib.pyplot as plt

from pixyz.models import Model
from pixyz.losses import ELBO, NLL

import models
from models import *
import utils 
from utils import * 

batch_size = 128
epochs = 10

plt.style.use("ggplot")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

log_dir = "./logs/mnist_gif_m1m2"
if not exists(log_dir):
    os.makedirs(log_dir)

In [None]:
# data loader
seed = 4200

mnist_labeled_path  = "./data/labeled_mnist_image"
if exists(mnist_labeled_path):
    shutil.rmtree(mnist_labeled_path)
for i in range(10):
    os.makedirs(join(mnist_labeled_path, "{}".format(i)))

batch_size = 100

transform = transforms.Compose([
    transforms.ToTensor()
])

dataset = datasets.MNIST('data/mnist', train=True, download=True, transform=transform)
unlabel_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

dataset_test = datasets.MNIST('data/mnist', train=False, download=True, transform=transform)
test_loader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False)

dataset = datasets.MNIST('data/mnist', train=True, download=True)

for i in range(seed, len(dataset)):
    for j in range(10):
        if (dataset[i][1]==j) and (len(glob(join(mnist_labeled_path, "{}/*".format(j)))) < 10):
            dataset[i][0].save(join(mnist_labeled_path, "{}/{}.png".format(j, i)))
    if len(glob(join(mnist_labeled_path, "*/*"))) == 100:
        break

        
labeled_dataset = datasets.ImageFolder(mnist_labeled_path, transform=transform)
label_loader = DataLoader(labeled_dataset, batch_size=batch_size, shuffle=True)

In [None]:
# ラベルつきデータの表示
samples, labels = iter(label_loader).next()
argsort = np.argsort(labels)
samples = samples[argsort]
samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1).squeeze()

# 横軸 z固定, y変化
# 縦軸 z変化, y固定
plt.figure(figsize=(10, 10))
for i in range(100):
    plt.subplot(10, 10, i+1)
    plt.xticks([])
    plt.yticks([])
    plt.subplots_adjust(wspace=0., hspace=0.)
    plt.imshow(samples[i], cmap=plt.cm.gray)
plt.show()


In [None]:
z_dim = 63
z2_dim = 20

# prior model p(z2)
loc = torch.tensor(0.).to(device)
scale = torch.tensor(1.).to(device)
prior_m2 = Normal(loc=loc, scale=scale, var=["z2"], dim=z2_dim)

# prior model p(z)
loc = torch.tensor(0.).to(device)
scale = torch.tensor(1.).to(device)
prior_m1 = Normal(loc=loc, scale=scale, var=["z"], dim=z_dim)

# initialize network
E_m1 = Encoder_m1().to(device)
D_m1 = Decoder_m1().to(device)
E_m2 = LatentEncoder().to(device)
D_m2 = LatentDecoder().to(device)
C_m2 = LatentClassifier().to(device)

D_j_m1 = D_m1 * prior_m1
D_j_m1.to(device)

D_j_m2 = D_m2 * prior_m2
D_j_m2.to(device)


# distributions for unsupervised learning
Eu_m1 = E_m1.replace_var(x="x_u")
Du_m1 = D_m1.replace_var(x="x_u")
Eu_m2 = E_m2.replace_var(z="z_u", y="y_u")
Du_m2 = D_m2.replace_var(z="z_u", y="y_u")
Cu = C_m2.replace_var(z="z_u", y="y_u")
ECu = Eu_m2 * Cu

Du_j_m1 = Du_m1 * prior_m1
Du_j_m1.to(device)

Du_j_m2 = Du_m2 * prior_m2
Du_j_m2.to(device)

ECu.to(device)
Cu.to(device)


In [None]:
elbo_u_m1 = ELBO(Du_j_m1, Eu_m1)
elbo_m1 = ELBO(D_j_m1, E_m1)
elbo_u_m2 = ELBO(Du_j_m2, ECu)
elbo_m2 = ELBO(D_j_m2, E_m2)
nll = NLL(C_m2)

rate = 1 * (len(unlabel_loader) + len(label_loader)) / len(label_loader)

loss_m1 = -elbo_u_m1.mean() #-elbo_m1.mean()
loss_m2 = -elbo_u_m2.mean() -elbo_m2.mean() + (rate * nll).mean() 


# 最適化
model1 = Model(loss_m1,distributions=[E_m1, D_m1], optimizer=optim.Adam, optimizer_params={"lr":1e-3})
print(model1)

# 最適化
model2 = Model(loss_m2,test_loss=nll.mean(),
              distributions=[E_m2, D_m2, C_m2], optimizer=optim.Adam, optimizer_params={"lr":5e-4})
print(model2)

In [None]:
train_hist = {}
train_hist["precision"] = []

# M1 training
for epoch in range(20):
    train_loss = 0
    for x_u, _ in tqdm(unlabel_loader):
        x_u = x_u.to(device)        
        loss = model1.train({"x_u": x_u})
        train_loss += loss
        
    train_loss = train_loss * unlabel_loader.batch_size / len(unlabel_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    plot_sample_m1(D_m1, epoch)

 -(mean(E_p(z2,y_u|z_u)[log p(z_u,z2|y_u)/p(z2,y_u|z_u)])) - mean(E_q(z2|z,y)[log p(z,z2|y)/q(z2|z,y)]) + mean(log q(y|z) * 601.0) 

In [None]:
E_m2.sample({"y": y, "z": z})

In [None]:
D_j_m2.prob_text

 (z_2*y).shape
torch.Size([100, 10])

ipdb> y.shape
torch.Size([100, 10])
ipdb> z_2.shape
torch.Size([1, 10])
ipdb> z_1.shape
torch.Size([1, 40])

In [None]:
%debug
D_j_m2.sample({"y": y})

In [None]:
# M2 training
for epoch in range(100):
    train_loss = 0
    for (x, y), (x_u, y_u) in tqdm(zip(cycle(label_loader), unlabel_loader), total=len(unlabel_loader)):
        x = x[:, 0:1].to(device)
        y = torch.eye(10)[y].to(device)
        x_u = x_u.to(device)        
        z = E_m1.sample({"x": x})["z"]
        z_u = Eu_m1.sample({"x_u": x_u})["z"]
        loss = model2.train({"y": y, "z": z, "z_u": z_u})
        train_loss += loss
        
    train_loss = train_loss * unlabel_loader.batch_size / len(unlabel_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    train_hist["precision"].append(compute_precision_m1(C_m2, E_m1, test_loader))
    plot_sample_m1m2(D_m2, D_m1, epoch)
    plot_loss(train_hist)


In [None]:
reload(models)
reload(utils)
from models import * 
from utils import *