# End to End Inference Tutorial
Here we implement a complete end-to-end use of paltas. This notebook is intended as a 'minimal reproducible example', and thus doesn't use the full extent of the package, but should be a useful starting point. \
A number of the code-blocks simply run command-line instructions. This is intentional, as paltas is designed to run in the command-line. Furthermore, running such command-line statements allows easier transfer to remote computing clusters/parallelisation.


# Goals

1. To be able to implement a simple end-to-end example of Paltas
2. To understand how each of the packages inter-communicate, and which packages need to be run (and when), to perform hierarchichal inference.

# Import Packages
Here we import the required packages and define the training and model directories (where the training images and model weights are stored, respectively).\
The '/home/runner/work' referred to here is required to run this notebook as a Github Action, but should be changed to a prefered directory when running this notebook locally.\
Although tensorflow, emcee and ipython do not form part of the requirements for paltas, they are required for this notebook.

In [None]:
#paltas_directory = './'
#training_directory = '/home/runner/work/notebooks/End_to_End_Tutorial_Files/' #For github actions
#model_directory = '/home/runner/work/notebooks/End_to_End_Tutorial_Files/'
#training_directory = '/global/u2/p/phil1884/paltas/notebooks/End_to_End_Tutorial_Files' #For NERSC
#model_directory = '/global/u2/p/phil1884/paltas/notebooks/End_to_End_Tutorial_Files'
#paltas_directory = '/global/u2/p/phil1884/paltas/'
training_directory = '/mnt/extraspace/hollowayp/paltas_data/Example_SL_12/' #For Glamdring
model_directory = '/mnt/extraspace/hollowayp/paltas_data/Example_SL_12/'
paltas_directory = '/mnt/zfsusers/hollowayp/paltas/'
import os
os.chdir(paltas_directory)
from paltas.Analysis import hierarchical_inference,dataset_generation, loss_functions, conv_models
from IPython.display import display,Pretty
import matplotlib.patches as mpatches
import matplotlib.pyplot as pl
from scipy.stats import norm
import tensorflow as tf
from tqdm import tqdm
import pandas as pd
import numpy as np
import corner
import emcee
import numba
import h5py
import glob
import sys
import datetime
random_seed = 4
np.random.seed(random_seed)
tf.random.set_seed(random_seed)


In [None]:
#!python3 -m pip install arviz --upgrade
import arviz as az
az.style.use("arviz-plasmish")
#az.style.available

# Load Kernel

In [None]:
import os
import dill
paltas_directory = '/mnt/zfsusers/hollowayp/paltas/'
os.chdir(paltas_directory)
dill.load_session('/mnt/zfsusers/hollowayp/paltas/notebooks/End_to_End_Example_Inference_Working_File.db')

# Save Kernel

In [None]:
import dill
#dill.dump_session('/mnt/zfsusers/hollowayp/paltas/notebooks/End_to_End_Example_Inference_Working_File.db')

# Generate Images
We start by generating lensed images divided into training and validation sets. The images are saved within one h5 file for each run of generate.py. We'll first look at the configuration file used to determine the properties of the generated images:

In [None]:
display(Pretty(glob.glob(f'{training_directory}/training/1/config*')[0]))

We then run the image generation:

In [None]:
!python3 ./paltas/generate.py ./paltas/Configs/Examples/config_simple_tutorial.py /$training_directory/training/1 --n 100 --tf_record --h5
!python3 ./paltas/generate.py ./paltas/Configs/Examples/config_simple_tutorial.py /$training_directory/validation/1 --n 100 --tf_record --h5

# Plot the Generated Distributions

In [None]:
train_folders_list = glob.glob(f'{training_directory}/training/**/metadata.csv',recursive=True)
for ii,m_i in tqdm(enumerate(train_folders_list)):
    if ii==0:
        generated_image_params_db = pd.read_csv(f'{train_folders_list[ii]}')
    else:
        generated_image_params_db = pd.concat([generated_image_params_db,
                                pd.read_csv(f'{train_folders_list[ii]}')],ignore_index=True)

try:
    generated_image_params_db=generated_image_params_db.drop(['cosmology_parameters_cosmology_name',
                                        'detector_parameters_background_noise',
                                        'detector_parameters_exposure_time',
                                        'detector_parameters_magnitude_zero_point',
                                        'detector_parameters_num_exposures',
                                        'detector_parameters_pixel_scale',
                                        'detector_parameters_read_noise',
                                        'detector_parameters_sky_brightness',
                                        'lens_light_parameters_output_ab_zeropoint',
                                        'main_deflector_parameters_M200',
                                        'detector_parameters_ccd_gain',
                                        'psf_parameters_fwhm',
                                        'psf_parameters_psf_type',
                                        'seed',
                                        'source_parameters_output_ab_zeropoint',
                                        'main_deflector_parameters_dec_0',
                                        'main_deflector_parameters_ra_0',
                                        'source_parameters_n_sersic',
                                        'lens_light_parameters_n_sersic'],axis=1)
except:
        pass


In [None]:
#corner.corner(generated_image_params_db,labels=generated_image_params_db.columns)
corner.corner(generated_image_params_db[learning_params],
              labels=[elem.replace('main_deflector_parameters_','') for elem in learning_params])
pl.show()

# Train Model
The neural network is then trained. The --h5 indicates that the images were originally saved as h5 files, and should be retrieved as such. Again, most of the work is done by the configration file, so we'll inspect that first. We define the learning parameters (the lens properties the network should determine), in this file - we chose the Einstein radius, shear, power-law slope, position and ellipticity in this example. 

In [None]:
display(Pretty("./paltas/Analysis/AnalysisConfigs/train_config_examp_tutorial.py"))

In [None]:
!python3 ./paltas/Analysis/train_model.py ./paltas/Analysis/AnalysisConfigs/train_config_examp_tutorial.py --h5

# Generate Model Predictions
Having trained the model, we locate the filename of the final epoch (this can be hardcoded instead if desired).

In [None]:
def load_model_weights_list(directory):
    """ Function to return a list of weights filenames from the network
    args: Directory containing the training, validation and weights files """
    weights_list = glob.glob(f'{directory}/model_weights/*')
    weights_list = [elem.split('model_weights/')[1] for elem in weights_list]
    return weights_list

def return_final_epoch_weights(directory):
    """ File to return the weight filename of the final trained epoch
    args: Directory containing the training, validation and weights files """
    weights_list = load_model_weights_list(directory)
    print(weights_list)
    final_epoch =  np.max([int(elem.split('-')[0]) for elem in weights_list])
    w_filename = [x for x in weights_list if x.startswith("{:02d}".format(final_epoch)+'-')][0]
    print('FINAL EPOCH',w_filename)
    return directory+'/model_weights/'+w_filename

def return_list_of_weight_files(directory):
    '''Returns list of weight files, ordered by their creation date'''
    files = list(filter(os.path.isfile, glob.glob(f'{directory}/model_weights/*h5')))
    files.sort(key=lambda x: os.path.getmtime(x))
    return files

final_weights_filename = return_final_epoch_weights(model_directory)

def retrieve_training_prior():
    prior_path = glob.glob(f'{training_directory}/**/norm*',recursive=True)[0]
    print(f'Retrieving prior path from {prior_path}')
    training_prior_db = pd.read_csv(prior_path)
    return training_prior_db

prior_db = retrieve_training_prior()
prior_db_indx = prior_db.set_index(prior_db['parameter'])
print('NOTE: This training prior should really encompass all the training images, not just a subset (i.e. not just one folder of them)')

In [None]:
print(list(prior_db['mean']))
print(list(prior_db['std']))

## Configuration
The trained model is loaded (along with the network weights from the final epoch).

In [None]:
def load_model(model_weights_filename,loss_type,model_type,learning_params,log_learning_params,img_size):
    """ Loads the trained model
    args: 
    model_weights_filename (str): .h5 file containing the weights of the trained model.
    loss_type (str): 'full' or 'diag', depending on the type of covariance matrix chosen
    model type (str): 'xresnet34' or 'xresnet101', according to the choice of network
    learning_params (list of str): Parameters learnt by the network
    img_size (int): Dimensions of the input images"""
    num_params = len(learning_params+log_learning_params)
    if loss_type == 'full':
        num_outputs = num_params + int(num_params*(num_params+1)/2)
        loss_func = loss_functions.FullCovarianceLoss(num_params)
    elif loss_type == 'diag':
        num_outputs = 2*num_params
        loss_func = loss_functions.DiagonalCovarianceLoss(num_params)
    if model_type == 'xresnet101':
        model = conv_models.build_xresnet101(img_size,num_outputs)
    if model_type == 'xresnet34':
        model = conv_models.build_xresnet34(img_size,num_outputs)
    model.load_weights(model_weights_filename,by_name=True,skip_mismatch=True)
    return model,loss_func,num_params

#Import training configs
print("NB: Need to make sure the following is importing the correct training configuration file: Currently loading train_config_Simpipeline")
from paltas.Analysis.AnalysisConfigs.train_config_Simpipeline import learning_params,batch_size,flip_pairs,\
                                                               n_epochs,random_seed,norm_images,\
                                                               loss_function,model_type,\
                                                               npy_folders_train,img_size


corner_param_print= [elem.replace('main_deflector_parameters_','').replace('subhalo_parameters_','').\
                     replace('theta','\Theta').replace('gamma','\gamma') for elem in learning_params]

model_dict = {}
for ii,epoch_i_weights in tqdm(enumerate([return_list_of_weight_files(model_directory)[-1]])): #Just getting the most recent epoch
    if ii==0:
        model,loss_func,num_params = load_model(epoch_i_weights,loss_function,learning_params=learning_params,\
                             log_learning_params=[],model_type=model_type,img_size=img_size)
        model_dict[ii]=model
    else: 
        model,_,_ = load_model(epoch_i_weights,loss_function,learning_params=learning_params,\
                        log_learning_params=[],model_type=model_type,img_size=img_size)
        model_dict[ii]=model

In [None]:
return_list_of_weight_files(model_directory)

## Generate Network Predictions
The network predictions are then loaded, for testing on the validation set generated above

In [None]:
def gen_network_predictions(test_folder,norm_path,learning_params,log_learning_params,loss_type,
                            loss_func,model,shuffle=True,
                            norm_images=True,log_norm_images=False):
    """
    Generate neural network predictions given a paltas generated folder of images

    Args:
        test_folder (string): Path to folder of paltas generated images, 
            containig a data.tfrecord file
        norm_path (string): Path to .csv containing normalization of parameters
            applied during training of network
        learning_params (list(string)): Names of parameters learned
        loss_type (string): only 'diag' currently supported for this notebook
        loss_func (paltas.Analysis.loss_function): Loss function object, (needs
            draw_samples() and convert_output() functionality)
        model (paltas.Analysis.conv_models): Trained neural network with weights
            loaded
        shuffle (bool, default=True): If True, the order of the test set is shuffled
            when generating predictions
        norm_images (bool, default=True): If True, normalize test set images
        log_norm_images (bool, default=False): If True, test set imags are
            log-normalized and rescaled to range (0,1)

    Returns:
        y_test, y_pred, std_pred, prec_pred
    """

    tfr_test_path = os.path.join(test_folder,'data.tfrecord')
    input_norm_path = norm_path
    #The following code implementation here and in the hierarchical inference function below assumes a diagonal covariance matrix
    if loss_type !='diag':
        raise ValueError('loss_type not supported in this notebook')
    tf_dataset_test = dataset_generation.generate_tf_dataset(tf_record_path = tfr_test_path,\
                                                             learning_params = learning_params,
                                                             batch_size = 3,\
                                                             n_epochs = 1,\
                                                             norm_images=norm_images,
                                                             kwargs_detector=None,\
                                                             input_norm_path=input_norm_path,
                                                             log_learning_params=log_learning_params,\
                                                             shuffle=shuffle)

    y_test_list = [];y_pred_list = []
    std_pred_list = [];cov_pred_list = []
    predict_samps_list = []

    for batch in tf_dataset_test:
        images = batch[0].numpy()
        y_test = batch[1].numpy()
        
        # use unrotated output for covariance matrix
        output = model.predict(images)
        y_pred, log_var_pred = loss_func.convert_output(output)

        # compute std. dev.
        std_pred = np.exp(log_var_pred/2)
        cov_mat = np.empty((len(std_pred),len(std_pred[0]),len(std_pred[0])))
        for i in range(len(std_pred)):
            cov_mat[i] = np.diag(std_pred[i]**2)

        y_test_list.append(y_test)
        y_pred_list.append(y_pred)
        std_pred_list.append(std_pred)
        cov_pred_list.append(cov_mat)

    y_test = np.concatenate(y_test_list)
    y_pred = np.concatenate(y_pred_list)
    std_pred = np.concatenate(std_pred_list)
    cov_pred = np.concatenate(cov_pred_list)

    if input_norm_path is not None:
        dataset_generation.unnormalize_outputs(input_norm_path,learning_params+log_learning_params,
                                        y_pred,standard_dev=std_pred,cov_mat=cov_pred)
        dataset_generation.unnormalize_outputs(input_norm_path,learning_params+log_learning_params,
                                        y_test)
    prec_pred = np.linalg.inv(cov_pred)
   
    return y_test, y_pred, std_pred, prec_pred

In [None]:
learning_params

In [None]:
network_predictions_dict = {}
key_indx = np.linspace(0,len(model_dict.keys())-1,2).astype('int') #Just retrieving 2 epochs, including the last. 
for epoch_i in tqdm(np.array(list(model_dict.keys()))[key_indx]):
     network_predictions_dict[epoch_i] = gen_network_predictions(\
                        test_folder=training_directory+'/validation/1',\
                        norm_path=glob.glob(f'{training_directory}/**/norm*',recursive=True)[0],
                        #norm_path = training_directory+'/training/1/norms.csv',\
                        learning_params=learning_params,\
                        log_learning_params = [],\
                        loss_type=loss_function,
                        loss_func=loss_func,\
                        model=model_dict[epoch_i],
                        shuffle=False, #NOT shuffling here, so the network outputs can be compared with other parameters in the test set.
                        norm_images=norm_images,
                        log_norm_images=False)

In [None]:
model_dict.keys()

In [None]:
#Plot to see what proportion of the network posteriors (i.e the network outputs) lie outside the training prior, i.e to see how many the network believes it is 
#extrapolating for.

X_plot_dict = {'main_deflector_parameters_theta_E':np.linspace(-3,3,100),
'main_deflector_parameters_gamma1':np.linspace(-0.5,0.5,100),
'main_deflector_parameters_gamma2':np.linspace(-0.5,0.5,100),
'main_deflector_parameters_gamma':np.linspace(0,3,100),
'main_deflector_parameters_e1':np.linspace(-0.5,0.5,100),
'main_deflector_parameters_e2':np.linspace(-0.5,0.5,100),
'main_deflector_parameters_center_x':np.linspace(-0.5,0.5,100),
'main_deflector_parameters_center_y':np.linspace(-0.5,0.5,100),
}

N_cols=4
fig,ax = pl.subplots(2,N_cols,figsize=(20,10))
fig2,ax2 = pl.subplots(1,figsize=(8,5))
for n_i,p_i in enumerate(prior_db['parameter']):
    x = n_i%N_cols
    y = np.floor(n_i/N_cols).astype('int')
    outside_prior = (network_pred_mu_db[p_i].to_numpy()<(prior_db_indx.loc[p_i,'mean']-prior_db_indx.loc[p_i,'std'])) |\
                    (network_pred_mu_db[p_i].to_numpy()>(prior_db_indx.loc[p_i,'mean']+prior_db_indx.loc[p_i,'std']))
    ax[y,x].plot(X_plot_dict[p_i],norm.pdf(np.array([X_plot_dict[p_i]]*sum(1-outside_prior)).T,
                                                loc=network_pred_mu_db[p_i].to_numpy()[~outside_prior],
                                                scale=network_std_db[p_i].to_numpy()[~outside_prior]),c='k',alpha=0.1)
    ax[y,x].plot(X_plot_dict[p_i],norm.pdf(np.array([X_plot_dict[p_i]]*sum(outside_prior)).T,
                                                loc=network_pred_mu_db[p_i].to_numpy()[outside_prior],
                                                scale=network_std_db[p_i].to_numpy()[outside_prior]),c='red',alpha=0.1)
    #Multiplying the prior by 5 so it can be seen:
    ax[y,x].plot(X_plot_dict[p_i],5*norm.pdf(X_plot_dict[p_i],prior_db_indx.loc[p_i,'mean'],
                                    prior_db_indx.loc[p_i,'std']),c='blue')
    ax[y,x].legend(handles=[
    mpatches.Patch(color='k', label='Paltas Posterior'),
    mpatches.Patch(color='red', label=f'Paltas Posterior ({int(np.round(100*sum(outside_prior)/len(outside_prior)))}% outside $1\sigma$ prior)'),
    mpatches.Patch(color='blue', label='Training Prior')],fontsize=8)
    ax[y,x].set_xlabel(p_i)
    ax2.scatter(np.median(network_std_db[p_i].to_numpy()),8-n_i,label=p_i)
    ax2.scatter(prior_db_indx.loc[p_i,'std'],8-n_i,marker='x',label='_nolegend_',c='red')

ax2.set_xlim(left=0)
ax2.set_title('$\sigma_{learnt}$ vs $\sigma_{prior}$',fontsize=18)
ax2.legend()
pl.show()

In [None]:
metadata_val=pd.read_csv(training_directory+'/validation/1/metadata.csv')
bright_source_indx = np.where(metadata_val['source_parameters_mag_app']<25)[0]
final_epoch_n = np.max(list(network_predictions_dict.keys()))
network_truth_db = pd.DataFrame(network_predictions_dict[final_epoch_n][0],columns=learning_params)
network_pred_mu_db = pd.DataFrame(network_predictions_dict[final_epoch_n][1],columns=learning_params)
network_std_db = pd.DataFrame(network_predictions_dict[final_epoch_n][2],columns=learning_params)

#Assert that the true parameters from the metadata are equal to those I'm getting from gen_network_predictions, to make sure there hasn't been any reshuffling
#and to make sure I'm comparing the properties of the same objects. The 'round' is to remove the problem of some being saved as float32 and others as float64 files.
assert (np.round(network_truth_db[learning_params],3)==np.round(metadata_val[learning_params],3).astype('float32')).all().all()

property_list = ['source_parameters_mag_app','lens_light_parameters_mag_app','main_deflector_parameters_theta_E',
#                 'main_deflector_parameters_z_lens','lens_light_parameters_z_source','source_parameters_z_source',
                'main_deflector_parameters_e1','main_deflector_parameters_e2']
fig,ax = pl.subplots(len(property_list),2,figsize=(8,3*len(property_list)))
for p_i,image_property in enumerate(property_list):
    ax[p_i,0].scatter(
                metadata_val[image_property],
                network_pred_mu_db['main_deflector_parameters_theta_E']-network_truth_db['main_deflector_parameters_theta_E'],s=1)
    ax[p_i,0].set_ylabel('Pred-Truth',fontsize=15)
    ax[p_i,1].scatter(metadata_val[image_property],network_std_db['main_deflector_parameters_theta_E'],s=1)
    ax[p_i,1].set_ylabel('$\sigma_{network}$',fontsize=15)
    for r_i in range(2): ax[p_i,r_i].set_xlabel(image_property,fontsize=10)

pl.tight_layout()
pl.show()

In [None]:
bins_dict = {'main_deflector_parameters_theta_E':np.arange(-3,3.5,0.5),
'main_deflector_parameters_gamma1':np.arange(-2.5,3,0.5),
'main_deflector_parameters_gamma2':np.arange(-2.5,3,0.5),
'main_deflector_parameters_gamma':np.arange(-3,3.5,0.5),
'main_deflector_parameters_e1':np.arange(-2,2.5,0.5),
'main_deflector_parameters_e2':np.arange(-2,2.5,0.5),
'main_deflector_parameters_center_x':np.arange(-3,3.5,0.5),
'main_deflector_parameters_center_y':np.arange(-3,3.5,0.5),
}
fig,ax = pl.subplots(1,len(network_pred_mu_db.columns),figsize=(5*len(network_pred_mu_db.columns),5))
for n_ii,c_i in enumerate(network_pred_mu_db.columns):
    hist_dict = {'density':True,'bins':bins_dict[c_i]}
    ax[n_ii].hist(((network_pred_mu_db[c_i]-prior_db_indx.loc[c_i,'mean'])/prior_db_indx.loc[c_i,'std']),**hist_dict)
    ax[n_ii].hist(((network_pred_mu_db[c_i]-prior_db_indx.loc[c_i,'mean'])/prior_db_indx.loc[c_i,'std'])[bright_source_indx],fill=False,edgecolor='k',**hist_dict)
    ax[n_ii].set_xlabel(c_i,fontsize=12)
    ax[n_ii].set_ylabel('(Pred-$\mu_{train}$)/$\sigma_{train}$',fontsize=12)
    ax[n_ii].legend(['Full','$m_{source}$<25'])
pl.tight_layout()
pl.show()

bins_dict2 = {'main_deflector_parameters_theta_E':np.arange(-0.2,0.22,0.02),
'main_deflector_parameters_gamma1':np.arange(-0.16,0.18,0.02),
'main_deflector_parameters_gamma2':np.arange(-0.16,0.18,0.02),
'main_deflector_parameters_gamma':np.arange(-0.5,0.6,0.1),
'main_deflector_parameters_e1':np.arange(-0.16,0.18,0.02),
'main_deflector_parameters_e2':np.arange(-0.16,0.18,0.02),
'main_deflector_parameters_center_x':np.arange(-0.16,0.18,0.02),
'main_deflector_parameters_center_y':np.arange(-0.16,0.18,0.02),
}
fig,ax = pl.subplots(1,len(network_pred_mu_db.columns),figsize=(5*len(network_pred_mu_db.columns),5))
for n_ii,c_i in enumerate(network_pred_mu_db.columns):
    hist_dict = {'density':True,'bins':bins_dict2[c_i]}
    ax[n_ii].hist((network_pred_mu_db[c_i]-network_truth_db[c_i]),**hist_dict)
    ax[n_ii].hist((network_pred_mu_db[c_i]-network_truth_db[c_i])[bright_source_indx],fill=False,edgecolor='k',**hist_dict)
    ax[n_ii].set_xlabel(f'Pred-Truth \n {c_i}',fontsize=12)
    ax[n_ii].set_ylabel('Probability Density',fontsize=12)
    ax[n_ii].legend(['Full','$m_{source}$<25'])
pl.tight_layout()
pl.show()

In [None]:
prop_e_list = ['main_deflector_parameters_e1','main_deflector_parameters_e2']
fig,ax=pl.subplots(1,2,figsize=(10,5))
for prop in prop_e_list:
    ax[0].scatter(network_truth_db[prop],network_pred_mu_db[prop],label=prop,alpha=0.5)
    ax[0].set_xlabel('Truth');ax[0].set_ylabel('Prediction $\mu$')
    ax[1].scatter(network_truth_db[prop],network_std_db[prop],label=prop,alpha=0.5)
    ax[1].set_xlabel('Truth');ax[1].set_ylabel('Prediction $\sigma$')
    ax[0].set_xlim(-0.4,0.4)
    ax[0].set_ylim(-0.4,0.4)
    ax[0].axis('equal')
ax[0].legend()
pl.show()
corner.corner(network_truth_db[prop_e_list],labels=prop_e_list,range=[(-0.3,0.3),(-0.3,0.3)],
             truths=[0,0]) #Highlighting zero-ellipticity on the plot
pl.tight_layout()
pl.show()
#
corner.corner(network_truth_db)
pl.show()
print(network_truth_db.columns)

In [None]:
prior_db

## Plot Network Output Distributions
We now plot the distributions of the network predictions, and compare those to the ground-truth

In [None]:
db_columns = [elem.replace('main_deflector_parameters_','') for elem in learning_params]
Error_db = {elem:pd.DataFrame(columns=db_columns) for elem in network_predictions_dict.keys()}

def RMS_error_func(pred,truth,db_columns):
    #Calculates RMS error
    RMS_error =np.sqrt(np.mean((pred-truth)**2,axis=0))
    print(f'{len(RMS_error)} dimensional output') 
    return {db_columns[i]:RMS_error[i] for i in range(len(db_columns))}

def MAE_func(pred,truth,db_columns):
    #Calculates mean absolute error
    MAE_error = np.mean(abs(pred-truth),axis=0) 
    print(f'{len(MAE_error)} dimensional output') 
    return {db_columns[i]:MAE_error[i] for i in range(len(db_columns))}

#RMS_error_func(network_predictions_dict[epoch_i][0],network_predictions_dict[epoch_i][1])

for epoch_i in network_predictions_dict.keys():
    Error_db[epoch_i] = pd.concat([Error_db[epoch_i],pd.DataFrame(data=MAE_func(network_predictions_dict[epoch_i][0],
                                                                                network_predictions_dict[epoch_i][1],
                                                                                db_columns),
                                                                index=np.array(['MAE']))])
    Error_db[epoch_i] = pd.concat([Error_db[epoch_i],pd.DataFrame(data=RMS_error_func(network_predictions_dict[epoch_i][0],
                                                                                      network_predictions_dict[epoch_i][1],
                                                                                      db_columns),
                                                                index=np.array(['RMS']))])

Error_db[epoch_i]

In [None]:
import imageio
label_kwargs = {'fontsize': 20}
range_dict = {'main_deflector_parameters_theta_E':(0,3),
              'main_deflector_parameters_gamma':(1,3),
              'main_deflector_parameters_gamma1':(-0.5,0.5),
              'main_deflector_parameters_gamma2':(-0.5,0.5),
              'main_deflector_parameters_e1':(-0.5,0.5),
              'main_deflector_parameters_e2':(-0.5,0.5),
              'main_deflector_parameters_center_x':(-0.5,0.5),
              'main_deflector_parameters_center_y':(-0.5,0.5)}
bins_corner=20
gif_images = []
for epoch_i in tqdm(network_predictions_dict.keys()):
    fig = pl.figure(figsize=(3*len(learning_params),3*len(learning_params)))
    corner_kwargs_dict = {'fig':fig,'bins':bins_corner,'range':[range_dict[elem] for elem in learning_params]}
    corner.corner(network_predictions_dict[epoch_i][0],color='k',**corner_kwargs_dict)
    corner.corner(network_predictions_dict[epoch_i][1],color='red',\
                labels=['$'+elem+'$' for elem in corner_param_print],\
                label_kwargs=label_kwargs,
                **corner_kwargs_dict)
    pl.legend(['Truth','Pred'])
    pl.tight_layout()
#To save as a gif:
    pl.suptitle(f'Epoch {epoch_i}',fontsize=25,fontweight='bold')
    try:
        pl.savefig(f'{model_directory}/corner_plots/corner_plot_evolution_{epoch_i}.png')
    except:
        os.mkdir(f'{model_directory}/corner_plots/')
        pl.savefig(f'{model_directory}/corner_plots/corner_plot_evolution_{epoch_i}.png')
#    pl.show()
    pl.close()
    corner_i = imageio.imread(f'{model_directory}/corner_plots/corner_plot_evolution_{epoch_i}.png')
    gif_images.append(corner_i)

imageio.mimsave(f'{model_directory}/corner_plots/corner_plot_evolution.gif', gif_images,duration=2)


## Load Model Outputs
The hyperparameters of the training set are loaded (to use as an interim prior in the hierarchical inference), along with the network predictions for the validation set

In [None]:
train_mean = np.array(pd.read_csv(glob.glob(f'{training_directory}/**/norm*',recursive=True)[0])['mean']) 
train_scatter = np.array(pd.read_csv(glob.glob(f'{training_directory}/**/norm*',recursive=True)[0])['std']) 

#Since we are using a diagonal covariance matrix, the precision matrix is the diagonal matrix of
#the (elementwise) values of 1/std^2. In general however it is inv(cov_matrix).
final_epoch = max(list(network_predictions_dict.keys()))
network_means = network_predictions_dict[final_epoch][1][:,:].astype('float64')              
network_prec = network_predictions_dict[final_epoch][3][:,:,:].astype('float64')

In [None]:
np.shape(network_means)

# Hierarchical Inference
The following performs hierarchical inference to retrieve the population hyperparameters of the validation set, assuming a diagonal covariance matrix.

In [None]:
learning_params[5]

In [None]:
np.sum(np.isnan(chain_orig))

In [None]:
import warnings
def plot_sampler_properties(plot_evolution=False,plot_post_burnin=False,burnin=1000,burnout=None,
                           learning_params_for_HI=None):
    warnings.filterwarnings('ignore',category=DeprecationWarning)
    #[walker,evolution_number,property]
    N_cols = len(learning_params_for_HI) #Needs to stay as N_params, as one row is the mean, the other is the sigma (and the titles are defined as such.)
    N_rows = np.ceil(2*len(learning_params_for_HI)/N_cols).astype('int')
    print(N_cols,N_rows)
    if plot_evolution:
        fig,ax = pl.subplots(N_rows,N_cols,figsize=(5*N_cols,5*N_rows))
        for prop_i in range(2*len(learning_params_for_HI)):
            x = prop_i%N_cols
            y = np.floor(prop_i/N_cols).astype('int')
            param_i = learning_params_for_HI[prop_i%len(learning_params_for_HI)]
            for c_i in range(40):
                ax[y,x].plot(chain_orig[c_i,:,prop_i],alpha=0.1,c='k')
                ax[y,x].set_title('$\mu$ ('*(y==0)+'$\sigma$ ('*(y==1)+learning_params_for_HI[prop_i%N_cols].replace('main_deflector_parameters_','')+')',
                                fontsize=16)
                ax[y,x].set_xlim(burnin,burnout)
            if prop_i<len(learning_params_for_HI):
                    ax[y,x].plot(ax[y,x].get_xlim(),
                                 [prior_db_indx.loc[param_i]['mean'],
                                  prior_db_indx.loc[param_i]['mean']],
                                 c='red',label='prior mean',linewidth=5)
            else:
                    ax[y,x].plot(ax[y,x].get_xlim(),
                                 np.log(
                                 [prior_db_indx.loc[param_i]['std'],
                                  prior_db_indx.loc[param_i]['std']]),
                                 c='red',label='prior std',linewidth=5)
            ax[y,x].legend(loc='lower right')
#            ax[y,x].set_ylim(-1,1)
        pl.tight_layout()
        pl.show()
    if plot_post_burnin:
        print("NOTE: The sigma are in natural logarithms, not in log10.")
        fig,ax = pl.subplots(N_rows,N_cols,figsize=(5*N_cols,5*N_rows))
        for prop_i in range(2*len(learning_params_for_HI)):
            prop_i_val = learning_params_for_HI[prop_i%len(learning_params_for_HI)]
            x = prop_i%N_cols
            y = np.floor(prop_i/N_cols).astype('int')
            if prop_i<len(learning_params_for_HI):
                bins=np.linspace(
                             prior_db_indx.loc[prop_i_val]['mean']-prior_db_indx.loc[prop_i_val]['std'],
                             prior_db_indx.loc[prop_i_val]['mean']+prior_db_indx.loc[prop_i_val]['std'],
                             50)
                ax[y,x].hist(chain_orig[:,burnin:burnout,prop_i].flatten(),bins=bins)
                ax_ylim = ax[y,x].get_ylim()
                ax[y,x].plot([prior_db_indx.loc[prop_i_val]['mean']]*2,ax_ylim)
            else:
                bins=50
                ax[y,x].hist(chain_orig[:,burnin:burnout,prop_i].flatten(),bins=bins)
                ax_ylim = ax[y,x].get_ylim()
                ax[y,x].plot(np.log([prior_db_indx.loc[prop_i_val]['std']]*2),ax_ylim)
            ax[y,x].set_title('$\mu$ ('*(y==0)+'$\sigma$ ('*(y==1)+learning_params_for_HI[prop_i%N_cols].replace('main_deflector_parameters_','')+')')
        pl.tight_layout()
        pl.show()

plot_sampler_properties(plot_evolution=True,
                        plot_post_burnin=True,
                        burnin=1000,burnout=None,
                        learning_params_for_HI = ['main_deflector_parameters_theta_E',
                                                  'main_deflector_parameters_gamma',
#                                                 'main_deflector_parameters_gamma1',
#                                                 'main_deflector_parameters_gamma2'])
#                                                  'main_deflector_parameters_center_x',
#                                                  'main_deflector_parameters_center_y',
#                                                  'main_deflector_parameters_e1'])
                                                  'main_deflector_parameters_e2'])

In [None]:
from importlib import reload  # Python 3.4+
hierarchical_inference = reload(hierarchical_inference)

In [None]:
#Due to the current implementation of hierarchical_inference.ProbabilityClassAnalytical.log_post_omega, we assert
#that the covariance matrix must be diagonal for now.
assert loss_function=='diag'
burnin = int(1000)
n_samps = 3000
#Saving arrays for hierarchical inference:
try:
    os.mkdir(f'{model_directory}/mcmc_files/')
except:
    pass
np.save(model_directory+'/mcmc_files/network_means.npy',network_means)
np.save(model_directory+'/mcmc_files/network_prec.npy',network_prec)
np.save(model_directory+'/mcmc_files/train_mean.npy',train_mean)
np.save(model_directory+'/mcmc_files/train_scatter.npy',train_scatter)
np.save(model_directory+'/mcmc_files/learning_params.npy',learning_params)
prior_db_indx.to_csv(model_directory+'/mcmc_files/prior_db_indx.csv')

#The MCMC hierarchical inference now takes place in ./run_mcmc.py:
print(f"addqueue -c '45min' -m  3 /mnt/users/hollowayp/python11_env/bin/python3.11 ./run_mcmc.py {model_directory} {n_samps} {num_params}")
!addqueue -c '45min' -m  3 /mnt/users/hollowayp/python11_env/bin/python3.11 ./run_mcmc.py $model_directory $n_samps $num_params
print('Note: run_mcmc assumes there are 1000 lenses in the test set - otherwise the reshaping will be wrong.')
print('Settings in run_mcmc.py:'+\
        '\n1) The mu values must be within 5-sigma of the training prior. The mu and sigma values are HARDCODED and need to be UPDATED with each new training set.'+\
        '\n2) The sigma values must be between 0.001 and 2. (natural-logged, in np.log space)'+\
        '\n3) The mu walkers are initialised within 5 sigma of the training prior'+\
        '\n4) The sigma walkers are initialsed between 1% of the sigma and 3-sigma.')

print(
    '\n A few notes on MCMC:'+\
    '\n 1) I think this is broadly evaluating C5 in the appendix of Wagner Carena 2021. I think the "normalising factor" terms in C6 '+\
        'are ignored? Or, it is evaluating C6, not integrating at all, just taking the product of the 3 gaussians?? But gaus_prod_analytical does look like an integral'+\
    '\n 2) mu_omega_i: I think this comes from the mean of the parameters in the training set, i.e. train_mean above. '+\
        'This acts as the prior.'
    '\n 3) mu_omega: I think this is updated during the MCMC sampling. It starts out as cur_state_mu, defined above, then evolves over'+\
        'time. Note that p(omega|{d}) $\propto$ p(xi_i|Omega)...extraterms, so the equation must be solved implicitly. I.e. omega'+\
        'appears on both sides. I think the MCMC evolves the values of omega, sampling the parameter space.'
    '\n 4) mu_pred: I think these are the predictions from the network, dependent on the particular image being considered. The '+\
        'network outputs a mean and an uncertainty. mu_pred is the list of all the means.'+\
    '\n 5) Have also set eval_func_omega to return 0, not the sum of the hyperparameters.'
    )

In [None]:
list(prior_db['parameter'])

In [None]:
glob.glob(f'{model_directory}/mcmc_files/mcmc_chains*.npy')

In [None]:
#Loading up the latest mcmc chain:
latest_chain = np.max([int(elem.split('mcmc_chains_')[1].replace('.npy','')) for elem in glob.glob(f'{model_directory}/mcmc_files/mcmc_chains*.npy')])
#Loading the complete chains:
latest_chain_filename = f'{model_directory}/mcmc_files/mcmc_chains_{latest_chain}.npy'
#latest_chain_filename=f'{model_directory}/mcmc_files/mcmc_chains_1700045646.npy'
print(f'Loading {latest_chain_filename}, last modified at {datetime.datetime.utcfromtimestamp(os.path.getmtime(latest_chain_filename))}')
chain_orig = np.load(latest_chain_filename) #[walker,step_number,parameters_x2]
#Removing the burn-in:
burnin=1000
chain = chain_orig[:,burnin:,:].reshape((-1,2*num_params))[:,0:num_params] #Just using the means for now

In [None]:
#Plotting loss evolution:
def plot_loss_evolution(directory_to_save_model):
    loss_db = pd.read_csv(directory_to_save_model+'/loss_function_db.csv')
    pl.plot(loss_db['epoch'],loss_db['loss'],c='k')
    pl.plot(loss_db['epoch'],loss_db['val_loss'],c='blue')
    pl.xlabel('Epoch',fontsize=12)
    pl.ylabel('Loss',fontsize=12)
    pl.legend(['Training Loss','Validation Loss'])
    pl.tick_params(labelsize=10)
    pl.tight_layout()
    pl.ylim(-3,5)
    pl.sh ow()

plot_loss_evolution(model_directory)


# Results
We plot the results of the hierarchical inference here: the ground-truth is plotted as solid black lines, along with the population posterior 

In [None]:
labels_kwargs = {'fontsize':20}
hist_kwargs = {'density':False,'color':'orange','lw':3}
fig,ax = pl.subplots(len(learning_params),len(learning_params),figsize=(3*len(learning_params),3*len(learning_params)))

f = 10 #Bounds of corner plot, in units of sigma

corner.corner(chain,
                labels=['$\mu_{'+elem+'}$' for elem in corner_param_print],
                range=[(prior_db_indx.loc[elem]['mean']-f*prior_db_indx.loc[elem]['std'],
                        prior_db_indx.loc[elem]['mean']+f*prior_db_indx.loc[elem]['std']) for elem in learning_params],
                fig=fig,
                show_titles=False,
                plot_datapoints=True,
                label_kwargs=labels_kwargs,\
                levels=[0.68,0.95],
                color='orange',
                fill_contours=True,
                hist_kwargs=hist_kwargs,
                title_fmt='.2f',\
                #truths=[1.1,0,0,0,0,0,0],
                truth_color='k',\
                max_n_ticks=3,
                bins=50,
                s=100
                )

sig_learning_params = np.array(prior_db['std'])#np.array([0.15,0.05,0.05,0.1,0.1,0.1,0.16,0.16])
loc_learning_params = np.array(prior_db['mean'])#np.array([1.1,0,0,2,0,0,0,0])
for i in range(len(learning_params)):
    for j in range(len(learning_params)):
        ax[i,j].set_xlim((loc_learning_params-f*sig_learning_params)[j],(loc_learning_params+f*sig_learning_params)[j])
        if i!=j:
            ax[i,j].set_ylim((loc_learning_params-f*sig_learning_params)[i],(loc_learning_params+f*sig_learning_params)[i])

pl.tight_layout()
pl.show()

# Endnote
We have now 1) generated simulated lensed images, 2) trained a neural network to model these images, 3) applied the trained network to a different set of images and 4) run hierarchical inference to infer population-level parameters.\
You may note that the final results could be improved - the configuration settings used here were chosen for speed rather than precision. For science-level results, the following should be changed:
1) Increase the training-set size for the neural network,
2) Increase the number of epochs the network trains for,
3) Increase the number of iterations (and burn-in) of the MCMC used for hierarchical inference.
