In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import time

from torchvision import datasets, transforms
from torchvision.utils import save_image

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [3]:
device

device(type='cuda')

In [4]:

#The model of the Decoder
class GenerativeModel(nn.Module):

    def __init__(self, latent_dim=50):
        super(GenerativeModel, self).__init__()
        self.latent_dim = latent_dim
        self.net = torch.nn.Sequential(
                    torch.nn.Linear(latent_dim, 1024),
                    torch.nn.ReLU(),
                    torch.nn.Linear(1024, 1024),
                    torch.nn.ReLU(),
                    torch.nn.Linear(1024, 784),
                    torch.nn.Sigmoid()
                    )
        
    def forward(self, x):
        return self.net(x)

    def sample(self, M, N=None):
        device = next(self.parameters()).device
        if N is None:
            x = torch.randn(M, self.latent_dim).to(device)
        else:
            x = torch.randn(M, N, self.latent_dim).to(device)
        return self.forward(x)
    
    def conditional_log_likelihood(self, x, y):
        recon_x = torch.clamp(self.forward(x), 1e-6, 1.-1e-6)
        return torch.log(recon_x) * y + torch.log(1 - recon_x) * (1 - y)
        
class SimpleVAE(nn.Module):

    def __init__(self, latent_dim=50):
        super(SimpleVAE, self).__init__()
        self.latent_dim = latent_dim
        self.G = GenerativeModel(latent_dim)
        self.encoder = torch.nn.Sequential(
                    torch.nn.Linear(784, 1024),
                    torch.nn.ReLU(),
                    torch.nn.Linear(1024, 1024),
                    torch.nn.ReLU(),
                    torch.nn.Linear(1024, latent_dim * 2)
                    )
    
    def forward(self, x, y):
        device = next(self.parameters()).device
        M = x.shape[0]
        N = y.shape[0]
        dW = torch.zeros((M, N, 1)).to(device)
        mean_std = self.encoder(y)
        mean = mean_std[:, :self.latent_dim]
        std = torch.abs(mean_std[:, self.latent_dim:]) + 1e-6
        x1 = x * std + mean
        dW = dW + (x**2).sum(axis=2, keepdims=True) / 2
        dW = dW - (x1**2).sum(axis=2, keepdims=True) / 2
        dW = dW + self.G.conditional_log_likelihood(x1, y).sum(axis=2, keepdims=True)
        dW = dW + torch.log(std).sum(axis=1, keepdims=True)
        return x1, dW

    def log_likelihood(self, y, M):
        device = next(self.parameters()).device
        x0 = torch.randn(M, y.shape[0], self.latent_dim).to(device)
        x, dW = self.forward(x0, y.view(-1, 784))
        return torch.mean(dW, axis=0, keepdims=False)

class LangevinVAE(nn.Module):

    def __init__(self, latent_dim=50, nsteps=30, stepsize=0.01):
        super().__init__()
        self.latent_dim = latent_dim
        self.G = GenerativeModel(latent_dim)
        self.nsteps = nsteps
        stepsize_list = torch.FloatTensor([stepsize,] * nsteps)
        lambda_list = (np.array(range(1,nsteps + 1))/nsteps).tolist()
        lambda_list = torch.FloatTensor(lambda_list)
        self.stepsize_para_list, self.lambda_para_list = self.stepsize_lambda_2_para(stepsize_list, lambda_list)
        self.stepsize_para_list = nn.Parameter(torch.FloatTensor(self.stepsize_para_list), requires_grad=True)
        self.lambda_para_list = nn.Parameter(torch.FloatTensor(self.lambda_para_list), requires_grad=True)
        
    def stepsize_lambda_2_para(self, stepsize_list, lambda_list):
        stepsize_para_list = torch.clamp(torch.abs(stepsize_list), min=1e-6)
        lambda_para_list = lambda_list
        return stepsize_para_list, lambda_para_list
    
    def para_2_stepsize_lambda(self, stepsize_para_list, lambda_para_list):
        stepsize_list = torch.abs(stepsize_para_list) + 1e-6
        lambda_list = lambda_para_list
        return stepsize_list, lambda_list

    def energy_0(self, x, y):
        return (x**2).sum(axis=2, keepdims=True) / 2

    def force_0(self, x, y):
        return -x
    
    def sample_energy_0(self, y, M):
        device = next(self.parameters()).device
        x = torch.randn(M, y.shape[0], self.latent_dim).to(device)
        return x
        
    def energy_1(self, x, y):
        return (x**2).sum(axis=2, keepdims=True) / 2 - self.G.conditional_log_likelihood(x, y).sum(axis=2, keepdims=True)

    def force_1(self, x, y):
        x0 = x.clone().detach().requires_grad_(True)
        e = self.energy_1(x0, y)
        return -torch.autograd.grad(e.sum(), x0, create_graph=True)[0]

    def interpolated_energy(self, x, y, lambda_=1.):
        return self.energy_0(x, y) * (1 - lambda_) + self.energy_1(x, y) * lambda_

    def interpolated_force(self, x, y, lambda_=1.):
        return self.force_0(x, y) * (1 - lambda_) + self.force_1(x, y) * lambda_

    def forward(self, x, y):
        stepsize_list, lambda_list = self.para_2_stepsize_lambda(self.stepsize_para_list, self.lambda_para_list)
        dW = self.energy_0(x, y)
        for i in range(self.nsteps):
            lambda_ = lambda_list[i]
            stepsize = stepsize_list[i]
            # forward step
            x1 = x + stepsize * self.interpolated_force(x, lambda_) + torch.sqrt(2*stepsize) * torch.randn_like(x)
            tmp_dW = self.interpolated_energy(x1, y, lambda_) - self.interpolated_energy(x, y, lambda_)
            A = torch.exp(torch.clamp(-tmp_dW, - math.inf, 0.))
            u = torch.rand_like(A)
            acc = (u <= A).float()
            x = (1 - acc) * x + acc * x1
            dW += acc * tmp_dW
        dW = dW - self.energy_1(x, y)
        return x, dW

    def log_likelihood(self, y, M):
        x0 = self.sample_energy_0(y.view(-1, 784), M)
        x, dW = self.forward(x0, y.view(-1, 784))
        return torch.mean(dW, axis=0, keepdims=False)


class CouplingLayer(nn.Module):
    def __init__(self, input_dim, hid_dim, mask, cond_dim=None, s_tanh_activation=True, smooth_activation=False):
        super().__init__()
        
        if cond_dim is not None:
            total_input_dim = input_dim + cond_dim
        else:
            total_input_dim = input_dim

        self.s_fc1 = nn.Linear(total_input_dim, hid_dim)
        self.s_fc2 = nn.Linear(hid_dim, hid_dim)
        self.s_fc3 = nn.Linear(hid_dim, input_dim)
        self.t_fc1 = nn.Linear(total_input_dim, hid_dim)
        self.t_fc2 = nn.Linear(hid_dim, hid_dim)
        self.t_fc3 = nn.Linear(hid_dim, input_dim)
        self.mask = nn.Parameter(mask, requires_grad=False)
        self.s_tanh_activation = s_tanh_activation
        self.smooth_activation = smooth_activation

    def forward(self, x, cond_x=None, mode='direct'):
        x_m = x * self.mask
        if cond_x is not None:
            x_m = torch.cat([x_m, cond_x.expand(x_m.shape[0], -1, -1)], -1)
        if self.smooth_activation:
            if self.s_tanh_activation:
                s_out = torch.tanh(self.s_fc3(F.elu(self.s_fc2(F.elu(self.s_fc1(x_m)))))) * (1-self.mask)
            else:
                s_out = self.s_fc3(F.elu(self.s_fc2(F.elu(self.s_fc1(x_m))))) * (1-self.mask)
            t_out = self.t_fc3(F.elu(self.t_fc2(F.elu(self.t_fc1(x_m))))) * (1-self.mask)
        else:
            if self.s_tanh_activation:
                s_out = torch.tanh(self.s_fc3(F.relu(self.s_fc2(F.relu(self.s_fc1(x_m)))))) * (1-self.mask)
            else:
                s_out = self.s_fc3(F.relu(self.s_fc2(F.relu(self.s_fc1(x_m))))) * (1-self.mask)
            t_out = self.t_fc3(F.relu(self.t_fc2(F.relu(self.t_fc1(x_m))))) * (1-self.mask)
        if mode == 'direct':
            y = x * torch.exp(s_out) + t_out
            log_det_jacobian = s_out.sum(-1, keepdim=True)
        else:
            y = (x - t_out) * torch.exp(-s_out)
            log_det_jacobian = -s_out.sum(-1, keepdim=True)
        return y, log_det_jacobian

class RealNVP(nn.Module):
    def __init__(self, input_dim, hid_dim = 256, n_layers = 2, cond_dim = None, s_tanh_activation = True, smooth_activation=False):
        super().__init__()
        assert n_layers >= 2, 'num of coupling layers should be greater or equal to 2'
        
        self.input_dim = input_dim
        mask = (torch.arange(0, input_dim) % 2).float()
        self.modules = []
        self.modules.append(CouplingLayer(input_dim, hid_dim, mask, cond_dim, s_tanh_activation, smooth_activation))
        for _ in range(n_layers - 2):
            mask = 1 - mask
            self.modules.append(CouplingLayer(input_dim, hid_dim, mask, cond_dim, s_tanh_activation, smooth_activation))
        self.modules.append(CouplingLayer(input_dim, hid_dim, 1 - mask, cond_dim, s_tanh_activation, smooth_activation))
        self.module_list = nn.ModuleList(self.modules)
        
    def forward(self, x, cond_x=None, mode='direct'):
        """ Performs a forward or backward pass for flow modules.
        Args:
            x: a tuple of inputs and logdets
            mode: to run direct computation or inverse
        """
        logdets = torch.zeros(x.size(), device=x.device).sum(-1, keepdim=True)

        assert mode in ['direct', 'inverse']
        if mode == 'direct':
            for module in self.module_list:
                x, logdet = module(x, cond_x, mode)
                logdets += logdet
        else:
            for module in reversed(self.module_list):
                x, logdet = module(x, cond_x, mode)
                logdets += logdet

        return x, logdets

    def log_probs(self, x, cond_x = None):
        u, log_jacob = self(x, cond_x)
        log_probs = (-0.5 * u.pow(2) - 0.5 * math.log(2 * math.pi)).sum(
            -1, keepdim=True)
        return (log_probs + log_jacob).sum(-1, keepdim=True)

    def sample(self, num_samples, noise=None, cond_x=None):
        if noise is None:
            noise = torch.Tensor(num_samples, self.input_dim).normal_()
        device = next(self.parameters()).device
        noise = noise.to(device)
        if cond_x is not None:
            cond_x = cond_x.to(device)
        samples = self.forward(noise, cond_x, mode='inverse')[0]
        return samples
    
class RealNVPVAE(nn.Module):

    def __init__(self, latent_dim=50):
        super().__init__()
        self.latent_dim = latent_dim
        self.G = GenerativeModel(latent_dim)
        self.F = RealNVP(latent_dim, hid_dim=64, n_layers=6, cond_dim=784)

    def energy_0(self, x, y):
        return (x**2).sum(axis=2, keepdims=True) / 2
    
    def sample_energy_0(self, y, M):
        device = next(self.parameters()).device
        x = torch.randn(M, y.shape[0], self.latent_dim).to(device)
        return x
        
    def energy_1(self, x, y):
        return (x**2).sum(axis=2, keepdims=True) / 2 - self.G.conditional_log_likelihood(x, y).sum(axis=2, keepdims=True)

    def forward(self, x, y):
        dW = self.energy_0(x, y)
        x, tmp_dW = self.F(x, y)
        dW += tmp_dW
        dW = dW - self.energy_1(x, y)
        return x, dW

    def log_likelihood(self, y, M):
        x0 = self.sample_energy_0(y.view(-1, 784), M)
        x, dW = self.forward(x0, y.view(-1, 784))
        return torch.mean(dW, axis=0, keepdims=False)

class RealNVPVAE_eval(nn.Module):

    def __init__(self, G):
        super().__init__()
        latent_dim = G.latent_dim
        self.latent_dim = latent_dim
        self.G = G
        self.F = RealNVP(latent_dim, hid_dim=256, n_layers=12, cond_dim=784)

    def energy_0(self, x, y):
        return (x**2).sum(axis=2, keepdims=True) / 2
    
    def sample_energy_0(self, y, M):
        device = next(self.parameters()).device
        x = torch.randn(M, y.shape[0], self.latent_dim).to(device)
        return x
        
    def energy_1(self, x, y):
        return (x**2).sum(axis=2, keepdims=True) / 2 - self.G.conditional_log_likelihood(x, y).sum(axis=2, keepdims=True)

    def forward(self, x, y):
        dW = self.energy_0(x, y)
        x, tmp_dW = self.F(x, y)
        dW += tmp_dW
        dW = dW - self.energy_1(x, y)
        return x, dW

    def log_likelihood(self, y, M):
        x0 = self.sample_energy_0(y.view(-1, 784), M)
        x, dW = self.forward(x0, y.view(-1, 784))
        return torch.logsumexp(dW, axis=0, keepdims=False) - math.log(M)

def ModelEval(G, sample_size, data_file):
    start = time.process_time()
    
#    device = torch.device("cuda")
    latent_dim = 50
    batch_size = 128
    n_epochs = 40
    log_interval = 10

    if data_file == 'mnist_data':
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST(data_file, train=True, download=False,
                           transform=transforms.ToTensor()),
                            batch_size=batch_size, shuffle=True)
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST(data_file, train=False, transform=transforms.ToTensor()),
            batch_size=batch_size, shuffle=False)
    else:
        train_loader = torch.utils.data.DataLoader(
            datasets.FashionMNIST(data_file, train=True, download=False,
                           transform=transforms.ToTensor()),
                            batch_size=batch_size, shuffle=True)
        test_loader = torch.utils.data.DataLoader(
            datasets.FashionMNIST(data_file, train=False, transform=transforms.ToTensor()),
            batch_size=batch_size, shuffle=False)
        
    flow = RealNVPVAE_eval(G).to(device)
    optim = torch.optim.Adam(flow.F.parameters(), lr=1e-3)

    M = 1
    for epoch in range(1, n_epochs + 1):
        train_loss = 0
        for batch_idx, (data, _) in enumerate(test_loader):
            data = ((torch.rand_like(data) <= data) + 0.).float()
            data = data.to(device)
            loss = -flow.log_likelihood(data, M).mean()
            optim.zero_grad()
            loss.backward()
            train_loss += loss.item()*len(data)
            optim.step()
            if batch_idx % log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(test_loader.dataset),
                    100. * batch_idx / len(test_loader),
                    loss.item()))
    
    with torch.no_grad():
        test_loss = 0
        M = sample_size
        K = 10
        for kk in range(K):
            for batch_idx, (data, _) in enumerate(test_loader):
                data = ((torch.rand_like(data) <= data) + 0.).float()
                data = data.to(device)
                loss = -flow.log_likelihood(data, M).mean()
                test_loss += loss.item()*len(data)
                if batch_idx % log_interval == 0:
                    print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        kk, batch_idx * len(data), len(test_loader.dataset),
                        100. * batch_idx / len(test_loader),
                        loss.item()))
        test_loss /= len(test_loader.dataset)*K
    print('====> Test set NLL: {:.4f}'.format(test_loss))

    return test_loss
    

class SNFVAE(nn.Module):

    def __init__(self, latent_dim=50, unit_num=3, nsteps=10, stepsize=0.1):
        super().__init__()
        self.latent_dim = latent_dim
        self.unit_num = unit_num
        self.G = GenerativeModel(latent_dim)
        self.F_list = []
        for _ in range(unit_num):
            self.F_list.append(RealNVP(latent_dim, hid_dim=64, n_layers=2, cond_dim=784))
        self.F_list = nn.ModuleList(self.F_list)
        self.nsteps = nsteps
        stepsize_list = torch.FloatTensor([stepsize,] * nsteps * unit_num)
        lambda_list = (np.array(range(1,nsteps * unit_num + 1))/nsteps / unit_num).tolist()
        lambda_list = torch.FloatTensor(lambda_list)
        self.stepsize_para_list, self.lambda_para_list = self.stepsize_lambda_2_para(stepsize_list, lambda_list)
        self.stepsize_para_list = nn.Parameter(torch.FloatTensor(self.stepsize_para_list), requires_grad=True)
        self.lambda_para_list = nn.Parameter(torch.FloatTensor(self.lambda_para_list))
        
    def stepsize_lambda_2_para(self, stepsize_list, lambda_list):
        stepsize_para_list = torch.clamp(torch.abs(stepsize_list), min=1e-6)
        lambda_para_list = lambda_list
        return stepsize_para_list, lambda_para_list
    
    def para_2_stepsize_lambda(self, stepsize_para_list, lambda_para_list):
        stepsize_list = torch.abs(stepsize_para_list) + 1e-6
        lambda_list = lambda_para_list
        return stepsize_list, lambda_list

    def energy_0(self, x, y):
        return (x**2).sum(axis=2, keepdims=True) / 2

    def force_0(self, x, y):
        return -x
    
    def sample_energy_0(self, y, M):
        device = next(self.parameters()).device
        x = torch.randn(M, y.shape[0], self.latent_dim).to(device)
        return x
        
    def energy_1(self, x, y):
        return (x**2).sum(axis=2, keepdims=True) / 2 - self.G.conditional_log_likelihood(x, y).sum(axis=2, keepdims=True)

    def force_1(self, x, y):
        x0 = x.clone().detach().requires_grad_(True)
        e = self.energy_1(x0, y)
        return -torch.autograd.grad(e.sum(), x0, create_graph=True)[0]

    def interpolated_energy(self, x, y, lambda_=1.):
        return self.energy_0(x, y) * (1 - lambda_) + self.energy_1(x, y) * lambda_

    def interpolated_force(self, x, y, lambda_=1.):
        return self.force_0(x, y) * (1 - lambda_) + self.force_1(x, y) * lambda_

    def forward(self, x, y, flow_disable=False):
        stepsize_list, lambda_list = self.para_2_stepsize_lambda(self.stepsize_para_list, self.lambda_para_list)
        dW = self.energy_0(x, y)
        for i in range(self.nsteps * self.unit_num):
            if i % self.nsteps == 0:
                x, tmp_dW = self.F_list[int(i/self.nsteps)](x, y)
                dW += tmp_dW                
            if flow_disable:
                continue
            lambda_ = lambda_list[i]
            stepsize = stepsize_list[i]
            # forward step
            x1 = x + stepsize * self.interpolated_force(x, lambda_) + torch.sqrt(2*stepsize) * torch.randn_like(x)
            tmp_dW = self.interpolated_energy(x1, y, lambda_) - self.interpolated_energy(x, y, lambda_)
            A = torch.exp(torch.clamp(-tmp_dW, - math.inf, 0.))
            u = torch.rand_like(A)
            acc = (u <= A).float()
            x = (1 - acc) * x + acc * x1
            dW += acc * tmp_dW
        dW = dW - self.energy_1(x, y)
        return x, dW

    def log_likelihood(self, y, M, flow_disable=False):
        x0 = self.sample_energy_0(y.view(-1, 784), M)
        x, dW = self.forward(x0, y.view(-1, 784), flow_disable)
        return torch.mean(dW, axis=0, keepdims=False)

In [5]:
def train(model_name, data_file, M):
    start = time.process_time()
    
    latent_dim = 50
    batch_size = 128
    log_interval = 100

    if data_file == 'mnist_data':
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST('mnist_data', train=True, download=True,
                           transform=transforms.ToTensor()),
                            batch_size=batch_size, shuffle=True)
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST('mnist_data', train=False, transform=transforms.ToTensor()),
            batch_size=batch_size, shuffle=True)
    else:
        train_loader = torch.utils.data.DataLoader(
            datasets.FashionMNIST('fashionmnist_data', train=True, download=True,
                           transform=transforms.ToTensor()),
                            batch_size=batch_size, shuffle=True)
        test_loader = torch.utils.data.DataLoader(
            datasets.FashionMNIST('fashionmnist_data', train=False, transform=transforms.ToTensor()),
            batch_size=batch_size, shuffle=True)
    
    if model_name in ['SimpleVAE','RealNVPVAE','LangevinVAE']:
        n_epochs = 40
        if model_name == 'SimpleVAE':
            flow = SimpleVAE(latent_dim).to(device)
        if model_name == 'RealNVPVAE':
            flow = RealNVPVAE(latent_dim).to(device)
        if model_name == 'LangevinVAE':
            flow = LangevinVAE(latent_dim).to(device)
        optim = torch.optim.Adam(flow.parameters(), lr=1e-3)
        #perform training
        for epoch in range(1, n_epochs + 1):
            train_loss = 0
            for batch_idx, (data, _) in enumerate(train_loader):
                data = ((torch.rand_like(data) <= data) + 0.).float()
                data = data.to(device)
                loss = -flow.log_likelihood(data, M).mean()
                optim.zero_grad()
                loss.backward()
                train_loss += loss.item() * len(data)
                optim.step()
                if batch_idx % log_interval == 0:
                    print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        epoch, batch_idx * len(data), len(train_loader.dataset),
                        100. * batch_idx / len(train_loader),
                        loss.item() * len(data) / len(data)))

            test_loss = 0
            for i, (data, _) in enumerate(test_loader):
                data = ((torch.rand_like(data) <= data) + 0.).float()
                data = data.to(device)
                loss = -flow.log_likelihood(data, M).sum()
                test_loss += loss.item()

            test_loss /= len(test_loader.dataset)
            print('====> Test set loss: {:.4f}'.format(test_loss))
    else:
        flow = SNFVAE(latent_dim, nsteps=10, stepsize=1e-2).to(device)
        optim = torch.optim.Adam(flow.parameters(), lr=1e-3)
        n_epochs = 20
        flow_disable = True
        for epoch in range(1, n_epochs + 1):
            train_loss = 0
            for batch_idx, (data, _) in enumerate(train_loader):
                data = ((torch.rand_like(data) <= data) + 0.).float()
                data = data.to(device)
                loss = -flow.log_likelihood(data, M, flow_disable).mean()
                optim.zero_grad()
                loss.backward()
                train_loss += loss.item()
                optim.step()
                if batch_idx % log_interval == 0:
                    print(flow.stepsize_para_list.mean())
                    print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        epoch, batch_idx * len(data), len(train_loader.dataset),
                        100. * batch_idx / len(train_loader),
                        loss.item()))

            test_loss = 0
            for i, (data, _) in enumerate(test_loader):
                data = ((torch.rand_like(data) <= data) + 0.).float()
                data = data.to(device)
                loss = -flow.log_likelihood(data, M, flow_disable).sum()
                test_loss += loss.item()

            test_loss /= len(test_loader.dataset)
            print('====> Test set loss: {:.4f}'.format(test_loss))

        optim = torch.optim.Adam(flow.parameters(), lr=1e-3)
        flow_disable = False
        n_epochs = 20
        flow_disable = True
        for epoch in range(1, n_epochs + 1):
            train_loss = 0
            for batch_idx, (data, _) in enumerate(train_loader):
                data = ((torch.rand_like(data) <= data) + 0.).float()
                data = data.to(device)
                loss = -flow.log_likelihood(data, M, flow_disable).mean()
                optim.zero_grad()
                loss.backward()
                train_loss += loss.item()
                optim.step()
                if batch_idx % log_interval == 0:
                    print(flow.stepsize_para_list.mean())
                    print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        epoch+20, batch_idx * len(data), len(train_loader.dataset),
                        100. * batch_idx / len(train_loader),
                        loss.item()))

            test_loss = 0
            for i, (data, _) in enumerate(test_loader):
                data = ((torch.rand_like(data) <= data) + 0.).float()
                data = data.to(device)
                loss = -flow.log_likelihood(data, M, flow_disable).sum()
                test_loss += loss.item()

            test_loss /= len(test_loader.dataset)
            print('====> Test set loss: {:.4f}'.format(test_loss))


    #calculate the marginal log-likelihood
    loss = ModelEval(flow.G, 2000, data_file)

    print('Running time: %s Seconds'%(time.process_time()-start))

In [6]:
M = 5
for model in ['SimpleVAE', 'RealNVPVAE', 'LangevinVAE', 'SNFVAE']:
    for data_file in ['mnist_data', 'fashionmnist_data']:
        train(model, data_file, M)

====> Test set loss: 137.8264
====> Test set loss: 115.8454
====> Test set loss: 108.8558
====> Test set loss: 105.4472
====> Test set loss: 103.1052
====> Test set loss: 101.7620
====> Test set loss: 100.5393
====> Test set loss: 99.0912
====> Test set loss: 98.3424
====> Test set loss: 98.2897
====> Test set loss: 97.4921
====> Test set loss: 97.1838
====> Test set loss: 97.0271
====> Test set loss: 96.5876
====> Test set loss: 96.6738
====> Test set loss: 96.0730
====> Test set loss: 96.0118
====> Test set loss: 95.9934
====> Test set loss: 95.7344
====> Test set loss: 95.3972
====> Test set loss: 95.3941
====> Test set loss: 95.2695
====> Test set loss: 95.1676
====> Test set loss: 95.1743
====> Test set loss: 95.0058
====> Test set loss: 95.0852
====> Test set loss: 95.1467
====> Test set loss: 94.7399


====> Test set loss: 94.6370
====> Test set loss: 94.6565
====> Test set loss: 94.7044
====> Test set loss: 94.5726
====> Test set loss: 94.5309
====> Test set loss: 94.2874
====> Test set loss: 94.2958
====> Test set loss: 106.8035
====> Test set loss: 104.5031
====> Test set loss: 113.2632
====> Test set loss: 109.6053
====> Test set loss: 113.9040




====> Test set NLL: 96.8194
Running time: 7943.390625 Seconds
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to fashionmnist_data\FashionMNIST\raw\train-images-idx3-ubyte.gz


96.5%

Extracting fashionmnist_data\FashionMNIST\raw\train-images-idx3-ubyte.gz to fashionmnist_data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to fashionmnist_data\FashionMNIST\raw\train-labels-idx1-ubyte.gz


100.6%


Extracting fashionmnist_data\FashionMNIST\raw\train-labels-idx1-ubyte.gz to fashionmnist_data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to fashionmnist_data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz


100.0%


Extracting fashionmnist_data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz to fashionmnist_data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to fashionmnist_data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz


119.3%


Extracting fashionmnist_data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz to fashionmnist_data\FashionMNIST\raw

====> Test set loss: 258.2698
====> Test set loss: 250.9579
====> Test set loss: 247.9019
====> Test set loss: 245.6281
====> Test set loss: 244.0494
====> Test set loss: 246.0073
====> Test set loss: 258.9976
====> Test set loss: 270.6588
====> Test set loss: 258.2443
====> Test set loss: 264.0515
====> Test set loss: 266.3973
====> Test set loss: 260.9879
====> Test set loss: 262.4643
====> Test set loss: 271.1788
====> Test set loss: 276.6870
====> Test set loss: 271.7624
====> Test set loss: 273.6716
====> Test set loss: 273.4859
====> Test set loss: 271.5053
====> Test set loss: 283.0969
====> Test set loss: 289.6180
====> Test set loss: 286.6373
====> Test set loss: 284.4428
====> Test set loss: 284.0972
====> Test set loss: 287.4722
====> Test set loss: 304.6278
====> Test set loss: 300.3455
====> Test set loss: 294.1909


====> Test set loss: 287.6420
====> Test set loss: 291.7893
====> Test set loss: 324.5279
====> Test set loss: 317.0475
====> Test set loss: 311.3809
====> Test set loss: 308.5773
====> Test set loss: 311.3539
====> Test set loss: 321.2123
====> Test set loss: 359.3398
====> Test set loss: 322.3713
====> Test set loss: 304.8240
====> Test set loss: 332.7388




====> Test set NLL: 259.6149
Running time: 7664.09375 Seconds
====> Test set loss: 126.4996
====> Test set loss: 109.5897


====> Test set loss: 105.1388
====> Test set loss: 102.5701
====> Test set loss: 101.1181
====> Test set loss: 99.7554
====> Test set loss: 98.7246
====> Test set loss: 97.9893
====> Test set loss: 97.4544
====> Test set loss: 96.7191
====> Test set loss: 96.4141
====> Test set loss: 96.3260
====> Test set loss: 95.8211
====> Test set loss: 95.8275
====> Test set loss: 95.6023
====> Test set loss: 95.0727
====> Test set loss: 94.7467
====> Test set loss: 94.7971
====> Test set loss: 94.6296
====> Test set loss: 94.4563
====> Test set loss: 94.2299
====> Test set loss: 94.1343
====> Test set loss: 93.9956
====> Test set loss: 93.7358
====> Test set loss: 93.7071
====> Test set loss: 93.6895
====> Test set loss: 93.9093
====> Test set loss: 93.5024
====> Test set loss: 93.7688
====> Test set loss: 93.3593
====> Test set loss: 93.1843


====> Test set loss: 93.3536
====> Test set loss: 93.2102
====> Test set loss: 93.2277
====> Test set loss: 93.0816
====> Test set loss: 92.8832
====> Test set loss: 92.7914
====> Test set loss: 92.8634
====> Test set loss: 92.9022
====> Test set loss: 92.9281




====> Test set NLL: 88.2087
Running time: 10527.03125 Seconds
====> Test set loss: 255.9460
====> Test set loss: 248.4758
====> Test set loss: 245.3227
====> Test set loss: 243.5221
====> Test set loss: 241.9737
====> Test set loss: 240.9382
====> Test set loss: 240.3649


====> Test set loss: 239.6782
====> Test set loss: 238.8288
====> Test set loss: 238.2480
====> Test set loss: 237.9232
====> Test set loss: 237.6144
====> Test set loss: 237.6311
====> Test set loss: 237.1084
====> Test set loss: 237.1187
====> Test set loss: 236.7414
====> Test set loss: 236.2719
====> Test set loss: 236.2683
====> Test set loss: 236.4014
====> Test set loss: 235.9385
====> Test set loss: 235.6694
====> Test set loss: 235.6479
====> Test set loss: 235.7227
====> Test set loss: 235.3376
====> Test set loss: 235.2327
====> Test set loss: 235.3947
====> Test set loss: 235.0274
====> Test set loss: 235.0982
====> Test set loss: 235.1374
====> Test set loss: 234.8654
====> Test set loss: 234.8595
====> Test set loss: 234.7641
====> Test set loss: 234.8782
====> Test set loss: 234.8629
====> Test set loss: 234.7095


====> Test set loss: 234.6987
====> Test set loss: 234.7446
====> Test set loss: 234.6060
====> Test set loss: 234.2604
====> Test set loss: 234.3314




====> Test set NLL: 231.6458
Running time: 13436.21875 Seconds
====> Test set loss: 207.0557
====> Test set loss: 206.6956
====> Test set loss: 206.5808
====> Test set loss: 206.1211
====> Test set loss: 206.2860
====> Test set loss: 203.4679
====> Test set loss: 206.2150
====> Test set loss: 205.8312
====> Test set loss: 202.6040
====> Test set loss: 211.6887


====> Test set loss: 195.4048
====> Test set loss: 195.6504
====> Test set loss: 223.8770
====> Test set loss: 204.4399
====> Test set loss: 196.2484
====> Test set loss: 205.3531
====> Test set loss: 208.2405
====> Test set loss: 206.7349
====> Test set loss: 207.0803
====> Test set loss: 207.9969
====> Test set loss: 209.1292
====> Test set loss: 209.4120
====> Test set loss: 206.1982
====> Test set loss: 207.1072
====> Test set loss: 208.5181
====> Test set loss: 203.7684
====> Test set loss: 203.6482
====> Test set loss: 206.6702
====> Test set loss: 205.8361
====> Test set loss: 207.5423
====> Test set loss: 206.9615
====> Test set loss: 206.9622
====> Test set loss: 205.7766
====> Test set loss: 207.6150
====> Test set loss: 206.7956
====> Test set loss: 206.2006
====> Test set loss: 205.5860
====> Test set loss: 206.6784


====> Test set loss: 206.4566
====> Test set loss: 204.1159




====> Test set NLL: 209.9231
Running time: 35052.671875 Seconds
====> Test set loss: 368.3632
====> Test set loss: 368.6836
====> Test set loss: 380.2391
====> Test set loss: 389.2225
====> Test set loss: 388.1197
====> Test set loss: 387.9507
====> Test set loss: 395.4375
====> Test set loss: 387.9137
====> Test set loss: 385.1315
====> Test set loss: 383.7395
====> Test set loss: 383.4373
====> Test set loss: 386.3138


====> Test set loss: 383.9717
====> Test set loss: 384.2430
====> Test set loss: 383.8242
====> Test set loss: 384.5674
====> Test set loss: 385.9729
====> Test set loss: 385.8953
====> Test set loss: 384.5959
====> Test set loss: 386.4113
====> Test set loss: 384.8155
====> Test set loss: 384.9977
====> Test set loss: 387.1582
====> Test set loss: 385.4830
====> Test set loss: 387.7874
====> Test set loss: 386.6260
====> Test set loss: 386.3531
====> Test set loss: 387.0647
====> Test set loss: 387.0857
====> Test set loss: 385.4707
====> Test set loss: 386.4599
====> Test set loss: 384.6655
====> Test set loss: 384.6605
====> Test set loss: 385.5905
====> Test set loss: 385.4129
====> Test set loss: 386.6488
====> Test set loss: 387.0078
====> Test set loss: 386.7946
====> Test set loss: 386.6037
====> Test set loss: 384.8957






====> Test set NLL: 389.7032
Running time: 34978.6875 Seconds
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
====> Test set loss: 124.9433
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
====> Test set loss: 110.0039
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
====> Test set loss: 10

====> Test set loss: 98.0813
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
====> Test set loss: 97.9425
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
====> Test set loss: 97.1815
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
====> Test set loss: 96.7787
tensor(0.0100, device='cuda:0

tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
====> Test set loss: 94.2508
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
====> Test set loss: 93.8671
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
====> Test set loss: 93.7271
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
t

====> Test set loss: 92.8116
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
====> Test set loss: 92.7133
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
====> Test set loss: 92.7519
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
====> Test set loss: 92.9703




====> Test set NLL: 88.0589
Running time: 12391.953125 Seconds
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
====> Test set loss: 255.1780
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
====> Test set loss: 247.6545
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
====> Test set loss: 2

tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
====> Test set loss: 241.7536
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
====> Test set loss: 240.8195
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
====> Test set loss: 240.1794
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>

tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
====> Test set loss: 236.0834
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
====> Test set loss: 236.0842
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
====> Test set loss: 235.8736
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>

tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
====> Test set loss: 234.8010
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
====> Test set loss: 234.8451
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
====> Test set loss: 234.7984
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0100, device='cuda:0', grad_fn=<MeanBackward0>



====> Test set NLL: 231.9653
Running time: 11725.34375 Seconds
