In [5]:
ensemble_list = ["ensemble_1001_001", "ensemble_1011_001", "ensemble_1021_001", "ensemble_1031_002", "ensemble_1041_001", "ensemble_1051_003", "ensemble_1061_001", "ensemble_1071_004", "ensemble_1081_001", "ensemble_1091_005", "ensemble_1101_001",
                 "ensemble_1111_006", "ensemble_1121_001", "ensemble_1131_007", "ensemble_1141_001", "ensemble_1151_008", "ensemble_1161_001", "ensemble_1171_009", "ensemble_1181_001", "ensemble_1191_010",
                 "ensemble_1231_001", "ensemble_1251_001", "ensemble_1281_001", "ensemble_1301_001", "ensemble_1301_002", "ensemble_1301_010", "ensemble_1301_011"]

In [6]:
from collections import defaultdict
import numpy as np
import json

def parallel_variance(n_a, avg_a, M2_a, n_b, avg_b, M2_b):
    """ Combines two sets of statistics (size, mean, and M2) to calculate a combined variance. """
    n = n_a + n_b
    delta = avg_b - avg_a
    M2 = M2_a + M2_b + delta**2 * n_a * n_b / n
    var_ab = M2 / (n - 1) if n > 1 else 0  # Bessel's correction for unbiased estimator
    return var_ab, M2, n

def calculate_global_mean_std_parallel(ensemble_dict):
    """
    Calculates the global mean and standard deviation over multiple ensembles for each level using a parallel algorithm.

    Parameters:
    ensemble_dict (defaultdict of dict): A defaultdict where each key is a level and the value is
                                         a dictionary containing two lists ('mean' and 'std') representing
                                         the means and standard deviations for different ensembles at that level.

    Returns:
    final_stats (dict): A dictionary where each key is a level, and the value is a dictionary
                        containing 'global_mean' and 'global_std' for that level.
    """

    final_stats = {}

    # Iterate over each level in the ensemble_dict
    for level, stats in ensemble_dict.items():
        means = stats['mean']
        stds = stats['std']
        n_ensembles = len(means)
        
        # Initialize for parallel variance calculation
        global_mean = means[0]
        M2 = 0
        n = 1

        # Accumulate global mean and variance in parallel
        for i in range(1, n_ensembles):
            n_a = n
            n_b = 1  # Each ensemble has weight 1 (equal sample sizes)

            avg_a = global_mean
            avg_b = means[i]

            # Calculate M2 (sum of squares of differences from the mean) for each ensemble
            M2_b = stds[i] ** 2

            # Combine statistics for the current ensemble with the global statistics so far
            var_ab, M2, n = parallel_variance(n_a, avg_a, M2, n_b, avg_b, M2_b)
            global_mean = (n_a * avg_a + n_b * avg_b) / n  # Updated global mean

        # Final global standard deviation is the square root of the combined variance
        global_std = np.sqrt(var_ab)

        # Store the results for this level
        final_stats[f"level_{level}"] = {
            "global_mean": global_mean,
            "global_std": global_std
        }

    return final_stats

def calculate_global_mean_std(ensemble_list, data_dict, levels=["0", "1", "3", "5", "8", "11", "14", 'tux', 'tuy', 'sst', 'ssh']):

    # Initialize a dictionary to store the global mean and std for each level
    global_stats = defaultdict(lambda: {"mean": [], "std": []})
    
    # Iterate through each key in the data dictionary
    for key, value in data_dict.items():
        # Split the key to extract the ensemble and level
        parts = key.split('_')
        ensemble = '_'.join(parts[:3])  # e.g., "ensemble_1001_001"
        # Extract the level number, assuming "levelX" or "levelXX" format
        #level_str = parts[-1].replace('level', '')  # Remove the "level" part
        level_str = parts[-1].replace('a', '')  # Remove the "a" part
        level = level_str  # Convert the remaining part to an integer

        # Check if the ensemble is in the input list and the level is one of the levels of interest
        if ensemble in ensemble_list and level in levels:
            # Append the mean and std from the dictionary for that level
            global_stats[level]["mean"].append(value["mean"])
            global_stats[level]["std"].append(value["std"])

    # Calculate the global mean and std for each level
    final_stats = calculate_global_mean_std_parallel(global_stats)

    return final_stats

# Open and read the file
with open('C:/Users/felix/PycharmProjects/deeps2a-enso/scripts/data_processing/CESM2/lens/calculate_ensemble_metrics/ensemble_stats_1_2_all.txt', 'r') as file:
    data_dict = json.load(file)

# Example usage
global_stats = calculate_global_mean_std(ensemble_list, data_dict)
print(global_stats)

# Save the resulting dictionary to a file
with open('global_stats_1_1grid.json', 'w') as outfile:
    json.dump(global_stats, outfile)

{'level_tux': {'global_mean': -3.9094611618755137e-20, 'global_std': 0.19905465166576986}, 'level_tuy': {'global_mean': 1.7936438524139273e-20, 'global_std': 0.13234788275774506}, 'level_0': {'global_mean': -1.4607491879568907e-17, 'global_std': 0.6707193411666164}, 'level_11': {'global_mean': 4.809813862225758e-19, 'global_std': 0.9499027942590248}, 'level_14': {'global_mean': -1.1246988115251005e-17, 'global_std': 0.9049123856449006}, 'level_1': {'global_mean': 6.5882962321022365e-18, 'global_std': 0.6708457100967729}, 'level_3': {'global_mean': 6.049085615681659e-18, 'global_std': 0.7219668185682295}, 'level_5': {'global_mean': -1.404773589676592e-17, 'global_std': 0.8409158748930468}, 'level_8': {'global_mean': -1.8972213187577904e-18, 'global_std': 0.911882499347164}}
