## (1) Packages and settings

In [None]:
import sys, os, math
import numpy as np
import torch

import torch.nn as nn
from torch.functional import F
import torch.distributions as dist
from torch.utils.data import DataLoader, TensorDataset, random_split

import importlib
sys.path.append('../')
from utils_modules.models import SummaryNet, Expander, Net, vector_to_Cov
from utils_modules.vicreg import vicreg_loss
import utils_modules.data as utils_data

In [None]:
# settings for plots
import matplotlib
import matplotlib.pyplot as plt
font = {'family' : 'serif',
        'weight' : 'normal',
        'size'   : 10}
matplotlib.rc('font', **font)

rcnew = {"mathtext.fontset" : "cm", 
         "xtick.labelsize" : 18,
         "ytick.labelsize" : 18,
         "axes.titlesize" : 26, 
         "axes.labelsize" : 22,
         "xtick.major.size" : 8,      
         "xtick.minor.size" : 4,     
         "ytick.major.size" : 8,      
         "ytick.minor.size" : 4,      
         "legend.fontsize" : 22,
         'figure.titlesize' : 30,
         'errorbar.capsize' : 4,
         'axes.xmargin': 0.05,
          'axes.ymargin': 0.05,
        }
plt.rcParams.update(rcnew)
plt.style.use('tableau-colorblind10')
#plt.rcParams.update({"text.usetex": True,})

%config InlineBackend.figure_format = 'retina'


In [None]:
# use GPUs if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Device: %s'%(device))

In [None]:
from tqdm.notebook import tqdm
from scipy.stats import uniform

from nflows import distributions as distributions_
from nflows import flows, transforms
from nflows.nn import nets

## (2) Load and pre-process data

In [None]:
# load maps and parameters
maps      = np.load(...)[:, :, None, :, :]
dset_size = maps.shape[0] # data set size
splits    = maps.shape[1] # number of augmentations/views per parameter set

params  = np.load(...)[:, None, :]
params  = np.repeat(params, splits, axis = 1) # reshape the parameters to match the shape of the maps

# pre-process the maps data set
rescale     = True
standardize = True
verbose     = True

if rescale:
    maps = np.log(maps+1)
if standardize:
    maps_mean, maps_std = np.mean(maps, dtype=np.float64), np.std(maps, dtype=np.float64)
    maps = (maps - maps_mean)/maps_std
    
if verbose:
    print('Shape of parameters and maps:', params.shape, maps.shape)
    print('Parameter 1 range of values: [{:.3f}, {:.3f}]'.format(params[:, :, 0].min(), params[:, :, 0].max()))
    print('Parameter 2 range of values: [{:.3f}, {:.3f}]'.format(params[:, :, 1].min(), params[:, :, 1].max()))
    
    if rescale: print('Rescale: ', rescale)
    if standardize: print('Standardize: ', standardize)

maps   = torch.tensor(maps).float().to(device) 
params = torch.tensor(params).float().to(device)

In [None]:
# divide the data into train, validation, and test sets
batch_size = 256
train_frac, valid_frac, test_frac = 0.8, 0.1, 0.1


train_dset, valid_dset, test_dset = utils_data.create_datasets(maps, params, 
                                                               train_frac, valid_frac, test_frac, 
                                                               seed = seed,
                                                               rotations=False) 


train_loader = DataLoader(train_dset, batch_size, shuffle = True)
valid_loader = DataLoader(valid_dset, batch_size, shuffle = True)
test_loader  = DataLoader(test_dset, batch_size, shuffle = False)

if verbose: print('Split the data into train, validation, and test sets.')


## (3) Load the encoder model

In [None]:
# the Encoder Network 
fmodel = ...
fout   = ...

hidden     = 8         # architecture parameters 
last_layer = 2*hidden

model = SummaryNet(hidden = hidden, last_layer = last_layer).to(device)
model.load_state_dict(torch.load(fmodel, map_location=torch.device(device)))
model.eval(); 

## (4) Get the summaries from the maps

In [None]:
x_train = []
y_train = []

x_valid = []
y_valid = []

x_test = []
y_test = []
with torch.no_grad():
    for x, y in train_loader:
        x    = x.to(device=device)
        y    = y.to(device=device)
        x_NN = model(x).to(device=device)
        
        x_train.append(x_NN)
        y_train.append(y)
        
    for x, y in valid_loader:
        x    = x.to(device=device)
        y    = y.to(device=device)
        x_NN = model(x).to(device=device)
        
        x_valid.append(x_NN)
        y_valid.append(y)
        
    for x, y in test_loader:
        x    = x.to(device=device)
        y    = y.to(device=device)
        x_NN = model(x).to(device=device)
        
        x_test.append(x_NN)
        y_test.append(y)

############################
x_train = torch.cat(x_train)
y_train = torch.cat(y_train)

train_dset = TensorDataset(x_train, y_train)
train_loader = DataLoader(train_dset, batch_size, shuffle = True)
############################

x_valid = torch.cat(x_valid)
y_valid = torch.cat(y_valid)

valid_dset = TensorDataset(x_valid, y_valid)
valid_loader = DataLoader(valid_dset, batch_size, shuffle = True)
############################
x_test = torch.cat(x_test)
y_test = torch.cat(y_test)

test_dset = TensorDataset(x_test, y_test)
test_loader = DataLoader(test_dset, batch_size, shuffle = False)


## (5) Build and train a normalizing flow to predict $\Omega_M$ and $\sigma_8$ from the summaries
### (without assuming that these parameters follow a Gaussian distribution)

In [12]:
def build_maf(dim=1, num_transforms=8, context_features=None, hidden_features=128):
    transform = transforms.CompositeTransform(
        [
            transforms.CompositeTransform(
                [
                    transforms.MaskedAffineAutoregressiveTransform(
                        features=dim,
                        hidden_features=hidden_features,
                        context_features=context_features,
                        num_blocks=2,
                        use_residual_blocks=False,
                        random_mask=False,
                        activation=torch.tanh,
                        dropout_probability=0.0,
                        use_batch_norm=False,
                    ),
                    transforms.RandomPermutation(features=dim),
                ]
            )
            for _ in range(num_transforms)
        ]
    )

    distribution = distributions_.StandardNormal((dim,))
    neural_net = flows.Flow(transform, distribution)

    return neural_net


In [None]:
# output files
fmodel = ...
fout   = ...

num_transforms   = 8
hidden_features  = 128

flow_net = build_maf(dim=2, context_features=2,
                     num_transforms=num_transforms,
                     hidden_features=hidden_features).to(device=device)

lr         = 5e-4
max_epochs = 300
optimizer = torch.optim.AdamW(flow_net.parameters(), 
                              lr=lr,
                              weight_decay=1e-5,)  
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                       T_max=max_epochs, 
                                                       verbose=True);


In [None]:
flow_net.eval()
min_valid_loss, points = 0.0, 0
for x, y in valid_loader:
    with torch.no_grad():
        x    = x.float().to(device=device)
        y    = y.float().to(device=device)
        bs   = x.shape[0]
        
        loss = -flow_net.log_prob(y, context = x).mean() 
        
        min_valid_loss += (loss.cpu().item())*bs
        points += bs
    min_valid_loss /= points

print('Initial valid loss = %.3e'%min_valid_loss)
# do a loop over all the epochs
for epoch in range(max_epochs):
    
    # training
    train_loss, num_points = 0.0, 0
    flow_net.train()
    for x,y in train_loader:
        x    = x.float().to(device=device)
        y    = y.float().to(device=device)
        bs   = x.shape[0]
        
        loss =  -flow_net.log_prob(y, context = x).mean() 
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += (loss.cpu().item())*bs
        num_points += bs
        
    train_loss = train_loss/num_points

    # validation
    valid_loss, num_points = 0.0, 0
    flow_net.eval()
    for x,y in valid_loader:
        with torch.no_grad():
            x    = x.float().to(device=device)
            y    = y.float().to(device=device)
            bs   = x.shape[0]
 
            loss = -flow_net.log_prob(y, context = x).mean() 
            
            valid_loss += (loss.cpu().item())*bs
            num_points += bs
    valid_loss = valid_loss/num_points

    # verbose
    if valid_loss<min_valid_loss:
        min_valid_loss = valid_loss
        torch.save(flow_net.state_dict(), fmodel)
        print('Epoch %d: %.3e %.3e (saving)'%(epoch, train_loss, valid_loss))
    else:
        print('Epoch %d: %.3e %.3e '%(epoch, train_loss, valid_loss))

    if epoch == 0:
        f = open(fout, 'w')
    else:
        f = open(fout, 'a')
    f.write('%d %.4e %.4e\n'%(epoch, train_loss, valid_loss))
    f.close()
    
    scheduler.step()

In [None]:
# plot losses
plt.figure(figsize = (10, 6))
losses = np.loadtxt(fout)
start_epoch = 5
end_epoch = -1
plt.plot(losses[start_epoch:end_epoch, 0], 
         losses[start_epoch:end_epoch, 1], label = 'Training loss')
plt.plot(losses[start_epoch:end_epoch, 0], losses[start_epoch:end_epoch, 2], label = 'Validation loss')
plt.legend(loc = 'best')