In [None]:
import json
import glob
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.image as mpimg

TASK = "binary"

TRAIN_DATASETS = {
    "DeepSpeak_v1_1": None,
    "SWAN_DF": None,
}

# Find all results.json files
RESULTS = {
    ds: glob.glob(f"../logs/{ds}/{TASK}/*/version_*/results/*/metrics.json")
    for ds in TRAIN_DATASETS.keys()
}

In [None]:
for train_ds in TRAIN_DATASETS.keys():
    # List to hold all rows
    rows = []

    # Process each file
    for file_path in RESULTS[train_ds]:
        eval_dataset = Path(file_path).parent.name  # Extract 'eval_dataset'
        result_dir = Path(file_path).parent
        version = result_dir.parent.parent.name  # Extract 'version_X'
        fusion = result_dir.parent.parent.parent.name  # Extract 'fusion'

        with open(file_path, "r") as f:
            data = json.load(f)

        # Attach PNG paths
        png_paths = {
            # "train_audio_proj": str(result_dir / "train_audio_proj.png"),
            # "train_video_proj": str(result_dir / "train_video_proj.png"),
            # "train_fused": str(result_dir / "train_fused.png"),
            "val_audio_proj": str(result_dir / "val_audio_proj.png"),
            "val_video_proj": str(result_dir / "val_video_proj.png"),
            "val_fused": str(result_dir / "val_fused.png"),
            "test_audio_proj": str(result_dir / "test_audio_proj.png"),
            "test_video_proj": str(result_dir / "test_video_proj.png"),
            "test_fused": str(result_dir / "test_fused.png"),
        }

        # Flatten the structure
        for split in ["val", "test"]:  # "train",
            if split in data:
                row = {
                    "fusion": fusion,
                    "version": version,
                    "split": split,
                    "dataset": eval_dataset,
                }
                row.update(data[split])
                row.update({k: v for k, v in png_paths.items() if k.startswith(split)})
                rows.append(row)

    # Create DataFrame
    df = pd.DataFrame(rows)

    if df.empty:
        print(f"[WARN] {train_ds} produced no valid splits - skipping")
        TRAIN_DATASETS[train_ds] = {}  # or whatever default makes sense
        continue

    df = df.groupby(["dataset"])

    eval_dfs = {}
    for name, group in df:
        group = group.drop(columns=["dataset"])
        # Split DataFrame
        # train = group[group["split"] == "train"].drop(columns=["split"])
        val = group[group["split"] == "val"].drop(columns=["split"])
        test = group[group["split"] == "test"].drop(columns=["split"])

        # Sorting: Best EER first
        # train = train.sort_values(by="auc", ascending=False).reset_index(drop=True)
        val = val.sort_values(by="auc", ascending=False).reset_index(drop=True)
        test = test.sort_values(by="auc", ascending=False).reset_index(drop=True)

        eval_dfs[name[0]] = {}

        for split in ["val", "test"]:
            split_df = (
                group[group["split"] == split]
                .drop(columns=["split"])
                .reset_index(drop=True)
            )

            # Separate into metrics and image paths
            metric_cols = [
                col
                for col in split_df.columns
                if col
                not in [
                    "val_audio_proj",
                    "val_video_proj",
                    "val_fused",
                    "test_audio_proj",
                    "test_video_proj",
                    "test_fused",
                ]
            ]
            image_cols = [col for col in split_df.columns if col not in metric_cols] + ["version"]

            eval_dfs[name[0]][split] = {
                "metrics": split_df[metric_cols],
                "images": split_df[image_cols],
            }

    TRAIN_DATASETS[train_ds] = eval_dfs

In [None]:
maximize_metrics = ["acc", "ap", "prec", "rec", "f1", "auc"]
minimize_metrics = ["loss"]

In [None]:
def highlight_best_per_column(s):
    if s.name in maximize_metrics:
        is_best = s == s.max()
    elif s.name in minimize_metrics:
        is_best = s == s.min()
    else:
        is_best = [False] * len(s)
    return [
        (
            "background-color: red"
            if v and (s.name == "auc" or s.name == "ap")
            else "background-color: green" if v else ""
        )
        for v in is_best
    ]


# "00", "01", "10", "11" -> "RA-RV", "RA-FV", "FA-RV", "FA-FV"
av_classes = ["RA-RV", "RA-FV", "FA-RV", "FA-FV"]


def plot_embeddings(images_df: pd.DataFrame, split: str, version: str):
    fig, axs = plt.subplots(1, 3, figsize=(18, 5))
    fig.suptitle(
        f"{split.capitalize()} Embeddings (model {version})", fontsize=16, x=0.45
    )

    mod_names = ["Audio Proj", "Video Proj", "Fused"]

    # Look up the row with the matching version
    row = images_df[images_df["version"] == version]
    if row.empty:
        print(f"No images found for version {version} in {split} split.")
        plt.close(fig)
        return

    row = row.iloc[0]

    for i, mod in enumerate(mod_names):
        img_col = f"{split}_{mod.lower().replace(' ', '_')}"
        if img_col in row and isinstance(row[img_col], str) and Path(row[img_col]).exists():
            try:
                img = mpimg.imread(row[img_col])
                axs[i].imshow(img)
                axs[i].axis("off")
            except Exception:
                axs[i].set_visible(False)
        else:
            axs[i].set_visible(False)

    handles = [
        mpatches.Patch(color=plt.cm.tab10(i), label=av_classes[i])
        for i in range(len(av_classes))
    ]
    fig.legend(handles=handles, loc="center right", title="AV Classes", borderaxespad=5)
    plt.tight_layout(rect=[0, 0, 0.9, 1])
    plt.show()


In [None]:
train_ds_1 = list(TRAIN_DATASETS.keys())[0]
print(f"Models trained on {train_ds_1}.")
print("-" * 50)
print("\n")

eval_dfs = TRAIN_DATASETS[train_ds_1]

for eval_name, eval_data in eval_dfs.items():
    print(f"Results on {eval_name}{' (cross-dataset)' if eval_name != train_ds_1 else ''}:")
    print("-" * 50)

    # Sort test metrics by AUC
    test_metrics = (
        eval_data["test"]["metrics"]
        .sort_values(by="auc", ascending=False)
        .reset_index(drop=True)
    )

    # Align val_metrics by test AUC order using 'version' as key
    sorted_versions = test_metrics["version"].tolist()
    val_metrics = eval_data["val"]["metrics"]
    val_metrics = val_metrics.set_index("version").loc[sorted_versions].drop_duplicates().reset_index()
    
    # Get best version (highest test AUC)
    best_version = test_metrics.loc[0, "version"]

    val_images = eval_data["val"]["images"]
    test_images = eval_data["test"]["images"]

    print("- Val (sorted by test AUC):")
    display(val_metrics.style.apply(highlight_best_per_column, axis=0))
    plot_embeddings(val_images, "val", best_version)

    print("- Test:")
    display(test_metrics.style.apply(highlight_best_per_column, axis=0))
    plot_embeddings(test_images, "test", best_version)

    print("\n")

In [None]:
train_ds_2 = list(TRAIN_DATASETS.keys())[1]
print(f"Models trained on {train_ds_2}.")
print("-" * 50)
print("\n")

eval_dfs = TRAIN_DATASETS[train_ds_2]

for eval_name, eval_data in eval_dfs.items():
    print(f"Results on {eval_name}{' (cross-dataset)' if eval_name != train_ds_2 else ''}:")
    print("-" * 50)

    # Sort test metrics by AUC
    test_metrics = (
        eval_data["test"]["metrics"]
        .sort_values(by="auc", ascending=False)
        .reset_index(drop=True)
    )

    # Align val_metrics by test AUC order using 'version' as key
    sorted_versions = test_metrics["version"].tolist()
    val_metrics = eval_data["val"]["metrics"]
    val_metrics = val_metrics.set_index("version").loc[sorted_versions].drop_duplicates().reset_index()
    
    # Get best version (highest test AUC)
    best_version = test_metrics.loc[0, "version"]

    val_images = eval_data["val"]["images"]
    test_images = eval_data["test"]["images"]

    print("- Val (sorted by test AUC):")
    display(val_metrics.style.apply(highlight_best_per_column, axis=0))
    plot_embeddings(val_images, "val", best_version)

    print("- Test:")
    display(test_metrics.style.apply(highlight_best_per_column, axis=0))
    plot_embeddings(test_images, "test", best_version)

    print("\n")