In [None]:
import polars as pl
import numpy as np
from scipy import stats
import altair as alt

In [None]:
df = pl.scan_parquet("data/data.parquet")

In [None]:
def ci_lower(x, confidence = .95):
    x = x.to_numpy()
    n = len(x)

    mean = np.mean(x)
    se = stats.sem(x)
    ci = se * stats.t.pdf((1 + confidence) / 2, n-1)

    return mean - ci

def ci_upper(x, confidence = .95):
    x = x.to_numpy()
    n = len(x)

    mean = np.mean(x)
    se = stats.sem(x)
    ci = se * stats.t.pdf((1 + confidence) / 2, n-1)

    return mean + ci

In [None]:
cols = ["train acc", "train f1", "test acc", "test f1", 
               "adv acc", "adv f1", "adv distance"]

def agg_metrics(cols):
    res = []
    for col in cols:
        res.extend([
            pl.col(col).mean().alias(f"{col}_mean"),
            pl.col(col).std().alias(f"{col}_std"),
            pl.col(col).map_batches(ci_lower, return_dtype=pl.Float32, returns_scalar=True).alias(f"{col}_ci_lb"),
            pl.col(col).map_batches(ci_upper, return_dtype=pl.Float32, returns_scalar=True).alias(f"{col}_ci_ub"),
        ])
    return res

In [None]:
df.collect()

In [None]:
df_stats = df.select(pl.exclude("seed")).group_by(
    pl.col("distribution"),
    pl.col("seed"),
    pl.col("depth")
).agg(
    agg_metrics(cols)
).sort([
    pl.col("distribution"),
    # pl.col("seed"),
    pl.col("depth")
])

# df.with_columns(pl.col("train acc").gr)

In [6]:
df_stats.filter(
    pl.col("distribution") == "check"
    ).collect()

distribution,depth,train acc_mean,train acc_std,train acc_ci_lb,train acc_ci_ub,train f1_mean,train f1_std,train f1_ci_lb,train f1_ci_ub,test acc_mean,test acc_std,test acc_ci_lb,test acc_ci_ub,test f1_mean,test f1_std,test f1_ci_lb,test f1_ci_ub,adv acc_mean,adv acc_std,adv acc_ci_lb,adv acc_ci_ub,adv f1_mean,adv f1_std,adv f1_ci_lb,adv f1_ci_ub,adv distance_mean,adv distance_std,adv distance_ci_lb,adv distance_ci_ub
str,i32,f64,f64,f32,f32,f64,f64,f32,f32,f64,f64,f32,f32,f64,f64,f32,f32,f64,f64,f32,f32,f64,f64,f32,f32,f64,f64,f32,f32
"""check""",1,0.511042,0.004573,0.510906,0.511179,0.476569,0.215943,0.470123,0.483016,0.489588,0.012946,0.489202,0.489975,0.454418,0.217463,0.447926,0.46091,0.510412,0.012946,0.510025,0.510798,0.433731,0.166766,0.428753,0.43871,0.385867,0.045263,0.384515,0.387218
"""check""",2,0.516436,0.007718,0.516245,0.516626,0.554647,0.127982,0.551489,0.557805,0.49918,0.016164,0.498781,0.499579,0.536796,0.134271,0.533483,0.54011,0.50082,0.016164,0.500421,0.501219,0.346705,0.205774,0.341627,0.351783,0.378351,0.045614,0.377225,0.379476
"""check""",3,0.529042,0.01434,0.528688,0.529396,0.578263,0.135095,0.574929,0.581597,0.50758,0.019385,0.507102,0.508058,0.558019,0.142372,0.554506,0.561533,0.49242,0.019385,0.491942,0.492898,0.31369,0.193842,0.308906,0.318473,0.37123,0.045834,0.370099,0.372361
"""check""",4,0.552204,0.025587,0.551573,0.552836,0.572374,0.118746,0.569443,0.575304,0.53246,0.029205,0.531739,0.533181,0.551097,0.127711,0.547945,0.554248,0.46758,0.02916,0.46686,0.4683,0.352866,0.178191,0.348468,0.357263,0.378064,0.04758,0.37689,0.379239
"""check""",5,0.582113,0.037549,0.581187,0.58304,0.585612,0.142007,0.582108,0.589117,0.55778,0.039194,0.556813,0.558747,0.562267,0.145083,0.558687,0.565847,0.44222,0.039226,0.441252,0.443188,0.354668,0.155864,0.350821,0.358514,0.394037,0.060081,0.392555,0.39552
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
"""check""",29,0.999389,0.001474,0.999342,0.999436,0.99939,0.001469,0.999344,0.999437,0.7252,0.019156,0.724592,0.725808,0.724368,0.021349,0.72369,0.725046,0.282833,0.040357,0.281552,0.284115,0.290376,0.071292,0.288112,0.292639,0.417777,0.060536,0.415855,0.419699
"""check""",30,0.999636,0.000988,0.999601,0.99967,0.999637,0.000984,0.999602,0.999671,0.72584,0.020113,0.725142,0.726538,0.725148,0.022135,0.724379,0.725916,0.2838,0.043824,0.282278,0.285322,0.292134,0.07716,0.289455,0.294813,0.422119,0.054047,0.420242,0.423995
"""check""",31,0.999644,0.000782,0.99961,0.999679,0.999645,0.00078,0.999611,0.99968,0.727067,0.020026,0.726175,0.727958,0.72667,0.022442,0.725671,0.727669,0.284467,0.043703,0.282521,0.286412,0.292837,0.079996,0.289276,0.296398,0.426485,0.060163,0.423807,0.429163
"""check""",32,0.9998,0.000536,0.999776,0.999824,0.999801,0.000534,0.999777,0.999824,0.726133,0.020718,0.725211,0.727056,0.725815,0.023046,0.724789,0.726841,0.285333,0.043831,0.283382,0.287284,0.293473,0.079966,0.289913,0.297032,0.426825,0.060432,0.424135,0.429515


In [18]:
base = alt.Chart(
    df_stats.collect()
)

scale = alt.Scale(
    domain=["train acc_mean", "test acc_mean", "adv distance_mean"], 
    range=["blue", "red", "purple"]
)


dash_scale = alt.Scale(
    domain=["train acc_mean", "test acc_mean", "adv distance_mean"],
    range=[[2, 4], [1, 0], [8, 4]]  # [dash_length, gap_length]
)

# Chart for accuracy metrics (left y-axis)
acc_chart = base.mark_line(strokeWidth=2).transform_fold(
    fold=["train acc_mean", "test acc_mean"],
    as_=["variable", "value"]
).encode(
    x=alt.X("depth:Q").title("Depth"),
    y=alt.Y('value:Q').title("Accuracy").scale(zero=False),
    color=alt.Color('variable:N', scale=scale, title="Metric"),
    strokeDash=alt.StrokeDash('variable:N', scale=dash_scale)
)


# Chart for adversarial distance (right y-axis)
adv_chart = base.mark_line(strokeWidth=2).transform_fold(
    fold=["adv distance_mean"],
    as_=["variable", "value"]
).encode(
    x=alt.X("depth:Q").title("Depth"),
    y=alt.Y('value:Q').title("Adversarial Distance").axis(orient='right').scale(zero=False),
    color=alt.Color('variable:N', scale=scale, title="Metric"),
    strokeDash=alt.StrokeDash('variable:N', scale=dash_scale)

)

ci_band0 = base.mark_area(
    opacity=0.3,
    interpolate="basis",
).encode(
    x="depth:Q",
    y=alt.Y("train acc_ci_lb:Q").axis(title="Accuracy", orient="left"),
    y2=alt.Y2("train acc_ci_ub:Q"),
    color = alt.value("blue"),
)

ci_band1 = base.mark_area(
    opacity=0.3,
    interpolate="basis",
).encode(
    x="depth:Q",
    y=alt.Y("test acc_ci_lb:Q").axis(title="Accuracy", orient="left"),
    y2=alt.Y2("test acc_ci_ub:Q"),
    color = alt.value("red")
)

ci_band2 = base.mark_area(
    opacity=0.3,
    interpolate="basis",
).encode(
    x="depth:Q",
    y=alt.Y("adv distance_ci_lb:Q").axis(orient="right"),
    y2=alt.Y2("adv distance_ci_ub:Q"),
    color = alt.value("purple")
)

alt.layer(acc_chart + ci_band0 + ci_band1, adv_chart + ci_band2).resolve_scale(
    y='independent'
).facet(
    row="distribution"
).resolve_scale(
    x="independent",
    y="independent"
)


ERROR:tornado.general:Uncaught exception in ZMQStream callback
Traceback (most recent call last):
  File "/home/aidanjared/Documents/Code/ml_projects/Auditing/.venv/lib/python3.12/site-packages/zmq/eventloop/zmqstream.py", line 565, in _log_error
    f.result()
  File "/home/aidanjared/Documents/Code/ml_projects/Auditing/.venv/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 341, in dispatch_control
    await self.process_control(msg)
  File "/home/aidanjared/Documents/Code/ml_projects/Auditing/.venv/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 347, in process_control
    idents, msg = self.session.feed_identities(msg, copy=False)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/aidanjared/Documents/Code/ml_projects/Auditing/.venv/lib/python3.12/site-packages/jupyter_client/session.py", line 993, in feed_identities
    raise ValueError(msg)
ValueError: DELIM not in msg_list
ERROR:tornado.general:Uncaught exception in ZMQStream call