In [None]:
import os
import glob
import subprocess
import numpy as np
import pandas as pd
import matplotlib as mpl
mpl.use('ps')
import matplotlib.pyplot as plt

from copy import deepcopy
from tqdm.auto import tqdm
from matplotlib.lines import Line2D
from matplotlib.ticker import MultipleLocator
from matplotlib.legend_handler import HandlerTuple
# from IPython.core.display import display, HTML

from gensit.utils.misc_utils import *
from gensit.utils.math_utils import *
from gensit.static.plot_variables import LATEX_RC_PARAMETERS

In [2]:
%matplotlib inline

# AUTO RELOAD EXTERNAL MODULES
%load_ext autoreload
%autoreload 2

In [3]:
# LaTeX font configuration
mpl.rcParams.update(LATEX_RC_PARAMETERS)

# SRMSE & CP vs (iterations, ensemble size) by sigma, scheme and table constraint
variable = table

sigma = low,high,learned

constraints = total,rowsums,doubly,doubly_10percent_cells,doubly_20percent_cells

In [32]:
datapath = "../../data/outputs/cambridge_work_commuter_lsoas_to_msoas/exp2/paper_figures/figure3/exploration_exploitation_tradeoff_srmse_cp_vs_method_epoch_seed_label_sigma&title_marker_sigma_markersize_table_coverage_probability_size_linewidth_1.0_colour_title_opacity_1.0_hatchopacity_1.0"

outputbasepath ="../../data/outputs/cambridge_work_commuter_lsoas_to_msoas/exp2/paper_figures/figure3_v2/"

figuretitle = "_vs_iterations_x_ensemble_size_by_constraints_method"

sweep_data = {
    "|total|": [
        "$\sigma = 0.014$, $\mytabletotal$",
        "$\sigma = 0.141$, $\mytabletotal$",
        r"$\sigma = \text{learned}$, $\mytabletotal$"
    ],
    "|colsums|": [
        "$\sigma = 0.014$, $\mytablecolsums$",
        "$\sigma = 0.141$, $\mytablecolsums$",
        r"$\sigma = \text{learned}$, $\mytablecolsums$"
    ],
    "|doubly|": [
        "$\sigma = 0.014$, $\mytablecolsums,\mytablerowsums$",
        "$\sigma = 0.141$, $\mytablecolsums,\mytablerowsums$",
        r"$\sigma = \text{learned}$, $\mytablecolsums,\mytablerowsums$"
    ],
    "|doubly_10percent_cells|": [
        r"$\sigma = 0.014$, $\mytablecolsums,\mytablerowsums,\mytablecells{_1}$",
        r"$\sigma = 0.141$, $\mytablecolsums,\mytablerowsums,\mytablecells{_1}$",
        r"$\sigma = \text{learned}$, $\mytablecolsums,\mytablerowsums,\mytablecells{_1}$"
    ],
    "|doubly_20percent_cells|": [
        r"$\sigma = 0.014$, $\mytablecolsums,\mytablerowsums,\mytablecells{_2}$",
        r"$\sigma = 0.141$, $\mytablecolsums,\mytablerowsums,\mytablecells{_2}$",
        r"$\sigma = \text{learned}$, $\mytablecolsums,\mytablerowsums,\mytablecells{_2}$"
    ]
}

In [33]:
fontsize = 14

In [34]:
def add_leaf(tree, path, values_dict):
    """
    Add a leaf node to the tree where each leaf is a dictionary of lists.
    If any intermediate branches don't exist, they will be created automatically.
    
    Args:
        tree (defaultdict): The tree structure.
        path (list): A list of keys specifying the path to the leaf node.
        values_dict (dict): A dictionary where the keys are the list names and the values are the elements to be appended.
    """
    node = tree
    for key in path[:-1]:  # Traverse until the second-to-last node
        # If the branch doesn't exist, create it (automatically adds intermediate branches)
        if key not in node:
            node[key] = defaultdict(lambda: defaultdict())  # Create a new branch if not already existing
        node = node[key]  # Navigate deeper into the tree
    # If the node is a leaf, initialize it as a dictionary of lists if not already created
    if path[-1] not in node:
        node[path[-1]] = {k: [] for k in values_dict.keys()}  # Initialize empty lists for the keys in the values_dict
    # Append values to each list in the leaf node's dictionary
    for key, value in values_dict.items():
        node[path[-1]][key].append(value)

# Traversal function that collects paths and leaf data
def traverse_tree(tree,leaf_keys):
    """
    Traverse the tree and collect all leaf node data (path and lists).
    
    Args:
        tree (defaultdict): The tree structure to traverse.
        
    Returns:
        list: A list of tuples, where each tuple contains the path and leaf node data (dictionary of lists).
    """
    paths_and_leaves = []

    def traverse_and_collect(node, path=[]):
        """
        Recursively traverse the tree and collect leaf nodes with their paths.
        """
        if any([lk == nk for lk in leaf_keys for nk in node.keys()]):
            # We are at a leaf node, which is a dictionary of lists
            paths_and_leaves.append((path, node))
        else:
            for key, child_node in node.items():
                traverse_and_collect(child_node, path + [key])

    traverse_and_collect(tree)
    return paths_and_leaves


In [43]:

SIGMA_COLORS = {
    "$\sigma = 0.014$":"#A2CFFE",
    "$\sigma = 0.141$":"#9B59B6",
    r"$\sigma = \text{learned}$":"#2ECC40"
}
SCHEME_MARKERS = {
    "Joint":"*",
    "Disjoint":"o"
}
SCHEME_LINESTYLES = {
    "Joint":"solid",
    "Disjoint":"dashed"
}

KEEP_KEYS = ["iteration_ensemble","srmse","cp","sigma","scheme"]

def myplot(data,output_path,groups,y_label):
    fig,ax = plt.subplots()#figsize=(10,15)
    # ax.set_box_aspect(1)

    num_data = np.shape(data[y_label])[0]
    
    # Get unique elements without changing their order of appearance
    xlabels = list(dict.fromkeys(data['iteration_ensemble']))
    x = np.arange(1,len(xlabels)*5,5)
    xlabel_dict = dict(zip(xlabels,x))

    # Create a defaultdict that defaults to another defaultdict
    data_tree = defaultdict(lambda: defaultdict())
    for i in range(num_data):
        _ = ax.scatter(
            xlabel_dict[data['iteration_ensemble'][i]],
            data[y_label][i],
            linewidth = 1.0,
            alpha=1.0,
            c = SIGMA_COLORS[data['sigma'][i]],
            marker = SCHEME_MARKERS[data['scheme'][i]],
        )
        # print(' > '.join([data[g][i] for g in groups]))
        add_leaf(data_tree, [data[g][i] for g in groups], {k:v[i] for k,v in data.items() if k in KEEP_KEYS})
        
    # Collect paths and leaf data
    leaf_data = traverse_tree(data_tree,KEEP_KEYS)

    # Print out the collected paths and leaf node data
    for path, leaf in leaf_data:
        print(f"Path: {' -> '.join(path)}")
        # print(json.dumps({k:v for k,v in leaf.items() if k in ['iteration_ensemble','srmse']},indent=2))
        _ = ax.plot(
            [xlabel_dict[v] for v in leaf['iteration_ensemble']],
            leaf[y_label],
            linewidth = 1.0,
            linestyle = SCHEME_LINESTYLES[leaf['scheme'][0]],
            c = SIGMA_COLORS[leaf['sigma'][0]],
            marker = SCHEME_MARKERS[leaf['scheme'][0]],
        )

    ax.tick_params(labelsize=fontsize)
    ax.set_xticks(x)
    ax.set_xticklabels(xlabels,rotation=30, ha="right")
    ax.xaxis.set_major_locator(mpl.ticker.FixedLocator(x))
    
    plt.xlabel(r'$(N,E)$',fontsize=fontsize)
    plt.ylabel(y_label.upper(),fontsize=fontsize)

    if y_label == 'cp':
        plt.ylim(0,100)

    legend_tree = {}
    # Adding custom legend entries: one for each noise regime with all variations
    for m in range(2):
        for noise,color in SIGMA_COLORS.items():
            for scheme in SCHEME_MARKERS.keys():
                line_style = SCHEME_LINESTYLES[scheme]
                marker = SCHEME_MARKERS[scheme]
                
                if m == 0:
                    style = Line2D(
                        xdata=[0,100],
                        ydata=[0,0], 
                        color=color, 
                        marker=marker, 
                        linestyle=line_style, 
                        label=noise,
                        linewidth=1,
                        markersize=5,
                    )
                    legend_tree.setdefault(noise,[]).append(style)
                else:
                    style = Line2D(
                        xdata=[0,100],
                        ydata=[0,0], 
                        color=color, 
                        marker=marker, 
                        linestyle=line_style, 
                        label=scheme,
                        linewidth=1,
                        markersize=5, 
                    )
                    legend_tree.setdefault(scheme,[]).append(style)

    # Manually create legend handles using the Line2D objects
    handles,labels = [],[]
    for label, lines in legend_tree.items():
        handles.append(tuple(lines))
        labels.append(label)

    # Add the custom legend (only use one label per group)
    plt.legend(
        handles=handles, 
        labels=labels, 
        handler_map={tuple: HandlerTuple(ndivide=None)},
        handlelength=4,
        handleheight=0.5,
        markerscale=1,
        loc='best', 
        fontsize=9, 
        frameon=True,
        ncol=2,
    )

    # fig.tight_layout(rect=(0, 0, 0.7, 1.1))
    # fig.tight_layout()

    # plt.show()
    makedir(os.path.dirname(output_path))
    write_figure(
        fig,
        output_path,
        filename_ending='ps',
        pad_inches=0.0,
        bbox_inches='tight'
    )

    subprocess.run(["ps2pdf", output_path+'.ps', output_path+'.pdf'], check=True)
    subprocess.run(["pdfcrop", "--margins", "0 0 0 0", output_path+'.pdf', output_path+'_cropped.pdf'], check=True)



In [None]:
group_keys = ['scheme','sigma'] 
sort_keys = ['iteration','ensemble']

for sweep_id, slice_vals in sweep_data.items():
    srmseoutputpath = outputbasepath+"srmse"+figuretitle+sweep_id
    cpoutputpath = outputbasepath+"cp"+figuretitle+sweep_id

    data = read_json(datapath+'_data.json')
    settings = read_json(datapath+'_settings.json')

    slice_key = 'label'
    slice_index = []
    for i,v in enumerate(data[slice_key]):
        if any([v == sv for sv in slice_vals]):
            # print(v)
            slice_index.append(i)
    
    assert slice_index

    data_slice = deepcopy(data)
    IGNORED_COLUMNS = ['outputs','x_group','y_group','z_group','annotate','hatch','x_id','y_id','z_id','z']

    dropped_keys = []
    for k in data_slice.keys():
        if k not in IGNORED_COLUMNS:
            data_slice[k] = np.array(data_slice[k])[slice_index].tolist()
        else:
            dropped_keys.append(k)
    
    data_slice['scheme'] = np.array(data_slice['x'])[:,0].tolist()
    data_slice['iteration_ensemble'] = [str((n,e)) for n,e in np.array(data_slice['x'])[:,1]]
    data_slice['iteration'] = [n for n,e in np.array(data_slice['x'])[:,1]]
    data_slice['ensemble'] = [e for n,e in np.array(data_slice['x'])[:,1]]
    data_slice['srmse'] = np.array(data_slice['y'])[:,0].tolist()
    data_slice['cp'] = np.array(100*(np.log(data['marker_size'])+2)/8).tolist()
    

    # Pattern to match LaTeX-style expressions
    label_pattern = r'\$(.*?)\$'
    # Separate into two lists
    data_slice['sigma'] = []
    data_slice['table_constraints'] = []

    for item in data_slice['label']:
        matches = re.findall(label_pattern, item)
        if len(matches) == 2:
            data_slice['sigma'].append(f"${matches[0]}$")
            data_slice['table_constraints'].append(f"${matches[1]}$")
    
    for k in dropped_keys+['x','y','colour']:
        del data_slice[k]

    # for k,v in data_slice.items():
        # if len(np.shape(v)) > 0 and np.shape(v)[0] > 0:
            # print(k,np.shape(v))
        # else:
            # data

    # Sort all lists based on the values of the selected key
    sorted_indices = sorted(range(len(data_slice['scheme'])), key=lambda i: tuple([data_slice[k][i] for k in sort_keys]))

    # Reorder each list in the dictionary
    for k, v in data_slice.items():
        try:
            data_slice[k] = np.array([v[i] for i in sorted_indices])
        except:
            pass
    
    print(sweep_id,data_slice['cp'].shape[0])
    print(data_slice['cp'])
    # print(data_slice['table_constraints'])
    # myplot(data_slice,srmseoutputpath,group_keys,y_label='srmse')
    # myplot(data_slice,cpoutputpath,group_keys,y_label='cp')

|total| 42
[88.46153846 95.81939799 78.01003344 84.36454849 85.45150502 90.30100334
 86.4548495  92.80936455 28.34448161 35.45150502 42.47491639 90.96989967
 84.7826087  91.55518395 28.92976589 36.12040134 43.31103679 90.55183946
 81.77257525 88.87959866 33.94648829 40.55183946 49.41471572 87.45819398
 80.68561873 88.21070234 44.23076923 51.83946488 60.95317726 86.70568562
 79.76588629 27.5083612  68.72909699 82.35785953 84.61538462 81.18729097
 72.82608696 85.03344482 34.7826087  41.80602007 89.54849498 89.96655518]
|colsums| 42
[88.46153846 95.81939799 78.01003344 84.36454849 85.45150502 90.30100334
 86.4548495  92.80936455 28.34448161 35.45150502 42.47491639 90.96989967
 84.7826087  91.55518395 28.92976589 36.12040134 43.31103679 90.55183946
 81.77257525 88.87959866 33.94648829 40.55183946 49.41471572 87.45819398
 80.68561873 88.21070234 44.23076923 51.83946488 60.95317726 86.70568562
 79.76588629 85.03344482 68.72909699 82.35785953 84.61538462 81.18729097
 72.82608696 27.5083612  3

: 