In [1]:
import pandas as pd
import numpy as np
import glob

import utils_greedy_histogram as greedy_hists
import utils_probabilistic_histogram as prob_hists

from utils_io import print_sep, read_json, persist_histograms
from default_vars import BIN_CENTER, BIN_OFFSET, UNCERTAINTY_EXPRESSIONS

def parse_verifiable(df: pd.DataFrame, gen_study=False) -> pd.DataFrame:
    assert df["statement_type"].nunique() == (6 if not gen_study else 2), df["statement_type"].unique()
    
    data = df.copy()
    data["_statement_type_orig"] = data["statement_type"]
    data["statement_type"] = data["_statement_type_orig"].apply(lambda x: x.split("_")[0].strip())
    data["statement_truth"] = data["statement_id"].apply(lambda x: str("true" in x).lower())
    return data

In [2]:
OUTPUT_DIR = "../../results"

N_SHOTS = (
    0, 
    2,
)

# 1. Greedy Histogram

In this section, we parse the results obtained via the sampling based approach. 
In general, we generate `k` samples from the same prompt and collect the majority vote across the generations. 

If the model generates a number among the top-20 log probabilities, then it will be placed in the column `number_1`. Therefore, a greedy histogram can be constructed by assuming this to be most likely number. 

In the general case, we compute the histogram by following the steps: 
1. Compute the empirical frequency of each number.
2. Pick the continuation that maximizes the empirical frequency (i.e., majority vote)
3. Determine its bin.
4. Add to the current count of that bin.
5. Repeat steps 1-4 for every statement.
6. Finally, normalize by the number of unique statements.


In this study, we use the sample-based algorithm with greedy decoding (`n=1`) for the following models:
- `Meta/Llama-3-70b-chat-hf`, run using TogetherAI for efficiency purposes.
- `mistralai/Mixtral-8x7B-Instruct-v0.1`
- `mistralai/Mixtral-8x22B-Instruct-v0.1`
- `gemini-pro`

**Note**: Heuristically, we define the first number as being the accurate response. 
This naturally warrants more investigation, to assess the validity of this heuristic.

In [3]:
hist_sampl_kwargs = dict(
    id_cols=["statement_uuid", "statement_type", "uncertainty_expression"],
    number_col="response_first_number",
    unc_col="uncertainty_expression",
    uncertainty_expressions=UNCERTAINTY_EXPRESSIONS,
    bin_center=BIN_CENTER,
    bin_offset=BIN_OFFSET,
)

## 1.1. Non-verifiable

In [4]:
for n_shot in N_SHOTS:
    sample_filepaths = glob.glob(f"../../results/outputs/non-verifiable-{n_shot}-shot/**/sample_completions.csv", recursive=True)

    for fp in sample_filepaths:
        if "gemini" in fp:
            continue
        print("Processing", fp); 
        df = pd.read_csv(fp)
        model_name = df.loc[0, "model"]
        prefix = "sampling__" + model_name.replace("/", "__")
        assert model_name in fp
    
        # Overall
        histograms = greedy_hists.create_histogram_for_sampling_approach(df, **hist_sampl_kwargs)
        persist_histograms(*histograms, results_folder=f"{OUTPUT_DIR}/greedy/all/non_verifiable/models-{n_shot}shot", prefix=prefix)
    
        # By gender
        for gender in ("male", "female"):
            df_gender_subset = df[df["gender"] == gender].copy()
            assert len(df_gender_subset) < len(df)
            histograms = greedy_hists.create_histogram_for_sampling_approach(df_gender_subset, **hist_sampl_kwargs)
            persist_histograms(*histograms, results_folder=f"{OUTPUT_DIR}/greedy/by_gender/non_verifiable/models-{n_shot}shot", prefix=prefix+"_"+gender)

        # By statement type
        assert 4 == df["statement_type"].nunique()
        for st_type in df["statement_type"].unique():
            df_st_type_subset = df[df["statement_type"] == st_type].copy()
            assert df_st_type_subset["statement_type"].nunique() == 1
            histograms = greedy_hists.create_histogram_for_sampling_approach(df_st_type_subset, **hist_sampl_kwargs)
            persist_histograms(*histograms, results_folder=f"{OUTPUT_DIR}/greedy/by_statement_type/non_verifiable/models-{n_shot}shot", prefix=prefix+"_"+st_type)

Processing ../../results/outputs/non-verifiable-0-shot/mistralai/Mixtral-8x22B-Instruct-v0.1/sample_completions.csv
Processing ../../results/outputs/non-verifiable-0-shot/mistralai/Mixtral-8x7B-Instruct-v0.1/sample_completions.csv
Processing ../../results/outputs/non-verifiable-2-shot/meta-llama/Llama-3-70b-chat-hf/sample_completions.csv
Processing ../../results/outputs/non-verifiable-2-shot/mistralai/Mixtral-8x22B-Instruct-v0.1/sample_completions.csv
Processing ../../results/outputs/non-verifiable-2-shot/mistralai/Mixtral-8x7B-Instruct-v0.1/sample_completions.csv


### 1.2. Verifiable


In [5]:
for n_shot in N_SHOTS:
    sample_filepaths = glob.glob(f"../../results/outputs/verifiable-FT-{n_shot}-shot/**/sample_completions.csv", recursive=True)
    
    for fp in sample_filepaths:
        print("Processing", fp); 
        df = parse_verifiable(pd.read_csv(fp))
        model_name = df.loc[0, "model"]
        assert model_name in fp
        prefix = "sampling__" + model_name.replace("/", "__")
        
        # ---- ---- ---- ---- ---- ---- ---- ----
        # Overall
        # ---- ---- ---- ---- ---- ---- ---- ----
        histograms = greedy_hists.create_histogram_for_sampling_approach(df, **hist_sampl_kwargs)
        persist_histograms(*histograms, results_folder=f"{OUTPUT_DIR}/greedy/all/verifiable/models-{n_shot}shot", prefix=prefix)
        # ---- ---- ---- ---- ---- ---- ---- ----
        # Gender
        # ---- ---- ---- ---- ---- ---- ---- ----
        for gender in ("male", "female"):
            df_gender_subset = df[df["gender"] == gender].copy()
            assert len(df_gender_subset) < len(df)
            histograms = greedy_hists.create_histogram_for_sampling_approach(df_gender_subset, **hist_sampl_kwargs)
            persist_histograms(*histograms, results_folder=f"{OUTPUT_DIR}/greedy/by_gender/verifiable/models-{n_shot}shot", prefix=prefix+"_"+gender)

        # ---- ---- ---- ---- ---- ---- ---- ----
        # Statement type
        # ---- ---- ---- ---- ---- ---- ---- ----
        assert 3 == df["statement_type"].nunique()
        for st_type in df["statement_type"].unique():
            df_st_type_subset = df[df["statement_type"] == st_type].copy()
            assert df_st_type_subset["statement_type"].nunique() == 1
            histograms = greedy_hists.create_histogram_for_sampling_approach(df_st_type_subset, **hist_sampl_kwargs)
            persist_histograms(*histograms, results_folder=f"{OUTPUT_DIR}/greedy/by_statement_type/verifiable/models-{n_shot}shot", prefix=prefix+"_"+st_type)

        # ---- ---- ---- ---- ---- ---- ---- ----
        # Statement truth/falsity
        # ---- ---- ---- ---- ---- ---- ---- ----
        assert 2 == df["statement_truth"].nunique()
        for st_truth in df["statement_truth"].unique():
            df_subset_v = df[df["statement_truth"] == st_truth]
            assert len(df_subset_v) < len(df)    
            histograms = greedy_hists.create_histogram_for_sampling_approach(df_subset_v, **hist_sampl_kwargs)
            persist_histograms(*histograms, results_folder=f"{OUTPUT_DIR}/greedy/by_statement_truth/verifiable/models-{n_shot}shot", prefix=prefix+"_"+st_truth)

Processing ../../results/outputs/verifiable-FT-0-shot/meta-llama/Llama-3-70b-chat-hf/sample_completions.csv
Processing ../../results/outputs/verifiable-FT-0-shot/mistralai/Mixtral-8x22B-Instruct-v0.1/sample_completions.csv
Processing ../../results/outputs/verifiable-FT-0-shot/mistralai/Mixtral-8x7B-Instruct-v0.1/sample_completions.csv
Processing ../../results/outputs/verifiable-FT-0-shot/models/gemini-pro/sample_completions.csv
Processing ../../results/outputs/verifiable-FT-2-shot/allenai/OLMo-7B-Instruct/sample_completions.csv
Processing ../../results/outputs/verifiable-FT-2-shot/google/gemma-1.1-2b-it/sample_completions.csv
Processing ../../results/outputs/verifiable-FT-2-shot/meta-llama/Llama-3-70b-chat-hf/sample_completions.csv
Processing ../../results/outputs/verifiable-FT-2-shot/mistralai/Mixtral-8x22B-Instruct-v0.1/sample_completions.csv
Processing ../../results/outputs/verifiable-FT-2-shot/mistralai/Mixtral-8x7B-Instruct-v0.1/sample_completions.csv
Processing ../../results/outp

### 1.3. Verifiable (AI2-Arc)

In [6]:
n_shot = 2

for ai2arc_subset in ("ai2arc-easy", "ai2arc-challenge"):
    sample_filepaths = glob.glob(f"../../results/outputs/verifiable-FT-{n_shot}-shot-{ai2arc_subset}/**/sample_completions.csv", recursive=True)
    
    for fp in sample_filepaths:
        print("Processing", fp); 
        df = parse_verifiable(pd.read_csv(fp), gen_study=True)
        model_name = df.loc[0, "model"]
        assert model_name in fp
        prefix = "sampling__" + model_name.replace("/", "__")
        
        # ---- ---- ---- ---- ---- ---- ---- ----
        # Overall
        # ---- ---- ---- ---- ---- ---- ---- ----
        histograms = greedy_hists.create_histogram_for_sampling_approach(df, **hist_sampl_kwargs)
        persist_histograms(*histograms, results_folder=f"{OUTPUT_DIR}/greedy/all/verifiable-{ai2arc_subset}/models-{n_shot}shot", prefix=prefix)
        # ---- ---- ---- ---- ---- ---- ---- ----
        # Gender
        # ---- ---- ---- ---- ---- ---- ---- ----
        for gender in ("male", "female"):
            df_gender_subset = df[df["gender"] == gender].copy()
            assert len(df_gender_subset) < len(df)
            histograms = greedy_hists.create_histogram_for_sampling_approach(df_gender_subset, **hist_sampl_kwargs, ok_non_symmetric=True)
            persist_histograms(*histograms, results_folder=f"{OUTPUT_DIR}/greedy/by_gender/verifiable-{ai2arc_subset}/models-{n_shot}shot", prefix=prefix+"_"+gender)
    
        # ---- ---- ---- ---- ---- ---- ---- ----
        # Statement type
        # ---- ---- ---- ---- ---- ---- ---- ----
        assert 1 == df["statement_type"].nunique()
        for st_type in df["statement_type"].unique():
            df_st_type_subset = df[df["statement_type"] == st_type].copy()
            assert df_st_type_subset["statement_type"].nunique() == 1
            histograms = greedy_hists.create_histogram_for_sampling_approach(df_st_type_subset, **hist_sampl_kwargs)
            persist_histograms(*histograms, results_folder=f"{OUTPUT_DIR}/greedy/by_statement_type/verifiable-{ai2arc_subset}/models-{n_shot}shot", prefix=prefix+"_"+st_type)
    
        # ---- ---- ---- ---- ---- ---- ---- ----
        # Statement truth/falsity
        # ---- ---- ---- ---- ---- ---- ---- ----
        assert 2 == df["statement_truth"].nunique()
        for st_truth in df["statement_truth"].unique():
            df_subset_v = df[df["statement_truth"] == st_truth]
            assert len(df_subset_v) < len(df)    
            histograms = greedy_hists.create_histogram_for_sampling_approach(df_subset_v, **hist_sampl_kwargs)
            persist_histograms(*histograms, results_folder=f"{OUTPUT_DIR}/greedy/by_statement_truth/verifiable-{ai2arc_subset}/models-{n_shot}shot", prefix=prefix+"_"+st_truth)

Processing ../../results/outputs/verifiable-FT-2-shot-ai2arc-easy/allenai/OLMo-7B-Instruct/sample_completions.csv
Processing ../../results/outputs/verifiable-FT-2-shot-ai2arc-easy/google/gemma-1.1-2b-it/sample_completions.csv
Processing ../../results/outputs/verifiable-FT-2-shot-ai2arc-easy/meta-llama/Llama-3-70b-chat-hf/sample_completions.csv
Processing ../../results/outputs/verifiable-FT-2-shot-ai2arc-easy/mistralai/Mixtral-8x22B-Instruct-v0.1/sample_completions.csv
Processing ../../results/outputs/verifiable-FT-2-shot-ai2arc-easy/mistralai/Mixtral-8x7B-Instruct-v0.1/sample_completions.csv
Processing ../../results/outputs/verifiable-FT-2-shot-ai2arc-easy/models/gemini-pro/sample_completions.csv
Processing ../../results/outputs/verifiable-FT-2-shot-ai2arc-challenge/allenai/OLMo-7B-Instruct/sample_completions.csv
Processing ../../results/outputs/verifiable-FT-2-shot-ai2arc-challenge/google/gemma-1.1-2b-it/sample_completions.csv
Processing ../../results/outputs/verifiable-FT-2-shot-ai2a

Because of budget constraints, we could not run this method using multiple samples (i.e., k>>1). Instead, we only run the sampling methodology using `k=1` and greedy decoding. For that reason, if we were to compute the probabilistic histogram vs the greedy histogram the two histograms would be the same. Future work should consider running the same experiment using multiple samples and determine whether significative differences exist between greedy decoding and sampling.