In [1]:
from torchvision.datasets import MNIST
from torch.utils.data import TensorDataset, DataLoader
import torch
from torch import nn
from torch import optim
import os
import sys

In [2]:
import numpy as np

In [3]:
from tqdm import tqdm_notebook as tqdm
from matplotlib import pyplot as plt

In [4]:
sys.path.append('../src')
from DeepGenerativeModels.AutoEncoders import FlowVAE, VAE
from DeepGenerativeModels.RealNVP import RealNVP

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [6]:
data = MNIST('mnist', download=True, train=True)
train_data = TensorDataset(data.train_data.view(-1, 28 * 28).float() / 255, data.train_labels)
data = MNIST('mnist', download=True, train=False)
test_data = TensorDataset(data.test_data.view(-1, 28 * 28).float() / 255, data.test_labels)



# Обучение

## FlowVAE

In [7]:
import utils

In [8]:
flow_vae_model = FlowVAE(2, 28*28, device=device)
vae_model = VAE(2, 28*28, device=device)

In [9]:
optimizer = optim.Adam(vae_model.parameters(), lr=0.001)
utils.trainer(model = vae_model, 
        optimizer = optimizer, 
        dataset = train_data, 
        count_of_epoch = 5,
        batch_size = 64,              
        callback = None,
        progress = tqdm)

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))




In [10]:
optimizer = optim.Adam(flow_vae_model.parameters(), lr=0.001)
utils.trainer(model = flow_vae_model, 
        optimizer = optimizer, 
        dataset = train_data, 
        count_of_epoch = 5,
        batch_size = 64,              
        callback = None,
        progress = tqdm)

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))




# Draw posterior

In [11]:
batch_x, batch_y = list(DataLoader(train_data))[0]

In [None]:
x = np.linspace(-4, 4.0, 100)
y = np.linspace(-4, 4.0, 100)

xx, yy = np.meshgrid(x, y)

probas1 = np.zeros(xx.shape)
probas2 = np.zeros(xx.shape)

for i in tqdm(range(xx.shape[0])):
    for j in range(xx.shape[1]):
        z = torch.Tensor([xx[i,j], yy[i,j]]).view([1, -1])
        probas1[i, j] = vae_model.posterior_z(z, batch_x)
        probas2[i, j] = flow_vae_model.posterior_z(z, batch_x)

HBox(children=(IntProgress(value=0), HTML(value='')))

In [None]:
plt.rcParams['figure.figsize'] = (15, 8)
fig, axes = plt.subplots(1,2)
axes[0].imshow(probas1)
axes[1].imshow(probas2)
plt.show()