In [None]:
import os
import json
import gzip
import glob
import itertools
import scipy.stats
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt

from tqdm import tqdm
from copy import deepcopy
from argparse import Namespace
from collections import OrderedDict
from itertools import product
from functools import partial
from multiprocessing import Pool
from statsmodels.graphics.tsaplots import plot_acf
from multiresticodm.utils import *
from multiresticodm.sim_models.production_constrained import *
from multiresticodm.outputs import Outputs

In [None]:
%matplotlib inline

# AUTO RELOAD EXTERNAL MODULES
%load_ext autoreload
%autoreload 2

# Import samples

In [None]:
settings = {'table_total':33704}

foldername = 'exp14_JointTableSIMLatentMCMC_HighNoise_both_margins_permuted_cells_10%_20_03_2023_10_29_54'

outputs = Outputs(
    os.path.join('../data/outputs/cambridge_work_commuter_lsoas_to_msoas/',foldername),
    settings,
    sample_names=['intensity','theta','log_destination_attraction','table']
)

In [None]:
# #  for sample_name in self.experiment.results.keys():
# sample_name = 'intensity'
# dims = np.shape(outputs.experiment.results['table'])[1:]
# print(sample_name,dims)
# cells = sorted([tuple(cell) for cell in product(*[range(dim) for dim in dims])])
# for cell in cells:
#     data = outputs.experiment.results[sample_name][(...,*cell)]
#     plt.figure(figsize=(10,10))
#     _ = plt.hist(data)
#     plt.show()

In [None]:
# Read important metadata (true latent values)
I,J = sim.I,sim.dims[1]
log_destination_attraction_data = sim.log_destination_attraction
true_log_destination_attraction = sim.log_true_destination_attraction

# Decide on burnin
burnin = 5000
table_steps = 1#metadata['mcmc']['contingency_table']['table_steps']
colsum_steps = 1#metadata['mcmc']['contingency_table']['column_sum_steps']

# Define figure format
figure_format = 'eps' #'eps','png'

# Preprocessing

In [None]:
if sim.ground_truth_known:
    theta_true = [sim.alpha_true,sim.beta_true*sim.bmax]
    true_log_flows = sim.log_intensity(sim.log_true_destination_attraction,theta_true)
    true_flows = np.exp(true_log_flows)
    print('Intensities based on ground truth for x')
    print(pd.DataFrame(true_flows))

# Compute mean intensities
lambda_sample_mean = np.exp(np.mean(log_lambda_samples,axis=0))
    
theta_mean_scaled = np.mean(theta_samples[burnin:,:],axis=0)
theta_mean_scaled[1] *= sim.bmax
expected_log_flows = sim.log_intensity(np.mean(log_destination_attraction_samples[burnin:,:],axis=0),theta_mean_scaled)
expected_flows = np.exp(expected_log_flows)
print('Intensities based on mean x, theta')
print(pd.DataFrame(expected_flows))

# X, Theta Sampling
## Trace plots

In [None]:
print(f"{np.mean(theta_samples[burnin:],axis=0)} +/- {np.std(theta_samples[burnin:],axis=0)}")

In [None]:
fig,axs = plt.subplots(1,2,figsize=(20,10))

axs[0].plot(theta_samples[burnin:, 0],label='samples')
axs[0].set_ylabel(r'$\alpha$',fontsize=18,rotation=0,labelpad=7)
axs[0].set_xlabel('MCMC samples',fontsize=18,labelpad=7)
if hasattr(sim,'alpha_true'):
    axs[0].axhline(y=sim.alpha_true, color='black', linestyle='-',label='true')
axs[0].axhline(y=np.mean(theta_samples[burnin:, 0]),color='lime',label=r'$\mu$')
axs[0].set_ylim(0,2.0)
axs[0].legend()

axs[1].plot(theta_samples[burnin:, 1])
axs[1].set_ylabel(r'$\beta$',fontsize=18,rotation=0,labelpad=7)
axs[1].set_xlabel('MCMC samples',fontsize=18,labelpad=7)
if hasattr(sim,'beta_true'):
    axs[1].axhline(y=sim.beta_true, color='black', linestyle='-',label='true')
axs[1].axhline(y=np.mean(theta_samples[burnin:, 1]),color='lime',label=r'$\mu$')
axs[1].set_ylim(0,2.0)
axs[1].legend()

fig.suptitle(fr'{experiment_type}',fontsize=20)
plt.savefig(os.path.join(dirpath,f'figures/parameter_mixing.{figure_format}'),format=figure_format)
plt.show()

In [None]:
print(f"{np.mean(log_destination_attraction_samples[burnin:],axis=0)} +/- {np.std(log_destination_attraction_samples[burnin:],axis=0)}")

In [None]:
print(sim.log_destination_attraction)

In [None]:
np.mean(log_destination_attraction_samples[burnin:],axis=0)-sim.log_destination_attraction

In [None]:
np.mean(log_destination_attraction_samples[burnin:],axis=0)

In [None]:
log_destination_attraction_data

In [None]:
1/(sim.gamma/2)

In [None]:
sim.noise_var

In [None]:
fig,axs = plt.subplots(1,J,figsize=(20,10))
relative_noise = np.sqrt(sim.noise_var)/np.log(sim.dims[1])
relative_noise_percentage = round(100*relative_noise)
upper_bound = log_destination_attraction_data + np.log((1.0+relative_noise_percentage/100))
lower_bound = log_destination_attraction_data - np.log((1.0+relative_noise_percentage/100))
for j in range(J):
    axs[j].plot(log_destination_attraction_samples[burnin:, j])
    axs[j].set_ylabel(fr'$x_{j}$',fontsize=18,rotation=0,labelpad=7)
    axs[j].set_xlabel('MCMC samples',fontsize=18,labelpad=7)
    axs[j].axhline(y=np.mean(log_destination_attraction_samples[burnin:],axis=0)[j], color='lime', linestyle='-',label='$\mu$')
    axs[j].axhline(y=log_destination_attraction_data[j], color='r', linestyle='-',label='data')
    axs[j].axhline(y=sim.log_true_destination_attraction[j], color='black', linestyle='-',label='generated')
    axs[j].axhline(y=upper_bound[j], color='r', linestyle='--',label=f'data + {relative_noise_percentage}%')
    axs[j].axhline(y=lower_bound[j], color='r', linestyle='--',label=f'data - {relative_noise_percentage}%')
    axs[j].legend()
fig.suptitle(fr'{experiment_type}',fontsize=20)
plt.savefig(os.path.join(dirpath,f'figures/log_destination_attraction_mixing.{figure_format}'),format=figure_format)
plt.show()

In [None]:
plt.figure(figsize=(7,5))
plt.scatter(sim.log_destination_attraction-np.mean(log_destination_attraction_samples[burnin:],axis=0),np.mean(log_destination_attraction_samples[burnin:],axis=0))
plt.xlabel(r'$\log(Y_j)-E[\log(W_j)|Y_j]$',fontsize=16)#,rotation=0,labelpad=90)
plt.ylabel(r'$E[\log(W_j)|Y_j]$',fontsize=16)
plt.tight_layout()
plt.savefig(os.path.join(dirpath,f'figures/log_destination_attraction_residuals.{figure_format}'),format=figure_format)
plt.show()

In [None]:
plt.figure(figsize=(7,5))
plt.scatter(sim.log_destination_attraction,np.mean(log_destination_attraction_samples[burnin:],axis=0))
plt.xlabel(r'$\log(Y_j)$',fontsize=16)#,rotation=0,labelpad=40)
plt.ylabel(r'$E[\log(W_j)|Y_j]$',fontsize=16)
plt.tight_layout()
plt.savefig(os.path.join(dirpath,f'figures/log_destination_attraction_predictions.{figure_format}'),format=figure_format)
plt.show()

## 2D theta sampled space

In [None]:
fig = plt.figure(figsize=(10,10))

plt.plot(theta_samples[burnin:, 0],theta_samples[burnin:, 1])#,label='samples',marker='x')
plt.ylabel(r'$\beta$',fontsize=18,rotation=0,labelpad=7)
plt.xlabel(r'$\alpha$',fontsize=18,labelpad=7)
plt.xlim(0,2)
plt.ylim(0,2)
if hasattr(sim,'alpha_true') and hasattr(sim,'beta_true'):
    plt.plot(sim.alpha_true,sim.beta_true,marker='o',color='r',label='ground truth')
plt.plot(np.mean(theta_samples[burnin:,0]),np.mean(theta_samples[burnin:,1]),marker='o',color='y',label='mean')
plt.legend()
plt.savefig(os.path.join(dirpath,f'figures/parameter_2d_exploration.{figure_format}'),format=figure_format)
plt.show()

## Histograms

In [None]:
fig,axs = plt.subplots(1,2,figsize=(20,10))

axs[0].hist(theta_samples[burnin:, 0],bins=100)
axs[0].set_ylabel(r'$\alpha$',fontsize=18,rotation=0,labelpad=7)
axs[0].set_xlabel('MCMC samples',fontsize=18,labelpad=7)
if hasattr(sim,'alpha_true'):
    axs[0].axvline(x=sim.alpha_true, color='r', linestyle='-',label='generated')
axs[0].axvline(x=np.mean(theta_samples[burnin:, 0]), color='lime', linestyle='-',label=f'posterior $\mu$')
axs[0].set_xlim(0,2.0)
axs[0].legend()

axs[1].hist(theta_samples[burnin:, 1],bins=100)
axs[1].set_ylabel(r'$\beta$',fontsize=18,rotation=0,labelpad=7)
axs[1].set_xlabel('MCMC samples',fontsize=18,labelpad=7)
if hasattr(sim,'beta_true'):
    axs[1].axvline(x=sim.beta_true, color='r', linestyle='-',label='generated')
axs[1].axvline(x=np.mean(theta_samples[burnin:, 1]), color='lime', linestyle='-',label=f'posterior $\mu$')
axs[1].set_xlim(0,2.0)
axs[1].legend()


fig.suptitle(fr'{experiment_type}',fontsize=20)
plt.savefig(os.path.join(dirpath,f'figures/parameter_histogram.{figure_format}'),format=figure_format)
plt.show()

## Autocorrelation Plots

In [None]:
fig,axs = plt.subplots(1,2,figsize=(10,5))

plot_acf(theta_samples[burnin:, 0],lags=200,ax=axs[0])
axs[0].set_ylim(0,1.1)
axs[0].set_ylabel(r'$\alpha$',rotation=0,labelpad=10,fontsize=14)
axs[0].axhline(y=0.2,color='red')
axs[0].set_xlabel('Lags')

plot_acf(theta_samples[burnin:, 1],lags=200,ax=axs[1])
axs[1].set_ylim(0,1.1)
axs[1].axhline(y=0.2,color='red')
axs[1].set_ylabel(r'$\beta$',rotation=0,labelpad=10,fontsize=14)
axs[1].set_xlabel('Lags')

plt.savefig(os.path.join(dirpath,f'figures/parameter_acf.{figure_format}'),format=figure_format)
plt.show()

# Table sample inspection

In [None]:
ct.table

In [None]:
# Count number of times true table was reconstructed
# and build a histogram over table
table_histogram = dict([(table_to_str(k), 0) for k in table_samples[burnin::table_steps]])
for f in tqdm(table_samples[burnin::table_steps]):
    # Append table count to histogram
    table_histogram[table_to_str(f)] += 1

    
# Sort histogram by frequency
# table_histogram = {k: v for k, v in sorted(table_histogram.items(), key=lambda item: item[0])}
table_histogram = {k: v for k, v in sorted(table_histogram.items(), key=lambda item: -item[1])}
n_table_samples = sum(table_histogram.values(), 0.0)
table_probabilities = {k: v / n_table_samples for k, v in table_histogram.items()}

print(f"{int(100*table_histogram[table_to_str(ct.table)]/len(table_samples[burnin::table_steps]))}% of table samples matched true table")
print('Matches',table_histogram[table_to_str(ct.table)])

In [None]:
# Check whether true table is identified and in which order
# Check how many tables have the right column sums
true_table_index = -1
true_colsums_indices = []
true_colsums_table_strings = []
for i,keyvalue in enumerate(table_histogram.items()):
    # Monitor correct table samples
    if keyvalue[0] == table_to_str(ct.table):
        true_table_index = i
#         print(f'True table is the {i}th most commonly sampled out of {len(table_histogram)} distinct tables')
    # Monitor table samples with corrent column sums
    if np.all(abs(str_to_table(keyvalue[0],dims=(I,J)).sum(axis=0) - ct.table.sum(axis=0)) <= 1e-9):
        true_colsums_indices.append(i)
        true_colsums_table_strings.append(keyvalue[0])
print(f'True table is the {true_table_index+1} most commonly sampled out of {len(table_histogram)} distinct tables')

In [None]:
if sim.ground_truth_known:
    # Get multinomial probabilities
    _,log_probs,_ = ct_mcmc.log_intensities_to_multinomial_log_probabilities(true_log_flows)
    probs = np.exp(log_probs)

    # Find mode through search
    support = [x for x in itertools.product(range(0,ct.margins[range(self.ndims())]), repeat=J) if sum(x) == ct.margins[range(self.ndims())]]
    mode = np.zeros(J)
    multiple_modes = []
    max_prob = 0
    normalisation = 0
    total_prob = 0
    for v in support:
        # Renormalise by limiting the support to vectors that do not contain zeros
        if 0 in v:
            normalisation += scipy.stats.multinomial.pmf(v,n=ct.margins[range(self.ndims())],p=probs)
        total_prob += scipy.stats.multinomial.pmf(v,n=ct.margins[range(self.ndims())],p=probs)
    assert abs(total_prob-1.0) <= 1e-7
    # Find mode
    for v in support:
        if scipy.stats.multinomial.pmf(v,n=ct.margins[range(self.ndims())],p=probs)/(1-normalisation) > max_prob:
            max_prob = scipy.stats.multinomial.pmf(v,n=ct.margins[range(self.ndims())],p=probs)/(1-normalisation)
            mode = v
    # Find nearby modes
    for v in support:
        if abs(scipy.stats.multinomial.pmf(v,n=ct.margins[range(self.ndims())],p=probs)/(1-normalisation)-max_prob)<=1e-4:
            multiple_modes.append(v)
mode

In [None]:
ct.true_colsums

In [None]:
ct.table

In [None]:
# Find table with higher log target under true intensities
max_index = -1
max_target = -np.infty
for i in range(true_table_index+1):
    lt = log_product_multinomial_pmf(str_to_table(list(table_histogram.keys())[i],dims=(I,J)),true_log_flows)
    if lt > max_target:
        max_target = lt
        max_index = i
print(lt)
print(str_to_table(list(table_histogram.keys())[max_index],dims=(I,J)))
print(ct.table)
print(str_to_table(list(table_histogram.keys())[max_index],dims=(I,J)).sum(axis=0))
print(ct.colsums)

In [None]:
plt.figure(figsize=(20,10))
plt.bar(range(len(table_histogram)), list(table_histogram.values()), align='center', label='samples')
plt.bar(true_colsums_indices, np.array(list(table_histogram.values()))[true_colsums_indices], align='center',color='green',label='true colsums')
plt.bar(true_table_index, list(table_histogram.values())[true_table_index], align='center',color='red',label='true table')
# _ = plt.xticks(range(len(table_histogram)), list(table_histogram.keys()),rotation=90,fontsize=5)
_ = plt.xlim(-1,300)#len()
_ = plt.ylabel('Frequency',fontsize=20)
_ = plt.xlabel('Table string',fontsize=20)
_ = plt.legend(fontsize=20)
_ = plt.title(f'{len(table_histogram)} different tables sampled')
plt.savefig(os.path.join(dirpath,f'figures/table_histogram.{figure_format}'),format=figure_format)

In [None]:
# Slice table histogram to obtain histogram of tables with true column sums
true_table_histogram = OrderedDict((k, table_histogram[k]) for k in true_colsums_table_strings)
# Sort histogram lexicographically
# true_table_histogram = {k: v for k, v in sorted(true_table_histogram.items(), key=lambda item: item[0])}
true_table_histogram = {k: v for k, v in sorted(true_table_histogram.items(), key=lambda item: -item[1])}
# Store updated index of true table
true_table_new_index = list(true_table_histogram.keys()).index(table_to_str(ct.table))

print(f"{int(100*np.sum(list(true_table_histogram.values()))/len(table_samples[burnin::table_steps]))}% of table samples matched true column sums")
print('Matches',np.sum(list(true_table_histogram.values())))

In [None]:
plt.figure(figsize=(20,10))
plt.bar(range(len(true_table_histogram)), np.fromiter(true_table_histogram.values(),dtype=int), align='center',color='green',label='true colsums')
plt.bar(true_table_new_index, true_table_histogram[table_to_str(ct.table)], align='center',color='red',label='true table')
_ = plt.xticks(range(-1,len(true_table_histogram)-1), list(true_table_histogram.keys()),rotation=45)
_ = plt.xlim(-1,len(true_table_histogram))
_ = plt.ylabel('Frequency',fontsize=20)
_ = plt.xlabel('Table string',fontsize=20)
_ = plt.legend(fontsize=20)
plt.savefig(os.path.join(dirpath,f'figures/true_table_histogram.{figure_format}'),format=figure_format)

In [None]:
str_to_table(list(true_table_histogram.keys())[0],dims=(I,J))

In [None]:
ct.table

In [None]:
np.exp(true_log_flows)

In [None]:
expected_flows

# Colsums sample inspection

In [None]:
# Count number of times true table was reconstructed
# and build a histogram over table
colsum_histogram = dict([(table_to_str(k), 0) for k in colsum_samples[burnin::colsum_steps]])
for f in tqdm(colsum_samples[burnin::colsum_steps]):
    # Append table count to histogram
    colsum_histogram[table_to_str(f)] += 1
    
# Sort histogram by frequency
# colsum_histogram = {k: v for k, v in sorted(colsum_histogram.items(), key=lambda item: item[0])}
colsum_histogram = {k: v for k, v in sorted(colsum_histogram.items(), key=lambda item: -item[1])}
n_colsum_samples = sum(colsum_histogram.values(), 0.0)
colsum_probabilities = {k: v / n_colsum_samples for k, v in colsum_histogram.items()}

if str_in_list(table_to_str(np.array(ct.true_colsums)),colsum_histogram.keys()):
    print(f"{int(100*colsum_histogram[table_to_str(np.array(ct.true_colsums))]/len(colsum_samples[burnin::colsum_steps]))}% of colsum samples matched true colsums")
    print('Matches',colsum_histogram[table_to_str(np.array(ct.true_colsums))])

In [None]:
# Check whether true table is identified and in which order
# Check how many tables have the right column sums
true_colsum_index = -1
for i,keyvalue in enumerate(colsum_histogram.items()):
    # Monitor correct table samples
    if keyvalue[0] == table_to_str(np.array(ct.true_colsums)):
        true_colsum_index = i
print(f'True colsum is the {true_colsum_index+1} most commonly sampled out of {len(colsum_histogram)} distinct colsums')

In [None]:
str_to_table(list(colsum_histogram.keys())[0],dims=(1,J))[0]

In [None]:
np.exp(true_log_flows).sum(axis=0)

In [None]:
ct.true_colsums

In [None]:
plt.figure(figsize=(20,10))
plt.bar(range(len(colsum_histogram)), list(colsum_histogram.values()), align='center', label='samples')
plt.bar(true_colsum_index, list(colsum_histogram.values())[true_colsum_index], align='center',color='red',label='true colsums')
# _ = plt.xticks(range(len(colsum_histogram)), list(colsum_histogram.keys()),rotation=90)
_ = plt.xlim(-1,len(colsum_histogram))
_ = plt.ylabel('Frequency',fontsize=20)
_ = plt.xlabel('Colsums string',fontsize=20)
_ = plt.legend(fontsize=20)
plt.savefig(os.path.join(dirpath,f'figures/colsum_histogram.{figure_format}'),format=figure_format)

# Intensity convergence

In [None]:
# Define sample sizes so that statistics will be compute every MCMC interval
table_sample_step = 1
table_burnin = 0
table_chain_length = 20000
maxN = int(min(table_burnin+table_chain_length,table_samples.shape[0]))
table_sample_sizes = range(table_burnin+table_sample_step,maxN+table_sample_step,table_sample_step)

In [None]:
lambda_sample_mean_error_l1_norms = np.zeros(len(table_sample_sizes))
lambda_sample_mean_error_l2_norms = np.zeros(len(table_sample_sizes))
for i,s in tqdm(enumerate(table_sample_sizes),total=len(table_sample_sizes)):
    theta_running_mean_scaled = np.mean(theta_samples[table_burnin:s],axis=0)
    theta_running_mean_scaled[1] *= sim.bmax
    lambda_sample_mean_error_l1_norms[i] = relative_l1(
                        tab0=true_log_flows,
                        tab=np.mean(log_lambda_samples[table_burnin:s],axis=0)
#         sim.log_intensity(np.mean(log_destination_attraction_samples[table_burnin:s],axis=0),theta_running_mean_scaled)
    )
    lambda_sample_mean_error_l2_norms[i] = relative_l2_norm(
                        tab0=true_log_flows,
                        tab=np.mean(log_lambda_samples[table_burnin:s],axis=0)
#         sim.log_intensity(np.mean(log_destination_attraction_samples[table_burnin:s],axis=0),theta_running_mean_scaled)
    )

In [None]:
plt.figure(figsize=(7,5))
plt.plot(table_sample_sizes,lambda_sample_mean_error_l1_norms)
plt.xlabel('MCMC iteration',fontsize=16)
plt.ylabel(r'Relative $L_1$ of $\mathbb{E}[\lambda]$',fontsize=16)
plt.locator_params(axis='x', nbins=20)
plt.axhline(y=0,color='red')
plt.savefig(os.path.join(dirpath,f'figures/intensity_relative_l1_with_mcmc_iteration_chain_length_{table_chain_length}_{comment}.{figure_format}'),format=figure_format)

In [None]:
plt.figure(figsize=(7,5))
plt.plot(table_sample_sizes,lambda_sample_mean_error_l2_norms)
plt.xlabel('MCMC iteration',fontsize=16)
plt.ylabel(r'Relative $L_2$ of $\mathbb{E}[\lambda]$',fontsize=16)
plt.locator_params(axis='x', nbins=20)
plt.axhline(y=0,color='red')
plt.savefig(os.path.join(dirpath,f'figures/intensity_relative_l1_with_mcmc_iteration_chain_length_{table_chain_length}_{comment}.{figure_format}'),format=figure_format)

# Table convergence

### Check for convergence in probability of table mean estimator
This checks if weak law of large numbers holds by ensuring that

$$\lim_{M\to \infty} |\mathbf{\bar{n}}^{(0:M)} - \boldsymbol{\lambda}|^p = 0$$

for $p=1,2$. The norm is defined as follows:

$$|\mathbf{n}|^p = \left( \sum_{i,j}^{I,J} |n_{ij}|^p \right)^{1/p}.$$

The goal is to establish that 
$$\mathbf{\bar{n}}^{(0:i)} = \frac{1}{M}\sum_{m=1}^M \mathbf{n}^{(m)} \to \left( \frac{O_i \lambda_{ij}}{\sum_{l}^J \lambda_{il}} \right)_{i,j}^{I,J} = \boldsymbol{\lambda},$$
i.e. that the estimator $\mathbf{\bar{n}}^{(0:i)}$ converges in probability to the ground truth mean.

In [None]:
# Define sample sizes so that statistics will be compute every MCMC interval
table_sample_step = 1
table_burnin = 0
table_chain_length = 20000
maxN = int(min(table_burnin+table_chain_length,table_samples.shape[0]))
table_sample_sizes = range(table_burnin+table_sample_step,maxN+table_sample_step,table_sample_step)

In [None]:
table_sample_mean_error_l1_norms = np.zeros(len(table_sample_sizes))
table_sample_mean_error_l2_norms = np.zeros(len(table_sample_sizes))
for i,s in tqdm(enumerate(table_sample_sizes),total=len(table_sample_sizes)):
    table_sample_mean_error_l1_norms[i] = relative_l1(
                        tab0=np.exp(true_log_flows),
                        tab=np.mean(table_samples[table_burnin:s],axis=0)
    )
    table_sample_mean_error_l2_norms[i] = relative_l2_norm(
                        tab0=np.exp(true_log_flows),
                        tab=np.mean(table_samples[table_burnin:s],axis=0)
    )

In [None]:
plt.figure(figsize=(7,5))
plt.plot(table_sample_sizes,table_sample_mean_error_l1_norms)
plt.xlabel('MCMC iteration',fontsize=16)
plt.ylabel(r'Relative $L_1$ of $\mathbb{E}[\mathbf{n}|\mathbf{n}_{\cdot,+}]$',fontsize=16)
plt.locator_params(axis='x', nbins=20)
plt.axhline(y=0,color='red')
plt.savefig(os.path.join(dirpath,f'figures/expected_table_relative_l1_with_mcmc_iteration_chain_length_{table_chain_length}_{comment}.{figure_format}'),format=figure_format)

In [None]:
plt.figure(figsize=(7,5))
plt.plot(table_sample_sizes,table_sample_mean_error_l2_norms)
plt.xlabel('MCMC iteration',fontsize=16)
plt.ylabel(r'Relative $L_2$ of $\mathbb{E}[\mathbf{n}|\mathbf{n}_{\cdot,+}]$',fontsize=16)
# plt.xticks(sample_sizes[::50])
plt.locator_params(axis='x', nbins=20)
plt.axhline(y=0,color='red')
plt.savefig(os.path.join(dirpath,f'figures/expected_table_relative_l2_norm_with_mcmc_iteration_chain_length_{table_chain_length}_{comment}.{figure_format}'),format=figure_format)

### X,Theta Convergence
## Import samples from different chains

In [None]:
# Expertiment id
experiment_id = 'synthetic_2x3_exp7'
# Expertiment type
experiment_type = 'JointTableSIMLatentLowNoiseMCMCConvergence'
# Expertiment date
date = '25_05_2022'

# Define directory
dirpath = f'../data/outputs/{experiment_id}_{experiment_type}_{date}/'
# Define filepaths
log_dest_attraction_filenames = os.path.join(dirpath,f'samples/log_destination_attraction_samples*.npy.gz')
theta_filenames = os.path.join(dirpath,f'samples/theta_samples*.npy.gz')
convergence_metadata_filename = os.path.join(dirpath,f'{experiment_id}_{experiment_type}_{date}_metadata.json')

In [None]:
with open(convergence_metadata_filename, 'r') as f:
    convergence_metadata = json.load(f)
    
# Get number of chains
M = convergence_metadata['M']

In [None]:
log_destination_attraction_multiple_samples = []
for file in sorted(glob.glob(log_dest_attraction_filenames)):
    # Load files into memory
    s = read_npy(file)
    log_destination_attraction_multiple_samples.append(s)
    
theta_multiple_samples = []
for file in sorted(glob.glob(theta_filenames)):
    # Load files into memory
    s = read_npy(file)
    theta_multiple_samples.append(s)
    
theta_multiple_samples = np.array(theta_multiple_samples)
log_destination_attraction_multiple_samples = np.array(log_destination_attraction_multiple_samples)

## Trace plots for different chains

In [None]:
chain_index = 8
fig,axs = plt.subplots(1,2,figsize=(20,10))

axs[0].plot(theta_multiple_samples[chain_index,burnin:, 0],label='samples')
axs[0].set_ylabel(r'$\alpha$',fontsize=18,rotation=0,labelpad=7)
axs[0].set_xlabel('MCMC samples',fontsize=18,labelpad=7)
if hasattr(sim,'alpha_true'):
    axs[0].axhline(y=sim.alpha_true, color='black', linestyle='-',label='true')
axs[0].axhline(y=np.mean(theta_multiple_samples[chain_index,burnin:, 0]),color='lime',label=r'$\mu$')
axs[0].set_ylim(0,2.0)
axs[0].legend()

axs[1].plot(theta_multiple_samples[chain_index,burnin:, 1])
axs[1].set_ylabel(r'$\beta$',fontsize=18,rotation=0,labelpad=7)
axs[1].set_xlabel('MCMC samples',fontsize=18,labelpad=7)
if hasattr(sim,'beta_true'):
    axs[1].axhline(y=sim.beta_true, color='black', linestyle='-',label='true')
axs[1].axhline(y=np.mean(theta_multiple_samples[chain_index,burnin:, 1]),color='lime',label=r'$\mu$')
axs[1].set_ylim(0,2.0)
axs[1].legend()

fig.suptitle(fr'{experiment_type}',fontsize=20)
plt.show()

In [None]:
fig,axs = plt.subplots(1,J,figsize=(20,10))
relative_noise = np.sqrt(sim.noise_var)/np.log(sim.dims[1])
relative_noise_percentage = round(100*relative_noise)
upper_bound = sim.log_destination_attraction + np.log((1.0+relative_noise_percentage/100))
lower_bound = sim.log_destination_attraction - np.log((1.0+relative_noise_percentage/100))
for j in range(J):
    axs[j].plot(log_destination_attraction_multiple_samples[chain_index,burnin:, j])
    axs[j].set_ylabel(fr'$x_{j}$',fontsize=18,rotation=0,labelpad=7)
    axs[j].set_xlabel('MCMC samples',fontsize=18,labelpad=7)
    axs[j].axhline(y=np.mean(log_destination_attraction_multiple_samples[chain_index,burnin:],axis=0)[j], color='lime', linestyle='-',label='$\mu$')
    axs[j].axhline(y=sim.log_destination_attraction[j], color='r', linestyle='-',label='data')
#     axs[j].axhline(y=sim.log_true_destination_attraction[j], color='black', linestyle='-',label='generated')
    axs[j].axhline(y=upper_bound[j], color='r', linestyle='--',label=f'data + {relative_noise_percentage}%')
    axs[j].axhline(y=lower_bound[j], color='r', linestyle='--',label=f'data - {relative_noise_percentage}%')
    axs[j].legend()
fig.suptitle(fr'{experiment_type}',fontsize=20)
plt.show()

In [None]:
def gelman_rubin_criterion(samples,burnin:int,step:int=1,r_critical:float=1.1,prints:bool=True):
        
    # Convert to numpy
    samples = np.array(samples)
    
    # Get number of chain iterations and number of chains
    m,n,p = np.shape(samples)
    
    # Create list of possible burnin times
    possible_lengths = list(range(burnin+step,n,step))

    if prints: print(f'Gelman Rubin convergence criterion with M = {m}, N = {n-burnin}, P = {p}')

    r_stats = np.ones((len(possible_lengths),p))*1e9
    converged_chain_length = 0
    # Loop over possible burnins
    for i,chain_length in enumerate(possible_lengths):
        if prints:
            print(f'Checking convergence with chain length = {chain_length-burnin}')

        # Calculate between-chain variance
        B_over_m = np.sum([(np.mean(samples[:,burnin:(chain_length),:], axis=1)[j,:] - np.mean(samples[:,burnin:(chain_length),:],axis=(0,1)))**2 for j in range(m)],axis=0) / (m - 1)
    
        # Calculate within-chain variances
        W = np.sum([(samples[i,burnin:(chain_length)] - xbar) ** 2 for i,xbar in enumerate(np.mean(samples[:,burnin:(chain_length)],1))],axis=(0,1)) / (m * (chain_length-burnin - 1))
    
        # (over) estimate of variance
        s2 = W * (chain_length-burnin-1) / (chain_length-burnin) + B_over_m

        # Pooled posterior variance estimate
        V = s2 + B_over_m / m

        # Calculate PSRF
        r_stat = V / W
        r_stats[i] = r_stat

        # Print if chains have converged
        if all(r_stat < r_critical):
            if prints:
                print(r'Vanilla MCMC chains have converged!')
                print(pd.DataFrame(r_stat))
                print(f'Chain length: {chain_length-burnin}')
                print(f'Burnin: {burnin}')
                prints = False
            converged_chain_length = chain_length-burnin

    if any(r_stat >= r_critical):
        print(r'Vanilla MCMC chains have NOT converged ...')
        print(pd.DataFrame(r_stat))
        
    return r_stats, converged_chain_length


In [None]:
# Define burnin and chain length step size
burnin_period = 2000
chain_length_step_size = 1000
N = theta_multiple_samples.shape[1]
r_critical_value = 1.1
maxN = min(burnin+8000,N)
# Get chain lengths
chain_lengths = np.array(list(range(burnin_period+chain_length_step_size,maxN,chain_length_step_size)))

In [None]:
theta_r_stats,theta_chain_length = gelman_rubin_criterion(
                                        theta_multiple_samples[:,:maxN,:],
                                        burnin=burnin_period,
                                        step=chain_length_step_size,
                                        r_critical=r_critical_value
                                    )

In [None]:
destination_attraction_r_stats,destination_attraction_chain_length = gelman_rubin_criterion(
                                                                        log_destination_attraction_multiple_samples[:,:maxN,:],
                                                                        burnin=burnin_period,
                                                                        step=chain_length_step_size,
                                                                        r_critical=r_critical_value
                                                                    )

# Plot R statitistic with chain length

In [None]:
plt.figure(figsize=(15,10))
for i in range(np.shape(theta_r_stats)[1]):
    plt.plot(chain_lengths-burnin_period,theta_r_stats[:,i],label=fr'${sim.parameter_names[i]}$')
plt.plot(chain_lengths-burnin_period,np.ones(len(chain_lengths))*r_critical_value,label='$R_{critical}$',color='red')
plt.ylabel(r'$R$ statistic',fontsize=15)
plt.xlabel('MCMC chain length',fontsize=15)
plt.locator_params(axis='x', nbins=20)
# plt.xticks((chain_lengths-burnin_period)[::2],(chain_lengths-burnin_period)[::2])
plt.legend(frameon=False,prop={'size': 15})
plt.savefig(os.path.join(dirpath,f'figures/r_statistic_parameters_vs_mcmc_chain_length.{figure_format}'),format=figure_format)
plt.show()

In [None]:
plt.figure(figsize=(15,10))
for i in range(np.shape(destination_attraction_r_stats)[1]):
    plt.plot(chain_lengths-burnin_period,destination_attraction_r_stats[:,i],label=fr'${sim.parameter_names[i]}$')
plt.plot(chain_lengths-burnin_period,np.ones(len(chain_lengths))*r_critical_value,label='$R_{critical}$',color='red')
plt.ylabel(r'$R$ statistic',fontsize=15)
plt.xlabel('MCMC chain length',fontsize=15)
plt.locator_params(axis='x', nbins=20)
# plt.xticks((chain_lengths-burnin_period)[::2],(chain_lengths-burnin_period)[::2])
plt.legend(frameon=False,prop={'size': 15})
plt.savefig(os.path.join(dirpath,f'figures/r_statistic_log_destination_attraction_vs_mcmc_chain_length.{figure_format}'),format=figure_format)
plt.show()