In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm


plt.style.use('ggplot')

In [2]:
from versatileorbits import *
#from RealNVP import *
import masks

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.multivariate_normal import MultivariateNormal

In [4]:
class RealNVPNode(nn.Module):
    def __init__(self, mask, hidden_size):
        super(RealNVPNode, self).__init__()
        self.dim = len(mask)
        self.mask = nn.Parameter(mask, requires_grad=False)

        self.s_func = nn.Sequential(nn.Linear(in_features=self.dim, out_features=hidden_size), nn.LeakyReLU(),
                                    nn.Linear(in_features=hidden_size, out_features=hidden_size), nn.LeakyReLU(),
                                    nn.Linear(in_features=hidden_size, out_features=self.dim))

        self.scale = nn.Parameter(torch.Tensor(self.dim))

        self.t_func = nn.Sequential(nn.Linear(in_features=self.dim, out_features=hidden_size), nn.LeakyReLU(),
                                    nn.Linear(in_features=hidden_size, out_features=hidden_size), nn.LeakyReLU(),
                                    nn.Linear(in_features=hidden_size, out_features=self.dim))

    def forward(self, x):
        x_mask = x*self.mask
        s = self.s_func(x_mask) * self.scale
        t = self.t_func(x_mask)

        y = x_mask + (1 - self.mask) * (x*torch.exp(s) + t)

        # Sum for -1, since for every batch, and 1-mask, since the log_det_jac is 1 for y1:d = x1:d.
        log_det_jac = ((1 - self.mask) * s).sum(-1)
        return y, log_det_jac

    def inverse(self, y):
        print(self.mask.shape)
        print(y.shape)
        y_mask = y * self.mask
        s = self.s_func(y_mask) * self.scale
        t = self.t_func(y_mask)

        x = y_mask + (1-self.mask)*(y - t)*torch.exp(-s)

        inv_log_det_jac = ((1 - self.mask) * -s).sum(-1)

        return x, inv_log_det_jac


class RealNVP(nn.Module):
    def __init__(self, masks, hidden_size):
        super(RealNVP, self).__init__()

        self.dim = len(masks[0])
        self.hidden_size = hidden_size

        self.masks = nn.ParameterList([nn.Parameter(torch.Tensor(mask), requires_grad=False) for mask in masks])
        self.layers = nn.ModuleList([RealNVPNode(mask, self.hidden_size) for mask in self.masks])

        self.distribution = MultivariateNormal(torch.zeros(self.dim), torch.eye(self.dim))

    def log_probability(self, x):
        log_prob = torch.zeros(x.shape[0])
        for layer in reversed(self.layers):
            x, inv_log_det_jac = layer.inverse(x)
            log_prob += inv_log_det_jac
        log_prob += self.distribution.log_prob(x)

        return log_prob

    def rsample(self, num_samples):
        x = self.distribution.sample((num_samples,))
        log_prob = self.distribution.log_prob(x)

        for layer in self.layers:
            x, log_det_jac = layer.forward(x)
            log_prob += log_det_jac

        return x, log_prob

    def sample_each_step(self, num_samples):
        samples = []

        x = self.distribution.sample((num_samples,))
        samples.append(x.detach().numpy())

        for layer in self.layers:
            x, _ = layer.forward(x)
            samples.append(x.detach().numpy())

        return samples


In [5]:
# Very simple training loop
def train(model, num_epochs = 100, batch_size = 64):
    NF_dataset = OrbitsDataset_NF(num_samples = 1000, phi0 = 1, H = -0.3, L = 0.5)
    train_loader = torch.utils.data.DataLoader(NF_dataset, batch_size=batch_size)
    optimizer = torch.optim.Adam(model.parameters())
    
    losses = []
    for epoch in tqdm(range(num_epochs)):    
        epoch_loss = 0
        for orbit_position in train_loader:
            log_probability = model.log_probability(orbit_position) #(batch_size)
            loss = - torch.mean(log_probability, dim = 0)
            
            
            loss.backward()
            optimizer.step()
            model.zero_grad()
            
            epoch_loss += loss
            
        
        epoch_loss /= len(train_loader)
        losses.append(epoch_loss.detach())
    
    return model, losses


In [6]:
num_layers= 4
masks = masks.mask2(input_output_size = 4, num_layers = num_layers)
hidden_size = 32

In [8]:
model = RealNVP(masks, hidden_size)

model, losses = train(model, num_epochs = 100)

  0%|          | 0/100 [00:00<?, ?it/s]

[ 0.92195445  1.66666667  1.         -0.3         0.5       ]
torch.Size([1000, 4])
It took 0.005406856536865234 time to finish the job.
torch.Size([4])
torch.Size([64, 4])
torch.Size([4])
torch.Size([64, 4])
torch.Size([4])
torch.Size([64, 4])
torch.Size([4])
torch.Size([64, 4])





ValueError: Expected value argument (Tensor of shape (64, 4)) to be within the support (IndependentConstraint(Real(), 1)) of the distribution MultivariateNormal(loc: torch.Size([4]), covariance_matrix: torch.Size([4, 4])), but found invalid values:
tensor([[-1.3766, -2.6055,     inf,     nan],
        [ 0.2442, -0.7375,     inf,  1.5533],
        [-1.7779, -2.2471,    -inf,     nan],
        [-1.0293, -2.4765,     inf,     nan],
        [-0.4612, -1.9820,     inf,     nan],
        [-1.6706, -2.5040,     inf,     nan],
        [-1.0524, -0.4670,    -inf,     nan],
        [-1.3049, -2.5938,     inf,     nan],
        [-0.8901, -0.2716,    -inf,     nan],
        [-0.8724, -2.3687,     inf,     nan],
        [-1.6938, -1.5370,    -inf,     nan],
        [-1.6399, -2.5338,     inf,     nan],
        [-0.9675, -0.3650,    -inf,     nan],
        [-1.7279, -1.6503,    -inf,     nan],
        [-0.5030, -2.0249,     inf,     nan],
        [-1.7653, -2.3119,     inf,     nan],
        [-1.5143, -2.5967,     inf,     nan],
        [-0.7981, -2.3093,     inf,     nan],
        [-1.5476, -1.1969,    -inf,     nan],
        [-1.7580, -1.7859,    -inf,     nan],
        [-1.6397, -1.3892,    -inf,     nan],
        [ 0.3186, -0.3903,     inf,  2.1158],
        [-1.2210, -0.6899,    -inf,     nan],
        [ 0.2325, -0.7770,     inf,     nan],
        [-1.4044, -2.6072,     inf,     nan],
        [-1.1166, -0.5504,    -inf,     nan],
        [ 0.1051, -1.1269,     inf,     nan],
        [ 0.0429, -1.2590,     inf,     nan],
        [-1.0033, -0.4079,    -inf,     nan],
        [-1.1890, -0.6470,    -inf,     nan],
        [-0.5973,  0.0421,    -inf,     nan],
        [-1.5993, -2.5627,     inf,     nan],
        [ 0.1863, -0.9219,     inf,     nan],
        [-1.7846, -2.0045,    -inf,     nan],
        [-1.1833, -0.6394,    -inf,     nan],
        [-0.2298, -1.7122,     inf,     nan],
        [-0.9103, -0.2956,    -inf,     nan],
        [-1.4834, -2.6027,     inf,     nan],
        [-1.5114, -2.5974,     inf,     nan],
        [-1.7239, -1.6362,    -inf,     nan],
        [ 0.2093, -0.8536,     inf,     nan],
        [-1.6132, -2.5539,     inf,     nan],
        [-1.6371, -1.3833,    -inf,     nan],
        [-0.9012, -0.2848,    -inf,     nan],
        [-1.3779, -2.6056,     inf,     nan],
        [-1.7225, -1.6311,    -inf,     nan],
        [-0.4819, -2.0034,     inf,     nan],
        [-1.0319, -0.4421,    -inf,     nan],
        [-0.7196, -2.2376,     inf,     nan],
        [-0.3009, -1.8035,     inf,     nan],
        [-1.5682, -2.5785,     inf,     nan],
        [-0.8249, -0.2006,    -inf,     nan],
        [ 0.3171, -0.0735,    -inf,  2.8707],
        [-1.6499, -2.5250,     inf,     nan],
        [-0.0234, -1.3912,     inf,     nan],
        [-0.7482, -2.2644,     inf,     nan],
        [-1.6892, -2.4813,     inf,     nan],
        [-1.7881, -2.1086,    -inf,     nan],
        [-1.2741, -2.5860,     inf,     nan],
        [-1.7528, -1.7570,    -inf,     nan],
        [-1.7839, -2.1975,    -inf,     nan],
        [-1.4500, -2.6064,     inf,     nan],
        [-0.4091, -1.9276,     inf,     nan],
        [-1.7799, -1.9532,    -inf,     nan]], grad_fn=<AddBackward0>)