In [None]:
import os
import time
import json
import argparse
from os.path import join, exists, splitext, basename
from imp import reload
from glob import glob
import shutil
from itertools import cycle
from imp import reload

import numpy as np

import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

import matplotlib.pyplot as plt

from pixyz.models import Model
from pixyz.losses import ELBO, NLL

import models
from models import *
import utils 
from utils import * 

batch_size = 128
epochs = 10

plt.style.use("ggplot")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

log_dir = "./logs/mnist_gif_m1"
if not exists(log_dir):
    os.makedirs(log_dir)

In [None]:
# data loader
seed = 4200

mnist_labeled_path  = "./data/labeled_mnist_image"
if exists(mnist_labeled_path):
    shutil.rmtree(mnist_labeled_path)
for i in range(10):
    os.makedirs(join(mnist_labeled_path, "{}".format(i)))

batch_size = 100

transform = transforms.Compose([
    transforms.ToTensor()
])

dataset = datasets.MNIST('data/mnist', train=True, download=True, transform=transform)
unlabel_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

dataset_test = datasets.MNIST('data/mnist', train=False, download=True, transform=transform)
test_loader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False)

dataset = datasets.MNIST('data/mnist', train=True, download=True)

for i in range(seed, len(dataset)):
    for j in range(10):
        if (dataset[i][1]==j) and (len(glob(join(mnist_labeled_path, "{}/*".format(j)))) < 10):
            dataset[i][0].save(join(mnist_labeled_path, "{}/{}.png".format(j, i)))
    if len(glob(join(mnist_labeled_path, "*/*"))) == 100:
        break

        
labeled_dataset = datasets.ImageFolder(mnist_labeled_path, transform=transform)
label_loader = DataLoader(labeled_dataset, batch_size=batch_size, shuffle=True)

In [None]:
# ラベルつきデータの表示
samples, labels = iter(label_loader).next()
argsort = np.argsort(labels)
samples = samples[argsort]
samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1).squeeze()

# 横軸 z固定, y変化
# 縦軸 z変化, y固定
plt.figure(figsize=(10, 10))
for i in range(100):
    plt.subplot(10, 10, i+1)
    plt.xticks([])
    plt.yticks([])
    plt.subplots_adjust(wspace=0., hspace=0.)
    plt.imshow(samples[i], cmap=plt.cm.gray)
plt.show()


In [None]:
z_dim = 63

# prior model p(z)
loc = torch.tensor(0.).to(device)
scale = torch.tensor(1.).to(device)
prior = Normal(loc=loc, scale=scale, var=["z"], dim=z_dim, name="p")

# initialize network
E = Encoder_m1().to(device)
D = Decoder_m1().to(device)
C = LatentClassifier().to(device)
D_j = D * prior
D_j.to(device)

In [None]:
elbo = ELBO(D_j, E)
nll = NLL(C)

rate = 1 * (len(unlabel_loader) + len(label_loader)) / len(label_loader)

loss_cls =  -elbo.mean() + (rate * nll).mean()

print(loss_cls)

In [None]:
# 最適化
model = Model(loss_cls,test_loss=nll.mean(),
              distributions=[E, D, C], optimizer=optim.Adam, optimizer_params={"lr":5e-4})
print(model)

In [None]:
train_hist = {}
train_hist["precision"] = []
for epoch in range(100):
    train_loss = 0
    for (x, y), (x_u, y_u) in tqdm(zip(cycle(label_loader), unlabel_loader), total=len(unlabel_loader)):
        x = x[:, 0:1].to(device)
        y = torch.eye(10)[y].to(device)
        x_u = x_u.to(device)        
        z = E.sample({"x": x})["z"]
        loss = model.train({"y": y, "x": x_u, "z": z})
        train_loss += loss
        
    train_loss = train_loss * unlabel_loader.batch_size / len(unlabel_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    train_hist["precision"].append(compute_precision_m1(C, E, test_loader))
    plot_sample_m1(D, epoch)
    plot_loss(train_hist)


In [None]:
reload(models)
reload(utils)
from models import * 
from utils import *