In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
# Load Dataset for the plots:
all_data = pd.read_csv('../model_results/analysed_data/chi_squared_per_question.csv')

In [None]:
question_types = [1, 2, 3, 4, 5]
subsets_of_correctness = ['-1', 'True']
model_names = ['Llama3-8b', 'Llama3-70b', 'Yi-34b', 'Mistral-7b']
n_question_types = len(question_types)
n_correctness = len(subsets_of_correctness)

# Plotting Function

In [None]:
def create_plot(data, plot_title, output_filename):
    # Set up the figure and subplots
    fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(
        10, 10), sharey=True, sharex=True)

    # Set the positions and width for the bars
    positions = np.arange(n_question_types)
    width = 0.35  # the width of the bars

    # Define colors for each bar
    colors = ['#61A958', '#3339AC']

    # Plot data
    for idx, ax in enumerate(axes.flatten()):
        for bar_idx in range(n_correctness):
            ax.bar(positions + bar_idx * width, data[idx, bar_idx], width,
                label=f'{"All" if bar_idx == 0 else "Only Correct"}', color=colors[bar_idx])

        # Set model name as title
        ax.text(0.5, 0.93, model_names[idx], fontsize=20,
                horizontalalignment='center',
                transform=ax.transAxes)

        # Set x-axis labels
        ax.set_xticks(positions + width / 2)
        ax.set_xticklabels(["Type 1", "Type 2", "Type 3", "Type 4", "Type 5"], fontsize=17)

    for ax in axes.flatten():
        ax.set_ylim(0, 1200)

    # Add a legend
    fig.legend(loc='upper center', labels=[
            'All', 'Correct'], fontsize=17, bbox_to_anchor=(0.85, 0.05), ncol=2)

    # x-axis label
    fig.text(0.54, 0, 'Question Type', ha='center', fontsize=23)

    # Add a main y-axis label
    fig.text(0.02, 0.5, 'Chi-Squared Value (lower is better)',
            va='center', rotation='vertical', fontsize=23)

    # Main title
    fig.suptitle(plot_title, fontsize=25, x=0.54, y=0.95)

    # Adjust the layout to prevent overlapping
    plt.tight_layout(rect=[0.05, 0.05, 1, 0.96])

    plt.savefig('plots/'+output_filename+'.png', dpi=200, bbox_inches='tight')
    # Show plot
    plt.show()

# Chi-Squared: Logit

In [None]:
chi_squared_of_interest = 'logit_chi' # logit == 1st token

data = np.zeros((4, n_correctness, n_question_types))
for model_name in model_names:
    for correctness in ['-1', 'True']:
        for question_type in question_types:
            subset = all_data[(all_data['model_name'] == model_name) & (all_data['model_correctness'] == correctness) & (all_data['question_subset'] == question_type)]
            data[model_names.index(model_name), subsets_of_correctness.index(correctness), question_types.index(question_type)] = subset[chi_squared_of_interest]  # The subset should only contain 1 row.

create_plot(data, '1st Token Probability', 'chi_1st_token')

In [None]:
chi_squared_of_interest = 'order_chi' # logit == 1st token

data = np.zeros((4, n_correctness, n_question_types))
for model_name in model_names:
    for correctness in ['-1', 'True']:
        for question_type in question_types:
            subset = all_data[(all_data['model_name'] == model_name) & (all_data['model_correctness'] == correctness) & (all_data['question_subset'] == question_type)]
            data[model_names.index(model_name), subsets_of_correctness.index(correctness), question_types.index(question_type)] = subset[chi_squared_of_interest]  # The subset should only contain 1 row.

create_plot(data, 'Choice Order Sensitivity', 'chi_order')