In [1]:
import os
import gc
import sys
import json
import gzip
import glob
import itertools
import numpy as np
import pandas as pd
import sklearn.manifold
import matplotlib as mpl
import matplotlib.pyplot as plt

from math import comb
from tqdm import tqdm
from copy import deepcopy
from argparse import Namespace
from collections import OrderedDict
from functools import partial
from scipy.stats import multinomial

from ticodm.utils import *
from ticodm.spatial_interaction_model import ProductionConstrained
from ticodm.contingency_table_mcmc import ContingencyTableMarkovChainMonteCarlo
from ticodm.contingency_table import instantiate_ct

mpl.rcParams['agg.path.chunksize'] = 10000

# Import samples

In [2]:
# Expertiment id
experiment_data = 'synthetic_2x3_N_5000'
experiment_id = 'exp3_direct_sampling'
# Expertiment type
experiment_type = 'TableMCMC'
# Expertiment date
date = '27_06_2022'
# exp8c_TableMCMC_19_04_2022
# Comment
comment = 'using_direct_sampling' 
# comment = 'using_degree_higher_markov_basis'
# comment = 'using_degree_one_markov_basis'
# comment = 'using_direct_sampling'


# Define directory
dirpath = f'../data/outputs/{experiment_data}_{experiment_id}_{experiment_type}_{date}/'
# Define filepaths
metadata_filename = os.path.join(dirpath,f'{experiment_data}_{experiment_id}_{experiment_type}_{date}_metadata.json')
table_filename = os.path.join(dirpath,f'samples/table*_samples.npy')
table_filenames = glob.glob(table_filename)
destination_dem_filename = os.path.join(dirpath,f'samples/destination_demand*_samples.npy')
destination_dem_filenames = glob.glob(destination_dem_filename)

In [3]:
# Load files into memory
with open(metadata_filename, 'r') as fin:
    metadata = json.load(fin)
N = int(metadata['mcmc']['N'])
batch_size = int(metadata['store_progress']*N)

FileNotFoundError: [Errno 2] No such file or directory: '../data/outputs/synthetic_2x3_N_5000_exp3_direct_sampling_TableMCMC_27_06_2022/synthetic_2x3_N_5000_exp3_direct_sampling_TableMCMC_27_06_2022_metadata.json'

In [None]:
print('MCMC took %s minutes and %s seconds' % (divmod(metadata['execution_time'], 60)))
print(f"{N} samples taken")

In [None]:
metadata['inputs']['dataset'] = '.'+metadata['inputs']['dataset']
# Reconstruct expected flows
metadata_copy = deepcopy(metadata)
# del metadata_copy['mcmc']['contingency_table']['column_sum_proposal']
metadata_copy['mcmc']['contingency_table']['proposal'] = 'direct_sampling'
dummy_config = Namespace(**{'settings':metadata_copy})
ct = instantiate_ct(dummy_config)
sim = ProductionConstrained(dummy_config)
ct_mcmc = ContingencyTableMarkovChainMonteCarlo(ct)

In [None]:
# Read important metadata (true latent values)
true_table = ct.table
colsums = true_table.sum(axis=0)
rowsums = true_table.sum(axis=1)
I,J = len(rowsums),len(colsums)
log_lambdas = np.asarray(metadata['log_true_intensities'],dtype='float32')
log_colsum_lambdas = np.ones(J)
for j in range(J):
    log_colsum_lambdas[j] = logsumexp(log_lambdas[:,j])
# Get size of support over tables
# THIS TAKES TOO LONG FOR LARGE TABLES - CANNOT ENUMERATE SUPPORT FAST
# table_support_size = 1
# for i in tqdm(range(I)):
#     supp = [x for x in itertools.product(range(1,ct.rowsums[i]-J+2), repeat=J) if sum(x) == ct.rowsums[i]]
#     assert not np.any([0 in x for x in supp])
#     table_support_size *= len(supp)
# I*np.prod([comb(ct.rowsums[i]+J-1,J-1) for i in range(I)])

# Flag for re computing sample statistics
recompute = True
# Decide on figure format
figure_format = 'eps' # 'eps','png'

In [None]:
# Define sample sizes so that statistics will be compute every MCMC interval
table_sample_step = 1
assert N%table_sample_step == 0
table_burnin = 0
table_chain_length = int(20000)
maxN = int(min(table_chain_length,N))
table_sample_sizes = list(range(table_burnin,maxN,table_sample_step))

In [None]:
# f = gzip.GzipFile("../data/outputs/synthetic_2x3_exp8_degree_one_TableMCMC_31_05_2022/samples/table_samples.npy.gz", "r")
# samples = np.load(f)

In [None]:
# write_npy(
#     samples,
#     "../data/outputs/synthetic_2x3_exp8_degree_one_TableMCMC_31_05_2022/samples/destination_demand_samples.npy"
# )

In [None]:
# Initialise running table mean and total samples
running_table_mean = np.zeros((1,I,J),dtype='float32')
total_samples = 0
break_flag = False

for i,file in tqdm(enumerate(sorted(table_filenames)),total=len(table_filenames)):
    print(f'Reading batch {i}')
    # Read table batch
    table_sample_batch = np.load(file,mmap_mode='r')

    # Compute running mean for current batch
    print('Running average multivariate')
    latest_running_means = running_average_multivariate(
                                table_sample_batch,
                                running_table_mean[-1],
                                total_samples
                            )
    latest_batch_size = latest_running_means.shape[0]
    last_sample = (table_chain_length-total_samples)
    
    # Clear memmory
    del table_sample_batch
    gc.collect()
    print('Appending results and clearing memory')
        
    # Keep only Kth mean
    if last_sample > latest_batch_size:
        if i == 0:
            running_table_mean = np.append(
                                    latest_running_means[table_burnin::table_sample_step],
                                    np.expand_dims(latest_running_means[-1],axis=0),
                                    axis=0
                                )
        else:
            # If the remaining samples to be averaged over exceed the batch size continue as normal
            running_table_mean = np.concatenate(
                                    [running_table_mean,
                                    latest_running_means[table_sample_step::table_sample_step],
                                    np.expand_dims(latest_running_means[-1],axis=0)],
                                    axis=0
                                )
    else:
        break_flag = True
        if i == 0:
            running_table_mean = latest_running_means[table_burnin:(last_sample):table_sample_step]
        else:
            # If the remaining samples to be averaged over exceed the batch size continue as normal
            running_table_mean = np.concatenate(
                                    [running_table_mean,
                                    latest_running_means[table_sample_step:(last_sample):table_sample_step]],
                                    axis=0
                                )
    
    # Update total number of samples
    total_samples += latest_running_means.shape[0]
    latest_batch_size = latest_running_means.shape[0]
    last_sample = (table_chain_length-total_samples)
    
    # Clear memory
    del latest_running_means
    gc.collect()
    
    # If total number of sample exceeds specified total - stop
    if break_flag:
        print('Breaking early. Found more samples than required.')
        break

In [None]:
sample_mean_error_l1_norms = np.zeros(running_table_mean.shape[0])
sample_mean_error_l2_norms = np.zeros(running_table_mean.shape[0])
for i,s in tqdm(enumerate(table_sample_sizes),total=len(table_sample_sizes)):
    sample_mean_error_l1_norms[i] = relative_l1(
                        tab0=np.exp(log_lambdas),
                        tab=running_table_mean[i]
    )
    sample_mean_error_l2_norms[i] = relative_l2_norm(
                        tab0=np.exp(log_lambdas),
                        tab=running_table_mean[i]
    )

In [None]:
# # Define experiment directory
# experiment_directory = os.path.basename(metadata['inputs']['dataset'])+'_'+metadata['experiment_id']+'_'+metadata['type']+'_'+metadata['datetime']
# # Create subdirectory for estimator statistics
# makedir(os.path.join(
#         metadata['outputs']['directory'],
#         experiment_directory,
#         'sample_estimators')
# )
# print('Writing compressed npy. This is going to take a while...')
# # Writing estimator data to compressed npy
# write_compressed_npy(
#     running_table_mean,
#     os.path.join(
#         metadata['outputs']['directory'],
#         experiment_directory,
#         'sample_estimators',
#         f"table_mean_burnin_{table_burnin}_N_{table_chain_length}_step_{table_sample_step}_{comment}.npy.gz"
#     )
# )

In [None]:
plt.figure(figsize=(7,5))
plt.plot(table_sample_sizes,sample_mean_error_l1_norms)
plt.xlabel('MCMC iteration',fontsize=16)
plt.ylabel(r'Relative $L_1$ of $\mathbb{E}[\mathbf{n}|\mathbf{n}_{\cdot,+}]$',fontsize=16)
plt.locator_params(axis='x', nbins=10)
plt.axhline(y=0,color='red')
# plt.savefig(os.path.join(dirpath,f'figures/expected_table_relative_l1_with_mcmc_iteration_chain_length_{table_chain_length}_{comment}.{figure_format}'),format=figure_format)

In [None]:
plt.figure(figsize=(7,5))
plt.plot(table_sample_sizes,sample_mean_error_l2_norms)
plt.xlabel('MCMC iteration',fontsize=16)
plt.ylabel(r'Relative $L_2$ of $\mathbb{E}[\mathbf{n}|\mathbf{n}_{\cdot,+}]$',fontsize=16)
plt.locator_params(axis='x', nbins=20)
plt.axhline(y=0,color='red')
# plt.savefig(os.path.join(dirpath,f'figures/expected_table_relative_l2_norm_with_mcmc_iteration_chain_length_{table_chain_length}_{comment}.{figure_format}'),format=figure_format)

In [None]:
# Define sample sizes so that statistics will be compute every MCMC interval
colsum_sample_step = 1
assert N%colsum_sample_step == 0
colsum_burnin = 0
colsum_chain_length = int(1e6)
assert colsum_chain_length%batch_size == 0
maxN = int(min(colsum_burnin+colsum_chain_length,N))
colsum_sample_sizes = list(range(colsum_burnin,maxN+1,colsum_sample_step))
colsum_sample_sizes[0] = 1

In [None]:
# Initialise running table mean and total samples
running_colsum_mean = np.zeros((1,J))
total_colsum_samples = 0

for i,file in tqdm(enumerate(sorted(destination_dem_filenames)),total=len(destination_dem_filenames)):
    print(f'Reading file {i}')
    # Read table batch
    colsum_sample_batch = np.load(file,mmap_mode='r')

    # Compute running mean for current batch
    print('Running average multivariate')
    latest_colsum_running_means = running_average_multivariate(
                                        colsum_sample_batch,
                                        running_colsum_mean,
                                        total_colsum_samples
                                )
    # Clear memmory
    del colsum_sample_batch
    gc.collect()
    
    print('Appending results and clearing memmory')
    # Keep only Kth mean
    if i == 0:
        running_colsum_mean = np.append(
                                latest_colsum_running_means[0::colsum_sample_step],
                                np.expand_dims(latest_colsum_running_means[-1],axis=0),
                                axis=0
                            )
    else:
        running_colsum_mean = np.concatenate(
                                [running_colsum_mean,
                                latest_colsum_running_means[table_sample_step::colsum_sample_step],
                                np.expand_dims(latest_colsum_running_means[-1],axis=0)],
                                axis=0
                            )

    # Update total number of samples
    total_colsum_samples += latest_colsum_running_means.shape[0]
    
    # Clear memory
    del latest_colsum_running_means
    gc.collect()
    
    # If total number of sample exceeds specified total - stop
    if total_colsum_samples >= (colsum_chain_length-batch_size):
        print('Breaking early. Found more samples than required.')
        break

In [None]:
colsum_sample_mean_error_l1_norms = np.zeros(running_colsum_mean.shape[0])
colsum_sample_mean_error_l2_norms = np.zeros(running_colsum_mean.shape[0])
ground_truth_colsum_intensities = np.exp(log_lambdas).sum(axis=0)
for i,s in tqdm(enumerate(colsum_sample_sizes),total=colsum_sample_sizes.shape[0]):
    colsum_sample_mean_error_l1_norms[i] = relative_l1(
                        tab0=ground_truth_colsum_intensities,
                        tab=running_colsum_mean[i]
    )
    colsum_sample_mean_error_l2_norms[i] = relative_l2_norm(
                        tab0=ground_truth_colsum_intensities,
                        tab=running_colsum_mean[i]
    )

In [None]:
# Define experiment directory
experiment_directory = os.path.basename(metadata['inputs']['dataset'])+'_'+metadata['experiment_id']+'_'+metadata['type']+'_'+metadata['datetime']
# Create subdirectory for estimator statistics
makedir(os.path.join(
        metadata['outputs']['directory'],
        experiment_directory,
        'sample_estimators')
)
print('Writing compressed npy.')
# Writing estimator data to compressed npy
write_compressed_npy(
    running_colsum_mean,
    os.path.join(
        metadata['outputs']['directory'],
        experiment_directory,
        'sample_estimators',
        f"colsum_mean_burnin_{colsum_burnin}_N_{colsum_chain_length}_step_{colsum_sample_step}_{comment}.npy.gz"
    )
)

In [None]:
plt.figure(figsize=(7,5))
plt.plot(colsum_sample_sizes[0:stop],colsum_sample_mean_error_l1_norms[0:stop])
plt.xlabel('MCMC iteration',fontsize=16)
plt.ylabel(r'Relative $L_1$ of $\mathbb{E}[\mathbf{n}_{+,\cdot}|\mathbf{n}_{+,+}]$',fontsize=16)
plt.locator_params(axis='x', nbins=10)
plt.axhline(y=0,color='red')
plt.savefig(os.path.join(dirpath,f'figures/expected_colsum_relative_l1_with_mcmc_iteration_N_{colsum_chain_length}_{comment}.{figure_format}'),format=figure_format)

In [None]:
plt.figure(figsize=(7,5))
plt.plot(colsum_sample_sizes[0:stop],colsum_sample_mean_error_l2_norms[0:stop])
plt.xlabel('MCMC iteration',fontsize=16)
plt.ylabel(r'Relative $L_2$ of $\mathbb{E}[\mathbf{n}_{+,\cdot}|\mathbf{n}_{+,+}]$',fontsize=16)
plt.locator_params(axis='x', nbins=10)
plt.axhline(y=0,color='red')
plt.savefig(os.path.join(dirpath,f'figures/expected_colsum_relative_l2_norm_with_mcmc_iteration_N_{colsum_chain_length}_{comment}.{figure_format}'),format=figure_format)

# Postprocessing 
Computing statistics over samples

In [None]:
table_samples = []
for file in sorted(glob.glob(table_filename)):
    # Load files into memory
    table_sample_batch = read_npy(file)
    table_samples.append(table_sample_batch)
table_samples = np.concatenate(table_samples,axis=0)
N = table_samples.shape[0]

colsum_samples = []
for file in sorted(glob.glob(destination_dem_filename)):
    # Load files into memory
    colsum_sample_batch = read_npy(file)
    colsum_samples.append(colsum_sample_batch)
colsum_samples = np.concatenate(colsum_samples,axis=0)
N = colsum_samples.shape[0]

In [None]:
table_samples_flattened = table_samples.reshape(table_samples.shape[0],-1)

In [None]:
# if recompute or not os.path.exists(os.path.join(dirpath,f'samples/table_strings.txt.gz')):
table_strings = np.empty(table_samples.shape[0],dtype=object)
for i,ts in tqdm(enumerate(table_samples),total=table_samples.shape[0]):
    table_strings[i] = table_to_str(ts)
#     np.savetxt(fname=os.path.join(dirpath,f'samples/table_strings.txt.gz'),X=np.asarray(table_strings), fmt="%s")
# else:
#     table_strings = np.loadtxt(os.path.join(dirpath,f'samples/table_strings.txt.gz'),dtype=str)

In [None]:
# if recompute or not os.path.exists(os.path.join(dirpath,f'samples/column_sum_strings.txt.gz')):
colsum_strings = np.empty(colsum_samples.shape[0],dtype=object)
for i,cs in tqdm(enumerate(colsum_samples),total=colsum_samples.shape[0]):
    colsum_strings[i] = table_to_str(cs.astype('int32'))
# np.savetxt(fname=os.path.join(dirpath,f'samples/column_sum_strings.txt.gz'),X=np.asarray(colsum_strings), fmt="%s")
# else:
#     colsum_strings = np.loadtxt(os.path.join(dirpath,f'samples/column_sum_strings.txt.gz'),dtype=str)

In [None]:
# if recompute or not os.path.exists(os.path.join(dirpath,f'samples/table_similarity.txt.gz')):
# Compute table similarity (log target difference)
tab0_evaluation = log_product_multinomial_pmf(true_table,log_lambdas)
table_similarity_measure_partial = partial(
                                        table_similarity_measure,
                                        tab0=tab0_evaluation,
                                        log_cell_intensities=np.ascontiguousarray(log_lambdas)
                                    )
table_similarity = np.empty(table_samples.shape[0])
for i,ts in tqdm(enumerate(table_samples),total=table_samples.shape[0]):
    table_similarity[i] = table_similarity_measure_partial(ts)
# np.savetxt(fname=os.path.join(dirpath,f'samples/table_similarity.txt.gz'),X=np.asarray(table_similarity))
# else:
#     tab0_evaluation = log_product_multinomial_pmf(np.asarray(true_table),np.asarray(log_lambdas))
#     table_similarity = np.loadtxt(os.path.join(dirpath,f'samples/table_similarity.txt.gz'))

In [None]:
# if recompute or not os.path.exists(os.path.join(dirpath,f'samples/colsum_similarity.txt.gz')):
# Compute table similarity (log target difference)
colsum0_evaluation = log_multinomial_pmf(ct.colsums,np.log(ct.colsums))
colsum_similarity_measure_partial = partial(
                                            column_sum_similarity_measure,
                                            csum0=colsum0_evaluation,
                                            log_cell_intensities=np.ascontiguousarray(log_lambdas)
                                    )
colsum_similarity = np.empty(colsum_samples.shape[0])
for i,cs in tqdm(enumerate(colsum_samples),total=colsum_samples.shape[0]):
    colsum_similarity[i] = colsum_similarity_measure_partial(csum=cs.astype('int32'))
#     sys.exit()
# np.savetxt(fname=os.path.join(dirpath,f'samples/colsum_similarity.txt.gz'),X=np.asarray(colsum_similarity))
# else:
#     colsum0_evaluation = log_multinomial_pmf(ct.table,np.log(ct.true_colsums))
#     colsum_similarity = np.loadtxt(os.path.join(dirpath,f'samples/colsum_similarity.txt.gz'))

In [None]:
# %%time

# if recompute or not os.path.exists(os.path.join(dirpath,f'samples/table_discrepancy.json.gz')):

#     # Compute table differences
#     with Pool(processes=3) as pool:
#         table_difference_histogram_partial = partial(ct.table_difference_histogram,tab0=true_table)
#         table_discrepancy = list(tqdm(pool.imap(table_difference_histogram_partial, table_samples), position=0, leave=True, total=len(table_samples)))

#     with open(os.path.join(dirpath,f'samples/table_discrepancy.json.gz'), 'w') as outfile:
#         outfile.write(json.dumps(table_discrepancy))
# else:
#     with open(os.path.join(dirpath,f'samples/table_discrepancy.json.gz')) as json_file:
#         table_discrepancy = json.load(json_file)

In [None]:
# Collate all data into dataframe
# if recompute or not os.path.exists(os.path.join(dirpath,f'samples/table_statistics.csv.gz')):
table_data = []
table_data_names = []
var_names = ['table_similarity','colsum_similarity','table_discrepancy',
             'table_distances','table_strings','colsum_strings']
var_titles = ['table_similarity','colsum_similarity','discrepancy',
              'distance','table_strings','colsum_strings']
for attr,attr_name in zip(var_names,var_titles):
    if attr in locals():
        print(f'Adding {attr}')
        table_data.append(np.asarray(globals()[attr]))
        table_data_names.append(attr_name)
table_data = np.asarray(table_data).reshape((len(table_data_names),len(table_samples)))
table_stats = pd.DataFrame(table_data.T,index=range(len(table_samples)),columns=table_data_names)
print('Built df')
# if 'discrepancy' in table_stats.columns.values:
#     table_stats[['+','-','0']] = pd.json_normalize(table_stats['discrepancy'])
#     table_stats['not0'] = table_stats['+']+table_stats['-']
#     table_stats.drop(['discrepancy'], axis=1,inplace=True)
# if 'distance' in table_stats.columns.values:
#     table_stats['updated_distance'] = table_stats['not0']*table_stats['distance']

# Compute log target
table_stats['table_similarity'] = table_stats['table_similarity'].astype('float32')
table_stats['colsum_similarity'] = table_stats['colsum_similarity'].astype('float32')
table_stats['table_log_target'] = -(table_stats['table_similarity'] - tab0_evaluation)
table_stats['colsum_log_target'] = -(table_stats['colsum_similarity'] - colsum0_evaluation)

# Export table to csv
# table_stats.to_csv(os.path.join(dirpath,f'samples/table_statistics.csv.gz'))


#     for attr,attr_name in zip(var_names,var_titles):
#         if attr in locals():
#             print(f'Removing {attr} from memmory') 
#             del globals()[attr]
# else:
#     table_stats = pd.read_csv(os.path.join(dirpath,f'samples/table_statistics.csv.gz'))
    
# if recompute or not os.path.exists(os.path.join(dirpath,f'samples/unique_table_statistics.csv.gz')):
print('Getting unique table')
# Create df with unique tables
table_stats_unique = table_stats.drop_duplicates(subset=['table_strings'])
column_stats_unique = table_stats.drop_duplicates(subset=['colsum_strings'])
print('Adding table frequencies')
table_stats_unique = pd.merge(table_stats_unique,table_stats.groupby('table_strings')['table_similarity'].count().reset_index(name='frequency'),on='table_strings',how='left')
column_stats_unique = pd.merge(column_stats_unique,table_stats.groupby('colsum_strings')['colsum_similarity'].count().reset_index(name='frequency'),on='colsum_strings',how='left')
# print('Storing dfs')
# table_stats_unique.to_csv(os.path.join(dirpath,f'samples/unique_table_statistics.csv.gz'))
# column_stats_unique.to_csv(os.path.join(dirpath,f'samples/unique_column_sum_statistics.csv.gz'))
# else:
#     table_stats_unique = pd.read_csv(os.path.join(dirpath,f'samples/unique_table_statistics.csv.gz'))
#     table_stats_unique.drop(columns=['Unnamed: 0','Unnamed: 0.1'],inplace=True,errors='ignore')
#     column_stats_unique = pd.read_csv(os.path.join(dirpath,f'samples/unique_column_sum_statistics.csv.gz'))
#     column_stats_unique.drop(columns=['Unnamed: 0','Unnamed: 0.1'],inplace=True,errors='ignore')

# Table sampling convergence

In [None]:
## Get maximum a posteriori tables based on empirical distribution and log target distribution

In [None]:
maximum_a_posteriori_table = str_to_array(
    table_stats_unique.sort_values('table_log_target',ascending=False).head(10)['table_strings'].values[0],dims=(I,J)
)

In [None]:
maximum_a_posteriori_table

In [None]:
emprirical_maximum_a_posteriori_table = str_to_array(
    table_stats_unique.sort_values('frequency',ascending=False).head(10)['table_strings'].values[1],dims=(I,J)
)

In [None]:
emprirical_maximum_a_posteriori_table

In [None]:
true_table

## Compare empirical frequency of samples versus target measure

In [None]:
table_stats_unique.sort_values('frequency',ascending=False).head(10)

In [None]:
str_to_table(table_stats_unique.sort_values('colsum_log_target',ascending=False).head(1)['colsum_strings'].values[0],dims=(1,J))

In [None]:
table_stats_unique.sort_values('table_log_target',ascending=False).head(10)

## Table distance/similarity measures

In [None]:
# _ = table_stats.plot.scatter('distance','similarity', figsize=(10, 10))

In [None]:
# Define metric of interest
metric = 'table_similarity'
# 'normed_inner_product'
# 'table_similarity'
# 'colsum_similarity'
# 'distance'
# 'updated_distance'
# 'difference_frobenious_norm'
metric_title = 'Log target difference from true table'
# 'Normalised inner product wrt true table'
# 'Normalised Manhattan distance from true table'
# 'Log target difference from true column sums
# 'Log target difference from true table'
# 'Sum of differences from true table'
# Resort table statistics based on metric of interest
table_stats_unique = table_stats_unique.sort_values(metric)

In [None]:
table_stats = table_stats.sort_index()
table_stats.reset_index().iloc[table_burnin:,:].plot(y=metric,x='index',figsize=(20,10),legend=False)
plt.xlabel('MCMC iteration',fontsize=20)
plt.ylabel(metric_title,fontsize=20)
# _ = plt.legend()
plt.title(f'Burnin = {table_burnin}')
# plt.savefig(os.path.join(dirpath,f'figures/{metric}_metric_versus_mcmc_iterations_{comment}.{figure_format}'),format=figure_format)

In [None]:
# REALLY EXPENSIVE COMPUTATION
# fig, (ax1,ax2) = plt.subplots(1,2, figsize=(10,10))
# table_stats_unique[['0','+','-']].plot.barh(color={"0": "green", "+": "red", "-": "orange"},stacked=True,ax=ax1,width=1.0)
# table_stats_unique[metric].plot.barh(color="blue",stacked=True,ax=ax2,width=1.0)
# ax1.set_title('Proportions of +ve,-ve and zero table differences from true table',fontsize=12)
# _ = ax1.set_xticks([], [])
# _ = ax1.set_yticks([], [])
# _ = ax1.get_legend().remove()
# _ = ax1.set_ylabel('Table samples',fontsize=16)
# ax2.set_title(metric_title,fontsize=12)
# _ = ax2.set_xticks([], [])
# _ = ax2.set_yticks([], [])
# fig.tight_layout()
# fig.savefig(os.path.join(dirpath,f'figures/{metric}_metric_barplot_with_table_differences.eps'),format='eps')

## Table histograms

In [None]:
# Count number of times true table was reconstructed
# and build a histogram over table
table_histogram = {}
for tab in tqdm(table_samples[table_burnin::table_sample_step]):
    # If encountered again add one to frequency
    if table_to_str(tab) in table_histogram:
        table_histogram[table_to_str(tab)] += 1
    # If not encountered again create frequency 1
    else:
        table_histogram[table_to_str(tab)] = 1
    
# Sort histogram lexicographically
# table_histogram = {k: v for k, v in sorted(table_histogram.items(), key=lambda item: item[0])}
# Sort histogram by frequency
table_histogram = {k: v for k, v in sorted(table_histogram.items(), key=lambda item: -item[1])}
# Get total number of samples
n_table_samples = table_samples.shape[0]
# Normalise frequencies to get empirical probabilities
table_probabilities = {k: v / n_table_samples for k, v in table_histogram.items()}

if str_in_list(table_to_str(true_table),table_histogram.keys()):
    print('True table frequency',table_histogram[table_to_str(true_table)])

In [None]:
# # THIS WILL NOT WORK BECAUSE I AM NOT OBSERVING THE ENTIRE SUPPORT OF TABLES!!!
# ## Use table histogram to get support over all tables
# true_log_target_distribution = {}
# for str_tab in tqdm(table_histogram.keys()):
#     true_log_target_distribution[str_tab] = log_product_multinomial_pmf(
#             table=str_to_table(str_tab,dims=(I,J)),
#             log_cell_intensities=log_lambdas
#     ) + log_factorial(ct.margins[range(self.ndims())])
# # Convert to pandas
# # true_log_target_distribution_keys, true_log_target_distribution_values = true_target_distribution.keys(), true_target_distribution.values()
# # true_log_target_distribution = pd.DataFrame({
# #     "strings":true_log_target_distribution_keys,
# #     "frequency":true_log_target_distribution_values
# # })

In [None]:
# Check whether true table is identified and in which order
# Check how many tables have the right column sums
true_table_index = -1
true_colsums_indices = []
true_colsums_table_strings = []
for i,keyvalue in tqdm(enumerate(table_histogram.items()),total=len(table_histogram.items())):
    # Monitor correct table samples
    if np.array_equal(str_to_table(keyvalue[0],dims=(I,J)),true_table):
        true_table_index = i
    # Monitor table samples with corrent column sums
    if np.all(abs(str_to_table(keyvalue[0],dims=(I,J)).sum(axis=0) - true_table.sum(axis=0)) <= 1e-9):
        true_colsums_indices.append(i)
        true_colsums_table_strings.append(keyvalue[0])

In [None]:
print(len(table_histogram.keys()),f'distinct tables sampled.')
# print(table_support_size,f'distinct tables exist in support.')
print(f'True table was the {true_table_index+1} most frequently sampled table')

In [None]:
plt.figure(figsize=(20,10))
plt.bar(range(len(table_histogram)), list(table_histogram.values()), align='center', label='samples')
if len(true_colsums_indices) > 0:
    plt.bar(true_colsums_indices, np.array(list(table_histogram.values()))[true_colsums_indices], align='center',color='green',label='true colsums')
if true_table_index >= 0:
    plt.bar(true_table_index, list(table_histogram.values())[true_table_index], align='center',color='red',label='true table')
_ = plt.xlim(-1,200)
_ = plt.ylabel('Frequency',fontsize=20)
_ = plt.xlabel('Table string',fontsize=20)
_ = plt.legend(fontsize=15)
# plt.savefig(os.path.join(dirpath,f'figures/table_histogram_{comment}.{figure_format}'),format=figure_format)

In [None]:
print('MAP table')
str_to_table(list(table_histogram.keys())[0],dims=(I,J))

In [None]:
print('True table')
ct.table

In [None]:
print('Difference')
str_to_table(list(table_histogram.keys())[0],dims=(I,J))-ct.table

In [None]:
# Slice table histogram to obtain histogram of tables with true column sums
true_table_histogram = OrderedDict((k, table_histogram[k]) for k in true_colsums_table_strings)
# Sort histogram lexicographically
true_table_histogram = {k: v for k, v in sorted(true_table_histogram.items(), key=lambda item: item[0])}
# Store updated index of true table
true_table_new_index = -1
if str_in_list(table_to_str(true_table),true_table_histogram.keys()):
    true_table_new_index = list(true_table_histogram.keys()).index(table_to_str(true_table))

print(f"{int(100*np.sum(list(true_table_histogram.values()))/len(table_samples[table_burnin::table_sample_step]))}% of table samples matched true column sums")
print('True table frequency',np.sum(list(true_table_histogram.values())))

In [None]:
plt.figure(figsize=(20,10))
if len(true_table_histogram) > 0:
    plt.bar(range(len(true_table_histogram)), np.fromiter(true_table_histogram.values(),dtype=int), align='center',color='green',label='true colsums')
if true_table_new_index > 0:
    plt.bar(true_table_new_index, true_table_histogram[table_to_str(true_table)], align='center',color='red',label='true table')
_ = plt.xticks(range(len(true_table_histogram)), list(true_table_histogram.keys()),rotation=50)
_ = plt.xlim(-1,len(true_table_histogram))
_ = plt.ylabel('Frequency',fontsize=20)
_ = plt.xlabel('Table string',fontsize=20)
_ = plt.legend(fontsize=20)
# plt.savefig(os.path.join(dirpath,f'figures/true_table_histogram_{comment}.{figure_format}'),format=figure_format)

# margin sampling convergence

In [None]:
# Heuristically search for mode of column sum distribution (multinomial) based on generated probabilities
colsum_distribution_mode,_,_,_,_ = ct_mcmc.mode_estimate_proposal_1way_table_multinomial(
                                                    colsum_prev=colsums,
                                                    log_cell_intensities=log_lambdas
                                    )
print('Column sums mode',colsum_distribution_mode)

In [None]:
print('Difference with true column sums',np.asarray(colsum_distribution_mode - colsums))

In [None]:
# Count number of times true table was reconstructed
# and build a histogram over table
colsum_histogram = {}
for col in tqdm(colsum_samples):
    if table_to_str(col.astype('int32')) in colsum_histogram:
        colsum_histogram[table_to_str(col.astype('int32'))] += 1
    else:
        colsum_histogram[table_to_str(col.astype('int32'))] = 1
    
# Sort histogram lexicographically
# colsum_histogram = {k: v for k, v in sorted(colsum_histogram.items(), key=lambda item: item[0])}
# Sort histogram by frequency
colsum_histogram = {k: v for k, v in sorted(colsum_histogram.items(), key=lambda item: -item[1])}

n_colsum_samples = sum(colsum_histogram.values(), 0.0)
colsum_probabilities = {k: v / n_colsum_samples for k, v in colsum_histogram.items()}
total_flow = np.sum(colsums)

if str_in_list(table_to_str(colsums),colsum_histogram.keys()):
    print('True colsums frequency',colsum_histogram[table_to_str(colsums)])

# Store frequency of samples according to multinomial
_,log_probs,_ = ct_mcmc.log_intensities_to_multinomial_log_probabilities(log_lambdas)
multinomial_colsums_histogram = {}
for k in tqdm(colsum_histogram.keys()):
    # Append table count to histogram
    multinomial_colsums_histogram[k] = round(n_colsum_samples*multinomial.pmf(str_to_array(k,dims=(J)),n=total_flow,p=np.exp(log_probs)))

In [None]:
# Check whether true table is identified
# Check how many tables have the right column sums
true_colsums_index = -1
for i,keyvalue in enumerate(colsum_histogram.items()):
    # Monitor table samples with corrent column sums
    if np.all(abs(str_to_array(keyvalue[0],dims=(J)) - colsums) <= 1e-9):
        true_colsums_index = i

In [None]:
width = 1
plt.figure(figsize=(15,10))
plt.bar(range(len(colsum_histogram)), list(colsum_histogram.values()), align='center',label='samples', width = width)
# plt.bar(np.array(range(len(multinomial_colsums_histogram)))-width, list(multinomial_colsums_histogram.values()), align='center',label='multinomial', width = width)
# _ = plt.xticks(range(len(colsum_histogram)), list(colsum_histogram.keys()),rotation=80)
plt.ylabel('Frequency',fontsize=20)
plt.xlabel('Table string',fontsize=20)
plt.xlim(-1,len(colsum_histogram))
plt.bar(true_colsums_index, list(colsum_histogram.values())[true_colsums_index], fill=False, linewidth=width, align='center',edgecolor='red',label='true', width = width)
plt.legend()
plt.savefig(os.path.join(dirpath,f'figures/true_colsums_histogram_{comment}.{figure_format}'),format=figure_format)

# Table and margin convergence
## Define metrics
1. Total variation distance between empirical measure and target measure
1. Kullback Leibler divergence between empirical measure and target measure

In [None]:
# Compute total probability of observed support
total_probability = 0
min_probability,max_probability = 1,0
for k in np.unique(table_strings):
    prob = np.exp(log_product_multinomial_pmf(str_to_table(k,dims=(I,J)),log_lambdas) + np.sum(log_factorial_vectorised(1,ct.rowsums)))
    if prob < min_probability:
        min_probability = prob
    elif prob > max_probability:
        max_probability = prob
    total_probability += prob

In [None]:
# Define total variation
def total_variation(table_statistics,support,log_cell_intensities,total_prob,column_name:str='table_strings'):
    # Get emprical distribution of tables from samples
    empirical_log_distribution = table_statistics.groupby(
                                [column_name]
                            ).agg(
                                {column_name:'count'}
                            ).rename(
                                columns={column_name:"frequency"}
                            ).reset_index().values
    # Convert two lists into dict
    log_distribution = np.log(empirical_log_distribution[:,1].astype('float32')) - np.log(np.sum(empirical_log_distribution[:,1]))
    empirical_log_distribution = dict(zip(empirical_log_distribution[:,0],log_distribution))
    # Compute total variation
    tv = 0
    max_prob = 0
    for k in support:
        true_probability = np.exp(log_product_multinomial_pmf(str_to_table(k,dims=(I,J)),log_cell_intensities) + np.sum(log_factorial_vectorised(1,ct.rowsums))) / total_prob
        if true_probability > max_prob:
            max_prob = true_probability
        if k in empirical_log_distribution.keys():
            if abs(true_probability - np.exp(empirical_log_distribution[k])) > tv:
                tv = abs(true_probability - np.exp(empirical_log_distribution[k]))
        else:
            if true_probability > tv:
                tv = true_probability
    return tv,max_prob

# Define kl divergence
def kl_divergence(table_statistics,support,log_cell_intensities,total_prob,overflow:float=1e-50,column_name:str='table_strings'):
    # Get emprical distribution of tables from samples
    empirical_log_distribution = table_statistics.groupby(
                                [column_name]
                            ).agg(
                                {column_name:'count'}
                            ).rename(
                                columns={column_name:"frequency"}
                            ).reset_index().values
    # Convert two lists into dict
    log_distribution = np.log(empirical_log_distribution[:,1].astype('float32')) - np.log(np.sum(empirical_log_distribution[:,1]))
    empirical_log_distribution = dict(zip(empirical_log_distribution[:,0],log_distribution))
    # Compute kullback leibler divergence
    kl = 0
    max_prob = 0
    for k in support:
        true_probability = np.exp(log_product_multinomial_pmf(str_to_table(k,dims=(I,J)),log_cell_intensities) + np.sum(log_factorial_vectorised(1,ct.rowsums))) / total_prob
        if true_probability > max_prob:
            max_prob = true_probability
        if k in empirical_log_distribution.keys():
            kl += true_probability*(np.log(true_probability) - empirical_log_distribution[k])
        else:
            kl += true_probability*(np.log(true_probability) - np.log(overflow))
    return kl,max_prob

# Define posterior probability mass coverage %
def mass_coverage(table_statistics,log_cell_intensities,column_name:str='table_strings'):
    # Get emprical distribution of tables from samples
    empirical_log_distribution = table_statistics.groupby(
                                [column_name]
                            ).agg(
                                {column_name:'count'}
                            ).rename(
                                columns={column_name:"frequency"}
                            ).reset_index().values
    # Convert two lists into dict
    log_distribution = np.log(empirical_log_distribution[:,1].astype('float32')) - np.log(np.sum(empirical_log_distribution[:,1]))
    empirical_log_distribution = dict(zip(empirical_log_distribution[:,0],log_distribution))
    # Compute total probability of explored support under true distribution
    total_prob = 0
    for k in empirical_log_distribution.keys():
        total_prob += np.exp(log_product_multinomial_pmf(str_to_table(k,dims=(I,J)),log_cell_intensities) + np.sum(log_factorial_vectorised(1,ct.rowsums)))
    return total_prob

In [None]:
# Define sample sizes so that statistics will be compute every MCMC interval
table_sample_step = 1000
table_burnin = 0
table_chain_length = 1e7
maxN = int(min(table_burnin+table_chain_length,table_samples.shape[0]))
table_sample_sizes = range(table_burnin+table_sample_step,maxN+table_sample_step,table_sample_step)

In [None]:
table_tvs = np.ones(len(table_sample_sizes)) * (-1)
table_kls = np.ones(len(table_sample_sizes)) * (-1)
table_mass_coverage = np.ones(len(table_sample_sizes)) * (-1)
for i,s in tqdm(enumerate(table_sample_sizes),total=len(table_sample_sizes)):
#     table_tvs[i],max_probability = total_variation(
#                                         table_statistics=table_stats.iloc[table_burnin:s,:],
#                                         log_cell_intensities=log_lambdas,
#                                         support = np.unique(table_strings),
#                                         total_prob = total_probability,
#                                         column_name='table_strings'
#                                     )
#     table_kls[i],max_probability = kl_divergence(
#                                         table_statistics=table_stats.iloc[table_burnin:s,:],
#                                         log_cell_intensities=log_lambdas,
#                                         support = np.unique(table_strings),
#                                         total_prob = total_probability,
#                                         overflow = 0.1*min_probability,
#                                         column_name='table_strings'
#                                     )
    table_mass_coverage[i] = mass_coverage(
                                table_stats.iloc[table_burnin:s,:],
                                log_lambdas,
                                column_name='table_strings'
                            )

In [None]:
# Write to file
write_npy(
    np.array([table_sample_sizes,table_tvs]),
    os.path.join(dirpath,f'samples/table_empirical_distribution_total_variation_with_sample_size_max_chain_length_{table_chain_length}_{comment}.gz.npy')
)
# Write to file
write_npy(
    np.array([table_sample_sizes,table_kls]),
    os.path.join(dirpath,f'samples/table_empirical_distribution_kl_divergence_with_sample_size_max_chain_length_{table_chain_length}_{comment}.gz.npy')
)

In [None]:
plt.figure(figsize=(10,5))
plt.plot(table_sample_sizes,table_mass_coverage)
plt.title('Table empirical distribution convergence rate',fontsize=18)
plt.xlabel('Number of samples',fontsize=16)
plt.ylabel('% of probability mass explored',fontsize=16)
plt.locator_params(axis='x', nbins=20)
plt.savefig(os.path.join(dirpath,f'figures/table_total_probability_coverage_with_sample_size_max_chain_length_{table_chain_length}_{comment}.{figure_format}'),format=figure_format)

In [None]:
# table_mass_coverage_direct_sampling = deepcopy(table_mass_coverage)
# table_mass_coverage_degree_higher = deepcopy(table_mass_coverage)
# table_mass_coverage_degree_one = deepcopy(table_mass_coverage)

In [None]:
plt.figure(figsize=(10,5))
plt.plot(table_sample_sizes,table_mass_coverage_direct_sampling,label='direct sampling')
plt.plot(table_sample_sizes,table_mass_coverage_degree_higher,label='MBMCMC step size > 1')
plt.plot(table_sample_sizes,table_mass_coverage_degree_one,label='MBMCMC step size = 1')
# plt.title('Table empirical distribution convergence rate',fontsize=18)
plt.xlabel('Number of samples',fontsize=16)
plt.ylabel('% of probability mass explored',fontsize=16)
plt.locator_params(axis='x', nbins=20)
plt.legend()
plt.savefig(os.path.join(dirpath,f'figures/table_total_probability_coverage_with_sample_size_max_chain_length_{table_chain_length}_comparison.{figure_format}'),format=figure_format)
plt.show()

In [None]:
plt.figure(figsize=(10,5))
plt.plot(table_sample_sizes,table_tvs)
plt.title('Table empirical distribution convergence rate',fontsize=18)
plt.axhline(y=0,color='red')
plt.axhline(y=max_probability,color='purple')
plt.xlabel('Number of samples',fontsize=16)
plt.ylabel('Total variation',fontsize=16)
plt.locator_params(axis='x', nbins=20)
plt.savefig(os.path.join(dirpath,f'figures/table_empirical_distribution_total_variation_with_sample_size_max_chain_length_{table_chain_length}_{comment}.{figure_format}'),format=figure_format)

In [None]:
# THIS IS PROBLEMATIC
plt.figure(figsize=(10,5))
plt.plot(table_sample_sizes,table_kls)
plt.title('Table empirical distribution convergence rate',fontsize=18)
plt.axhline(y=0,color='red')
# plt.axhline(y=max_probability,color='purple')
plt.xlabel('Number of samples',fontsize=16)
plt.ylabel('KL divergence',fontsize=16)
plt.locator_params(axis='x', nbins=20)
plt.savefig(os.path.join(dirpath,f'figures/table_empirical_distribution_kl_divergence_with_sample_size_max_chain_length_{table_chain_length}_{comment}.{figure_format}'),format=figure_format)

### Check for convergence in probability of table mean estimator
This checks if weak law of large numbers holds by ensuring that

$$\lim_{M\to \infty} |\mathbf{\bar{n}}^{(0:M)} - \boldsymbol{\lambda}|^p = 0$$

for $p=1,2$. The norm is defined as follows:

$$|\mathbf{n}|^p = \left( \sum_{i,j}^{I,J} |n_{ij}|^p \right)^{1/p}.$$

The goal is to establish that 
$$\mathbf{\bar{n}}^{(0:i)} = \frac{1}{M}\sum_{m=1}^M \mathbf{n}^{(m)} \to \left( \frac{O_i \lambda_{ij}}{\sum_{l}^J \lambda_{il}} \right)_{i,j}^{I,J} = \boldsymbol{\lambda},$$
i.e. that the estimator $\mathbf{\bar{n}}^{(0:i)}$ converges in probability to the ground truth mean.

In [None]:
# Define sample sizes so that statistics will be compute every MCMC interval
table_sample_step = 1000
table_burnin = 0
table_chain_length = N
maxN = int(min(table_burnin+table_chain_length,table_samples.shape[0]))
table_sample_sizes = range(table_burnin+table_sample_step,maxN+table_sample_step,table_sample_step)

In [None]:
sample_mean_error_l1_norms = np.zeros(len(table_sample_sizes))
sample_mean_error_l2_norms = np.zeros(len(table_sample_sizes))
for i,s in tqdm(enumerate(table_sample_sizes),total=len(table_sample_sizes)):
    sample_mean_error_l1_norms[i] = relative_l1(
                        tab0=np.exp(log_lambdas),
                        tab=latest_mean
    )
    sample_mean_error_l2_norms[i] = relative_l2_norm(
                        tab0=np.exp(log_lambdas),
                        tab=latest_mean
    )

In [None]:
plt.figure(figsize=(7,5))
plt.plot(table_sample_sizes,sample_mean_error_l1_norms)
plt.xlabel('MCMC iteration',fontsize=16)
plt.ylabel(r'Relative $L_1$ of $\mathbb{E}[\mathbf{n}|\mathbf{n}_{\cdot,+}]$',fontsize=16)
# plt.xticks(range(burnin,max(sample_sizes),(max(sample_sizes)-burnin)//10))
plt.locator_params(axis='x', nbins=20)
plt.axhline(y=0,color='red')
plt.savefig(os.path.join(dirpath,f'figures/expected_table_relative_l1_with_mcmc_iteration_chain_length_{table_chain_length}_{comment}.{figure_format}'),format=figure_format)

In [None]:
plt.figure(figsize=(7,5))
plt.plot(table_sample_sizes,sample_mean_error_l2_norms)
plt.xlabel('MCMC iteration',fontsize=16)
plt.ylabel(r'Relative $L_2$ of $\mathbb{E}[\mathbf{n}|\mathbf{n}_{\cdot,+}]$',fontsize=16)
# plt.xticks(sample_sizes[::50])
plt.locator_params(axis='x', nbins=20)
plt.axhline(y=0,color='red')
plt.savefig(os.path.join(dirpath,f'figures/expected_table_relative_l2_norm_with_mcmc_iteration_chain_length_{table_chain_length}_{comment}.{figure_format}'),format=figure_format)

### Check for convergence in probability of margin mean estimator
This checks if weak law of large numbers holds by ensuring that

$$\lim_{M\to \infty} |\mathbf{\bar{n}}_{\cdot,+}^{(0:M)} - \boldsymbol{\sum_{i=1}^I \lambda_{i,\cdot}}|^p = 0$$

for $p=1,2$. The norm is defined as follows:

$$|\mathbf{n}|^p = \left( \sum_{i,j}^{I,J} |n_{ij}|^p \right)^{1/p}.$$

The goal is to establish that 
$$\mathbf{\bar{n}}_{\cdot,+}^{(0:i)} = \frac{1}{M}\sum_{m=1}^M \mathbf{n}_{\cdot,+}^{(m)} \to \left( \frac{N \sum_{i=1}^I\lambda_{ij}}{\sum_{i',j'}^{I,J} \lambda_{i'j'}} \right)_{i,j}^{I,J} = \sum_{i=1}^I \lambda_{i,j} \; \forall j,$$
i.e. that the estimator $\mathbf{\bar{n}}_{\cdot,+}^{(0:i)}$ converges in probability to the ground truth mean.

In [None]:
# Define sample sizes so that statistics will be compute every MCMC interval
table_sample_step = 1
table_burnin = 0
table_chain_length = 2000
maxN = int(min(table_burnin+table_chain_length,colsum_samples.shape[0]))
table_sample_sizes = range(table_burnin+1,maxN+table_sample_step,table_sample_step)

In [None]:
colsum_sample_mean_error_l1_norms = np.zeros(len(table_sample_sizes))
colsum_sample_mean_error_l2_norms = np.zeros(len(table_sample_sizes))
ground_truth_colsum_intensities = np.exp(log_lambdas).sum(axis=0)
for i,s in tqdm(enumerate(table_sample_sizes),total=len(table_sample_sizes)):
    colsum_sample_mean_error_l1_norms[i] = relative_l1(
                        tab0=ground_truth_colsum_intensities,
                        tab=np.mean(colsum_samples[table_burnin:s],axis=0)
    )
    colsum_sample_mean_error_l2_norms[i] = relative_l2_norm(
                        tab0=ground_truth_colsum_intensities,
                        tab=np.mean(colsum_samples[table_burnin:s],axis=0)
    )

In [None]:
plt.figure(figsize=(7,5))
plt.plot(table_sample_sizes,colsum_sample_mean_error_l1_norms)
plt.xlabel('MCMC iteration',fontsize=16)
plt.ylabel(r'Relative $L_1$ of $\mathbb{E}[\mathbf{n}_{+,\cdot}|\mathbf{n}_{+,+}]$',fontsize=16)
# plt.xticks(sample_sizes[::50])
plt.locator_params(axis='x', nbins=20)
plt.axhline(y=0,color='red')
plt.savefig(os.path.join(dirpath,f'figures/expected_colsum_relative_l1_with_mcmc_iteration_chain_length_{table_chain_length}_{comment}.{figure_format}'),format=figure_format)

In [None]:
plt.figure(figsize=(7,5))
plt.plot(table_sample_sizes,colsum_sample_mean_error_l2_norms)
plt.xlabel('MCMC iteration',fontsize=16)
plt.ylabel(r'Relative $L_2$ of $\mathbb{E}[\mathbf{n}_{+,\cdot}|\mathbf{n}_{+,+}]$',fontsize=16)
# plt.xticks(sample_sizes[::50])
plt.locator_params(axis='x', nbins=20)
plt.axhline(y=0,color='red')
plt.savefig(os.path.join(dirpath,f'figures/expected_colsum_relative_l2_norm_with_mcmc_iteration_chain_length_{table_chain_length}_{comment}.{figure_format}'),format=figure_format)

### Check for convergence in probability of table variance estimator [problematic]
NOTE: THERE IS SOMETHING WRONG IN THE DEFINITION OF THE VARIANCE OF THE TABLE
Let the empirical standard deviation be

$$\hat{\sigma}^{(M)}_{\mathbf{n}|\mathbf{n_{\cdot,+}}} = \sqrt{\frac{1}{M-1}\sum_{m=1}^{M} (\mathbf{n}^{(m)} - \mathbf{\bar{n}})^2},$$
where $\mathbf{\bar{n}} = \frac{1}{M} \sum_{m=1}^M \mathbf{n}^{(m)}$, $\mathbf{n}^{(m)} \sim p(\mathbf{n}|\mathbf{n_{\cdot,+}})$.

Under the ground truth intensities, the true variance is $\mathbb{V}\left[\mathbf{n}|\mathbf{n_{\cdot,+}} \right] = \left( \frac{O_i \lambda_{ij}}{\sum_{q=1}^J \lambda_{iq}} \right)_{i,j=1}^{I,J} = \boldsymbol{\lambda}$.

The following checks if the weak law of large numbers holds by ensuring that

$$\lim_{M\to \infty} |(\hat{\sigma}^{(M)}_{\mathbf{n}|\mathbf{n_{\cdot,+}}})^{2} - \boldsymbol{\lambda}|^r = 0$$

for $r=1,2$. The norm is defined as follows:

$$|\mathbf{n}|^r = \left( \sum_{i,j}^{I,J} |n_{ij}|^r \right)^{1/r}.$$

The goal is to establish that 
$$p(|\hat{\sigma}^{(M)}_{\mathbf{n}|\mathbf{n_{\cdot,+}}} - \boldsymbol{\lambda}|^r \leq \epsilon) = 1$$
for arbitrary $\epsilon > 0$.
i.e. that the estimator $\mathbf{\bar{n}}^{(0:i)}$ converges in probability to the ground truth variance.

The problem in the above is that we cannot get an unbiased estimator of $\mathbb{V}\left[\mathbf{n}|\mathbf{n}_{\cdot,+}\right]$ easily.

In [None]:
# Define sample sizes so that statistics will be compute every MCMC interval
table_sample_step = 1
table_burnin = 0
table_chain_length = 500#00
maxN = int(min(table_burnin+table_chain_length,table_samples.shape[0]))
table_sample_sizes = range(table_burnin+table_sample_step,maxN+table_sample_step,table_sample_step)

In [None]:
sample_variance_error_l1_norms = np.zeros(len(table_sample_sizes))
sample_variance_error_l2_norms = np.zeros(len(table_sample_sizes))
for i,s in tqdm(enumerate(table_sample_sizes),total=len(table_sample_sizes)):
    sample_variance_error_l1_norms[i] = relative_l1(
                        tab0=np.sqrt(np.exp(log_lambdas)),
                        tab=np.std(table_samples[table_burnin:s],axis=0,ddof=1)
    )
    sample_variance_error_l2_norms[i] = relative_l2_norm(
                        tab0=np.sqrt(np.exp(log_lambdas)),
                        tab=np.std(table_samples[table_burnin:s],axis=0,ddof=1)
    )

In [None]:
plt.figure(figsize=(7,5))
plt.plot(table_sample_sizes,sample_variance_error_l2_norms)
plt.xlabel('MCMC iteration',fontsize=16)
plt.ylabel(r'Relative $L_2$ error of sample variance',fontsize=16)
plt.locator_params(axis='x', nbins=20)
plt.axhline(y=0,color='red')
plt.savefig(os.path.join(dirpath,f'figures/relative_l2_error_sample_var_with_mcmc_iteration_{comment}.{figure_format}'),format=figure_format)

In [None]:
plt.figure(figsize=(7,5))
plt.plot(table_sample_sizes,sample_variance_error_l1_norms)
plt.xlabel('MCMC iteration',fontsize=16)
plt.ylabel(r'Relative $L_1$ error of sample variance',fontsize=16)
plt.locator_params(axis='x', nbins=20)
plt.axhline(y=0,color='red')
plt.savefig(os.path.join(dirpath,f'figures/relative_l1_error_sample_var_with_mcmc_iteration_{comment}.{figure_format}'),format=figure_format)

## Convergence of second largest eigenvalue of transition matrix

In [None]:
from hsnf import column_style_hermite_normal_form, row_style_hermite_normal_form, smith_normal_form
from itertools import product
from ticodm.markov_basis import instantiate_markov_basis

In [None]:
ct_copy = deepcopy(ct)
ct_copy.I = 10
ct_copy.J = 10
ct_copy.rowsums = list(np.random.multinomial(1000,np.ones(ct_copy.I)*(1/ct_copy.I)))
ct_copy.colsums = list(np.random.multinomial(1000,np.ones(ct_copy.J)*(1/ct_copy.J)))
ct_copy.table = ct_copy.table_monte_carlo_sample()
ct_copy.cells = sorted([tuple(cell) for cell in product(range(ct_copy.I),range(ct_copy.J))])
mb = instantiate_markov_basis(ct_copy)

In [None]:
def T(x):
    return np.array([1 if (x[0] == r) else 0 for r in range(ct_copy.I)])#.reshape(ct.I,1)
def T2(x):
    return np.array([1 if (x[0] == c or (x[1]+ct_copy.I) == c) else 0 for c in range(ct_copy.I+ct_copy.J)])#.reshape(ct.I,1)

In [None]:
# A = np.empty((ct_copy.I+ct_copy.J,len(ct_copy)))
# for i,c in enumerate(ct_copy.cells):
#     A[:,i] = T2(c)
    
A = np.empty((ct_copy.I,len(ct_copy)))
for i,c in enumerate(ct_copy.cells):
    A[:,i] = T(c)

In [None]:
# Find hermite normal form of A
H, U = column_style_hermite_normal_form(A)

In [None]:
# Find basis of kernel of A
basis_indices = []
basis = []
for j in range(H.shape[1]):
    if not np.any(H[:,j]):
        basis_indices.append(j)
for j in basis_indices:
    basis.append(U[:,j])
#     print(U[:,j].reshape(ct_copy.shape()))
#     print('\n')
basis = np.array(basis)

print('Markov basis',len(mb))
print('Lattice basis',basis.shape[0])

In [None]:
ct_copy.table_admissible(ct_copy.table)

In [None]:
ct_copy.table_admissible((ct_copy.table.flatten() + basis[0]).reshape(ct_copy.I,ct_copy.J))

In [None]:
ct_copy.table_admissible((ct_copy.table.flatten() + basis[0] + 2*basis[4]).reshape(ct_copy.I,ct_copy.J))

# Intensity convergence

In [None]:
# Expertiment id
experiment_id = 'synthetic_2x3_exp9_K200_direct_sampling'
# Expertiment type
experiment_type = 'TableMCMCConvergence'
# Expertiment date
date = '26_05_2022'
# exp8c_TableMCMC_19_04_2022
# Comment
comment = 'using_direct_sampling' 
# comment = 'using_degree_higher_markov_basis'
# comment = 'using_degree_one_markov_basis'
# comment = 'using_direct_sampling'


# Define directory
dirpath = f'../data/outputs/{experiment_id}_{experiment_type}_{date}/'
# Define filepaths
ensemble_metadata_filename = os.path.join(dirpath,f'{experiment_id}_{experiment_type}_{date}_metadata.json')
ensemble_table_filename = os.path.join(dirpath,f'samples/table_samples*.npy.gz')
ensemble_destination_dem_filename = os.path.join(dirpath,f'samples/destination_demand_samples*.npy.gz')

In [None]:
# Read samples
with open(ensemble_metadata_filename, 'r') as fin:
    ensemble_metadata = json.load(fin)
ensemble_table_samples = []
for file in sorted(glob.glob(ensemble_table_filename)):
    # Load files into memory
    sam = read_npy(file)
    ensemble_table_samples.append(sam)
    
ensemble_table_samples = np.array(ensemble_table_samples)

# Load objects
# Reconstruct expected flows 
ensemble_metadata_copy = deepcopy(ensemble_metadata)
ensemble_metadata_copy['mcmc']['contingency_table']['proposal'] = 'direct_sampling'
ensemble_dummy_config = Namespace(**{'settings':ensemble_metadata_copy})
ct = instantiate_ct(ensemble_dummy_config)

# Read important metadata (true latent values)
true_table = ct.table
colsums = true_table.sum(axis=0)
rowsums = true_table.sum(axis=1)
I,J = len(rowsums),len(colsums)
log_lambdas = np.asarray(np.log(ct.table))
log_colsum_lambdas = np.ones(J)
for j in range(J):
    log_colsum_lambdas[j] = logsumexp(log_lambdas[:,j])
N = ensemble_table_samples.shape[1]

## Convergence of mean intensity estimator
This follows [Sheldon's approach](https://proceedings.neurips.cc/paper/2011/file/fccb3cdc9acc14a6e70a12f74560c026-Paper.pdf]) that checks that 

$\mathbb{E}\left[\mathbb{E}\left[\mathbf{n}|\mathbf{n}_{\cdot,+}\right]\right] \to \mathbb{E}\left[\mathbf{n}\right] = \boldsymbol{\lambda}$

This established weak law of large numbers using the mean estimator
$$\bar{\mathbf{n}} = \frac{1}{KM}\sum_{k=1}^K\sum_{m=1}^M \mathbf{n}^{(k,m)}$$

where $\mathbf{n}^{(k,m)} \sim p(\mathbf{n}|\mathbf{n}_{\cdot,+}^{k})$,$\;\; \mathbf{n}_{\cdot,+}^{k}\sim p(\mathbf{n}_{\cdot,+})$.

In [None]:
# Define sample sizes so that statistics will be compute every MCMC interval
table_sample_step = 1
table_burnin = 0
table_chain_length = 100
maxN = int(min(table_burnin+table_chain_length,ensemble_table_samples.shape[1]))
table_sample_sizes = range(table_burnin+table_sample_step,maxN+table_sample_step,table_sample_step)

In [None]:
ensemble_l1_norms = np.zeros(len(table_sample_sizes))
ensemble_l2_norms = np.zeros(len(table_sample_sizes))
ensemble_mean = np.mean(ensemble_table_samples,axis=0)
for i,s in tqdm(enumerate(table_sample_sizes),total=len(table_sample_sizes)):    
    # Compute L1 of ensemble mean
    ensemble_l1_norms[i] = relative_l1(
                                tab=np.mean(ensemble_mean[table_burnin:s,:,:],axis=0).astype('float32'),
                                tab0=np.exp(log_lambdas).astype('float32')
                        )
    # Compute L2 of ensemble mean
    ensemble_l2_norms[i] = relative_l2_norm(
                                tab=np.mean(ensemble_mean[table_burnin:s,:,:],axis=0).astype('float32'),
                                tab0=np.exp(log_lambdas).astype('float32')
                        )

In [None]:
# Write to file
write_npy(
    np.array([table_sample_sizes,ensemble_l1_norms]),
    os.path.join(dirpath,f'samples/ensemble_k{ensemble_table_samples.shape[0]}_relative_l1_sample_mean_with_mcmc_iteration_{comment}.gz.npy')
)

In [None]:
plt.figure(figsize=(7,5))
plt.plot(table_sample_sizes,ensemble_l1_norms)
plt.axhline(y=0,color='red')
plt.title(f'Ensemble of {ensemble_table_samples.shape[0]} samplers {comment.replace("_"," ")}',fontsize=14)
plt.xlabel('MCMC iteration',fontsize=16)
plt.ylabel(r'Relative $L_1$ norm of $\mathbb{E}\left[\mathbb{E}\left[\mathbf{n}|\mathbf{n}_{\cdot,+}\right]\right]$',fontsize=16)
# plt.xticks(sample_sizes[::10])
plt.locator_params(axis='x', nbins=20)
plt.savefig(os.path.join(dirpath,f'figures/ensemble_k{ensemble_table_samples.shape[0]}_relative_l2_norm_sample_mean_with_mcmc_iteration_chain_length_{table_chain_length}_{comment}.eps'),format='eps')
plt.show()

## Plot all rates together

## Convergence of variance intensity estimator [this is problematic]
This checks that

$\mathbb{V}\left[\mathbf{n}\right] = \boldsymbol{\lambda} = \mathbb{E}_{\mathbf{n}_{\cdot,+}}\left[\mathbb{V}\left[\mathbf{n}|\mathbf{n}_{\cdot,+} \right] \right] + \mathbb{V}_{\mathbf{n}_{\cdot,+}} \left[ \mathbb{E}\left[ \mathbf{n}|\mathbf{n}_{\cdot,+}\right] \right]$

This established weak law of large numbers using the mean estimators
$$\bar{\mathbf{n}}= \frac{1}{KM}\sum_{k=1}^K\sum_{m=1}^M \mathbf{n}^{(k,m)}$$
$$\bar{\mathbf{n}}^{(k)} = \frac{1}{M}\sum_{m=1}^M \mathbf{n}^{(k,m)}$$
$$\bar{\mathbf{n}}^{(m)} = \frac{1}{K}\sum_{k=1}^K \mathbf{n}^{(k,m)}$$
and the standard deviation estimator of $\sqrt{\mathbb{V} \left[ \mathbf{n}|\mathbf{n}_{\cdot,+}^{(k)} \right]}$
$$\hat{\sigma}^{(k,M)}_{\mathbf{n}|\mathbf{n_{\cdot,+}}} = \sqrt{\frac{1}{M-1}\sum_{m=1}^{M} (\mathbf{n}^{(m,k)} - \mathbf{\bar{n}}^{(k)})^2}$$

where $\mathbf{n}^{(k,m)} \sim p(\mathbf{n}|\mathbf{n}_{\cdot,+}^{k})$,$\;\; \mathbf{n}_{\cdot,+}^{k}\sim p(\mathbf{n}_{\cdot,+})$. Therefore, we get that 

$$\mathbb{V}\left[\mathbf{n}\right] = \frac{1}{K}\frac{1}{M-1}\sum_{k=1}^{K}\sum_{m=1}^{M} (\mathbf{n}^{(m,k)} - \mathbf{\bar{n}}^{(k)})^2  + \frac{1}{K-1}\sum_{k=1}^K\left(\mathbf{\bar{n}}^{(k)} - \mathbf{\bar{n}} \right)^2$$

The problem in the above is that we cannot get an unbiased estimator of $\mathbb{V}\left[\mathbf{n}|\mathbf{n}_{\cdot,+}\right]$ easily. For example, using 
$$\mathbb{V}\left[\mathbf{n}\right] = \mathbb{E}\left[\mathbf{n}^T\mathbf{n}|\mathbf{n}_{\cdot,+}\right]- \left(\mathbb{E}\left[\mathbf{n}|\mathbf{n}_{\cdot,+}\right]\right)^2$$
one cannot easily obtain an unbiased estimator for $\left(\mathbb{E}\left[\mathbf{n}|\mathbf{n}_{\cdot,+}\right]\right)^2$  and guarantee that the conditional variance estimator is always non-negative.

In [None]:
# Define sample sizes so that statistics will be compute every MCMC interval
table_sample_step = 1
table_burnin = 0
table_chain_length = 1000
maxN = int(min(table_burnin+table_chain_length,ensemble_table_samples.shape[1]))
table_sample_sizes = range(table_burnin+table_sample_step+1,maxN+table_sample_step+1,table_sample_step)

In [None]:
ensemble_variance_l1_norms = np.zeros(len(table_sample_sizes))
ensemble_variance_l2_norms = np.zeros(len(table_sample_sizes))
ensemble_mean = np.mean(ensemble_table_samples,axis=0)
sample_mean = np.mean(ensemble_table_samples,axis=1)
for i,s in tqdm(enumerate(table_sample_sizes),total=len(table_sample_sizes)):    
    # Compute L1 of ensemble mean
    ensemble_variance_l1_norms[i] = relative_l1(
                                tab=(
                                    np.mean(np.var(ensemble_table_samples[:,table_burnin:s,:,:],axis=1,ddof=1),axis=0) +\
                                    np.var(np.mean(ensemble_table_samples[:,table_burnin:s,:,:],axis=1),axis=0,ddof=1)
                                ).astype('float32'),
                                tab0=np.exp(log_lambdas).astype('float32')
                        )
    # Compute L2 of ensemble mean
    ensemble_variance_l2_norms[i] = relative_l2_norm(
                                tab=(
                                    np.mean(np.var(ensemble_table_samples[:,table_burnin:s,:,:],axis=1,ddof=1),axis=0) +\
                                    np.var(np.mean(ensemble_table_samples[:,table_burnin:s,:,:],axis=1),axis=0,ddof=1)
                                ).astype('float32'),
                                tab0=np.exp(log_lambdas).astype('float32')
                        )

In [None]:
# Write to file
write_npy(
    np.array([table_sample_sizes,ensemble_variance_l1_norms]),
    os.path.join(dirpath,f'samples/ensemble_k{ensemble_table_samples.shape[0]}_relative_l1_sample_var_with_mcmc_iteration_{comment}.gz.npy')
)

In [None]:
plt.figure(figsize=(7,5))
plt.plot(table_sample_sizes,ensemble_variance_l1_norms)
plt.axhline(y=0,color='red')
plt.title(f'Ensemble of {ensemble_table_samples.shape[0]} samplers {comment.replace("_"," ")}',fontsize=14)
plt.xlabel('MCMC iteration',fontsize=16)
plt.ylabel(r'Relative $L_1$ norm of $\mathbb{V}\left[\mathbf{n}|\mathbf{n}_{\cdot,+}\right]$',fontsize=16)
plt.locator_params(axis='x', nbins=20)
plt.savefig(os.path.join(dirpath,f'figures/ensemble_k{ensemble_table_samples.shape[0]}_relative_l2_norm_sample_mean_with_mcmc_iteration_chain_length_{table_chain_length}_{comment}.eps'),format='eps')
plt.show()

# Convergence of multinomially distributed random variable
Script for checking asymptotic distribution of Multinomial for large N

In [None]:
# size = 3
# probs = np.array([0.6,0.3,0.1])
# #np.ones(size)*(1/size)
# totals = np.linspace(10,1000,5,dtype=int)
# histograms = {}
# for i,t in enumerate(totals):
#     histograms[str(t)] = {}
#     print(f'Finding support, i = {i}')
#     supp = [x for x in itertools.product(range(1,t), repeat=size) if sum(x) == t]
#     normalisation = 0
#     for s in tqdm(supp,leave=True):
#         histograms[str(t)][table_to_str(s)] = scipy.stats.multinomial.logpmf(x=s,n=t,p=probs).flatten()[0]
#         normalisation += np.exp(histograms[str(t)][table_to_str(s)])
#     # Renormalise probability
#     for s in supp:
#         histograms[str(t)][table_to_str(s)] -= np.log(normalisation)

In [None]:
# fig,axs = plt.subplots(1,len(totals),figsize=(20,10))
# for i,t in enumerate(totals):
#     axs[i].set_title(f'N = {t}, support size = {len(histograms[str(t)])}')
#     axs[i].plot( range(len(histograms[str(t)])), np.exp(list(histograms[str(t)].values()))) 
# plt.savefig(os.path.join(dirpath,f'figures/asymptotic_multinomial_distribution_large_N_unequal_probs.eps'),format='eps')