In [None]:
from astropy.visualization import simple_norm
import corner
import jax
import jax.numpy as jnp
from matplotlib.patches import ConnectionPatch, FancyBboxPatch, Ellipse
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats, signal

from paltax import input_pipeline
from paltax import train

# Reproduce the paper Figures
__To generate all of the figures in this notebook, you will have to download figure_data.zip from [TODO](https://google.com) and expand it into this folder.__

This notebook is built to work with `paltax` version `1.0.0`. Later versions __may not work__ with this code.

Figures:
- [Figure 1](#figure_1)
- [Figure 2](#figure_2)
- [Figure 3](#figure_3)
- [Figure 4](#figure_4)
- [Figure_5](#figure_5)
- [Figure 6](#figure_6)
- [Figure 7](#figure_7)
- [Figure 8](#figure_8)
- [Figure 9](#figure_9)

## Figure 1 <a id='figure_1'></a>

In [None]:
# Load the data we need for the figures
mock_index = 5
image_hier = np.load('figure_data/image_hier.npy')
image_draws_list = np.load('figure_data/image_draws_list_5.npy')
def create_flexible_grid(images, n_row, n_col):
    placeholder = images[:n_row * n_col].reshape((n_row, n_col, *images.shape[1:]))
    return np.vstack([np.hstack(image) for image in placeholder])
    
# Create the tableau on which we will add our plots.
figsize = (16,10)
fig = plt.figure(figsize=figsize, dpi=100)
fontsize = 15
spine_width = 2
method_colors = ['#d95f02', '#1b9e77']

# Use the image mock to set the normalization for images that will be drawn.
image_mock = image_hier[mock_index] * np.mean(np.std(image_draws_list[0].reshape(image_draws_list[0].shape[0],-1), axis=1))
im_norm = simple_norm(image_mock,stretch='asinh')
    
# Add the samples from the prior distribution.
image_box_size = 0.08
n_row = 2
n_col = 3
prior_samples_ax = fig.add_axes([0.1, 0.7, figsize[1]/figsize[0]*image_box_size*n_col, image_box_size*n_row])
prior_samples_ax.imshow(create_flexible_grid(image_draws_list[0], n_row, n_col), norm=im_norm, cmap='plasma')
prior_samples_ax.get_xaxis().set_visible(False)
prior_samples_ax.get_yaxis().set_visible(False)
prior_samples_ax.text(128*n_col/2, -20, 
                      r'Draw from $\theta_n \sim p(\theta|\Omega_0)$,' + '\n' + r'Simulate $x_n \sim g(x|\theta_n)$',
                      fontsize=fontsize, ha='center', color='black')

# Add the samples from the sequential proposal distribution.
seq_samples_ax = fig.add_axes([0.30 + figsize[1]/figsize[0]*image_box_size*1.5, 0.40, 
                               figsize[1]/figsize[0]*image_box_size*n_col , image_box_size*n_row])
seq_samples_ax.imshow(create_flexible_grid(image_draws_list[1], n_row, n_col), norm=im_norm, cmap='plasma')
seq_samples_ax.get_xaxis().set_visible(False)
seq_samples_ax.get_yaxis().set_visible(False)
seq_samples_ax.text(128*n_col/2, -20, 
                    r'Draw from $\theta_n \sim p(\theta|\Omega_i)$,' + '\n' + r'Simulate $x_n \sim g(x|\theta_n)$', 
                    fontsize=fontsize, ha='center', color='black')

# Add initial prior distribution
prior_dist_ax = fig.add_axes([-0.04, 0.7, figsize[1]/figsize[0]*image_box_size*2, image_box_size*2])
prior_dist_ax.get_xaxis().set_visible(False)
prior_dist_ax.get_yaxis().set_visible(False)
for spine in prior_dist_ax.spines.values():
    spine.set_linewidth(spine_width)
prior_dist_ax.text(0.5, 1.0 + 20/256, r'Specify prior', fontsize=fontsize, ha='center', color='black')
prior_dist_ax.text(0.5, 0.85, r'$p(\theta|\Omega_0)$', fontsize=fontsize, ha='center', color='black')
prior_dist = fig.add_axes([-0.04 + figsize[1]/figsize[0]*image_box_size*0.4, 0.72, figsize[1]/figsize[0]*image_box_size*1.2, image_box_size*1.2])
ellipse = Ellipse(xy=(1.0, 1.0), width=1.5, height=1.5, angle=0, edgecolor='none', facecolor='grey')
prior_dist.add_patch(ellipse)
prior_dist.set_xlim([-0.1,2.1])
prior_dist.set_ylim([-0.1,2.1])
prior_dist.set_yticks([])
prior_dist.set_xticks([])
prior_dist.spines['top'].set_visible(False)
prior_dist.spines['right'].set_visible(False)

# Add the two types of loss function.
loss_ax = fig.add_axes([0.29, 0.7, figsize[1]/figsize[0]*image_box_size*5 , image_box_size*2])
loss_ax.text(0.5, 1.0 + 20/256, r'Train model to predict posterior:' + '\n' + r'$q_{\phi}(\theta|x,\Omega_0) \ \to \ p(\theta|x,\Omega_0)$', fontsize=fontsize, ha='center', color='black')
loss_ax.get_xaxis().set_visible(False)
loss_ax.get_yaxis().set_visible(False)
loss_ax.axhline(0.5, 0.1, 0.9, c='k')
loss_ax.text(0.5, 0.75, 
             r'NPE: $\mathcal{L}(\phi) = - \sum \log q_\phi$',
             fontsize=fontsize*0.9, ha='center', va='center', color=method_colors[0])
loss_ax.text(0.5, 0.25, 
             r'SNPE: $\mathcal{L}(\phi) = - \sum \log \left[ q_\phi \times \mathrm{Reweighting} \right]$', 
             fontsize=fontsize*0.9, ha='center', va='center', color=method_colors[1])
for spine in loss_ax.spines.values():
    spine.set_linewidth(spine_width)
    
# Add the prediction on the observed image
pred_img_ax = fig.add_axes([0.58, 0.7, figsize[1]/figsize[0]*image_box_size*5.5, image_box_size*2])
pred_img_ax.get_xaxis().set_visible(False)
pred_img_ax.get_yaxis().set_visible(False)
pred_img_ax.text(0.5, 1.0 + 20/256, 'Predict posterior on \n' + r'observed image $x_\mathrm{obs}$', fontsize=fontsize, ha='center', color='black')
for spine in pred_img_ax.spines.values():
    spine.set_linewidth(spine_width)
# Add the observed image
obs_img_ax = fig.add_axes([0.59, 0.7 + image_box_size*0.25, figsize[1]/figsize[0]*image_box_size*1.5, image_box_size*1.5])
obs_img_ax.imshow(image_mock, norm=im_norm, cmap='plasma')
obs_img_ax.get_xaxis().set_visible(False)
obs_img_ax.get_yaxis().set_visible(False)
# Add the model
model_ax = fig.add_axes([0.695, 0.7 + image_box_size*0.5, figsize[1]/figsize[0]*image_box_size, image_box_size],
                        facecolor='grey')
model_ax.get_xaxis().set_visible(False)
model_ax.get_yaxis().set_visible(False)
model_ax.text(0.5, 0.5, 'Trained\nModel',
              fontsize=fontsize, ha='center', va='center', color='white')
#Add the prediction
pred_img_ax.text(0.835, 0.85, r'$q_{\phi}(\theta|x_\mathrm{obs},\Omega_0)$', fontsize=fontsize*0.9, ha='center', va='center', color='k')
pred_q = fig.add_axes([0.782, 0.72, figsize[1]/figsize[0]*image_box_size*1.2, image_box_size*1.2])
ellipse = Ellipse(xy=(0.9, 1.2), width=0.7, height=1.2, angle=30, edgecolor='none', facecolor='k')
pred_q.add_patch(ellipse)
pred_q.set_xlim([-0.1,2.1])
pred_q.set_ylim([-0.1,2.1])
pred_q.set_yticks([])
pred_q.set_xticks([])
pred_q.spines['top'].set_visible(False)
pred_q.spines['right'].set_visible(False)

# Add the sequential proposal stage
seq_prop_ax = fig.add_axes([0.59, 0.40, figsize[1]/figsize[0]*image_box_size*3.5, image_box_size*2])
seq_prop_ax.get_xaxis().set_visible(False)
seq_prop_ax.get_yaxis().set_visible(False)
seq_prop_ax.text(0.5, 1.0 + 20/256, 'Generate new proposal\n' + r'distribution: $p(\theta | \Omega_i)$', fontsize=fontsize, ha='center', color='black')
for spine in seq_prop_ax.spines.values():
    spine.set_linewidth(spine_width)
# Add the proposal transition
snpe_dist_axes = []
for x_coord, y_coord, color in zip(
    [0.59 + figsize[1]/figsize[0]*image_box_size*2.05, 0.59 + figsize[1]/figsize[0]*image_box_size*0.25],
    [0.42, 0.42], ['k', 'grey']):
    pred_q = fig.add_axes([x_coord, y_coord, figsize[1]/figsize[0]*image_box_size*1.2, image_box_size*1.2])
    snpe_dist_axes.append(pred_q)
    ellipse = Ellipse(xy=(0.9, 1.2), width=0.7, height=1.2, angle=30, edgecolor='none', facecolor=color)
    pred_q.add_patch(ellipse)
    pred_q.set_xlim([-0.1,2.1])
    pred_q.set_ylim([-0.1,2.1])
    pred_q.set_yticks([])
    pred_q.set_xticks([])
    pred_q.spines['top'].set_visible(False)
    pred_q.spines['right'].set_visible(False)
    
# Set NPE boundary
npe_bound_ax = fig.add_axes([0.08, 0.65, 0.79, 0.29], zorder=-1)
npe_bound_ax.patch.set_alpha(0.0)
npe_bound_ax.get_xaxis().set_visible(False)
npe_bound_ax.get_yaxis().set_visible(False)
for spine in npe_bound_ax.spines.values():
    # spine.set_linewidth(spine_width*2)
    # spine.set_color(method_colors[0])
    spine.set_visible(False)
npe_bound_ax.text(0.112, 0.07, 'NPE', fontsize=fontsize*1.5, ha='center', va='center', color=method_colors[0], weight='bold')
fill_box = FancyBboxPatch((0, 0), 1, 1,
                        boxstyle="round,pad=-0.0040,rounding_size=0.1",
                        ec='white', fc='white', clip_on=False, lw=spine_width*2,
                        mutation_aspect=1,
                        transform=npe_bound_ax.transAxes)
npe_bound_ax.add_patch(fill_box)
fill_box = FancyBboxPatch((0, 0), 1, 1,
                        boxstyle="round,pad=-0.0040,rounding_size=0.1",
                        ec=method_colors[0], fc=method_colors[0], clip_on=False, lw=spine_width*2,
                        mutation_aspect=1, alpha=0.15,
                        transform=npe_bound_ax.transAxes)
npe_bound_ax.add_patch(fill_box)
    
# Set SNPE boundary
snpe_bound_ax = fig.add_axes([0.07, 0.39, 0.81, 0.57], zorder=-2)
snpe_bound_ax.patch.set_alpha(0.0)
snpe_bound_ax.get_xaxis().set_visible(False)
snpe_bound_ax.get_yaxis().set_visible(False)
for spine in snpe_bound_ax.spines.values():
    # spine.set_linewidth(spine_width*2)
    # spine.set_color(method_colors[1])
    spine.set_visible(False)
npe_bound_ax.text(0.112, -0.8, 'SNPE', fontsize=fontsize*1.5, ha='center', va='center', color=method_colors[1], weight='bold')
fill_box = FancyBboxPatch((0, 0), 1, 1,
                        boxstyle="round,pad=-0.0040,rounding_size=0.1",
                        ec=method_colors[1], fc=method_colors[1], clip_on=False, lw=spine_width*2,
                        mutation_aspect=1, alpha=0.15,
                        transform=snpe_bound_ax.transAxes)
snpe_bound_ax.add_patch(fill_box)

# Add the arrows
arrowstyle = 'simple,head_length=1.0,head_width=2.0,tail_width=0.7'
# Arrow from prior to prior sims
arrow = ConnectionPatch((1 + 0.02*6/2,0.5), (-0.02*6/4,0.5), coordsA='axes fraction', coordsB='axes fraction',
                        axesA=prior_dist_ax, axesB=prior_samples_ax, color=method_colors[0], lw=3, arrowstyle=arrowstyle, fill=True)
fig.add_artist(arrow)
arrow = ConnectionPatch((1 + 0.02*6/2,0.5), (-0.02*6/4,0.5), coordsA='axes fraction', coordsB='axes fraction',
                        axesA=prior_dist_ax, axesB=prior_samples_ax, color=method_colors[1], lw=3, arrowstyle=arrowstyle, fill=False)
fig.add_artist(arrow)
# Arrow from prior sims to loss box
arrow = ConnectionPatch((1 + 0.02*6/4,0.5), (-0.02,0.5), coordsA='axes fraction', coordsB='axes fraction',
                        axesA=prior_samples_ax, axesB=loss_ax, color=method_colors[0], lw=3, arrowstyle=arrowstyle, fill=True)
fig.add_artist(arrow)
arrow = ConnectionPatch((1 + 0.02*6/4,0.5), (-0.02,0.5), coordsA='axes fraction', coordsB='axes fraction',
                        axesA=prior_samples_ax, axesB=loss_ax, color=method_colors[1], lw=3, arrowstyle=arrowstyle, fill=False)
fig.add_artist(arrow)
# Arrow from loss box to predicted posterior
arrow = ConnectionPatch((1.02,0.5), (-0.02 * 6/5.5,0.5), coordsA='axes fraction', coordsB='axes fraction',
                        axesA=loss_ax, axesB=pred_img_ax, color=method_colors[0], lw=3, arrowstyle=arrowstyle, fill=True)
fig.add_artist(arrow)
arrow = ConnectionPatch((1.02,0.5), (-0.02 * 6/5.5,0.5), coordsA='axes fraction', coordsB='axes fraction',
                        axesA=loss_ax, axesB=pred_img_ax, color=method_colors[1], lw=3, arrowstyle=arrowstyle, fill=False)
fig.add_artist(arrow)
# Arrow from predicted posterior to sequential proposal.
connectionstyle = 'angle3,angleA=-90,angleB=0'
arrow = ConnectionPatch((0.90,-0.05), (1.05,0.5), coordsA='axes fraction', coordsB='axes fraction',
                        axesA=pred_img_ax, axesB=seq_prop_ax, color=method_colors[1], lw=2, arrowstyle=arrowstyle, connectionstyle=connectionstyle, fill=True)
fig.add_artist(arrow)
# Add arrow from sequential proposal to sequential samples.
arrow = ConnectionPatch((-0.04,0.5), (1.05,0.5), coordsA='axes fraction', coordsB='axes fraction',
                        axesA=seq_prop_ax, axesB=seq_samples_ax, color=method_colors[1], lw=2, arrowstyle=arrowstyle, fill=True)
fig.add_artist(arrow)
# Arrow from predicted posterior to sequential proposal.
connectionstyle = 'angle3,angleA=0,angleB=90'
arrow = ConnectionPatch((-0.04,0.5), (0.12,-0.05), coordsA='axes fraction', coordsB='axes fraction',
                        axesA=seq_samples_ax, axesB=loss_ax, color=method_colors[1], lw=3, arrowstyle=arrowstyle, connectionstyle=connectionstyle, fill=True)
fig.add_artist(arrow)
# Arrow from observed image to trained model.
arrowstyle = 'simple,head_length=0.5,head_width=0.8,tail_width=0.2'
arrow = ConnectionPatch((1.1,0.5), (-0.15,0.5), coordsA='axes fraction', coordsB='axes fraction',
                        axesA=obs_img_ax, axesB=model_ax, color='k', lw=2, arrowstyle=arrowstyle)
fig.add_artist(arrow)
# Arrow from trained model to posterior.
arrow = ConnectionPatch((1.15,0.5), (1.55,0.5), coordsA='axes fraction', coordsB='axes fraction',
                        axesA=model_ax, axesB=model_ax, color='k', lw=2, arrowstyle=arrowstyle)
fig.add_artist(arrow)
# Arrow from posterior to proposal.
arrow = ConnectionPatch((-0.15,0.5), (1.05,0.5), coordsA='axes fraction', coordsB='axes fraction',
                        axesA=snpe_dist_axes[0], axesB=snpe_dist_axes[1], color='k', lw=2, arrowstyle=arrowstyle)
fig.add_artist(arrow)

seq_prop_ax.text(0.25, 0.85, r'$p(\theta | \Omega_i)$', 
                 fontsize=fontsize*0.9, ha='center', va='center', color='black')
seq_prop_ax.text(0.75, 0.85, r'$q_{\phi}(\theta|x_\mathrm{obs}, \Omega_{0})$', 
                 fontsize=fontsize*0.9, ha='center', va='center', color='k')

plt.show()

## Figure 2 <a id='figure_2'></a>

In [None]:
# Lopad and seperate the data.
from paltax.TrainConfigs import train_config_npe_base
config = train_config_npe_base.get_config()
n_gpus_train = 4 #Number of gpus that were used in parallel when training the models.

save_metrics = np.load('figure_data/npe_comparison_metrics.npy', allow_pickle=True).item()
steps = np.arange(1, 501) * config.steps_per_epoch
rmse_metrics = save_metrics['rmse_metrics']
rho_metrics = save_metrics['rho_metrics']
loss_metrics = save_metrics['loss_metrics']
loss_ss_metrics = save_metrics['loss_ss_metrics']

# Plot the loss as a function of the number of images in the dataset.
fontsize = 15

models_to_plot = ['Fiducial' , '50k', '500k', '5M']
model_names = [r'Fiducial (Infinite Unique Images)', r'$5 \times 10^{4}$ Unique Images', 
               r'$5 \times 10^{5}$ Unique Images', r'$5 \times 10^{6}$ Unique Images']
line_styles = ['.-', '.-', '.-', '.-']
colors = ['#969696','#cbc9e2', '#9e9ac8', '#6a51a3']
step_scaling = [1.0] * len(models_to_plot)
total_images = [jnp.max(steps) * config.batch_size, 5e4, 5e5, 5e6]
    
fig, ax = plt.subplots(1, 2, figsize=(18,8), sharey=True, gridspec_kw={'hspace': 0.02,'wspace':0.04},dpi=100)

filter_size = 5

for model_index, model_key in enumerate(models_to_plot):
    loss_array = jnp.array(jax.tree_util.tree_leaves(loss_ss_metrics[model_key]))
    loss_array = signal.medfilt(loss_array, kernel_size=filter_size)
    ax[0].plot(steps[:len(loss_array)] * config.batch_size * n_gpus_train, loss_array, line_styles[model_index], 
               c=colors[model_index], lw=5, ms = 10, alpha=0.9)
    if model_index > 0:
        ax[1].plot(steps[:len(loss_array)] * config.batch_size * n_gpus_train / total_images[model_index], loss_array, '.-', 
                   c=colors[model_index], lw=5, ms = 10, alpha=0.9)

ax[0].legend(model_names, fontsize=fontsize)
ax[0].set_ylabel(r'Mean Loss on $\Sigma_\mathrm{sub}$', fontsize=fontsize)
ax[0].set_xlabel('Images Seen', fontsize=fontsize)
ax[0].set_xscale('log')
ax[1].set_xlabel('Passes Through Dataset', fontsize=fontsize)
ax[1].set_xscale('log')
for axis in ax:
    axis.tick_params(axis='both', which='both', labelsize=fontsize, length=fontsize/2, width=1.5)
    axis.set_ylim([0.42,0.65])
plt.show()

## Figure 3 <a id='figure_3'></a>

In [None]:
# Lopad and seperate the data.
from paltax.TrainConfigs import train_config_npe_base
config = train_config_npe_base.get_config()
n_gpus_train = 4 #Number of gpus that were used in parallel when training the models.

save_metrics = np.load('figure_data/npe_comparison_metrics.npy', allow_pickle=True).item()
steps = np.arange(1, 501) * config.steps_per_epoch
rmse_metrics = save_metrics['rmse_metrics']
rho_metrics = save_metrics['rho_metrics']
loss_metrics = save_metrics['loss_metrics']
loss_ss_metrics = save_metrics['loss_ss_metrics']

from paltax.TrainConfigs import train_config_exp_fast
from paltax.TrainConfigs import train_config_exp_slow
from paltax.TrainConfigs import train_config_linear
from paltax.TrainConfigs import train_config_linear_0p001
from paltax.TrainConfigs import train_config_constant

fontsize = 15
models_to_plot = ['Fiducial', 'Exponential Slow', 'Exponential Fast', 'Constant', 'Linear', 'Linear Small Base']
model_names = ['Fiducial (Warmup + Cosine Decay)', 'Exponential Decay: 0.99', 'Exponential Decay: 0.98', 
               r'Constant Learning Rate', r'Linear Decay, Learning Rate: $10^{-2}$', r'Linear Decay, Learning Rate: $10^{-3}$']
learning_rate_schedule = train.get_learning_rate_schedule(config, config.learning_rate)
lr_exp_fast = train.get_learning_rate_schedule(train_config_exp_fast.get_config(), config.learning_rate)
lr_exp_slow = train.get_learning_rate_schedule(train_config_exp_slow.get_config(), config.learning_rate)
lr_lin_01 = train.get_learning_rate_schedule(train_config_linear.get_config(), config.learning_rate)
lr_lin_001 = train.get_learning_rate_schedule(train_config_linear.get_config(), train_config_linear_0p001.get_config().learning_rate)
lr_const = train.get_learning_rate_schedule(train_config_constant.get_config(), config.learning_rate)
learning_rate_schedule_list = [learning_rate_schedule, lr_exp_slow, lr_exp_fast, lr_const, lr_lin_01, lr_lin_001]
colors = ['#969696', '#fdd0a2', '#fdae6b', '#fd8d3c', '#e6550d', '#a63603']
line_styles = ['.-', '.-', '.-', '.-', '.-', '.-', '.-']

fig, ax = plt.subplots(1, 2, figsize=(18,8), sharey=True, gridspec_kw={'hspace': 0.02,'wspace':0.04},dpi=100)

filter_size = 5

for model_index, model_key in enumerate(models_to_plot):
    loss_array = jnp.array(jax.tree_util.tree_leaves(loss_ss_metrics[model_key]))
    loss_array = signal.medfilt(loss_array, kernel_size=filter_size)
    ax[0].plot(steps[:len(loss_array)] * config.batch_size * n_gpus_train, loss_array, line_styles[model_index],
               c=colors[model_index], lw=3, ms = 10, alpha=0.9)
    schedule = learning_rate_schedule_list[model_index]
    ax[1].plot(schedule(steps[:len(loss_array)])* config.batch_size / 256, loss_array, '.-', 
               c=colors[model_index], lw=3, ms = 10, alpha=0.9)

ax[0].legend(model_names, fontsize=fontsize)
ax[0].set_ylabel(r'Mean Loss on $\Sigma_\mathrm{sub}$', fontsize=fontsize)
ax[0].set_xlabel('Images Seen', fontsize=fontsize)
ax[1].set_xlabel('Learning Rate', fontsize=fontsize)
ax[1].invert_xaxis()
ax[0].set_xscale('log')
for axis in ax:
    axis.tick_params(axis='both', which='both', labelsize=fontsize, length=fontsize/2, width=1.5)
plt.show()

## Figure 4 <a id='figure_4'></a>

In [None]:
# Lopad and seperate the data.
from paltax.TrainConfigs import train_config_npe_base
config = train_config_npe_base.get_config()
n_gpus_train = 4 #Number of gpus that were used in parallel when training the models.

save_metrics = np.load('figure_data/npe_comparison_metrics.npy', allow_pickle=True).item()
steps = np.arange(1, 501) * config.steps_per_epoch
rmse_metrics = save_metrics['rmse_metrics']
rho_metrics = save_metrics['rho_metrics']
loss_metrics = save_metrics['loss_metrics']
loss_ss_metrics = save_metrics['loss_ss_metrics']

fontsize = 15

models_to_plot = ['Fiducial', 'Resnet 18 Very Small', 'Resnet 18 Small', 'Resnet 18', 'Resnet 34', 'Resnet-D 50']
model_names = ['Fiducial (Resnet 50)', 'Resnet 18 Very Small', 'Resnet 18 Small', 'Resnet 18', 'Resnet 34', 'Resnet-D 50']
colors = ['#969696', '#c6dbef', '#9ecae1', '#6baed6', '#4292c6', '#2171b5', '#084594']
line_styles = ['.-', '.-', '.-', '.-', '.-', '.-', '.-']
step_scaling = [1.0, 0.0303, 0.0605, 0.242, 0.355, 1.0]
flops_50 = 5.24e8

fig, ax = plt.subplots(1, 2, figsize=(18,8), sharey=True, gridspec_kw={'hspace': 0.02,'wspace':0.04},dpi=100)

filter_size = 5

for model_index, model_key in enumerate(models_to_plot):
    loss_array = jnp.array(jax.tree_util.tree_leaves(loss_ss_metrics[model_key]))
    loss_array = signal.medfilt(loss_array, kernel_size=filter_size)
    ax[0].plot(steps[:len(loss_array)] * config.batch_size * n_gpus_train, loss_array, line_styles[model_index],
               c=colors[model_index], lw=3, ms = 10, alpha=0.9)
    ax[1].plot(steps[:len(loss_array)] * flops_50 * step_scaling[model_index]* n_gpus_train, loss_array, 
               line_styles[model_index], c=colors[model_index], lw=3, ms = 10, alpha=0.9)

ax[0].legend(model_names, fontsize=fontsize)
ax[0].set_ylabel(r'Mean Loss on $\Sigma_\mathrm{sub}$', fontsize=fontsize)
ax[0].set_xlabel('Images Seen', fontsize=fontsize)
ax[1].set_xlabel('FLOPs', fontsize=fontsize)
for axis in ax:
    axis.tick_params(axis='both', which='both', labelsize=fontsize, length=fontsize/2, width=1.5)
    axis.set_xscale('log')
plt.show()

## Figure 5 <a id='figure_5'></a>

In [None]:
# Load and process the loss data
save_metrics = np.load('figure_data/snpe_comparison_metrics.npy', allow_pickle=True).item()
log_post_metrics = save_metrics['log_post_metrics']
log_post_sub_metrics = save_metrics['log_post_sub_metrics']

log_post_sub_array = []
for mn in [f'Image {mod_num}' for mod_num in range(30)]:
    log_post_sub_array += [jnp.array(jax.tree_util.tree_leaves(log_post_sub_metrics[mn]))]
log_post_sub_array = jnp.array(log_post_sub_array)

log_post_array = []
for mn in [f'Image {mod_num}' for mod_num in range(30)]:
    log_post_array += [jnp.array(jax.tree_util.tree_leaves(log_post_metrics[mn]))]
log_post_array = jnp.array(log_post_array)

# Pull the relevant configuration files.
from paltax.TrainConfigs import train_config_npe_base
from paltax.TrainConfigs import train_config_snpe_base
config = train_config_npe_base.get_config()
snpe_config = train_config_snpe_base.get_config()
n_gpus_train = 4 #Number of gpus that were used in parallel when training the models.

steps = np.arange(1,501) * config.steps_per_epoch
steps_snpe = np.arange(1,50) * config.steps_per_epoch

fontsize = 15

colors = ['#d95f02', '#1b9e77', '#252525', 'grey']

fit_cut = 50
snpe_cut = 40
filter_size = 5
    
fig, ax = plt.subplots(1, 1, figsize=(15,8), sharey=True, gridspec_kw={'hspace': 0.02,'wspace':0.04},dpi=100)

# Extract the npe / snpe results and make a linear fit to the power-law performance
npe_loss_array = jax.tree_util.tree_leaves(log_post_sub_metrics['Fiducial'])
lin_fit = stats.linregress(np.log10(steps[fit_cut:] * config.batch_size), npe_loss_array[fit_cut:])
npe_loss_array = signal.medfilt(npe_loss_array, kernel_size=filter_size)
snpe_loss_array = signal.medfilt(np.mean(log_post_sub_array[:,:snpe_cut], axis=0), kernel_size=filter_size)

# Plot the npe results and the power law.
ax.plot(steps[:len(npe_loss_array)] * config.batch_size * n_gpus_train, npe_loss_array, '.-', 
        c=colors[0], lw=3, ms = 10, alpha=0.9)
start_step = np.log10(steps[fit_cut] * config.batch_size)
final_step = (snpe_loss_array[-1] - lin_fit[1])/lin_fit[0]
p_law_steps = np.linspace(start_step, final_step, 100)
p_law = lin_fit[0]*p_law_steps + lin_fit[1]
ax.plot(10**(p_law_steps)*n_gpus_train, p_law, '--', c=colors[2], lw=3)

# Plot the snpe results
ax.plot(steps_snpe[:snpe_cut] * config.batch_size * n_gpus_train, snpe_loss_array, '.-', c=colors[1], lw=3, ms = 10, 
        alpha=0.9)
ax.axhspan(snpe_loss_array[-1]*0.95, snpe_loss_array[-1]*1.05, alpha=0.2, color=colors[3])
ax.text(5.2e5,snpe_loss_array[-1]*1.018, r'Sequential Performance', color='white', fontsize=fontsize,
        bbox=dict(boxstyle="round",ec='#969696',fc='#969696'))

ax.set_ylabel(r'Mean Loss on $\Sigma_\mathrm{sub}$', fontsize=fontsize)
ax.set_xlabel('Images Seen', fontsize=fontsize)
ax.set_xscale('log')
ax.set_xlim([10**5*n_gpus_train,10**(11.2)*n_gpus_train])
ax.set_ylim([-0.3,0.17])
ax.tick_params(axis='both', which='both', labelsize=fontsize, length=fontsize/2, width=1.5)
ax.legend(['Fiducial', 'Fiducial Power-law', 'Sequential'], fontsize=fontsize, loc='upper right')
plt.show()

## Figure 6 <a id='figure_6'></a>

In [None]:
# Load and process the data
save_metrics = np.load('figure_data/snpe_comparison_metrics.npy', allow_pickle=True).item()
mean_metrics = save_metrics['mean_metrics']
log_var_metrics = save_metrics['log_var_metrics']
truth_hier = np.load('figure_data/truth_hier.npy')

# Load up the encodings used to train the model to calculate the apt loss normalized to the 
# distribution used to train the NPE model.
from paltax.TrainConfigs import train_config_npe_base
from paltax.TrainConfigs import train_config_snpe_base
config = train_config_npe_base.get_config()
snpe_config = train_config_snpe_base.get_config()
n_gpus_train = 4 #Number of gpus that were used in parallel when training the models.
mu_prior = snpe_config.mu_prior
prec_prior = snpe_config.prec_prior
mu_prop_init = snpe_config.mu_prop_init
prec_prop_init = snpe_config.prec_prop_init
std_prop_init = jnp.power(jnp.diag(prec_prop_init), -0.5)
prop_encoding = jax.vmap(input_pipeline.encode_normal)(
        mu_prop_init, std_prop_init
)
steps = np.arange(1,501) * config.steps_per_epoch
steps_snpe = np.arange(1,50) * config.steps_per_epoch

# Use the mean, variance, and truths to build the loss for the plot.
gaussian_loss_vmap = jax.jit(jax.vmap(train.gaussian_loss, in_axes=[0,None]))

mean_fid_array = jnp.array(jax.tree_util.tree_leaves(mean_metrics['Fiducial']))
log_var_fid_array = jnp.array(jax.tree_util.tree_leaves(log_var_metrics['Fiducial']))
log_post_fid_all_array = []
output = jnp.concatenate([mean_fid_array, log_var_fid_array], axis=-1)
for mod_num in range(32):
    log_post_model = []
    for pi in range(11):
        log_post_model += [gaussian_loss_vmap(
            output[:,mod_num:mod_num+1,[pi,pi+11]], truth_hier[mod_num:mod_num+1,pi:pi+1]
        )]
    log_post_fid_all_array += [jnp.stack(log_post_model, axis=1)]
log_post_fid_all_array = jnp.stack(log_post_fid_all_array, axis=0)

# Repeat the same for the sequential loss.
snpe_c_loss_vmap = jax.jit(jax.vmap(train.snpe_c_loss, in_axes=[0,None,None,None,None]))
mean_seq_array = []
log_seq_array = []
log_post_all_array = []
for mn, mod_num in [(f'Image {mod_num}', mod_num) for mod_num in range(30)]:
    mean_seq_array += [jnp.array(jax.tree_util.tree_leaves(mean_metrics[mn]))[:,0]]
    log_seq_array += [jnp.array(jax.tree_util.tree_leaves(log_var_metrics[mn]))[:,0]]
    output = jnp.concatenate([mean_seq_array[-1], log_seq_array[-1]], axis=-1)
    log_post_model = []
    for pi in range(11):
        log_post_model += [snpe_c_loss_vmap(
            output[:,jnp.newaxis,[pi,pi+11]], truth_hier[mod_num:mod_num+1,pi:pi+1], prop_encoding[pi:pi+1], 
            mu_prior[pi:pi+1], prec_prior[pi:pi+1,pi:pi+1]
        )]
    log_post_all_array += [jnp.stack(log_post_model, axis=1)]
  
mean_seq_array = jnp.stack(mean_seq_array, axis=1)
log_seq_array = jnp.stack(log_seq_array, axis=1)
log_post_all_array = jnp.stack(log_post_all_array, axis=0)

# Plot the loss comparison for desired parameters. By default plots theta_E, the parameter
# used in Figure 6.
fontsize = 20
colors = ['#d95f02', '#1b9e77', 'grey']
snpe_cut = 40
filter_size = 5
parameter_print_names = [r'$\theta_\mathrm{E}$', r'$\gamma_\mathrm{lens}$', r'$x_\mathrm{lens}$',
                         r'$y_\mathrm{lens}$', r'$e_1$', r'$e_2$', r'$\gamma_1$',
                         r'$\gamma_2$', r'$x_\mathrm{source}$',
                         r'$y_\mathrm{source}$', r'$\Sigma_\mathrm{sub}$']
for param_i in [0]:

    fig, ax = plt.subplots(1, 1, figsize=(9,8), sharey=True, gridspec_kw={'hspace': 0.02,'wspace':0.04},dpi=100)

    npe_loss_array = np.mean(log_post_fid_all_array[:30,:,param_i], axis=0)
    npe_loss_array = signal.medfilt(npe_loss_array, kernel_size=filter_size)
    ax.plot(steps[:len(npe_loss_array)] * config.batch_size * n_gpus_train, npe_loss_array, '.-', c=colors[0], 
             lw=3, ms = 10, alpha=0.9)
    snpe_loss_array = signal.medfilt(np.mean(log_post_all_array[:,:snpe_cut,param_i], axis=0), kernel_size=filter_size)
    ax.plot(steps_snpe[:snpe_cut] * config.batch_size * n_gpus_train, snpe_loss_array, '.-', c=colors[1], lw=3, ms = 10, 
            alpha=0.9)

    ax.set_ylabel(r'Mean Loss on ' + parameter_print_names[param_i], fontsize=fontsize)
    ax.set_xlabel('Images Seen', fontsize=fontsize)
    ax.set_xscale('log')
    ax.set_xlim([10**5*n_gpus_train,10**(7.9)*n_gpus_train])
    # ax.set_ylim([-0.3,0.17])
    ax.tick_params(axis='both', which='both', labelsize=fontsize, length=fontsize/2, width=1.5)
    ax.legend(['Fiducial', 'Sequential'], fontsize=fontsize, loc='upper right')
    plt.show()

## Figure 7 <a id='figure_7'></a>

In [None]:
mock_index = 5
image_hier = np.load('figure_data/image_hier.npy')
image_draws_list = np.load('figure_data/image_draws_list_4.npy')
def create_grid(images, n_row_col):
    placeholder = images[:n_row_col ** 2].reshape((n_row_col, n_row_col, *images.shape[1:]))
    return np.vstack([np.hstack(image) for image in placeholder])
    

fig, axes = plt.subplots(2, 2, figsize=(10,10), gridspec_kw={'hspace': 0.025,'wspace':0.02},dpi=100)
box_colors = ['#bdd7e7', '#6baed6', '#2171b5']
n_row_col = 4
spine_width = 8
fontsize = 20

# I don't have the normalization for the mock image saved, so we'll need to reintroduce a reasonable standard deviation
# to be able to use the same norm for all the images.
image_mock = image_hier[mock_index] * np.mean(np.std(image_draws_list[0].reshape(image_draws_list[0].shape[0],-1), axis=1))

im_norm = simple_norm(image_mock,stretch='asinh')
axes[0,0].imshow(image_mock, norm=im_norm, cmap='plasma')
axes[0,0].text(64, 14, 'Mock Observation', fontsize=fontsize, ha='center', color='white',
               bbox=dict(boxstyle="round",ec='grey',fc='grey', alpha=0.8))
axes[0,1].imshow(create_grid(image_draws_list[0], n_row_col), norm=im_norm, cmap='plasma')
axes[0,1].text(64 * n_row_col, 14 * n_row_col, r'$p(\theta|\Omega_0)$', fontsize=fontsize, ha='center', 
               color='white', bbox=dict(boxstyle="round",ec='grey',fc='grey', alpha=0.8))
axes[1,0].imshow(create_grid(image_draws_list[1], n_row_col), norm=im_norm, cmap='plasma')
axes[1,0].text(64 * n_row_col, 14 * n_row_col, r'$p(\theta|\Omega_1)$', fontsize=fontsize, ha='center', 
               color='white', bbox=dict(boxstyle="round",ec='grey',fc='grey', alpha=0.8))
axes[1,1].imshow(create_grid(image_draws_list[2], n_row_col), norm=im_norm, cmap='plasma')
axes[1,1].text(64 * n_row_col, 14 * n_row_col, r'$p(\theta|\Omega_3)$', fontsize=fontsize, ha='center', 
               color='white', bbox=dict(boxstyle="round",ec='grey',fc='grey', alpha=0.8))
for ax in axes.flatten():
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    
for i in range(3):
    for spine in axes.flatten()[i+1].spines.values():
        spine.set_edgecolor(box_colors[i])
        spine.set_linewidth(spine_width)

plt.show()

## Figure 8 <a id='figure_8'></a>

In [None]:
corner_param_print=[r'$\Sigma_\mathrm{sub,pop}$' + r' $[\mathrm{kpc}^{-2}]$',
                    r'$\Sigma_\mathrm{sub,pop,\sigma}$' + r' $[\mathrm{kpc}^{-2}]$']
fontsize = 20
figsize = (12,12)

truth_color = 'k'
truths = np.array([1.5e-3, 2e-4])
burnin = 1000

# Plot the corner plots for the fiducial model.
model_name = 'fiducial_1950000'
n_lenses_list = [10, 30]
colors = ['#d95f02']
for n_lenses in n_lenses_list:
    
    chains_path = f'figure_data/m_{model_name}_{n_lenses}_lenses_loose_chains.npy'
    print(chains_path)

    chain =  np.load(chains_path)[3,:,burnin:].reshape((-1,2))
    # Renormalize
    chain[:,0] *= 1.1e-3
    chain[:,0] += 2e-3
    chain[:,1] *= 1.1e-3
    
    hist_kwargs = {'density':True,'color':colors[0],'lw':3}
    plot_range = [(0.0, 4e-3),(0,0.005)]
    
    f = corner.corner(chain, range=plot_range, labels=corner_param_print,bins=20, show_titles=True, plot_datapoints=False,
                      label_kwargs=dict(fontsize=fontsize),levels=[0.68,0.95],
                      color=colors[0],fill_contours=True,hist_kwargs=hist_kwargs,title_fmt='.4f',truths=truths,
                      truth_color=truth_color,max_n_ticks=3)
    
    # Do some whacky stuff to deal with corner
    f.set_figheight(figsize[0])
    f.set_figwidth(figsize[1])
    for i in range(len(truths)):
        for j in range(len(truths)):
            corn_axis = f.axes[i*(len(truths))+j]
            corn_axis.title.set_fontsize(fontsize)
            if j == 0:
                if i > 0:
                    corn_axis.tick_params(axis='y', labelsize=fontsize,labelrotation=75)
                    corn_axis.set_ylabel(corner_param_print[i],fontsize=fontsize * 1.5, labelpad=10.0)
            if i == len(truths)-1:
                corn_axis.tick_params(axis='x', labelsize=fontsize,labelrotation=15)
                corn_axis.set_xlabel(corner_param_print[j],fontsize=fontsize * 1.5, labelpad=10.0)
    plt.show()

# Plot the corner plots for the sequential model.
colors = ['#1b9e77']
model_name = 'sequential_117000'
for n_lenses in n_lenses_list:
    
    chains_path = f'figure_data/m_{model_name}_{n_lenses}_lenses_loose_chains.npy'
    print(chains_path)

    chain =  np.load(chains_path)[0,:,burnin:].reshape((-1,2))
    # Renormalize
    chain[:,0] *= 1.1e-3
    chain[:,0] += 2e-3
    chain[:,1] *= 1.1e-3
    
    hist_kwargs = {'density':True,'color':colors[0],'lw':3}
    plot_range = [(0.0, 4e-3),(0,0.005)]
    
    f = corner.corner(chain, range=plot_range, labels=corner_param_print,bins=20, show_titles=True, plot_datapoints=False,
                      label_kwargs=dict(fontsize=fontsize),levels=[0.68,0.95],
                      color=colors[0],fill_contours=True,hist_kwargs=hist_kwargs,title_fmt='.4f',truths=truths,
                      truth_color=truth_color,max_n_ticks=3)
    
    # Do some whacky stuff to deal with corner
    f.set_figheight(figsize[0])
    f.set_figwidth(figsize[1])
    for i in range(len(truths)):
        for j in range(len(truths)):
            corn_axis = f.axes[i*(len(truths))+j]
            corn_axis.title.set_fontsize(fontsize)
            if j == 0:
                if i > 0:
                    corn_axis.tick_params(axis='y', labelsize=fontsize,labelrotation=75)
                    corn_axis.set_ylabel(corner_param_print[i],fontsize=fontsize * 1.5, labelpad=10.0)
            if i == len(truths)-1:
                corn_axis.tick_params(axis='x', labelsize=fontsize,labelrotation=15)
                corn_axis.set_xlabel(corner_param_print[j],fontsize=fontsize * 1.5, labelpad=10.0)
    plt.show()

## Figure 9 <a id='figure_9'></a>

In [None]:
npe_n_lenses_list = [5,10,15,20,25,30,40,50,60,70,80,90,100,110,120]
snpe_n_lenses_list = [5,10,15,20,25,30]
n_permutation = 10

# Extract the fits for permutations of all the number of lenses.
def extract_pop_std(n_lenses_list, n_permutation, model_type, fit_type, burnin):

    pop_std = np.zeros((len(n_lenses_list), n_permutation))

    for ni, n_lenses in enumerate(n_lenses_list):

        chains_path = f'figure_data/m_{model_type}_{n_lenses}_lenses_{fit_type}_chains.npy'
        chain = np.load(chains_path)[:,:,burnin:]
        chain = chain.reshape((chain.shape[0],-1,2))

        # Renormalize
        chain[:,:,0] *= 1.1e-3
        chain[:,:,0] += 2e-3
        chain[:,:,1] *= 1.1e-3

        pop_std[ni] = np.std(chain[:,:,0],axis=1)

    return pop_std

# Load the fits where the scatter is fixed for the line plot.
burnin = 1000
fit_npe_pop_std = extract_pop_std(npe_n_lenses_list, n_permutation, 'fiducial_1950000', 'tight', burnin)
fit_snpe_pop_std = extract_pop_std(snpe_n_lenses_list, n_permutation, 'sequential_117000', 'tight', burnin)
npe_pop_std = extract_pop_std(npe_n_lenses_list, n_permutation, 'fiducial_1950000', 'loose', burnin)
snpe_pop_std = extract_pop_std(snpe_n_lenses_list, n_permutation, 'sequential_117000', 'loose', burnin)

# Do a fit to each of the samples.
npe_lin_fits = np.zeros((npe_pop_std.shape[1], 2))
snpe_lin_fits = np.zeros((snpe_pop_std.shape[1], 2))
log_npe_n_lenses = np.log10(npe_n_lenses_list)
log_snpe_n_lenses = np.log10(snpe_n_lenses_list)

for i in range(npe_pop_std.shape[1]):
    npe_lin_fit = stats.linregress(log_npe_n_lenses[3:], np.log10(fit_npe_pop_std[3:,i]))
    snpe_lin_fit = stats.linregress(log_snpe_n_lenses[1:], np.log10(fit_snpe_pop_std[1:,i]))
    npe_lin_fits[i,0], npe_lin_fits[i,1] = npe_lin_fit[0], npe_lin_fit[1]
    snpe_lin_fits[i,0], snpe_lin_fits[i,1] = snpe_lin_fit[0], snpe_lin_fit[1]

# Plot the constraint as a function of number of lenses.
fontsize=15
fig, ax = plt.subplots(1, 1, figsize=(10,7), dpi=150)

colors = ['#d95f02', '#1b9e77', 'grey']

# Calculate the quantiles for the constraining power and plot them with the median value.
npe_quantiles = np.stack(
    [np.median(npe_pop_std, axis=1)-np.quantile(npe_pop_std, 0.16, axis=1),
     np.quantile(npe_pop_std, 0.84, axis=1) - np.median(npe_pop_std, axis=1)],
    axis=0)
snpe_quantiles = np.stack(
    [np.median(snpe_pop_std, axis=1)-np.quantile(snpe_pop_std, 0.16, axis=1),
     np.quantile(snpe_pop_std, 0.84, axis=1) - np.median(snpe_pop_std, axis=1)],
    axis=0)
ax.errorbar(npe_n_lenses_list, np.median(npe_pop_std, axis=1), yerr=npe_quantiles, 
            fmt='.', c=colors[0], ms = 15, label='Fiducial')
ax.errorbar(snpe_n_lenses_list, np.median(snpe_pop_std, axis=1), yerr=snpe_quantiles, 
            fmt='.', c=colors[1], ms = 15, label='Sequential')

# large_n_alpha = 0.3
# ax.errorbar(npe_n_lenses_list, np.median(fit_npe_pop_std, axis=1),
#             fmt='.', c=colors[0], ms = 15, alpha=large_n_alpha)
# ax.errorbar(snpe_n_lenses_list, np.median(fit_snpe_pop_std, axis=1),
#             fmt='.', c=colors[1], ms = 15, alpha=large_n_alpha)

# Plot the large N power law.
mean_npe_lin_fits = np.mean(npe_lin_fits, axis=0)
min_npe_lin_fits = np.quantile(npe_lin_fits, 0.16, axis=0)
max_npe_lin_fits = np.quantile(npe_lin_fits, 0.84, axis=0)
mean_snpe_lin_fits = np.mean(snpe_lin_fits, axis=0)
min_snpe_lin_fits = np.quantile(snpe_lin_fits, 0.16, axis=0)
max_snpe_lin_fits = np.quantile(snpe_lin_fits, 0.84, axis=0)

npe_p_law_steps = np.linspace(np.log10(3), np.log10(280), 100)
npe_p_law = mean_npe_lin_fits[0]*npe_p_law_steps + mean_npe_lin_fits[1]
ax.plot(10**(npe_p_law_steps), 10**npe_p_law, '-.', c=colors[0], lw=3, label='Fiducial: Large $N_\mathrm{lens}$ Scaling')

snpe_p_law_steps = np.linspace(np.log10(3), np.log10(55), 100)
snpe_p_law = mean_snpe_lin_fits[0]*snpe_p_law_steps + mean_snpe_lin_fits[1]
ax.plot(10**(snpe_p_law_steps), 10**snpe_p_law, '--', c=colors[1], lw=3, label='Sequential: Large $N_\mathrm{lens}$ Scaling')

# Describe the power law lines
ax.text(10**(npe_p_law_steps[5]), 10**(npe_p_law[5]), r'$\propto N^{-0.5}$',fontsize=fontsize, rotation=-25, 
        horizontalalignment='center', verticalalignment='bottom', rotation_mode='anchor')
ax.text(10**(snpe_p_law_steps[8]), 10**(snpe_p_law[8]), r'$\propto N^{-0.5}$',fontsize=fontsize, rotation=-25,
        horizontalalignment='center', verticalalignment='bottom', rotation_mode='anchor')

# Label the gap between the two power laws and the regions of interes.
percent_measurement = 0.2
ax.axhspan(1.4e-3*percent_measurement,1.6e-3*percent_measurement, color=colors[2], alpha=0.3)
ax.text(3.0,1.5e-3*percent_measurement*0.965, r'$5\sigma$ Detection', color='white', fontsize=fontsize,
        bbox=dict(boxstyle="round",ec='#969696',fc='#969696'))

c_vert = 1.5e-4
x_snpe = 10**((np.log10(c_vert)-mean_snpe_lin_fits[1])/mean_snpe_lin_fits[0])
x_npe = 10**((np.log10(c_vert)-mean_npe_lin_fits[1])/mean_npe_lin_fits[0])
ax.annotate(text='', xy=(x_npe,c_vert), xytext=(x_snpe,c_vert), arrowprops=dict(arrowstyle='<->', lw=2))
ax.text(95,c_vert + 0.1e-4,f'~{x_npe/x_snpe:.0f}x',fontsize=fontsize)

percent_measurement = 0.1
ax.axhspan(1.4e-3*percent_measurement,1.6e-3*percent_measurement, color=colors[2], alpha=0.3)
ax.text(3.0,1.455e-4, r'$10\%$ Measurement', color='white', fontsize=fontsize,
        bbox=dict(boxstyle="round",ec='#969696',fc='#969696'))

# Format the plot.
ax.set_ylabel('Uncertainty on ' + r'$\Sigma_\mathrm{sub,pop}$',fontsize=fontsize)
ax.set_xlabel('Number of Analyzed Lenses', fontsize=fontsize)
ax.set_yscale('log')
ax.set_xscale('log')
ax.tick_params(axis='both', which='both', labelsize=fontsize, length=fontsize/2, width=1.5)

handles, labels = plt.gca().get_legend_handles_labels()
order = [2,3,0,1]
ax.legend([handles[idx] for idx in order],[labels[idx] for idx in order], fontsize=fontsize, loc='upper right')
ax.set_ylim([1.2e-4,3e-3])
ax.set_xlim([2.5,340])

plt.show()