## Simulated knoweldge

In [67]:
import pandas as pd
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
from pathlib import Path
from elk_generalization.elk.elk_utils import SplitConfig

# Ignore pandas performancewarnings as these df's are small anyways
warnings.simplefilter(action='ignore', category=pd.errors.PerformanceWarning)

def cfg_sort_key(descriptor):
    cfg = SplitConfig.from_descriptor(descriptor)
    key1 = cfg.column_to_key["objective_labels"]
    key2 = cfg.column_to_key["quirky_labels"]
    alignment = cfg.get_alignment(key1=key1, key2=key2)

    return str(alignment) + descriptor

def alignment_str(descriptor):
    cfg = SplitConfig.from_descriptor(descriptor)
    key1 = cfg.column_to_key["objective_labels"]
    key2 = cfg.column_to_key["quirky_labels"]
    alignment = cfg.get_alignment(key1=key1, key2=key2)
    alignment_strings = {
        -1: "NEGATIVELY aligned",
        0: "NOT aligned",
        1: "POSITIVELY aligned"
    }
    return f"{key1}, {key2} are {alignment_strings[alignment]}"

In [68]:
# Options   
csv_dir = Path(r"..\elk-generalization\experiments")
csv_filename = Path("summary_aligning_20240331_deduplicated.csv")
fig_dir = Path(r"..\elk-generalization\figures\transfer_align") / csv_filename.stem
os.makedirs(fig_dir, exist_ok=True)
heatmap_kwargs = {"vmin": -0.5, "vmax": 0.5, "fmt":'', "cmap":'coolwarm', "cbar":False}
pr_filter_val = False

# Constants
models = ['pythia-12B', 'pythia-6.9B', 'pythia-2.8B', 'pythia-1.4B', 'pythia-1B', 'pythia-410M']
reporters = ["lr", "mean-diff", "lda", "ccs", "crc", "lm"]
unsupervised_reporters = ["ccs", "crc"]
supervised_reporters = ["lr", "mean-diff", "lda"]
reporters_with_averages = reporters + ["unsupervised_avg", "supervised_avg"]

# Prepare data
filtered_df = pd.read_csv(csv_dir / csv_filename, index_col=[0,1,2]) # Indexes are (model, reporter, train_cfg)
# df = df.sort_index()
# Filter
for f in ["pi=True", f"pr={pr_filter_val}"]:
    relevant_configs = filtered_df.index.get_level_values('train_cfg').str.contains(f)
    filtered_df = filtered_df.loc[relevant_configs, filtered_df.columns.str.contains(f)]
# Add averages
for model in models:
    for train_cfg in filtered_df.columns:
        for test_dataset in filtered_df.columns:
            unsupervised_performances = filtered_df.loc[pd.IndexSlice[model, unsupervised_reporters, train_cfg], test_dataset]
            supervised_performances = filtered_df.loc[pd.IndexSlice[model, supervised_reporters, train_cfg], test_dataset]
            assert len(unsupervised_performances) == len(unsupervised_reporters)
            assert len(supervised_performances) == len(supervised_reporters)
            
            filtered_df.loc[(model, "unsupervised_avg", train_cfg), test_dataset] = unsupervised_performances.mean()
            filtered_df.loc[(model, "supervised_avg", train_cfg), test_dataset] = supervised_performances.mean()

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.colors as mcolors
import os

def plot_heatmaps_for_model(df, model, save_dir, **heatmap_kwargs):
    reporters = ['lm', 'lr', 'mean-diff', 'lda', 'ccs', 'crc']
    titles = ['Baseline (LM output)', 'LR', 'Mean-Diff', 'LDA', 'CCS', 'CRC']
   
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    axes = axes.flatten()
   
    # Custom colormap
    colors = ['#710000', '#f8696b', '#ffeb84', '#63be7b']  # Dark Red to Red to Yellow to Green
    n_bins = 100
    cmap = mcolors.LinearSegmentedColormap.from_list('custom', colors, N=n_bins)
   
    for i, (reporter, title) in enumerate(zip(reporters, titles)):
        filtered_df = df.loc[model, reporter, :]
       
        if reporter == 'lm':
            # For baseline, use only the first row
            filtered_df = filtered_df.iloc[[0]]
        else:
            # Sort index and columns for other methods
            filtered_df = filtered_df.sort_index(key=lambda x: x.map(cfg_sort_key), ascending=False)
            sorted_columns = sorted(filtered_df.columns, key=cfg_sort_key, reverse=True)
            filtered_df = filtered_df[sorted_columns]
       
        # Plot heatmap
        sns.heatmap(filtered_df, ax=axes[i], cmap=cmap, **heatmap_kwargs)
       
        axes[i].set_title(title, fontsize=12)
        axes[i].set_xticks([])
        axes[i].set_yticks([])
        
        # Add "Train" and "Test" labels
        if i >= 3:  # Only for the bottom row
            axes[i].set_xlabel('Test', fontsize=10)
        if i % 3 == 0:  # Only for the leftmost column
            axes[i].set_ylabel('Train', fontsize=10)
   
    # Add colorbar
    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=0, vmax=1))
    sm.set_array([])
    cbar = fig.colorbar(sm, cax=cbar_ax)
    cbar.set_label('AUROC', rotation=270, labelpad=15)
   
    # Add tick for AUROC = 0.5
    cbar.ax.axhline(y=0.5, color='k', linestyle='--', linewidth=0.5)
    cbar.ax.text(1.5, 0.5, 'Random\nGuess', va='center', ha='left', fontsize=8)
   
    plt.suptitle(f'Heatmaps for {model}', fontsize=16)
    plt.tight_layout(rect=[0, 0.03, 0.9, 0.95])
   
    # Save the figure
    plt.savefig(os.path.join(save_dir, f'{model}_heatmaps.png'), dpi=300, bbox_inches='tight')
    plt.close(fig)
def plot_all_models(df, save_dir):
    models = df.index.get_level_values('models').unique()
    
    heatmap_kwargs = {
        "vmin": 0, 
        "vmax": 1, 
        "cbar": False, 
        "annot": True, 
        "fmt": '.2f', 
        "annot_kws": {"size": 8},
        "linewidths": 0.5
    }
    
    for model in models:
        plot_heatmaps_for_model(df, model, save_dir, **heatmap_kwargs)

# Example usage:
save_dir = 'figures/transfer_align/appendix'
os.makedirs(save_dir, exist_ok=True)
plot_all_models(filtered_df, save_dir)

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.colors as mcolors

def plot_3x3_heatmaps_optimized(df, model='pythia-12B', **heatmap_kwargs):
    reporters = ['supervised_avg', 'unsupervised_avg', 'lm']
    titles = ['Supervised (LR, MM, LDA)', 'Unsupervised (CCS, CRC)', 'Baseline (LM output)']
    
    fig, axes = plt.subplots(1, 3, figsize=(13, 4))
    
    # Custom colormap
    colors = ['#710000', '#f8696b', '#ffeb84', '#63be7b']  # Dark Red to Red to Yellow to Green
    n_bins = 100
    cmap = mcolors.LinearSegmentedColormap.from_list('custom', colors, N=n_bins)
    
    for i, (reporter, title) in enumerate(zip(reporters, titles)):
        filtered_df = df.loc[model, reporter, :]
        
        if reporter == 'lm':
            # For baseline, use only the first row (as it's the same for all)
            filtered_df = filtered_df.iloc[[0]]
            # Reshape to get 3 tall columns
            filtered_df = pd.DataFrame(filtered_df.values, columns=['Column 1', 'Column 2', 'Column 3'])
        else:
            # Sort index and columns for supervised and unsupervised
            filtered_df = filtered_df.sort_index(key=lambda x: x.map(cfg_sort_key), ascending=False)
            sorted_columns = sorted(filtered_df.columns, key=cfg_sort_key, reverse=True)
            filtered_df = filtered_df[sorted_columns]
        
        # Plot heatmap
        sns.heatmap(filtered_df, ax=axes[i], cmap=cmap, **heatmap_kwargs)
        
        axes[i].set_title(title, fontsize=10, wrap=True)
        axes[i].set_xticks([])
        axes[i].set_yticks([])
        axes[i].set_xlabel('')  # Remove x-axis label
        axes[i].set_ylabel('')  # Remove y-axis label
    
    # Add colorbar
    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=0, vmax=1))
    sm.set_array([])
    cbar = fig.colorbar(sm, cax=cbar_ax)
    
    # Add AUROC label above the colorbar
    cbar.ax.text(0.5, 1.05, 'AUROC', ha='center', va='bottom', transform=cbar.ax.transAxes)
    
    # Add tick for AUROC = 0.5
    cbar.ax.axhline(y=0.5, color='k', linestyle='--', linewidth=0.5)
    cbar.ax.text(1.5, 0.5, 'Random\nGuess', va='center', ha='left', fontsize=8)
    
    # Adjust layout
    plt.tight_layout()
    plt.subplots_adjust(right=0.9, wspace=0.1)  # Make room for colorbar and reduce space between subplots
    
    return fig, axes

# Example usage:
heatmap_kwargs = {
    "vmin": 0, 
    "vmax": 1, 
    "cbar": False, 
    "annot": True, 
    "fmt": '.2f', 
    "annot_kws": {"size": 10},  # Increase font size of numbers
    "linewidths": 0.5
}
fig, axes = plot_3x3_heatmaps_optimized(filtered_df, **heatmap_kwargs)
plt.savefig('figures/transfer_align/simulated_knowledge_heatmaps.png', dpi=300, bbox_inches='tight')
plt.show()
plt.close(fig)

In [46]:
from pathlib import Path
import pandas as pd
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import ast
import matplotlib.cm as cm
from matplotlib.lines import Line2D

temp_df = None

def visualize_diversify(
        csv_path, 
        reporters, 
        train_size, 
        layer=13, 
        model="meta-llama/Llama-2-13b-hf", 
        ignore_train_datasets=[], 
        ignore_eval_datasets=[], 
        require_train_datasets=[], 
        apply_train_examples_per_dataset=False, 
        seeds=[],
        transfer_type="any",
        n_datasets_traineds=[]
    ):

    if type(reporters) is str: reporters = [reporters]
    
    metric = "accuracy"

    # Create titles and filename for saving   
    pretty_transfer_type = {
        "no transfer": "seen",
        "full transfer": "unseen",
        "semi transfer": "related seen",
        "any": "any"
    }
    def create_title(reporter, train_size, apply_train_examples_per_dataset, transfer_type):
        pretty_reporter = {
            "<class 'probes.LRProbe'>": "LR",
            "<class 'probes.MMProbe'>": "MM",
            "<class 'probes.CCSProbe'>": "CCS",
            "<class 'probes.CrcReporter'>": "CRC",
            "<class 'probes.MMProbe'>_tuple_inference": "MM (contrast-inference)",
            "<class 'probes.LRProbe'>_tuple_inference": "LR (contrast-inference)",
        }
        aggregation_strategy = "Fixed Contribution" if apply_train_examples_per_dataset else "Fixed Total"

        return f"Probe: {pretty_reporter[reporter]} on {pretty_transfer_type[transfer_type]} target, n={train_size} ({aggregation_strategy})"
    save_path = Path(csv_path).parent / f"{train_size}_{pretty_transfer_type[transfer_type]}_{'FC' if apply_train_examples_per_dataset else 'FT'}_{metric}.png"

    eval_only_datasets = [
        'got/companies_true_false',
        'got/common_claim_true_false',
        'got/cities_cities_conj',
        'got/cities_cities_disj',
    ]
    
    ordered_datasets = [
        'got/cities', 
        'got/neg_cities', 
        'got/larger_than',
        'got/smaller_than', 
        'got/sp_en_trans', 
        'got/neg_sp_en_trans',
        'azaria/animals_true_false', 
        'azaria/neg_animals_true_false',
        'azaria/elements_true_false', 
        'azaria/neg_elements_true_false',
        'azaria/facts_true_false', 
        'azaria/neg_facts_true_false',
        'azaria/inventions_true_false', 
        'azaria/neg_inventions_true_false',
    ]


    # Constants
    colormap = cm.get_cmap('tab20', 20)
    oracle_csv_path = Path(r"experiments\diversify_remake\summary_oracle_full.csv")
    always_ignore_datasets = ["got/counterfact_true", "got/counterfact_false"] # these datasets have been corrupted

    # Modify parameters
    ignore_eval_datasets += always_ignore_datasets
    ignore_train_datasets += always_ignore_datasets

    # DataFrame and Filtering
    filtered_df = pd.read_csv(csv_path)

    # Fix wrongly stored data
    # Counting medleys used, replace 0->1 only affects oracles
    filtered_df["n_train_datasets"] = (filtered_df["train_desc"].str.count(r'\+')).replace(0,1)

    filtered_df = filtered_df[filtered_df["model"] == model]
    filtered_df = filtered_df[filtered_df["layer"] == layer]
    for ignore_train_dataset in ignore_train_datasets:
        filtered_df = filtered_df[~(filtered_df["train_desc"].str.contains(ignore_train_dataset))]
    for ignore_eval_dataset in ignore_eval_datasets:
        filtered_df = filtered_df[filtered_df["eval_dataset"] != ignore_eval_dataset]
    for require_train_dataset in require_train_datasets:
        filtered_df = filtered_df[filtered_df["train_desc"].str.contains(require_train_dataset)]

    if reporters: 
        filtered_df = filtered_df[filtered_df["reporter"].isin(reporters)]
    if train_size:
        if apply_train_examples_per_dataset: 
            expected_train_sizes = filtered_df["n_train_datasets"] * train_size
            filtered_df = filtered_df[filtered_df["train_size"] == expected_train_sizes]
        else:
            filtered_df = filtered_df[filtered_df["train_size"] == train_size]
    if seeds:
        filtered_df = filtered_df[filtered_df["seed"].isin(seeds)]
    if n_datasets_traineds:
        filtered_df = filtered_df[filtered_df["n_train_datasets"].isin(n_datasets_traineds)]
    if transfer_type != "any":
        filtered_df = filtered_df[filtered_df["transfer_type"] == transfer_type]
        
    # # TEMP
    # seed_groups = [(seed, len(group_df)) for (seed,  ), group_df in filtered_df.groupby(["seed"])]
    # print(seed_groups)

    global temp_df
    temp_df = filtered_df

    # Sort datasets to ensure ordered_datasets come first in the legend
    unlisted_datasets = [ds for ds in filtered_df["eval_dataset"].unique() if ds not in ordered_datasets]
    sorted_datasets = ordered_datasets + unlisted_datasets
    # Remove datasets that do not appear
    sorted_datasets = [ds for ds in sorted_datasets if ds in filtered_df["eval_dataset"].unique()]

    # Assign colors to datasets, prioritizing the ordered_datasets list
    ds_to_color = {eval_dataset: colormap(i) for i, eval_dataset in enumerate(sorted_datasets)}

    assert len(filtered_df[(filtered_df["oracle"]) & filtered_df["transfer_type"].str.contains("unseen")]) == 0, "Oracles should be 'seen'"

    # Load baselines
    oracle_df = pd.read_csv(oracle_csv_path)
    oracle_df = oracle_df[(oracle_df['oracle']) & (oracle_df['reporter'] == "<class 'probes.LRProbe'>")]

    groups = filtered_df.groupby(["reporter", "transfer_type"])
    fig, axes = plt.subplots(len(groups), figsize=(7.5,1+5*len(groups)), squeeze=False)
    axes = axes.flatten()
    legend_elements = []
    for i, ((reporter, transfer_type), subplot_df) in enumerate(groups):
        ax = axes[i]

        ax.set_title(create_title(reporter, train_size, apply_train_examples_per_dataset, transfer_type))

        n_train_values = sorted(subplot_df["n_train_datasets"].unique())
        ax.set_xlabel("Training datasets mixed")
        ax.set_xticks(n_train_values)
        ax.set_xticklabels(n_train_values)
        ax.set_ylabel(metric)
        ax.set_ylim(0.6, 1.02)
        ax.set_xlim(0.5, 7.5)

        oracle_accuracies = []
        all_datasets_df = pd.DataFrame()
        # Plot one curve for each eval dataset
        for j, (eval_dataset, curve_df) in enumerate(subplot_df.groupby("eval_dataset")):
            # To plot one datapoint for each seed and n_train_datasets, we store the accuracy averaged over training configurations
            aggs_across_training_configs = []

            # Check if all train_descs are uniformly represented
            for (n_train_datasets), datapoint_df in curve_df.groupby(["n_train_datasets"]):
                seeds_per_train_desc = [len(df) for (train_desc), df in datapoint_df.groupby(["train_desc"])]
                if not all([val == seeds_per_train_desc[0] for val in seeds_per_train_desc]):
                    print(f"WARNING: Some train_descs are under / over represented for {eval_dataset=}{n_train_datasets=}")

            previous_datapoint_train_datasets = None
            for (seed, n_train_datasets), datapoint_df in curve_df.groupby(["seed", "n_train_datasets"]):

                # For debug/validation purposes, indicate unfair comparison by drawing black line on this datapoint if different training datasets were used between this run and the previous n_train_datasets or seeds
                datapoint_train_datasets = set([dataset for dataset in datapoint_df["all_train_datasets"].unique() for dataset in ast.literal_eval(dataset)])
                if previous_datapoint_train_datasets is not None and datapoint_train_datasets != previous_datapoint_train_datasets:
                    # There are two cases in which we expect this in the "no transfer" setting
                    # When going from 1 to 2 train datasets, as that is when new datasets are being added to the mix
                    # And when going from multiple to 1 train dataset (usually because of a new seed), as that is when they are removed from the mix again
                    if not (transfer_type == "no transfer" and (n_train_datasets == 2 or (n_train_datasets == 1 and len(previous_datapoint_train_datasets) > 1))):
                        # ax.axvline(x=n_train_datasets, color='black', linewidth=0.5, linestyle="--")
                        pass
                previous_datapoint_train_datasets = datapoint_train_datasets

                aggs_across_training_configs.append({
                    "seed": seed, 
                    "n_train_datasets": n_train_datasets,
                    "eval_dataset": eval_dataset,
                    "mean_accuracy": datapoint_df[metric].mean()
                    }
                )

            # Compute mean and std for this eval_dataset and each n_train_datasets across seeds
            this_dataset_points = []
            for n_train_datasets, datapoint_df in pd.DataFrame(aggs_across_training_configs).groupby("n_train_datasets"):
                this_dataset_points.append({
                    "n_train_datasets": n_train_datasets,
                    "eval_dataset": eval_dataset,
                    "acc_mean": datapoint_df["mean_accuracy"].mean(), 
                    "acc_std": datapoint_df["mean_accuracy"].std(), 
                    "n_seeds": len(datapoint_df)
                })
            #     aggs_across_seeds.append((n_train_datasets, datapoint_df["mean_accuracy"].mean(), datapoint_df["mean_accuracy"].std(), len(datapoint_df))) # REMOVE?

            eval_ds_color = ds_to_color[eval_dataset]
            # n_train_datasets, means, stds, n = zip(*aggs_across_seeds)
            # ax.plot(n_train_datasets, means, color=eval_ds_color) # REMOVE?
            this_dataset_df = pd.DataFrame(this_dataset_points)
            ax.plot(
                this_dataset_df["n_train_datasets"], 
                this_dataset_df["acc_mean"], 
                color=eval_ds_color)
            
            all_datasets_df = pd.concat([all_datasets_df, this_dataset_df], ignore_index=True)

            # ax.errorbar(n_train_datasets, means, yerr=stds, fmt='-o', color=eval_ds_color, capsize=3, lw=1) # REMOVE?
            ax.errorbar(
                this_dataset_df["n_train_datasets"], 
                this_dataset_df["acc_mean"], 
                yerr=this_dataset_df["acc_std"], 
                fmt='-o', 
                color=eval_ds_color, 
                capsize=3, 
                lw=1)

            # Plot oracle baseline
            oracle_accuracy = oracle_df[(oracle_df['eval_dataset'] == eval_dataset) & (oracle_df['n_train_datasets'] == 1)]['accuracy']
            # Ignore oracles for which we don't have data
            if len(oracle_accuracy) > 0:
                ax.axhline(y=oracle_accuracy.item(), color=eval_ds_color, linestyle='-')
                oracle_accuracies.append(oracle_accuracy.item())
            else:
                print(f"No oracle for {eval_dataset}")

        # # Plot average of seperately trained oracle baseline
        # ax.axhline(y=np.mean(oracle_accuracies), color="black", lw=1.0, linestyle="--")

        # Plot jointly trained oracle baseline
        jointly_trained_oracle_row = oracle_df[oracle_df['n_train_datasets'] == oracle_df['n_train_datasets'].max()]
        n_train_for_oracle = jointly_trained_oracle_row['n_train_datasets'].iloc[0]
        n_unique_eval_sets = len(subplot_df['eval_dataset'].unique())
        if n_train_for_oracle != n_unique_eval_sets:
            print(f"Oracle was trained on {n_train_for_oracle} datasets, but we show {n_unique_eval_sets}.")
        ax.axhline(y=jointly_trained_oracle_row['accuracy'].item(), color="black", lw=2.0, linestyle="-")


        # Plot averages in black
        # Summarize over medleys
        aggs_over_mixes = []
        for (seed, n_train_datasets, eval_dataset), sub_df in subplot_df.groupby(["seed", "n_train_datasets", "eval_dataset"]):
            # Ignore datasets that are only used for evaluation to allow fair comparison between transfer types
            if eval_dataset not in eval_only_datasets:
                aggs_over_mixes.append({
                    "n_train_datasets": n_train_datasets,
                    "eval_dataset": eval_dataset,
                    "acc_mean": sub_df[metric].mean(), 
                    # "acc_std_over_mixes": sub_df["accuracy"].std(), 
                    "seed": seed
                })

        summarized_df = pd.DataFrame(aggs_over_mixes)

        # For average dataframe, first aggregate over medleys and eval datasets
        agg_rows = []
        for (n_train_datasets, seed), sub_df in summarized_df.groupby(["n_train_datasets", "seed"]):
            agg_rows.append({
                "n_train_datasets": n_train_datasets,
                "acc_mean": sub_df["acc_mean"].mean(), 
                # "acc_std": sub_df["acc_std"].std(),
                "seed": seed
            })
        agg_df = pd.DataFrame(agg_rows)

        # Then further aggregate over seeds to obtain stds across seeds (but not eval_datasets etc.)
        agg_rows = []
        for (n_train_datasets,), sub_df in agg_df.groupby(["n_train_datasets"]):
            agg_rows.append({
                "n_train_datasets": n_train_datasets,
                "acc_mean": sub_df["acc_mean"].mean(), 
                "acc_std": sub_df["acc_mean"].std(),
            })
        agg_df = pd.DataFrame(agg_rows)

        # Plot averages
        ax.plot(agg_df["n_train_datasets"], agg_df["acc_mean"], color="black", lw=3)
        ax.errorbar(agg_df["n_train_datasets"], agg_df["acc_mean"], yerr=agg_df["acc_std"], fmt='-o', color="black", capsize=3, lw=1)


    # Create custom legend
    def format_dataset_name(ds_name, eval_only_datasets):
        if ds_name in eval_only_datasets:
            return f"({ds_name})"
        return ds_name
    
    legend_elements.append(Line2D([0], [0], color="black", lw=3, label="average"))
    legend_elements.extend([Line2D([0], [0], color=ds_to_color[ds], lw=2, label=format_dataset_name(ds, eval_only_datasets)) for ds in sorted_datasets])
    axes[0].legend(handles=legend_elements, bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
    fig.tight_layout()
    fig.show()
    fig.savefig(save_path)
    print(f"Saved fig to {save_path}")

## Diversity

In [None]:
# Merge data from different seeds into one file per setup
import pandas as pd
import os
from glob import glob

# Directory where the CSV files are located
input_dir = r"experiments\diversify_remake\thesis_summaries"
output_dir = r"experiments\diversify_remake\thesis_summaries_merged"

# Create the output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Get a list of all CSV files in the input directory
csv_files = glob(os.path.join(input_dir, '*.csv'))

# Group files by their common prefix (excluding the seed)
file_groups = {}
for file in csv_files:
    # Extract the common prefix by removing the seed part
    prefix = '_'.join(file.split('_')[:-1])
    if prefix not in file_groups:
        file_groups[prefix] = []
    file_groups[prefix].append(file)

# Merge files for each group and save to a new CSV file
for prefix, files in file_groups.items():
    # Read and concatenate all files in the group
    merged_df = pd.concat([pd.read_csv(f) for f in files], ignore_index=True)
    # Define the output file path
    output_file = os.path.join(output_dir, os.path.basename(prefix) + '_merged.csv')
    # Save the merged DataFrame to a CSV file
    # merged_df.to_csv(output_file, index=False)
    unique_probes_df = merged_df.drop_duplicates(("reporter", "train_desc", "train_size", "oracle", "seed"))
    print(f"{len(merged_df)} evaluations from {len(unique_probes_df)} probes in {output_file}")

print(f"Files have been merged and saved to '{output_dir}' directory.")


In [75]:
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import ast
import matplotlib.cm as cm
from matplotlib.lines import Line2D

def prepare_diversify(
        csv_path, 
        csv_out,
        reporters, 
        train_size, 
        layer=13, 
        model="meta-llama/Llama-2-13b-hf", 
        ignore_train_datasets=[], 
        ignore_eval_datasets=[], 
        require_train_datasets=[], 
        apply_train_examples_per_dataset=False, 
        seeds=[],
        transfer_type="any",
        n_datasets_traineds=[]
    ):

    if type(reporters) is str:
        reporters = [reporters]
    
    debug = True
    metric = "accuracy"

    # Constants
    oracle_csv_path = Path(r"experiments\diversify_remake\summary_oracle_full.csv")
    always_ignore_datasets = ["got/counterfact_true", "got/counterfact_false"]

    # Modify parameters
    ignore_eval_datasets += always_ignore_datasets
    ignore_train_datasets += always_ignore_datasets

    # DataFrame and Filtering
    filtered_df = pd.read_csv(csv_path)

    # Fix wrongly stored data
    filtered_df["n_train_datasets"] = (filtered_df["train_desc"].str.count(r'\+')).replace(0,1)

    filtered_df = filtered_df[filtered_df["model"] == model]
    filtered_df = filtered_df[filtered_df["layer"] == layer]
    for ignore_train_dataset in ignore_train_datasets:
        filtered_df = filtered_df[~(filtered_df["train_desc"].str.contains(ignore_train_dataset))]
    for ignore_eval_dataset in ignore_eval_datasets:
        filtered_df = filtered_df[filtered_df["eval_dataset"] != ignore_eval_dataset]
    for require_train_dataset in require_train_datasets:
        filtered_df = filtered_df[filtered_df["train_desc"].str.contains(require_train_dataset)]

    if debug: print("Unique n_train_datasets after initial filtering:", filtered_df["n_train_datasets"].unique())

    if reporters: 
        filtered_df = filtered_df[filtered_df["reporter"].isin(reporters)]
    if debug: print("Unique n_train_datasets after reporter filter:", filtered_df["n_train_datasets"].unique())
    if train_size:
        if apply_train_examples_per_dataset: 
            expected_train_sizes = filtered_df["n_train_datasets"] * train_size
            filtered_df = filtered_df[filtered_df["train_size"] == expected_train_sizes]
        else:
            filtered_df = filtered_df[filtered_df["train_size"] == train_size]
    if debug: print("Unique n_train_datasets after train_size filter:", filtered_df["n_train_datasets"].unique())
    if seeds:
        filtered_df = filtered_df[filtered_df["seed"].isin(seeds)]
    if debug: print("Unique n_train_datasets after seeds filter:", filtered_df["n_train_datasets"].unique())
    if n_datasets_traineds:
        filtered_df = filtered_df[filtered_df["n_train_datasets"].isin(n_datasets_traineds)]
    if debug: print("Unique n_train_datasets after n_datasets_traineds filter:", filtered_df["n_train_datasets"].unique())
    if transfer_type != "any":
        filtered_df = filtered_df[filtered_df["transfer_type"] == transfer_type]
    if debug: print("Unique n_train_datasets after transfer_type filter:", filtered_df["n_train_datasets"].unique())

    # Load baselines
    oracle_df = pd.read_csv(oracle_csv_path)
    oracle_df = oracle_df[(oracle_df['oracle']) & (oracle_df['reporter'] == "<class 'probes.LRProbe'>")]

    # Prepare data for visualization
    if debug: print("Final unique n_train_datasets before processing:", filtered_df["n_train_datasets"].unique())

    all_data = []
    avg_data = []

    eval_only_datasets = [
        'got/companies_true_false',
        'got/common_claim_true_false',
        'got/cities_cities_conj',
        'got/cities_cities_disj',
    ]

    for (reporter, transfer_type), group_df in filtered_df.groupby(["reporter", "transfer_type"]):
        for (n_train_datasets, eval_dataset), subgroup_df in group_df.groupby(["n_train_datasets", "eval_dataset"]):
            # First level of aggregation: over training configurations
            agg_over_configs = subgroup_df.groupby("seed")[metric].mean().reset_index()
            
            # Second level of aggregation: over seeds
            agg_over_seeds = agg_over_configs[metric].agg(["mean", "std"])
            
            all_data.append({
                "reporter": reporter,
                "transfer_type": transfer_type,
                "aggregation_strategy": "Fixed Contribution" if apply_train_examples_per_dataset else "Fixed Total",
                "n_train_datasets": n_train_datasets,
                "eval_dataset": eval_dataset,
                "acc_mean": agg_over_seeds["mean"],
                "acc_std": agg_over_seeds["std"] if len(agg_over_configs) > 1 else 0,
                "n_seeds": len(agg_over_configs)
            })

        # Calculate average over non-eval-only datasets
        for n_train_datasets, subgroup_df in group_df[~group_df["eval_dataset"].isin(eval_only_datasets)].groupby("n_train_datasets"):
            # First level of aggregation: over training configurations and eval datasets
            agg_over_configs_and_datasets = subgroup_df.groupby("seed")[metric].mean().reset_index()
            
            # Second level of aggregation: over seeds
            agg_over_seeds = agg_over_configs_and_datasets[metric].agg(["mean", "std"])
            
            avg_data.append({
                "reporter": reporter,
                "transfer_type": transfer_type,
                "aggregation_strategy": "Fixed Contribution" if apply_train_examples_per_dataset else "Fixed Total",
                "n_train_datasets": n_train_datasets,
                "eval_dataset": "average",
                "acc_mean": agg_over_seeds["mean"],
                "acc_std": agg_over_seeds["std"] if len(agg_over_configs_and_datasets) > 1 else 0,
                "n_seeds": len(agg_over_configs_and_datasets)
            })

    # Combine individual dataset results and overall average
    final_data = pd.DataFrame(all_data + avg_data)

    # Save processed data
    final_data.to_csv(csv_out, index=False)
    oracle_csv_out = Path(str(csv_out).replace('.csv', '_oracle.csv'))
    oracle_df.to_csv(oracle_csv_out, index=False)

    if debug:
        print("Unique n_train_datasets in processed data:", final_data["n_train_datasets"].unique())

    print(f"Saved processed data to {csv_out} and {oracle_csv_out}")

In [None]:
# Aggregate data from experiments, compute averages
import os

# Preparing data (not visualizing)
# Settings for all figures
csv_dir = Path(r"experiments\diversify_remake\thesis_summaries_merged")
csv_plot_data_dir = csv_dir / "plot_data"
os.makedirs(csv_plot_data_dir, exist_ok=True)

# Common settings
ignore_train_datasets = ["got/companies_true_false"]

filename = "summary_500_total_merged.csv"
apply_train_examples_per_dataset=False
filename_out = filename
prepare_diversify(
    csv_dir / filename,
    csv_plot_data_dir / filename_out,
    reporters=None,
    train_size=500,
    ignore_train_datasets=ignore_train_datasets,
    apply_train_examples_per_dataset=apply_train_examples_per_dataset,
    transfer_type="any",
    n_datasets_traineds=[1,2,3,4,5,6,7]
)

filename = "summary_250_contrib_merged.csv"
apply_train_examples_per_dataset=True
filename_out = filename.replace("250", "500") # fixing misnamed file. train_size argument below guarantees correctness
prepare_diversify(
    csv_dir / filename,
    csv_plot_data_dir / filename_out,
    reporters=None,
    train_size=500,
    ignore_train_datasets=ignore_train_datasets,
    apply_train_examples_per_dataset=apply_train_examples_per_dataset,
    transfer_type="any",
    n_datasets_traineds=[1,2,3,4,5,6,7]
)

filename = "summary_500_contrib_merged.csv"
apply_train_examples_per_dataset=True
filename_out = filename.replace("500", "1000") # fixing misnamed file. train_size argument below guarantees correctness
prepare_diversify(
    csv_dir / filename,
    csv_plot_data_dir / filename_out,
    reporters=None,
    train_size=1000,
    ignore_train_datasets=ignore_train_datasets,
    apply_train_examples_per_dataset=apply_train_examples_per_dataset,
    transfer_type="any",
    n_datasets_traineds=[1,2,3,4,5,6,7]
)

filename = "summary_1000_total_merged.csv"
apply_train_examples_per_dataset=False
filename_out = filename
prepare_diversify(
    csv_dir / filename,
    csv_plot_data_dir / filename_out,
    reporters=None,
    train_size=1000,
    ignore_train_datasets=ignore_train_datasets,
    apply_train_examples_per_dataset=apply_train_examples_per_dataset,
    # transfer_type="full transfer",
    n_datasets_traineds=[2,3,4,5,6]
)

# UNSUPERVISED
# For unsupervised methods we store the number of pairs, which is half as many. 
# We still name it the same downstream from here, as it is effectively the same number of samples
unsupervised_train_size_multiplier=0.5

# Figure 1 a: 500, Unseen, FT
filename = "ccs_summary_500_total_merged.csv"
apply_train_examples_per_dataset=False
filename_out = filename
prepare_diversify(
    csv_dir / filename,
    csv_plot_data_dir / filename_out,
    reporters=None,
    train_size=500*unsupervised_train_size_multiplier,
    ignore_train_datasets=ignore_train_datasets,
    apply_train_examples_per_dataset=apply_train_examples_per_dataset,
    transfer_type="any",
    n_datasets_traineds=[1,2,3,4,5,6,7]
)

filename = "ccs_summary_500_contrib_merged.csv"
apply_train_examples_per_dataset=True
filename_out = filename # fixing misnamed file. train_size argument below guarantees correctness
prepare_diversify(
    csv_dir / filename,
    csv_plot_data_dir / filename_out,
    reporters=None,
    train_size=500*unsupervised_train_size_multiplier,
    ignore_train_datasets=ignore_train_datasets,
    apply_train_examples_per_dataset=apply_train_examples_per_dataset,
    transfer_type="any",
    n_datasets_traineds=[1,2,3,4,5,6,7]
)

filename = "ccs_summary_1000_total_merged.csv"
apply_train_examples_per_dataset=False
filename_out = filename
prepare_diversify(
    csv_dir / filename,
    csv_plot_data_dir / filename_out,
    reporters=None,
    train_size=1000*unsupervised_train_size_multiplier,
    ignore_train_datasets=ignore_train_datasets,
    apply_train_examples_per_dataset=apply_train_examples_per_dataset,
    # transfer_type="full transfer",
    n_datasets_traineds=[2,3,4,5,6]
)

# Require sp_en_trans
filename = "summary_500_total_merged.csv"
apply_train_examples_per_dataset=False
required_train_datasets = ["got/sp_en_trans"]
filename_out = filename.replace(".csv", "_spanish.csv")
prepare_diversify(
    csv_dir / filename,
    csv_plot_data_dir / filename_out,
    reporters=None,
    train_size=500,
    ignore_train_datasets=ignore_train_datasets,
    require_train_datasets=required_train_datasets,
    apply_train_examples_per_dataset=apply_train_examples_per_dataset,
    # transfer_type="full transfer",
    n_datasets_traineds=[1,2,3,4,5,6]
)

filename = "summary_250_contrib_merged.csv"
required_train_datasets = ["got/sp_en_trans"]
apply_train_examples_per_dataset=True
filename_out = filename.replace("250", "500").replace(".csv", "_spanish.csv")
prepare_diversify(
    csv_dir / filename,
    csv_plot_data_dir / filename_out,
    reporters=None,
    train_size=500,
    ignore_train_datasets=ignore_train_datasets,
    require_train_datasets=required_train_datasets,
    apply_train_examples_per_dataset=apply_train_examples_per_dataset,
    # transfer_type="full transfer",
    n_datasets_traineds=[1,2,3,4,5,6]
)

# Contrastive inference
filename = "summary_500_tuple_inference_merged.csv"
reporters = [
    "<class 'probes.LRProbe'>_tuple_inference",
    "<class 'probes.MMProbe'>_tuple_inference"
    ]
ignore_eval_datasets=[
            'got/companies_true_false',
            'got/common_claim_true_false',
            'got/cities_cities_conj',
            'got/cities_cities_disj'
        ]
apply_train_examples_per_dataset = False
filename_out = filename
prepare_diversify(
    csv_dir / filename, 
    csv_plot_data_dir / filename_out,
    reporters, 
    train_size=None, # For this particular figure, the train sizes range from 496 - 504. We accept this inaccuracy and do not filter to not lose any data 
    layer=13, 
    model="meta-llama/Llama-2-13b-hf", 
    ignore_train_datasets=['got/companies_true_false'], 
    ignore_eval_datasets=ignore_eval_datasets, 
    require_train_datasets=[], 
    apply_train_examples_per_dataset=apply_train_examples_per_dataset, 
    seeds=[],
    transfer_type="any",
    n_datasets_traineds=[]
)

In [None]:
# Merge data from different experimental setups so they can be in one CSV and one Figure
def prepare_combined_data(csv_dir, output_csv_path, configs):   
    all_data = []
   
    for filename, train_size in configs:
        try:
            df = pd.read_csv(csv_dir / filename)
            df["train_size"] = train_size
            all_data.append(df)
        except FileNotFoundError:
            print(f"Warning: File {filename} not found. Skipping this configuration.")
      
    combined_df = pd.concat(all_data, ignore_index=True)
        
    os.makedirs(output_csv_path.parent, exist_ok=True)
    combined_df.to_csv(output_csv_path, index=False)
    print(f"Combined data saved to {output_csv_path}")

working_dir = Path(r"experiments\diversify_remake\thesis_summaries_merged")
csv_dir = working_dir / "plot_data"
combined_dir = working_dir / "combined"

# Default experiments n=500
combined_csv_path = combined_dir / "combined_data_500.csv"
configs = [
    ("summary_500_total_merged.csv", 500),
    ("summary_500_contrib_merged.csv", 500),
    ("ccs_summary_500_total_merged.csv", 500),
    ("ccs_summary_500_contrib_merged.csv", 500),
]
prepare_combined_data(csv_dir, combined_csv_path, configs)

# Default experiments n=1000
combined_csv_path = combined_dir / "combined_data_1000.csv"
configs = [
    ("summary_1000_total_merged.csv", 1000),
    ("summary_1000_contrib_merged.csv", 1000),
    ("ccs_summary_1000_total_merged.csv", 1000),
]
prepare_combined_data(csv_dir, combined_csv_path, configs)

# Require Spanish
combined_csv_path = combined_dir / "combined_data_spanish.csv"
configs = [
    ("summary_500_contrib_merged_spanish.csv", 500),
    ("summary_500_total_merged_spanish.csv", 500),
]
prepare_combined_data(csv_dir, combined_csv_path, configs)

# Contrastive Inference
combined_csv_path = combined_dir / "combined_data_contrastive.csv"
configs = [
    ("summary_500_tuple_inference_merged.csv", 500),
]
# Prepare combined data (only need to do this once)
prepare_combined_data(csv_dir, combined_csv_path, configs)

In [78]:
def visualize_combined(
        csv_path, 
        csv_oracle_path, 
        reporter, 
        train_size, 
        transfer_types=["full transfer", "no transfer"], 
        aggregation_strategies=["Fixed Total", "Fixed Contribution"], 
        contrastive_inference=False,
        reference_csv_path=None,
        reference_reporter=None):
    df = pd.read_csv(csv_path)
    df = df[df["reporter"] == reporter]
    df = df[df["train_size"] == train_size]
    oracle_df = pd.read_csv(csv_oracle_path)
    oracle_df = oracle_df[(oracle_df["oracle"]) & (oracle_df["reporter"] == "<class 'probes.LRProbe'>")]
    
    if df.empty:
        raise ValueError(f"No data found for reporter {reporter}")
    
    # Load reference data if provided
    if reference_csv_path:
        assert reference_reporter is not None
        ref_df = pd.read_csv(reference_csv_path)
        ref_df = ref_df[ref_df["reporter"] == reference_reporter]
        ref_df = ref_df[ref_df["train_size"] == train_size]
        
    eval_only_datasets = [
        'got/companies_true_false',
        'got/common_claim_true_false',
        'got/cities_cities_conj',
        'got/cities_cities_disj',
    ]
   
    if contrastive_inference:       
        ordered_datasets = [
            'got/cities+got/neg_cities',
            'got/larger_than+got/smaller_than',
            'got/sp_en_trans+got/neg_sp_en_trans',
            'azaria/animals_true_false+azaria/neg_animals_true_false',
            'azaria/elements_true_false+azaria/neg_elements_true_false',
            'azaria/facts_true_false+azaria/neg_facts_true_false',
            'azaria/inventions_true_false+azaria/neg_inventions_true_false',
        ]
    else:
        ordered_datasets = [
            'got/cities',
            'got/neg_cities',
            'got/larger_than',
            'got/smaller_than',
            'got/sp_en_trans',
            'got/neg_sp_en_trans',
            'azaria/animals_true_false',
            'azaria/neg_animals_true_false',
            'azaria/elements_true_false',
            'azaria/neg_elements_true_false',
            'azaria/facts_true_false',
            'azaria/neg_facts_true_false',
            'azaria/inventions_true_false',
            'azaria/neg_inventions_true_false',
            'got/companies_true_false',
            'got/common_claim_true_false',
            'got/cities_cities_conj',
            'got/cities_cities_disj'
        ]



    # Create titles and filename for saving   
    pretty_transfer_type = {
        "no transfer": "seen",
        "full transfer": "unseen",
        "semi transfer": "related seen",
        "any": "any"
    }

    def create_title(reporter, train_size):
        pretty_reporter = {
            "<class 'probes.LRProbe'>": "LR",
            "<class 'probes.MMProbe'>": "MM",
            "<class 'probes.CCSProbe'>": "CCS",
            "<class 'probes.CrcReporter'>": "CRC",
            "<class 'probes.MMProbe'>_tuple_inference": "MM (contrast-inference)",
            "<class 'probes.LRProbe'>_tuple_inference": "LR (contrast-inference)",
        }

        return f"{pretty_reporter[reporter]}, n={train_size}"
        return f"{pretty_reporter[reporter]} on {pretty_transfer_type[transfer_type]} target, n={train_size} ({aggregation_strategy})"
    
    fig, axes = plt.subplots(len(transfer_types), len(aggregation_strategies), figsize=(3+ 6*len(aggregation_strategies), 3 + 6*len(transfer_types)))
    axes = axes.reshape((len(aggregation_strategies), (len(transfer_types)))) 
    fig.suptitle(create_title(reporter, train_size), fontsize=20, y=0.998)
    
    # Create a custom color palette using tab20
    colormap = cm.get_cmap('tab20', 20)
    color_dict = {eval_dataset: colormap(i) for i, eval_dataset in enumerate(ordered_datasets)}
    color_dict['average'] = 'black'

    # Define x-axis values for each subplot
    x_axis_values = {
        (500, "full transfer", "Fixed Total"): [1,2,3,4,5,6],
        (500, "full transfer", "Fixed Contribution"): [1,2,3,4,5,6],
        (500, "no transfer", "Fixed Total"): [1,2,3,4,5,6,7],
        (500, "no transfer", "Fixed Contribution"): [1,2,3,4,5,6,7],
        (1000, "full transfer", "Fixed Total"): [2,3,4,5,6],
        # (1000, "full transfer", "Fixed Contribution"): [2,3,4,5,6], # Not enough samples
        (1000, "no transfer", "Fixed Total"): [2,3,4,5,6,7],
        # (1000, "no transfer", "Fixed Contribution"): [2,3,4,5,6,7] # Not enough samples
    }


    legend_handles = []

    # Initialize variables to track overall min and max
    overall_y_min = float('inf')
    overall_y_max = float('-inf')

    for i, transfer_type in enumerate(transfer_types):
        for j, agg_strategy in enumerate(aggregation_strategies):
            ax = axes[j, i]
            data = df[(df["transfer_type"] == transfer_type) & (df["aggregation_strategy"] == agg_strategy)]
            
            if data.empty:
                ax.text(0.5, 0.5, "No data available", ha='center', va='center')
                continue
            
            subplot_x_values = x_axis_values[(train_size, transfer_type, agg_strategy)]
            
            subplot_y_min = float('inf')
            subplot_y_max = float('-inf')

            # Plot individual datasets
            for eval_dataset in ordered_datasets:

                subset = data[data["eval_dataset"] == eval_dataset]
                subset = subset[subset["n_train_datasets"].isin(subplot_x_values)]
                line = ax.plot(subset["n_train_datasets"], subset["acc_mean"], 
                            color=color_dict[eval_dataset], alpha=0.9, linewidth=2)
                ax.fill_between(subset["n_train_datasets"], 
                                subset["acc_mean"] - subset["acc_std"],
                                subset["acc_mean"] + subset["acc_std"],
                                color=color_dict[eval_dataset], alpha=0.2)
                
                # Add error bars (whiskers)
                ax.errorbar(subset["n_train_datasets"], subset["acc_mean"], 
                            yerr=subset["acc_std"], 
                            fmt='none', ecolor=color_dict[eval_dataset], 
                            elinewidth=1, capsize=3, alpha=0.7)
                
                # Update subplot min and max
                subplot_y_min = min(subplot_y_min, (subset["acc_mean"] - subset["acc_std"]).min())
                subplot_y_max = max(subplot_y_max, (subset["acc_mean"] + subset["acc_std"]).max())

                # Plot oracle baseline for individual datasets
                oracle_accuracy = oracle_df[
                    (oracle_df['eval_dataset'] == eval_dataset) 
                    & (oracle_df['n_train_datasets'] == 1)
                ]['accuracy']
                if len(oracle_accuracy) > 0:
                    ax.axhline(y=oracle_accuracy.iloc[0], color=color_dict[eval_dataset], linestyle='-', alpha=0.5)

                # Only add to legend_handles if it's the first subplot
                if i == 0 and j == 0:
                    if eval_dataset in eval_only_datasets:
                        legend_handles.append(Line2D([0], [0], color=color_dict[eval_dataset], 
                                                    lw=2, label=f"({eval_dataset})"))
                    else:
                        legend_handles.append(Line2D([0], [0], color=color_dict[eval_dataset], 
                                                    lw=2, label=eval_dataset))

            # Plot average
            avg_data = data[data["eval_dataset"] == "average"]
            avg_data = avg_data[avg_data["n_train_datasets"].isin(subplot_x_values)]
            if not avg_data.empty:
                ax.plot(avg_data["n_train_datasets"], avg_data["acc_mean"], 
                        color='black', linewidth=3, label='Average')
                ax.fill_between(avg_data["n_train_datasets"], 
                                avg_data["acc_mean"] - avg_data["acc_std"],
                                avg_data["acc_mean"] + avg_data["acc_std"],
                                color='black', alpha=0.2)
                
                # Add error bars (whiskers) for average
                ax.errorbar(avg_data["n_train_datasets"], avg_data["acc_mean"], 
                            yerr=avg_data["acc_std"], 
                            fmt='none', ecolor='black', 
                            elinewidth=1, capsize=3, alpha=0.7)
            
                # Add average to legend only once
                if i == 0 and j == 0:
                    legend_handles.insert(0, Line2D([0], [0], color='black', lw=3, label='Average'))

            # Plot reference average if provided
            if reference_csv_path:
                ref_avg_data = ref_df[(ref_df["transfer_type"] == transfer_type) & 
                                      (ref_df["aggregation_strategy"] == agg_strategy) & 
                                      (ref_df["eval_dataset"] == "average")]
                ref_avg_data = ref_avg_data[ref_avg_data["n_train_datasets"].isin(subplot_x_values)]
                if not ref_avg_data.empty:
                    ax.plot(ref_avg_data["n_train_datasets"], ref_avg_data["acc_mean"], 
                            color='black', linestyle=':', linewidth=2, label='Average (normal inference)')
                    
                    # Add reference average to legend only once
                    if i == 0 and j == 0:
                        legend_handles.insert(1, Line2D([0], [0], color='black', linestyle=':', lw=2, label='Average (normal inference)'))
    
            # Update subplot min and max
            subplot_y_min = min(subplot_y_min, (avg_data["acc_mean"] - avg_data["acc_std"]).min())
            subplot_y_max = max(subplot_y_max, (avg_data["acc_mean"] + avg_data["acc_std"]).max())

            
            # Plot oracle baseline for average (black horizontal line)
            max_n_train_datasets = oracle_df['n_train_datasets'].max()
            oracle_accuracy_avg = oracle_df[oracle_df['n_train_datasets'] == max_n_train_datasets]['accuracy']
            if len(oracle_accuracy_avg) > 0:
                ax.axhline(y=oracle_accuracy_avg.iloc[0], color='black', linestyle='-', linewidth=2)

            ax.set_title(f"{'Unseen' if transfer_type == 'full transfer' else 'Seen'}, {agg_strategy}", fontsize=16)
            ax.set_xlabel("Number of Training Datasets", fontsize=14)
            ax.set_ylabel("Accuracy", fontsize=14)
            ax.tick_params(axis='both', which='major', labelsize=12)
            
            # Set x-axis limits and ticks
            ax.set_xlim(1, 7)
            ax.set_xticks(range(1, 8))
            ax.set_xticklabels(range(1, 8))

            # Update overall min and max
            overall_y_min = min(overall_y_min, subplot_y_min)
            overall_y_max = max(overall_y_max, subplot_y_max)

    # Add some padding to the y-limits
    y_range = overall_y_max - overall_y_min
    overall_y_min = max(0, overall_y_min - 0.05 * y_range)
    overall_y_max = min(1.05, overall_y_max + 0.05 * y_range)
    

    # Set consistent y-axis limits for all subplots
    for ax in axes.flatten():
        ax.set_ylim(overall_y_min, overall_y_max)

    # Add a single legend to the figure with increased font size
    fig.legend(handles=legend_handles, bbox_to_anchor=(1.02, 0.5), loc='center left', fontsize=16)
    
    plt.tight_layout()
    safe_reporter_name = reporter.split("'")[-2].replace('.', '_')
    plt.savefig(f"{safe_reporter_name}_{train_size}.png", bbox_inches='tight', dpi=300)
    plt.show()

In [None]:
# Take oracle from any run, as they are independently calculated
working_dir = Path(r"experiments\diversify_remake\thesis_summaries_merged")
combined_dir = working_dir / "combined"
csv_oracle_path = Path(r"experiments\diversify_remake\thesis_summaries_merged\summary_500_contrib_merged.csv")
# Visualize for different reporters
reporters = [
    "<class 'probes.LRProbe'>",
    "<class 'probes.MMProbe'>",
    "<class 'probes.CCSProbe'>"
]

# Default 500
train_size = 500
transfer_types = ["full transfer", "no transfer"]
aggregation_strategies = ["Fixed Total", "Fixed Contribution"]
combined_csv_path = combined_dir / f"combined_data_{train_size}.csv"
for reporter in reporters:
    visualize_combined(combined_csv_path, csv_oracle_path, reporter, train_size, transfer_types, aggregation_strategies)

# Default 1000
train_size = 1000
transfer_types = ["full transfer", "no transfer"]
aggregation_strategies = ["Fixed Total"]
combined_csv_path = combined_dir / f"combined_data_{train_size}.csv"
for reporter in reporters:
    visualize_combined(combined_csv_path, csv_oracle_path, reporter, train_size, transfer_types, aggregation_strategies)

# Require Spanish
csv_oracle_path = Path(r"experiments\diversify_remake\thesis_summaries_merged\summary_500_contrib_merged.csv")
train_size = 500
transfer_types = ["full transfer"]
aggregation_strategies = ["Fixed Total", "Fixed Contribution"]
combined_csv_path = combined_dir / f"combined_data_spanish.csv"
for reporter in ["<class 'probes.LRProbe'>"]:
    visualize_combined(combined_csv_path, csv_oracle_path, reporter, train_size, transfer_types, aggregation_strategies)

# Contrastive Inference
reporters = [
    ("<class 'probes.LRProbe'>_tuple_inference","<class 'probes.LRProbe'>"),
    ("<class 'probes.MMProbe'>_tuple_inference","<class 'probes.MMProbe'>"),
]
csv_oracle_path = Path(r"experiments\diversify_remake\thesis_summaries_merged\summary_500_contrib_merged.csv")
train_size = 500
transfer_types = ["full transfer", "no transfer"]
aggregation_strategies = ["Fixed Total"]
combined_csv_path = combined_dir / f"combined_data_contrastive.csv"
reference_csv_path = combined_dir / f"combined_data_{train_size}.csv"
for reporter, reference_reporter in reporters:
    visualize_combined(combined_csv_path, csv_oracle_path, reporter, train_size, transfer_types, aggregation_strategies, contrastive_inference=True, reference_csv_path=reference_csv_path, reference_reporter=reference_reporter)

## Conceptual Illustrations

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Set the style and color palette
sns.set_style("whitegrid")
colors = ["#E64B35", "#4DBBD5", "#00A087"]
datasets = ["Dataset A", "Dataset B", "Dataset C"]

# Hardcoded data
data = {
    "point-wise": [
        [0, 0, 1, 0, 0, 0, 0, 1, 0, 0],  # Dataset A
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # Dataset B
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]   # Dataset C
    ],
    "local": [
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # Dataset A
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # Dataset B
        [1, 1, 0, 1, 1, 1, 1, 1, 0, 1]   # Dataset C
    ],
    "global": [
        [1, 1, 1, 1, 0, 1, 1, 0, 1, 0],  # Dataset A
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],  # Dataset B
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]   # Dataset C
    ]
}

def plot_frequency(ax, data, leftmost=False):
    for i, d in enumerate(data):
        x = np.arange(i*10 + 1, (i+1)*10 + 1)
        ax.scatter(x, d, color=colors[i], s=20, label=datasets[i] if leftmost else "")
   
    ax.set_xlabel("Samples", fontsize="medium")
    if leftmost:
        ax.set_ylabel("Co-occurrence")
    ax.set_xlim(0, 31)
    ax.set_ylim(-0.1, 1.1)
    ax.set_yticks([0, 1])
    ax.set_xticklabels([])

def plot_barchart(ax, data, leftmost=False):
    correlations = [np.mean(d) for d in data]
    x = np.arange(len(datasets))
    ax.bar(x, correlations, color=colors)
    if leftmost:
        ax.set_ylabel("Empirical Correlation")
    ax.set_ylim(0, 1)
    ax.set_xticks(x)
    ax.set_xticklabels(datasets, fontsize="medium")

# Create the figure
fig, axs = plt.subplots(2, 3, figsize=(12, 3))

for i, corr_type in enumerate(["point-wise", "local", "global"]):
    plot_frequency(axs[0, i], data[corr_type], leftmost=(i==0))
    plot_barchart(axs[1, i], data[corr_type], leftmost=(i==0))
    axs[0, i].set_title(corr_type.capitalize(), fontsize="large")

plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Set the Seaborn style
sns.set_style("whitegrid")

# Create a new figure
fig, ax = plt.subplots(figsize=(8, 8))

# Set equal aspect ratio
ax.set_aspect('equal')

# Plot axes
ax.axhline(y=0, color='k', linewidth=0.5, zorder=0)
ax.axvline(x=0, color='k', linewidth=0.5, zorder=0)

# Plot the blue vector
ax.arrow(0, 0, 1, 0, color='blue', width=0.015, head_width=0.08, head_length=0.08, length_includes_head=True, zorder=1)

# Plot the red vectors along the axes (dotted)
ax.arrow(0, 0, 0.5, 0, color='red', width=0.008, head_width=0.06, head_length=0.06, length_includes_head=True, linestyle=':', zorder=2)
ax.arrow(0, 0, 0, 0.5, color='red', width=0.008, head_width=0.06, head_length=0.06, length_includes_head=True, linestyle=':', zorder=2)

# Plot the aggregated red vector (solid)
ax.arrow(0, 0, 0.5, 0.5, color='red', width=0.015, head_width=0.08, head_length=0.08, length_includes_head=True, zorder=3)

# Set the limits to start slightly before 0 and extend slightly beyond 1
ax.set_xlim(-0.08, 1.08)
ax.set_ylim(-0.08, 0.55)

# Set ticks
ax.set_xticks([0, 0.5, 1])
ax.set_yticks([0, 0.5])

# Add labels with larger font size
ax.set_xlabel('Frequency of Feature 1', fontsize=14)
ax.set_ylabel('Frequency of Feature 2', fontsize=14)

# Remove top and right spines
sns.despine()

# Increase font size for all text elements
plt.rcParams.update({'font.size': 14})

# Show the plot
plt.show()

# Save the figure
plt.savefig('thm_cosine_illustration.png', dpi=300, bbox_inches='tight')
plt.close()

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Set the style
sns.set_style("whitegrid")

# Create data
diversity = np.arange(1, 9)
n_datasets = len(diversity)

# Constants
n_l = 7  # number of local features per dataset
n_g = 5   # number of global features per dataset
n_g_total = 20  # total number of global features
n_pw = 100  # number of point-wise features in a dataset of size n
n = 20  # default dataset size

# Functions to generate data for each feature type and metric
def point_wise_number(d):
    return np.full_like(d, n_pw)

def point_wise_frequency(d):
    return np.ones_like(d) / n # For FT
    # return n / (n * d)  # For FC, uncomment this line and comment the above line

def point_wise_aggregated_salience(d):
    return np.sqrt(point_wise_number(d) * point_wise_frequency(d)**2)

def point_wise_usefulness(d):
    return np.zeros_like(d)  # Point-wise features are not useful for predicting truth

def local_number(d):
    return n_l * d

def local_frequency(d):
    return 1 / d

def local_aggregated_salience(d):
    return np.sqrt(local_number(d) * local_frequency(d)**2)

def local_usefulness(d):
    return 1 / d  # Usefulness decreases with diversity for "Seen" scenario
    # return np.zeros_like(d)  # For "Unseen" scenario, uncomment this line and comment the above line

def global_number(d):
    return n_g_total * (1 - np.exp(-d * n_g / n_g_total))

def global_frequency(d):
    return 1 - (1 - n_g / n_g_total) * (1 - np.exp(-(d-1)))

def global_usefulness_seen(d):
    return 1 - (1 - n_g / n_g_total) * (1 - np.exp(-(d-1)))

def global_usefulness_unseen(d):
    return np.ones_like(d) * n_g / n_g_total

def global_aggregated_salience(d):
    return np.sqrt(global_number(d) * global_frequency(d)**2)

# Create the plot
fig = plt.figure(figsize=(15, 11))
gs = fig.add_gridspec(4, 2)
ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[1, 0])
ax3 = fig.add_subplot(gs[2, 0])
ax4 = fig.add_subplot(gs[1, 1])
ax5 = fig.add_subplot(gs[2, 1])
# fig.suptitle('Feature Characteristics vs Diversity', fontsize=16)

# Colors for each feature type
colors = {'Point-wise': 'blue', 'Local': 'orange', 'Global': 'green'}

# Plot Number
for feature, color in colors.items():
    y = globals()[f"{feature.lower().replace('-', '_')}_number"](diversity)
    ax1.plot(diversity, y, color=color, label=feature)
ax1.set_ylabel('Number of Features')
ax1.axhline(y=n_g_total, color='green', linestyle='--', alpha=0.5, label='Total Global Features')
ax1.axhline(y=n_l, color='orange', linestyle='--', alpha=0.5, label='Local Features per Dataset')
ax1.text(-0.1, 1.0, '(a)', transform=ax1.transAxes, fontsize=12, fontweight='bold', va='top', ha='right')
ax1.set_yticks([0])
ax1.set_ylim(-5, None)

# Plot Frequency
for feature, color in colors.items():
    y = globals()[f"{feature.lower().replace('-', '_')}_frequency"](diversity)
    ax2.plot(diversity, y, color=color)
ax2.set_ylabel('Frequency')
ax2.axhline(y=n_g/n_g_total, color='green', linestyle=':', alpha=0.5, label='Global Feature Prevalence')
ax2.text(-0.1, 1.0, '(b)', transform=ax2.transAxes, fontsize=12, fontweight='bold', va='top', ha='right')
ax2.set_yticks([0, 1])
ax2.set_ylim(-0.05, 1.05)  # Adjusted to show the full range

# Plot Aggregated Salience
for feature, color in colors.items():
    y = globals()[f"{feature.lower().replace('-', '_')}_aggregated_salience"](diversity)
    ax3.plot(diversity, y, color=color)
ax3.set_ylabel('Aggregated Salience\n(unnormalized)')
ax3.set_xlabel('Diversity d of Training data')
ax3.text(-0.1, 1.0, '(c)', transform=ax3.transAxes, fontsize=12, fontweight='bold', va='top', ha='right')
ax3.set_yticks([])
ax3.set_ylim(0, 5)

# Plot Prevalence on Target (Seen)
for feature, color in colors.items():
    if feature == 'Point-wise':
        y = np.zeros_like(diversity) - 0.0
    elif feature == 'Local':
        y = 1 / diversity
    else:
        y = global_usefulness_seen(diversity)
    ax4.plot(diversity, y, color=color)
ax4.set_ylabel('Predictive Relevance\n(Seen)')
ax4.axhline(y=n_g/n_g_total, color='green', linestyle=':', alpha=0.5)
ax4.text(-0.1, 1.0, '(d)', transform=ax4.transAxes, fontsize=12, fontweight='bold', va='top', ha='right')
ax4.set_yticks([0, 1])
ax4.set_ylim(-0.05, 1.05)  # Adjusted to show the full range

# Plot Predictive Relevance (Unseen)
for feature, color in colors.items():
    if feature == 'Local':
        y = np.zeros_like(diversity) - 0.0045
    elif feature == 'Point-wise':
        y = np.zeros_like(diversity) + 0.0045
    else:
        y = global_usefulness_unseen(diversity)
    ax5.plot(diversity, y, color=color)
ax5.set_ylabel('Predictive Relevance\n(Unseen)')
ax5.set_xlabel('Diversity d of Training data')
ax5.set_ylim(-0.1, 1)
ax5.axhline(y=n_g/n_g_total, color='green', linestyle=':', alpha=0.5)
ax5.text(-0.1, 1.0, '(e)', transform=ax5.transAxes, fontsize=12, fontweight='bold', va='top', ha='right')
ax5.set_yticks([0, 1])
ax5.set_ylim(-0.05, 1.05)  # Adjusted to show the full range

# Remove x-ticks
for ax in [ax1, ax2, ax3, ax4, ax5]:
    ax.set_xticks([1])
    ax.set_xticklabels([1])

# Add a single legend for all subplots
handles, labels = ax1.get_legend_handles_labels()
dotted_line = plt.Line2D([0], [0], color='green', linestyle=':', label='Global Feature Prevalence')
handles.append(dotted_line)
labels.append('Global Feature Prevalence')
fig.legend(handles, labels, loc='upper right', bbox_to_anchor=(0.98, 0.98))

# Adjust layout and display
plt.tight_layout()
plt.show()