In [None]:
import json

from models_under_pressure.config import RESULTS_DIR

result_files = [
    RESULTS_DIR / "prompts_28_02_25_Llama-3.2-1B-Instruct_language_fig2.json",
    RESULTS_DIR / "prompts_28_02_25_Llama-3.2-1B-Instruct_prompt_style_fig2.json",
    RESULTS_DIR / "prompts_28_02_25_Llama-3.2-1B-Instruct_None_fig2.json",
]

result_list = [json.load(
    open(result_file, "r")
) for result_file in result_files]
print(result_list[0])

In [16]:
def get_run_name(results: dict) -> str:
    """Extract metadata from results and create a run name string.
    
    Args:
        results: Dictionary containing model metadata
        
    Returns:
        String containing layer and variation type info for the run
    """
    #model_name = results["model_name"].split("/")[-1]
    layer = results["layer"]
    variation_type = results["variation_type"]
    variation_value = results["variation_value"]
    
    run_name = "layer=" + str(layer)
    if variation_type is not None:
        run_name += ",variation_type=" + variation_type
    if variation_value is not None:
        run_name += ",variation_value=" + variation_value
    return run_name

print(get_run_name(result_list[0]))


layer=11,variation_type=language


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Set width of bars and positions of the bars
bar_width = 0.3
x = np.arange(len(result_list[0]["datasets"]))

plt.figure(figsize=(8, 6))

# Create bars for each result
for i, results in enumerate(result_list):
    datasets = results["datasets"]
    auroc_scores = results["AUROC"]
    run_name = get_run_name(results)
    
    # Create offset bars
    offset = i * bar_width
    bars = plt.bar(x + offset, auroc_scores, bar_width, label=run_name)
    
    # Add value labels on top of bars
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.2f}',
                ha='center', va='bottom')

# Customize the plot
plt.ylim(0, 1)  # AUROC ranges from 0 to 1
plt.ylabel('AUROC Score')
plt.title('Generalization to non-AIS datasets')

# Center x-axis labels between grouped bars
plt.xticks(x + bar_width/2, datasets)
plt.legend()

plt.show()
