In [1]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from glob import glob
import os
import re
import plotly.express as px

In [2]:
def load_experiment_data(path):
    result = []
    for dirpath in glob(os.path.join(path, '*')):
        if not os.path.isdir(dirpath):
            continue
        exp = dirpath.split('/')[-1].replace('-hf', '')
        size = exp.split('_')[1].split('-')[-2]
        model = '-'.join(exp.split('_')[1].split('-')[:-2])
        ngram = exp.split('_')[2].split('-')[0]
        contextfunc = exp.split('contextfunc_')[-1]
        d_ = {
            'model': model,
            'size': size,
            'ngram': ngram,
            'contextfunc': contextfunc
        }
        
        with open(os.path.join(dirpath, 'likelihood.txt')) as f:
            lines = f.readlines()
            for line in lines[:3]:
                key, value = line.strip().split(': ')
                d_[key] = float(value)
        
        
        with open(os.path.join(dirpath, 'eval.txt')) as f:
            line = f.readlines()[0].strip("{}").replace('"', "")
            key, value = line.strip().split(': ')
            d_[key] = float(value)
        
        result.append(d_)
        
    df = pd.DataFrame(result).drop(columns=['linear_fit_logLik', 'delta_linear_fit_chi_p']).rename(columns={'delta_linear_fit_logLik': 'PPP'})
    return df

In [3]:
mamba_df = load_experiment_data('surprisals/DC-mamba')
pythia_df = load_experiment_data('surprisals/DC-pythia')
pretrained_mamba_df = load_experiment_data('surprisals/DC-pretrained-mamba')

In [4]:
mamba_df.sort_values('PPP', ascending=False)

Unnamed: 0,model,size,ngram,contextfunc,PPP,PPL
65,mamba,370m,2,delete,0.007435,359.931618
136,mamba,790m,2,delete,0.007300,353.142378
130,mamba,1.4b,2,delete,0.007292,336.875576
116,mamba,130m,2,delete,0.007224,387.745736
53,mamba,2.8b,2,delete,0.007010,340.448331
...,...,...,...,...,...,...
120,mamba,2.8b,20,lossy-0.25,0.005122,52.695219
54,mamba,2.8b,20,lossy-0.5,0.005120,52.983243
47,mamba,2.8b,20,delete,0.005120,53.128593
124,mamba,2.8b,1000,delete,0.005115,51.075546


In [5]:
pythia_df.sort_values('PPP', ascending=False)

Unnamed: 0,model,size,ngram,contextfunc,PPP,PPL
55,pythia,70m,20,lossy-0.5,0.007652,134.616018
193,pythia,70m,1000,delete,0.007646,132.258720
149,pythia,70m,3,lossy-0.5,0.007644,269.023969
102,pythia,70m,7,lossy-0.125,0.007635,151.197154
12,pythia,70m,20,delete,0.007634,134.858502
...,...,...,...,...,...,...
152,pythia,6.9b,20,lossy-0.25,0.005258,50.646119
147,pythia,6.9b,20,lossy-0.125,0.005250,50.375851
188,pythia,6.9b,20,lossy-0.5,0.005247,50.928076
98,pythia,6.9b,20,lossy-0.0625,0.005245,49.996341


In [23]:
pretrained_mamba_df.model = 'wiki-mamba'
pretrained_mamba_df.sort_values('PPP', ascending=False)

Unnamed: 0,model,size,ngram,contextfunc,PPP,PPL
23,wiki-mamba,130m,1000,delete,0.008445,728.232237
8,wiki-mamba,130m,20,lossy-0.0625,0.008441,728.834778
22,wiki-mamba,130m,20,lossy-0.125,0.008432,730.452812
0,wiki-mamba,130m,20,lossy-0.25,0.008425,731.802576
30,wiki-mamba,130m,20,lossy-0.5,0.008422,732.813376
2,wiki-mamba,130m,20,delete,0.00842,733.403457
19,wiki-mamba,130m,10,lossy-0.125,0.008393,755.76776
10,wiki-mamba,130m,10,lossy-0.25,0.008388,766.871106
6,wiki-mamba,130m,10,lossy-0.0625,0.008384,741.772073
20,wiki-mamba,130m,10,lossy-0.5,0.008378,775.146044


In [24]:
df = pd.concat([mamba_df, pythia_df, pretrained_mamba_df], ignore_index=True)

In [25]:
def to_num(s):
    """Convert size strings like '130m', '1.4b' to numeric values."""
    m = re.match(r"([\d.]+)\s*([mbMB])", str(s))
    if not m:
        return float("nan")
    val, unit = float(m.group(1)), m.group(2).lower()
    return val * (1e6 if unit == "m" else 1e9)

In [26]:
df['size_num'] = df['size'].apply(to_num)
df = df[df['contextfunc'] == 'delete'].drop(columns=['contextfunc'])
df.ngram = df.ngram.astype(int)

In [27]:
df

Unnamed: 0,model,size,ngram,PPP,PPL,size_num
2,mamba,370m,7,0.005961,97.389515,3.700000e+08
6,mamba,790m,3,0.006774,201.594884,7.900000e+08
20,mamba,2.8b,10,0.005203,66.334283,2.800000e+09
30,mamba,790m,10,0.005691,74.003092,7.900000e+08
32,mamba,130m,3,0.006891,231.205478,1.300000e+08
...,...,...,...,...,...,...
381,wiki-mamba,130m,3,0.008207,1179.907805,1.300000e+08
384,wiki-mamba,130m,2,0.008099,1687.761690,1.300000e+08
390,wiki-mamba,130m,5,0.008302,913.519282,1.300000e+08
395,wiki-mamba,130m,1000,0.008445,728.232237,1.300000e+08


In [28]:
agg_df = df.groupby(['model', 'size_num']).agg({'PPP': 'mean'}).sort_index().reset_index()

fig = px.scatter(
    agg_df,
    x="size_num",
    y="PPP",
    color="model",
    title="Average PPP vs Model Size across Models",
    log_x=True,
).update_traces(mode='lines+markers')

fig.update_layout(
    legend_title_text="Model",
    template="plotly_white",
    xaxis_title="Model Size (Number of Parameters in log scale)",
    yaxis_title="Average PPP",
)
fig.show()

In [29]:
#check duplicates
df[df.duplicated(subset=['model', 'size', 'ngram'], keep=False)]

Unnamed: 0,model,size,ngram,PPP,PPL,size_num


In [30]:
df.ngram.value_counts().sort_index()

ngram
2       13
3       13
5       13
7       13
10      13
20      13
1000    13
Name: count, dtype: int64

In [31]:
df.size_num.value_counts()

size_num
1.300000e+08    14
2.800000e+09    14
1.400000e+09    14
3.700000e+08     7
7.900000e+08     7
1.000000e+09     7
6.900000e+09     7
4.100000e+08     7
7.000000e+07     7
1.600000e+08     7
Name: count, dtype: int64

In [32]:
df.model.value_counts()

model
pythia        49
mamba         35
wiki-mamba     7
Name: count, dtype: int64

In [64]:
df_sorted = df.sort_values(["size_num", "model",  "ngram"])

df_sorted = df_sorted[df_sorted.model.isin([
    'wiki-mamba',
    'mamba',
    'pythia'
])]

ngram_to_size = {
    2: 7,
    3: 8,
    5: 9,
    7: 10,
    10: 11,
    20: 12,
    1000: 15
}
df_sorted['marker_size'] = df_sorted['ngram'].map(ngram_to_size)

# comparible sizes together 
size_to_group = {
    '130m': '130m - 160m',
    '160m': '130m - 160m',
    '370m': '370m - 410m',
    '410m': '370m - 410m',
    '790m': '790m - 1b',
    '1b': '790m - 1b',
    '1.4b': '1.4b',
    '2.8b': '2.8b',
    '6.9b': '6.9b',
}
df_sorted['size_group'] = df_sorted['size'].map(size_to_group)

fig = px.scatter(
    df_sorted,
    x="PPL",
    y="PPP",
    color="size_group",           # unique color for each (model + size)
    symbol="model",               # marker shape per model type
    # facet_col="model",
    hover_data=["model", "size", "ngram", "PPP", "PPL"],
    title="PPP vs PPL across Models, Sizes, and N-grams",
    log_x=True,
    log_y=True,
    color_continuous_scale=px.colors.sequential.Viridis,
    width=1000,
    height=800
).update_traces(mode='lines+markers', line=dict(width=2))

fig.update_traces(marker=dict(size=df_sorted['marker_size'], opacity=0.8, line=dict(width=1, color='DarkSlateGrey')))

symbol_map = {
    'mamba': 'star',
    'pythia': 'square',
    'wiki-mamba': 'diamond'
}

opacity_map = {
    'mamba': 0.9,
    'pythia': 0.7,
    'wiki-mamba': 0.5
}

for trace in fig.data:
    model = trace.name.split(',')[1].strip()  # Extract model name from trace name
    if model in symbol_map:
        trace.marker.symbol = symbol_map[model]
    if model in opacity_map:
        trace.marker.opacity = opacity_map[model]
    
for trace in fig.data:
    # Split the trace name into parts
    parts = [p.strip() for p in trace.name.split(',')]
    if len(parts) < 2:
        continue  # skip if unexpected format

    group_name, model_name = parts[0], parts[1]

    # Filter by both size_group and model
    sizes_in_group = df_sorted.loc[
        (df_sorted['size_group'] == group_name) &
        (df_sorted['model'] == model_name),
        'size'
    ].unique()

    # Use just the single size (should almost always be one)
    if len(sizes_in_group) == 1:
        label = f"{sizes_in_group[0]}, {model_name}"
    else:
        # fallback if somehow multiple sizes exist
        label = f"{'/'.join(sizes_in_group)}, {model_name}"

    trace.name = label

# Clean layout
fig.update_layout(
    legend_title_text="Size, Model",
    template="plotly_white",
    xaxis_title="PPL in log scale",
    yaxis_title="PPP in log scale",
    # legend in bottom right corner
    legend=dict(
        x=1,
        y=0,
        xanchor='right',
        yanchor='bottom',
        bgcolor='rgba(255,255,255,0.8)',
        bordercolor='rgba(0,0,0,0.2)',
        borderwidth=1
    )
)

fig.show()