In [None]:
import sys, os, math
import numpy as np
import torch
import matplotlib.pyplot as plt
import importlib

import torch.nn as nn
from torch.functional import F
import torch.distributions as dist
from torch.utils.data import DataLoader, TensorDataset, random_split
import torchvision

sys.path.append('../')
from utils_modules.models import SummaryNet, Expander, Net, NetEquivalent, vector_to_Cov
from utils_modules.vicreg import vicreg_loss
import utils_modules.data as utils_data

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Device: %s'%(device))

## Load maps

In [None]:
# load maps and parameters
n_params = 2
home_dir = ... # maps and parameters directory
maps   = np.load(home_dir + 'Maps_Mtot_SIMBA_LH_z=0.00.npy')
dset_size = 1000 # data set size
splits    = 15   # number of realizations per parameter set
maps_size = maps.shape[-1]
maps   = maps.reshape(dset_size, splits, 1, maps_size, maps_size) # prepare maps for VICReg

params = np.loadtxt(home_dir + 'params_SIMBA.txt')[:, None, :n_params]
params  = np.repeat(params, splits, axis = 1) # reshape the parameters to match the shape of the maps
minimum = np.array([0.1, 0.6])
maximum = np.array([0.5, 1.0])
params  = (params - minimum)/(maximum - minimum) # rescale parameters

# pre-process the maps data set
rescale     = True
standardize = True
verbose     = True

if rescale:
    maps = np.log10(maps)
if standardize:
    maps_mean, maps_std = np.mean(maps, dtype=np.float64), np.std(maps, dtype=np.float64)
    maps = (maps - maps_mean)/maps_std
    
if verbose:
    print('Shape of parameters and maps:', params.shape, maps.shape)
    print('Parameter 1 range of values: [{:.3f}, {:.3f}]'.format(params[:, :, 0].min(), params[:, :, 0].max()))
    print('Parameter 2 range of values: [{:.3f}, {:.3f}]'.format(params[:, :, 1].min(), params[:, :, 1].max()))
    
    if rescale: print('Rescale: ', rescale)
    if standardize: print('Standardize: ', standardize)

maps   = torch.tensor(maps).float().to(device) 
params = torch.tensor(params).float().to(device)

In [None]:
# divide the data into train, validation, and test sets
batch_size = 150
train_frac, valid_frac, test_frac = 0.7, 0.2, 0.1
seed = 1
train_dset, valid_dset, test_dset = utils_data.create_datasets(maps, params, 
                                                               train_frac, valid_frac, test_frac, 
                                                               seed = seed, 
                                                               rotations=True) 

train_loader = DataLoader(train_dset, batch_size, shuffle = True)
valid_loader = DataLoader(valid_dset, batch_size, shuffle = True)
test_loader  = DataLoader(test_dset, batch_size, shuffle = False)


## Run the training

In [6]:
save_dir    = ...

if not os.path.exists(save_dir):
    os.makedirs(save_dir)
    if verbose: print('\n Saving to: ', save_dir)
n_params = 2
n_tril = int(n_params * (n_params + 1) / 2)  
n_total = n_params + n_tril

lr         = 2e-4
eta_min    = lr/100
epochs     = 200    

fmodel = save_dir + 'model.pt'
fout   = save_dir + 'losses.txt'


In [None]:
# define the model
model = torchvision.models.resnet18(num_classes=n_total)
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model.to(device);

# define the optimizer, scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.5, 0.999))
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=epochs, 
                                                       eta_min=eta_min, verbose=True)


In [None]:
model.eval()
min_valid_loss, points = 0.0, 0
for x, y in valid_loader:
    with torch.no_grad():
        x    = x.to(device=device)
        y    = y.to(device=device)
        y_NN = model(x).to(device=device) 
        
        y_pred, cov_pred = y_NN[:, :n_params], y_NN[:, n_params:]
        Cov = vector_to_Cov(cov_pred.cpu()).to(device=device)
        loss = -dist.MultivariateNormal(loc=y_pred, covariance_matrix=Cov).log_prob(y).to(device=device).mean()
        
        min_valid_loss += (loss.cpu().item())*x.shape[0]
        points += x.shape[0]
        
min_valid_loss /= points
if verbose:
    print('Initial valid loss = %.3e'%min_valid_loss)
    
# loop over the epochs
for epoch in range(epochs):
    
    # training
    train_loss, num_points = 0.0, 0
    model.train()
    for x,y in train_loader:
        x = x.to(device)
        y = y.to(device)
        y_NN = model(x).to(device=device) 
        
        y_pred, cov_pred = y_NN[:, :n_params], y_NN[:, n_params:]
        Cov = vector_to_Cov(cov_pred.cpu()).to(device=device)
        loss = -dist.MultivariateNormal(loc=y_pred, covariance_matrix=Cov).log_prob(y).mean()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += (loss.cpu().item())*x.shape[0]
        num_points += x.shape[0]
        
    train_loss = train_loss/num_points

    # validation
    valid_loss, num_points = 0.0, 0
    model.eval()
    for x,y in valid_loader:
        with torch.no_grad():
            x = x.to(device)
            y = y.to(device)            
            y_NN = model(x).to(device=device) 
        
            y_pred, cov_pred = y_NN[:, :n_params], y_NN[:, n_params:]
            Cov = vector_to_Cov(cov_pred.cpu()).to(device=device)
            loss = -dist.MultivariateNormal(loc=y_pred, covariance_matrix=Cov).log_prob(y).mean()
            
            valid_loss += (loss.cpu().item())*x.shape[0]
            num_points += x.shape[0]
    valid_loss = valid_loss/num_points

    # verbose
    if valid_loss<min_valid_loss:
        min_valid_loss = valid_loss
        torch.save(model.state_dict(), fmodel)
        print('Epoch %d: %.3e %.3e (saving)'%(epoch, train_loss, valid_loss))
    else:
        print('Epoch %d: %.3e %.3e '%(epoch, train_loss, valid_loss))

    if epoch == 0:
        f = open(fout, 'w')
    else:
        f = open(fout, 'a')
    f.write('%d %.4e %.4e\n'%(epoch, train_loss, valid_loss))
    f.close()
    
    scheduler.step()

## Evaluate

In [None]:
# plot losses
losses = np.loadtxt(fout)
start_epoch = 0
end_epoch   = -1

plt.figure(figsize = (10, 6))
plt.ylim(-7, 7)
plt.plot(losses[start_epoch:end_epoch, 0], losses[start_epoch:end_epoch, 1], label = 'Training loss')
plt.plot(losses[start_epoch:end_epoch, 0], losses[start_epoch:end_epoch, 2], label = 'Validation loss')
plt.legend(loc = 'best')

In [None]:
# compute loss on the test set
model.load_state_dict(torch.load(fmodel))
model.eval();

test_loss, num_points = 0., 0
params_true = []
params_pred = []
errors_pred = []
with torch.no_grad(): 
    for x, y in test_loader:
        x = x.to(device)
        y = y.to(device)
        y_NN = model(x)
        
        y_pred, cov_pred = y_NN[:, :n_params], y_NN[:, n_params:]
        Cov = vector_to_Cov(cov_pred.cpu()).to(device=device)
        loss = -dist.MultivariateNormal(loc=y, covariance_matrix=Cov).log_prob(y_pred).mean()
        
        test_loss += (loss.cpu().item())*x.shape[0]
        num_points += x.shape[0]
        
        params_true.append(y)
        params_pred.append(y_pred)
        errors_pred.append(Cov)
    
    test_loss = test_loss/num_points
print(test_loss)

params_true = torch.cat(params_true)
params_pred = torch.cat(params_pred)  
errors_pred = torch.cat(errors_pred)

delta_scale = maximum - minimum
params_true = params_true.cpu()*(maximum - minimum) + minimum
params_pred = params_pred.cpu()*(maximum - minimum) + minimum  
errors_pred = np.array([torch.sqrt(errors_pred[:, 0, 0]).cpu().numpy()*delta_scale[0], 
                        torch.sqrt(errors_pred[:, 1, 1]).cpu().numpy()*delta_scale[1]]).T



MSE_error = F.mse_loss(params_true[:, :2], params_pred[:, :2]).cpu().numpy()
print('MSE error: {:}'.format(MSE_error))
MSE_error = F.mse_loss(params_true[:, :1], params_pred[:, :1]).cpu().numpy()
print('MSE error on OmegaM: {:}'.format(MSE_error))
MSE_error = F.mse_loss(params_true[:, 1:], params_pred[:, 1:]).cpu().numpy()
print('MSE error on sigma8: {:}'.format(MSE_error))

print('\nActual errors on A, B (relative, %)')
print((torch.abs(params_pred[:, :1] - params_true[:, :1])/params_true[:, :1]).mean()*100)
print((torch.abs(params_pred[:, 1:] - params_true[:, 1:])/params_true[:, 1:]).mean()*100)

print('\nPredicted errors on A, B (relative, %)')
print((errors_pred[:, 0]/params_pred[:, :1].cpu()).mean()*100)
print((errors_pred[:, 1]/params_pred[:, 1:].cpu()).mean()*100)

In [None]:
params_pred_plot = params_pred.cpu()
params_true_plot = params_true.cpu()

params_unique, indices_unique = np.unique(params_true_plot[:, 0], return_index=True)
np.random.seed(seed)
np.random.shuffle(indices_unique)
indices_unique = indices_unique[:100]

fig, axs = plt.subplots(1, 2, figsize=(12, 5))
axs[0].set_ylabel(r'$\Omega_M$'+', predicted')
axs[0].set_xlabel(r'$\Omega_M$'+', true')
axs[0].errorbar(params_true_plot[indices_unique, 0], params_pred_plot[indices_unique, 0], 
                yerr=errors_pred[indices_unique, 0], 
                linestyle = '', capsize = 2, label =  r'$1\sigma$')
axs[0].plot([0.1, 0.5], [0.1, 0.5], c = 'k', lw = 2)
axs[0].legend(loc='best')
axs[0].set_aspect('equal')

axs[1].set_ylabel(r'$\sigma_8$'+', predicted')
axs[1].set_xlabel(r'$\sigma_8$'+', true')
axs[1].errorbar(params_true_plot[indices_unique, 1], params_pred_plot[indices_unique, 1], 
                yerr=errors_pred[indices_unique, 1], 
                linestyle = '', capsize = 2, label = r'$1\sigma$')
axs[1].plot([0.6, 1.], [0.6, 1.], c = 'k', lw = 2)
axs[1].set_aspect('equal')
axs[1].legend(loc='best')
plt.savefig(...)