In [3]:
from typing import TYPE_CHECKING


if TYPE_CHECKING:
    from math_rag.application.containers import ApplicationContainer
    from math_rag.infrastructure.containers import InfrastructureContainer

    application_container: ApplicationContainer
    infrastructure_container: InfrastructureContainer

In [5]:
RESET = False
%load_ext hooks.notebook_hook

2025-06-20 21:14:31,747 - INFO - PyTorch version 2.6.0 available.


In [3]:
result_repository = infrastructure_container.math_expression_dataset_test_result_repository()
label_repository = infrastructure_container.math_expression_label_repository()

In [None]:
from uuid import UUID


dataset_id = UUID('f7626dc1-6852-499f-b650-2c30137cffda')
results = await result_repository.find_many(filter=dict(math_expression_dataset_id=dataset_id))

In [8]:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

from math_rag.core.enums import MathExpressionLabelEnum
from math_rag.core.models import MathExpressionLabel

In [None]:
def _sort_by_math_expression_id(labels: list[MathExpressionLabel]) -> list[MathExpressionLabel]:
    return sorted(
        labels,
        key=lambda label: label.math_expression_id,
    )


def evaluate_multiclass_labels(
    y_true: list[str],
    y_pred: list[str],
    labels: list[str],
) -> tuple[float, str, list[list[int]]]:
    acc = accuracy_score(y_true, y_pred)
    report = classification_report(
        y_true,
        y_pred,
        labels=labels,
        target_names=labels,
        zero_division=0,
    )
    cm = confusion_matrix(
        y_true,
        y_pred,
        labels=labels,
    )

    return acc, report, cm

In [None]:
true_labels = results[0].math_expression_labels
pred_labels = results[1].math_expression_labels

sorted_true_labels = _sort_by_math_expression_id(true_labels)
sorted_pred_labels = _sort_by_math_expression_id(pred_labels)

y_true = [label.value.value for label in sorted_true_labels]
y_pred = [label.value.value for label in sorted_pred_labels]
labels = [e.value for e in MathExpressionLabelEnum]

accuracy, report, confusion_matrix = evaluate_multiclass_labels(y_true, y_pred, labels)

print(accuracy)
print(report)
print(confusion_matrix)

0.9603399433427762
              precision    recall  f1-score   support

    equality       1.00      1.00      1.00        48
  inequality       1.00      1.00      1.00        10
    constant       0.94      0.95      0.94        62
    variable       1.00      0.91      0.95        77
       other       0.94      0.97      0.96       156

    accuracy                           0.96       353
   macro avg       0.97      0.97      0.97       353
weighted avg       0.96      0.96      0.96       353

[[ 48   0   0   0   0]
 [  0  10   0   0   0]
 [  0   0  59   0   3]
 [  0   0   0  70   7]
 [  0   0   4   0 152]]
