In [None]:
data = [
    {   
        "dataset": "ImageNet",
        "mask_type": "expand",
        "algorithm": "Copaint"
        "mean_lpips": 0.503011167049408,
        "mean_psnr": 17.135379791259766,
        "mean_ssim": 0.5807474851608276,
        "final_loss": 1.5211669206619263,
        "mean time": 244.44522905349731
    },

    {   
        "dataset": "ImageNet",
        "mask_type": "half",
        "algorithm": "Copaint"
        "mean_lpips": 0.18792548775672913,
        "mean_psnr": 20.609283447265625,
        "mean_ssim": 0.7838151454925537,
        "final_loss": 9.873313903808594,
        "mean time": 277.57729959487915
    },

    {   
        "dataset": "ImageNet",
        "mask_type": "line",
        "algorithm": "Copaint"
        "mean_lpips": 0.034693643450737,
        "mean_psnr": 38.161407470703125,
        "mean_ssim": 0.9552620649337769,
        "final_loss": 13.067879676818848,
        "mean time": 253.88609099388123
    },

    {   
        "dataset": "ImageNet",
        "mask_type": "sr2",
        "algorithm": "Copaint"
        "mean_lpips": 0.05219435691833496,
        "mean_psnr": 37.041473388671875,
        "mean_ssim": 0.9436663389205933,
        "final_loss": 7.975658416748047,
        "mean time": 244.29270911216736
    },

    {   
        "dataset": "ImageNet",
        "mask_type": "expand",
        "algorithm": "Copaint"
        "mean_lpips": 0.503011167049408,
        "mean_psnr": 17.135379791259766,
        "mean_ssim": 0.5807474851608276,
        "final_loss": 1.5211669206619263,
        "mean time": 244.44522905349731
    },

    {   
        "dataset": "ImageNet",
        "mask_type": "expand",
        "algorithm": "Copaint"
        "mean_lpips": 0.503011167049408,
        "mean_psnr": 17.135379791259766,
        "mean_ssim": 0.5807474851608276,
        "final_loss": 1.5211669206619263,
        "mean time": 244.44522905349731
    },
    
]

In [None]:
import os
import json
import numpy as np
import matplotlib.pyplot as plt

# Directory where result files are stored
base_dir = 'train_results'

# Function to parse result files and extract the metrics
def extract_metrics_from_result(file_path):
    metrics = {}
    with open(file_path, 'r') as f:
        for line in f:
            try:
                # Try loading each line as a JSON object
                data = json.loads(line)
                # Extract relevant metrics from the JSON data
                if 'mean_lpips' in data:
                    metrics['mean_lpips'] = data['mean_lpips']
                elif 'mean_psnr' in data:
                    metrics['mean_psnr'] = data['mean_psnr']
                elif 'mean_ssim' in data:
                    metrics['mean_ssim'] = data['mean_ssim']
                elif 'final_loss' in data:
                    metrics['final_loss'] = data['final_loss']
                elif 'mean time' in data:
                    metrics['mean_time'] = data['mean time']
            except json.JSONDecodeError:
                continue
    return metrics

# Prepare lists to hold the data for plotting
datasets = []
mask_types = []
algorithms = []
metrics_data = {
    "mean_lpips": [],
    "mean_psnr": [],
    "mean_ssim": [],
    "final_loss": [],
    "mean_time": []
}

# Walk through the directory and collect data
for algo in os.listdir(base_dir):
    algo_dir = os.path.join(base_dir, algo)
    if not os.path.isdir(algo_dir):
        continue
    
    for dataset in os.listdir(algo_dir):
        dataset_dir = os.path.join(algo_dir, dataset)
        if not os.path.isdir(dataset_dir):
            continue
        
        for mask in os.listdir(dataset_dir):
            mask_dir = os.path.join(dataset_dir, mask)
            if not os.path.isdir(mask_dir):
                continue

            # Get all .result files in the mask folder
            result_files = [f for f in os.listdir(mask_dir) if f.endswith('.result')]
            if len(result_files) < 2:
                continue  # If there are less than 2 result files, skip this mask

            # Sort result files to pick the second one
            result_files.sort()
            second_result_file = os.path.join(mask_dir, result_files[1])

            # Extract metrics from the second result file
            metrics = extract_metrics_from_result(second_result_file)
            
            if metrics:
                # Store the dataset, mask, algorithm and corresponding metrics
                datasets.append(dataset)
                mask_types.append(mask)
                algorithms.append(algo)

                for metric in metrics_data:
                    metrics_data[metric].append(metrics.get(metric, None))

# Plot the data
bar_width = 0.2
x = np.arange(len(datasets))  # x-axis positions for bars

# Plot each metric
for metric in metrics_data:
    plt.figure(figsize=(12, 6))
    for m_idx, mask in enumerate(mask_types):
        for a_idx, algo in enumerate(algorithms):
            # Calculate bar positions
            positions = x + (m_idx * len(algorithms) + a_idx) * bar_width
            # Plot bars for each algorithm
            plt.bar(positions, metrics_data[metric], bar_width,
                    label=f'{mask} - {algo}')
    
    # Set plot details
    plt.xlabel('Datasets')
    plt.ylabel(metric.replace('_', ' ').title())
    plt.title(f'{metric.replace("_", " ").title()} Across Datasets, Mask Types, and Algorithms')
    plt.xticks(x + bar_width * (len(mask_types) * len(algorithms) - 1) / 2, datasets)
    plt.legend()
    plt.tight_layout()
    plt.show()
