# TODO: put me in library code plx

In [1]:
import logging
import typing as ty
import pandas as pd
import numpy as np
import mlflow
import xarray as xr
from sklearn.metrics import (
    confusion_matrix,
    roc_curve,
    roc_auc_score,
    precision_recall_curve,
    average_precision_score,
    f1_score
)
from pathlib import Path
from sklearn.metrics import classification_report


logger = logging.getLogger(__name__)


def clf_report(y_true: pd.Series | np.ndarray, y_pred: pd.Series | np.ndarray) -> pd.DataFrame:
    report = classification_report(y_true, y_pred, output_dict=True, zero_division=0)
    assert isinstance(report, dict)
    return pd.DataFrame(report).transpose()


def save_classification_report(
    y_true: pd.Series,
    y_pred: pd.Series,
    output_dir: Path,
    log_mlflow: bool = False
) -> Path:
    """
    Generates a classification report and saves it to a CSV file.

    :param y_true: Series containing the true labels.
    :param y_pred: Series containing the predicted labels.
    :param encoder: HierarchicalEncoder object used for label encoding.
    :param classification_level: The level of classification used by the encoder.
    :param output_dir: Directory to save the classification report.
    :return: Path to the saved classification report CSV file.
    """
    report_df = clf_report(y_true, y_pred)
    report_csv_path = output_dir / 'classification_report.csv'
    report_df.to_csv(report_csv_path, index=True)
    if log_mlflow:
        # Logging all metrics in classification_report
        mlflow.log_metric("accuracy", report.pop("accuracy"))
        for class_or_avg, metrics_dict in report.items():
            for metric, value in metrics_dict.items():
                mlflow.log_metric(class_or_avg + '_' + metric,value)
    return report_csv_path


# Function to update indices using Pandas Index intersection
def update_indices_pandas(indices, sampled_indices):
    return pd.Index(indices).intersection(sampled_indices)


def precision_recall_dataset(
    Y_true: pd.Series | np.ndarray,
    Y_pred: pd.Series | np.ndarray,
    pos_label = None
) -> xr.Dataset:
    """
    Builds a PR curve dataset from predictions and ground truth
    """
    precision, recall, thresh_pr = precision_recall_curve(Y_true, Y_pred, pos_label=pos_label)
    average_precision = average_precision_score(Y_true, Y_pred, pos_label=pos_label)
    ds_pr = pd.DataFrame(
        data=np.stack([precision, recall], axis=-1),
        columns=["Precision", "Recall"],
    ).to_xarray().rename(index="pr_curve_index")
    ds_pr["thresh_pr"] = (
        "pr_curve_index",
        np.pad(
            thresh_pr,
            pad_width=(1, 0),
            mode="constant",
            constant_values=0.
        )
    )
    ds_pr["F1_score"] = ("pr_curve_index", 2*precision*recall / (precision+recall))
    ds_pr["AP"] = average_precision
    return ds_pr


def roc_curve_dataset(
    Y_true: pd.Series | np.ndarray,
    Y_pred: pd.Series | np.ndarray,
    pos_label=None
) -> xr.Dataset:
    """
    Builds a ROC curve dataset from predictions and ground truth
    """
    fpr, tpr, thresh_roc = roc_curve(Y_true, Y_pred, pos_label=pos_label)
    auc_score = roc_auc_score(Y_true, Y_pred)
    ds_roc = pd.DataFrame(
        data=np.stack([fpr, tpr], axis=-1),
        columns=["FPR", "TPR"]
    ).to_xarray().rename(index="roc_curve_index")
    ds_roc["thresh_roc"] = (
        "roc_curve_index",
        thresh_roc
    )
    ds_roc["balanced_accuracy"] = ("roc_curve_index", (tpr + (1-fpr))/2)
    ds_roc["AUC"] = auc_score
    return ds_roc


def confusion_matrix_dataframe(
    y_pred: np.ndarray | pd.Series,
    y_true: np.ndarray | pd.Series,
    classes: list[str]
) -> pd.DataFrame:
    """
    Build a confusion matrix and embed it into a DataFrame
    
    Calculates precision/recall/specificity/NPV and pred/real totals
    """
    # make cm, swapping pred/true to get our preferred convention (opposite to sklearn)
    cm = confusion_matrix(y_true=y_pred, y_pred=y_true)
    # stick the CM in a dataframe
    df_cm = pd.DataFrame(index=classes, columns=classes, data=cm)
    # calculate total predicted/true in each class
    df_cm.loc["total", :] = df_cm.sum()
    df_cm['total'] = df_cm.sum(axis=1)
    # calculate recalls and precisions for each class
    for i in range(len(classes)):
        rec = np.round(df_cm.iloc[i, i]/df_cm.loc["total", df_cm.columns[i]],2)
        prec = np.round(df_cm.iloc[i, i]/df_cm.loc[df_cm.index[i], 'total'],2)
        df_cm.loc['recall', df_cm.columns[i]] = rec
        df_cm.loc[df_cm.index[i], 'precision'] = prec
    return df_cm


In [2]:
import math
import logging
from typing import Optional
from pathlib import Path

import matplotlib.pyplot as plt
import plotly.graph_objs as go
import umap
import numpy as np
import pandas as pd
import plotly.express as px
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    ConfusionMatrixDisplay,
    roc_curve,
    auc,
    accuracy_score,
    balanced_accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    roc_auc_score,
    average_precision_score,
    precision_recall_curve
)


def plot_confusion_matrix(
    y_true: pd.Series,
    y_pred: pd.Series,
    output_dir: Path = Path("."),
    file_name: str | None = "confusion_matrix.png"
) -> Path | go.Figure:
    """
    Generates and saves the confusion matrix plot to a file.

    :param y_true: Series containing the true labels.
    :param y_pred: Series containing the predicted labels.
    :param output_dir: Directory to save the confusion matrix plot.
    :return: Path to the saved confusion matrix plot file.
    """
    logger.info("Plot confusion matrix")
    fig, ax = plt.subplots(figsize=(15, 15))
    ConfusionMatrixDisplay.from_predictions(y_true, y_pred, ax=ax)
    plt.title('Confusion Matrix')

    if file_name is None:
        return fig
    confusion_matrix_path = output_dir / file_name 
    plt.savefig(confusion_matrix_path)
    plt.close()
    
    return confusion_matrix_path


def plot_predicted_vs_real_counts(
    y_true: pd.Series,
    y_pred: pd.Series,
    output_dir: Path  = Path("."),
    file_name: str | None = "predicted_vs_real_counts.png"
) -> Path | go.Figure:
    """
    Plots and saves the predicted vs real counts.

    :param y_true: Series containing the true labels.
    :param y_pred: Series containing the predicted labels.
    :param encoder: HierarchicalEncoder object used for label encoding.
    :param classification_level: The level of classification used by the encoder.
    :param output_dir: Directory to save the plot.
    :return: Path to the saved plot file.
    """
    logger.info("Plot counts")
    value_counts = (
        pd.concat(
            [
                y_true.value_counts().rename("true"),
                y_pred.value_counts().rename("predicted")
            ],
            axis=1
        )
        .stack()
        .reset_index()
        .rename(
            {
                "level_0": "class",
                "level_1": "set",
                0: "count"
            },
            axis=1
        )
    )
    
    fig = px.bar(
        value_counts,
        x="class",
        y="count",
        color="set",
        barmode="group",
        title="Predicted vs real counts"
    )
    if file_name is None:
        return fig
    counts_plot_path = output_dir / file_name 
    fig.write_image(str(counts_plot_path), engine="kaleido")
   
    return counts_plot_path


def plot_multiclass_roc_auc_save(
    y_true, y_proba, output_dir: Path = Path("."), file_name: str | None = "roc_curve.png"
) -> Path | go.Figure:
    """
    Generates, saves, and returns the ROC curve plot with AUC for multi-class.

    :param y_true: Series containing the true labels.
    :param y_proba: DataFrame containing the predicted probabilities.
    :param output_dir: Directory to save the ROC curve plot.
    :param file_name: File name for saving the plot.
    :return: Path to the saved ROC curve plot file.
    """
    logger.info("Plot multiclass ROC")
    # Create a figure for the plot
    fig, ax = plt.subplots(figsize=(10, 7))

    # Compute ROC curve and ROC area for each class
    for i, class_label in enumerate(y_proba.columns):
        fpr, tpr, _ = roc_curve(y_true == class_label, y_proba[class_label])
        roc_auc = auc(fpr, tpr)
        ax.plot(fpr, tpr, label=f'ROC curve of class {class_label} (AUC = {roc_auc:.2f})')

    ax.plot([0, 1], [0, 1], 'k--')
    ax.set_xlim([0.0, 1.0])
    ax.set_ylim([0.0, 1.05])
    ax.set_xlabel('False Positive Rate')
    ax.set_ylabel('True Positive Rate')
    ax.set_title('Receiver Operating Characteristic (ROC) with AUC for Multi-Class')
    ax.legend(loc="lower right")

    if file_name is None:
        return fig
    
    # Save the plot to the specified directory
    roc_curve_path = output_dir / file_name
    fig.savefig(roc_curve_path)
    plt.close(fig)

    return roc_curve_path

def plot_multiclass_pr_curve_save_with_auc(
    y_true: pd.Series, 
    y_proba: pd.DataFrame, 
    output_dir: Path = Path("."), 
    file_name: str | None = "pr_curve.png"
) -> Path | go.Figure:
    """
    Generates, saves, and returns the Precision-Recall curve plot with AUC for multi-class.

    :param y_true: Series containing the true labels.
    :param y_proba: DataFrame containing the predicted probabilities.
    :param output_dir: Directory to save the PR curve plot.
    :param file_name: File name for saving the plot.
    :return: Path to the saved PR curve plot file.
    """
    logger.info("Plot multiclass PR curve")
    fig, ax = plt.subplots(figsize=(10, 7))

    # Compute and plot PR curve and AUC for each class
    for i, class_label in enumerate(y_proba.columns):
        precision, recall, _ = precision_recall_curve(y_true == class_label, y_proba[class_label])
        pr_auc = average_precision_score(y_true == class_label, y_proba[class_label])
        ax.plot(recall, precision, label=f'Class {class_label} (AUC PR = {pr_auc:.2f})')

    ax.set_xlabel('Recall')
    ax.set_ylabel('Precision')
    ax.set_title('Precision-Recall Curve for Multi-Class with AUC')
    ax.legend(loc="lower right")

    if file_name is None:
        return fig
    
    # Save the plot
    pr_curve_path = output_dir / file_name
    fig.savefig(pr_curve_path)
    plt.close(fig)

    return pr_curve_path


def evaluate_hard_label_metrics(
    y_true: pd.Series, 
    y_pred: pd.Series, 
    metrics: list[str] = [
        "accuracy",
        "balanced_accuracy",
        "f1_macro",
        "f1_micro",
        "f1_weighted",
        "precision_macro",
        "precision_micro",
        "precision_weighted",
        "recall_macro",
        "recall_micro",
        "recall_weighted"
    ]
) -> dict[str, float]:
    """
    Evaluates various metrics using hard labels.

    :param y_true: Series containing the true labels.
    :param y_pred: Series containing the predicted labels.
    :param metrics: List of metric names to evaluate.
    :return: Dictionary with metric names as keys and their corresponding scores as values.
    """
    logger.info("Evaluate hard label metrics")
    results = {}
    for metric in metrics:
        logger.info(f"Evaluating {metric=}")
        if metric == "accuracy":
            results[metric] = accuracy_score(y_true, y_pred)
        elif metric == "balanced_accuracy":
            results[metric] = balanced_accuracy_score(y_true, y_pred)
        elif metric.startswith("f1"):
            average_type = metric.split("_")[1]  # e.g., macro, micro, weighted
            results[metric] = f1_score(y_true, y_pred, average=average_type)
        elif metric.startswith("precision"):
            average_type = metric.split("_")[1]
            results[metric] = precision_score(y_true, y_pred, average=average_type, zero_division=0)
        elif metric.startswith("recall"):
            average_type = metric.split("_")[1]
            results[metric] = recall_score(y_true, y_pred, average=average_type, zero_division=0)
    return results


def evaluate_probability_metrics(
    y_true: pd.Series, 
    y_proba: pd.DataFrame, 
    metrics: list[str] = ["roc_auc_ovr", "average_precision"]
) -> dict[str, float]:
    """
    Evaluates various metrics using probability scores for multiclass classification.

    :param y_true: Series containing the true labels.
    :param y_proba: DataFrame containing the predicted probabilities for each class.
    :param metrics: List of metric names to evaluate.
    :return: Dictionary with metric names as keys and their corresponding scores as values.
    """
    logger.info("Evaluate probability metrics")
    results = {}
    y_true_cat = y_true

    for metric in metrics:
        if metric == "roc_auc_ovr":
            # Handle binary classification differently
            if y_proba.shape[1] == 2: 
                # Assuming the second column is the one of interest for binary classification
                y_proba_binary = y_proba.iloc[:, 1]
                results[metric] = roc_auc_score(y_true_cat, y_proba_binary)
            else:
                # Multiclass case
                results[metric] = roc_auc_score(
                    y_true_cat, y_proba, multi_class="ovr", labels=sorted(y_proba.columns)
                )
        elif metric == "average_precision":
            # Compute Average Precision for each class and take the mean
            ap_scores = [
                average_precision_score(
                    y_true_cat == class_label,
                    y_proba[class_label]
                )
                for class_label in y_proba.columns
            ]
            results[metric] = sum(ap_scores) / len(ap_scores)
    
    return results



# Actual plotting shit

In [3]:
from pathlib import Path

import torch
import numpy as np

all_exp_dir = Path("../../exp")
heritage_exp_dir = all_exp_dir / "heritage"
data_dir = Path("../../data")

experiment_names = [
    "library",
    "maritime_and_park_v1",
    "rog_and_foundry_v2",
    "maritime_park_library_v1",
]

dataset_names = [
    "library",
    "maritime_and_park",
    "rog_and_foundry",
    "maritime_park_library",
]

class_labels = pd.Series({
    1 : "wall",
    2 : "floor",
    3 : "roof",
    4 : "ceiling",
    5 : "footpath",
    6 : "grass",
    7 : "column",
    8 : "door",
    9 : "window",
    10 : "stair",
    11 : "railing",
    12 : "rainwater_pipe",
    13 : "other"
})

In [4]:
def dataset_filename_to_scene_id(dataset_filename: str) -> str:
    assert dataset_filename.endswith(".pth")
    return dataset_filename[:-len(".pth")]

def predictions_filename_to_scene_id(preds_filename: str) -> str:
    assert preds_filename.endswith("_pred.npy")
    return preds_filename[:-len("_pred.npy")]

def load_predictions(exp_dir: Path) -> dict[str, np.ndarray]:
    assert (files := list((exp_dir / "result").glob("*.npy")))
    return {predictions_filename_to_scene_id(f.name): np.load(f) for f in files}

def load_clouds(dataset_dir: Path, subset: str | Path | None = "test") -> dict[str, torch.Tensor]:
    glob_prefix = "*" if subset is None else f"{subset}/"
    files = dataset_dir.glob(f"{glob_prefix}*.pth")
    return {dataset_filename_to_scene_id(f.name): torch.load(f) for f in files}

In [7]:
metrics = {}

for exp_name, dataset_name in zip(experiment_names, dataset_names):
    print(f"{exp_name=}")
    metrics[dataset_name] = {}
    exp_preds = load_predictions(heritage_exp_dir / exp_name)
    exp_clouds = load_clouds(data_dir / dataset_name)
    assert set(exp_preds.keys()) == set(exp_clouds.keys()) # all match up to common scene id
    # collect individual scene arrays /results as we go to collect global metrics at the end
    y_trues_all, y_preds_all = [], []
    # evaluate per scene in the experiment
    for scene_key in exp_preds.keys():
        metrics[dataset_name][scene_key] = {}
        print(f"  {scene_key=}")
        y_pred = exp_preds[scene_key]; y_preds_all.append(y_pred)
        y_true = exp_clouds[scene_key]["gt"]; y_trues_all.append(y_true)
        unique_labels_pred = np.unique(y_pred)
        unique_labels_true = np.unique(y_true)
        unique_labels = set(unique_labels_pred).union(set(unique_labels_true))
        metrics[dataset_name][scene_key] = dict(
            hard_label_metrics = evaluate_hard_label_metrics(y_true, y_pred),
            df_clf_report = clf_report(y_true, y_pred),
            df_cm = confusion_matrix_dataframe(y_true, y_pred, classes=class_labels.loc[list(unique_labels)])
        )
    # now globally across all scenes in expt
    y_true_all, y_pred_all = np.concatenate(y_trues_all), np.concatenate(y_preds_all)
    unique_labels_pred_all = np.unique(y_pred_all)
    unique_labels_true_all = np.unique(y_true_all)
    unique_labels_all = set(unique_labels_pred_all).union(set(unique_labels_true_all))
    metrics[dataset_name]["global"] = dict(
        hard_label_metrics = evaluate_hard_label_metrics(y_true_all, y_pred_all),
        df_clf_report = clf_report(y_true_all, y_pred_all),
        df_cm = confusion_matrix_dataframe(y_true_all, y_pred_all, classes=class_labels.loc[list(unique_labels_all)])
    )

exp_name='library'
  scene_key='combined_library_test_sceneid2'
exp_name='maritime_and_park_v1'
  scene_key='combined_maritime_museum_test_sceneid2'


  prec = np.round(df_cm.iloc[i, i]/df_cm.loc[df_cm.index[i], 'total'],2)


  scene_key='combined_maritime_museum_test_sceneid3'


  prec = np.round(df_cm.iloc[i, i]/df_cm.loc[df_cm.index[i], 'total'],2)


  scene_key='combined_park_row_test_sceneid2'
exp_name='rog_and_foundry_v2'
  scene_key='combined_brass_foundry_test_sceneid2'
  scene_key='combined_rog_south_test_sceneid2'


  rec = np.round(df_cm.iloc[i, i]/df_cm.loc["total", df_cm.columns[i]],2)


  scene_key='combined_rog_south_test_sceneid3'


  rec = np.round(df_cm.iloc[i, i]/df_cm.loc["total", df_cm.columns[i]],2)


  scene_key='combined_rog_north_test_sceneid2'


  prec = np.round(df_cm.iloc[i, i]/df_cm.loc[df_cm.index[i], 'total'],2)


exp_name='maritime_park_library_v1'
  scene_key='combined_maritime_museum_test_sceneid2'


  prec = np.round(df_cm.iloc[i, i]/df_cm.loc[df_cm.index[i], 'total'],2)


  scene_key='combined_maritime_museum_test_sceneid3'


  prec = np.round(df_cm.iloc[i, i]/df_cm.loc[df_cm.index[i], 'total'],2)


  scene_key='combined_park_row_test_sceneid2'


# Plot shit

In [8]:
for dataset, dataset_metrics in metrics.items():
    pass