In [8]:
import polars as pl
import numpy as np
from scipy import stats
import altair as alt

In [2]:
df = pl.scan_parquet("data/data.parquet")

In [3]:
def ci_lower(x, confidence = .95):
    x = x.to_numpy()
    n = len(x)

    mean = np.mean(x)
    se = stats.sem(x)
    ci = se * stats.t.pdf((1 + confidence) / 2, n-1)

    return mean - ci

def ci_upper(x, confidence = .95):
    x = x.to_numpy()
    n = len(x)

    mean = np.mean(x)
    se = stats.sem(x)
    ci = se * stats.t.pdf((1 + confidence) / 2, n-1)

    return mean + ci

In [4]:
cols = ["train acc", "train f1", "test acc", "test f1", 
               "adv acc", "adv f1", "adv distance"]

def agg_metrics(cols):
    res = []
    for col in cols:
        res.extend([
            pl.col(col).mean().alias(f"{col}_mean"),
            pl.col(col).std().alias(f"{col}_std"),
            pl.col(col).map_batches(ci_lower, return_dtype=pl.Float32, returns_scalar=True).alias(f"{col}_ci_lb"),
            pl.col(col).map_batches(ci_upper, return_dtype=pl.Float32, returns_scalar=True).alias(f"{col}_ci_ub"),
        ])
    return res

In [5]:
df_stats = df.group_by(
    pl.col("distribution"),
    pl.col("seed"),
    pl.col("depth")
).agg(
    agg_metrics(cols)
).sort([
    pl.col("distribution"),
    pl.col("seed"),
    pl.col("depth")
])

# df.with_columns(pl.col("train acc").gr)

In [13]:
df_stats.filter(
    pl.col("distribution") == "check"
    ).collect()

distribution,seed,depth,train acc_mean,train acc_std,train acc_ci_lb,train acc_ci_ub,train f1_mean,train f1_std,train f1_ci_lb,train f1_ci_ub,test acc_mean,test acc_std,test acc_ci_lb,test acc_ci_ub,test f1_mean,test f1_std,test f1_ci_lb,test f1_ci_ub,adv acc_mean,adv acc_std,adv acc_ci_lb,adv acc_ci_ub,adv f1_mean,adv f1_std,adv f1_ci_lb,adv f1_ci_ub,adv distance_mean,adv distance_std,adv distance_ci_lb,adv distance_ci_ub
str,i32,i32,f64,f64,f32,f32,f64,f64,f32,f32,f64,f64,f32,f32,f64,f64,f32,f32,f64,f64,f32,f32,f64,f64,f32,f32,f64,f64,f32,f32
"""check""",8,1,0.5241875,0.005241,0.523672,0.524703,0.589687,0.035311,0.586212,0.593162,0.5135,0.014853,0.512038,0.514962,0.576092,0.052866,0.570889,0.581295,0.4865,0.014853,0.485038,0.487962,0.376677,0.092297,0.367594,0.38576,58.512655,8.275574,57.698242,59.327068
"""check""",8,2,0.55325,0.010443,0.552222,0.554278,0.590798,0.015746,0.589249,0.592348,0.54175,0.02036,0.539746,0.543754,0.578737,0.02983,0.575801,0.581673,0.45825,0.02036,0.456246,0.460254,0.399438,0.046004,0.394911,0.403966,73.338011,2.968117,73.045914,73.630112
"""check""",8,3,0.5656875,0.006634,0.565035,0.56634,0.623196,0.030317,0.620213,0.62618,0.55175,0.015376,0.550237,0.553263,0.606391,0.047018,0.601764,0.611018,0.44825,0.015376,0.446737,0.449763,0.339037,0.091834,0.33,0.348075,81.58065,3.385377,81.24749,81.913811
"""check""",8,4,0.604625,0.017488,0.602904,0.606346,0.615942,0.075103,0.608551,0.623333,0.5805,0.029456,0.577601,0.583399,0.590349,0.0854,0.581944,0.598753,0.4195,0.029456,0.416601,0.422399,0.375631,0.097287,0.366056,0.385205,66.556573,15.176135,65.063065,68.050079
"""check""",8,5,0.6545,0.019362,0.652595,0.656405,0.667106,0.019108,0.665226,0.668987,0.61975,0.042181,0.615599,0.623901,0.633036,0.02567,0.630509,0.635562,0.38025,0.042181,0.376099,0.384401,0.353503,0.033859,0.350171,0.356835,68.347189,9.344633,67.427567,69.266808
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
"""check""",105,19,0.9986875,0.001537,0.998536,0.998839,0.998682,0.001543,0.99853,0.998834,0.8075,0.008927,0.806621,0.808379,0.806337,0.010746,0.805279,0.807395,0.1925,0.008927,0.191621,0.193378,0.19385,0.019678,0.191913,0.195786,80.219568,6.87216,79.543266,80.895866
"""check""",105,20,0.99925,0.000784,0.999173,0.999327,0.999248,0.000787,0.99917,0.999325,0.80775,0.008263,0.806937,0.808563,0.806474,0.009603,0.805529,0.807419,0.19225,0.008263,0.191437,0.193063,0.194271,0.017096,0.192588,0.195953,80.215741,6.865171,79.540131,80.89135
"""check""",105,21,0.9996875,0.000541,0.999634,0.999741,0.999686,0.000543,0.999633,0.99974,0.80825,0.006993,0.807562,0.808938,0.806539,0.009286,0.805626,0.807453,0.19175,0.006993,0.191062,0.192438,0.195207,0.021058,0.193134,0.197279,80.219589,6.868377,79.543663,80.895515
"""check""",105,22,0.9999375,0.00014,0.999924,0.999951,0.999937,0.00014,0.999923,0.999951,0.80875,0.009143,0.80785,0.80965,0.807639,0.01103,0.806553,0.808724,0.19125,0.009143,0.19035,0.19215,0.192401,0.019931,0.19044,0.194363,80.218413,6.867646,79.542557,80.894272


In [48]:
base = alt.Chart(
    df_stats.collect()
).encode(
    x=alt.X("depth").title("depth")
)

line0 = base.mark_line(interpolate="basis").encode(
    y=alt.Y("train acc_mean:Q").axis(title="Accuracy", orient="left").scale(domain=[0,1]),
)

line1 = base.mark_line(interpolate="basis", color="red").encode(
    y=alt.Y("test acc_mean:Q").axis(title="Accuracy", orient="left").scale(domain=[0,1])
)

line2 = base.mark_line(interpolate="basis", color="purple", strokeDash=[5,5]).encode(
    y=alt.Y("adv distance_mean:Q").axis(title="Distance", orient="right")
)



ci_band0 = base.mark_area(
    opacity=0.3,
    interpolate="basis",
).encode(
    y=alt.Y("train acc_ci_lb:Q").axis(title="Accuracy", orient="left").scale(domain=[0,1]),
    y2=alt.Y2("train acc_ci_ub:Q")
)

ci_band1 = base.mark_area(
    opacity=0.3,
    interpolate="basis",
    color="red"
).encode(
    y=alt.Y("test acc_ci_lb:Q").axis(title="Accuracy", orient="left").scale(domain=[0,1]),
    y2=alt.Y2("test acc_ci_ub:Q")
)

ci_band2 = base.mark_area(
    opacity=0.3,
    interpolate="basis",
    color="purple"
).encode(
    y=alt.Y("adv distance_ci_lb:Q").axis(title="Distance", orient="right"),
    y2=alt.Y2("adv distance_ci_ub:Q")
)

alt.layer(line0 + line1 + ci_band0, ci_band1, line2 + ci_band2).resolve_scale(y="independent").facet(
    column="seed", 
    row="distribution"
).resolve_scale(
    x="independent",
    y="independent"
)

# chart.show()