In [74]:
import re
import json
import numpy as np
import pandas as pd
from prettytable import PrettyTable
from tqdm.auto import tqdm
from collections import Counter
from omegaconf import OmegaConf
from sklearn.metrics import f1_score, multilabel_confusion_matrix, matthews_corrcoef, confusion_matrix
from scipy.stats import spearmanr
from scipy.stats import pointbiserialr



from cgeval.method import ClassifyAndCount, StandardClassification, BCC
from cgeval.rating import Ratings, Label, Observation

In [75]:
MODEL = {
    'llama 2': '../../out/pipeline/2025-05-15_sentiment_analysis_llama2',
    'llama 3.3': '../../out/pipeline/2025-05-15_sentiment_analysis_llama3-3',
    'mistral': '../../out/pipeline/2025-05-15_sentiment_analysis_mistral',
}

REPORT_PATH = '../../out/pipeline/2025-04-30_15-26_sentiment_analysis_llama2'
CONFIG = '../../config.yaml'

In [76]:
def label_name_to_id(name: str, labels) -> int:
    return next((l['id'] for l in labels if l['name'] == name), None)

def label_match_to_id(match: str, matching_label: str) -> int:
    return int(match == matching_label) if match is not None else match

def extract_sentiment(input):
    m = re.search('The story should have a (.+?) sentiment', input)
    if m:
        found = m.group(1)
        return found

    return input

def load_ratings(cfg, classifier, report_path):
    with open(f"{report_path}/evaluate/dataset_{classifier.id}.json", 'r') as f:
        ratings_data = json.load(f)

    observations = list(map(lambda i: Observation(
        id=i['id'],
        output=i['output'],
        input=label_name_to_id(i['input'], classifier.labels),
        oracle=label_name_to_id(i['oracle'], classifier.labels),
        metric=label_name_to_id(i['metric'], classifier.labels)
    ), ratings_data))

    labels = list(map(lambda l: Label(**l), classifier.labels))

    return Ratings(labels=labels, observations=observations)

def load_binary_ratings(classifier, report_path):
    df = pd.read_json(f"{report_path}/evaluate/dataset_{classifier.id}.json", orient='records')

    df['condition'] = df['input'].apply(extract_sentiment)
    df['oracle'] = df.apply(lambda r: r['oracle'] == r['condition'] if(pd.notnull(r['oracle'])) else r['oracle'], axis=1)
    df['metric'] = df.apply(lambda r: r['metric'] == r['condition'], axis=1)

    observations = df.to_dict(orient='records')
    observations = list(map(lambda i: Observation(
        id=i['id'],
        output=i['output'],
        input=1,
        oracle=label_match_to_id(i['oracle'], True),
        metric=label_match_to_id(i['metric'], True)
    ), observations))

    labels = list(map(lambda l: Label(**l), [{'id': 0, 'name': 'match'},{'id': 1, 'name': 'no_match'}]))

    return Ratings(labels=labels, observations=observations)

In [77]:
def compute_mixture_matrix(rating):
    items = [o for o in rating.observations if o.oracle is not None]

    oracle_ratings = [o.oracle for o in items]
    metric_ratings = [o.metric for o in items]


    return multilabel_confusion_matrix(oracle_ratings, metric_ratings, labels=rating.get_label_ids())

def compute_single_mixture_matrix(rating):
    items = [o for o in rating.observations if o.oracle is not None]

    oracle_ratings = [o.oracle for o in items]
    metric_ratings = [o.metric for o in items]


    return confusion_matrix(oracle_ratings, metric_ratings, labels=rating.get_label_ids())

def compute_f1_score(rating):
    items = [o for o in rating.observations if o.oracle is not None]

    oracle_ratings = [o.oracle for o in items]
    metric_ratings = [o.metric for o in items]


    return f1_score(oracle_ratings, metric_ratings, labels=rating.get_label_ids(), average='macro')


def compute_spearman(rating):
    items = [o for o in rating.observations if o.oracle is not None]

    oracle_ratings = [o.oracle for o in items]
    metric_ratings = [o.metric for o in items]


    return spearmanr(oracle_ratings, metric_ratings)


def compute_pointbiserialr(rating):
    items = [o for o in rating.observations if o.oracle is not None]

    oracle_ratings = [o.oracle for o in items]
    metric_ratings = [o.metric for o in items]


    return pointbiserialr(oracle_ratings, metric_ratings)

def compute_matthews_corrcoef(rating):
    items = [o for o in rating.observations if o.oracle is not None]

    oracle_ratings = [o.oracle for o in items]
    metric_ratings = [o.metric for o in items]

    return matthews_corrcoef(oracle_ratings, metric_ratings)


def pretty_print_latex(latex_str):
    lines = latex_str.replace(r" \\ ", r" \\" + "\n").splitlines()
    formatted_lines = []
    indent_level = 0
    for line in lines:
        if r"\begin" in line:
            formatted_lines.append(line)
            indent_level += 1
        elif r"\end" in line:
            indent_level -= 1
            formatted_lines.append(line)
        else:
            formatted_lines.append("    " * indent_level + line)
    return "\n".join(formatted_lines)


In [78]:
cfg = OmegaConf.load(CONFIG)

tpr = np.zeros((3,3))
fpr = np.zeros((3,3))
f1 = np.zeros(3)
p = np.zeros(3)

pbar = tqdm(total=len(MODEL)*len(cfg.classifier))

for id, base_path in MODEL.items():
    pbar.set_description(f'Processing {id}')
    for i, cls in enumerate(cfg.classifier):
        pbar.update(1)
        r = load_ratings(cfg, cls, base_path)
        cm = compute_mixture_matrix(r)

        f1[i] += compute_f1_score(r)
        p[i] += compute_spearman(r).statistic

        for idx in r.get_label_ids():
            (tn, fp, fn, tp) = cm[idx].ravel()
            label_name = cls.labels[idx]['name']
            tpr[idx][i] += tp / (tp + fn)
            fpr[idx][i] += fp / (fp + tn)

pbar.close()

tpr = tpr / len(MODEL)
fpr = fpr / len(MODEL)
f1 = f1 / len(MODEL)
p = p / len(MODEL)

Processing mistral: 100%|██████████| 9/9 [5:33:04<00:00, 2220.48s/it]
Processing mistral: 100%|██████████| 9/9 [00:07<00:00,  1.21it/s]


In [79]:
pbar = tqdm(total=len(MODEL)*len(cfg.classifier))

tpr_match = np.zeros(3)
fpr_match = np.zeros(3)
for id, base_path in MODEL.items():
    pbar.set_description(f'Processing {id}')
    for i, cls in enumerate(cfg.classifier):
        pbar.update(1)
        r = load_binary_ratings(cls, base_path)
        cm = compute_single_mixture_matrix(r)
        (tn, fp, fn, tp) = cm.ravel()

        tpr_match[i] += tp / (tp + fn)
        fpr_match[i] += fp / (fp + tn)

tpr_match = tpr_match / len(MODEL)
fpr_match = fpr_match / len(MODEL)
tpr_match = np.round(tpr_match, 3)
fpr_match = np.round(fpr_match, 3)

tpr_match, fpr_match

Processing mistral: 100%|██████████| 9/9 [00:01<00:00,  7.41it/s]  

(array([0.327, 0.696, 0.844]), array([0.913, 0.   , 0.052]))

In [80]:
tpr = np.round(tpr, 3)
fpr = np.round(fpr, 3)
f1 = np.round(f1, 3)
p = np.round(p, 3)

In [81]:
t = PrettyTable()
t.add_column('', ['Macro F1', 'Spearman', 'TPR_positive', 'FPR_positive', 'TPR_neutral', 'FPR_neutral', 'TPR_negative', 'FPR_negative'])
t.add_column('FIB', [f1[0], p[0], tpr[0][0], fpr[0][0], tpr[1][0], fpr[1][0], tpr[2][0], fpr[2][0]])
t.add_column('DSS', [f1[1], p[1], tpr[0][1], fpr[0][1], tpr[1][1], fpr[1][1], tpr[2][1], fpr[2][1]])
t.add_column('LL3', [f1[2], p[2], tpr[0][2], fpr[0][2], tpr[1][2], fpr[1][2], tpr[2][2], fpr[2][2]])

t.add_row(['TPR_match',tpr_match[0],tpr_match[1],tpr_match[2]])
t.add_row(['FPR_match',fpr_match[0],fpr_match[1],fpr_match[2]])

t

Unnamed: 0,FIB,DSS,LL3
Macro F1,0.275,0.521,0.671
Spearman,0.263,0.598,0.914
TPR_positive,0.216,0.695,1.0
FPR_positive,0.074,0.127,0.195
TPR_neutral,0.886,0.045,0.089
FPR_neutral,0.784,0.0,0.0
TPR_negative,0.129,0.943,0.99
FPR_negative,0.016,0.388,0.063
TPR_match,0.327,0.696,0.844
FPR_match,0.913,0.0,0.052


In [82]:
l = pretty_print_latex(t.get_latex_string())

print(l)

\begin{tabular}{cccc}
     & FIB & DSS & LL3 \\
    Macro F1 & 0.275 & 0.521 & 0.671 \\
    Spearman & 0.263 & 0.598 & 0.914 \\
    TPR_positive & 0.216 & 0.695 & 1.0 \\
    FPR_positive & 0.074 & 0.127 & 0.195 \\
    TPR_neutral & 0.886 & 0.045 & 0.089 \\
    FPR_neutral & 0.784 & 0.0 & 0.0 \\
    TPR_negative & 0.129 & 0.943 & 0.99 \\
    FPR_negative & 0.016 & 0.388 & 0.063 \\
    TPR_match & 0.327 & 0.696 & 0.844 \\
    FPR_match & 0.913 & 0.0 & 0.052 \\
\end{tabular}
