In [48]:
import json
from pathlib import Path

import numpy as np
import polars as pl
from sklearn.metrics import rand_score, adjusted_rand_score, mutual_info_score, normalized_mutual_info_score, adjusted_mutual_info_score, homogeneity_score, completeness_score, v_measure_score, fowlkes_mallows_score, silhouette_score, calinski_harabasz_score, davies_bouldin_score
from sklearn.metrics.pairwise import cosine_distances

In [54]:
y_col = "cluster"
x_cols = ["x_pos", "y_pos"]

In [49]:
super_metrics = {
    "rand_score": rand_score,
    "adjusted_rand_score": adjusted_rand_score,
    "mutual_info_score": mutual_info_score,
    "normalized_mutual_info_score": normalized_mutual_info_score,
    "adjusted_mutual_info_score": adjusted_mutual_info_score,
    "homogeneity_score": homogeneity_score,
    "completeness_score": completeness_score,
    "v_measure_score": v_measure_score,
    "fowlkes_mallows_score": fowlkes_mallows_score,
}

pos_metrics = {
    "silhouette_score": silhouette_score,
    "calinski_harabasz_score": calinski_harabasz_score,
    "davies_bouldin_score": davies_bouldin_score,
}

In [50]:
def do_metric(metric, yhat, y_or_X):
    if metric in super_metrics:
        return super_metrics[metric](y_or_X, yhat)
    elif metric in pos_metrics:
        return pos_metrics[metric](y_or_X, yhat)
    else:
        raise ValueError(f"Invalid metric: {metric}")

In [None]:
def find_hyps(book_dir: Path):
    for path in book_dir.iterdir():
        if path.name.endswith("df_points.jsonl"):
            yield path

In [51]:
ref = Path("outp/AGUILAR_REF_GT_df_points.jsonl")
hyp = Path("corpus_en/AGUILAR_home-influence/AGUILAR_REF/AGUILAR_home-influence_REF_AffpropHyperparams2_df_points.jsonl")

In [52]:
ref_df = pl.read_ndjson(ref)
hyp_df = pl.read_ndjson(hyp)

In [55]:
y = ref_df[y_col].to_numpy()
yhat = hyp_df[y_col].to_numpy()
coords = hyp_df[x_cols].to_numpy()


In [56]:
dists = cosine_distances(coords)

In [58]:
for metric in super_metrics:
    print(f"{metric}: {do_metric(metric, yhat, y)}")

rand_score: 0.9221213569039656
adjusted_rand_score: 0.036905192077508524
mutual_info_score: 1.9919468661428772
normalized_mutual_info_score: 0.6277813964814433
adjusted_mutual_info_score: 0.05742181444488404
homogeneity_score: 0.5675852901302326
completeness_score: 0.7022607052961277
v_measure_score: 0.6277813964814433
fowlkes_mallows_score: 0.07731523606134356


In [59]:
for metric in pos_metrics:
    print(f"{metric}: {do_metric(metric, yhat, coords)}")

silhouette_score: -0.4297444391206185
calinski_harabasz_score: 23.25427755025825
davies_bouldin_score: 12.12243265034912
