In [None]:
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output
from ipywidgets import Layout
from collections import OrderedDict

# Load the data
all_models = pd.read_csv("MODELS_METRICS_IDS_TF_CM_AG_FT.csv")
votes      = pd.read_csv("voting_table_TF_CM_AG_FT.csv")

METRICS    = [c for c in ("accuracy","precision","recall","f_1.0","f_0.7","f_0.5","f_0.3", "TN", "TP", "FP", "FN")
              if c in all_models.columns]

# 1) start defining the widgets
metric_w = widgets.Dropdown(options=METRICS, description="Sort by:")
topk_w   = widgets.IntSlider(min=1, max=1, step=1, value=1, description="Top K:")
max_w    = widgets.Label(value="Max Top K: ?")

prefix_w = widgets.SelectMultiple(
    options=[("Classical ML","CM"),("Attention Agg","AG"),("Finetuned BERT","FT"),("TF-IDF","TF")],
    value=["CM","AG","FT","TF"], description="Model\nCategories:",
    layout=Layout(width="270px", height="120px"), style={'description_width': '130px'}
)
model_w = widgets.SelectMultiple(
    options=[], value=[], description="Models:",
    layout=Layout(width="300px", height="300px")
)
embedding_w = widgets.SelectMultiple(
    options=[], value=[], description="Embedding type:",
    layout=Layout(width="400px", height="300px"), style={'description_width': '130px'}
)
beta_w = widgets.SelectMultiple(
    options=[], value=[], description="Betas:",
    layout=Layout(width="150px", height="100px")
)

out = widgets.Output()


# refresh functions

def refresh_embedding_options(*_):
    prefs  = tuple(prefix_w.value)
    models = tuple(model_w.value)
    mask = (
        all_models['model_id'].str.startswith(prefs) &
        all_models['model'].isin(models)
    )
    vals = sorted(all_models.loc[mask, 'embedding'].unique())
    embedding_w.options = vals
    embedding_w.value   = tuple(vals)
    refresh_beta_options()


def refresh_model_options(*_):
    prefs = list(prefix_w.value)
    ordered = []
    for p in ["CM","AG","FT","TF"]:
        if p in prefs:
            names = sorted(all_models.loc[
                all_models["model_id"].str.startswith(p),
                "model"
            ].unique())
            ordered.extend(names)
    model_w.options = ordered
    model_w.value   = tuple(ordered)
    refresh_embedding_options()


def refresh_topk(*_):
    prefs      = tuple(prefix_w.value)
    models     = tuple(model_w.value)
    embeddings = tuple(embedding_w.value)    
    betas      = tuple(beta_w.value)

    mask = (
        all_models['model_id'].str.startswith(prefs) &
        all_models['model'].isin(models) &
        all_models['embedding'].isin(embeddings) &    
        all_models['beta'].isin(betas)
    )
    n_ids = mask.sum()
    max_w.value = f"Max Top K: {n_ids}"
    topk_w.max  = n_ids or 1
    if topk_w.value > topk_w.max:
        topk_w.value = topk_w.max


def refresh_beta_options(*_):
    prefs      = tuple(prefix_w.value)
    models     = tuple(model_w.value)
    embeddings = tuple(embedding_w.value)
    mask = (
        all_models['model_id'].str.startswith(prefs) &
        all_models['model'].isin(models) &
        all_models['embedding'].isin(embeddings)
    )
    betas = sorted(all_models.loc[mask, 'beta'].unique())
    beta_w.options = betas
    beta_w.value   = tuple(betas)
    refresh_topk()


# Wire up the observers in the new sequence
prefix_w.observe(refresh_model_options, names='value')
model_w.observe(refresh_embedding_options, names='value')
embedding_w.observe(refresh_beta_options, names='value')
beta_w.observe(refresh_topk, names='value')

# initialize everything
refresh_model_options()

# update the plot
def update_plot(*_):
    global selected_subset
    with out:
        clear_output()
        sort_by = metric_w.value
        topk    = topk_w.value
        prefs   = tuple(prefix_w.value)
        models  = tuple(model_w.value)
        betas   = tuple(beta_w.value)

        m = all_models['model_id'].str.startswith(prefs)
        m &= all_models['model'].isin(models)
        m &= all_models['beta'].isin(betas)
        m &= all_models['embedding'].isin(embedding_w.value)
        df_sel = all_models[m].sort_values(sort_by, ascending=False).head(topk)
        chosen = df_sel['model_id'].tolist()

        # vote counts
        vote_df = votes[chosen]
        N, M = len(vote_df), len(chosen)
        ths = np.arange(0,1.01,0.1)
        pos = [(vote_df.sum(1)>=np.ceil(f*M)).sum() for f in ths]
        neg = [N-p for p in pos]

        # plot
        fig, ax = plt.subplots(figsize=(10,6))
        ax.plot(ths, pos, 'o-', label="Positive", color='green')
        ax.plot(ths, neg, 'o-', label="Negative", color='red')
        for x,c in zip(ths,pos):
            ax.text(x, c+N*0.02, f"{c/N*100:.1f}%", ha="center")
        for x,c in zip(ths,neg):
            ax.text(x, c+N*0.02, f"{c/N*100:.1f}%", ha="center")

        ax.set(
            xlabel="Threshold (fraction required)",
            ylabel="Count",
            title=(
                f"Model Categories={'+'.join(prefs)} | Model Types={len(models)} | "
                f"Betas={'+'.join(map(str,betas))} | Top {topk} by {sort_by}"
            )
        )
        ax.legend(); ax.grid(True)
        plt.show()

        display(df_sel.head(10))




# 4) Layout
controls = widgets.VBox([
    widgets.HBox([prefix_w, model_w, embedding_w, beta_w]),
    widgets.HBox([metric_w, topk_w, max_w]),
])
display(controls, out)
update_plot()
# wire plot updates
for w in (metric_w, topk_w, model_w, beta_w, embedding_w):
    w.observe(update_plot, names='value')

def get_selected_ids(threshold):
    """
    Returns a DataFrame of PMID/PXDID for the current widget selection
    at the given vote‐fraction threshold (0.0–1.0).
    """
    # grab the live widget values
    prefs   = tuple(prefix_w.value)
    models  = tuple(model_w.value)
    betas   = tuple(beta_w.value)
    sort_by = metric_w.value
    topk    = topk_w.value

    # top-K model_ids
    mask = (
        all_models['model_id'].str.startswith(prefs) &
        all_models['model'].isin(models) &
        all_models['embedding'].isin(tuple(embedding_w.value)) &
        all_models['beta'].isin(betas)
    )

    df_sel = (
        all_models[mask]
        .sort_values(sort_by, ascending=False)
        .head(topk)
    )
    chosen = df_sel['model_id'].tolist()

    # build a votes dataframe including the ids
    vote_df = votes[['PMID','PXDID'] + chosen].copy()

    # compute required count and filter
    required = math.ceil(threshold * len(chosen))
    mask_pass = vote_df[chosen].sum(axis=1) >= required

    # return only the ID columns
    return vote_df.loc[mask_pass, ['PMID','PXDID']].reset_index(drop=True)


VBox(children=(HBox(children=(SelectMultiple(description='Model\nCategories:', index=(0, 1, 2, 3), layout=Layo…

Output()