In [68]:
import os
import subprocess
import numpy as np
import matplotlib as mpl
mpl.use('ps')
import matplotlib.pyplot as plt

from copy import deepcopy
from tqdm.auto import tqdm
from matplotlib.lines import Line2D
from matplotlib.ticker import MultipleLocator
from matplotlib.legend_handler import HandlerTuple
# from IPython.core.display import display, HTML

from gensit.utils.misc_utils import *
from gensit.static.plot_variables import LATEX_RC_PARAMETERS

In [69]:
%matplotlib inline

# AUTO RELOAD EXTERNAL MODULES
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [70]:
# LaTeX font configuration
mpl.rcParams.update(LATEX_RC_PARAMETERS)

# Cumulative SRMSE & CP vs iteration by constraints, method
variable = table

sigma = high

constraints = rowsums, doubly_and_20percent_cells

In [75]:
datapath = "../../data/outputs/cambridge_work_commuter_lsoas_to_msoas/exp1/paper_figures/figure2/cumulative_srmse_and_cp_by_method_label_title&sigma_marker_sigma_markersize_table_coverage_probability_size_linewidth_1.0_colour_title_opacity_1.0_hatchopacity_1.0"

outputbasepath ="../../data/outputs/cambridge_work_commuter_lsoas_to_msoas/exp1/paper_figures/figure2_v2/"
figuretitle = "_by_constraints_method"

sweep_data = {
    # "|total|": [
    #     "$\mytabletotal$, $\sigma = 0.014$",
    #     "$\mytabletotal$, $\sigma = 0.141$",
    #     r"$\mytabletotal$, $\sigma = \text{learned}$"
    # ],
    "|colsums|": [
        "$\mytablecolsums$, $\sigma = 0.014$",
        "$\mytablecolsums$, $\sigma = 0.141$",
        r"$\mytablecolsums$, $\sigma = \text{learned}$"
    ],
    "|doubly|": [
        "$\mytablecolsums,\mytablerowsums$, $\sigma = 0.014$",
        "$\mytablecolsums,\mytablerowsums$, $\sigma = 0.141$",
        r"$\mytablecolsums,\mytablerowsums$, $\sigma = \text{learned}$"
    ],
    "|doubly_10percent_cells|": [
        r"$\mytablecolsums,\mytablerowsums,\mytablecells{_1}$, $\sigma = 0.014$",
        r"$\mytablecolsums,\mytablerowsums,\mytablecells{_1}$, $\sigma = 0.141$",
        r"$\mytablecolsums,\mytablerowsums,\mytablecells{_1}$, $\sigma = \text{learned}$"
    ],
    "|doubly_20percent_cells|": [
        r"$\mytablecolsums,\mytablerowsums,\mytablecells{_2}$, $\sigma = 0.014$",
        r"$\mytablecolsums,\mytablerowsums,\mytablecells{_2}$, $\sigma = 0.141$",
        r"$\mytablecolsums,\mytablerowsums,\mytablecells{_2}$, $\sigma = \text{learned}$"
    ]
}

In [87]:
fontsize = 14

SIGMA_COLORS = {
    "$\sigma = 0.014$":"#A2CFFE",
    "$\sigma = 0.141$":"#9B59B6",
    r"$\sigma = \text{learned}$":"#2ECC40"
}
METHOD_MARKERS = {
    "\zachosframeworktag":"+",
    "Joint (\\frameworktag)":"*",
    "Disjoint (\\frameworktag)":"o",
}
METHOD_LINESTYLES = {
    "\zachosframeworktag":":",
    "Joint (\\frameworktag)":"-",
    "Disjoint (\\frameworktag)":"--",
}

KEEP_KEYS = ["iteration","srmse","cp","sigma","method"]

In [91]:
def myplot(data,output_path,groups,x_label,y_label):
    fig,ax = plt.subplots()#figsize=(10,15)
    # ax.set_box_aspect(1)

    num_data = np.shape(data[y_label])[0]

    # Create a defaultdict that defaults to another defaultdict
    data_tree = defaultdict(lambda: defaultdict())
    for i in range(num_data):
        _ = ax.scatter(
            data[x_label][i],
            data[y_label][i],
            linewidth = 1.0,
            alpha=1.0,
            c = SIGMA_COLORS[data['sigma'][i]],
            marker = METHOD_MARKERS[data['method'][i]],
        )
        # print(' > '.join([data[g][i] for g in groups]))
        add_leaf(data_tree, [data[g][i] for g in groups], {k:v[i] for k,v in data.items() if k in KEEP_KEYS})
        
    # Collect paths and leaf data
    leaf_data = traverse_tree(data_tree,KEEP_KEYS)

    # Print out the collected paths and leaf node data
    for path, leaf in leaf_data:
        # print(f"Path: {' -> '.join(path)}")
        # print(json.dumps({k:v for k,v in leaf.items() if k in ['iteration_ensemble','srmse']},indent=2))
        _ = ax.plot(
            leaf[x_label],
            leaf[y_label],
            linewidth = 1.0,
            linestyle = METHOD_LINESTYLES[leaf['method'][0]],
            c = SIGMA_COLORS[leaf['sigma'][0]],
            marker = METHOD_MARKERS[leaf['method'][0]],
        )

    ax.tick_params(labelsize=fontsize)
    # ax.xaxis.set_major_locator(mpl.ticker.FixedLocator(leaf[x_label]))
    
    plt.xlabel(r'$N$',fontsize=fontsize)
    plt.ylabel(y_label.upper(),fontsize=fontsize)

    if y_label == 'cp':
        plt.ylim(0,100)

    legend_tree = {}
    # Adding custom legend entries: one for each noise regime with all variations
    for m in range(2):
        for noise,color in SIGMA_COLORS.items():
            for method in METHOD_MARKERS.keys():
                line_style = METHOD_LINESTYLES[method]
                marker = METHOD_MARKERS[method]

                # Create legend entry style
                if m == 0:
                    style = Line2D(
                        xdata=[],
                        ydata=[], 
                        color=color, 
                        marker=marker,
                        linewidth=0,
                        markersize=5,
                    )
                    legend_tree.setdefault(noise,[]).append(style)
                else:
                    style = Line2D(
                        xdata=[0,100],
                        ydata=[0,0], 
                        color=color, 
                        marker=marker, 
                        linestyle=line_style,
                        linewidth=1,
                        markersize=5,
                    )
                    if method == 'Disjoint (\frameworktag)':
                        style.set_dashes([2,1])
                    elif method == '\zachosframeworktag':
                        style.set_dashes([1,1])
                    legend_tree.setdefault(method,[]).append(style)

    # Manually create legend handles using the Line2D objects
    handles,labels = [],[]
    for label, lines in legend_tree.items():
        length = 1 if 'sigma' in label else 6
        handles.append(tuple(lines))
        labels.append(label)

    # Add the custom legend (only use one label per group)
    plt.legend(
        handles=handles, 
        labels=labels, 
        handler_map={tuple: HandlerTuple(ndivide=None)},
        handlelength=6,
        handleheight=0.5,
        markerscale=1,
        loc='best', 
        fontsize=9, 
        frameon=True,
        ncol=1 if y_label == 'srmse' else 2
    )

    # fig.tight_layout(rect=(0, 0, 0.7, 1.1))
    # fig.tight_layout()

    # plt.show()
    makedir(os.path.dirname(output_path))
    write_figure(
        fig,
        output_path,
        filename_ending='ps',
        pad_inches=0.0,
        bbox_inches='tight'
    )

    subprocess.call(
        ["ps2pdf", output_path+'.ps', output_path+'.pdf'], 
        stdout=subprocess.DEVNULL, 
        stderr=subprocess.DEVNULL
    )
    subprocess.call(
        ["pdfcrop", "--margins", "0 0 0 0", output_path+'.pdf', output_path+'.pdf'], 
        stdout=subprocess.DEVNULL, 
        stderr=subprocess.DEVNULL
    )

In [96]:
group_keys = ['method','sigma'] 
sort_keys = ['iteration']

for sweep_id, slice_vals in sweep_data.items():
    srmseoutputpath = outputbasepath+"cumulative_srmse"+figuretitle+sweep_id
    cpoutputpath = outputbasepath+"cumulative_cp"+figuretitle+sweep_id

    data = read_json(datapath+'_data.json')

    slice_key = 'label'
    slice_index = []
    for i,v in enumerate(data[slice_key]):
        if any([v == sv for sv in slice_vals]):
            # print(v)
            slice_index.append(i)
    
    assert slice_index

    data_slice = deepcopy(data)
    IGNORED_COLUMNS = ['outputs','x_group','y_group','z_group','annotate','hatch','x_id','y_id','z_id','z']

    dropped_keys = []
    for k in data_slice.keys():
        if k not in IGNORED_COLUMNS:
            data_slice[k] = np.array(data_slice[k])[slice_index].tolist()
        else:
            dropped_keys.append(k)
    
    data_slice['method'] = np.array(data_slice['x'])[:,0].tolist()
    data_slice['iteration'] = list(map(int,np.array(data_slice['x'])[:,1].tolist()))
    data_slice['srmse'] = np.array(data_slice['y'])[:,0].tolist()
    data_slice['cp'] = np.array(100*(np.log(data_slice['marker_size'])+2)/8).tolist()
    

    # Pattern to match LaTeX-style expressions
    label_pattern = r'\$(.*?)\$'
    # Separate into two lists
    data_slice['sigma'] = []
    data_slice['table_constraints'] = []

    for item in data_slice['label']:
        matches = re.findall(label_pattern, item)
        if len(matches) == 2:
            data_slice['table_constraints'].append(f"${matches[0]}$")
            data_slice['sigma'].append(f"${matches[1]}$")
    
    for k in dropped_keys+['x','y']:
        del data_slice[k]

    # for k,v in data_slice.items():
    #     if len(np.shape(v)) > 0 and np.shape(v)[0] > 0:
    #         print(k,np.shape(v))

    # Sort all lists based on the values of the selected key
    sorted_indices = sorted(range(len(data_slice['cp'])), key=lambda i: tuple([data_slice[k][i] for k in sort_keys]))

    # Reorder each list in the dictionary
    for k, v in data_slice.items():
        try:
            data_slice[k] = np.array([v[i] for i in sorted_indices])
        except:
            pass
    
    print(sweep_id,data_slice['cp'].shape[0])
    print(data_slice['iteration'])
    myplot(
        data_slice,
        output_path=srmseoutputpath,
        groups=group_keys,
        x_label='iteration',
        y_label='srmse'
    )
    myplot(
        data_slice,
        output_path=cpoutputpath,
        groups=group_keys,
        x_label='iteration',
        y_label='cp'
    )

|colsums| 48
[ 10000  10000  10000  10000  10000  10000  10000  10000  20000  20000
  20000  20000  20000  20000  20000  20000  40000  40000  40000  40000
  40000  40000  40000  40000  60000  60000  60000  60000  60000  60000
  60000  60000  80000  80000  80000  80000  80000  80000  80000  80000
 100000 100000 100000 100000 100000 100000 100000 100000]


The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.


|doubly| 48
[ 10000  10000  10000  10000  10000  10000  10000  10000  20000  20000
  20000  20000  20000  20000  20000  20000  40000  40000  40000  40000
  40000  40000  40000  40000  60000  60000  60000  60000  60000  60000
  60000  60000  80000  80000  80000  80000  80000  80000  80000  80000
 100000 100000 100000 100000 100000 100000 100000 100000]


The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.


|doubly_10percent_cells| 48
[ 10000  10000  10000  10000  10000  10000  10000  10000  20000  20000
  20000  20000  20000  20000  20000  20000  40000  40000  40000  40000
  40000  40000  40000  40000  60000  60000  60000  60000  60000  60000
  60000  60000  80000  80000  80000  80000  80000  80000  80000  80000
 100000 100000 100000 100000 100000 100000 100000 100000]


The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.


|doubly_20percent_cells| 48
[ 10000  10000  10000  10000  10000  10000  10000  10000  20000  20000
  20000  20000  20000  20000  20000  20000  40000  40000  40000  40000
  40000  40000  40000  40000  60000  60000  60000  60000  60000  60000
  60000  60000  80000  80000  80000  80000  80000  80000  80000  80000
 100000 100000 100000 100000 100000 100000 100000 100000]


The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
