In [None]:
# Software Name : TimeStress
# SPDX-FileCopyrightText: Copyright (c) Orange SA
# SPDX-License-Identifier: MIT

# This software is distributed under the MIT License,
# see the "LICENSE.txt" file for more details or https://spdx.org/licenses/MIT.html

# Authors: see CONTRIBUTORS.md
# Software description: Evaluating the Consistency of the Temporal Representation of Facts in Large Language Models

# **Goal of notebook**: Analysis of the language models' predictions on TimeStress

This notebook is used to generate the figures and tables present in our [paper](https://arxiv.org/abs/2502.01220)

**How to use this notebook?**

1. Create a folder `outputs` in the same folder as this notebook.
1. Put all the language models' predictions inside `outputs` (the ones generated using the `run.py` script).
1. Set the `EXPERIMENT` constant accordingly (see below for details), and run all the cells in this notebook. 
1. The plots will be generated in a new folder called `plots`.

Choose `EXPERIMENT` value from the follozing list
   - `classic`: Most tables and figures can be generated with this setting
   - `explain_granularity_generalization` and `explain_date`: These 2 settings are used to generate the *Explanation Prompts* results in the paper. 

In [None]:
EXPERIMENT = 'classic' # choose from ['classic', 'explain_granularity_generalization', 'explain_date']

In [None]:
import pandas as pd
from pathlib import Path
from models_metadata import is_instruct, get_model_num_params, get_instruct_version_of_model, get_classical_version_of_model, get_model_family, extract_metadata_from_filename
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from ke_utils.glob_core import TimeUnit
from wikidata_tools.core import Date, enable_level_heterogeneous_comparison

sns.set_theme()

In [None]:

KEPT_MODELS = ['google_gemma-2-27b-it',
                'google_gemma-2-9b-it',
 'meta-llama_Llama-3.1-70B-Instruct',
 'mistralai_Mistral-7B-Instruct-v0.3',
 
]
def load_data() -> tuple[pd.DataFrame, pd.DataFrame]:
    dfs = []
    models_data = []
    models_cols = ['Model', "Instruct", "NumParams"]
    for p in Path('outputs/temporality_paper').glob('*__experiment=%s.pkl' % EXPERIMENT):
        print('Loading %s' % p.name)
        model = extract_metadata_from_filename(p.name)['model']
        inst = is_instruct(model)
        nparams = get_model_num_params(model)
        models_data.append((model, inst, nparams))
        df = pd.read_pickle(p)
        df['Model'] = model
        dfs.append(df)
    return pd.concat(dfs), pd.DataFrame(data=models_data, columns=models_cols)

In [None]:
logprobs, models_data = load_data()

# Remove future dates
REFERENCE_TIME = Date(np.datetime64('2021-01-04'))
with enable_level_heterogeneous_comparison():
    logprobs = logprobs[logprobs['Time'].apply(lambda x : Date(x) < REFERENCE_TIME if isinstance(x,np.datetime64) else True)]

### Verify that for each Fact, there is the same number of tests (times) per granularity.

In [None]:
logprobs[logprobs['IsCorrect'] == 'Transitional']['Alpha'].describe()

In [None]:
logprobs_dedup = logprobs.copy()

# Associate all Transitional dates to an alpha=+-0.5 to have a better visualization in alpha vs. logprob plots
logprobs_dedup.loc[logprobs_dedup['IsCorrect'] == 'Transitional', 'Alpha'] = logprobs_dedup.loc[logprobs_dedup['IsCorrect'] == 'Transitional', 'Alpha'].apply(lambda value : min([-0.5,0.5], key=lambda x:abs(x-value)))

In [None]:
# Sample one test per relation

pd.set_option('display.max_colwidth', None)
# rel = logprobs['Fact'].apply(lambda x : x.relation.id)
x = logprobs_dedup.groupby(['IsCorrect'])[['Fact', 'Time', 'IsCorrect', 'Statement']].sample(2)
print(x.to_latex(index=False))

In [None]:
plt.figure()
unique_dates_per_fact = logprobs.groupby(['Fact', 'Granularity'], sort=False)['Time'].unique().apply(len).reset_index()
sns.histplot(unique_dates_per_fact, x='Time', hue='Granularity')
plt.title('Before deduplication')

plt.figure()
unique_dates_per_fact = logprobs_dedup.groupby(['Fact', 'Granularity'], sort=False)['Time'].unique().apply(len).reset_index()
sns.histplot(unique_dates_per_fact, x='Time', hue='Granularity')
plt.title('After deduplication')

In [None]:
# Add order score
def f(df : pd.DataFrame):
    x = df['CondLogProb'].values
    wins = (x[:, None] > x[None, :]).sum(-1) / (len(x)-1)
    return pd.Series(wins, index=df.index)
logprobs_dedup['FactStr'] = logprobs_dedup['Fact'].apply(str)
logprobs_dedup = logprobs_dedup.sort_values(by=['Model', 'FactStr', 'Granularity', 'Alpha']).reset_index(drop=True)
logprobs_dedup['TimeWin'] = logprobs_dedup[['Model', 'FactStr', 'Granularity', 'CondLogProb']].groupby(['Model', 'FactStr', 'Granularity'], sort=False).apply(f).reset_index(drop=True)

In [None]:
models_data

In [None]:
def f(df : pd.DataFrame):
    correct_mask = df['IsCorrect'] == 'Correct'
    incorrect_mask = df['IsCorrect'] == 'Incorrect'
    matches = df.loc[correct_mask, 'CondLogProb'].values[:, None] > df.loc[incorrect_mask, 'CondLogProb'].values[None, :]
    wrong1, wrong2 = np.where(~matches)
    alpha1, alpha2 = df.loc[correct_mask, 'Alpha'].iloc[wrong1].values, df.loc[incorrect_mask, 'Alpha'].iloc[wrong2].values
    soft_score = matches.mean()
    hard_score = soft_score == 1
    if not df['CondLogProbInstruct'].isna().all():
        matches_inst = df.loc[correct_mask, 'CondLogProbInstruct'].values[:, None] > df.loc[incorrect_mask, 'CondLogProbInstruct'].values[None, :]
        soft_score_inst = matches_inst.mean()
        hard_score_inst = soft_score_inst == 1
        matches_inst = df.loc[correct_mask, 'CondLogProbInstruct'].values[:, None] > df.loc[incorrect_mask, 'CondLogProbInstruct'].values[None, :]
        wrong1, wrong2 = np.where(~matches_inst)
        alpha1_inst, alpha2_inst = df.loc[correct_mask, 'Alpha'].iloc[wrong1].values, df.loc[incorrect_mask, 'Alpha'].iloc[wrong2].values
    else:
        soft_score_inst = float('nan')
        hard_score_inst = float('nan')
        alpha1_inst, alpha2_inst = None, None
    
    model = df['Model'].iloc[0]
    fact = df['Fact'].iloc[0]
    granularity = df['Granularity'].iloc[0]
    return pd.DataFrame(data = {
        'Model' : [model]*2,
        'Granularity' : [granularity.name.lower().capitalize()]*2,
        'Fact' : [fact]*2,
        'Instruct' : [False, True],
        'AccSoft' : [soft_score, soft_score_inst],
        'AccHard' : [hard_score, hard_score_inst],
        'WrongAlphaIncorrect' : [alpha2, alpha2_inst],
        'WrongAlphaCorrect' : [alpha1, alpha1_inst]
    })
score_per_fact = logprobs_dedup.groupby(['Model', 'Fact', 'Granularity'], sort=False).apply(f).reset_index(drop=True)

GLOBALGR = "Global"
all_hard = score_per_fact.groupby(['Model', 'Fact', 'Instruct'], sort=False)['AccHard'].apply(lambda x : float('nan') if x.isna().all() else (x == 1).all())
all_soft = score_per_fact.groupby(['Model', 'Fact', 'Instruct'], sort=False)['AccSoft'].mean()
all_ = pd.concat([all_hard, all_soft], axis=1).reset_index()
all_['Granularity'] = GLOBALGR

score_per_fact = pd.concat([score_per_fact, all_])
score_per_fact['AccHard'] = score_per_fact['AccHard'].astype(float)

In [None]:
# Computing the performance of a random baseline for zwin rate
logprobs_random = logprobs_dedup[logprobs_dedup['Model'] == 'apple_OpenELM-3B'].copy()
logprobs_random['CondLogProb'] = np.random.random(len(logprobs_random))
logprobs_random['CondLogProbInstruct'] = np.random.random(len(logprobs_random))

score_per_fact_random = logprobs_random.groupby(['Model', 'Fact', 'Granularity'], sort=False).apply(f).reset_index(drop=True)
print('Random baseline performance:')
score_per_fact_random[['AccSoft', 'AccHard']].mean()

In [None]:
score_per_fact.shape

In [None]:
score_per_fact['Model'].unique()

# How many popular facts LLMs know? 

## non-instruct queries

In [None]:
import os
sns.set(font_scale=1.1)

data = score_per_fact[~score_per_fact['Instruct']]
data['Model'] = data['Model'].apply(lambda x : x.split('_', 1)[-1])
order = data.groupby('Model')['AccSoft'].mean().sort_values().index.tolist()
plt.figure(figsize=(5,7))
sns.barplot(data=data, y='Model', x='AccSoft', hue='Granularity', order=order)
plt.axvline(x=0.5, color='black', linestyle='--');
plt.xlim((0.4,1))
plt.xlabel('Average $\mathcal{W}$')
plt.ylabel('')
plt.text(0.53, 10, 'Random', color='black', fontsize=15,  ha='center', va='bottom', rotation=90);
plt.tight_layout()
os.makedirs('plots', exist_ok=True)
plt.savefig('plots/lenient_%s.pdf' % EXPERIMENT, bbox_inches='tight', pad_inches=0);

In [None]:
sns.set(font_scale=1.1)

data = score_per_fact[~score_per_fact['Instruct']]
data['Model'] = data['Model'].apply(lambda x : x.split('_', 1)[-1])
# data['AccHard'] = data['AccHard'].astype(float)
order = data[data['Granularity'] == GLOBALGR].groupby('Model')['AccHard'].mean().sort_values().index.tolist()
plt.figure(figsize=(5,7))
sns.barplot(data=data, y='Model', x='AccHard', hue='Granularity', order=order)
plt.xlim((0,0.15));
plt.xlabel('Average $\mathcal{R}$')
plt.ylabel('');
plt.tight_layout()
plt.savefig('plots/robust_%s.pdf' % EXPERIMENT, bbox_inches='tight', pad_inches=0);

### Instruct models vs. Non-instruct models

In [None]:
data = score_per_fact[~score_per_fact['Instruct'] & (score_per_fact['Granularity'] == GLOBALGR)]
classical_models = [x for x in data['Model'].apply(get_classical_version_of_model).unique() if x is not None]
instruct_models = [x for x in data['Model'].apply(get_instruct_version_of_model).unique() if x is not None]    
data_classical = data[np.isin(data['Model'], classical_models)]
data_instruct : pd.DataFrame = data[np.isin(data['Model'], instruct_models)]
data_instruct['ModelClassical'] = data_instruct['Model'].apply(get_classical_version_of_model)
assert data_instruct['ModelClassical'].shape[0] == data_classical.shape[0]
versus_data = data_instruct.merge(data_classical, left_on=['ModelClassical', 'Fact', 'Granularity'], right_on=['Model', 'Fact', 'Granularity'], suffixes=['', 'Classical'])
column_numbers = [x for x in range(versus_data.shape[1])]
column_numbers.remove(versus_data.columns.tolist().index('ModelClassical')) # Remove Duplicate column
versus_data = versus_data.iloc[:, column_numbers]

versus_data['Versus $\mathcal{W}$'] = versus_data[['AccSoft', 'AccSoftClassical']].apply(lambda x : 'Instruct-tuned Wins' if x['AccSoft'] > x["AccSoftClassical"] else "Pretrained Wins" if x['AccSoft'] < x["AccSoftClassical"] else 'Tie', axis=1)
versus_data['Versus $\mathcal{R}$'] = versus_data[['AccHard', 'AccHardClassical']].apply(lambda x : 'Instruct-tuned Wins' if x['AccHard'] > x["AccHardClassical"] else "Pretrained Wins" if x['AccHard'] < x["AccHardClassical"] else 'Tie', axis=1)
versus_data['ModelClassical'] = versus_data['ModelClassical'].apply(lambda x : x.split('_', 1)[-1])

In [None]:
sns.set(font_scale=1.2)

# display(versus_data.groupby(['ModelClassical', 'Versus-Lenient'])['Model'].count().to_frame())
plt.figure(figsize=(6,4))
n_unique_models = len(versus_data['ModelClassical'].unique())
sns.histplot(data=versus_data, y='ModelClassical', hue='Versus $\mathcal{W}$', weights=(1/versus_data.shape[0])*n_unique_models, multiple='stack', hue_order=['Pretrained Wins', 'Tie', 'Instruct-tuned Wins'])
plt.axvline(x=0.5, color='black', linestyle='--');
plt.ylabel('')
plt.xlabel('Proportion')
plt.savefig('plots/inst_vs_noninst_lenient_%s.pdf' % EXPERIMENT, bbox_inches='tight', pad_inches=0);

## instruct queries

Does Intruct-format helps predicting the right answer?

In [None]:
data.drop(columns=["WrongAlphaIncorrect", "WrongAlphaCorrect"]).groupby(['Model', 'Granularity', 'Instruct'])[['AccSoft', 'AccHard']].mean()

In [None]:
sns.set(font_scale=1.2)

instruct_models = models_data.loc[models_data['Instruct'], 'Model']
data = score_per_fact[np.isin(score_per_fact['Model'], instruct_models)]
avg = data.groupby(['Instruct', 'Fact', 'Granularity'], sort=False)[['AccHard', 'AccSoft']].mean().reset_index()
avg['Model'] = 'Average'
data = pd.concat([data, avg], ignore_index=True)
data['Model'] = data['Model'].apply(lambda x : x.split('_', 1)[-1])
data['Average $\mathcal{R}$'] = data['AccHard']
data['Average $\mathcal{W}$'] = data['AccSoft']
data['InstructGran'] = data['Instruct'].apply(lambda x : 'Inst' if x else 'NoInst') + '_' + data['Granularity']
order = data.groupby('Model')['AccHard'].mean().sort_values().index.tolist()
order.remove('Average')
order.append('Average')
sns.set(rc={"figure.figsize":(12, 5)})
sns.catplot(data=data, y='Model', x='Average $\mathcal{R}$', hue='Instruct', order=order, kind='bar', col='Granularity', height=4, aspect=0.6);
plt.savefig('plots/inst_vs_noninst_inst_vs_noninst_query_robust_all_%s.pdf' % EXPERIMENT, bbox_inches='tight', pad_inches=0);

In [None]:
sns.set(font_scale=1.2)

order = data.groupby('Model')['AccSoft'].mean().sort_values().index.tolist()
order.remove('Average')
order.append('Average')
sns.set(rc={"figure.figsize":(12, 5)})
sns.catplot(data=data, y='Model', x='Average $\mathcal{W}$', hue='Instruct', order=order, kind='bar', col='Granularity', height=4, aspect=0.6);
plt.savefig('plots/inst_vs_noninst_inst_vs_noninst_query_winrate_all_%s.pdf' % EXPERIMENT, bbox_inches='tight', pad_inches=0);

Impossible to conclude something because instruct prompt are in the form of a question while the classical prompt is an affirmation 

In [None]:
plt.figure(figsize=(3,5))
sns.set(font_scale=1.1)

data = data[data['Granularity'] == GLOBALGR]
data['Input type'] = data['Instruct'].apply(lambda x: "Instruction" if x else "Raw text")
data['Average $\mathcal{R}_G$'] = data['Average $\mathcal{R}$']
sns.barplot(data=data, y='Model', x='Average $\mathcal{R}_G$', hue='Input type', order=order);
plt.ylabel('')
plt.savefig('plots/inst_vs_noninst_inst_vs_noninst_query_robust_global_%s.pdf' % EXPERIMENT, bbox_inches='tight', pad_inches=0);

# Model size and family vs. Performance

In [None]:
sns.set(font_scale=1.2)
data = score_per_fact[~score_per_fact['Instruct']].reset_index(drop=True).merge(models_data[['Model', 'NumParams', 'Instruct']], on='Model', suffixes=['', 'Info'])
data = data[data['Granularity'] == GLOBALGR]
family = data['Model'].apply(get_model_family)
# data['NumParams'] = np.log10(data['NumParams'])
plt.figure(figsize=(5,4))
sns.lineplot(data,x='NumParams',y='AccHard', hue=family, style=data['InstructInfo'], markers='o', hue_order=sorted(family.unique()), errorbar=None)
plt.xlabel('Number of parameters')
plt.ylabel('Average $\mathcal{R}_G$')
plt.savefig('plots/numparams_vs_robust_%s.pdf' % EXPERIMENT, bbox_inches='tight', pad_inches=0);

In [None]:
data = score_per_fact[~score_per_fact['Instruct']].reset_index(drop=True).merge(models_data[['Model', 'NumParams', 'Instruct']], on='Model', suffixes=['', 'Info'])
family = data['Model'].apply(get_model_family)
data = data[(data['Granularity'] == GLOBALGR) & (family == 'Gemma-2') & (data['InstructInfo'])]
plt.figure(figsize=(5,4))
sns.lineplot(data,x='NumParams',y='AccHard', style=data['InstructInfo'], markers='o', errorbar=None)
plt.xlim((1,1000000000000000))
plt.ylim((0,1))
plt.xscale('log')
plt.xlabel('Number of parameters')
plt.ylabel('Average $\mathcal{R}_G$')
plt.savefig('plots/numparams_vs_robust_gemma.pdf', bbox_inches='tight', pad_inches=0);

# Measuring temporal coherence

The proportion of facts that are known in one granularity but not in the others

In [None]:
from itertools import chain, combinations

def powerset(iterable):
    "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
    s = list(iterable)
    return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))

# all_combinations_gran = list(set(x) for x in powerset(['Day', 'Month', 'Year']))[1:]
all_combinations_gran_str = [
  'Year',
  'Month',
  'Day',
  'Year+Month',
  'Year+Day',
  'Month+Day',
  'Year+Month+Day',
]
all_combinations_gran = [set(x.split('+')) for x in all_combinations_gran_str]
# all_combinations_gran_str = [all_combinations_gran_str[i] for i in order]
# all_combinations_gran = [all_combinations_gran[i] for i in order]

In [None]:
def f(df : pd.DataFrame):
    df = df[df['AccHard'] == 1]
    gl = df['Granularity'].tolist()
    
    df = df.iloc[0]
    df['KnownGran'] = '_'.join(sorted(gl))
    df.drop(columns='AccHard', inplace=True)
    return df

data = score_per_fact[(score_per_fact['AccHard'] == 1) & ~score_per_fact['Instruct'] & (score_per_fact['Granularity'] != GLOBALGR)]
known_gran = data.groupby(['Model','Fact'], sort=False)[['AccHard', 'Granularity']].apply(f).reset_index(drop=True)
order = known_gran.groupby('KnownGran')['AccHard'].count().sort_values().index.tolist()
known_gran['KnownGran'] = pd.Categorical(known_gran['KnownGran'], order)
sns.histplot(data=known_gran, y="KnownGran")

In [None]:
sns.set_style('white')
sns.set(font_scale=1.5)
for instruct in (True, False):
    mask = ~score_per_fact['Instruct'] if not instruct else score_per_fact['Instruct']
    for USE_ALL_COMBINATIONS in (False,):
        if KEPT_MODELS is not None:
            kept_models = KEPT_MODELS.copy()
        else:
            kept_models = score_per_fact[(score_per_fact['Granularity'] == GLOBALGR) & mask].groupby('Model')['AccHard'].mean().sort_values(ascending=False).index.tolist()[:5]
        for model in kept_models:
            print(model, instruct)
            def f(df : pd.DataFrame):
                l = set(df['Granularity'].to_list())
                df2 = pd.DataFrame(data={
                    'Fact': [df.name]*len(all_combinations_gran)
                })
                df2['CombGran'] = all_combinations_gran
                success = [x.issubset(l) for x in all_combinations_gran]
                df2['Success'] = success
                return df2
            data = score_per_fact[(score_per_fact['Model'] == model) & (score_per_fact['AccHard'] == 1) & mask & (score_per_fact['Granularity'] != GLOBALGR)]
            known_gran = data.groupby(['Fact'], sort=False)[['Granularity']].apply(f).reset_index(drop=True)
            if len(known_gran) == 0:
                print('Skip %s' % model)
                continue

            def propci_wilson_cc(count, nobs, alpha=0.05):
                # get confidence limits for proportion
                # using wilson score method w/ cont correction
                # i.e. Method 4 in Newcombe [1]; 
                # verified via Table 1
                from scipy import stats
                n = nobs
                p = count/n
                q = 1.-p
                z = stats.norm.isf(alpha / 2.)
                z2 = z**2   
                denom = 2*(n+z2)
                num = 2.*n*p+z2-1.-z*np.sqrt(z2-2-1./n+4*p*(n*q+1))    
                ci_l = num/denom
                num = 2.*n*p+z2+1.+z*np.sqrt(z2+2-1./n+4*p*(n*q-1))
                ci_u = num/denom
                if p == 0:
                    ci_l = 0.
                elif p == 1:
                    ci_u = 1.
                return (ci_l+ci_u)/2, (ci_u-ci_l)/2 

            heatmap = np.zeros((len(all_combinations_gran), len(all_combinations_gran)))
            cis = np.zeros_like(heatmap).astype(str)
            for i in range(len(all_combinations_gran)):
                for j in range(len(all_combinations_gran)):
                    A = all_combinations_gran[i].union(all_combinations_gran[j])
                    B = all_combinations_gran[j]
                    A_B = A.intersection(B)
                    B_occ = known_gran.loc[known_gran['CombGran'] == B, 'Success'].sum()
                    A_occ = known_gran.loc[known_gran['CombGran'] == A, 'Success'].sum()
                    A_B_occ = A_occ + B_occ - known_gran.loc[known_gran['CombGran'] == A_B, 'Success'].sum()
                    cis[i,j] = '%.2f±\n%.2f' % propci_wilson_cc(A_B_occ, B_occ) if not A.issubset(B) else 1
                    heatmap[i,j] = A_B_occ / B_occ
            plt.figure(figsize=(5,4))
            if USE_ALL_COMBINATIONS:
                sns.heatmap(heatmap, vmin=0, vmax=1, xticklabels=all_combinations_gran_str, yticklabels=all_combinations_gran_str, cmap="rocket", fmt='', annot=cis, annot_kws={"size": 10})
            else:
                sns.heatmap(heatmap[:3,:3], vmin=0, vmax=1, xticklabels=all_combinations_gran_str[:3], yticklabels=all_combinations_gran_str[:3], cmap="rocket", fmt='', annot=cis[:3,:3], annot_kws={"size": 20})
            plt.ylabel('A')
            plt.xlabel('B')
            plt.title('P(A is known | B is known)')
            plt.savefig('plots/gran_vs_%s_%s_%s_%s.pdf' % ('allcomb' if USE_ALL_COMBINATIONS else 'smallcomb', 'inst' if instruct else 'noninst', model, EXPERIMENT), bbox_inches='tight', pad_inches=0);
            plt.show()


In [None]:
# Global robustness
for instruct in (True, False):
    for USE_ALL_COMBINATIONS in (False,):
        print(instruct)
        def f(df : pd.DataFrame):
            l = set(df['Granularity'].to_list())
            df2 = pd.DataFrame(data={
                'Fact': [df.name]*len(all_combinations_gran)
            })
            df2['CombGran'] = all_combinations_gran
            success = [x.issubset(l) for x in all_combinations_gran]
            df2['Success'] = success
            return df2
        mask = ~score_per_fact['Instruct'] if not instruct else score_per_fact['Instruct']
        if KEPT_MODELS is not None:
            kept_models = KEPT_MODELS.copy()
        else:
            kept_models = score_per_fact[(score_per_fact['Granularity'] == GLOBALGR) & mask].groupby('Model')['AccHard'].mean().sort_values(ascending=False).index.tolist()[:5]

        data = score_per_fact[(score_per_fact['AccHard'] == 1) & mask & (score_per_fact['Granularity'] != GLOBALGR) & np.isin(score_per_fact['Model'], kept_models)]
        known_gran = data.groupby(['Model','Fact'], sort=False)[['Granularity']].apply(f).reset_index(drop=True)

        def propci_wilson_cc(count, nobs, alpha=0.05):
            # get confidence limits for proportion
            # using wilson score method w/ cont correction
            # i.e. Method 4 in Newcombe [1]; 
            # verified via Table 1
            from scipy import stats
            n = nobs
            p = count/n
            q = 1.-p
            z = stats.norm.isf(alpha / 2.)
            z2 = z**2   
            denom = 2*(n+z2)
            num = 2.*n*p+z2-1.-z*np.sqrt(z2-2-1./n+4*p*(n*q+1))    
            ci_l = num/denom
            num = 2.*n*p+z2+1.+z*np.sqrt(z2+2-1./n+4*p*(n*q-1))
            ci_u = num/denom
            if p == 0:
                ci_l = 0.
            elif p == 1:
                ci_u = 1.
            return (ci_l+ci_u)/2, (ci_u-ci_l)/2 

        heatmap = np.zeros((len(all_combinations_gran), len(all_combinations_gran)))
        cis = np.zeros_like(heatmap).astype(str)
        for i in range(len(all_combinations_gran)):
            for j in range(len(all_combinations_gran)):
                A = all_combinations_gran[i].union(all_combinations_gran[j])
                B = all_combinations_gran[j]
                A_B = A.intersection(B)
                B_occ = known_gran.loc[known_gran['CombGran'] == B, 'Success'].sum()
                A_occ = known_gran.loc[known_gran['CombGran'] == A, 'Success'].sum()
                A_B_occ = A_occ + B_occ - known_gran.loc[known_gran['CombGran'] == A_B, 'Success'].sum()
                cis[i,j] = '%.2f±\n%.2f' % propci_wilson_cc(A_B_occ, B_occ) if not A.issubset(B) else 1
                heatmap[i,j] = A_B_occ / B_occ
        plt.figure(figsize=(5,4))
        if USE_ALL_COMBINATIONS:
            sns.heatmap(heatmap, vmin=0, vmax=1, xticklabels=all_combinations_gran_str, yticklabels=all_combinations_gran_str, cmap="rocket", fmt='', annot=cis, annot_kws={"size": 10})
        else:
            sns.heatmap(heatmap[:3,:3], vmin=0, vmax=1, xticklabels=all_combinations_gran_str[:3], yticklabels=all_combinations_gran_str[:3], cmap="rocket", fmt='', annot=cis[:3,:3], annot_kws={"size": 20})
            avg_transfer = (heatmap[:3,:3].sum()-3)/6
            print('Average=%s' % avg_transfer)
        plt.ylabel('A')
        plt.xlabel('B')
        plt.title('P(A is known | B is known)')
        plt.savefig('plots/gran_vs_%s_%s_alllms_%s.pdf' % ('allcomb' if USE_ALL_COMBINATIONS else 'smallcomb', 'inst' if instruct else 'noninst', EXPERIMENT), bbox_inches='tight', pad_inches=0);
        plt.show()

### What about date to date consistency?

In [None]:
def f(df : pd.DataFrame):
    correct_mask = df['IsCorrect'] == 'Correct'
    incorrect_mask = df['IsCorrect'] == 'Incorrect'
    matches = df.loc[correct_mask, 'CondLogProb'].values[:, None] > df.loc[incorrect_mask, 'CondLogProb'].values[None, :]
    soft_score = matches.mean(0)
    df['CorrectWinRateOverIncorrect'] = float('nan')
    df['CorrectWinRateOverIncorrectInstruct'] = float('nan')
    df.loc[incorrect_mask, 'CorrectWinRateOverIncorrect'] = soft_score
    if not df['CondLogProbInstruct'].isna().all():
        matches = df.loc[correct_mask, 'CondLogProbInstruct'].values[:, None] > df.loc[incorrect_mask, 'CondLogProbInstruct'].values[None, :]
        soft_score = matches.mean(0)
        df.loc[incorrect_mask, 'CorrectWinRateOverIncorrectInstruct'] = soft_score
    df = df.sort_values('Alpha')
    df['Position'] = np.arange(len(df))
    df = df[df['IsCorrect'] == 'Incorrect']
    return df[['Position', "Alpha", 'CorrectWinRateOverIncorrect', 'CorrectWinRateOverIncorrectInstruct']]
logprobs_dedup2 = logprobs_dedup[~logprobs_dedup['Alpha'].isna() & np.isin(logprobs_dedup['Model'], kept_models)].groupby(['Fact', 'Model', 'Granularity'], sort=False)[['CondLogProbInstruct', 'CondLogProb', 'Alpha', 'IsCorrect']].apply(f)

## Naive Consistency between date precision   

In [None]:
logprobs_dedup2[['CorrectWinRateOverIncorrect2', 'CorrectWinRateOverIncorrectInstruct2']] = logprobs_dedup2[['CorrectWinRateOverIncorrect', 'CorrectWinRateOverIncorrectInstruct']] == 1
data = logprobs_dedup2.groupby(['Model', 'Fact', 'Position'], sort=False)[['CorrectWinRateOverIncorrect2', 'CorrectWinRateOverIncorrectInstruct2']].sum()
data = pd.concat([data, logprobs_dedup2.groupby(['Model', 'Fact', 'Position'], sort=False)['Alpha'].mean()], axis=1)

In [None]:
data['AlphaBin'] = pd.cut(data['Alpha'], bins=np.arange(-5,5.001,0.5))
data2 = data[data['CorrectWinRateOverIncorrect2'] > 0]
data2['CorrectWinRateOverIncorrect2'] = data2['CorrectWinRateOverIncorrect2'] == 3
data2 = data2.reset_index()
data2['FactStr'] = data2['Fact'].apply(str)
weight = data2.value_counts(["Model", "FactStr"], normalize=True)
weight.name = "Weight"
data2 = data2.merge(weight, on=["Model", "FactStr"])
data2['Format'] = "Raw text"

# Instruct
data3 = data[data['CorrectWinRateOverIncorrectInstruct2'] > 0]
data3['CorrectWinRateOverIncorrectInstruct2'] = data3['CorrectWinRateOverIncorrectInstruct2'] == 3
data3 = data3.reset_index()
data3['FactStr'] = data3['Fact'].apply(str)
weight = data3.value_counts(["Model", "FactStr"], normalize=True)
weight.name = "Weight"
data3 = data3.merge(weight, on=["Model", "FactStr"])
data3['Format'] = "Instruction"
data3['CorrectWinRateOverIncorrect2'] = data3['CorrectWinRateOverIncorrectInstruct2']

data4 = pd.concat([data3,data2], ignore_index=True)

sns.barplot(data4, x='AlphaBin', y='CorrectWinRateOverIncorrect2', hue='Format')
plt.xlabel('$\\alpha$')
plt.ylabel('Precision-wise consistency')
plt.xticks(rotation=90)
# plt.savefig('plots/granularity_consistency_%s.pdf' % EXPERIMENT, bbox_inches='tight', pad_inches=0);

### More sophisticated consistency through correlation

In [None]:
logprobs_dedup2['CorrectWinRateOverIncorrect2'] = logprobs_dedup2['CorrectWinRateOverIncorrect'] == 1
logprobs_dedup2['CorrectWinRateOverIncorrectInstruct2'] = logprobs_dedup2['CorrectWinRateOverIncorrectInstruct'] == 1
# data = logprobs_dedup2.groupby(['Model', 'Fact', 'Position'], sort=False)[['CorrectWinRateOverIncorrect2', 'CorrectWinRateOverIncorrectInstruct2']].sum()
# data = pd.concat([data, logprobs_dedup2.groupby(['Model', 'Fact', 'Position'], sort=False)['Alpha'].mean()], axis=1)
data = logprobs_dedup2.copy().reset_index()
data['AlphaBin'] = pd.cut(data['Alpha'], bins=np.arange(-5,5.001,0.5))
data = data[~data['AlphaBin'].isna()]
data['FactStr'] = data.reset_index()['Fact'].apply(str).values  
data = data.sort_values(['FactStr', 'Position'])

In [None]:
from sklearn.metrics import matthews_corrcoef
def f(df : pd.DataFrame):
    df = df.set_index(['Position', 'FactStr']) 
    ydf = df.loc[df['Granularity'] == TimeUnit.YEAR, ['CorrectWinRateOverIncorrect2', 'CorrectWinRateOverIncorrectInstruct2']]
    ydf.rename(inplace=True, columns={
        'CorrectWinRateOverIncorrect2' : 'CorrectWinRateOverIncorrect2Year',
        'CorrectWinRateOverIncorrectInstruct2' : 'CorrectWinRateOverIncorrectInstruct2Year'
    })
    mdf = df.loc[df['Granularity'] == TimeUnit.MONTH, ['CorrectWinRateOverIncorrect2', 'CorrectWinRateOverIncorrectInstruct2']]
    mdf.rename(inplace=True, columns={
        'CorrectWinRateOverIncorrect2' : 'CorrectWinRateOverIncorrect2Month',
        'CorrectWinRateOverIncorrectInstruct2' : 'CorrectWinRateOverIncorrectInstruct2Month'
    })
    ddf = df.loc[df['Granularity'] == TimeUnit.DAY, ['CorrectWinRateOverIncorrect2', 'CorrectWinRateOverIncorrectInstruct2']]
    ddf.rename(inplace=True, columns={
        'CorrectWinRateOverIncorrect2' : 'CorrectWinRateOverIncorrect2Day',
        'CorrectWinRateOverIncorrectInstruct2' : 'CorrectWinRateOverIncorrectInstruct2Day'
    })
    df = pd.concat([mdf,ydf,ddf], axis=1)
    df = df.loc[~df[['CorrectWinRateOverIncorrect2Month', 'CorrectWinRateOverIncorrectInstruct2Month', 'CorrectWinRateOverIncorrectInstruct2Day',
                     'CorrectWinRateOverIncorrect2Year', 'CorrectWinRateOverIncorrectInstruct2Year', 'CorrectWinRateOverIncorrect2Day']].isna().any(axis=1)]
    df[['CorrectWinRateOverIncorrect2Month', 'CorrectWinRateOverIncorrectInstruct2Month', 'CorrectWinRateOverIncorrectInstruct2Day', 'CorrectWinRateOverIncorrect2Day',
        'CorrectWinRateOverIncorrect2Year', 'CorrectWinRateOverIncorrectInstruct2Year']] = \
        df[['CorrectWinRateOverIncorrect2Month', 'CorrectWinRateOverIncorrectInstruct2Month', 
        'CorrectWinRateOverIncorrectInstruct2Day', 'CorrectWinRateOverIncorrect2Day', 'CorrectWinRateOverIncorrect2Year', 
        'CorrectWinRateOverIncorrectInstruct2Year']].astype(int)
    raw = 1/3*(matthews_corrcoef(df['CorrectWinRateOverIncorrect2Month'], df['CorrectWinRateOverIncorrect2Year']) \
        + matthews_corrcoef(df['CorrectWinRateOverIncorrect2Day'], df['CorrectWinRateOverIncorrect2Year'])\
        + matthews_corrcoef(df['CorrectWinRateOverIncorrect2Month'], df['CorrectWinRateOverIncorrect2Day']))
    instruct = 1/3*(matthews_corrcoef(df['CorrectWinRateOverIncorrectInstruct2Month'], df['CorrectWinRateOverIncorrectInstruct2Year']) \
        + matthews_corrcoef(df['CorrectWinRateOverIncorrectInstruct2Day'], df['CorrectWinRateOverIncorrectInstruct2Year'])\
        + matthews_corrcoef(df['CorrectWinRateOverIncorrectInstruct2Month'], df['CorrectWinRateOverIncorrectInstruct2Day']))
    return raw, instruct
data2 = data.groupby(['Model', 'AlphaBin']).apply(f)
data2.name = 'MCC'
data2 = data2.reset_index()
data2['Format'] = [('Raw text', 'Instruction')]*len(data2)
data3 = data2.explode(['Format', 'MCC'])
data3['AlphaBin2'] = data3['AlphaBin'].apply(lambda x : x.mid)
plt.figure(figsize=(10,3))
sns.barplot(data3, x='AlphaBin2', y='MCC', hue='Format', errorbar=None)
plt.xlabel('$\\alpha$')
plt.ylabel('Average correlation between\ndate precisions')
plt.xticks(range(len(data3['AlphaBin2'].cat.categories)),data3['AlphaBin'].cat.categories, rotation=90)
plt.ylim((0.3,0.6))
plt.savefig('plots/granularity_consistency_corr_%s.pdf' % EXPERIMENT, bbox_inches='tight', pad_inches=0);

### Distribution of win rate

In [None]:
sns.set_style('white')
plt.figure(figsize=(6,4))
ax = sns.histplot(data=score_per_fact.reset_index(), x='AccSoft', bins=100)

In [None]:
# Best robustness model
plt.figure(figsize=(6,4))
sns.histplot(data=score_per_fact.loc[score_per_fact['Model'] == 'google_gemma-2-27b-it', 'AccSoft'].reset_index(), x='AccSoft', bins=100)

In [None]:
# Best win rate model
plt.figure(figsize=(6,4))
sns.histplot(data=score_per_fact.loc[score_per_fact['Model'] == 'meta-llama_Llama-3.1-70B-Instruct', 'AccSoft'].reset_index(), x='AccSoft', bins=100)

# Where do language models fail with respect to the distance to the validity period?

### Where LMs fail when win rate is high?

In [None]:
def t(alpha: float):
    def t_(series : pd.Series):
        mu = (np.abs(series) >= alpha).mean()
        if alpha < 0.5:
            p = 1
        else:
            p = (5-alpha) / 4.5
        return mu
    return t_
to_plots = []
alpha_range = np.arange(1,4.999,1)
for instruct, mask in zip(('instruct', 'noninstruct'), (score_per_fact['Instruct'], ~score_per_fact['Instruct'])):
    for th in (0.9,0.95,0.97,0.99):
        print(instruct, th)
        kept_models = score_per_fact[(score_per_fact['Granularity'] == GLOBALGR) & mask].groupby('Model')['AccHard'].mean().sort_values(ascending=False).index.tolist()[:5] 
        data = score_per_fact.loc[mask & (score_per_fact['Granularity'] == "Year") & np.isin(score_per_fact['Model'], kept_models),
                                ['Model', 'AccSoft', 'WrongAlphaIncorrect', 'Fact']].copy()
        data = data[(1 > data['AccSoft']) & (data['AccSoft'] >= th)]
        print('len(data)', len(data))
        data['FactStr'] = data['Fact'].apply(str)
        weights = data.value_counts(["Model", "FactStr"], normalize=True)
        weights.name = 'Weight'
        data = data.explode('WrongAlphaIncorrect')
        for alpha in alpha_range:
            mu = data.groupby(["Model", 'FactStr'], sort=False)['WrongAlphaIncorrect'].apply(t(alpha)).to_frame()
            mu['Alpha'] = int(alpha)
            mu['Threshold'] = th
            mu = mu.merge(weights, on=['Model', 'FactStr'])
            mu['Instruct'] = instruct
            to_plots.append(mu)
to_plot = pd.concat(to_plots)
values = []
for instruct in ('instruct', 'noninstruct'):
    to_plot2 = to_plot[to_plot['Instruct'] == instruct]
    plt.figure(figsize=(5,4))
    ax = sns.barplot(to_plot2, x='Alpha', y='WrongAlphaIncorrect', weights='Weight', hue='Threshold')
    cis = np.stack([l.get_xydata()[:, 1] for l in ax.lines])
    mid = cis.mean(1)
    error = (cis[:,1] - cis[:,0]) / 2
    values.append(["%.3f±%.3f" % (x,y) for x,y in zip(mid, error)])
    plt.xlabel('$\\alpha$')
    plt.xticks(list(range(len(alpha_range))), ["$\\geq%s$" % int(x) for x in alpha_range])
    plt.ylabel('Proportion of incorrect dates\nfavored over a correct date')
    plt.savefig('plots/incorrect_date_position__th_%s_%s.pdf' % (instruct, EXPERIMENT), bbox_inches='tight', pad_inches=0);
    plt.show()

In [None]:
alpha2succ = to_plot[to_plot['Threshold'] == 0.95].groupby(['Instruct', 'Alpha'])[['WrongAlphaIncorrect', 'Weight']].apply(lambda x : (x['WrongAlphaIncorrect'] * x['Weight']).sum())
alpha2succ = alpha2succ.to_frame().reset_index()
alpha2succ["Incorrect dates' proportion"] = alpha2succ[0]
alpha2succ = alpha2succ.drop(columns=0)
alpha2succ['Alpha'] = alpha2succ['Alpha'].apply(lambda x : "$\\geq %s$" % int(x))
alpha2succ = alpha2succ.set_index(['Alpha'])
alpha2succ = pd.concat([alpha2succ[alpha2succ['Instruct'] != 'instruct'], alpha2succ[alpha2succ['Instruct'] == 'instruct']], axis=1).drop(columns='Instruct')
alpha2succ.columns = ['Raw text', 'Instruction']
values = np.array(values).reshape(2,4,4)
alpha2succ['Raw text'] = values[1, 1, :]
alpha2succ['Instruction'] = values[0, 1, :]
print(alpha2succ.to_latex(index=True, float_format='%.3f'))

In [None]:
# for instruct, mask in zip(('instruct', 'noninstruct'), (score_per_fact['Instruct'], ~score_per_fact['Instruct'])):
#     print(instruct)
#     kept_models = score_per_fact[(score_per_fact['Granularity'] == GLOBALGR) & mask].groupby('Model')['AccHard'].mean().sort_values(ascending=False).index.tolist()[:5] 
#     data = score_per_fact.loc[mask & (score_per_fact['Granularity'] != GLOBALGR) & np.isin(score_per_fact['Model'], kept_models),
#                             ['Model', 'AccSoft', 'WrongAlphaIncorrect']].copy()
#     data = data[(data['AccSoft'] > 0.95)]
#     data = np.concatenate(data['WrongAlphaIncorrect'].tolist())
#     plt.figure(figsize=(5,4))
#     plt.xlabel('$\\alpha$')
#     plt.ylabel('Proportion')
#     weights = np.ones_like(data) / len(data)
#     plt.hist(data, bins=30, weights=weights)
#     plt.savefig('plots/incorrect_date_position_%s_%s.pdf' % (instruct, EXPERIMENT), bbox_inches='tight', pad_inches=0);
#     plt.show()
#     for alpha in (0,0.5,1,2,3,4,5):
#         w = len(data[np.abs(data) >= alpha])  / len(data)
#         print('Alpha >= %s --> %s' % (alpha, w))

# Are dynamic facts less likely to be memorized?

In [None]:
from wikidata_tools.core import TimedTriple, TripleQuery
from wikidata_tools.wikidata import TempWikidata, WikidataPrepStage


all_facts = score_per_fact['Fact'].unique()
wd = TempWikidata("20210104", WikidataPrepStage.PREPROCESSED)
history_size = []
for fact in all_facts:
    fact : TimedTriple
    history_size.append(len(list(wd.find(TripleQuery(fact.subject, fact.relation)))))

In [None]:
fact_history_size = pd.DataFrame({
    'Fact' : all_facts,
    'HistorySize' : np.log(1+np.array(history_size))
})
data = score_per_fact.merge(fact_history_size, on='Fact')
data = data[data['Granularity'] != GLOBALGR]

data_long = pd.melt(data, id_vars='HistorySize', value_vars=['AccSoft', 'AccHard'], 
                    var_name='Type', value_name='Performance')

# Replace the variable names for better readability
data_long.replace({'AccSoft': '$\mathcal{W}$', 'AccHard': '$\mathcal{R}$'}, inplace=True)

# Plot the data
sns.lmplot(data=data_long, x="HistorySize", y='Performance', hue='Type', x_bins=list(range(0, 11)))
plt.xlabel('Log History length of the fact (s,r,o)')

plt.savefig('plots/perf_vs_historysize_%s.pdf' % EXPERIMENT, bbox_inches='tight', pad_inches=0);

The bigger the history the larger chance of being memorized
- More history means probably more times where the training dataset mentioned the fact and its frontier validity (positive corr.)
- However, the longer the history the more contradictions in the trainig dataset  (negative corr.)

# Are older facts more known when the granularity is year?

In [None]:
data = score_per_fact.copy()
data = data[data['Granularity'] != GLOBALGR]
data['Log distance in years between\nstart date of the validity period and present'] = np.log(data['Fact'].apply(lambda x : 2021 - x.valid_between.start.year))
sns.lmplot(data=data, x="Log distance in years between\nstart date of the validity period and present", y='AccSoft', hue='Granularity', x_bins=range(0, 10, 1))
plt.ylabel('Average $\mathcal{W}$')
plt.savefig('plots/start_vs_lenient_%s.pdf' % EXPERIMENT, bbox_inches='tight', pad_inches=0);
sns.lmplot(data=data, x="Log distance in years between\nstart date of the validity period and present", y='AccHard', hue='Granularity', x_bins=range(0, 10, 1))
plt.ylabel('Average $\mathcal{R}$')
plt.savefig('plots/start_vs_robust_%s.pdf' % EXPERIMENT, bbox_inches='tight', pad_inches=0);

Because data exist more on recent events than old ones, the model is biased towards recent knowledge.

# Is length period related to performance?

In [None]:
data = score_per_fact.copy()
data = data[data['Granularity'] != GLOBALGR]
data['Log Duration of fact'] = np.log(data['Fact'].apply(lambda x : x.valid_between.end.year - x.valid_between.start.year))
sns.lmplot(data=data, x="Log Duration of fact", y='AccSoft', hue='Granularity', x_bins=range(0, 10), order=1)
plt.ylabel('Average $\mathcal{W}$')
plt.savefig('duration_vs_lenient.pdf', bbox_inches='tight', pad_inches=0);
sns.lmplot(data=data, x="Log Duration of fact", y='AccHard', hue='Granularity', x_bins=range(0, 10), order=1)
plt.ylabel('Average $\mathcal{R}$')
plt.savefig('plots/duration_vs_robust_%s.pdf' % EXPERIMENT, bbox_inches='tight', pad_inches=0);

- The higher the period length the less it is known because frequent changes on some fact implies more occurences of it in the dataset (positive correlation)
- The higher the period length the less contradictions in the dataset (negative correlation)

In [None]:
plt.figure(figsize=(3,6))
data = score_per_fact.copy()
data = data[data['Granularity'] == GLOBALGR]
data['Relation'] = data['Fact'].apply(lambda x : x.relation.label)
order = data.groupby('Relation')['AccHard'].mean().sort_values(ascending=False).index.tolist()[:10]
sns.barplot(data=data, y='Relation', x='AccHard', order=order)
plt.xlabel('Average $\mathcal{R}_G$')
plt.ylabel('')
plt.savefig('plots/relation_vs_robust_%s.pdf' % EXPERIMENT, bbox_inches='tight', pad_inches=0);

# Is popularity related to performance?

In [None]:
from wikidata_tools.wikidata import WikidataPopularity


data = score_per_fact.merge(logprobs_dedup[['Fact', 'FactPop']].drop_duplicates(), on='Fact')
data = data[data['Granularity'] == GLOBALGR]

wikipop = WikidataPopularity('20210104')
data['ObjectPop'] = np.exp(wikipop.get_popularity(data['Fact'].apply(lambda x : x.object)))
data['SubjectPop'] = np.exp(wikipop.get_popularity(data['Fact'].apply(lambda x : x.subject)))

Subject popularity

In [None]:
sns.lmplot(data=data, y='AccHard', x='SubjectPop', x_bins=range(0, 12*10**5+1, 10**5), order=2)
plt.ylabel('Average $\mathcal{R}_G$')
plt.xlabel('Popularity of $s$ in fact $(s,r,o)$')
plt.ylim((-0.03,0.6))
plt.savefig('plots/pop_sub_vs_robust_%s.pdf' % EXPERIMENT, bbox_inches='tight', pad_inches=0);

Object popularity

In [None]:
sns.lmplot(data=data, y='AccHard', x='ObjectPop', x_bins=range(0, 12*10**5+1, 10**5), order=2)
plt.ylabel('Average $\mathcal{R}_G$')
plt.xlabel('Popularity of $o$ in fact $(s,r,o)$')
plt.ylim((-0.03,0.6))
plt.savefig('plots/pop_obj_vs_robust_%s.pdf' % EXPERIMENT, bbox_inches='tight', pad_inches=0);

Fact popularity

In [None]:
plt.figure(figsize=(6,10))
sns.lmplot(data=data, y='AccHard', x='FactPop', x_bins=range(0, 12*10**5+1, 10**5), order=2)
plt.ylabel('Average $\mathcal{R}_G$')
plt.xlabel('Popularity of fact')
plt.ylim((-0.03,0.6))
plt.savefig('plots/pop_fact_vs_robust_%s.pdf' % EXPERIMENT, bbox_inches='tight', pad_inches=0);

# Alpha vs. Performance

## using Delta LogProb

In [None]:
data = logprobs_dedup.copy()
data['Granularity'] = data['Granularity'].apply(lambda x : x.name.lower().capitalize())

data['Date validity'] = data['IsCorrect']
data = data[np.isin(data['IsCorrect'], ['Correct', 'Incorrect', 'Transitional'])]
data['$\\alpha$'] = data['Alpha']
data['$log\ P(o \\mid f,d)$'] = data['CondLogProb']

In [None]:
sns.lmplot(data, x='$\\alpha$', y='$log\ P(o \\mid f,d)$', hue="Date validity", x_bins=np.linspace(-5,5,100), col='Granularity', fit_reg=False, ci=None)

In [None]:
sns.lmplot(data[data['Granularity'] == 'Year'], x='$\\alpha$', y='$log\ P(o \\mid f,d)$', hue="Date validity", x_bins=np.linspace(-5,5,100), fit_reg=False, ci=None, 
           hue_order=['Correct', 'Incorrect', 'Transitional'], palette=['#007f4e', '#e12729', '#f37324'], markers=['*', 'x', 'o'],
           height=5, aspect=1.2)
plt.savefig('plots/alpha_vs_logprob_%s.pdf' % EXPERIMENT, bbox_inches='tight', pad_inches=0);

In [None]:
plt.figure(figsize=(7,4))
sns.set(font_scale=1.2)

data_ = data.copy()
data_['$\\alpha$'] = pd.cut(data_['Alpha'], bins=[-5,-3,-1,-0.5,-0.25,0,0.25,0.5,1,3,5]).astype('category')
data_['$\\alpha$'] = data_['$\\alpha$'].cat.add_categories([-0.5,0.5,pd.Interval(-1.0, -0.5, 'neither'), pd.Interval(0.25, 0.5, 'neither')])
# data_.loc[data_['IsCorrect'] == 'Transitional', 'Alpha'] = data_.loc[data_['IsCorrect'] == 'Transitional', 'Alpha'].apply(lambda value : min([-0.5,0.5], key=lambda x:abs(x-value)))
data_.loc[data_['$\\alpha$'] == pd.Interval(-1.0, -0.5, 'right'), '$\\alpha$'] = pd.Interval(-1.0, -0.5, 'neither')
data_.loc[data_['$\\alpha$'] == pd.Interval(0.25, 0.5, 'right'), '$\\alpha$'] = pd.Interval(0.25, 0.5, 'neither')
data_.loc[data_['Alpha'] == 0.5, '$\\alpha$'] = 0.5
data_.loc[data_['Alpha'] == -0.5, '$\\alpha$'] = -0.5
data_['$\\alpha$'] = data_['$\\alpha$'].astype(str).astype('category')
categories = [
    "(-5.0, -3.0]",
    "(-3.0, -1.0]",
    "(-1.0, -0.5)",
    "-0.5",
    "(-0.5, -0.25]",
    "(-0.25, 0.0]",
    "(0.0, 0.25]",
    "(0.25, 0.5)",
    "0.5",
    "(0.5, 1.0]",
    "(1.0, 3.0]",
    "(3.0, 5.0]"
]
reference = 10
data_['$\\alpha$'] = data_['$\\alpha$'].cat.set_categories(categories, ordered=True)
data_['$log\ P(o \\mid f,d)$'] = data_['$log\ P(o \\mid f,d)$'] + reference
ax = sns.barplot(data_[data_['Granularity'] == 'Year'], x='$\\alpha$', y='$log\ P(o \\mid f,d)$', hue="Date validity", 
           hue_order=['Correct', 'Incorrect', 'Transitional'], palette=['#007f4e', '#e12729', '#f37324'],
           dodge=False, 
           width=1)
sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))

counts = data_[data_['Granularity'] == 'Year'].groupby('$\\alpha$').size()

plt.xticks(rotation=90)  # Rotate x-axis labels for better readability
plt.tight_layout()
ax.set_yticklabels([f'{label - reference}' for label in ax.get_yticks()])
plt.ylim(0,3.5)
plt.savefig('plots/alpha_vs_logprob_bar_%s.pdf' % EXPERIMENT, bbox_inches='tight', pad_inches=0)

- The model assigns a higher logprob of the good answer when it is close to the start of the validity period vs. the end

## using order

In [None]:
sns.lmplot(data, x='Alpha', y='TimeWin', hue="IsCorrect", x_bins=np.linspace(-5,5,100), col="Granularity", fit_reg=False, ci=None)