# Neural Network Flows

ReLU network transport map on the synthetic examples

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import time

import pyro 
import pyro.distributions as dist
import pyro.distributions.transforms as T
from pyro.nn import AutoRegressiveNN

from MTKSD.loss import KSD_U, KSD_V, KSD_gammaU , Wasserstein, ELBO, KSD_U_nograd
from MTKSD.get_score import get_score
from MTKSD.plot import plot_dist2D, plot_loss, plot_scatter, get_distvals
from MTKSD.toy_distributions import MOG2D, Banana2D, Sinusoidal2D
from MTKSD.polynomial_transport import theta_init
from MTKSD.neural_net_transport import ReLU_transport, transform_dist
from MTKSD.utils import save_output, train_KSD, train_ELBO, get_metric, load_output

import matplotlib
import matplotlib.pyplot as plt
from matplotlib import style

plt.rcParams['figure.figsize'] = [6,6]

style.use("ggplot")
matplotlib.rcParams['mathtext.fontset'] = 'stix'
matplotlib.rcParams['font.family'] = 'STIXGeneral'


## Defining test problems

### Mixture of Gaussians

In [None]:
means = torch.Tensor([[1,1],[-1,1],[1,-1],[-1,-1]])
var = 0.2

mog2D = MOG2D(means, var)

torch.manual_seed(0)
mog2D_sample = mog2D.sample(10000)

plot_dist2D(mog2D,[-3,3],[-3,3])

### Banana

In [None]:
a,v1,v2 = 0.5,1,0.1
var = [a,v1,v2]

banana2D = Banana2D(var)

torch.manual_seed(0)
banana2D_sample = banana2D.sample(10000)

plot_dist2D(banana2D, [-5,5], [-5,10])

### Sinusoidal

In [None]:
a,v1,v2 = 1.2,1.3,0.001
var = [a,v1,v2]

sinusoidal2D = Sinusoidal2D(var)

torch.manual_seed(0)
sinusoidal2D_samp = sinusoidal2D.sample(10000)

plot_dist2D(sinusoidal2D, [-3,3],[-1.5,1.5],n_steps=1200)

## ReLU Network

In [None]:
class ReLU_Network(nn.Module):
    
    def __init__(self, input_dim, output_dim, hidden_dims):
        super().__init__()
        
        self.input_dim = input_dim
        self.output_dim = output_dim
        
        n_hidden = len(hidden_dims)
        self.input = nn.Linear(input_dim, hidden_dims[0])
        self.linears = ([nn.Linear(hidden_dims[i], hidden_dims[i+1]) for i in range(n_hidden - 1)])
        self.output = nn.Linear(hidden_dims[-1], output_dim)
        
    def forward(self,x):
        out = F.relu(self.input(x))
        for i in self.linears:
            out = F.relu(i(out))
        out = self.output(out)
        return out


### Mixture of Gaussians

In [None]:
torch.manual_seed(11)

input_dim = 4
base_dist = dist.Normal(torch.zeros(input_dim), torch.ones(input_dim))
relu_mog = ReLU_Network(input_dim,2,[20,20])
transform_relu_mog = transform_dist(base_dist, [relu_mog])

In [None]:
class mv():
    
    def __init__(self, mvn):
        self.mvn = mvn
        
    def log_prob(self,x):
        return self.mvn.log_prob(x)
    
    def sample(self, n):
        return self.mvn.sample(n)
    
    def score(self,x):
        return get_score(x,self.mvn)
        
MV = mv(dist.MultivariateNormal(torch.zeros(2),2*torch.eye(2)))

In [None]:
train_KSD(transform_relu_mog, MV, relu_mog, "", save_out = False, n_steps = 10000) #pretrain

In [None]:
plot_dist2D(mog2D,[-3,3],[-3,3])
plot_scatter(transform_relu_mog.sample((10000,)).detach(), color="cyan",alpha=0.03)
plt.xlim([-3,3])
plt.ylim([-3,3])

In [None]:
train_KSD(transform_relu_mog, mog2D, relu_mog, "", save_out = False, n_steps = 50000)

In [None]:
wass_relu_mog, KSD_relu_mog = get_metric(transform_relu_mog, mog2D, mog2D_sample)
print(wass_relu_mog, KSD_relu_mog)

In [None]:
plot_dist2D(mog2D,[-3,3],[-3,3])
plot_scatter(transform_relu_mog.sample((10000,)).detach(), color="cyan",alpha=0.03)
plt.xlim([-3,3])
plt.ylim([-3,3])

### Banana

In [None]:
torch.manual_seed(0)

input_dim = 4
base_dist = dist.Normal(torch.zeros(input_dim), torch.ones(input_dim))
relu_banana = ReLU_Network(input_dim,2,[20,20])
transform_relu_banana = transform_dist(base_dist, [relu_banana])

In [None]:
train_KSD(transform_relu_banana, MV, relu_banana, "", save_out = False, n_steps = 10000) #pretrain

In [None]:
plot_dist2D(banana2D, [-5,5], [-5,10])

plt.xlim([-3.5,3.5])
plt.ylim([-2.5,7.5])

plot_scatter(transform_relu_banana.sample((10000,)).detach(), color="cyan", alpha=0.03)

In [None]:
train_KSD(transform_relu_banana, banana2D, relu_banana, "relu_banana", save_out = True, n_steps = 50000)

In [None]:
wass_relu_banana, KSD_relu_banana = get_metric(transform_relu_banana, banana2D, banana2D_sample)
print(wass_relu_banana, KSD_relu_banana)

In [None]:
plot_dist2D(banana2D, [-5,5], [-5,10])

plt.xlim([-3.5,3.5])
plt.ylim([-2.5,7.5])

plot_scatter(transform_relu_banana.sample((10000,)).detach(), color="cyan", alpha=0.05,s=2)

### Sinusoidal

In [None]:
torch.manual_seed(90)

input_dim = 4
base_dist = dist.Normal(torch.zeros(input_dim), torch.ones(input_dim))
relu_sinusoidal = ReLU_Network(input_dim,2,[20,20])
transform_relu_sinusoidal = transform_dist(base_dist, [relu_sinusoidal])

In [None]:
train_KSD(transform_relu_sinusoidal, MV, relu_sinusoidal, "", save_out = False, n_steps = 10000) #pretrain

In [None]:
plot_dist2D(sinusoidal2D, [-5,5], [-3,3])

plt.xlim([-4,4])
plt.ylim([-1.5,1.5])

plot_scatter(transform_relu_sinusoidal.sample((10000,)).detach(), color="cyan", alpha=0.03)

In [None]:
train_KSD(transform_relu_sinusoidal, sinusoidal2D, relu_sinusoidal, "relu_sinusoidal", save_out = True, n_steps = 50000)

In [None]:
wass_relu_sinusoidal, KSD_relu_sinusoidal = get_metric(transform_relu_sinusoidal, MV2, MV2.sample((10000,)))
print(wass_relu_sinusoidal, KSD_relu_sinusoidal)

In [None]:
train_KSD(transform_relu_banana, banana2D, relu_banana, "relu_banana", save_out = True, n_steps = 50000)

In [None]:
KSD_U_nograd(torch.zeros(1000,2), sinusoidal2D.score,gamma=0.1)

In [None]:
wass_relu_sinusoidal, KSD_relu_sinusoidal = get_metric(transform_relu_sinusoidal, sinusoidal2D, sinusoidal2D_samp)
print(wass_relu_sinusoidal, KSD_relu_sinusoidal)

In [None]:
plot_dist2D(sinusoidal2D, [-5,5], [-3,3])

plt.xlim([-4,4])
plt.ylim([-1.5,1.5])

plot_scatter(transform_relu_sinusoidal.sample((10000,)).detach(), color="cyan", alpha=0.03)