In [1]:
import os
import matplotlib.pyplot as plt
import numpy as np
from colors import COLOR_MAP
import pandas as pd

In [2]:

PLOT_DIR = "../figure_lamma1vs2"
if not os.path.exists(PLOT_DIR):
    os.makedirs(PLOT_DIR)

# Define the global fontsize
GLOBAL_FONTSIZE = 40

In [3]:
models=["LLaMa1","LLaMa1-S","LLaMa2","LLaMa2-S"]

model_result_path={
    "Llama-1":'/share/data/mei-work/kangrui/github/mango/kangrui/eval_results/results_llama_overall_0713/llama',
    "Llama-1-S":'/share/data/mei-work/kangrui/github/mango/kangrui/eval_results/results_llama_overall_0713/llama_anno',
    "Llama-2":'/share/data/mei-work/kangrui/github/mango/kangrui/eval_results/results_llama2_13b_base_overall_0816/llama',
    "Llama-2-S":'/share/data/mei-work/kangrui/github/mango/kangrui/eval_results/results_llama2_13b_base_overall_0816/llama_anno',
}

In [4]:
def getresult(path):
    data_all={}
    for dirpath, _, filenames in os.walk(path):
        for file in filenames:
            if file.endswith('.csv') and 'loose' in file:
                full_path = os.path.join(dirpath, file)
                df=pd.read_csv(full_path)
                last_row=df.iloc[-1]
                
                data={
                    'easy_success_rate':last_row['easy_success_rate'],
                    'hard_success_rate':last_row['hard_success_rate'],
                    'easy_reasoning_acc':last_row['easy_reasoning_acc'],
                    'hard_reasoning_acc':last_row['hard_reasoning_acc'],
                }
                if 'stepnav' in full_path:
                    data_all['df']=data
                else:
                    data_all['rf']=data
    return data_all

In [5]:
model_results_all={}
for k,v in model_result_path.items():
    model_results_all[k]=getresult(v)

In [6]:
model_names = [
    "Llama-1",
    "Llama-1-S",
    "Llama-2",
    "Llama-2-S",
]

In [7]:
def plot(name,model_results):
    y_range=(0,1)
    fig = plt.figure(figsize=(16, 16))
    plt.tight_layout()

    colors = [COLOR_MAP[model_name] for model_name in model_names]
    bars=plt.bar(
        model_names,
        [model_results[model_name] for model_name in model_names],
        color=colors,
    )
    for bar in bars:
        yval = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2, yval + 0.01, "{:.2f}".format(yval), ha='center', va='bottom',fontsize=40)
    # Set the y range
    plt.ylim(y_range)
    plt.yticks(fontsize=GLOBAL_FONTSIZE)
    plt.xticks(fontsize=GLOBAL_FONTSIZE, rotation=45)
    if name.endswith('reasoning'):
        plt.ylabel('Reasoning Accuracy',fontsize=GLOBAL_FONTSIZE+5)
    else:
        plt.ylabel('Success Rate',fontsize=GLOBAL_FONTSIZE+5)
    # horizontal line at y=0.8
    plt.axhline(y=0.2, color="lightcoral", linestyle="--")
    plt.axhline(y=0.4, color="lightcoral", linestyle="--")
    plt.axhline(y=0.6, color="lightcoral", linestyle="--")
    plt.axhline(y=0.8, color="lightcoral", linestyle="--")

    # save
    plt.savefig(f"{PLOT_DIR}/{name}.pdf", dpi=300, bbox_inches="tight")
    plt.close()

In [8]:
for task_type in ['df','rf']:
    for difficulty in ['easy','hard']:
        for metrics in ['success_rate','reasoning_acc']:
            name=task_type+'_'+difficulty+'_'+metrics+'lamma1vs2'
            model_results={}
            for k,v in model_results_all.items():
                model_results[k]=model_results_all[k][task_type][difficulty+'_'+metrics]
            plot(name,model_results)
            