In [1]:
from typing import TYPE_CHECKING


if TYPE_CHECKING:
    from math_rag.application.containers import ApplicationContainer
    from math_rag.infrastructure.containers import InfrastructureContainer

    application_container: ApplicationContainer
    infrastructure_container: InfrastructureContainer

In [None]:
RESET = False
%load_ext hooks.notebook_hook

2025-06-20 14:45:13,557 - INFO - PyTorch version 2.6.0 available.


In [None]:
result_repository = infrastructure_container.math_expression_dataset_test_result_repository()
label_repository = infrastructure_container.math_expression_label_repository()

In [None]:
from uuid import UUID


dataset_id = UUID('e4fd82d2-b4e6-4aee-8cf9-fb83c9940ffa')
results = await result_repository.find_many(filter=dict(math_expression_dataset_id=dataset_id))

In [8]:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

from math_rag.core.enums import MathExpressionLabelEnum
from math_rag.core.models import MathExpressionLabel

In [11]:
def _sort_labels(labels: list[MathExpressionLabel]) -> list[MathExpressionLabel]:
    """Sort labels by expression ID then index ID."""
    return sorted(
        labels,
        key=lambda l: l.math_expression_id,
    )


def evaluate_multiclass_labels(
    y_true: list[MathExpressionLabel],
    y_pred: list[MathExpressionLabel],
) -> tuple[float, str, list[list[int]]]:
    # align predictions with ground truth
    sorted_true = _sort_labels(y_true)
    sorted_pred = _sort_labels(y_pred)

    # extract raw enum values
    y_true_vals = [label.value.value for label in sorted_true]
    y_pred_vals = [label.value.value for label in sorted_pred]

    # define class order from the enum
    labels = [e.value for e in MathExpressionLabelEnum]

    # compute metrics
    acc = accuracy_score(y_true_vals, y_pred_vals)
    report = classification_report(
        y_true_vals,
        y_pred_vals,
        labels=labels,
        target_names=labels,
        zero_division=0,
    )
    cm = confusion_matrix(
        y_true_vals,
        y_pred_vals,
        labels=labels,
    )

    return acc, report, cm

In [None]:
gt_labels = results[0].math_expression_labels
pred_labels = results[1].math_expression_labels

accuracy, report, confusion_matrix = evaluate_multiclass_labels(gt_labels, pred_labels)

print(accuracy)
print(report)
print(confusion_matrix)

0.9603399433427762
              precision    recall  f1-score   support

    equality       1.00      1.00      1.00        48
  inequality       1.00      1.00      1.00        10
    constant       0.94      0.95      0.94        62
    variable       1.00      0.91      0.95        77
       other       0.94      0.97      0.96       156

    accuracy                           0.96       353
   macro avg       0.97      0.97      0.97       353
weighted avg       0.96      0.96      0.96       353

[[ 48   0   0   0   0]
 [  0  10   0   0   0]
 [  0   0  59   0   3]
 [  0   0   0  70   7]
 [  0   0   4   0 152]]
