In [7]:
import os
import glob
import warnings
import h5py as h5
import numpy as np
import pandas as pd
import matplotlib as mpl
mpl.use('ps')
import matplotlib.pyplot as plt

from copy import deepcopy
from tqdm.auto import tqdm
from matplotlib.ticker import MultipleLocator
# from IPython.core.display import display, HTML

from gensit.config import Config
from gensit.inputs import Inputs
from gensit.outputs import Outputs
from gensit.utils.misc_utils import *
from gensit.utils.math_utils import *

from gensit.utils.probability_utils import *
from gensit.contingency_table import instantiate_ct
from gensit.contingency_table.ContingencyTable_MCMC import ContingencyTableMarkovChainMonteCarlo

from gensit.config import Config
from gensit.inputs import Inputs
from gensit.utils.misc_utils import *
from gensit.static.plot_variables import *
from gensit.static.global_variables import *
from gensit.outputs import Outputs,OutputSummary

In [8]:
%matplotlib inline

# AUTO RELOAD EXTERNAL MODULES
%load_ext autoreload
%autoreload 2

In [9]:
# LaTeX font configuration
mpl.rcParams.update(LATEX_RC_PARAMETERS)

# SRMSE & CP vs (iterations, ensemble size) by method and table constraint
variable = table

sigma = low,high,learned

constraints = total,rowsums,doubly,doubly_10percent_cells,doubly_20percent_cells

In [53]:
datapath = "../../data/outputs/cambridge_work_commuter_lsoas_to_msoas/exp2/paper_figures/figure3/exploration_exploitation_tradeoff_srmse_cp_vs_method_epoch_seed_label_sigma&title_marker_sigma_markersize_table_coverage_probability_size_linewidth_1.0_colour_title_opacity_1.0_hatchopacity_1.0"

outputbasepath ="../../data/outputs/cambridge_work_commuter_lsoas_to_msoas/exp2/paper_figures/figure3_v2/"

figuretitle = "_vs_iterations_x_ensemble_size_by_constraints_method"

sweep_data = {
    "|total|high_noise": [
        '$\sigma = 0.141$, $\mytabletotal$',
    ]
}

In [54]:
fontsize = 14

In [55]:
def srmse_plot(data,output_path):
    fig,ax = plt.subplots()#figsize=(10,15)
    ax.set_box_aspect(1)

    cs = ['#8ebeda', '#8ebeda', '#a6c858', '#a6c858']#'#8ebeda' '#a6c858', '#ca4a58', '#e0ad41'
    markers = ['o','^','o','^']
    for jindx, j in enumerate(list(range(0,len(cs)*6,6))):
        # print(jindx,(j,j+6))
        for i in range(j,j+6):
            print(i,(data['label'][i]+', \;'+data['newlabel'][i]).replace(", $\\sigma = 0.014$",""))
            plot_label = (data['label'][i]+', \;'+data['newlabel'][i]).replace(", $\\sigma = 0.014$","")
            _ = ax.scatter(
                list(map(int,data['newx']))[i],
                np.array(data['y'][i])[0],
                linewidth = 1.0,
                alpha=1.0,
                c = cs[jindx],
                marker = markers[jindx],
                label = plot_label,
            )
        _ = ax.plot(
            list(map(int,data['newx']))[slice(j,j+6)],
            np.array(data['y'][slice(j,j+6)])[:,0],
            linewidth = 1.0,
            c = cs[jindx],
            marker = markers[jindx]
        )

    ax.tick_params(labelsize=fontsize)
    ax.xaxis.set_major_locator(MultipleLocator(20000))
    plt.xlabel(r'$N$',fontsize=fontsize)
    plt.ylabel(r'SRMSE',fontsize=fontsize)

    handles, labels = plt.gca().get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    _ = plt.legend(by_label.values(), by_label.keys(),fontsize=9,ncol=1)

    # fig.tight_layout(rect=(0, 0, 0.7, 1.1))
    # fig.tight_layout()

    # plt.show()
    write_figure(
        fig,
        output_path,
        filename_ending='ps',
        pad_inches=0.0,
        bbox_inches='tight'
    )

In [56]:
def cp_plot(data,output_path):
    fig,ax = plt.subplots(figsize=(10,4))
    ax.set_box_aspect(1)

    cs = ['#8ebeda', '#8ebeda', '#a6c858', '#a6c858']#'#8ebeda' '#a6c858', '#ca4a58', '#e0ad41'
    markers = ['o','^','o','^']
    for jindx, j in enumerate(list(range(0,len(cs)*6,6))):
        # print(jindx,(j,j+6))
        for i in range(j,j+6):
            print(i,(data['label'][i]+', \;'+data['newlabel'][i]).replace(", $\\sigma = 0.014$",""))
            _ = ax.scatter(
                list(map(int,data['newx']))[i],
                100*(np.log(data['marker_size'][i])+2)/8,
                linewidth = 1.0,
                alpha=1.0,
                c = cs[jindx],
                marker = markers[jindx],
                label = (data['label'][i]+', \;'+data['newlabel'][i]).replace(", $\\sigma = 0.014$",""),
            )
        _ = ax.plot(
            list(map(int,data['newx']))[slice(j,j+6)],
            100*(np.log(data['marker_size'][slice(j,j+6)])+2)/8,
            linewidth = 1.0,
            c = cs[jindx],
            marker = markers[jindx]
        )
    ax.tick_params(labelsize=fontsize)
    ax.xaxis.set_major_locator(MultipleLocator(20000))
    plt.xlabel(r'$N$',fontsize=fontsize)
    plt.ylabel(r'CP',fontsize=fontsize)

    handles, labels = plt.gca().get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    _ = plt.legend(by_label.values(), by_label.keys(),
                fontsize=9,ncol=1,loc='best', bbox_to_anchor=(0.5, 0., 0.5, 0.5))

    # fig.tight_layout(rect=(0, 0, 0.7, 1.1))
    # plt.tight_layout()

    write_figure(
        fig,
        output_path,
        filename_ending='ps',
        pad_inches=0.0,
        bbox_inches='tight'
    )

In [None]:
for sweep_id, slice_vals in sweep_data.items():
    srmseoutputpath = outputbasepath+"srmse"+figuretitle+sweep_id
    cpoutputpath = outputbasepath+"cp"+figuretitle+sweep_id

    data = read_json(datapath+'_data.json')
    settings = read_json(datapath+'_settings.json')

    slice_key = 'label'
    slice_index = []
    for i,v in enumerate(data[slice_key]):
        # print(v)
        if v in slice_vals:
            slice_index.append(i)
    
    assert slice_index

    data_slice = deepcopy(data)
    
    IGNORED_COLUMNS = ['outputs','x_group','y_group','z_group','annotate','hatch','x_id','y_id','z_id']

    for k in [j for j in data.keys() if j not in IGNORED_COLUMNS]:
        data_slice[k] = np.array(data_slice[k])[slice_index].tolist()
    print(data_slice)
    data_slice['newlabel'] = np.array(data_slice['x'])[:,0].tolist()
    data_slice['newx'] = np.array(data_slice['x'])[:,1].tolist()
    for k,v in data_slice.items():
        print(k,np.shape(v))
    
    srmse_plot(data_slice,srmseoutputpath)
    # cp_plot(data_slice,cpoutputpath)

{'x': [['Disjoint', [10, 1000]], ['Disjoint', [50, 200]], ['Disjoint', [100, 100]], ['Disjoint', [500, 20]], ['Disjoint', [1000, 10]], ['Disjoint', [5000, 2]], ['Disjoint', [10000, 1]], ['Joint', [10, 1000]], ['Joint', [50, 200]], ['Joint', [100, 100]], ['Joint', [500, 20]], ['Joint', [1000, 10]], ['Joint', [5000, 2]], ['Joint', [10000, 1]]], 'y': [[1.5763506889343262, []], [1.4428880214691162, []], [1.4113309383392334, []], [1.34134042263031, []], [1.3379679918289185, []], [1.8632210493087769, []], [0.726523756980896, []], [1.5772409439086914, []], [1.4432698488235474, []], [1.4094140529632568, []], [1.3465971946716309, []], [1.3223631381988525, []], [1.8881617784500122, []], [0.7227044701576233, []]], 'z': [[[], []], [[], []], [[], []], [[], []], [[], []], [[], []], [[], []], [[], []], [[], []], [[], []], [[], []], [[], []], [[], []], [[], []]], 'x_group': [], 'y_group': [], 'z_group': [], 'marker_size': [185.6903065644204, 195.89758031567, 189.45417168302507, 147.91762313932662, 139