In [None]:
import pandas as pd
import json
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
def f1_score(true, pred_result):
    correct = 0
    total = len(true)
    correct_positive = 0
    pred_positive = 0
    gold_positive = 0

    for i in range(total):
        golden = true[i]
        if golden == pred_result[i]:
            correct += 1
            if golden not in ['NA', 'na', 'no_relation', 'Other', 'Others', 'false', 'unanswerable', 'NONE']:
                correct_positive += 1
        if golden not in ['NA', 'na', 'no_relation', 'Other', 'Others', 'false', 'unanswerable', 'NONE']:
            gold_positive +=1
        if pred_result[i] not in ['NA', 'na', 'no_relation', 'Other', 'Others', 'false', 'unanswerable', 'NONE']:
            pred_positive += 1
    acc = float(correct) / float(total)
    try:
        micro_p = float(correct_positive) / float(pred_positive)
    except:
        micro_p = 0
    try:
        micro_r = float(correct_positive) / float(gold_positive)
    except:
        micro_r = 0
    try:
        micro_f1 = 2 * micro_p * micro_r / (micro_p + micro_r)
    except:
        micro_f1 = 0
    result = {'acc': acc, 'micro_p': micro_p, 'micro_r': micro_r, 'micro_f1': micro_f1}
    return result

In [None]:
def f1_score_na(true, pred_result):
    correct = 0
    total = len(true)
    correct_positive = 0
    pred_positive = 0
    gold_positive = 0

    for i in range(total):
        golden = true[i]
        if golden == pred_result[i]:
            correct += 1
            correct_positive += 1
        gold_positive +=1
        pred_positive += 1
    acc = float(correct) / float(total)
    try:
        micro_p = float(correct_positive) / float(pred_positive)
    except:
        micro_p = 0
    try:
        micro_r = float(correct_positive) / float(gold_positive)
    except:
        micro_r = 0
    try:
        micro_f1 = 2 * micro_p * micro_r / (micro_p + micro_r)
    except:
        micro_f1 = 0
    result = {'acc': acc, 'micro_p': micro_p, 'micro_r': micro_r, 'micro_f1': micro_f1}
    return result

In [None]:
methods = ['KnowPrompt', '020', 'Roberta_base', 'GenPT/Bart', 'GenPT/T5', 'GenPT/roberta', 'GPT-RE', 'DeepKE']

In [None]:
dataset = ['crossRE', 'NYT10', 'FewRel', 'tacred', 'retacred', 'WebNLG', 'semeval_nodir']

In [None]:
df = pd.DataFrame(columns=['Method','Dataset','k', 'f1','p','r'])
for method in methods:
    print(method)
    if method.split('/')[0]=='GenPT':
        name = 'GenPT'
    else:
        name = method
        
    for data in dataset:
        for k in [1, 5,10,20,30]:
            f1 = []
            p = []
            r = []
            for seed in [13, 42, 100]:
                try:
                    res_file = f'{base_path}/{method}/few_output/{data}/{seed}-{k}/{name}_test.jsonl'
                    with open(res_file) as f:
                        batch = f.read().splitlines()
                    batch = [json.loads(line) for line in batch if line != '']

                    true_label = [x['label_true'] for x in batch]
                    pred_label = [x['label_pred'] for x in batch]

                    results = f1_score(true_label, pred_label)

                    f1.append(results['micro_f1'])
                    p.append(results['micro_p'])
                    r.append(results['micro_r'])
                except:
                    # print(f'Missing {data}, {k}, {seed}')
                    continue

            new_row = {
                'Method': method,
                'Dataset': data,
                'k': k,
                'f1': np.mean(f1),
                'p': np.mean(p),
                'r': np.mean(r)
            }
            df = df.append(new_row, ignore_index=True)
    print('\n')
        
    

In [None]:
g_df = df.groupby(['Method', 'Dataset', 'k'], as_index=False).mean()

In [None]:
g_df['Method'] = g_df.apply(lambda x: 'RBERT' if x['Method']=='020' else 'UnleashLLM'
                            if x['Method']=='DeepKE' else x['Method'], axis=1)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

# Define the k values and methods
k_values = [1, 5, 10, 20, 30]
methods = g_df['Method'].unique()

# Create a new DataFrame for the multi-level index (Method, k_value)
g_df['k_str'] = 'k=' + g_df['k'].astype(str)  # Create string versions of the k values for labeling

# Create figure and axes
fig, ax = plt.subplots(figsize=(22, 6))

# Create the boxplot
sns.boxplot(x=g_df['Method'] + ' ' + g_df['k_str'], y='f1', data=g_df, palette='Set3',
            boxprops={'facecolor': 'none', 'edgecolor': 'black'},  # Transparent box with black edges
            whiskerprops={'color': 'black'},
            capprops={'color': 'black'},
            medianprops={'color': 'black'}, ax=ax)

# Overlay stripplot to show individual data points with different colors
sns.stripplot(x=g_df['Method'] + ' ' + g_df['k_str'], y='f1', hue='Dataset', data=g_df, dodge=False, palette='bright',
              marker='o', alpha=0.7, size=8, ax=ax)

# Set axis labels and title
ax.set_ylabel('F1 Score', fontsize=14)
# ax.set_xlabel('Methods', fontsize=14)

# Increase fontsize of x and y ticks
ax.tick_params(axis='x', labelsize=14)
ax.tick_params(axis='y', labelsize=14)

# Set y-axis limits
ax.set_ylim(0, 1)

# Get the current axis
ax.set_xticklabels([])  # Remove the current x-axis labels

# Define primary x-axis ticks (for methods + k values)
primary_xtick_positions = np.arange(len(methods) * len(k_values))  # Assumes each method has all k_values
ax.set_xticks(primary_xtick_positions)

# Get combined labels for methods + k values (display only k values)
method_k_labels = [f"{k}" for method in methods for k in k_values]
ax.set_xticklabels(method_k_labels, rotation=0, ha='right', fontsize=12)

# Secondary x-axis for grouping by methods (instead of k-values)
sec = ax.secondary_xaxis(location=0)

# Generate the method labels with one label per group, the rest will be empty
method_labels = []
for method in methods:
    method_labels.extend(['', '', method, '', ''])

# Set the secondary x-axis labels
sec.set_xticks(np.arange(len(method_labels)), labels=[f'\n\n{label}' for label in method_labels], fontsize=14)

# Secondary x-axis to visually separate the groups of methods
sec2 = ax.secondary_xaxis(location=0)
sec2.set_xticks(np.linspace(-0.5, len(method_k_labels)-0.5, len(methods)+1), labels=[])
sec2.tick_params('x', length=40, width=1.5)

# Add the legend
handles, labels = ax.get_legend_handles_labels()
custom_hue_labels = ['FewRel', 'NYT10', 'WebNLG', 'CrossRE', 'RETACRED', 'SemEval', 'TACRED']  # Replace with your custom hue labels

plt.legend(handles, custom_hue_labels, title='Datasets', loc='upper center', fontsize=14, title_fontsize=14,
           bbox_to_anchor=(0.5, 1.2), ncol=len(custom_hue_labels), frameon=False)

# Adjust layout for a better fit
plt.tight_layout()

# Save the entire figure as a PNG file
plt.savefig('./images/few_all_subplots.png', format='png', dpi=600, bbox_inches='tight')  # Save with 600 DPI

# Show the plot
plt.show()
