In [1]:
import numpy as np
import pandas as pd

Для реализации задачи обучение модели не требуется. Используем сразу валидационный датасет

In [2]:
valid = pd.read_csv("../../data/preprocessed/simple/valid.csv")
reference = pd.read_csv("../../data/preprocessed/simple/reference.csv")

In [3]:
valid.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 184 entries, 0 to 183
Data columns (total 2 columns):
 #   Column     Non-Null Count  Dtype 
---  ------     --------------  ----- 
 0   name       184 non-null    object
 1   school_id  184 non-null    int64 
dtypes: int64(1), object(1)
memory usage: 3.0+ KB


In [4]:
reference.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 305 entries, 0 to 304
Data columns (total 2 columns):
 #   Column     Non-Null Count  Dtype 
---  ------     --------------  ----- 
 0   id         305 non-null    int64 
 1   reference  305 non-null    object
dtypes: int64(1), object(1)
memory usage: 4.9+ KB


### Преобразования в np.array

In [24]:
x = valid["name"].to_numpy(dtype="str").flatten()
y = valid["school_id"].to_numpy(dtype="int").flatten()
reference_id = reference["id"].to_numpy(dtype="int").flatten()
reference_name = reference["reference"].to_numpy(dtype="str").flatten()

### Функции расчета метрик и simple-предсказаний

In [25]:
def find_matches(X, reference_id, reference_name, top_k=5):
    y_pred = []
    for name in X:
        matches = reference_id[np.char.find(reference_name, name) >= 0]
        top_matches = matches[:top_k].astype(int).tolist()
        top_matches = [(i, 1.0) for i in top_matches]
        if len(top_matches) < top_k:
            top_matches += [(None, 0.0)] * (top_k - len(top_matches))
        y_pred.append(top_matches)
    return y_pred

In [26]:
def accuracy_top_k(true_values, predictions, k):
    correct = 0
    for true_id, pred_list in zip(true_values, predictions):
        top_k_ids = [pred[0] for pred in pred_list[:k]]
        if true_id in top_k_ids:
            correct += 1
    return correct / len(true_values)

In [27]:
def accuracy_top_k(y_true, y_pred, k):
    correct = 0
    valid_count = 0  # Счетчик для валидных предсказаний
    for true_id, pred_list in zip(y_true, y_pred):
        top_k_ids = [pred[0] for pred in pred_list[:k] if pred[0] is not None]
        if top_k_ids:
            valid_count += 1
            if true_id in top_k_ids:
                correct += 1
    return correct / valid_count if valid_count > 0 else 0

In [28]:
def check_top_1(true_values, predictions):
    result = []
    for pred_list, true in zip(predictions, true_values):
        if pred_list[0][0] == true:
            result.append(False)
        else:
            result.append(True)
    return result

In [29]:
def calculate_metrics(true_values, predictions):
    # Accuracy
    accuracy_top_1 = accuracy_top_k(true_values, predictions, 1)
    accuracy_top_3 = accuracy_top_k(true_values, predictions, 3)
    accuracy_top_5 = accuracy_top_k(true_values, predictions, 5)

    # Бизнес-метрика: процент ошибок
    error_rate = (1 - accuracy_top_1) * 100

    # Бизнес-метрика: процент ручной обработки данных
    result = check_top_1(true_values, predictions)
    manual_processing_rate = (sum(result) / len(predictions)) * 100

    general_error = (
        (len(true_values) - len(true_values) * manual_processing_rate / 100)
        * error_rate
        / len(true_values)
    )

    return {
        "Accuracy@1": round(accuracy_top_1, 3),
        "Accuracy@3": round(accuracy_top_3, 3),
        "Accuracy@5": round(accuracy_top_5, 3),
        "error_rate": round(error_rate, 3),
        "manual_processing_rate": round(manual_processing_rate, 3),
        "general_error": round(general_error, 3),
        # "manual_spend": manual_spend,
        # "correction_spend": correction_spend,
        # "general_spend": manual_spend+correction_spend,
    }

In [30]:
y_pred = find_matches(x, reference_id, reference_name, top_k=5)

metrics = calculate_metrics(y, y_pred)
metrics

{'Accuracy@1': 0.857,
 'Accuracy@3': 0.952,
 'Accuracy@5': 0.952,
 'error_rate': 14.286,
 'manual_processing_rate': 90.217,
 'general_error': 1.398}

: 