In [1]:
import os
import re
import subprocess
import numpy as np
import pandas as pd
import matplotlib as mpl
mpl.use('ps')
import matplotlib.pyplot as plt

from copy import deepcopy
from tqdm.auto import tqdm
from collections import defaultdict
from matplotlib.lines import Line2D
from matplotlib.ticker import MultipleLocator
from matplotlib.legend_handler import HandlerTuple
# from IPython.core.display import display, HTML

from gensit.static.plot_variables import COLOR_NAMES
from gensit.utils.misc_utils import write_figure,makedir
from gensit.utils.misc_utils import add_leaf,traverse_tree,read_json
from gensit.static.plot_variables import LATEX_RC_PARAMETERS

In [2]:
%matplotlib inline

# AUTO RELOAD EXTERNAL MODULES
%load_ext autoreload
%autoreload 2

In [3]:
# LaTeX font configuration
mpl.rcParams.update(LATEX_RC_PARAMETERS)

# SRMSE & CP vs loss operator by sigma and table constraint
variable = table

sigma = low,high,learned

constraints = total,rowsums,doubly,doubly_10percent_cells,doubly_20percent_cells

In [14]:
# datapath = "../../data/outputs/cambridge_work_commuter_lsoas_to_msoas/exp3/paper_figures/figure4_v1/loss_function_validation_tractable_odms_label_sigma&title_marker_sigma_markersize_table_coverage_probability_size_linewidth_1.0_colour_title_opacity_1.0_hatchopacity_1.0"

datapath = "../../data/outputs/cambridge_work_commuter_lsoas_to_msoas/exp3/paper_figures/figure4_v1/loss_function_validation_intractable_odms_label_sigma&title_marker_sigma_markersize_table_coverage_probability_size_linewidth_1.0_colour_title_opacity_1.0_hatchopacity_1.0"

# loss_function_validation_intractable_odms_label_sigma&title_marker_sigma_markersize_table_coverage_probability_size_linewidth_1.0_colour_title_opacity_1.0_hatchopacity_1.0
# loss_function_validation_tractable_odms_label_sigma&title_marker_sigma_markersize_table_coverage_probability_size_linewidth_1.0_colour_title_opacity_1.0_hatchopacity_1.0


outputbasepath ="../../data/outputs/cambridge_work_commuter_lsoas_to_msoas/exp3/paper_figures/figure4_v2/"

figuretitle = "_vs_loss_operator_by_constraints_method_sigma"

sweep_data = {
    # "|total|": [
    #     "$\sigma = 0.014$, $\mytabletotal$",
    #     "$\sigma = 0.141$, $\mytabletotal$",
    #     r"$\sigma = \text{learned}$, $\mytabletotal$"
    # ],
    # "|colsums|": [
    #     "$\sigma = 0.014$, $\mytablecolsums$",
    #     "$\sigma = 0.141$, $\mytablecolsums$",
    #     r"$\sigma = \text{learned}$, $\mytablecolsums$"
    # ],
    "|doubly|": [
        "$\sigma = 0.014$, $\mytablecolsums,\mytablerowsums$",
        "$\sigma = 0.141$, $\mytablecolsums,\mytablerowsums$",
        r"$\sigma = \text{learned}$, $\mytablecolsums,\mytablerowsums$"
    ],
    "|doubly_10percent_cells|": [
        r"$\sigma = 0.014$, $\mytablecolsums,\mytablerowsums,\mytablecells{_1}$",
        r"$\sigma = 0.141$, $\mytablecolsums,\mytablerowsums,\mytablecells{_1}$",
        r"$\sigma = \text{learned}$, $\mytablecolsums,\mytablerowsums,\mytablecells{_1}$"
    ],
    "|doubly_20percent_cells|": [
        r"$\sigma = 0.014$, $\mytablecolsums,\mytablerowsums,\mytablecells{_2}$",
        r"$\sigma = 0.141$, $\mytablecolsums,\mytablerowsums,\mytablecells{_2}$",
        r"$\sigma = \text{learned}$, $\mytablecolsums,\mytablerowsums,\mytablecells{_2}$"
    ]
}

In [15]:
fontsize = 14

SIGMA_MARKERS = {
    "$\sigma = 0.014$":"v",
    "$\sigma = 0.141$":"^",
    r"$\sigma = \text{learned}$":">"
} 
SIGMA_COLORS = {
    "$\sigma = 0.014$":"#A2CFFE",
    "$\sigma = 0.141$":"#9B59B6",
    r"$\sigma = \text{learned}$":"#2ECC40"
}

# TABLE_CONSTRAINT_COLORS = {
#     "$\mytabletotal$":COLOR_NAMES["tab20b_purple"],
#     "$\mytablecolsums$":COLOR_NAMES["tab20b_green"],
#     "$\mytablecolsums,\mytablerowsums$":COLOR_NAMES["tab20b_orange"],
#     "$\mytablecolsums,\mytablerowsums,\mytablecells{_1}$":COLOR_NAMES["tab20b_red"],
#     "$\mytablecolsums,\mytablerowsums,\mytablecells{_2}$":COLOR_NAMES["tab20c_blue"],
# }

KEEP_KEYS = ["loss_operator","srmse","cp","sigma","table_constraints"]

In [None]:
def myplot(data,output_path,groups,x_label,y_label):
    fig,ax = plt.subplots()#figsize=(10,15)
    # ax.set_box_aspect(1)

    num_data = np.shape(data[y_label])[0]
    
    # Get unique elements without changing their order of appearance
    xlabels = list(dict.fromkeys(data[x_label]))
    x = np.arange(1,len(xlabels)*5,5)
    xlabel_dict = dict(zip(xlabels,x))

    # Create a defaultdict that defaults to another defaultdict
    data_tree = defaultdict(lambda: defaultdict())
    for i in range(num_data):
        _ = ax.scatter(
            xlabel_dict[data[x_label][i]],
            data[y_label][i],
            linewidth = 1.0,
            alpha=1.0,
            c = SIGMA_COLORS[data['sigma'][i]],
            marker = SIGMA_MARKERS[data['sigma'][i]],
        )
        # print(' > '.join([data[g][i] for g in groups]))
        add_leaf(data_tree, [data[g][i] for g in groups], {k:v[i] for k,v in data.items() if k in KEEP_KEYS})
        
    # Collect paths and leaf data
    leaf_data = traverse_tree(data_tree,KEEP_KEYS)

    # Print out the collected paths and leaf node data
    for path, leaf in leaf_data:
        # print(f"Path: {' -> '.join(path)} Sigma: {list(set(leaf['sigma']))}")
        # print(json.dumps({k:v for k,v in leaf.items() if k in ['iteration_ensemble','srmse']},indent=2))
        _ = ax.plot(
            [xlabel_dict[v] for v in leaf[x_label]],
            leaf[y_label],
            linewidth = 1.0,
            linestyle = "solid",
            c = SIGMA_COLORS[leaf['sigma'][0]],
            marker = SIGMA_MARKERS[leaf['sigma'][0]],
        )

    ax.tick_params(labelsize=fontsize)
    ax.set_xticks(x)
    ax.set_xticklabels(xlabels,rotation=40, ha="right")
    ax.xaxis.set_major_locator(mpl.ticker.FixedLocator(x))
    
    plt.xlabel(x_label.replace('_',' ').capitalize(),fontsize=fontsize)
    plt.ylabel(y_label.upper(),fontsize=fontsize)

    if y_label == 'cp':
        plt.ylim(0,100)

    legend_tree = dict()
    # Adding custom legend entries: one for each noise regime with all variations
    for noise,marker in SIGMA_MARKERS.items():
        style = Line2D(
            xdata=[0,100],
            ydata=[0,0], 
            color=SIGMA_COLORS[noise], 
            marker=marker, 
            linestyle="solid", 
            label=noise,
            linewidth=1,
            markersize=5
        )
        legend_tree[noise] = style

    # Manually create legend handles using the Line2D objects
    handles,labels = [],[]
    for label, lines in legend_tree.items():
        handles.append(lines)
        labels.append(label)

    # Add the custom legend (only use one label per group)
    plt.legend(
        handles=handles, 
        labels=labels, 
        # handler_map={tuple: HandlerTuple(ndivide=None)},
        handlelength=2,
        handleheight=0.5,
        markerscale=1,
        loc='best', 
        fontsize=9, 
        frameon=True,
        ncol=1,
    )

    # fig.tight_layout(rect=(0, 0, 0.7, 1.1))
    # fig.tight_layout()

    # plt.show()
    makedir(os.path.dirname(output_path))
    write_figure(
        fig,
        output_path,
        figure_format='ps',
        pad_inches=0.0,
        bbox_inches='tight'
    )


In [17]:
group_keys = ['sigma']
sort_keys = ['loss_operator']

for sweep_id, slice_vals in tqdm(sweep_data.items(),total=len(sweep_data.keys())):
    srmseoutputpath = outputbasepath+"srmse"+figuretitle+sweep_id
    cpoutputpath = outputbasepath+"cp"+figuretitle+sweep_id

    data = read_json(datapath+'_data.json')

    slice_key = 'label'
    slice_index = []
    for i,v in enumerate(data[slice_key]):
        if any([v == sv for sv in slice_vals]):
            # print(v)
            slice_index.append(i)
    
    assert slice_index

    data_slice = deepcopy(data)
    IGNORED_COLUMNS = ['outputs','x_group','y_group','z_group','annotate','hatch','x_id','y_id','z_id','z']

    dropped_keys = []
    for k in data_slice.keys():
        if k not in IGNORED_COLUMNS:
            data_slice[k] = np.array(data_slice[k])[slice_index].tolist()
        else:
            dropped_keys.append(k)
    
    data_slice['loss_operator'] = np.array(data_slice['x'])[:,0].tolist()
    data_slice['srmse'] = np.array(data_slice['y'])[:,0].tolist()
    data_slice['cp'] = np.array(100*(np.log(data['marker_size'])+2)/8).tolist()
    
    # Pattern to match LaTeX-style expressions
    label_pattern = r'\$(.*?)\$'
    # Separate into two lists
    data_slice['sigma'] = []
    data_slice['table_constraints'] = []

    for item in data_slice['label']:
        matches = re.findall(label_pattern, item)
        if len(matches) == 2:
            data_slice['sigma'].append(f"${matches[0]}$")
            data_slice['table_constraints'].append(f"${matches[1]}$")
    
    for k in dropped_keys+['x','y','colour']:
        del data_slice[k]

    # for k,v in data_slice.items():
    #     if len(np.shape(v)) > 0 and np.shape(v)[0] > 0:
    #         print(k,np.shape(v))

    # Sort all lists based on the values of the selected key
    sorted_indices = sorted(range(len(data_slice['srmse'])), key=lambda i: tuple([data_slice[k][i] for k in sort_keys]))

    # Reorder each list in the dictionary
    for k, v in data_slice.items():
        try:
            data_slice[k] = np.array([v[i] for i in sorted_indices])
        except:
            pass
    
    print(sweep_id,data_slice['cp'].shape[0])
    # print(set(data_slice['sigma']))
    # print(set(data_slice['table_constraints']))
    # print('LOSS',len(set(data_slice['loss_operator'])))
    
    myplot(
        data=data_slice,
        output_path=srmseoutputpath,
        groups=group_keys,
        x_label='loss_operator',
        y_label='srmse'
    )
    myplot(
        data=data_slice,
        output_path=cpoutputpath,
        groups=group_keys,
        x_label='loss_operator',
        y_label='cp'
    )

  0%|          | 0/3 [00:00<?, ?it/s]

|doubly| 36


The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.


|doubly_10percent_cells| 36


The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.


|doubly_20percent_cells| 36


The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
