In [4]:
import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns


from models_under_pressure.scripts.train_probes import RESULTS_DIR

result_files = [
    f"generated_heatmap_{variation_type}.json" for variation_type in ["prompt_style", "tone", "language"]
]
result_list = [
    json.load(open(RESULTS_DIR / f))
    for f in result_files
]

#print(result_list)

def generate_heatmap_plot(result: dict):
    # Create dataframe from performances
    for layer in result["layers"]:
        performances = result["performances"][str(layer)]
        variation_values = result["variation_values"]
        
        # Create dataframe with rows=train variations, cols=test variations
        df = pd.DataFrame(
            performances,
            index=variation_values,
            columns=variation_values
        )
        
        # Create heatmap
        plt.figure(figsize=(8, 6))
        sns.heatmap(
            df,
            annot=True,
            fmt=".3f",
            cmap="RdBu",
            vmin=0,
            vmax=1
        )

        plt.title(f"Probe Generalization Across Variations, Layer {layer}")
        plt.xlabel("Test Variation") 
        plt.ylabel("Train Variation")
        plt.xticks(rotation=45, ha="right")
        plt.tight_layout()

        # Save plot
        plots_dir = RESULTS_DIR.parent / "plots"
        plots_dir.mkdir(exist_ok=True)
        plt.savefig(plots_dir / f"probe_generalisation_heatmap_layer_{layer}.png")
        plt.show()


In [None]:

generate_heatmap_plot(result_list[0])



In [None]:
generate_heatmap_plot(result_list[1])

In [None]:
generate_heatmap_plot(result_list[2])