In [None]:
import torch
from torch import nn
from torch import optim
import numpy as np
from tqdm import tqdm
import os
from torch.utils.tensorboard import SummaryWriter
writer_path = "./tb_log/nf/"

from nflows.flows.base import Flow
from nflows.transforms.base import CompositeTransform
from nflows.transforms.autoregressive import MaskedAffineAutoregressiveTransform
from nflows.flows.autoregressive import MaskedAutoregressiveFlow
from nflows.transforms.permutations import ReversePermutation
from nflows.distributions.normal import ConditionalDiagonalNormal

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
training_data = np.load("../../Data/n1000000_0910_all_flat.npz")
data_all = np.column_stack([training_data['ve_dune'][:,:36], training_data['vu_dune'][:,:36], training_data['vebar_dune'][:,:36], training_data['vubar_dune'][:,:36]])

# theta13, theta23, delta
target = np.column_stack([training_data["theta13"]/180*np.pi, training_data["theta23"]/180*np.pi,
                         training_data["delta"]/180*np.pi])

split = 900000
x_train = data_all[:split]
y_train = target[:split]
x_train_poisson = np.random.poisson(x_train)/1000

x_val = data_all[split:]
y_val = target[split:]
x_val_poisson = np.random.poisson(x_val)/1000

In [None]:
def flow_generator(num_layers=4, hidden_features=8, num_blocks=10):
    base_dist = ConditionalDiagonalNormal(shape=[3],
                                        context_encoder=nn.Linear(144, 6))

    transforms = []
    for _ in range(num_layers):
        transforms.append(ReversePermutation(features=3))
        transforms.append(MaskedAffineAutoregressiveTransform(features=3,
                                                            hidden_features=hidden_features,
                                                            context_features=144,
                                                            num_blocks=num_blocks))
    transform = CompositeTransform(transforms)

    flow = Flow(transform, base_dist)
    optimizer = optim.Adam(flow.parameters())
    return flow, optimizer

In [None]:
num_iter = 1000
hparam_writer = SummaryWriter("tb_log/nf/hparam")
for num_layers in [2, 3, 4, 5]:
    for hidden_features in [8, 16, 32, 64]:
        for num_blocks in [2, 4, 6, 8]:
            index = 1
            while os.path.isfile("nf/model_info_{}.txt".format(index)): index += 1
            flow, optimizer = flow_generator(num_layers, hidden_features, num_blocks)
            flow = flow.to(device)
            writer = SummaryWriter(writer_path + str(index))
            with open("nf/model_info_{}.txt".format(index), 'w') as f:
                f.writelines('num_layers = {}\n'.format(num_layers))
                f.writelines('hidden_features = {}\n'.format(hidden_features))
                f.writelines('num_blocks = {}\n'.format(num_blocks))

            for i in tqdm(range(num_iter)):
                x = torch.tensor(y_train, dtype=torch.float32).to(device)
                y = torch.tensor(x_train/1000, dtype=torch.float32).to(device)
                optimizer.zero_grad()
                loss = -flow.log_prob(inputs=x, context=y).mean()
                loss.backward()
                optimizer.step()
                writer.add_scalar("training_loss", loss, i)
            torch.save(flow, "./nf/test_{}.pt".format(index))
            hparam_writer.add_hparams({
                'num_layers': num_layers,
                'hidden_features': hidden_features,
                'num_blocks': num_blocks},
                {'hparam/loss': loss})

if plot_graph and i%30 == 0:
    colors =['green', 'blue', 'red', 'black', 'yellow']
    n_sample = 1000000
    samples = flow.sample(num_samples=n_sample,
                context=torch.tensor(np.array([data_all[0]/1000]), dtype=torch.float32).to(device)
                ).cpu().detach().numpy().reshape(-1, 3)
    samples = samples*180/np.pi
    figure = corner.hist2d(samples[:, 1], samples[:, 2],
                        levels=(0.68,),
                        scale_hist=True,
                        plot_datapoints=False,
                        color = colors[j],
                        labels= ["$\\theta_{23} $($^\circ$)", "$\delta_{cp} $($^\circ$)"],
                        # range=[[48,50], [170, 220]],
                        plot_contours = True,
                        plot_density = False,
                        fontsize=30,
                        bins = [200, 200],
                        label_kwargs={"fontsize": 30},
                        smooth=True
                    )
    j += 1
