# Predictions

This notebook computes the prediction metrics, i.e., the images from the test split are passed to the model, and the model predicts the top-k classes.
These predictions are used to compute the following metrics:
- `error_rate`
- `hier_dist_mistake`
- `hier_dist`

For each model, we have 5 runs (5 different models trained with 5 different random seeds, for a total of 25 models trained on each dataset), which are used to estimate the error bars. The prediction metric values can be organized into tables or visualized on a scatter plot (error_rate vs. hier_dist_mistake).

The prediction metrics are computed by the `predictions.py` script on the server, and the results are saved with pickle as pandas dataframes.
This notebook reads these dataframes and produces the tables and plots.

In [None]:
import re
from pathlib import Path
from functools import partial
import subprocess

import matplotlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import networkx as nx
import matplotlib.pyplot as plt

from pandas.io.formats.style import Styler

import utils

# Workaround for Firefox horizontal scroll for Dataframes
# https://github.com/jupyterlab/jupyterlab/issues/14625#issuecomment-1722137537
from IPython.display import display, HTML, Image, SVG
display(HTML("<style>.jp-OutputArea-output {display:flex}</style>"))

# Matplotlib theme
sns.set_theme(context='paper', style='ticks', palette='colorblind')
plt.rc('font', family='serif', serif='Times')
plt.rc('text', usetex=True)
plt.rc('xtick', labelsize=13)
plt.rc('ytick', labelsize=13)
plt.rc('axes', labelsize=13)
plt.rc('legend',fontsize=13)

*Utils functions to generate highlighted tables in html and tex format*

In [None]:
tex_experiments_names = {
    "xe-onehot": r"XE One-hot",
    "xe-mbm": r"XE MBM",
    "xe-b3p": r"XE B3P",
    "cd-bd": r"CD BD",
    "cd-desc": r"CD Desc.",
}
tex_metrics_names = {
    "error_rate": "Error Rate",
    "hier_dist_mistake": "Hier. Dist. M.",
    "hier_dist": "Hier. Dist.",
}


def highlight_predictions(dfs: Styler, axis: int = 0) -> Styler:
    """
    Highlight and format a DataFrame with style for prediction metrics.
    For all metrics predictions dataframe lower is better.

    Args:
        dfs (Styler): The pandas Styler object representing the DataFrame to be styled.
        axis (int, optional): The axis along which to apply the styling
        (0 for rows, 1 for columns). Defaults to 0.

    Returns:
        Styler: A Styler object with the specified styling applied.
    """
    dfs = dfs.background_gradient("Greens_r", axis=axis, low=1)
    dfs = dfs.highlight_min(props="font-weight: bold", axis=axis)
    return dfs


def table_html(df: pd.DataFrame, axis: int = 0) -> Styler:
    """Create HTML table with the predictions metrics for each experiment.
    This table is used to select the best experiment for each type.
    This will not be used in the final paper.
    """
    dfs = df.style.format(precision=3)
    dfs = highlight_predictions(dfs, axis=axis)
    return dfs


# Multi-index utils functions

def table_tex(
    df: pd.DataFrame,
    tex_experiments_names: dict[str, str],
    tex_metrics_names: dict[str, str],
) -> str:
    """Create Tex table with the predictions metrics mean and std for
    each experiment type. This table will be include in the paper.
    """
    global std_index
    std_index = 0
    df_mean = df.groupby(df_experiments["name"]).mean()
    df_std = df.groupby(df_experiments["name"]).std().fillna(0)

    # Sort columns according to thesis
    # Traspose df scales better to more levels
    dfs = df_mean.T[list(tex_experiments_names.keys())].style
    df_std = df_std.T[list(tex_experiments_names.keys())]

    def fmt(value: float, precision: float):
        global std_index
        std = df_std.stack().values[std_index]
        std_index += 1
        return rf"{value:.{precision}f} \mdseries ± {std:.{precision}f}"

    dfs = highlight_predictions(dfs, axis=1)
    # It's ok to use precision=3 for all predictions metrics
    dfs = dfs.format(partial(fmt, precision=3))

    # Headers Style
    dfs = dfs.hide(names=True, axis=0)
    dfs = dfs.hide(names=True, axis=1)
    dfs = dfs.format_index(lambda m: tex_metrics_names[m], axis=0, level=1)
    dfs = dfs.format_index(lambda m: tex_experiments_names[m], axis=1)

    # Convert to TeX
    dfs = dfs.to_latex(
        hrules=True,
        column_format=f"X r *{{{len(df_mean)}}}{{c}}",
        convert_css=True,
        multirow_align="c",
        clines="skip-last;index",
    )
    dfs = dfs.replace(
        r"\cline{1-2}",  # clines don't work with colored background
        rf"\hhline{{{'-' * (len(df_mean) + 2)}}}",  # so replace with hhline
        len(hierarchy) - 2,  # on n occurences replace only the first n - 1
        # because the last one it replaced by ""
    ).replace(r"\cline{1-2}", "")
    # Exract tabular enviroment.
    # I prefer to work with tabular instead of table env for the following reasons:
    # - Better control of table position and dimensions
    # - Better control of caption position
    pattern = r"\\begin\{tabular\}(.*?)\\end\{tabular\}"
    dfs = re.search(pattern, dfs, re.DOTALL)
    assert dfs is not None, "Pattern not found"
    dfs = r"\begin{tabularx}{\linewidth}" + dfs.group(1) + r"\end{tabularx}"
    return dfs


# Single-index utils functions

def df_to_df_s(df):
    df_s = {}
    for metric in metrics:
        _loc = (slice(None), metric), slice(None)
        _df = df.loc[_loc]
        _df = _df.reset_index().set_index("level")
        _df = _df.rename_axis("", axis=1) # maybe set to metric name
        _df = _df.rename_axis(index="") # maybe set to metric name
        _df = _df.drop('metric', axis=1)
        df_s[metric] = _df
    return df_s


def table_tex_s(
    df: pd.DataFrame,
    tex_experiments_names: dict[str, str],
    tex_metrics_names: dict[str, str],
) -> str:

    # aggregate and convert to single-index tables
    df_mean = df.groupby(df_experiments["name"]).mean()
    df_std = df.groupby(df_experiments["name"]).std().fillna(0)
    df_mean_s = df_to_df_s(df_mean.T)
    df_std_s = df_to_df_s(df_std.T)

    dfs = {}
    for metric in metrics:
        _df_mean = df_mean_s[metric]
        _df_std = df_std_s[metric]

        global std_index
        std_index = 0
    
        def fmt(value: float, precision: float):
            global std_index
            std = _df_std.stack().values[std_index]
            std_index += 1
            return rf"{value:.{precision}f} \mdseries ± {std:.{precision}f}"
     
        _dfs = highlight_predictions(_df_mean.style, axis=1)
        # It's ok to use precision=3 for all predictions metrics
        _dfs = _dfs.format(partial(fmt, precision=3))

        # Headers Formatting
        level_names = {i: f"level {i}" for i in _df_mean.index}
        _dfs = _dfs.hide(names=True, axis=0)
        _dfs = _dfs.hide(names=True, axis=1)
        _dfs = _dfs.format_index(lambda e: tex_experiments_names[e], axis=1)
        _dfs = _dfs.format_index(lambda l: level_names[l], axis=0)
    
        _dfs = _dfs.to_latex(
            hrules=True,
            column_format=f"X r *{{{len(_df_mean)}}}{{c}}",
            convert_css=True,
            multirow_align="c",
            clines="skip-last;index",
        )
    
        _dfs = _dfs.replace(
            r"\cline{1-2}",  # clines don't work with colored background
            rf"\hhline{{{'-' * (len(_df_mean) + 2)}}}",  # so replace with hhline
            len(hierarchy) - 2,  # on n occurences replace only the first n - 1
            # because the last one it replaced by ""
        ).replace(r"\cline{1-2}", "")
        # Exract tabular enviroment.
        # I prefer to work with tabular instead of table env for the following reasons:
        # - Better control of table position and dimensions
        # - Better control of caption position
        pattern = r"\\begin\{tabular\}(.*?)\\end\{tabular\}"
        _dfs = re.search(pattern, _dfs, re.DOTALL)
        assert _dfs is not None, "Pattern not found"
        _dfs = r"\begin{tabularx}{\linewidth}" + _dfs.group(1) + r"\end{tabularx}"

        # add to dict of metric-sty
        dfs[metric] = _dfs

    return dfs

In [None]:
def plot_hists():
    # For a single experiment
    path_experiments = path_root / "experiments" / DATASET
    path_experiment = path_experiments / df_experiments.iloc[0].name # hard-coded
    path_encoding = path_dataset / "encodings" / "onehot.npy" # hard-coded
    assert path_experiments.exists()
    assert path_encoding.exists()
    
    outputs = np.load(path_experiment / "results" / "outputs.npy")
    targets = np.load(path_experiment / "results" / "targets.npy")
    encodings = np.load(path_encoding)
    labels = utils.get_labels(targets, encodings)
    predictions = utils.get_predictions(outputs, encodings).argmax(axis=1)
    
    for level in range(len(hierarchy) - 1):
        lca_at_lvl = utils.hierarchy_to_lca(hierarchy[level:])
    
        wrong_mask = hierarchy[level][predictions] != hierarchy[level][labels]
        wrong_predictions_at_0 = predictions[wrong_mask]
        wrong_labels_at_0 = labels[wrong_mask]
    
        weights = lca_at_lvl[wrong_predictions_at_0, wrong_labels_at_0]
    
        # Metrics
        print("Error Rate:   ", wrong_mask.sum() / len(wrong_mask))
        print("Error Rate:   ", weights.sum() / len(weights))
        
        # Count the occurrences of each value
        counts = np.bincount(weights, minlength=len(hierarchy) - level + 1)[1:]
        print("Counts:       ", counts)
    
        # Create a bar plot for the occurrences
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
        x = np.arange(1, len(hierarchy) - level + 1)
        ax1.bar(x, counts)
    
        # Create another bar plot for the count * value
        values = x * counts
        ax1.bar(x, values, alpha=0.5)
    
        # Set the x-axis ticks and labels
        ax1.set_xticks(x)
        ax1.set_xticklabels(x)
    
        # Add labels and title
        ax1.set_xlabel('Distance')
        ax1.set_ylabel('Count')
        ax1.set_title(f'Level {level}')
    
        # Pie chart plot
        ax2.pie(values, labels=x, autopct='%1.1f%%', startangle=90)
        ax2.axis('equal')  # Equal aspect ratio ensures that pie is drawn as a circle.
        ax2.set_title('Contribution')
    
        # Show the plot
        plt.show()
        print("-" * 80)


def plot_tree():
    G = nx.DiGraph()

    for i, level in enumerate(hierarchy[1:], 1):
        pre_level = hierarchy[i-1]
        for parent in np.unique(level):
            childs = np.unique(pre_level[np.where(level == parent)[0]])
            edges = [(f"{i-1}.{child}", f"{i}.{parent}", ) for child in childs]
            # TODO: G.add_weighted_edges_from
            G.add_edges_from(edges)
    
    # Draw the graph
    fig, ax = plt.subplots(figsize=(13, len(hierarchy)))
    pos = nx.nx_agraph.graphviz_layout(G, prog="dot")
    nx.draw(G, pos, node_size=10, arrows=False, font_weight='bold', ax=ax, alpha=0.5)  # Increase node size and add labels
    
    plt.axis('off')  # Remove axes
    plt.show()

## CIFAR100

In [None]:
DATASET = "CIFAR100"

path_root = Path("..")
path_evals = path_root / "evals"
path_dataset = path_root / "datasets" / "datasets" / DATASET
path_results = path_evals / DATASET / "results"
path_results.mkdir(parents=True, exist_ok=True)
hierarchy = np.load(path_dataset / "hierarchy" / "hierarchy.npy")
metrics = ["error_rate", "hier_dist_mistake", "hier_dist"]

path_experiments = {
    "xe-onehot": sorted((path_evals / DATASET).glob("*_xe-onehot")),
    "xe-mbm": sorted((path_evals / DATASET).glob("*_xe-mbm-beta5.0")),
    "xe-b3p": sorted((path_evals / DATASET).glob("*_xe-b3p-beta0.4")),
    "cd-bd": sorted((path_evals / DATASET).glob("*_cd-barz-denzler")),
    "cd-desc": sorted((path_evals / DATASET).glob("*_cd-desc-pca-d100")),
}

df_experiments = pd.DataFrame(columns=["name"])
df_metrics = None
for name, paths in path_experiments.items():
    assert len(paths) == 5
    for path in paths:
        df_experiments.loc[path.name, "name"] = name
        path_predictions = path / "results" / "predictions.pkl"
        df_metrics = pd.concat([df_metrics, pd.read_pickle(path_predictions)])

df_mean = df_metrics.groupby(df_experiments["name"]).mean()
df_std = df_metrics.groupby(df_experiments["name"]).std().fillna(0)

### Tree

In [None]:
plot_tree()

### Confusion Matrices

In [None]:
# For a single experiemnt
path_experiments = path_root / "experiments" / DATASET
path_experiment = path_experiments / df_experiments.iloc[0].name # hard=coded
path_encoding = path_dataset / "encodings" / "onehot.npy" # hard-coded
assert path_experiments.exists()
assert path_encoding.exists()

outputs = np.load(path_experiment / "results" / "outputs.npy")
targets = np.load(path_experiment / "results" / "targets.npy")
encodings = np.load(path_encoding)
labels = utils.get_labels(targets, encodings)
predictions = utils.get_predictions(outputs, encodings)

# get top-1 predictions at given lvl of hierearchy
k = 1
level = 2
predictions = hierarchy[level][np.argsort(predictions, axis=1)[:, -k:]]
labels =  hierarchy[level][labels.reshape(-1, 1)]
assert labels.shape == predictions.shape
assert predictions.max() == hierarchy[level].max()
assert (num := labels.max()) == hierarchy[level].max()

# lca  matrix a specific level

indexes = np.unique(hierarchy[level:], return_index=True, axis=0)[1]
lca = utils.hierarchy_to_lca(hierarchy[level:])
lca

# idx to hier sorting
idx = np.lexsort(hierarchy)


# construct confusion matrix
size = labels.max() + 1
M = np.zeros((size, size), dtype=int)
print(M.shape)
for i, j in zip(labels, predictions):
    M[i, j] += 1
# remove diag values (i.e. correct results) for better viz
for i in range(len(M)):
    M[i, i] = 0


# lca matrix at level lvl
fig, axes = plt.subplots(nrows=1, ncols=2)

axes[0].imshow(M)
axes[1].imshow(lca)
#axes[2].imshow((M * lca)[idx, :][:, idx])
plt.show()

In [None]:
# For a single experiemnt
path_experiments = path_root / "experiments" / DATASET
path_experiment = path_experiments / df_experiments.iloc[0].name # hard=coded
path_encoding = path_dataset / "encodings" / "onehot.npy" # hard-coded
assert path_experiments.exists()
assert path_encoding.exists()


for level in range(len(hierarchy)):

    print(f"level {level}")
    outputs = np.load(path_experiment / "results" / "outputs.npy")
    targets = np.load(path_experiment / "results" / "targets.npy")
    encodings = np.load(path_encoding)
    labels = utils.get_labels(targets, encodings)
    predictions = utils.get_predictions(outputs, encodings)


    # get top-1 predictions at given lvl of hierearchy
    k = 1
    #level = 2
    predictions = np.argsort(predictions, axis=1)[:, -k:]
    labels =  labels.reshape(-1, 1)
    assert labels.shape == predictions.shape
    
    # lca  matrix a specific level
    lca = utils.hierarchy_to_lca(hierarchy[level:])
    
    # construct confusion matrix at specific level
    size = labels.max() + 1
    M = np.zeros((size, size), dtype=int)
    for i, j in zip(labels, predictions):
        if hierarchy[level][i] != hierarchy[level][j]:
            M[i, j] += 1
    
    # idx to hier sorting
    idx = np.lexsort(hierarchy)
    
    # lca matrix at level lvl
    fig, axes = plt.subplots(nrows=1, ncols=2)
    
    axes[0].imshow(M[idx, :][:, idx]**0.001)
    axes[1].imshow(lca[idx, :][:, idx])
    plt.show()
    
    
    W = lca * M
    
    from collections import Counter
    count = Counter(W.ravel())
    del count[0]
    plt.bar(range(len(count)), count.values())

    
    plt.show()

    plt.bar(range(len(count)), np.array(list(count.values())) * np.array(list(count.keys())))
    plt.show()

    print("-" * 80)

In [None]:
hierarchy

### Hist plots

In [None]:
plot_hists()

### Multi-index single-table

In [None]:
with open(path_evals / DATASET / "results" / "predictions.tex", "w") as f:
    f.write(table_tex(df_metrics, tex_experiments_names, tex_metrics_names))

table_html(df_mean.T, axis=1)

### Single-index multi-tables

- `slice(None)` is equivalent to `:`, but the syntax for multi-index selection only support `slice(None)` inside parentesis.
- In the following the suffix `_s` is used to indicate dict of sigle-index tables.
- The "metric dimension" is then encoded as key-value pair.


In [None]:
df_mean_s = df_to_df_s(df_mean.T)
df_std_s = df_to_df_s(df_std.T)
dfs_s = table_tex_s(df_metrics, tex_experiments_names, tex_metrics_names)

for metric in metrics:
    print(metric)
    display(table_html(df_mean_s[metric], axis=1))
    # display(df_std_s[metric])
    print("-" * 80)

    with open(path_evals / DATASET / "results" / f"predictions_{metric}.tex", "w") as f:
        f.write(dfs_s[metric])

### Scatter plots

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))

for name, tex in tex_experiments_names.items():
    ax.errorbar(
        df_mean_s["error_rate"][name],
        df_mean_s["hier_dist_mistake"][name],
        xerr=df_std_s["error_rate"][name],
        yerr=df_std_s["hier_dist_mistake"][name],
        #fmt='o',
        #markersize=2,
        capsize=2,
        label=tex,
        elinewidth=1,
        markeredgewidth=1.5,
        linestyle='dotted',
    )

ax.set_xlabel('Error Rate', labelpad=15)
ax.set_ylabel('Hierarchical Distance Mistake', labelpad=15)

handles, labels = ax.get_legend_handles_labels()
fig.legend(
    loc='upper center',
    handles=handles,
    labels=labels,
    bbox_to_anchor=(0.5, 1),
    ncol=5,
)

#fig.tight_layout()
path = (path_evals / DATASET / "results" / "predictions_hier_dist_mistake_error_rate_plot.pdf")
plt.savefig(path, bbox_inches='tight')
plt.show()

In [None]:
fig, axes = plt.subplots(ncols=4, nrows=1, figsize=(2.3 * 4, 2.7))

for ax, lvl in zip(axes, range(len(hierarchy))):
    for name, tex in tex_experiments_names.items():
        ax.errorbar(
            df_mean_s["error_rate"][name][lvl],
            df_mean_s["hier_dist_mistake"][name][lvl],
            xerr=df_std_s["error_rate"][name][lvl],
            yerr=df_std_s["hier_dist_mistake"][name][lvl],
            fmt='o',
            markersize=2,
            capsize=2,
            label=tex,
            #elinewidth=1,
            #markeredgewidth=1.5,
            #linestyle='dotted',
        )

        ax.set_xlabel(f'Level {lvl}', labelpad=15)
#ax.set_ylabel('Hierarchical Distance Mistake', labelpad=15)

handles, labels = ax.get_legend_handles_labels()
fig.legend(
    loc='upper center',
    handles=handles,
    labels=labels,
    bbox_to_anchor=(0.5, 1.2),
    ncol=5,
)

fig.tight_layout()
path = (path_evals / DATASET / "results" / "predictions_hier_dist_mistake_error_rate_plots.pdf")
plt.savefig(path, bbox_inches='tight')
plt.show()

## iNaturalist19

In [None]:
DATASET = "iNaturalist19"

path_root = Path("..")
path_evals = path_root / "evals"
path_dataset = path_root / "datasets" / "datasets" / DATASET
path_results = path_evals / DATASET / "results"
path_results.mkdir(parents=True, exist_ok=True)
hierarchy = np.load(path_dataset / "hierarchy" / "hierarchy.npy")
metrics = ["error_rate", "hier_dist_mistake", "hier_dist"]

path_experiments = {
    "xe-onehot": sorted((path_evals / DATASET).glob("*_xe-onehot")),
    "xe-mbm": sorted((path_evals / DATASET).glob("*_xe-mbm-beta15.0")),
    "xe-b3p": sorted((path_evals / DATASET).glob("*_xe-b3p-beta0.5")),
    "cd-bd": sorted((path_evals / DATASET).glob("*_cd-barz-denzler")),
    "cd-desc": sorted((path_evals / DATASET).glob("*_cd-desc-pca-d300")),
}

df_experiments = pd.DataFrame(columns=["name"])
df_metrics = None
for name, paths in path_experiments.items():
    for path in paths:
        df_experiments.loc[path.name, "name"] = name
        path_predictions = path / "results" / "predictions.pkl"
        df_metrics = pd.concat([df_metrics, pd.read_pickle(path_predictions)])


df_mean = df_metrics.groupby(df_experiments["name"]).mean()
df_std = df_metrics.groupby(df_experiments["name"]).std().fillna(0)

### Tree

In [None]:
plot_tree()

### Confusion Matrices

In [None]:
# For a single experiemnt
path_experiments = path_root / "experiments" / DATASET
path_experiment = path_experiments / df_experiments.iloc[0].name # hard=coded
path_encoding = path_dataset / "encodings" / "onehot.npy" # hard-coded
assert path_experiments.exists()
assert path_encoding.exists()


for level in range(len(hierarchy)):

    print(f"level {level}")
    outputs = np.load(path_experiment / "results" / "outputs.npy")
    targets = np.load(path_experiment / "results" / "targets.npy")
    encodings = np.load(path_encoding)
    labels = utils.get_labels(targets, encodings)
    predictions = utils.get_predictions(outputs, encodings)


    # get top-1 predictions at given lvl of hierearchy
    k = 1
    #level = 2
    predictions = np.argsort(predictions, axis=1)[:, -k:]
    labels =  labels.reshape(-1, 1)
    assert labels.shape == predictions.shape
    
    # lca  matrix a specific level
    lca = utils.hierarchy_to_lca(hierarchy[level:])
    
    # construct confusion matrix at specific level
    size = labels.max() + 1
    M = np.zeros((size, size), dtype=int)
    for i, j in zip(labels, predictions):
        if hierarchy[level][i] != hierarchy[level][j]:
            M[i, j] += 1
    
    # idx to hier sorting
    idx = np.lexsort(hierarchy)
    
    # lca matrix at level lvl
    fig, axes = plt.subplots(nrows=1, ncols=2)
    
    axes[0].imshow(M[idx, :][:, idx]**0.5)
    axes[1].imshow(lca[idx, :][:, idx])
    plt.show()
    
    
    W = lca * M
    
    from collections import Counter
    count = Counter(W.ravel())
    del count[0]
    plt.bar(range(len(count)), count.values())
    plt.show()

    print("-" * 80)

### Hist plots

In [None]:
plot_hists()

### Multi-index single-table

In [None]:
with open(path_evals / DATASET / "results" / "predictions.tex", "w") as f:
    f.write(table_tex(df_metrics, tex_experiments_names, tex_metrics_names))

table_html(df_mean.T, axis=1)

### Single-index multi-tables

In [None]:
df_mean_s = df_to_df_s(df_mean.T)
df_std_s = df_to_df_s(df_std.T)
dfs_s = table_tex_s(df_metrics, tex_experiments_names, tex_metrics_names)

for metric in metrics:
    print(metric)
    display(table_html(df_mean_s[metric], axis=1))
    # display(df_std_s[metric])
    print("-" * 80)

    with open(path_evals / DATASET / "results" / f"predictions_{metric}.tex", "w") as f:
        f.write(dfs_s[metric])

### Scatter plots

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))

for name, tex in tex_experiments_names.items():
    ax.errorbar(
        df_mean_s["error_rate"][name],
        df_mean_s["hier_dist_mistake"][name],
        xerr=df_std_s["error_rate"][name],
        yerr=df_std_s["hier_dist_mistake"][name],
        #fmt='o',
        #markersize=2,
        capsize=2,
        label=tex,
        elinewidth=1,
        markeredgewidth=1.5,
        linestyle='dotted',
    )

ax.set_xlabel('Error Rate', labelpad=15)
ax.set_ylabel('Hierarchical Distance Mistake', labelpad=15)

handles, labels = ax.get_legend_handles_labels()
fig.legend(
    loc='upper center',
    handles=handles,
    labels=labels,
    bbox_to_anchor=(0.5, 1),
    ncol=5,
)

#fig.tight_layout()
path = (path_evals / DATASET / "results" / "predictions_hier_dist_mistake_error_rate_plot.pdf")
plt.savefig(path, bbox_inches='tight')
plt.show()

In [None]:
fig, axes = plt.subplots(ncols=3, nrows=2, figsize=(2.3 * 3, 2.3 * 2 + 0.2))

for ax, lvl in zip(axes.ravel(), range(len(hierarchy))):
    for name, tex in tex_experiments_names.items():
        ax.errorbar(
            df_mean_s["error_rate"][name][lvl],
            df_mean_s["hier_dist_mistake"][name][lvl],
            xerr=df_std_s["error_rate"][name][lvl],
            yerr=df_std_s["hier_dist_mistake"][name][lvl],
            fmt='o',
            markersize=2,
            capsize=2,
            label=tex,
            #elinewidth=1,
            #markeredgewidth=1.5,
            #linestyle='dotted',
        )

        ax.set_xlabel(f'Level {lvl}', labelpad=15)
#ax.set_ylabel('Hierarchical Distance Mistake', labelpad=15)

handles, labels = ax.get_legend_handles_labels()
fig.legend(
    loc='upper center',
    handles=handles,
    labels=labels,
    bbox_to_anchor=(0.5, 1.1),
    ncol=5,
)

fig.tight_layout()
path = (path_evals / DATASET / "results" / "predictions_hier_dist_mistake_error_rate_plots.pdf")
plt.savefig(path, bbox_inches='tight')
plt.show()

## TieredImageNet

In [None]:
DATASET = "tieredImageNet"

path_root = Path("..")
path_evals = path_root / "evals"
path_dataset = path_root / "datasets" / "datasets" / DATASET
path_results = path_evals / DATASET / "results"
path_results.mkdir(parents=True, exist_ok=True)
hierarchy = np.load(path_dataset / "hierarchy" / "hierarchy.npy")
metrics = ["error_rate", "hier_dist_mistake", "hier_dist"]

path_experiments = {
    "xe-onehot": sorted((path_evals / DATASET).glob("*_xe-onehot")),
    "xe-mbm": sorted((path_evals / DATASET).glob("*_xe-mbm-beta15.0")),
    "xe-b3p": sorted((path_evals / DATASET).glob("*_xe-b3p-beta0.5")),
    "cd-bd": sorted((path_evals / DATASET).glob("*_cd-barz-denzler")),
    "cd-desc": sorted((path_evals / DATASET).glob("*_cd-desc-pca-d300")),
}

df_experiments = pd.DataFrame(columns=["name"])
df_metrics = None
for name, paths in path_experiments.items():
    # print(name, len(paths))
    for path in paths:
        df_experiments.loc[path.name, "name"] = name
        path_predictions = path / "results" / "predictions.pkl"
        df_metrics = pd.concat([df_metrics, pd.read_pickle(path_predictions)])


df_mean = df_metrics.groupby(df_experiments["name"]).mean()
df_std = df_metrics.groupby(df_experiments["name"]).std().fillna(0)

### Tree

In [None]:
plot_tree()

### Confusion Matrices

In [None]:
# For a single experiemnt
path_experiments = path_root / "experiments" / DATASET
path_experiment = path_experiments / df_experiments.iloc[0].name # hard=coded
path_encoding = path_dataset / "encodings" / "onehot.npy" # hard-coded
assert path_experiments.exists()
assert path_encoding.exists()


for level in range(len(hierarchy)):

    print(f"level {level}")
    outputs = np.load(path_experiment / "results" / "outputs.npy")
    targets = np.load(path_experiment / "results" / "targets.npy")
    encodings = np.load(path_encoding)
    labels = utils.get_labels(targets, encodings)
    predictions = utils.get_predictions(outputs, encodings)


    # get top-1 predictions at given lvl of hierearchy
    k = 1
    #level = 2
    predictions = np.argsort(predictions, axis=1)[:, -k:]
    labels =  labels.reshape(-1, 1)
    assert labels.shape == predictions.shape
    
    # lca  matrix a specific level
    lca = utils.hierarchy_to_lca(hierarchy[level:])
    
    # construct confusion matrix at specific level
    size = labels.max() + 1
    M = np.zeros((size, size), dtype=int)
    for i, j in zip(labels, predictions):
        if hierarchy[level][i] != hierarchy[level][j]:
            M[i, j] += 1
    
    # idx to hier sorting
    idx = np.lexsort(hierarchy)
    
    # lca matrix at level lvl
    fig, axes = plt.subplots(nrows=1, ncols=2)
    
    axes[0].imshow(M[idx, :][:, idx]**0.5)
    axes[1].imshow(lca[idx, :][:, idx])
    plt.show()
    
    
    W = lca * M
    
    from collections import Counter
    count = Counter(W.ravel())
    del count[0]
    plt.bar(range(len(count)), count.values())
    plt.show()

    print("-" * 80)

### Hist plots

In [None]:
plot_hists()

### Multi-index sigle-table

In [None]:
with open(path_evals / DATASET / "results" / "predictions.tex", "w") as f:
    f.write(table_tex(df_metrics, tex_experiments_names, tex_metrics_names))

table_html(df_mean.T, axis=1)

### Single-index multi-table

In [None]:
df_mean_s = df_to_df_s(df_mean.T)
df_std_s = df_to_df_s(df_std.T)
dfs_s = table_tex_s(df_metrics, tex_experiments_names, tex_metrics_names)

for metric in metrics:
    print(metric)
    display(table_html(df_mean_s[metric], axis=1))
    # display(df_std_s[metric])
    print("-" * 80)

    with open(path_evals / DATASET / "results" / f"predictions_{metric}.tex", "w") as f:
        f.write(dfs_s[metric])

### Scatter plots

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))

for name, tex in tex_experiments_names.items():
    ax.errorbar(
        df_mean_s["error_rate"][name],
        df_mean_s["hier_dist_mistake"][name],
        xerr=df_std_s["error_rate"][name],
        yerr=df_std_s["hier_dist_mistake"][name],
        #fmt='o',
        #markersize=2,
        capsize=2,
        label=tex,
        elinewidth=1,
        markeredgewidth=1.5,
        linestyle='dotted',
    )

ax.set_xlabel('Error Rate', labelpad=15)
ax.set_ylabel('Hierarchical Distance Mistake', labelpad=15)

handles, labels = ax.get_legend_handles_labels()
fig.legend(
    loc='upper center',
    handles=handles,
    labels=labels,
    bbox_to_anchor=(0.5, 1),
    ncol=5,
)

#fig.tight_layout()
path = (path_evals / DATASET / "results" / "predictions_hier_dist_mistake_error_rate_plot.pdf")
plt.savefig(path, bbox_inches='tight')
plt.show()

In [None]:
fig, axes = plt.subplots(ncols=3, nrows=3, figsize=(2.3 * 3, 2.3 * 3 + 0.2))

for ax, lvl in zip(axes.ravel(), range(len(hierarchy))):
    for name, tex in tex_experiments_names.items():
        ax.errorbar(
            df_mean_s["error_rate"][name][lvl],
            df_mean_s["hier_dist_mistake"][name][lvl],
            xerr=df_std_s["error_rate"][name][lvl],
            yerr=df_std_s["hier_dist_mistake"][name][lvl],
            fmt='o',
            markersize=2,
            capsize=2,
            label=tex,
            #elinewidth=1,
            #markeredgewidth=1.5,
            #linestyle='dotted',
        )

        ax.set_xlabel(f'Level {lvl}', labelpad=15)
#ax.set_ylabel('Hierarchical Distance Mistake', labelpad=15)

handles, labels = ax.get_legend_handles_labels()
fig.legend(
    loc='upper center',
    handles=handles,
    labels=labels,
    bbox_to_anchor=(0.5, 1.1),
    ncol=5,
)

fig.tight_layout()
path = (path_evals / DATASET / "results" / "predictions_hier_dist_mistake_error_rate_plots.pdf")
plt.savefig(path, bbox_inches='tight')
plt.show()