In [None]:
import torch
from torch import nn
import torch.optim.lr_scheduler as lr_scheduler
from torch.optim import Adam
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"


## Loading the database

In [2]:
import sys

# add path to Data dir
sys.path.insert(1, '../Data/')
from data_utils import get_dataloaders

In [3]:
data_path = '../Data/db.npz'
with np.load(data_path, allow_pickle = True) as f:
    param_labels = f['param_labels'] # last one is cosmo sigma_8
    redshifts = f['redshifts']
    LF_zs = f['LF_zs']
    M_UV = f['M_UV']
    N_params = len(param_labels)

We interact with the data using `DataLoader` objects that supply batches of data. We split the database into three parts: 
- training: database used to update the weights of the NN (most of the data)
- validation: database used to assess the NN's performance during the training (e.g. to avoid over/under fitting, to change the learning rate).
- test: database used to assess the NN's performance after the training is complete.

In [4]:
train_dataloader, valid_dataloader, test_dataloader, norms = get_dataloaders(f_train=0.8, f_valid=0.1, batch_size=64)

In [5]:
Tb_bias, Tb_scale,Ts_bias, Ts_scale,UVLFs_bias, UVLFs_scale,tau_bias, tau_scale = norms

We define a class that will produce a fully connected neural network for us to train:

In [6]:
class FeedForward(nn.Module):
    """Feed forward network architecture with optional activation at the end.
    Parameters
    ----------
    dims : array-like
        Dictates the network number of nodes and number of layers: 
        dims = [size of input, 
            number of nodes in hidden layer 1, ..., 
            number of nodes in hidden layer N, 
            size of output]
    act : callable, optional
        The activation function that goes between each nn.Linear layer.
        Default is nn.ReLU()
    final_act : callable, optional
        The activation function that goes at the very end of the network.
        Default is None.
    
    """

    def __init__(self, dims, act=nn.ReLU(), final_act=None):
        super(FeedForward, self).__init__()
        num = len(dims) - 2
        net = nn.Sequential()
        for i in range(num):
            net.append(nn.Linear(dims[i], dims[i+1]))
            net.append(act)
        i += 1
        net.append(nn.Linear(dims[i], dims[i+1]))
                    
        if final_act is not None:
            net.append(final_act)
        self.net = net
                                 
    def forward(self, x):
        x = self.net(x)
        return x

In [7]:
class Emulator(nn.Module):
    def __init__(self,params_dict):
        super().__init__()
        self.xhi = FeedForward(dims=params_dict['xhi']['dims'], final_act = params_dict['xhi']['final_act'])
        self.tb_shape = 1
        self.tb = FeedForward(dims=params_dict['tb']['dims'], final_act = params_dict['tb']['final_act'])
        self.ts = FeedForward(dims=params_dict['ts']['dims'], final_act = params_dict['ts']['final_act'])
        self.lfs6 = FeedForward(dims = params_dict['lfs']['dims'], final_act = params_dict['lfs']['final_act'])
        self.lfs7 = FeedForward(dims = params_dict['lfs']['dims'], final_act = params_dict['lfs']['final_act'])
        self.lfs8 = FeedForward(dims = params_dict['lfs']['dims'], final_act = params_dict['lfs']['final_act'])
        self.lfs9 = FeedForward(dims = params_dict['lfs']['dims'], final_act = params_dict['lfs']['final_act'])
        self.lfs10 = FeedForward(dims = params_dict['lfs']['dims'], final_act = params_dict['lfs']['final_act'])
        self.lfs12 = FeedForward(dims = params_dict['lfs']['dims'], final_act = params_dict['lfs']['final_act'])
        self.lfs15 = FeedForward(dims = params_dict['lfs']['dims'], final_act = params_dict['lfs']['final_act'])
        self.tau = FeedForward(dims = params_dict['tau']['dims'], final_act = params_dict['tau']['final_act'])
    def forward(self, theta):
        xhi_pred = self.xhi(theta)
        tb_pred = self.tb(theta)
        ts_pred = self.ts(theta)
        lfs_pred = torch.cat([self.lfs6(theta)[...,np.newaxis],
                              self.lfs7(theta)[...,np.newaxis],
                              self.lfs8(theta)[...,np.newaxis],
                              self.lfs9(theta)[...,np.newaxis],
                              self.lfs10(theta)[...,np.newaxis],
                              self.lfs12(theta)[...,np.newaxis],
                              self.lfs15(theta)[...,np.newaxis]], axis = -1)
        tau_pred = self.tau(theta)
        return xhi_pred.squeeze(), tb_pred.squeeze(), ts_pred.squeeze(), lfs_pred.squeeze(), tau_pred.squeeze()

In [8]:
emu_params = {'lfs':{'dims':[N_params, 1024,1024,1024,len(M_UV)], 'final_act':nn.Sigmoid()},
              'tau':{'dims':[N_params,128,128,128,128,1], 'final_act':nn.Sigmoid()},
              'ts':{'dims':[N_params, 1024,1024,len(redshifts)], 'final_act':nn.Sigmoid()},
              'tb':{'dims':[N_params, 1024,1024,len(redshifts)], 'final_act':nn.Sigmoid()},
              'xhi':{'dims':[N_params, 1024,1024,len(redshifts)], 'final_act':nn.Sigmoid()},
              }

In [9]:
model = Emulator(emu_params)
#model.load_state_dict(torch.load(str(results_folder) + '/model_pt10'))
model.float()
model.to(device)

Emulator(
  (xhi): FeedForward(
    (net): Sequential(
      (0): Linear(in_features=11, out_features=1024, bias=True)
      (1): ReLU()
      (2): Linear(in_features=1024, out_features=1024, bias=True)
      (3): ReLU()
      (4): Linear(in_features=1024, out_features=93, bias=True)
      (5): Sigmoid()
    )
  )
  (tb): FeedForward(
    (net): Sequential(
      (0): Linear(in_features=11, out_features=1024, bias=True)
      (1): ReLU()
      (2): Linear(in_features=1024, out_features=1024, bias=True)
      (3): ReLU()
      (4): Linear(in_features=1024, out_features=93, bias=True)
      (5): Sigmoid()
    )
  )
  (ts): FeedForward(
    (net): Sequential(
      (0): Linear(in_features=11, out_features=1024, bias=True)
      (1): ReLU()
      (2): Linear(in_features=1024, out_features=1024, bias=True)
      (3): ReLU()
      (4): Linear(in_features=1024, out_features=93, bias=True)
      (5): Sigmoid()
    )
  )
  (lfs6): FeedForward(
    (net): Sequential(
      (0): Linear(in_feature

In [10]:
optimizer = Adam(list(model.parameters()), lr=1e-3)

We define the loss as a weighted sum of the loss of individual summaries:

\begin{equation*}
\mathcal{L} = \sum_{i=0}^4 w_i||\vec{s}_{\rm{pred, }i} - \vec{s}_{\rm{true, }i}||^2_2,
\end{equation*} 

where $i = 0,1,...,4$ labels each of the summary statistics we are learning.

In [11]:
def loss(true, pred, loss_fnc = F.mse_loss,weights=None):
    if weights is None:
        weights = np.ones(len(true))
    xhi_pred, tb_pred, ts_pred, lfs_pred, tau_pred = pred
    xhi_true, tb_true, ts_true, lfs_true, tau_true = true
    xhi_loss = loss_fnc(xhi_true, xhi_pred)
    tb_loss = loss_fnc(tb_true, tb_pred)
    ts_loss = loss_fnc(ts_true, ts_pred)
    lfs_loss = loss_fnc(lfs_true, lfs_pred)
    tau_loss = loss_fnc(tau_true, tau_pred)
    
    loss = weights[0] * xhi_loss + weights[1]*tb_loss + weights[2]*ts_loss + weights[3]*lfs_loss + weights[4]*tau_loss
    return loss
    

In [12]:
from training import train, validate, lr_schedule

In [None]:
nepochs = 1000
epoch = 0
results_folder = 'results'
epoch_vloss = []
epoch_tloss = []
scheduler = lr_scheduler.StepLR(optimizer, step_size = 1, gamma=0.1)

while epoch < nepochs:
    tloss, model, optimizer = train(model, train_dataloader, optimizer, loss, epoch, device=device)
    epoch_tloss.append(tloss)
    vloss = validate(model, valid_dataloader, optimizer, loss, epoch, device)
    epoch_vloss.append(vloss)
    if epoch == 0:
        plateau = 0
    scheduler, plateau = lr_schedule(optimizer, epoch_vloss, plateau)
    this_loss = epoch_vloss[-1]
    if epoch < 5:
        torch.save(model.state_dict(), str(results_folder) + '/model_'+str(epoch))
    if (epoch >= 5 and this_loss <= np.sort(epoch_vlosses[:-1])[4]):
        num = int(np.where(np.sort(epoch_vlosses) == this_loss)[0][0])
        torch.save(model.state_dict(), str(results_folder) + '/model_'+str(num))
    epoch += 1