In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
plt.style.use("ggplot")
import os
from os.path import join, exists
import numpy as np
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.utils.data
import torch.optim as optim
from torch.nn import functional as F
from torchvision import datasets, transforms

from pixyz.distributions import DataDistribution
from pixyz.models import GAN

from models import *
from utils import *

batch_size = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

dataset = datasets.MNIST('../data/mnist', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

dataset_test = datasets.MNIST('../data/mnist', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=False)

log_dir = "./logs/mnist_gif"
if not exists(log_dir):
    os.makedirs(log_dir)

In [None]:
z_dim = 64

# prior model p(z)
loc = torch.tensor(0.).to(device)
scale = torch.tensor(1.).to(device)
prior = Normal(loc=loc, scale=scale, var=["z"], dim=z_dim, name="p_prior")

# generative model
p_g = generator(input_dim=z_dim)
p = (p_g*prior).marginalize_var("z").to(device)

# data distribution
p_data = DataDistribution(["x"]).to(device)

d = discriminator().to(device)

model = GAN(p_data, p, d,
            optimizer=optim.Adam, optimizer_params={"lr":0.0002},
            d_optimizer=optim.Adam, d_optimizer_params={"lr":0.0002})
print(model)

In [None]:
epoch_num = 100

train_loss = []
train_d_loss = []
for epoch in range(epoch_num):
    for x, _ in tqdm(train_loader):
        x = x.to(device)
        loss, d_loss = model.train({"x": x})
        train_loss.append(loss)
        train_d_loss.append(d_loss)
        
    plt.title("generator loss")
    plt.plot(train_loss)
    plt.title("discriminator loss")
    plt.plot(train_d_loss)   
    plt.show()
    plot_sample(p_g, epoch, z_dim=64)


In [None]:
from imp import reload
import models
import utils
reload(utils)
reload(models)
from models import *
from utils import *