In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
# Load Dataset for the plots:
all_data = pd.read_csv('../model_results/analysed_data/entropies_per_question.csv')

In [None]:
question_types = [1, 2, 3, 4, 5]
subsets_of_correctness = ['-1', 'True']
model_names = ['Llama3-8b', 'Llama3-70b', 'Yi-34b', 'Mistral-7b']
correlation_of_interest = 'spearman_entropy'
correlation_of_interest_p_value = 'spearman_p_entropy'

n_question_types = len(question_types)
n_correctness = len(subsets_of_correctness)

data = np.zeros((4, n_correctness, n_question_types))
significances = np.zeros((4, n_correctness, n_question_types))
for model_name in model_names:
    for correctness in ['-1', 'True']:
        for question_type in question_types:
            subset = all_data[(all_data['model_name'] == model_name) & (all_data['model_correctness'] == correctness) & (all_data['question_subset'] == question_type)]
            data[model_names.index(model_name), subsets_of_correctness.index(correctness), question_types.index(question_type)] = subset[correlation_of_interest]  # The subset should only contain 1 row.
            significances[model_names.index(model_name), subsets_of_correctness.index(correctness), question_types.index(question_type)] = subset[correlation_of_interest_p_value] < 0.05

# Generate Plot

In [None]:
# Set up the figure and subplots
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(
    10, 10), sharey=True, sharex=True)

# Set the positions and width for the bars
positions = np.arange(n_question_types)
width = 0.35  # the width of the bars

# Define colors for each bar
colors = ['#61A958', '#3339AC']

# Plot data
for idx, ax in enumerate(axes.flatten()):
    for bar_idx in range(n_correctness):
        ax.bar(positions + bar_idx * width, data[idx, bar_idx], width,
               label=f'{"All" if bar_idx == 0 else "Only Correct"}', color=colors[bar_idx])

    # Set model name as title
    ax.text(0.5, 0.93, model_names[idx], fontsize=20,
            horizontalalignment='center',
            transform=ax.transAxes)

    # Set x-axis labels
    ax.set_xticks(positions + width / 2)
    ax.set_xticklabels(["Type 1", "Type 2", "Type 3",
                       "Type 4", "Type 5"], fontsize=17)

    for bar_idx in range(n_correctness):
        for pos_idx, pos in enumerate(positions):
            # If the significance is 1, add asterisks
            if significances[idx, bar_idx, pos_idx] == 1:
                # Adjust the height to be above the bar
                height = data[idx, bar_idx, pos_idx] + 0.005
                ax.text(pos + bar_idx * width, height, '*', ha='center',
                        va='bottom', fontsize=16, color='black')

# Set the y limit of all plots to 0.8
for ax in axes.flatten():
    ax.set_ylim(-0.2, 1)

# Add a legend
fig.legend(loc='upper center', labels=[
           'All', 'Correct'], fontsize=17, bbox_to_anchor=(0.85, 0.05), ncol=2)

# x-axis label
fig.text(0.54, 0, 'Question Type', ha='center', fontsize=23)

# Add a main y-axis label
fig.text(0.02, 0.5, 'Spearman Correlation',
         va='center', rotation='vertical', fontsize=23)

# Main title
fig.suptitle('Student-Model Correlation of Choices Entropy',
             fontsize=25, x=0.54, y=0.95)

# Adjust the layout to prevent overlapping
plt.tight_layout(rect=[0.05, 0.05, 1, 0.96])

# save the plot
plt.savefig('plots/spear_entropy.png', dpi=200, bbox_inches='tight')
# Show plot
plt.show()