In [13]:
import json
import numpy as np

# Function to read JSON file and parse results
def read_results(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        results = json.load(f)
    return results

# Function to compute mean metrics over the classes
def compute_mean_metrics(results):
    analysis = {}
    for config, data in results.items():
        few_shot_data = data.get("few_shot", {})
        analysis[config] = {}
        for shot, shot_data in few_shot_data.items():
            analysis[config][shot] = {"train": {}, "test": {}}
            for res_type in ["train", "test"]:
                res_data = shot_data.get(res_type, {})
                overall_metrics = res_data.get("overall", {})
                class_metrics = {k: v for k, v in res_data.items() if k != "overall"}
                
                # Dynamically get the metric names from the class metrics
                if class_metrics:
                    sample_class = next(iter(class_metrics.values()))
                    metric_names = sample_class.keys()
                else:
                    metric_names = []
                
                # Initialize dictionary to hold the sums and counts for each metric
                metric_sums = {metric: 0 for metric in metric_names}
                metric_counts = {metric: 0 for metric in metric_names}
                
                # Collect metric values for each class
                for metrics in class_metrics.values():
                    for metric in metric_names:
                        value = metrics[metric]
                        if metric in metrics and not np.isnan(value) and value != -1:
                            metric_sums[metric] += value
                            metric_counts[metric] += 1
                
                # Compute mean for each metric
                mean_metrics = {metric: (metric_sums[metric] / metric_counts[metric] if metric_counts[metric] > 0 else float('nan')) for metric in metric_names}
                analysis[config][shot][res_type] = {
                    "mean_metrics": mean_metrics,
                    "overall": overall_metrics
                }
                
    return analysis

# Function to print analysis results
def print_analysis(analysis):
    for config, data in analysis.items():
        print(f"Config: {config}")
        for shot, shot_data in data.items():
            print(f"  Shot: {shot}")
            for res_type, res_data in shot_data.items():
                mean_metrics = res_data["mean_metrics"]
                overall = res_data["overall"]
                print(f"    {res_type.capitalize()}:")
                for metric, value in mean_metrics.items():
                    print(f"      Mean {metric}: {value:.4f}")
                print(f"      Overall Metrics: {overall}")

In [14]:
# Path to the JSON results file
file_path = 'evaluation_results.json'

# Read the results from the JSON file
results = read_results(file_path)

# Compute the analysis
analysis = compute_mean_metrics(results)
print(analysis)
# Print the analysis
print_analysis(analysis)

{'fcos_PVT_V2_B2_LI_FPN_RETINANET_DOTA.yaml': {'1_shot': {'train': {'mean_metrics': {'AP': 0.22985315794179295, 'AP50': 0.46208595592760915, 'AP75': 0.19618684590283667, 'APs': 0.11526205081481823, 'APm': 0.25981433616809163, 'APl': 0.32909304642691833}, 'overall': {'AP': 0.23188080445577552, 'AP50': 0.42217647024369953, 'AP75': 0.2302174036273389, 'APs': 0.29920536498061295, 'APm': 0.23801598296124424, 'APl': 0.3210372678343144}}, 'test': {'mean_metrics': {'AP': 0.08653326682288144, 'AP50': 0.17050578125707072, 'AP75': 0.08345988111254893, 'APs': 0.06436167150106861, 'APm': 0.10191601176728643, 'APl': 0.11975179410020494}, 'overall': {'AP': 0.08653326682288145, 'AP50': 0.17050578125707072, 'AP75': 0.08345988111254893, 'APs': 0.06436167150106861, 'APm': 0.10191601176728643, 'APl': 0.11975179410020495}}}, '10_shot': {'train': {'mean_metrics': {'AP': 0.26674808952491075, 'AP50': 0.5374720466719125, 'AP75': 0.2319565906691231, 'APs': 0.1281714372653335, 'APm': 0.30015004094276865, 'APl': 