In [None]:
import pandas as pd
import matplotlib.pyplot as plt

from config.config import METRICS, THRESHOLDS
from utils.io import load_validated_data, export_eval_data, plot_metrics, plot_confusion_matrix
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score

In [None]:
datasets = ["STREET", "FULLNAME", "BOTH"]

In [None]:
for dataset_name in datasets:
    if dataset_name == "BOTH":
        data_street = load_validated_data("STREET")
        data_fullname = load_validated_data("FULLNAME")
        DATA = pd.concat([data_street, data_fullname], ignore_index=True)
    else: 
        DATA = load_validated_data(dataset_name)
    
    if DATA.empty:
        print(f"No data found for dataset '{dataset_name}'. Skipping...")
        continue
    
    eval_df = pd.DataFrame(index=METRICS)
    for algo in THRESHOLDS:
        y_true_col = f"{algo.upper()}_TRUE_MATCH"
        y_pred_col = f"{algo.upper()}_BEST_MATCH_BINARY"
        
        if y_true_col not in DATA.columns or y_pred_col not in DATA.columns:
            print(f"Columns for algorithm '{algo}' not found in dataset '{dataset_name}'. Skipping...")
            continue
        
        y_true = DATA[y_true_col]
        y_pred = DATA[y_pred_col]
        
        mask = y_true.notna() & y_pred.notna()
        y_true = y_true[mask]
        y_pred = y_pred[mask]
        
        if y_true.empty:
            print(f"No valid data for algorithm '{algo}' in dataset '{dataset_name}'. Skipping...")
            continue
        
        if y_true.dtype == "object":
            y_true = y_true.astype(bool)
            y_pred = y_pred.astype(bool)
        
        valid_labels = (0, 1)
        if not set(y_true.unique()).issubset(valid_labels) or not set(y_pred.unique()).issubset(valid_labels):
            print(f"Invalid labels detected in algorithm '{algo}' for dataset '{dataset_name}'. Skipping...")
            continue
        
        accuracy = accuracy_score(y_true, y_pred)
        precision = precision_score(y_true, y_pred)
        recall = recall_score(y_true, y_pred)
        f1 = f1_score(y_true, y_pred)
        roc_auc = roc_auc_score(y_true, y_pred)
        
        eval_df[algo] = [accuracy, precision, recall, f1, roc_auc]
    
    if eval_df.empty:
        print(f"No evaluation data for dataset '{dataset_name}'. Skipping...")
        continue
    
    eval_df = export_eval_data(eval_df, dataset_name)
    metrics_fig = plot_metrics(eval_df,  dataset_name)
    plt.close(metrics_fig)
    
    cm_list = []
    for algo in THRESHOLDS:
        y_true_col = f"{algo.upper()}_TRUE_MATCH"
        y_pred_col = f"{algo.upper()}_BEST_MATCH_BINARY"
        
        if y_true_col not in DATA.columns or y_pred_col not in DATA.columns:
            continue

        y_true = DATA[y_true_col]
        y_pred = DATA[y_pred_col]

        mask = y_true.notna() & y_pred.notna()
        y_true = y_true[mask]
        y_pred = y_pred[mask]
        
        if y_true.empty:
            continue
            
        if y_true.dtype == "object":
            y_true = y_true.astype(bool)
            y_pred = y_pred.astype(bool)
        
        valid_labels = (0, 1)
        if not set(y_true.unique()).issubset(valid_labels) or not set(y_pred.unique()).issubset(valid_labels):
            continue
        
        cm = confusion_matrix(y_true, y_pred)
        cm_fig = plot_confusion_matrix(cm, algo, dataset_name)
        cm_list.append(cm_fig)
        plt.close(cm_fig)