In [1]:
import sys
import numpy as np
import torch

from torch.utils.data import DataLoader, TensorDataset, random_split
from torch.functional import F
import torch.nn as nn
import torch.distributions as dist


sys.path.append('../')
from utils_modules.models import Expander, vector_to_Cov
from utils_modules.vicreg import vicreg_loss
import utils_modules.data as utils_data
import utils_modules.baryons_toy_Pk as utils_toy_Pk



In [2]:
# plot formatting 
import matplotlib
import matplotlib.pyplot as plt
font = {'family' : 'serif',
        'weight' : 'normal',
        'size'   : 10}
matplotlib.rc('font', **font)

rcnew = {"mathtext.fontset" : "cm", 
         "xtick.labelsize" : 18,
         "ytick.labelsize" : 18,
         "axes.titlesize" : 26, 
         "axes.labelsize" : 22,
         "xtick.major.size" : 8,      # major tick size in points
         "xtick.minor.size" : 4,      # minor tick size in points
         "ytick.major.size" : 8,      # major tick size in points
         "ytick.minor.size" : 4,      # minor tick size in points
         "legend.fontsize" : 22
        }
plt.rcParams.update(rcnew)
plt.style.use('tableau-colorblind10')
plt.rcParams.update({
  "text.usetex": True,
})

%config InlineBackend.figure_format = 'retina'


In [4]:
# select device; use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Device: %s'%(device))

Device: cuda


## Generate parameters and power spectra

In [5]:
kmin  = 7e-3 #h/Mpc
kmax = 1

kF     = kmin
k_bins = int((kmax-kmin)/kF)
k      = np.arange(3,k_bins+2)*kF 
Nk     = 4.0*np.pi*k**2*kF/kF**3  #number of modes in each k-bin

# model parameters
predict_D     = True
Pk_continuous = True #whether fix A_value for kpivot or not

dset_size = 1000
train_frac, valid_frac, test_frac = 0.8, 0.1, 0.1

seed = 1
splits = 10

In [6]:
params = utils_toy_Pk.generate_params(dset_size, splits, 
                                                   predict_D = predict_D, 
                                                   Pk_continuous = Pk_continuous,
                                                   seed=seed)
params = params.reshape(dset_size, splits, -1)

Pk = utils_toy_Pk.get_Pk_arr(k, Nk, params, predict_D = predict_D, seed=seed)

## Load the model

In [7]:
inv, var, cov = 15, 15, 1
fmodel = 'trained_models/VICReg_{:d}_{:d}_{:d}.pt'.format(inv, var, cov)
fout   = 'trained_models/VICReg_{:d}_{:d}_{:d}.txt'.format(inv, var, cov)
    
# define the expander model
hidden = 16
last_layer = 32
args_net = [hidden, 
            last_layer, last_layer, last_layer, 
            last_layer, last_layer, last_layer]
net = Expander(args_net, k.shape[0], bn = True).to(device)
net.load_state_dict(torch.load(fmodel))
net.eval();


In [8]:
# inference netwoek
n_params = 3
n_tril = int(n_params * (n_params + 1) / 2)  # Number of parameters in lower triangular matrix, for symmetric matrix
n_out = n_params + n_tril  # Dummy output of neural network

# architecture parameters
last_layer = 32
mlp_lr_units = [4*last_layer, 4*last_layer, n_out]
lr_net = Expander(mlp_lr_units, last_layer, bn = True).to(device)

fmodel_lr = fmodel[:-3] + '_inference_network.pt'
fout_lr   = fout[:-4] + '_inference_network.txt'

# get optimizer and scheduler parameters
lr     = 5e-4
epochs = 300
optimizer = torch.optim.AdamW(lr_net.parameters(), lr=lr, betas=(0.9, 0.999), 
                             eps=1e-8, amsgrad=False)  
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', 
                                                       factor=0.3, verbose=True)


## Dataset

In [9]:
dset = TensorDataset(torch.tensor(np.log(Pk)), torch.tensor(params))
num_params = 4
batch_size= 256
train_dset, valid_dset, test_dset = torch.utils.data.random_split(dset,
                                                                [int(train_frac*dset_size),
                                                                 int(valid_frac*dset_size), 
                                                                 int(test_frac*dset_size)],
                                                                 generator=torch.Generator().manual_seed(seed))

train_dset = TensorDataset(dset.tensors[0][train_dset.indices].reshape(-1, len(k)),
                           dset.tensors[1][train_dset.indices].reshape(-1, num_params))
valid_dset = TensorDataset(dset.tensors[0][valid_dset.indices].reshape(-1, len(k)),
                           dset.tensors[1][valid_dset.indices].reshape(-1, num_params))
test_dset = TensorDataset(dset.tensors[0][test_dset.indices].reshape(-1, len(k)),
                          dset.tensors[1][test_dset.indices].reshape(-1, num_params))

train_loader = DataLoader(train_dset, batch_size, shuffle = True)
valid_loader = DataLoader(valid_dset, batch_size, shuffle = True)
test_loader  = DataLoader(test_dset, batch_size, shuffle = False)

## Obtain the summaries from the encoder network

In [10]:
x_train = []
y_train = []

x_valid = []
y_valid = []

with torch.no_grad():
    for x, y in train_loader:
        x    = x.to(device=device).float()
        y    = y.to(device=device).float()[:, [0, 1, 3]]
        x_NN = net(x).to(device=device)
        
        x_train.append(x_NN)
        y_train.append(y)
        
    for x, y in valid_loader:
        x    = x.to(device=device).float()
        y    = y.to(device=device).float()[:, [0, 1, 3]]
        x_NN = net(x).to(device=device)
        
        x_valid.append(x_NN)
        y_valid.append(y)

############################
x_train = torch.cat(x_train)
y_train = torch.cat(y_train)

train_dset = TensorDataset(x_train, y_train)
train_loader = DataLoader(train_dset, batch_size, shuffle = True)
############################

x_valid = torch.cat(x_valid)
y_valid = torch.cat(y_valid)

valid_dset = TensorDataset(x_valid, y_valid)
valid_loader = DataLoader(valid_dset, batch_size, shuffle = True)

## Train the inference network

In [None]:
verbose=False
lr_net.eval()
min_valid_loss, points = 0.0, 0
for x, y in valid_loader:
    with torch.no_grad():
        x    = x.to(device=device)
        y    = y.to(device=device)[:, :n_params]
        y_NN = lr_net(x).to(device=device) 
        
        y_pred, cov_pred = y_NN[:, :n_params], y_NN[:, n_params:]
        Cov = vector_to_Cov(cov_pred.cpu()).to(device=device)
        loss = -dist.MultivariateNormal(loc=y_pred, covariance_matrix=Cov).log_prob(y).to(device=device).mean()
        
        min_valid_loss += (loss.cpu().item())*x.shape[0]
        points += x.shape[0]
        
min_valid_loss /= points
if verbose:
    print('Initial valid loss = %.3e'%min_valid_loss)
    
# loop over all the epochs
for epoch in range(epochs):
    
    # training
    train_loss, num_points = 0.0, 0
    lr_net.train()
    for x,y in train_loader:
        x = x.to(device)
        y = y.to(device)[:, :n_params]
        y_NN = lr_net(x).to(device=device) 
        
        y_pred, cov_pred = y_NN[:, :n_params], y_NN[:, n_params:]
        Cov = vector_to_Cov(cov_pred.cpu()).to(device=device)
        loss = -dist.MultivariateNormal(loc=y_pred, covariance_matrix=Cov).log_prob(y).mean()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += (loss.cpu().item())*x.shape[0]
        num_points += x.shape[0]
        
    train_loss = train_loss/num_points

    # validation
    valid_loss, num_points = 0.0, 0
    lr_net.eval()
    for x,y in valid_loader:
        with torch.no_grad():
            x = x.to(device)
            y = y.to(device)[:, :n_params]          
            y_NN = lr_net(x).to(device=device) 
        
            y_pred, cov_pred = y_NN[:, :n_params], y_NN[:, n_params:]
            Cov = vector_to_Cov(cov_pred.cpu()).to(device=device)
            loss = -dist.MultivariateNormal(loc=y_pred, covariance_matrix=Cov).log_prob(y).mean()
            
            valid_loss += (loss.cpu().item())*x.shape[0]
            num_points += x.shape[0]
    valid_loss = valid_loss/num_points

    # verbose
    if valid_loss<min_valid_loss:
        min_valid_loss = valid_loss
        torch.save(lr_net.state_dict(), fmodel_lr)
        print('Epoch %d: %.3e %.3e (saving)'%(epoch, train_loss, valid_loss))
    else:
        print('Epoch %d: %.3e %.3e '%(epoch, train_loss, valid_loss))

    if epoch == 0:
        f = open(fout_lr, 'w')
    else:
        f = open(fout_lr, 'a')
    f.write('%d %.4e %.4e\n'%(epoch, train_loss, valid_loss))
    f.close()
    
    scheduler.step(valid_loss)

In [None]:
losses = np.loadtxt(fout_lr)
plt.plot(losses[:, 0], losses[:, 1], label = 'Training loss')
plt.plot(losses[:, 0], losses[:, 2], label = 'Validation loss')
plt.legend(loc = 'best')

In [None]:
net.load_state_dict(torch.load(fmodel))
net.eval();

lr_net.load_state_dict(torch.load(fmodel_lr))
lr_net.eval(); 

test_loss, num_points = 0., 0
params_true = []
params_pred = []
errors_pred = []
with torch.no_grad(): 
    for x, y in test_loader:
        x = x.float().to(device)
        y = y.to(device)[:, [0, 1, 3]]
        y_NN = lr_net(net(x))
        
        y_pred, cov_pred = y_NN[:, :n_params], y_NN[:, n_params:]
        Cov = vector_to_Cov(cov_pred.cpu()).to(device=device)
        loss = -dist.MultivariateNormal(loc=y, covariance_matrix=Cov).log_prob(y_pred).mean()
        
        test_loss += (loss.cpu().item())*x.shape[0]
        num_points += x.shape[0]
        
        params_true.append(y)
        params_pred.append(y_pred)
        errors_pred.append(Cov)
    
    test_loss = test_loss/num_points
print(test_loss)
params_true = torch.cat(params_true)
params_pred = torch.cat(params_pred)
errors_pred = torch.cat(errors_pred)

MSE_error = F.mse_loss(params_true[:, :2], params_pred[:, :2]).cpu().numpy()
print('MSE error: {:}'.format(MSE_error))
MSE_error = F.mse_loss(params_true[:, :1], params_pred[:, :1]).cpu().numpy()
print('MSE error on OmegaM: {:}'.format(MSE_error))
MSE_error = F.mse_loss(params_true[:, 1:2], params_pred[:, 1:2]).cpu().numpy()
print('MSE error on sigma8: {:}'.format(MSE_error))

print('\nActual errors on A, B (relative, %)')
print((torch.abs(params_pred[:, :1] - params_true[:, :1])/params_true[:, :1]).mean()*100)
print(-(torch.abs(params_pred[:, 1:2] - params_true[:, 1:2])/params_true[:, 1:2]).mean()*100)

print('\nPredicted errors on A, B (relative, %)')
print((torch.sqrt(errors_pred[:, 0, 0])/params_pred[:, :1]).mean()*100)
print(-(torch.sqrt(errors_pred[:, 1, 1])/params_pred[:, 1:2]).mean()*100)


In [None]:
fig, axs = plt.subplots(1, 3, figsize=(21, 7))

axs[0].set_ylabel('A, predicted')
axs[0].set_xlabel('A, true')
axs[0].errorbar(params_true[:, 0].cpu(), params_pred[:, 0].cpu(), 
                yerr=torch.sqrt(errors_pred[:, 0, 0]).cpu(), 
                linestyle = '', capsize = 2, label = r'$1\sigma$')
axs[0].plot([0.1, 1.], [0.1, 1.], c = 'k', lw = 2)
axs[0].legend(loc = 'upper left')

axs[1].set_ylabel('B, predicted')
axs[1].set_xlabel('B, true')
axs[1].errorbar(params_true[:, 1].cpu(), params_pred[:, 1].cpu(), 
                yerr=torch.sqrt(errors_pred[:, 1, 1]).cpu(), 
                linestyle = '', capsize = 2, label = r'$1\sigma$')
axs[1].plot([-1, 0.], [-1., 0.], c = 'k', lw = 2)
axs[1].legend(loc = 'upper left')

axs[2].set_ylabel('D, predicted')
axs[2].set_xlabel('D, true')
axs[2].errorbar(params_true[:, 2].cpu(), params_pred[:, 2].cpu(), 
                yerr=torch.sqrt(errors_pred[:, 2, 2]).cpu(),
                linestyle = '', capsize = 2, label = r'$1\sigma$')
axs[2].legend(loc = 'upper left')

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(24, 7))

axs[0].set_ylabel('Predicted A')
axs[0].set_xlabel('True A')
axs[0].errorbar(params_true[:, 0].cpu(), params_pred[:, 0].cpu(), 
                yerr=torch.sqrt(errors_pred[:, 0, 0]).cpu(), 
                linestyle = '', capsize = 2, label = r'$1\sigma$')
axs[0].plot([0.1, 1.], [0.1, 1.], c = 'k', lw = 2)
axs[0].legend(loc = 'upper left')
axs[0].set_aspect('equal')

axs[1].set_ylabel('Predicted B')
axs[1].set_xlabel('True B')
axs[1].errorbar(params_true[:, 1].cpu(), params_pred[:, 1].cpu(), 
                yerr=torch.sqrt(errors_pred[:, 1, 1]).cpu(), 
                linestyle = '', capsize = 2, label = r'$1\sigma$')
axs[1].plot([-1, 0.], [-1., 0.], c = 'k', lw = 2)
axs[1].legend(loc = 'upper left')
axs[1].set_aspect('equal')

axs[2].set_ylabel('Predicted D')
axs[2].set_xlabel('True D')
axs[2].errorbar(params_true[:, 2].cpu(), params_pred[:, 2].cpu(), 
                yerr=torch.sqrt(errors_pred[:, 2, 2]).cpu(), 
                linestyle = '', capsize = 2, label = r'$1\sigma$')
axs[2].plot([-.5, 0.5], [-.5, 0.5], c = 'k', lw = 2)
axs[2].legend(loc = 'upper left')
axs[2].set_aspect('equal')