In [1]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

In [2]:
# Load data
df = pd.read_csv("wandb_export_2022-08-31T16_24_21.563-04_00.csv")
df2 = pd.read_csv("wandb_export_2022-09-13T14_36_54.025-04_00.csv")
df = pd.concat([df, df2]).reset_index(drop=True)
# -1 implies no chebyshev
df["datamodule/transform_args/cheb_order"] = df["datamodule/transform_args/cheb_order"].fillna(
    "exact"
)
df["datamodule/transform_args/power"] = df["datamodule/transform_args/power"].fillna(2)
df["datamodule/transform_args/power"].value_counts()

1.0    1750
2.0     600
Name: datamodule/transform_args/power, dtype: int64

In [3]:
# Clean data
def get_model(x):
    power, tau = x
    if power == 2:
        return r"$\wtwo$"
    if tau == "exact":
        return r"$\wone$ (exact)"
    else:
        return r"$\wone$ ($\tau=%d$)" % (int(tau))


clean_df = df.rename(
    columns={
        "datamodule/transform_args/alpha": r"$\alpha$",
        "datamodule/transform_args/power": "power",
        "datamodule/dataset": "dataset",
        "test/acc": "acc",
        "datamodule/transform_args/cheb_order": r"$\tau$",
    }
)
clean_df["power"] = clean_df["power"].astype(int)
clean_df[r"$\alpha$"] = clean_df[r"$\alpha$"].round(2)  # apply(lambda x: f"{x:0.2}")
clean_df = clean_df.replace("PTC_MR", "PTC MR")
clean_df["model"] = clean_df[["power", r"$\tau$"]].apply(get_model, axis=1)
clean_df

Unnamed: 0,Name,State,Notes,Tags,acc,$\alpha$,dataset,ckpt_path,model/net/hidden_dims,seed,power,$\tau$,model
0,cool-feather-1662,finished,-,"lr, power1, v6",0.367045,0.5,IMDB-MULTI,,,9,1,exact,$\wone$ (exact)
1,eager-voice-1661,finished,-,"lr, power1, v6",0.366477,0.5,IMDB-MULTI,,,8,1,exact,$\wone$ (exact)
2,proud-silence-1660,finished,-,"lr, power1, v6",0.409659,0.5,IMDB-MULTI,,,7,1,exact,$\wone$ (exact)
3,spring-firefly-1659,finished,-,"lr, power1, v6",0.397159,0.5,IMDB-MULTI,,,6,1,exact,$\wone$ (exact)
4,comic-paper-1658,finished,-,"lr, power1, v6",0.376136,0.5,IMDB-MULTI,,,5,1,exact,$\wone$ (exact)
...,...,...,...,...,...,...,...,...,...,...,...,...,...
2345,driven-salad-1668,finished,-,"cheb, lr, power1, v7",0.683048,0.0,NCI1,,,4,1,10.0,$\wone$ ($\tau=10$)
2346,elated-oath-1667,finished,-,"cheb, lr, power1, v7",0.680199,0.0,NCI1,,,3,1,10.0,$\wone$ ($\tau=10$)
2347,radiant-pyramid-1666,finished,-,"cheb, lr, power1, v7",0.627315,0.0,NCI1,,,2,1,10.0,$\wone$ ($\tau=10$)
2348,cerulean-eon-1665,finished,-,"cheb, lr, power1, v7",0.675392,0.0,NCI1,,,1,1,10.0,$\wone$ ($\tau=10$)


In [4]:
# Table processing
def process_line(means, highlight, highlight_index, highlight_max):
    if highlight:
        if highlight_max:
            tops = set(means.groupby(highlight_index).idxmax())
        else:
            tops = set(means.groupby(highlight_index).idxmin())
    else:
        tops = set()

    def process_line(x):
        if x.name in tops:
            return rf"\textbf{{{x['mean']:0.3f} $\pm$ {x['std']:0.3f}}}"
        return rf"{x['mean']:0.3f} $\pm$ {x['std']:0.3f}"

    return process_line


def mean_pm_std(
    data, index, columns, value, highlight=True, highlight_cols=True, highlight_max=True
):
    assert len(data) > 0
    groupby = data.groupby([*index, *columns])
    means = groupby.mean()[value].rename("mean")
    stds = groupby.std()[value].rename("std")
    ddf = pd.concat([means, stds], axis=1).T
    highlight_index = columns if highlight_cols else index
    ddf = ddf.apply(process_line(means, highlight, highlight_index, highlight_max))
    ddf = ddf.reset_index().pivot(index=index, columns=columns)
    ddf.columns = ddf.columns.droplevel(level=0)
    return ddf

In [5]:
# Table 2. alpha vs. power
cdf = (
    clean_df[(clean_df["dataset"] != "ENZYMES") & (clean_df[r"$\tau$"] == "exact")]
    .groupby(["seed", "model", r"$\alpha$"])
    .mean()
    .reset_index()
)
cdf

# W1 vs W2 kernels across alpha
results = mean_pm_std(cdf, index=["model"], columns=[r"$\alpha$"], value="acc").T
print(
    results.style.to_latex(
        hrules=True,
    )
)
results

\begin{tabular}{lll}
\toprule
{model} & {$\wone$ (exact)} & {$\wtwo$} \\
{$\alpha$} & {} & {} \\
\midrule
-0.5 & \textbf{0.617 $\pm$ 0.007} & 0.616 $\pm$ 0.012 \\
-0.25 & \textbf{0.640 $\pm$ 0.005} & 0.626 $\pm$ 0.009 \\
0.0 & \textbf{0.626 $\pm$ 0.012} & 0.623 $\pm$ 0.006 \\
0.25 & 0.619 $\pm$ 0.010 & \textbf{0.638 $\pm$ 0.008} \\
0.5 & \textbf{0.626 $\pm$ 0.009} & 0.616 $\pm$ 0.009 \\
\bottomrule
\end{tabular}



model,$\wone$ (exact),$\wtwo$
$\alpha$,Unnamed: 1_level_1,Unnamed: 2_level_1
-0.5,\textbf{0.617 $\pm$ 0.007},0.616 $\pm$ 0.012
-0.25,\textbf{0.640 $\pm$ 0.005},0.626 $\pm$ 0.009
0.0,\textbf{0.626 $\pm$ 0.012},0.623 $\pm$ 0.006
0.25,0.619 $\pm$ 0.010,\textbf{0.638 $\pm$ 0.008}
0.5,\textbf{0.626 $\pm$ 0.009},0.616 $\pm$ 0.009


In [6]:
# Table 3. Count best performing settings
def rename(x):
    return pd.Series((x["acc"][0], x["acc"][1], int(x[0])), index=["power", r"$\alpha$", "count"])


index = ["dataset", "seed"]
columns = ["power", r"$\alpha$"]
value = "acc"
res = (
    clean_df[
        ((clean_df[r"$\tau$"] == "exact") | (clean_df["power"] == 2))
        & (clean_df["dataset"] != "ENZYMES")
    ]
    .groupby([*index, *columns])
    .mean()[value]
    .reset_index()
    .set_index(columns)
    .groupby(index)
    .idxmax()
    .value_counts()
    .reset_index()
    .apply(rename, axis=1)
    .pivot(index=["power"], columns=[r"$\alpha$"])
    .T
)
res.index = res.index.droplevel(level=0)
res_with_sum = pd.concat([res, pd.DataFrame(res.sum(axis=0).rename("sum")).T])
res_with_sum = pd.concat(
    [res_with_sum, pd.DataFrame(res_with_sum.sum(axis=1).rename("sum"))], axis=1
).astype(int)

print(
    res_with_sum.style.to_latex(
        hrules=True,
    )
)
res_with_sum

\begin{tabular}{lrrr}
\toprule
{} & {1.0} & {2.0} & {sum} \\
\midrule
-0.5 & 12 & 2 & 14 \\
-0.25 & 12 & 16 & 28 \\
0.0 & 8 & 5 & 13 \\
0.25 & 13 & 14 & 27 \\
0.5 & 19 & 9 & 28 \\
sum & 64 & 46 & 110 \\
\bottomrule
\end{tabular}



Unnamed: 0,1.0,2.0,sum
-0.5,12,2,14
-0.25,12,16,28
0.0,8,5,13
0.25,13,14,27
0.5,19,9,28
sum,64,46,110


In [7]:

results = mean_pm_std(
    clean_df[(clean_df[r"$\alpha$"] == 0) & (clean_df["power"] == 1)],
    index=[r"$\tau$"],
    columns=["dataset"],
    value="acc",
).T
results

$\tau$,10.0,100.0,exact
dataset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
COLLAB,0.686 $\pm$ 0.010,\textbf{0.698 $\pm$ 0.005},0.690 $\pm$ 0.007
DD,0.699 $\pm$ 0.017,0.696 $\pm$ 0.028,\textbf{0.714 $\pm$ 0.015}
ENZYMES,0.221 $\pm$ 0.032,\textbf{0.242 $\pm$ 0.028},
IMDB-BINARY,0.683 $\pm$ 0.019,0.695 $\pm$ 0.012,\textbf{0.713 $\pm$ 0.021}
IMDB-MULTI,\textbf{0.444 $\pm$ 0.030},0.402 $\pm$ 0.014,0.421 $\pm$ 0.027
MUTAG,\textbf{0.750 $\pm$ 0.035},0.735 $\pm$ 0.034,0.675 $\pm$ 0.063
NCI1,\textbf{0.663 $\pm$ 0.022},0.637 $\pm$ 0.017,0.638 $\pm$ 0.006
NCI109,\textbf{0.675 $\pm$ 0.009},0.631 $\pm$ 0.013,0.641 $\pm$ 0.004
PROTEINS,0.780 $\pm$ 0.024,\textbf{0.797 $\pm$ 0.013},0.778 $\pm$ 0.024
PTC MR,0.292 $\pm$ 0.090,0.289 $\pm$ 0.069,\textbf{0.384 $\pm$ 0.097}


In [8]:
# Table 4. Chebyshev approximation
results1 = mean_pm_std(
    clean_df[(clean_df["dataset"] != "ENZYMES") & (clean_df["power"] == 1)],
    index=[r"$\tau$"],
    columns=["ckpt_path"],
    value="acc",
    highlight=False,
).T
results1.index = ["Mean"]
results = mean_pm_std(
    clean_df[(clean_df["dataset"] != "ENZYMES") & (clean_df["power"] == 1)],
    index=[r"$\tau$"],
    columns=["dataset"],
    value="acc",
    highlight=False,
).T
results = pd.concat([results, results1])
print(results.style.to_latex(hrules=True))

results

\begin{tabular}{llll}
\toprule
{$\tau$} & {10.0} & {100.0} & {exact} \\
\midrule
COLLAB & 0.692 $\pm$ 0.010 & 0.683 $\pm$ 0.012 & 0.702 $\pm$ 0.009 \\
DD & 0.685 $\pm$ 0.025 & 0.698 $\pm$ 0.024 & 0.695 $\pm$ 0.039 \\
IMDB-BINARY & 0.688 $\pm$ 0.040 & 0.666 $\pm$ 0.050 & 0.691 $\pm$ 0.031 \\
IMDB-MULTI & 0.406 $\pm$ 0.046 & 0.414 $\pm$ 0.029 & 0.398 $\pm$ 0.027 \\
MUTAG & 0.761 $\pm$ 0.074 & 0.733 $\pm$ 0.063 & 0.688 $\pm$ 0.070 \\
NCI1 & 0.653 $\pm$ 0.016 & 0.659 $\pm$ 0.029 & 0.645 $\pm$ 0.030 \\
NCI109 & 0.668 $\pm$ 0.022 & 0.628 $\pm$ 0.013 & 0.658 $\pm$ 0.023 \\
PROTEINS & 0.772 $\pm$ 0.023 & 0.799 $\pm$ 0.016 & 0.782 $\pm$ 0.022 \\
PTC MR & 0.325 $\pm$ 0.095 & 0.333 $\pm$ 0.097 & 0.387 $\pm$ 0.089 \\
REDDIT-BINARY & 0.805 $\pm$ 0.021 & 0.822 $\pm$ 0.021 & 0.814 $\pm$ 0.021 \\
REDDIT-MULTI-5K & 0.410 $\pm$ 0.013 & 0.422 $\pm$ 0.013 & 0.417 $\pm$ 0.016 \\
Mean & 0.624 $\pm$ 0.163 & 0.624 $\pm$ 0.160 & 0.626 $\pm$ 0.151 \\
\bottomrule
\end{tabular}



$\tau$,10.0,100.0,exact
COLLAB,0.692 $\pm$ 0.010,0.683 $\pm$ 0.012,0.702 $\pm$ 0.009
DD,0.685 $\pm$ 0.025,0.698 $\pm$ 0.024,0.695 $\pm$ 0.039
IMDB-BINARY,0.688 $\pm$ 0.040,0.666 $\pm$ 0.050,0.691 $\pm$ 0.031
IMDB-MULTI,0.406 $\pm$ 0.046,0.414 $\pm$ 0.029,0.398 $\pm$ 0.027
MUTAG,0.761 $\pm$ 0.074,0.733 $\pm$ 0.063,0.688 $\pm$ 0.070
NCI1,0.653 $\pm$ 0.016,0.659 $\pm$ 0.029,0.645 $\pm$ 0.030
NCI109,0.668 $\pm$ 0.022,0.628 $\pm$ 0.013,0.658 $\pm$ 0.023
PROTEINS,0.772 $\pm$ 0.023,0.799 $\pm$ 0.016,0.782 $\pm$ 0.022
PTC MR,0.325 $\pm$ 0.095,0.333 $\pm$ 0.097,0.387 $\pm$ 0.089
REDDIT-BINARY,0.805 $\pm$ 0.021,0.822 $\pm$ 0.021,0.814 $\pm$ 0.021


In [9]:
# Table 5. Full results (averaged over seeds)
cdf = (
    clean_df[(clean_df["dataset"] != "ENZYMES")]
    .groupby(["seed", "model", r"$\alpha$"])
    .mean()
    .reset_index()
)
cdf["ckpt_path"] = "None"
cdf


# W1 vs W2 kernels across alpha
results1 = mean_pm_std(
    cdf,
    index=["ckpt_path"],
    columns=[r"$\alpha$", "model"],
    highlight=True,
    highlight_cols=False,
    value="acc",
).T
results1.columns = ["Mean"]
results = mean_pm_std(
    clean_df[(clean_df["dataset"] != "ENZYMES")],
    index=["dataset"],
    columns=[r"$\alpha$", "model"],
    highlight=True,
    highlight_cols=False,
    value="acc",
).T
results = pd.concat([results, results1], axis=1)

# Split into two parts
print(results.iloc[:, :6].style.to_latex(hrules=True))
print(results.iloc[:, 6:].style.to_latex(hrules=True))
results

\begin{tabular}{llllllll}
\toprule
{} & {} & {COLLAB} & {DD} & {IMDB-BINARY} & {IMDB-MULTI} & {MUTAG} & {NCI1} \\
{$\alpha$} & {model} & {} & {} & {} & {} & {} & {} \\
\midrule
\multirow[c]{4}{*}{-0.5} & $\wone$ ($\tau=10$) & 0.699 $\pm$ 0.005 & 0.654 $\pm$ 0.021 & \textbf{0.740 $\pm$ 0.022} & 0.384 $\pm$ 0.032 & 0.795 $\pm$ 0.064 & 0.657 $\pm$ 0.011 \\
 & $\wone$ ($\tau=100$) & 0.676 $\pm$ 0.008 & 0.721 $\pm$ 0.017 & 0.727 $\pm$ 0.021 & 0.438 $\pm$ 0.026 & 0.725 $\pm$ 0.035 & 0.681 $\pm$ 0.022 \\
 & $\wone$ (exact) & 0.706 $\pm$ 0.006 & 0.727 $\pm$ 0.013 & 0.698 $\pm$ 0.026 & 0.397 $\pm$ 0.023 & 0.600 $\pm$ 0.000 & 0.622 $\pm$ 0.006 \\
 & $\wtwo$ & 0.699 $\pm$ 0.009 & 0.660 $\pm$ 0.014 & 0.588 $\pm$ 0.012 & 0.335 $\pm$ 0.043 & 0.775 $\pm$ 0.049 & 0.651 $\pm$ 0.006 \\
\multirow[c]{4}{*}{-0.25} & $\wone$ ($\tau=10$) & 0.684 $\pm$ 0.008 & 0.683 $\pm$ 0.013 & 0.655 $\pm$ 0.036 & 0.438 $\pm$ 0.033 & 0.660 $\pm$ 0.070 & 0.642 $\pm$ 0.018 \\
 & $\wone$ ($\tau=100$) & 0.683 $\pm$ 0.007 & 0.69

Unnamed: 0_level_0,Unnamed: 1_level_0,COLLAB,DD,IMDB-BINARY,IMDB-MULTI,MUTAG,NCI1,NCI109,PROTEINS,PTC MR,REDDIT-BINARY,REDDIT-MULTI-5K,Mean
$\alpha$,model,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
-0.5,$\wone$ ($\tau=10$),0.699 $\pm$ 0.005,0.654 $\pm$ 0.021,\textbf{0.740 $\pm$ 0.022},0.384 $\pm$ 0.032,0.795 $\pm$ 0.064,0.657 $\pm$ 0.011,0.631 $\pm$ 0.008,0.772 $\pm$ 0.012,0.422 $\pm$ 0.106,0.794 $\pm$ 0.009,0.403 $\pm$ 0.007,0.631 $\pm$ 0.016
-0.5,$\wone$ ($\tau=100$),0.676 $\pm$ 0.008,0.721 $\pm$ 0.017,0.727 $\pm$ 0.021,0.438 $\pm$ 0.026,0.725 $\pm$ 0.035,0.681 $\pm$ 0.022,0.627 $\pm$ 0.009,0.790 $\pm$ 0.014,0.292 $\pm$ 0.057,0.821 $\pm$ 0.021,0.412 $\pm$ 0.012,0.627 $\pm$ 0.008
-0.5,$\wone$ (exact),0.706 $\pm$ 0.006,0.727 $\pm$ 0.013,0.698 $\pm$ 0.026,0.397 $\pm$ 0.023,0.600 $\pm$ 0.000,0.622 $\pm$ 0.006,0.624 $\pm$ 0.010,0.798 $\pm$ 0.013,0.423 $\pm$ 0.056,0.789 $\pm$ 0.014,0.403 $\pm$ 0.007,0.617 $\pm$ 0.007
-0.5,$\wtwo$,0.699 $\pm$ 0.009,0.660 $\pm$ 0.014,0.588 $\pm$ 0.012,0.335 $\pm$ 0.043,0.775 $\pm$ 0.049,0.651 $\pm$ 0.006,0.663 $\pm$ 0.010,0.802 $\pm$ 0.016,0.366 $\pm$ 0.072,0.814 $\pm$ 0.011,0.416 $\pm$ 0.007,0.616 $\pm$ 0.012
-0.25,$\wone$ ($\tau=10$),0.684 $\pm$ 0.008,0.683 $\pm$ 0.013,0.655 $\pm$ 0.036,0.438 $\pm$ 0.033,0.660 $\pm$ 0.070,0.642 $\pm$ 0.018,0.661 $\pm$ 0.007,0.790 $\pm$ 0.027,0.251 $\pm$ 0.044,0.793 $\pm$ 0.008,0.406 $\pm$ 0.012,0.607 $\pm$ 0.008
-0.25,$\wone$ ($\tau=100$),0.683 $\pm$ 0.007,0.694 $\pm$ 0.026,0.635 $\pm$ 0.051,0.389 $\pm$ 0.016,0.700 $\pm$ 0.062,0.621 $\pm$ 0.014,0.615 $\pm$ 0.015,0.803 $\pm$ 0.013,0.286 $\pm$ 0.096,0.820 $\pm$ 0.023,0.424 $\pm$ 0.009,0.609 $\pm$ 0.014
-0.25,$\wone$ (exact),0.701 $\pm$ 0.010,0.711 $\pm$ 0.026,0.692 $\pm$ 0.038,0.373 $\pm$ 0.014,0.770 $\pm$ 0.026,0.650 $\pm$ 0.004,0.676 $\pm$ 0.004,0.791 $\pm$ 0.009,\textbf{0.456 $\pm$ 0.020},0.808 $\pm$ 0.017,0.410 $\pm$ 0.016,\textbf{0.640 $\pm$ 0.005}
-0.25,$\wtwo$,0.693 $\pm$ 0.007,0.674 $\pm$ 0.011,0.579 $\pm$ 0.017,0.401 $\pm$ 0.045,0.740 $\pm$ 0.039,0.649 $\pm$ 0.009,0.680 $\pm$ 0.008,0.791 $\pm$ 0.038,0.402 $\pm$ 0.051,0.841 $\pm$ 0.007,\textbf{0.435 $\pm$ 0.009},0.626 $\pm$ 0.009
0.0,$\wone$ ($\tau=10$),0.686 $\pm$ 0.010,0.699 $\pm$ 0.017,0.683 $\pm$ 0.019,\textbf{0.444 $\pm$ 0.030},0.750 $\pm$ 0.035,0.663 $\pm$ 0.022,0.675 $\pm$ 0.009,0.780 $\pm$ 0.024,0.292 $\pm$ 0.090,0.791 $\pm$ 0.013,0.419 $\pm$ 0.006,0.624 $\pm$ 0.010
0.0,$\wone$ ($\tau=100$),0.698 $\pm$ 0.005,0.696 $\pm$ 0.028,0.695 $\pm$ 0.012,0.402 $\pm$ 0.014,0.735 $\pm$ 0.034,0.637 $\pm$ 0.017,0.631 $\pm$ 0.013,0.797 $\pm$ 0.013,0.289 $\pm$ 0.069,0.814 $\pm$ 0.014,0.430 $\pm$ 0.013,0.620 $\pm$ 0.009
