In [1]:
import altair as alt  # 5.1.2
import numpy as np
import pandas as pd

In [2]:
stats = pd.read_csv("stats.csv", index_col=0)
curves = pd.read_csv("curves.csv", index_col=0)
stims = pd.read_csv("stims.csv", index_col=0)

display(stats)
display(curves)
display(stims)

Unnamed: 0,x_pos,y_pos,threshold,ci_low,ci_high,ci_size,width,image,id
0,0,0,14.173614,10.170197,22.854961,12.684764,19.917327,https://raw.githubusercontent.com/LukasManinge...,0
1,0,1,19.602239,11.682305,40.286658,28.604353,41.205434,https://raw.githubusercontent.com/LukasManinge...,5
2,0,2,29.633428,22.29248,41.568355,19.275875,25.572502,https://raw.githubusercontent.com/LukasManinge...,10
3,0,3,34.559038,29.09721,42.204487,13.107277,14.818578,https://raw.githubusercontent.com/LukasManinge...,15
4,1,0,33.398395,25.44,48.936208,23.496208,44.902404,https://raw.githubusercontent.com/LukasManinge...,1
5,1,1,10.806923,6.547298,18.983421,12.436123,17.12504,https://raw.githubusercontent.com/LukasManinge...,6
6,1,2,24.282214,17.283092,36.247479,18.964386,23.136406,https://raw.githubusercontent.com/LukasManinge...,11
7,1,3,23.318828,18.878393,32.598655,13.720262,22.798651,https://raw.githubusercontent.com/LukasManinge...,16
8,2,0,26.289698,21.199701,34.378853,13.179152,19.203435,https://raw.githubusercontent.com/LukasManinge...,2
9,2,1,24.85288,15.013901,29.866417,14.852516,1.970603,https://raw.githubusercontent.com/LukasManinge...,7


Unnamed: 0,x_pos,y_pos,x_val,y_val,id
0,0,0,9.0,0.901795,0
1,0,0,10.0,0.877349,0
2,0,0,11.0,0.849962,0
3,0,0,12.0,0.820103,0
4,0,0,13.0,0.788425,0
...,...,...,...,...,...
447,4,3,16.0,0.500000,19
448,4,3,17.0,0.500000,19
449,4,3,18.0,0.500000,19
450,4,3,19.0,0.500000,19


Unnamed: 0,x_pos,y_pos,level,correct,id
0,0,0,9,1,0
1,0,0,9,1,0
2,0,0,10,1,0
3,0,0,10,1,0
4,0,0,10,1,0
...,...,...,...,...,...
1594,4,3,12,1,19
1595,4,3,12,1,19
1596,4,3,12,0,19
1597,4,3,13,0,19


In [3]:
# Compute limits
x_max = round(curves["x_val"].max())
count_max = stims.groupby("level")["correct"].count().max()

# Interactivity
selection = alt.selection_point(fields=["x_pos", "y_pos"])


# Threshold
base_threshold = (
    alt.Chart(stats, width=500, height=400)
    .encode(
        x="x_pos:O", y="y_pos:O", opacity=alt.condition(selection, alt.value(1), alt.value(0.5))
    )
    .add_params(selection)
    .properties(title="Thresholds")
)

heatmap_threshold = base_threshold.mark_rect().encode(color="threshold:Q", tooltip=["image"])

text_threshold = base_threshold.mark_text(baseline="middle").encode(
    text=alt.Text("threshold:Q", format=".1f"),
    color=alt.condition(
        alt.datum.threshold < stats["threshold"].mean(), alt.value("black"), alt.value("white")
    ),
)

chart_threshold = heatmap_threshold + text_threshold


# Confidence interval
base_ci = (
    alt.Chart(stats, width=500, height=400)
    .encode(
        x="x_pos:O",
        y="y_pos:O",
        opacity=alt.condition(selection, alt.value(1), alt.value(0.5)),
    )
    .add_params(selection)
    .properties(title="Threshold 0.95 Confidence Interval Sizes")
)

heatmap_ci = base_ci.mark_rect().encode(color="ci_size:Q", tooltip=["image", "ci_low", "ci_high"])

text_ci = base_ci.mark_text(baseline="middle").encode(
    text=alt.Text("ci_size:Q", format=".1f"),
    color=alt.condition(
        alt.datum.ci_size < stats["ci_size"].mean(), alt.value("black"), alt.value("white")
    ),
)

chart_ci = heatmap_ci + text_ci


# Psychometric curve
rule_curve = (
    alt.Chart(pd.DataFrame({"y_val": [0.75]}))
    .mark_rule()
    .encode(y=alt.Y("y_val").scale(zero=False), tooltip=alt.value("Threshold"))
)

lines_curve = (
    alt.Chart(curves, width=600, height=400)
    .mark_line(point=alt.OverlayMarkDef(size=15), interpolate="monotone")
    .encode(
        x=alt.X("x_val:O").scale(domain=np.arange(2, x_max + 1)),
        y=alt.Y("y_val:Q").scale(domain=(0.5, 1.0)),
        detail="id:N",
        color=alt.value("#4daf4a"),  # Set1:C2
        tooltip=["x_val", "y_val", "x_pos", "y_pos"],
    )
    .add_params(selection)
    .transform_filter(selection)
    .properties(title="Psychometric Curves")
)

chart_curve = lines_curve + rule_curve


# Error intervals
base_err = (
    alt.Chart(stats, width=600, height=100)
    .encode(
        y=alt.Y("id:N").axis(labels=False, ticks=False),
        color=alt.condition(selection, alt.value("#4daf4a"), alt.value("lightgray")),
    )
    .add_params(selection)
)

errorbars_err = base_err.mark_errorbar().encode(
    x=alt.X("ci_low:Q").scale(domain=(2, x_max)).title("ci_low, threshold, ci_high"),
    x2="ci_high:Q",
    tooltip=["ci_low", "threshold", "ci_high", "x_pos", "y_pos"],
)

points_err = (
    base_err.mark_point(fill="white", opacity=1)
    .encode(
        x="threshold:Q",
    )
    .properties(title="Threshold CIs")
)

chart_err = errorbars_err + points_err


# Shown stimuli
rule_stim = (
    alt.Chart(stims)
    .mark_rule()
    .encode(y=alt.Y("mean(correct):Q").scale(domain=(0, 1)), tooltip=["mean(correct)"])
    .transform_filter(selection)
)


hist_stim = (
    alt.Chart(stims, width=600, height=300)
    .mark_bar(stroke="white")
    .encode(
        x=alt.X("level:O").scale(domain=np.arange(2, x_max + 1)),
        y=alt.Y("count():Q").scale(domainMax=count_max),
        color=alt.Color("correct:N").scale(scheme="set1"),
        tooltip=["level", "count()"],
    )
    .transform_filter(selection)
    .properties(title="Shown Stimuli")
)

chart_stim = hist_stim + rule_stim


# Everything
chart_left = (chart_threshold & chart_ci).resolve_scale(color="independent")
chart_right = (chart_curve & chart_stim.resolve_scale(y="independent")).resolve_scale(
    x="shared"
) & chart_err
chart = chart_left | chart_right

display(chart)