In [None]:
import json
import matplotlib.pyplot as plt
import argparse
import numpy as np


def flatten_metrics(metrics):
    flat = {}
    for k, v in metrics.items():
        if isinstance(v, dict):
            for subk, subv in v.items():
                flat[f"{k}.{subk}"] = subv
        else:
            flat[k] = v
    return flat


def load_data(json_path):
    with open(json_path, 'r') as f:
        data = json.load(f)

    viewpoints = sorted(data["per_viewpoint"].keys())
    semantic = {vp: flatten_metrics(data["per_viewpoint"][vp]["semantic_metrics"]) for vp in viewpoints}
    geometric = {vp: flatten_metrics(data["per_viewpoint"][vp]["geometric_metrics"]) for vp in viewpoints}
    car_quality = {vp: flatten_metrics(data["per_viewpoint"][vp]["car_quality_metrics"]) for vp in viewpoints}

    overall_semantic = flatten_metrics(data["overall_semantic_metrics"])
    overall_geometric = flatten_metrics(data["overall_geometric_metrics"])
    overall_car_quality = flatten_metrics(data["overall_car_quality_metrics"])

    return viewpoints, semantic, geometric, car_quality, overall_semantic, overall_geometric, overall_car_quality


def plot_category(
    viewpoints,
    per_view_metrics,
    overall_metrics,
    metrics,
    title,
    normalize=True,
    log_scale=False
):

    x = np.arange(len(viewpoints))
    fig, ax = plt.subplots(figsize=(10, 6))
    line_colors = {}

    data_vals = {}
    min_max = {}
    for m in metrics:
        vals = np.array([per_view_metrics[vp].get(m, np.nan) for vp in viewpoints], dtype=float)
        if normalize:
            mn, mx = np.nanmin(vals), np.nanmax(vals)
            min_max[m] = (mn, mx)
            data_vals[m] = (vals - mn) / (mx - mn) if mx > mn else vals
        else:
            data_vals[m] = vals

    for m, vals in data_vals.items():
        line, = ax.plot(x, vals, marker='o', label=m)
        line_colors[m] = line.get_color()

    for m in metrics:
        if m in overall_metrics:
            avg = overall_metrics[m]
            if normalize:
                mn, mx = min_max[m]
                avg = (avg - mn) / (mx - mn) if mx > mn else avg
            ax.axhline(avg, linestyle='--', color=line_colors[m], label=f"{m} overall")

    ax.set_xticks(x)
    ax.set_xticklabels(viewpoints, rotation=45)
    ax.set_xlabel("Viewpoint")
    ylabel = "Normalized Value (0 to 1)" if normalize else "Metric Value"
    ax.set_ylabel(ylabel)
    if log_scale:
        ax.set_yscale('log')
    ax.set_title(title)
    ax.legend()
    ax.grid(True, which='both' if log_scale else 'major', ls='--')
    plt.tight_layout()
    plt.show()

In [None]:
json_file = "example_data/Meshfleet_Eval/metrics_results_front_1.json"

viewpoints, semantic, geometric, car_quality, overall_sem, overall_geo, overall_car = load_data(json_file)

semantic_metrics = ["MSE","CLIP-S","PSNR","SSIM", "LPIPS"]
geometric_metrics = ["Rel_BB_Aspect_Ratio_Diff", "Squared_Outline_Normals_Angle_Diff", "Squared_Summed_Outline_Normals_Angle_Diff"]
car_quality_metrics = ["orig_score.avg_quality_score", "gen_score.avg_quality_score", "rel_diff"]

plot_category(viewpoints, semantic, overall_sem, semantic_metrics, "Semantic Metrics by Viewpoint", normalize = True, log_scale = False)
plot_category(viewpoints, geometric, overall_geo, geometric_metrics, "Geometric Metrics by Viewpoint")
plot_category(viewpoints, car_quality, overall_car, car_quality_metrics, "Car Quality Metrics by Viewpoint")

In [None]:
import os
def plot_metric_across_viewpoints(
    json_files,
    metric_name,
    metric_group='semantic_metrics',
    overall=True,
    labels=None,
    legend_outside=True,
    highlight_label=None
):
    """
    Plots the specified metric across viewpoints for multiple JSON files and optionally shows the overall average.
    Can highlight a specific viewpoint label on the x-axis in bold red.

    Args:
        json_files (list of str): Paths to JSON files.
        metric_name (str): Name of the metric to plot (e.g., 'MSE', 'PSNR').
        metric_group (str): One of 'semantic_metrics', 'geometric_metrics', or 'car_quality_metrics'.
        overall (bool): Whether to include the overall average line.
        labels (list of str, optional): Custom labels for each JSON file. If provided, length must match json_files.
        legend_outside (bool): Place legend outside the plot area if True.
        highlight_label (str, optional): The x-axis label (e.g., '000.png') to color bold red.
    """
    overall_key_map = {
        'semantic_metrics': 'overall_semantic_metrics',
        'geometric_metrics': 'overall_geometric_metrics',
        'car_quality_metrics': 'overall_car_quality_metrics'
    }

    plt.figure(figsize=(10, 6))
    if legend_outside:
        plt.subplots_adjust(right=0.75)

    for idx, file_path in enumerate(json_files):
        with open(file_path, 'r') as f:
            data = json.load(f)

        per_view = data.get('per_viewpoint', {})
        viewpoints = sorted(per_view.keys())
        values = [per_view[vp].get(metric_group, {}).get(metric_name, float('nan')) for vp in viewpoints]

        if labels and idx < len(labels):
            label = labels[idx]
        else:
            label = os.path.splitext(os.path.basename(file_path))[0]

        x = list(range(len(viewpoints)))
        plt.plot(x, values, marker='o', label=label)

        if overall:
            overall_metrics = data.get(overall_key_map.get(metric_group, ''), {})
            avg_value = overall_metrics.get(metric_name)
            if avg_value is not None:
                plt.axhline(y=avg_value, linestyle='--', label=f'{label} average')

    xt_labels = [f"{vp}.png" for vp in viewpoints]
    plt.xticks(x, xt_labels, rotation=45)

    if highlight_label:
        ax = plt.gca()
        for tick in ax.get_xticklabels():
            if tick.get_text() == highlight_label:
                tick.set_color('red')
                tick.set_fontweight('bold')

    plt.xlabel('Viewpoint')
    plt.ylabel(metric_name)
    plt.title(f"{metric_name} across viewpoints")

    if legend_outside:
        plt.legend(loc='upper left', bbox_to_anchor=(1.02, 1), borderaxespad=0)
    else:
        plt.legend()

    plt.tight_layout()
    plt.show()


In [None]:
json_files = ['example_data/Meshfleet_Eval/metrics_results_GT.json', 'example_data/Meshfleet_Eval/metrics_results_front_1.json', 'example_data/Meshfleet_Eval/metrics_results_diagonal1.json', 'example_data/Meshfleet_Eval/metrics_results_front_2.json','example_data/Meshfleet_Eval/metrics_results_diagonal2.json', 'example_data/Meshfleet_Eval/metrics_results_front_3.json','example_data/Meshfleet_Eval/metrics_results_diagonal_3.json']
plot_metric_across_viewpoints(json_files, metric_group='geometric_metrics', metric_name='Squared_Summed_Outline_Normals_Angle_Diff', labels=['GT', 'low_front', 'low_diagonal', 'mid_front', 'mid_diagonal', 'high_front', 'high_diagonal'], highlight_label='000.png')