In [None]:
import sys, os, math
import numpy as np
import torch
import matplotlib.pyplot as plt

import torch.nn as nn
from torch.functional import F
import torch.distributions as dist
from torch.utils.data import DataLoader, TensorDataset, random_split

import importlib
sys.path.append('../')
from utils_modules.models import SummaryNet, Expander, Net, vector_to_Cov
from utils_modules.vicreg import vicreg_loss
import utils_modules.data as utils_data

In [None]:
# select device; use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Device: %s'%(device))

## Load data

In [None]:
# load maps and parameters
maps      = np.load(...)[:, :, None, :, :]
dset_size = maps.shape[0] # data set size
splits    = maps.shape[1] # number of augmentations/views per parameter set

params  = np.load(...)[:, None, :]
params  = np.repeat(params, splits, axis = 1) # reshape the parameters to match the shape of the maps

# pre-process the maps data set
rescale     = True
standardize = True
verbose     = True

if rescale:
    maps = np.log(maps+1)
if standardize:
    maps_mean, maps_std = np.mean(maps, dtype=np.float64), np.std(maps, dtype=np.float64)
    maps = (maps - maps_mean)/maps_std
    
if verbose:
    print('Shape of parameters and maps:', params.shape, maps.shape)
    print('Parameter 1 range of values: [{:.3f}, {:.3f}]'.format(params[:, :, 0].min(), params[:, :, 0].max()))
    print('Parameter 2 range of values: [{:.3f}, {:.3f}]'.format(params[:, :, 1].min(), params[:, :, 1].max()))
    
    if rescale: print('Rescale: ', rescale)
    if standardize: print('Standardize: ', standardize)

maps   = torch.tensor(maps).float().to(device) 
params = torch.tensor(params).float().to(device)

In [None]:
# divide the data into train, validation, and test sets
batch_size = 200
train_frac, valid_frac, test_frac = 0.8, 0.1, 0.1


train_dset, valid_dset, test_dset = utils_data.create_datasets(maps, params, 
                                                               train_frac, valid_frac, test_frac, 
                                                               seed = seed,
                                                               rotations=True) 


train_loader = DataLoader(train_dset, batch_size, shuffle = True)
valid_loader = DataLoader(valid_dset, batch_size, shuffle = True)
test_loader  = DataLoader(test_dset, batch_size, shuffle = False)

## Load the encoder model

In [6]:
fmodel = ...
fout   = ...

hidden     = 8
last_layer = 2*hidden

n_params   = 2
n_tril     = int(n_params * (n_params + 1) / 2)  # Number of parameters in lower triangular matrix, for symmetric matrix
n_out      = n_params + n_tril  

# load the encoder model
model = SummaryNet(hidden = hidden, last_layer = last_layer).to(device)
model.load_state_dict(torch.load(fmodel))
model.eval(); 

## Convert maps into summaries

In [7]:
x_train = []
y_train = []

x_valid = []
y_valid = []

with torch.no_grad():
    for x, y in train_loader:
        x    = x.to(device=device)
        y    = y.to(device=device)
        x_NN = model(x).to(device=device)
        
        x_train.append(x_NN)
        y_train.append(y)
        
    for x, y in valid_loader:
        x    = x.to(device=device)
        y    = y.to(device=device)
        x_NN = model(x).to(device=device)
        
        x_valid.append(x_NN)
        y_valid.append(y)

############################
x_train = torch.cat(x_train)
y_train = torch.cat(y_train)

train_dset = TensorDataset(x_train, y_train)
train_loader = DataLoader(train_dset, batch_size, shuffle = True)
############################

x_valid = torch.cat(x_valid)
y_valid = torch.cat(y_valid)

valid_dset = TensorDataset(x_valid, y_valid)
valid_loader = DataLoader(valid_dset, batch_size, shuffle = True)

## Downstream task: Parameter Inference

In [8]:
# output files
fmodel_lr = ...
fout_lr   = ...

# define the network model for the downstream task
mlp_lr_units = [16*last_layer, 16*last_layer, n_out]
lr_net = Expander(mlp_lr_units, last_layer, bn = True).to(device)


In [9]:
# hyperparameters
lr         = 1e-3
epochs     = 200
lr        = 5e-4 # 1e-3

optimizer = torch.optim.AdamW(lr_net.parameters(), lr=lr, betas=(0.9, 0.999), 
                             eps=1e-8, amsgrad=False)  
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2)

        
    

In [None]:
lr_net.eval()
min_valid_loss, points = 0.0, 0
for x, y in valid_loader:
    with torch.no_grad():
        x    = x.to(device=device)
        y    = y.to(device=device)
        y_NN = lr_net(x).to(device=device) 
        
        y_pred, cov_pred = y_NN[:, :n_params], y_NN[:, n_params:]
        Cov = vector_to_Cov(cov_pred.cpu()).to(device=device)
        loss = -dist.MultivariateNormal(loc=y_pred, covariance_matrix=Cov).log_prob(y).to(device=device).mean()
        
        min_valid_loss += (loss.cpu().item())*x.shape[0]
        points += x.shape[0]
        
min_valid_loss /= points
if verbose:
    print('Initial valid loss = %.3e'%min_valid_loss)
# loop over the epochs
for epoch in range(epochs):
    
    # training
    train_loss, num_points = 0.0, 0
    lr_net.train()
    for x,y in train_loader:
        x = x.to(device)
        y = y.to(device)
        y_NN = lr_net(x).to(device=device) 
        
        y_pred, cov_pred = y_NN[:, :n_params], y_NN[:, n_params:]
        Cov = vector_to_Cov(cov_pred.cpu()).to(device=device)
        loss = -dist.MultivariateNormal(loc=y_pred, covariance_matrix=Cov).log_prob(y).mean()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += (loss.cpu().item())*x.shape[0]
        num_points += x.shape[0]
        
    train_loss = train_loss/num_points

    # validation
    valid_loss, num_points = 0.0, 0
    lr_net.eval()
    for x,y in valid_loader:
        with torch.no_grad():
            x = x.to(device)
            y = y.to(device)            
            y_NN = lr_net(x).to(device=device) 
        
            y_pred, cov_pred = y_NN[:, :n_params], y_NN[:, n_params:]
            Cov = vector_to_Cov(cov_pred.cpu()).to(device=device)
            loss = -dist.MultivariateNormal(loc=y_pred, covariance_matrix=Cov).log_prob(y).mean()
            
            valid_loss += (loss.cpu().item())*x.shape[0]
            num_points += x.shape[0]
    valid_loss = valid_loss/num_points

    # verbose
    if valid_loss<min_valid_loss:
        min_valid_loss = valid_loss
        torch.save(lr_net.state_dict(), fmodel_lr)
        print('Epoch %d: %.3e %.3e (saving)'%(epoch, train_loss, valid_loss))
    else:
        print('Epoch %d: %.3e %.3e '%(epoch, train_loss, valid_loss))

    if epoch == 0:
        f = open(fout_lr, 'w')
    else:
        f = open(fout_lr, 'a')
    f.write('%d %.4e %.4e\n'%(epoch, train_loss, valid_loss))
    f.close()
    
    scheduler.step(valid_loss)

In [None]:
plt.figure(figsize = (10, 6))
losses = np.loadtxt(fout_lr)
start_epoch = 0
end_epoch = 200
plt.plot(losses[start_epoch:end_epoch, 0], losses[start_epoch:end_epoch, 1], label = 'Training loss')
plt.plot(losses[start_epoch:end_epoch, 0], losses[start_epoch:end_epoch, 2], label = 'Validation loss')
plt.legend(loc = 'best')

In [None]:
model.load_state_dict(torch.load(fmodel))
model.eval();

lr_net.load_state_dict(torch.load(fmodel_lr))
lr_net.eval(); 

test_loss, num_points = 0., 0
params_true = []
params_pred = []
errors_pred = []
with torch.no_grad(): 
    for x, y in test_loader:
        x = x.float()
        y = y.float()[:, [0, 1]]
        bs = x.shape[0]
        
        y_NN = lr_net(model(x))
        
        
        y_pred, cov_pred = y_NN[:, :n_params], y_NN[:, n_params:]
        Cov = vector_to_Cov(cov_pred.cpu()).to(device=device)
        loss = -dist.MultivariateNormal(loc=y, covariance_matrix=Cov).log_prob(y_pred).mean()
        
        test_loss += (loss.cpu().item())*x.shape[0]
        num_points += x.shape[0]
        
        params_true.append(y)
        params_pred.append(y_pred)
        errors_pred.append(Cov)
    
    test_loss = test_loss/num_points
print('Test loss: ', test_loss)

params_true = torch.cat(params_true)
params_pred = torch.cat(params_pred)  
errors_pred = torch.cat(errors_pred)


In [None]:
MSE_error = F.mse_loss(params_true[:, :2], params_pred[:, :2]).cpu().numpy()
print('MSE error: {:}'.format(MSE_error))
MSE_error = F.mse_loss(params_true[:, :1], params_pred[:, :1]).cpu().numpy()
print('MSE error on OmegaM: {:}'.format(MSE_error))
MSE_error = F.mse_loss(params_true[:, 1:], params_pred[:, 1:]).cpu().numpy()
print('MSE error on sigma8: {:}'.format(MSE_error))

print('\nActual errors on A, B (relative, %)')
print((torch.abs(params_pred[:, :1] - params_true[:, :1])/params_true[:, :1]).mean()*100)
print((torch.abs(params_pred[:, 1:] - params_true[:, 1:])/params_true[:, 1:]).mean()*100)

print('\nPredicted errors on A, B (relative, %)')
print((torch.sqrt(errors_pred[:, 0, 0])/params_pred[:, :1]).mean()*100)
print((torch.sqrt(errors_pred[:, 1, 1])/params_pred[:, 1:]).mean()*100)

## Make a plot

In [None]:
params_pred_plot = params_pred.cpu().numpy()
params_true_plot = params_true.cpu().numpy()
errors_pred_plot = errors_pred.cpu().numpy()

params_unique, indices_unique = np.unique(params_pred_plot[:, 0], return_index=True)
np.random.seed(seed)
np.random.shuffle(indices_unique)
indices_unique = indices_unique[:100]

fig, axs = plt.subplots(1, 2, figsize=(12, 5))
axs[0].set_ylabel('Predicted ' + r'$\Omega_M$')
axs[0].set_xlabel('True ' + r'$\Omega_M$')
axs[0].plot([0.15, 0.45], [0.15, 0.45], c = 'k', lw = 2, ls = '--')
axs[0].errorbar(params_true_plot[indices_unique, 0], params_pred_plot[indices_unique, 0], 
               yerr=np.sqrt(errors_pred_plot[indices_unique, 0, 0]), 
                linestyle = '', capsize = 2, label = r'$1\sigma$')
axs[0].set_aspect('equal')
axs[0].legend(loc = 'upper left')

axs[1].set_ylabel('Predicted ' + r'$\sigma_8$')
axs[1].set_xlabel('True ' + r'$\sigma_8$')
axs[1].plot([0.65, 0.95], [0.65, 0.95], c = 'k', lw = 2, ls = '--')
axs[1].errorbar(params_true_plot[indices_unique, 1], params_pred_plot[indices_unique, 1], 
               yerr=np.sqrt(errors_pred_plot[indices_unique, 1, 1]), 
                linestyle = '', capsize = 2, label = r'$1\sigma$')
axs[1].set_aspect('equal')
axs[1].legend(loc = 'upper left')
