In [None]:
import torch
from torch import nn
import torch.optim.lr_scheduler as lr_scheduler
from torch.optim import Adam
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"


In [None]:
import sys

# add path to Data dir
sys.path.insert(1, '../Data/')
from data_utils import get_dataloaders

In [None]:
data_path = '../Data/db.npz'
with np.load(data_path, allow_pickle = True) as f:
    param_labels = f['param_labels'] # last one is cosmo sigma_8
    redshifts = f['redshifts']
    LF_zs = f['LF_zs']
    M_UV = f['M_UV']
    N_params = len(param_labels)

In [None]:
train_dataloader, valid_dataloader, test_dataloader, norms = get_dataloaders(f_train=0.8, f_valid=0.1, batch_size=64, lstm=True)

In [None]:
Tb_bias, Tb_scale,Ts_bias, Ts_scale,UVLFs_bias, UVLFs_scale,tau_bias, tau_scale = norms

In [None]:
class LSTM(nn.Module):
    """LSTM"""
    def __init__(self,dims,num_layers=2, final_act=None, init=False):
        super().__init__()
        self.lstm = nn.LSTM(input_size=dims[0], hidden_size = dims[1], num_layers=num_layers, batch_first=True)
        self.lin = nn.Linear(dims[1], dims[2])
        self.final_act = final_act
        if init:
            self.lstm.apply(self._init_weights)
            self.lin.apply(self._init_weights)
            
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=1.0)
            if module.bias is not None:
                module.bias.data.zero_()

    def forward(self,x):
        output, hidden = self.lstm(x)
        output = self.lin(output)
        if self.final_act is not None:
            output = self.final_act(output)
        return output

class Emulator(nn.Module):
    def __init__(self,params_dict):
        super().__init__()
        self.xhi = LSTM(dims=params_dict['xhi']['dims'], num_layers = params_dict['xhi']['num_layers'], final_act = params_dict['xhi']['final_act'])
        self.tb_shape = 1
        self.tb = LSTM(dims=params_dict['tb']['dims'], num_layers = params_dict['tb']['num_layers'], final_act = params_dict['tb']['final_act'])
        self.ts = LSTM(dims=params_dict['ts']['dims'], num_layers = params_dict['ts']['num_layers'], final_act = params_dict['ts']['final_act'])
        self.lfs6 = LSTM(dims = params_dict['lfs']['dims'], num_layers = params_dict['lfs']['num_layers'], final_act = params_dict['lfs']['final_act'])
        self.lfs7 = LSTM(dims = params_dict['lfs']['dims'], num_layers = params_dict['lfs']['num_layers'], final_act = params_dict['lfs']['final_act'])
        self.lfs8 = LSTM(dims = params_dict['lfs']['dims'], num_layers = params_dict['lfs']['num_layers'], final_act = params_dict['lfs']['final_act'])
        self.lfs9 = LSTM(dims = params_dict['lfs']['dims'], num_layers = params_dict['lfs']['num_layers'], final_act = params_dict['lfs']['final_act'])
        self.lfs10 = LSTM(dims = params_dict['lfs']['dims'], num_layers = params_dict['lfs']['num_layers'], final_act = params_dict['lfs']['final_act'])
        self.lfs12 = LSTM(dims = params_dict['lfs']['dims'], num_layers = params_dict['lfs']['num_layers'], final_act = params_dict['lfs']['final_act'])
        self.lfs15 = LSTM(dims = params_dict['lfs']['dims'], num_layers = params_dict['lfs']['num_layers'], final_act = params_dict['lfs']['final_act'])
        self.tau = LSTM(dims = params_dict['tau']['dims'], final_act = params_dict['tau']['final_act'])
    def forward(self, theta):
        xhi_pred = self.xhi(theta)
        tb_pred = self.tb(theta)
        ts_pred = self.ts(theta)
        lfs_pred = torch.cat([self.lfs6(theta),
                              self.lfs7(theta),
                              self.lfs8(theta),
                              self.lfs9(theta),
                              self.lfs10(theta),
                              self.lfs12(theta),
                              self.lfs15(theta)], axis = -1)
        tau_pred = self.tau(theta)
        return xhi_pred.squeeze(), tb_pred.squeeze(), ts_pred.squeeze(), lfs_pred.squeeze(), tau_pred.squeeze()

In [None]:
emu_params = {'lfs':{'dims':[11, len(M_UV),1], 'num_layers':2, 'final_act':nn.Sigmoid()},
              'tau':{'dims':[11,128,128,128,128,1], 'final_act':nn.Sigmoid()},
              'ts':{'dims':[11, len(redshifts),1], 'num_layers':2, 'final_act':nn.Sigmoid()},
              'tb':{'dims':[11, len(redshifts),1], 'num_layers':2, 'final_act':nn.Sigmoid()},
              'xhi':{'dims':[11, len(redshifts),1], 'num_layers':2, 'final_act':nn.Sigmoid()},
              }

In [None]:
model = Emulator(emu_params)
#model.load_state_dict(torch.load(str(results_folder) + '/model_pt10'))
model.float()
model.to(device)

In [None]:
optimizer = Adam(list(model.parameters()), lr=1e-3)

In [None]:
def loss(true, pred, loss_fnc = F.mse_loss,weights=None):
    if weights is None:
        weights = np.ones(len(true))
    xhi_pred, tb_pred, ts_pred, lfs_pred, tau_pred = pred
    xhi_true, tb_true, ts_true, lfs_true, tau_true = true
    xhi_loss = loss_fnc(xhi_true, xhi_pred)
    tb_loss = loss_fnc(tb_true, tb_pred)
    ts_loss = loss_fnc(ts_true, ts_pred)
    lfs_loss = loss_fnc(lfs_true, lfs_pred)
    tau_loss = loss_fnc(tau_true, tau_pred)
    
    loss = weights[0] * xhi_loss + weights[1]*tb_loss + weights[2]*ts_loss + weights[3]*lfs_loss + weights[4]*tau_loss
    return loss
    

In [None]:
from training import train, validate, lr_schedule

In [None]:
nepochs = 1000
epoch = 0
results_folder = 'results'
epoch_vloss = []
epoch_tloss = []
scheduler = lr_scheduler.StepLR(optimizer, step_size = 1, gamma=0.1)

while epoch < nepochs:
    tloss, model, optimizer = train(model, train_dataloader, optimizer, loss, epoch, device=device)
    epoch_tloss.append(tloss)
    vloss = validate(model, valid_dataloader, optimizer, loss, epoch, device)
    epoch_vloss.append(vloss)
    if epoch == 0:
        plateau = 0
    scheduler, plateau = lr_schedule(optimizer, epoch_vloss, plateau)
    this_loss = epoch_vloss[-1]
    if epoch < 5:
        torch.save(model.state_dict(), str(results_folder) + '/model_'+str(epoch))
    if (epoch >= 5 and this_loss <= np.sort(epoch_vlosses[:-1])[4]):
        num = int(np.where(np.sort(epoch_vlosses) == this_loss)[0][0])
        torch.save(model.state_dict(), str(results_folder) + '/model_'+str(num))
    epoch += 1