## Import libraries

In [None]:
import os
import json

import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

In [None]:
def read_memmap(filepath):
    with open(filepath.replace(".dat", ".conf"), "r") as fin_config:
        memmap_configs = json.load(fin_config)
        return np.memmap(filepath, mode="r", shape=tuple(memmap_configs["shape"]), dtype=memmap_configs["dtype"])

In [None]:
result_dir = "../../../results"

## Embedding visualization

In [None]:
os.makedirs(f"embedding_2d_plot", exist_ok=True)

for dataset_name in ["100TFQA", "CommonsenseQA", "QASC", "GSM8K"]:
    for model_name in ["Llama-3.1-8B-Instruct", "Phi-3.5-mini-instruct"]:
        for prompting_strategy in ["zero-shot", "zero-shot-cot", "few-shot", "few-shot-cot"]:
            output_dir = f"{result_dir}/{dataset_name}/{model_name}"
            layer_wise_path = os.path.join(output_dir, f"{prompting_strategy}_layer_wise_hidden_states.dat")
            head_wise_path = os.path.join(output_dir, f"{prompting_strategy}_head_wise_hidden_states.dat")

            try:
                layer_wise_hidden_states = read_memmap(layer_wise_path)
            except:
                continue
            num_samples, num_formats, num_layers, hidden_size = layer_wise_hidden_states.shape
            fig, axes = plt.subplots(1, 1, figsize=(6, 5))
            for layer_idx in range(num_layers):
                # if layer_idx not in [0, 7, 15, 23, 31]:
                if layer_idx not in [31]:
                    continue
                # Step 1: Prepare input
                X = layer_wise_hidden_states[:,:,layer_idx,:].reshape(-1, hidden_size)
                Y = np.tile(np.arange(num_formats), num_samples)

                # Step 2: PCA Projection
                pca = PCA(n_components=2)
                X_pca = pca.fit_transform(X)

                # Step 3: Plot PCA
                scatter = axes.scatter(X_pca[:, 0], X_pca[:, 1], c=Y, cmap="Accent", alpha=0.7)
                axes.set_title(f"Layer {layer_idx}")
                axes.set_xlabel("PC1")
                axes.set_ylabel("PC2")

                fig.colorbar(scatter, ticks=range(num_formats))
            plt.tight_layout()
            plt.savefig(f"embedding_2d_plot/pca_2d_{dataset_name}_{model_name}_{prompting_strategy}.pdf")
            plt.show()
            print(f"{dataset_name} / {model_name} / {prompting_strategy} (above)")