In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix, accuracy_score, balanced_accuracy_score, ConfusionMatrixDisplay, f1_score

In [None]:
def create_error_data(result_type, model):
    real = pd.read_csv(f'../clean_data/test_{result_type}.csv')
    pred = pd.read_csv(f'../clean_results/{model}/{result_type}.txt', sep='\t')
    y_true = real['label']
    y_pred = pred['prediction']
    real['prediction'] = y_pred
    real.to_csv(f'../clean_data/test_with_predictions_{model}_{result_type}.csv', index=False)

In [None]:
create_error_data('i', 'roberta-large')
create_error_data('i_context', 'roberta-large')
create_error_data('l', 'roberta-large')
create_error_data('l_context', 'roberta-large')

In [None]:
def create_plot(result_type, model):
    real = pd.read_csv(f'../clean_data/test_{result_type}.csv')
    pred = pd.read_csv(f'../clean_results/{model}/{result_type}.txt', sep='\t')

    y_true = real['label']
    y_pred = pred['prediction']
    labels = real['label'].unique()

    bal = balanced_accuracy_score(y_true, y_pred)
    macro_f1 = f1_score(y_true, y_pred, average='macro')

    fig = plt.figure()
    ax = fig.gca()

    strategy_name = ''
    if result_type == 'l':
        strategy_name = 'Loc'
    elif result_type == 'l_context':
        strategy_name = 'Loc+Context'
    elif result_type == 'i':
        strategy_name = 'Prop'
    elif result_type == 'i_context':
        strategy_name = 'Prop+Context'

    labels = ['inference', 'conflict', 'rephrase', 'no_rel']
    display_labels = ['inference', 'conflict', 'rephrase', 'no rel.']
    
    ax.set_title(f'{model}, {strategy_name}, macro f1: {macro_f1:.2f}')
    disp = ConfusionMatrixDisplay.from_predictions(y_true, y_pred, labels=labels, display_labels=display_labels, colorbar=False, ax=ax, xticks_rotation=45)
    plt.show()

In [None]:
create_plot('i', 'bert-large-cased')
create_plot('i_context', 'bert-large-cased')
create_plot('l', 'bert-large-cased')
create_plot('l_context', 'bert-large-cased')

create_plot('i', 'roberta-large')
create_plot('i_context', 'roberta-large')
create_plot('l', 'roberta-large')
create_plot('l_context', 'roberta-large')