In [1]:
from typing import TYPE_CHECKING


if TYPE_CHECKING:
    from math_rag.application.containers import ApplicationContainer
    from math_rag.infrastructure.containers import InfrastructureContainer

    application_container: ApplicationContainer
    infrastructure_container: InfrastructureContainer

In [2]:
RESET = False
%load_ext hooks.notebook_hook

2025-06-21 12:04:38,553 - INFO - PyTorch version 2.6.0 available.


In [3]:
importer_service = application_container.math_expression_label_task_importer_service()
exporter_service = application_container.math_expression_label_exporter_service()

result_repository = infrastructure_container.math_expression_dataset_test_result_repository()
label_repository = infrastructure_container.math_expression_label_repository()

In [None]:
# TODO move router parameters out of LLMParams and EMParams
# TODO add split_name to MathExpressionDatasetTestResult
# TODO there is redundancy in hook and main
# TODO filter duplicates and empty katex!
# TODO add shortcut for moving in label studio?

In [5]:
from uuid import UUID


dataset_id = UUID('f7626dc1-6852-499f-b650-2c30137cffda')

In [6]:
project_id = await importer_service.import_tasks(None, dataset_id=dataset_id, split_name='test')

README.md:   0%|          | 0.00/1.04k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/7.60k [00:00<?, ?B/s]

validate-00000-of-00001.parquet:   0%|          | 0.00/7.47k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/35.3k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/43 [00:00<?, ? examples/s]

Generating validate split:   0%|          | 0/43 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/353 [00:00<?, ? examples/s]

prompt.json:   0%|          | 0.00/696 [00:00<?, ?B/s]

2025-06-21 12:04:52,698 - INFO - Project math_expression_label_task | f7626dc1 | test created
2025-06-21 12:04:52,768 - INFO - Imported 353 tasks into math_expression_label_task | f7626dc1 | test


In [17]:
human_labels = await exporter_service.export(project_id)

In [None]:
results = await result_repository.find_many(filter=dict(math_expression_dataset_id=dataset_id))

llama_labels = results[0].math_expression_labels
gpt_4_1_labels = results[1].math_expression_labels

In [33]:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

from math_rag.core.enums import MathExpressionLabelEnum
from math_rag.core.models import MathExpressionLabel

In [34]:
def _sort_by_math_expression_id(labels: list[MathExpressionLabel]) -> list[MathExpressionLabel]:
    return sorted(
        labels,
        key=lambda label: label.math_expression_id,
    )


def prepare(math_expression_labels: list[MathExpressionLabel]) -> list[str]:
    sorted = _sort_by_math_expression_id(math_expression_labels)

    return [label.value.value for label in sorted]


def evaluate_multiclass_labels(
    y_true: list[str],
    y_pred: list[str],
    labels: list[str],
) -> tuple[float, str, list[list[int]]]:
    acc = accuracy_score(y_true, y_pred)
    report = classification_report(
        y_true,
        y_pred,
        labels=labels,
        target_names=labels,
        zero_division=0,
    )
    cm = confusion_matrix(
        y_true,
        y_pred,
        labels=labels,
    )

    return acc, report, cm

In [35]:
y_true = prepare(human_labels)
y_pred = prepare(gpt_4_1_labels)
labels = [e.value for e in MathExpressionLabelEnum]

accuracy, report, cm = evaluate_multiclass_labels(y_true, y_pred, labels)

print(accuracy)
print(report)
print(cm)

0.8413597733711048
              precision    recall  f1-score   support

    equality       0.96      1.00      0.98        43
  inequality       0.71      1.00      0.83         5
    constant       0.24      1.00      0.39        15
    variable       1.00      0.93      0.97        74
       other       0.97      0.76      0.85       216

    accuracy                           0.84       353
   macro avg       0.78      0.94      0.80       353
weighted avg       0.94      0.84      0.87       353

[[ 43   0   0   0   0]
 [  0   5   0   0   0]
 [  0   0  15   0   0]
 [  0   0   0  69   5]
 [  2   2  47   0 165]]


In [36]:
y_true = prepare(human_labels)
y_pred = prepare(llama_labels)
labels = [e.value for e in MathExpressionLabelEnum]

accuracy, report, cm = evaluate_multiclass_labels(y_true, y_pred, labels)

print(accuracy)
print(report)
print(cm)

0.7365439093484419
              precision    recall  f1-score   support

    equality       0.69      0.26      0.37        43
  inequality       0.10      0.60      0.17         5
    constant       0.37      1.00      0.54        15
    variable       0.85      0.97      0.91        74
       other       0.88      0.74      0.80       216

    accuracy                           0.74       353
   macro avg       0.58      0.71      0.56       353
weighted avg       0.82      0.74      0.75       353

[[ 11  12   0   0  20]
 [  0   3   0   2   0]
 [  0   0  15   0   0]
 [  0   0   0  72   2]
 [  5  15  26  11 159]]
