In [None]:
import pandas as pd
import glob
import json
import matplotlib.pyplot as plt
import numpy as np

In [None]:
def load_json(fname):
    try:
        with open(fname, 'r') as f:
            data = json.load(f)
        return data
    except:
        return []

In [None]:
def plot_with_err(x, y_mean, y_std, label=None, plot_kwargs=None, ax=None):
    sort_idx = np.argsort(x)
    x_sorted = np.array(x)[sort_idx]
    y_mean_sorted = np.array(y_mean)[sort_idx]
    

    ax.plot(x_sorted, y_mean_sorted, label=label, **plot_kwargs)

    if y_std is not None:
        y_std_sorted = np.array(y_std)[sort_idx]
        y_upper = y_mean_sorted + y_std_sorted
        y_lower = y_mean_sorted - y_std_sorted
        ax.fill_between(
            x_sorted,
            y_lower,
            y_upper,
            alpha=0.1,
            color=plot_kwargs['color']
        )
    return ax


def plot_with_err_(x, y_mean, y_std, label=None, ax=None, plot_kwargs=None):
#def plot_with_err_(x, y_mean, y_std, linestyle=None, label=None, color='blue', marker='.', alpha=0.1, ax=None):
    y_upper = y_mean + y_std
    y_lower = y_mean - y_std

    ax.plot(
        x, 
        y_mean, 
        label=label, 
        **plot_kwargs
        #linestyle=linestyle, 
        #color=color, 
        #marker=marker
    )
    ax.fill_between(
        x,
        y_lower,
        y_upper,
        alpha=alpha,
        color=plot_kwargs['color']
    )
    return ax

In [None]:
style_dict = {
    "laplace": {
        "color": "#000000",
        "linestyle": (0, (5, 2)),   # dashed
        "marker": "o",
        "linewidth": 2.4,
        "markersize": 6.5,
        "markerfacecolor": "none",
        "markeredgewidth": 1.6,
    },
    "mle": {
        "color": "#D55E00",
        "linestyle": "solid",
        "marker": "s",
        "linewidth": 2.6,
        "markersize": 6.2,
        "markerfacecolor": "#D55E00",
        "markeredgecolor": "white",
        "markeredgewidth": 0.8,
    },
    "map": {
        "color": "#7F7F7F",
        "linestyle": (0, (1, 1)),   # dotted
        "marker": "D",
        "linewidth": 2.2,
        "markersize": 6.0,
        "markerfacecolor": "none",
        "markeredgewidth": 1.4,
    },
    "tempscale": {
        "color": "#0072B2",
        "linestyle": "dashdot",
        "marker": "^",
        "linewidth": 2.4,
        "markersize": 6.8,
        "markerfacecolor": "#0072B2",
        "markeredgecolor": "white",
        "markeredgewidth": 0.8,
    },
    "blob": {
        "color": "#CC79A7",
        "linestyle": (0, (3, 1, 1, 1)),  # short dash-dot pattern
        "marker": "P",                   # plus-filled
        "linewidth": 2.2,
        "markersize": 7.0,
        "markerfacecolor": "#CC79A7",
        "markeredgecolor": "white",
        "markeredgewidth": 0.8,
    },
    "scalabl": {
        "color": "#009E73",
        "linestyle": (0, (7, 2)),   # long dashed
        "marker": "v",
        "linewidth": 2.6,
        "markersize": 6.8,
        "markerfacecolor": "none",
        "markeredgewidth": 1.6,
    },
    "tfb": {
        "color": "#56B4E9",
        "linestyle": (0, (2, 2)),   # evenly dashed
        "marker": "X",
        "linewidth": 2.3,
        "markersize": 6.8,
        "markerfacecolor": "#56B4E9",
        "markeredgecolor": "black",
        "markeredgewidth": 0.6,
    },
    "mcdropout": {
        "color": "#E69F00",
        "linestyle": (0, (1, 2)),   # spaced dots
        "marker": "<",
        "linewidth": 2.4,
        "markersize": 6.8,
        "markerfacecolor": "#E69F00",
        "markeredgecolor": "white",
        "markeredgewidth": 0.8,
    },
    "deepensemble": {
        "color": "#332288",
        "linestyle": (0, (4, 1, 1, 1, 1, 1)),  # dash-dot-dot
        "marker": ">",
        "linewidth": 2.4,
        "markersize": 6.8,
        "markerfacecolor": "none",
        "markeredgewidth": 1.6,
    },
}


In [None]:
style_dict = {
    'laplace': {'color': 'black', 'linestyle': '--', 'marker': '.'},
    'mle': {'color': 'red', 'linestyle': ':', 'marker': 'v'},
    'map': {'color': 'grey', 'linestyle': ':', 'marker': 'v'},
    'tempscale': {'color': 'blue', 'linestyle': 'dashdot', 'marker': 'o'},
    'blob': {'color': 'purple', 'linestyle': '--', 'marker': 's'},
    'scalabl': {'color': 'green', 'linestyle': 'solid', 'marker': '^'},
    'tfb': {'color': 'blue', 'linestyle': 'dashdot', 'marker': '^'},
    'mcdropout': {'color': 'orange', 'linestyle': 'dashdot', 'marker': 'v'},
    'deepensemble': {'color': 'teal', 'linestyle': 'dashdot', 'marker': 'v'},
    #deepens
    #mcdroput
    #sgld?
    #map
    #zeroshot?
}
metric2arrow = {
    'ACC': '↑',
    'ECE': '↓',
    'NLL': '↓',
    'Brier': '↓',
    'peak_memory': '↓',
    'latency': '↓',
}

wrapper2label = {
    'mle': 'MLE',
    'blob': 'BLoB',
    'scalabl': 'ScalaBL',
    'laplace': 'Laplace',
    'tfb': 'TFB',
    'mcdropout': 'MCDropout',
    'tempscale': 'TempScale',
    'deepensemble': 'ENS',
    'map': 'MAP',
    'tempscale': 'TempScale'
}

In [None]:

root = '/workspace1/csamplawski/src/BayesAdapt/logs/'
root = '/project/synthesis/bayesadapt/logs/'


In [None]:
/project/synthesis/bayesadapt/logs/Qwen/Qwen3-VL-4B-Instruct/16bit/mle/rank8/vlm/seed0/srqa/results/active_learn

In [None]:
json_fnames = glob.glob(f'{root}/**/active_learn/results.json', recursive=True)
expdirs = []
for fname in json_fnames:
    tokens = fname.split('/')
    edir = '/'.join(tokens[0:-1])
    expdirs.append(edir)
expdirs = list(set(expdirs))

df = []
for edir in expdirs:
    tokens = edir.replace(root, '').split('/')
    keys = ['model', 'quant', 'wrapper', 'rank', 'prompt_type', 'seed', 'dataset']
    row = dict(zip(keys, tokens[1:]))
    row['rank'] = int(tokens[4].replace('rank', ''))
    row['seed'] = int(tokens[6][-1])
    data = load_json(f'{edir}/results.json')
    for metric in ['ACC', 'ECE', 'NLL', 'Brier']:
        row[metric] = [item['test_metrics'][0][metric] for item in data]
    #row['results'] = data
    df.append(row)
df = pd.DataFrame(df)

def mean_std_vectors(series: pd.Series) -> pd.Series:
    """
    series: metric column within one group (one value per seed),
            where each value is a vector (list/ndarray) of y-values.
    returns: (mean_vector, std_vector)
    """
    arrs = [np.asarray(v, dtype=float) for v in series]
    lengths = {a.shape for a in arrs}
    if len(lengths) != 1:
        raise ValueError(f"Vector shapes differ within group: {lengths}")
    stacked = np.stack(arrs, axis=0)          # (n_seeds, T)
    return pd.Series({
        "mean": stacked.mean(axis=0),
        "std":  stacked.std(axis=0, ddof=1) if stacked.shape[0] > 1 else np.zeros(stacked.shape[1]),
    })

group_cols = ['model', 'quant', 'wrapper', 'rank', 'prompt_type', 'dataset']

agg_parts = []
for m in metrics:
    tmp = (
        df.groupby(group_cols)[m]
          .apply(mean_std_vectors)
          .unstack()  # columns: mean/std
    )
    # rename to ACC_mean / ACC_std, etc.
    tmp.columns = [f"{m}_{c}" for c in tmp.columns]
    agg_parts.append(tmp)

active_df = pd.concat(agg_parts, axis=1).reset_index()
active_df = active_df.set_index(group_cols)

In [None]:
active_df

In [None]:
#ax = plt.gca()
fig, axes = plt.subplots(1, 4, figsize=(25, 5), sharey=False)
plt.rcParams.update({'font.size': 12})
metrics = ['ACC', 'ECE', 'NLL', 'Brier']

dataset = 'srqa'
prompt_type = 'vlm'
quant = '16bit'
rank = 8

#base_query_str = f"dataset == '{dataset}' and prompt_type == '{prompt_type}' and quant == '{quant}' and rank == {rank}"

for ax, metric in zip(axes, metrics):
    arrow = metric2arrow[metric]
    for wrapper in ['mle', 'scalabl', 'mcdropout','map','blob']:
        label = wrapper2label[wrapper]
        qdf = query(active_df, prompt_type=prompt_type, wrapper=wrapper, dataset=dataset)
        y_mean = qdf[f'{metric}_mean'][0]
        y_std = qdf[f'{metric}_std'][0]
        y_std=None
        x = np.arange(len(y_mean)) + 1
        x *= 10
        ax = plot_with_err(x, y_mean, y_std, plot_kwargs=style_dict[wrapper], label=label, ax=ax)


    ax.set_xlabel('# Labels Acquired')
    ax.set_ylabel(f"{metric} ({arrow})")
    #ax.legend(
    #    loc='upper center',          # Anchor point on the legend box itself
    #    bbox_to_anchor=(0.5, -0.15), # (x, y) coordinates relative to the plot axes
    #    ncols=2,       # Forces all items into a single row
    #    frameon=True                # Optional: removes the box border for a cleaner look
    #)
    #ax.set_title(f'Qwen3 Family | {prompt_type} | rank = {rank} | {dataset}')
    ax.grid()
handles, labels = axes[0].get_legend_handles_labels()
fig.legend(
    handles, labels,
    loc='lower center',
    bbox_to_anchor=(0.5, -0.05),
    ncols=9,          # adjust for readability
    frameon=True
)
fig.suptitle(f'Qwen3-8B-VL-Instruct Active Learning on SymbolicRegressionQA',y=0.95)
fig.subplots_adjust(bottom=0.15)
import matplotlib.ticker as mticker

for ax in axes:
    ax.yaxis.set_major_formatter(mticker.FormatStrFormatter('%.2f'))

In [None]:
exp_keys = ['model', 'quant', 'wrapper', 'rank', 'prompt_type', 'dataset', 'num_base', 'num_trainable_params', 'num_total_params']
metric_keys = ['ACC', 'ECE', 'NLL', 'Brier', 'peak_memory', 'latency']

json_fnames = glob.glob(f'{root}/**/id/metrics.json', recursive=True)

expdirs = []
for fname in json_fnames:
    tokens = fname.split('/')
    edir = '/'.join(tokens[0:-1])
    expdirs.append(edir)
expdirs = list(set(expdirs))

df = []
for edir in expdirs:
    tokens = edir.replace(root, '').split('/')
    keys = ['model', 'quant', 'wrapper', 'rank', 'prompt_type', 'seed', 'dataset']
    row = dict(zip(keys, tokens[1:]))
    row['rank'] = int(tokens[4].replace('rank', ''))
    row['seed'] = int(tokens[6][-1])
    data = load_json(f'{edir}/metrics.json')
    row['results'] = data
    df.append(row)
df = pd.DataFrame(df)
df_exploded = df.explode('results').reset_index(drop=True)
metrics_df = pd.json_normalize(df_exploded['results']).drop(columns=['seed'])
id_df_seeds = pd.concat([df_exploded.drop(columns=['results']), metrics_df], axis=1)
id_df = id_df_seeds.groupby(exp_keys)[metric_keys].agg(['mean', 'std'])

In [None]:
query(

In [None]:
query(id_df, dataset='expression_logic')

In [None]:
query(ood_df, dataset='expression_logic')

In [None]:
query(id_df, dataset='circuit_logic')

In [None]:
query(ood_df, dataset='circuit_logic')

In [None]:
import pandas as pd

METHODS_MAP = {
    "MLE": "mle",
    "MAP": "map",
    "MC-Dropout": "mcdropout",
    "Ensemble": "deepensemble",
    "Laplace": "laplace",
    "BLoB": "blob",
    "ScalaBL": "scalabl",
    "TFB": "tfb"
}

def latex_escape(s: str) -> str:
    # at minimum underscore; plus a few common LaTeX specials
    repl = {
        "\\": r"\textbackslash{}",
        "&": r"\&",
        "%": r"\%",
        "$": r"\$",
        "#": r"\#",
        "_": r"\_",
        "{": r"\{",
        "}": r"\}",
        "~": r"\textasciitilde{}",
        "^": r"\textasciicircum{}",
    }
    out = []
    for ch in str(s):
        out.append(repl.get(ch, ch))
    return "".join(out)

def _flatten_columns(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    if isinstance(df.columns, pd.MultiIndex):
        flat = []
        for a, b in df.columns.to_flat_index():
            if b is None or b == "":
                flat.append(str(a))
            else:
                flat.append(f"{a}_{b}")
        df.columns = flat
    return df

def _prep_df(id_df: pd.DataFrame) -> pd.DataFrame:
    # make exp_keys queryable as columns, and make metric columns like ACC_mean/ACC_std
    df = id_df.reset_index() if isinstance(id_df.index, pd.MultiIndex) else id_df.copy()
    return _flatten_columns(df)

def _format_pm(mean: float, std: float, metric: str) -> str:
    # If ACC/ECE look like fractions, auto-convert to percent
    if metric in ("ACC", "ECE") and mean <= 1.0:
        mean *= 100.0
        std *= 100.0

    # per-metric formatting (tweak if you want)
    if metric in ("ACC", "ECE"):
        return f"${mean:.2f}_{{\\pm {std:.1f}}}$"
    else:
        return f"${mean:.3f}_{{\\pm {std:.3f}}}$"

def make_latex_table(
    id_df: pd.DataFrame,
    model: str,
    rank: int,
    datasets: list[str],
    *,
    prompt_type: str = "instruct",
    quant: str = "16bit",
    metrics: list[str] = ("ACC", "ECE", "NLL"),
    methods_map: dict[str, str] = METHODS_MAP,
    caption: str | None = None,
) -> str:
    df = _prep_df(id_df)

    # sanity: required columns
    required = {"model", "rank", "prompt_type", "quant", "dataset", "wrapper"}
    missing = required - set(df.columns)
    if missing:
        raise ValueError(f"id_df is missing required columns after prep: {sorted(missing)}")

    ncols = 2 + len(datasets)
    col_spec = "@{}" + ("c" * ncols) + "@{}"

    ds_headers = [latex_escape(d) for d in datasets]
    header = (
        "\\textbf{Metric} & \\textbf{Method} & "
        + " & ".join([f"\\textbf{{{h}}}" for h in ds_headers])
        + " \\\\\n"
    )

    if caption is None:
        caption = f"Performance Comparison ({latex_escape(model)}, rank {rank})"

    lines = []
    lines.append("\\begin{table*}[h!]")
    lines.append("\\centering")
    lines.append(f"\\caption{{{caption}}}")
    lines.append(f"\\begin{{tabular}}{{{col_spec}}}")
    lines.append("\\toprule")
    lines.append(header.rstrip("\n"))
    lines.append("\\midrule")

    up_metrics = {"ACC"}  # everything else treated as "down"
    method_display_names = list(methods_map.keys())

    for mi, metric in enumerate(metrics):
        arrow = "\\uparrow" if metric in up_metrics else "\\downarrow"
        lines.append(f"\\multirow{{{len(method_display_names)}}}{{*}}{{\\textbf{{{latex_escape(metric)} ($${arrow}$$)}}}}".replace("$$", "$"))

        for display_name, wrapper_code in methods_map.items():
            row = [f"& {latex_escape(display_name)}"]

            for ds in datasets:
                sub = df[
                    (df["model"] == model)
                    & (df["rank"] == rank)
                    & (df["prompt_type"] == prompt_type)
                    & (df["quant"] == quant)
                    & (df["dataset"] == ds)
                    & (df["wrapper"] == wrapper_code)
                ]

                if sub.empty:
                    row.append("& TBD")
                    continue

                # if multiple rows match (e.g., because num_* params differ), pick a stable choice
                if len(sub) > 1:
                    pick_col = "num_trainable_params" if "num_trainable_params" in sub.columns else None
                    if pick_col:
                        sub = sub.sort_values(pick_col, ascending=False)
                    sub = sub.iloc[[0]]

                mean_col = f"{metric}_mean"
                std_col = f"{metric}_std"
                if mean_col not in sub.columns or std_col not in sub.columns:
                    row.append("& TBD")
                    continue

                mean = float(sub[mean_col].iloc[0])
                std = float(sub[std_col].iloc[0])
                row.append(f"& {_format_pm(mean, std, metric)}")

            lines.append(" ".join(row) + " \\\\")

        if mi != len(metrics) - 1:
            lines.append("\\midrule")

    lines.append("\\bottomrule")
    lines.append("\\end{tabular}")
    lines.append("\\end{table*}")
    return "\n".join(lines)

# --- Example ---
# latex = make_latex_table(
#     id_df,
#     model="Qwen3-8B",
#     rank=8,
#     datasets=["winogrande_s", "ARC-Challenge", "ARC-Easy", "obqa", "boolq"],
#     prompt_type="instruct",
#     quant="16bit",
#     metrics=["ACC", "ECE", "NLL"],
# )
# print(latex)


In [None]:
latex = make_latex_table(
    id_df,
    model="Qwen3-8B",
    rank=8,
    datasets=["winogrande_xs", "winogrande_s", "winogrande_m", "winogrande_l"],
    prompt_type="instruct",
    quant="16bit",
    metrics=["ACC", "ECE", "NLL", "Brier"],
)
latex = make_latex_table(
    id_df,
    model="Qwen3-VL-8B-Instruct",
    rank=8,
    datasets=["slake", "mmstar", "MathVerse"],
    prompt_type="vlm",
    quant="16bit",
    metrics=["ACC", "ECE", "NLL", "Brier"],
)
print(latex)

In [None]:
metrics_df

In [None]:
def query(df, model=None, dataset=None, wrapper=None, prompt_type='instruct', quant='16bit', rank=8):
    query_str = f"prompt_type == '{prompt_type}' and quant == '{quant}' and rank == {rank}"
    if model is not None:
        query_str += f" and model == '{model}'"
    if dataset is not None:
        query_str += f" and dataset == '{dataset}'"
    if wrapper is not None:
        query_str += f" and wrapper == '{wrapper}'"
    q = df.query(query_str).reset_index()
    return q

In [None]:
# ---- example ----
# tables = generate_multidataset_metric_tables(
#     id_df,
#     model="Qwen3-8B",
#     rank=8,
#     datasets=["winogrande_s","ARC-Challenge","ARC-Easy","winogrande_m","obqa","boolq"],
#     prompt_type="instruct",
#     quant="16bit",
#     scale_acc_ece_to_percent=False,  # flip if needed
#     caption="Performance Comparison (Qwen3-8B, rank=8)",
#     label_prefix="tab:qwen3_8b_r8",
# )

In [None]:
print(metric, wrapper_key, ds,
      "sub_rows=", len(sub),
      "picked_std=", None if len(sub)==0 else float(get_stat(sub.iloc[0], metric, "std")),
      "all_stds=", [] if len(sub)==0 else [float(x) for x in sub[(metric,"std")].tolist()])

In [None]:
query(id_df, dataset='winogrande_s', wrapper='mle', model='Qwen3-8B', rank=8)

In [None]:
query(id_df_seeds, prompt_type='vlm', wrapper='deepensemble')

In [None]:
json_fnames = glob.glob(f'{root}/**/ood/**/metrics.json', recursive=True)
expdirs = []
for fname in json_fnames:
    tokens = fname.split('/')
    edir = '/'.join(tokens[0:-1])
    expdirs.append(edir)
expdirs = list(set(expdirs))

df = []
for edir in expdirs:
    tokens = edir.replace(root, '').split('/')
    keys = ['model', 'quant', 'wrapper', 'rank', 'prompt_type', 'seed']
    row = dict(zip(keys, tokens[1:-2]))
    row['rank'] = int(tokens[4].replace('rank', ''))
    row['seed'] = int(tokens[6][-1])
    row['dataset'] = tokens[-1]
    data = load_json(f'{edir}/metrics.json')
    row['results'] = data
    df.append(row)
df = pd.DataFrame(df)
df_exploded = df.explode('results').reset_index(drop=True)
metrics_df = pd.json_normalize(df_exploded['results']).drop(columns=['seed'])
ood_df_seeds = pd.concat([df_exploded.drop(columns=['results']), metrics_df], axis=1)
ood_df = ood_df_seeds.groupby(exp_keys)[metric_keys].agg(['mean', 'std'])

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(25, 5), sharey=False)
plt.rcParams.update({'font.size': 12})
metrics = ['ACC', 'ECE', 'NLL', 'Brier']

noise_stds = [0,1,2,4,8,16,32,64,128]
x = np.arange(len(noise_stds))

dataset = 'slake'
prompt_type = 'vlm'
quant = '16bit'
rank = 8

for ax, metric in zip(axes, metrics):
    arrow = metric2arrow[metric]
    
    for wrapper in ['mle', 'blob', 'scalabl']:
        label = wrapper2label[wrapper]
        y_mean, y_std = [], []
        for std in noise_stds:
            if std == 0:
                dataset = 'slake'
                metric_df = id_df
            else:
                dataset = f'noisy_slake{std}'
                metric_df = ood_df
            
            metric_vals = metric_df.query(f"dataset == '{dataset}' and prompt_type == '{prompt_type}' and wrapper == '{wrapper}' and quant == '{quant}' and rank == {rank} and model == 'Qwen3-VL-8B-Instruct'" ).reset_index()[metric]
            y_mean.append(metric_vals['mean'].item())
            y_std.append(metric_vals['std'].item())
        ax = plot_with_err(x, y_mean, y_std, **style_dict[wrapper], label=label, ax=ax)
        ax.set_xlabel('Noise STD (pixel units)')
        
    ax.set_ylabel(f"{metric} ({arrow})")
    ax.legend(
        loc='upper center',          # Anchor point on the legend box itself
        bbox_to_anchor=(0.5, -0.15), # (x, y) coordinates relative to the plot axes
        ncols=2,       # Forces all items into a single row
        frameon=True                # Optional: removes the box border for a cleaner look
    )
    
    #ax.set_xscale('log', base=2)
    ax.set_xticks(x)
    ax.set_xticklabels(noise_stds)
    ax.grid()


In [None]:
query(ood_df_seeds, wrapper='scalabl', prompt_type='vlm', dataset='noisy_slake8')

In [None]:
#ax = plt.gca()
fig, axes = plt.subplots(1, 4, figsize=(25, 5), sharey=False)
plt.rcParams.update({'font.size': 12})
metrics = ['ACC', 'ECE', 'NLL', 'Brier']

dataset = 'obqa'
prompt_type = 'instruct'
quant = '16bit'
rank = 8

#base_query_str = f"dataset == '{dataset}' and prompt_type == '{prompt_type}' and quant == '{quant}' and rank == {rank}"

for ax, metric in zip(axes, metrics):
    arrow = metric2arrow[metric]
    for wrapper in ['mle', 'scalabl', 'blob', 'mcdropout','tfb','laplace','deepensemble', 'map', 'tempscale']:
    #for wrapper in ['mle', 'laplace', 'tfb' ,'deepensemble']:
        label = wrapper2label[wrapper]
        #metric_df = id_df.groupby(exp_keys)[metric].agg(['mean', 'std'])
        #query_str = base_query_str + f" and wrapper == '{wrapper}'"
        #q = id_df.query(query_str).reset_index()
        q = query(id_df, prompt_type=prompt_type, wrapper=wrapper, dataset=dataset)
        ax = plot_with_err(q['num_base'] / 10**9, q[(metric, 'mean')], None, plot_kwargs=style_dict[wrapper], label=label, ax=ax)
        #ax = plot_with_err(q['num_base'], q['num_trainable_params'], None, **style_dict[wrapper], label=label, ax=ax)


    ax.set_xlabel('# Base Parameters (billions)')
    ax.set_ylabel(f"{metric} ({arrow})")
    #ax.legend(
    #    loc='upper center',          # Anchor point on the legend box itself
    #    bbox_to_anchor=(0.5, -0.15), # (x, y) coordinates relative to the plot axes
    #    ncols=2,       # Forces all items into a single row
    #    frameon=True                # Optional: removes the box border for a cleaner look
    #)
    #ax.set_title(f'Qwen3 Family | {prompt_type} | rank = {rank} | {dataset}')
    ax.grid()
handles, labels = axes[0].get_legend_handles_labels()
fig.legend(
    handles, labels,
    loc='lower center',
    bbox_to_anchor=(0.5, -0.05),
    ncols=9,          # adjust for readability
    frameon=True
)
fig.suptitle(f'Qwen3 Family on In-Distribution {dataset}',y=0.95)
fig.subplots_adjust(bottom=0.15)
import matplotlib.ticker as mticker

for ax in axes:
    ax.yaxis.set_major_formatter(mticker.FormatStrFormatter('%.2f'))

In [None]:
from IPython.display import display
qdf = query(id_df_seeds, dataset='mmstar', prompt_type='vlm', wrapper='tfb')
print(qdf.shape)
qdf

In [None]:
#ax = plt.gca()
fig, axes = plt.subplots(1, 4, figsize=(25, 5), sharey=False)
plt.rcParams.update({'font.size': 12})
metrics = ['ACC', 'ECE', 'NLL', 'Brier']

dataset_sizes = ['xs','s','m','l']
x_vals = [160,640,2558,10234]
prompt_type = 'instruct'
quant = '16bit'
model = 'Qwen3-8B'
rank = 8

base_query_str = f"model == '{model}' and prompt_type == '{prompt_type}' and quant == '{quant}' and rank == {rank}"

for ax, metric in zip(axes, metrics):
    arrow = metric2arrow[metric]
    for wrapper in ['mle', 'scalabl', 'blob', 'mcdropout', 'laplace','tfb', 'deepensemble', 'map', 'tempscale']:
        label = wrapper2label[wrapper]
        y_mean, y_std = [], []
        for size in dataset_sizes:
            query_str = base_query_str + f" and wrapper == '{wrapper}' and dataset == 'winogrande_{size}'"
            #metric_df = id_df.groupby(exp_keys)[metric].agg(['mean', 'std'])
            q = id_df.query(query_str).reset_index()
            try:
                y_mean.append(q[(metric, 'mean')].item())
                y_std.append(q[(metric, 'std')].item())
            except:
                continue
        ax = plot_with_err(x_vals[0:len(y_mean)], y_mean, None, plot_kwargs=style_dict[wrapper], label=label, ax=ax)
    ax.grid()
    ax.set_ylabel(f"{metric} ({arrow})")
    ax.set_xlabel('Training Set Size (# of Instances)')
handles, labels = axes[0].get_legend_handles_labels()
fig.legend(
    handles, labels,
    loc='lower center',
    bbox_to_anchor=(0.5, -0.05),
    ncols=9,          # adjust for readability
    frameon=True
)
fig.suptitle(f'Qwen3-8B In-Distribution Winogrande',y=0.95)
fig.subplots_adjust(bottom=0.15)
