In [None]:
import json
from pprint import pprint

In [None]:
# Load original and final results
with open("notebook_utils/ori_results.json", "r") as f:
    all_results = json.load(f)
with open("outputs/final_results.json", "r") as f:
    final_results = json.load(f)

# Add final results to original results
all_results["reproduction"]     = final_results["reprod"]
all_results["lmbda tuned"]      = final_results["lmbda_reprod"]
all_results["Lipschitz test"]   = final_results["Lipz_replic"]

# Add mean and std keys even when there is no std (for ease of use)
for source in all_results:
    for alg in all_results[source]:
        for dataset in all_results[source][alg]:
            for metric in all_results[source][alg][dataset]:
                value = all_results[source][alg][dataset][metric]
                if not isinstance(value, dict):
                    all_results[source][alg][dataset][metric] = {
                        "mean": value,
                        "std": None,
                    }

SOURCES     = all_results.keys()
ALGS        = all_results["baseline"].keys()
DATASETS    = all_results["baseline"]["kmeans"].keys()
METRICS     = all_results["baseline"]["kmeans"]["Synthetic"].keys()

ALG_TITLES = {
    "kmeans":   r"\textbf{Fair $\boldsymbol{K}$-means}",
    "kmedian":  r"\textbf{Fair $\boldsymbol{K}$-medians}",
    "ncut":     r"\textbf{Fair Ncut}",
}

In [None]:
pprint(all_results)

In [None]:
def fill_template(
    column_count: int,
    alg: str,
    sources: list,
    rows: list,
    caption: str="WRITE CAPTION",
    ) -> str:
    """
    Fill LaTeX table template with custom values.

    :param column_count: Amount of columns in the table
    :param alg: the cluster option / algorithm
    :param sources: list of sources to use in the header
    :param rows: list of lists of column values
    :param caption: 

    :return LaTeX formatted table as string
    """

    layout = f"|l|{'|'.join('c'*(column_count-1))}|"
    title = ALG_TITLES[alg]
    source_header = "& "+" & ".join(sources*2)+r" \\"
    content = "\n".join([" & ".join(columns)+r" \\" for columns in rows])
    label = f"tab:comparison_{'VS'.join(sources)}_{alg}"    
    
    return r"""
\begin{table}[]
    \centering
    \begin{tabular}{"""+ layout +r"""}
        \hline
            & \multicolumn{"""+ f"{column_count-1}" +r"""}{c|}{"""+ title +r"""} \\
            \cline{2-"""+ f"{column_count}" +r"""}
        \textbf{Datasets} & \multicolumn{"""+ f"{int((column_count-1)/2)}" +r"""}{c|}{Objective} &   \multicolumn{"""+ f"{int((column_count-1)/2)}" +r"""}{c|}{Fairness error / Balance} \\ \cline{2-"""+ f"{column_count}" +r"""}
        """+ source_header +r"""
        \hline
        """+ content +r"""
        \hline
    \end{tabular}
    \caption{"""+ caption +r"""}
    \label{"""+ label +r"""}
\end{table}
"""

In [None]:
def bold_best(values: list, lower_is_better: bool=True) -> list:
    """
    Return list of strings where the best one is bolded, LaTeX style.

    :param values: list of dicts with values, eg [{"mean": 1, "std": None}, {"mean": 2, "std": 0.5}]
    :param lower_is_better: bolds the lower value if True

    :return list of LaTeX strings like ["\\textbf{1}", "2($\pm 0.5$)"]
    """
    num_means = [v["mean"] for v in values if str(v["mean"]).replace('.', '').isnumeric()]
    best = max(num_means)
    if lower_is_better:
        best = min(num_means)

    str_values = []
    for v in values:
        mean = v["mean"]
        std = v["std"]

        if not isinstance(mean, str):
            mean = round(mean, 2)
        str_value = str(mean)
        if std != None:
            str_value += r"($\pm "+f"{std:.2f}$)"
        if mean == round(best,2):
            str_value = r"\textbf{"+str_value+"}"
        str_values.append(str_value)

    return str_values

def red_not_reprod(values: list, lower_is_better: bool=True) -> str:
    """
    Return list of color-coded strings, LaTeX style.
    - The first string will be red if it is more than a std worse than the second.
    - The first string will be orange if it is worse than the second, but less than a std.
    - The first string will be black if it is better than or equal to the second, indicating reproducibility.

    :param values: list of two dicts with values, eg [{"mean": 1, "std": None}, {"mean": 2, "std": 0.5}]
    :param lower_is_better: bolds the lower value if True

    :return list of LaTeX strings like ["\color{red} 1 \color{black}", "2($\pm 0.5$)"]
    """
    values[1]["mean"] = round(values[1]["mean"], 2)
    original = values[0]
    reprod = values[1]
    difference = original["mean"] - reprod["mean"]
    if not lower_is_better:
        difference *= -1
    color = "red"
    if difference >= -0.01*abs(reprod["mean"]):
        color = None
    elif difference >= -reprod["std"]:
        color = "orange"

    str_values = []
    for i, v in enumerate(values):
        mean = v["mean"]
        std = v["std"]
        str_value = mean
        if not isinstance(mean, str):
            str_value = f"{mean:.2f}"
        if std != None:
            str_value += r"($\pm "+f"{std:.2f}$)"
        if i == 0 and color:
            str_value = r"\color{"+color+"}"+str_value+r"\color{black}"
        str_values.append(str_value)

    return str_values

In [None]:
def generate_table_by_alg(alg: str, sources: list, bold_best_flag: bool=True) -> str:
    """
    Generate LaTeX table for the specified algorithm/cluster_option.

    :param alg: algorithm/cluster_option to generate table for
    :param sources: list consisting of a subset of the keys for all_results, eg ['original', 'reproduction']
    :param bold_best_flag: bolds the best values per metric for each row if set to True

    :return LaTeX table as string
    """
    for source in sources:
        assert(source in SOURCES or source == "lmbda tuned all")

    rows = []
    for dataset in DATASETS:
        try:
            [all_results[source][alg][dataset] for source in sources if source != "lmbda tuned all"]
        except KeyError:
            continue

        results_by_metric = {}
        for metric in METRICS:
            values = []
            for source in sources:
                if source == "lmbda tuned all":
                    try:
                        values.append(all_results["lmbda tuned"][alg][dataset][metric])
                        continue
                    except KeyError:
                        source = "reproduction"
                values.append(all_results[source][alg][dataset][metric])

            if bold_best_flag:
                results_by_metric[metric] = bold_best(values, metric != "balance")
            else:
                results_by_metric[metric] = red_not_reprod(values, metric != "balance")
        columns = []
        columns.extend(results_by_metric["Objective"])
        columns.extend([f"{f} / {b}" for f, b in zip(results_by_metric["fairness error"], results_by_metric["balance"])])
        columns.insert(0, dataset)
        rows.append(columns)

    return fill_template(len(columns), alg, sources, rows)

In [None]:
sources = ["baseline", "original", "reproduction"]
sources = ["original", "reproduction"]
# sources = ["baseline", "reproduction"]
# sources = ["baseline", "lmbda tuned"]
# sources = ["lmbda tuned all", "Lipschitz test"]
for alg in ALGS:
    print()
    print(generate_table_by_alg(alg, sources, sources!=["original", "reproduction"]))