In [1]:
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
from torch import optim
from torch import nn
import torch
import torch.nn.functional as F
import math
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import math
import copy

In [2]:
import torch.distributions as D
import itertools

In [3]:
dev = torch.device(
    "cuda") if torch.cuda.is_available() else torch.device("cpu")


In [4]:
dev

device(type='cpu')

In [5]:
#weight to point can be added

def get_all_vects(a, v, i, u):
    while v[i] != u:
        if i != 0:
            get_all_vects(a, v, i-1, u)
        else:
            a.append(v.copy())
        v[i] += 1
    v[i] = 0

class Gaussian_Grid(nn.Module):
    def __init__(self, dimensions, grid_size, variance, weighting = 'uniform', rand_seed = None):
        super(Gaussian_Grid, self).__init__()
        self.grid_size = grid_size
        self.variance = variance
        self.dimensions = dimensions
        self.grid_distr = list()
        #uniform distribution on knots
        if weighting == 'uniform':
            u = 1/(self.grid_size**self.dimensions)
            self.grid_distr = [u for i in range(self.grid_size**self.dimensions)]
        #random distribution on knots (seed can be specified)
        if weighting == 'random':
            if rand_seed is not None:
                torch.manual_seed(rand_seed)
            cur_sum = 0
            for i in range(self.grid_size**self.dimensions):
                u = D.Uniform(0, 1-cur_sum).sample().item()
                self.grid_distr.append(u)
                cur_sum += u
        m = D.Categorical(torch.tensor(self.grid_distr))

        all_vects = list()
        get_all_vects(all_vects, [0.0 for i in range(self.dimensions)], self.dimensions-1, grid_size)
        comp = D.Independent(D.Normal(
             torch.tensor(all_vects), 
             self.variance*torch.ones(self.grid_size**self.dimensions,self.dimensions)), 1)
        self.gmm = D.MixtureSameFamily(m, comp)


    def sampler(self, sample_amnt):
        return self.gmm.sample((sample_amnt,))

    def log_pdf(self, x):
        return self.gmm.log_prob(x)

In [6]:
class Simple_Gaussian(nn.Module):
    def __init__(self, dimensions, mean = None, variance = None):
        super(Simple_Gaussian, self).__init__()
        self.mean = None
        if mean is None:
            self.mean = torch.zeros(dimensions)
        else:
            self.mean = torch.FloatTensor(mean)
        self.variance = None
        if variance is None:
            self.variance = torch.eye(dimensions)
        else:
            self.variance = torch.FloatTensor(variance)
        self.norml = D.MultivariateNormal(self.mean.to(dev), self.variance.to(dev)) # * torch.eye(dimensions)

    def sampler(self, sample_amnt):
        return self.norml.sample((sample_amnt,))

    def log_pdf(self, x):
        return self.norml.log_prob(x)

In [7]:
class Banana_Gaussian(nn.Module):
    def __init__(self, p=100, b=0.1):
        super(Banana_Gaussian, self).__init__()
        self.mean = torch.FloatTensor([0, 0])
        self.variance = torch.FloatTensor([[p, 0], [0, 1]])
        self.norml = D.MultivariateNormal(self.mean.to(dev), self.variance.to(dev)) # * torch.eye(dimensions)
        self.p = p
        self.b = b
        self.samps = None

    def sampler(self, sample_amnt):
        samps = self.norml.sample((sample_amnt,))
        samps[:, 1] = samps[:, 1] + self.b*samps[:, 0]**2-self.p*self.b
        return samps

    def log_pdf(self, x):
        y = x.clone().detach()
        y[:, 1] = y[:, 1] - self.b*y[:, 0]**2 +self.p*self.b
        return self.norml.log_prob(y)

In [8]:
def func(x):
    #return (x*x*x).sum(dim=1)
    return x[:, 1]

In [9]:
def sigma(pred, old, f, p, q, log_dets, coef_ll = 1, coef_var = 0):
    ans = 0
    if coef_ll is not None:
        ans += -coef_ll*(q.log_pdf(pred)+log_dets).sum()
    if coef_var is not None:
        ans += coef_var*((f(old)**2 *torch.exp(p.log_pdf(old)-q.log_pdf(pred)-log_dets)).mean())
    return ans

In [10]:
from IPython.display import clear_output
import itertools
from tqdm.notebook import tqdm
import pandas as pd

In [11]:
def estimate_params(p, q, f, model, sample_size = 10**3, lear_rt = 1e-3, epoch_amnt = 2*10**3, lr_downing_num=1,
                   ll_coef = 1, var_coef=0):
    lr = lear_rt #learning rate
    max_epochs = epoch_amnt
    samples_amnt = sample_size

    xb = p.sampler(samples_amnt)
    opt = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-6)

    lambda2 = lambda itnum: 0.9
    scheduler = optim.lr_scheduler.MultiplicativeLR(opt, lr_lambda=lambda2)

    #best_model = None
    #wanted to deepcopy.....
    
    best_min_loss = float('Inf')

    for itnum in range(lr_downing_num):
        epoch_cnt = 0
        inc_res = 0
        cur_res = 10**6
        scheduler.step()

        # heuristic of stopping
        while True:
            if max_epochs is not None:
                if epoch_cnt > max_epochs:
                    break
            zb, log_dets = model(xb)
            log_dets = log_dets.squeeze()
            loss = sigma(zb, xb, f, p, q, log_dets, ll_coef, var_coef)

            if torch.isnan(loss):
                print("nan loss")
                break
            
            loss.backward()
            opt.step()
            opt.zero_grad()

            is_nan_param = False
            for elem in model.parameters():
                if torch.any(torch.isnan(elem)).item():
                    print("nan param")
                    is_nan_param = True
                    break
            if is_nan_param:
                break

            if best_min_loss > loss.item():
                inc_res = 0
                best_min_loss = loss.item()
            
            if inc_res > 50:
                break

            prev_res = cur_res
            cur_res = loss
            epoch_cnt += 1
            inc_res += 1

            """
            if epoch_cnt % 100 == 0:
                clear_output(wait=True)
                print("Cur iter:", itnum)
                print("Curr_loss:", loss)
                print("Min_loss:", best_min_loss)
                print("Current epochs:", epoch_cnt) 
                print("Current tolerance:", torch.abs(cur_res-prev_res))
            """

In [12]:
def new_expectancy(f, p, q, test_amnt, sample_amnt, model):
    results = list()
    for i in range(test_amnt):
        mean = 0
        z = q.sampler(sample_amnt)
        x, log_det = model(z, mode='inverse')
        #plt.scatter(x[:, 0], x[:, 1])
        #plt.grid()
        #plt.show()
        """
        print("z", z)
        print("x", x)
        print("f(x)", f(x))
        print("p.pdf", p.pdf(x))
        print("tx", torch.exp(q_dist.log_pdf(z) - log_det.squeeze()) )
        print("1/tx", torch.exp(-q_dist.log_pdf(z) + log_det.squeeze()) )
        print("p.pdf* 1/tx", p.pdf(x)*torch.exp(-q_dist.log_pdf(z) + log_det.squeeze()) )
        print(f(x)*p.pdf(x)*torch.exp(-q_dist.log_pdf(z) + log_det.squeeze()) )
        """
        mean = (f(x)*torch.exp(p.log_pdf(x)-q_dist.log_pdf(z) + log_det.reshape(1, -1))).mean()
        results.append(mean.item())
        #if i%10 == 0:
        #    print(i)
    return results

In [13]:
import myfl as fnn


def build_model(data_dimensions, num_hidden, num_blocks, mtype = 'realnvp'):
    num_inputs = data_dimensions
    num_cond_inputs = None

    modules = []

    if mtype == "maf":
        for _ in range(num_blocks):
            modules += [
                    fnn.MADE(num_inputs, num_hidden, None, act='tanh'),
                    fnn.BatchNormFlow(num_inputs),
                    fnn.Reverse(num_inputs)
                ]
    elif mtype == 'realnvp':
        mask = torch.arange(0, num_inputs) % 2
        for _ in range(num_blocks):
            modules += [
                fnn.CouplingLayer(
                    num_inputs, num_hidden, mask, None,
                    s_act='tanh', t_act='relu'),
                fnn.BatchNormFlow(num_inputs)
            ]
            mask = 1 - mask
    elif mtype == 'maf-split':
        #does not work
        for _ in range(num_blocks):
            modules += [
                fnn.MADESplit(num_inputs, num_hidden, num_cond_inputs,
                            s_act='tanh', t_act='relu'),
                fnn.BatchNormFlow(num_inputs),
                fnn.Reverse(num_inputs)
            ]
    elif mtype == 'maf-split-glow':
        #does not work
        for _ in range(num_blocks):
            modules += [
                fnn.MADESplit(num_inputs, num_hidden, num_cond_inputs,
                            s_act='tanh', t_act='relu'),
                fnn.BatchNormFlow(num_inputs),
                fnn.InvertibleMM(num_inputs)
            ]


    model = fnn.FlowSequential(*modules)

    for module in model.modules():
        if isinstance(module, nn.Linear):
            nn.init.orthogonal_(module.weight)
            if hasattr(module, 'bias') and module.bias is not None:
                module.bias.data.fill_(0)

    return model

In [14]:
def regular_expectancy(f, p, test_amnt, sample_amnt):
    results = list()
    for i in range(test_amnt):
        samples = p.sampler(sample_amnt)
        mean = f(samples).mean()
        results.append(mean.item())
        #if i%10 == 0:
        #    print(i)
    return results

In [19]:
def box_comp(a, b):                                                    
    data = [a, b]                                             
    plt.figure(figsize=(12,8))                 
    plt.boxplot(data, showfliers = False, labels =                              
                ["MC Vanila", "norm_flow"])
    plt.grid()
    plt.show()

In [20]:
import os

data_dimensions = 2 # how many dimensions are there in our variable
# should work for >= 2
lrs = [1e-2, 1e-3, 1e-4]
layer_amnt = [3, 15, 25]
hidden_size = [3, 8, 15]
flow_type = ['realnvp', 'maf']
epoch_amnt = [5*10**3, 2*10**4]

column_names = ["learning rate", "layer amount", 
                "hidden layer size", "flow type", "max_amount of epochs",
               "MC Vanila", "norm_flow"]
if os.path.exists('./banana_res.csv.csv'):
    output = pd.read_csv('banana_res.csv.csv')
else:
    output = pd.DataFrame(columns = column_names)

for combination in itertools.product(lrs, layer_amnt, hidden_size, flow_type, epoch_amnt):
    cur_params = list(combination)
    was_comb = False
    for i in range(output.shape[0]):
        old_params = output.iloc[i, :len(cur_params)]
        if old_params.tolist() == cur_params:
            was_comb = True
            break
    if was_comb:
        continue
    print(cur_params)
    clr, clamnt, chidsz, cfltype, cepamnt = combination
    model = build_model(data_dimensions, clamnt, chidsz, cfltype)
    p = 100
    b = 0.01
    p_dist = Banana_Gaussian(p, b)
    q_dist = Simple_Gaussian(data_dimensions)
    estimate_params(p_dist, q_dist, func, model, sample_size = 10**3, lear_rt = clr, 
                    epoch_amnt = cepamnt, lr_downing_num = 5, ll_coef = None, var_coef=1)
    
    model.eval()
    reg = regular_expectancy(func, p_dist, 100, 10**3)
    newe = new_expectancy(func, p_dist, q_dist, 100, 10**3, model)
    cur_params.append(reg)
    cur_params.append(newe)
    cur_params = pd.Series(cur_params, index=output.columns)
    output = output.append(cur_params, ignore_index=True)
    output.to_csv('banana_res.csv', index=False)
    

[0.01, 3, 3, 'realnvp', 5000]
[0.01, 3, 3, 'realnvp', 20000]
[0.01, 3, 3, 'maf', 5000]
[0.01, 3, 3, 'maf', 20000]
[0.01, 3, 8, 'realnvp', 5000]
[0.01, 3, 8, 'realnvp', 20000]
[0.01, 3, 8, 'maf', 5000]
[0.01, 3, 8, 'maf', 20000]
[0.01, 3, 15, 'realnvp', 5000]
[0.01, 3, 15, 'realnvp', 20000]
[0.01, 3, 15, 'maf', 5000]
[0.01, 3, 15, 'maf', 20000]
[0.01, 15, 3, 'realnvp', 5000]
[0.01, 15, 3, 'realnvp', 20000]
[0.01, 15, 3, 'maf', 5000]
[0.01, 15, 3, 'maf', 20000]
[0.01, 15, 8, 'realnvp', 5000]
[0.01, 15, 8, 'realnvp', 20000]
[0.01, 15, 8, 'maf', 5000]
[0.01, 15, 8, 'maf', 20000]
[0.01, 15, 15, 'realnvp', 5000]
[0.01, 15, 15, 'realnvp', 20000]
[0.01, 15, 15, 'maf', 5000]
[0.01, 15, 15, 'maf', 20000]
[0.01, 25, 3, 'realnvp', 5000]
[0.01, 25, 3, 'realnvp', 20000]
[0.01, 25, 3, 'maf', 5000]
[0.01, 25, 3, 'maf', 20000]
[0.01, 25, 8, 'realnvp', 5000]
[0.01, 25, 8, 'realnvp', 20000]
nan param
nan loss
nan loss
nan loss
nan loss
[0.01, 25, 8, 'maf', 5000]
[0.01, 25, 8, 'maf', 20000]
nan param
nan 