In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os 
df = pd.read_csv("df_metrics.csv")

In [2]:
df['model'] = df['model'].str.replace('Multimodal Bi-LSTM', 'Multimodal BiLSTM', regex=False)

In [3]:
import matplotlib as mpl
mpl.rcParams['font.family'] = 'Arial'

In [4]:
def plot_metric(df, metric_name, save_path='./figs/'):
    os.makedirs(save_path, exist_ok=True)

    plt.figure(figsize=(10, 6))
    models = df['model'].unique()
    k = list(range(1, 11)) 
    
    all_values = [] 
    models_order = [
        'Multimodal Transformer',
        'Multimodal BiLSTM',
        'BiLSTM',
        'Random Forest',
        'Transformer',
        'Neural Network',
        'Random Classifier'
    ]

    color_dict = {
        'Random Classifier': plt.cm.tab10(0),  
        'Neural Network': plt.cm.tab10(1),  
        'Random Forest': plt.cm.tab10(3),  
        'BiLSTM': plt.cm.tab10(2), 
        'Transformer': plt.cm.tab10(4), 
        'Multimodal BiLSTM': plt.cm.tab10(9), 
        'Multimodal Transformer': plt.cm.tab10(6), 
    }

    offsets = {
        'Multimodal Transformer': -0.003,
        'Multimodal BiLSTM': 0.003,
        'BiLSTM': 0,
        'Random Forest': 0.002,
        'Transformer': -0.002,
        'Neural Network': 0,
        'Random Classifier': 0
    }

    for model in models_order:
        if model in df['model'].values: 
            model_data = df[df['model'] == model].reset_index(drop=True)
            y = model_data[metric_name].values[:10]
            all_values.extend(y)
            color = color_dict.get(model, 'black') 
            plt.plot(k, y, label=model, marker='o', color = color)
            y_pos = y[-1] + offsets.get(model, 0)
            plt.text(
                10 + 0.15,      
                y_pos,          
                model.replace("Multimodal ", "M."),  
                color=color,
                fontsize=9,
                va='center',
                fontname='Arial'
            )

    plt.xlabel("Number of predictions", fontsize=12, fontname='Arial')
    plt.ylabel(metric_name.capitalize(), fontsize=12, fontname='Arial')
    # plt.title(f"{metric_name.capitalize()} across models for different numbers of predictions", fontsize=14)
    # plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.10),
    #       fancybox=True, shadow=True, ncol=4)

    plt.grid(True)

    ymin = min(all_values)
    ymax = max(all_values)
    step = 0.02
    yticks = np.arange(np.floor(ymin * 100) / 100, np.ceil(ymax * 100) / 100 + step, step)
    plt.yticks(yticks)

    plt.xlim(0.9, 10.1)
    xticks = [1,2,3,4,5,6,7,8,9,10]
    plt.xticks(xticks)
    plt.tight_layout()

    filename = f"{save_path}{metric_name}_por_k.png"
    plt.savefig(f"{save_path}{metric_name}_por_k.pdf", format="pdf")  
    plt.savefig(filename)
    plt.close()

# plot_metric(df, 'map')
plot_metric(df, 'precision')
plot_metric(df, 'recall')
plot_metric(df, 'f1')