#### Header

No need to change I guess....

In [None]:
import glob
import re
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objs as go
import plotly.io as pio
import dask


target_var = "2m_temperature"

In [None]:

PRESSURE_VARIABLES=[
    "u_component_of_wind",
    "v_component_of_wind",
    "specific_humidity",
    "temperature",
    "vertical_velocity"
]

SURFACE_VARIABLES=[
    "10m_u_component_of_wind", 
    "10m_v_component_of_wind", 
    "total_precipitation_6hr", 
    "mean_sea_level_pressure",
]

def weighted_mean(dataset: xr.Dataset):
    weights = np.cos(np.deg2rad(dataset.lat))
    weights.name = "weights"
    weighted = dataset.weighted(weights)
    
    return weighted.mean(('lat', 'lon'))


def preprocess_GC(dataset:xr.Dataset, target_var="2m_temperature", region=None):
    if target_var == "2m_temperature":
        dataset = dataset.resample(time="1D").mean().squeeze('batch')
        dataset["time"] = pd.date_range("2021-06-22", periods=7, freq="1D")
        dataset = dataset.rename({"time":"date"})
        dataset = dataset.drop_vars(PRESSURE_VARIABLES + SURFACE_VARIABLES + ['level', 'geopotential'])

    elif target_var == "geopotential":
        dataset = dataset.sel(level=500).resample(time="1D").mean().squeeze('batch')
        dataset["time"] = pd.date_range("2021-06-22", periods=7, freq="1D")
        dataset = dataset.rename({"time":"date"})
        dataset = dataset.drop_vars(PRESSURE_VARIABLES + SURFACE_VARIABLES + ['level', '2m_temperature'])

    # TODO: add region selection 
    if region:
        dataset = dataset.sel(lat=slice(25, 60), lon=slice(102.5, 150))

    return dataset


def preprocess_nwp(dataset:xr.Dataset, target_var="2m_temperature", region=None):
    dataset = dataset.expand_dims(dim={'date': [dataset.time.values[0]]}).compute()
    dataset = dataset.rename({'time': 'ensemble'})
    dataset['ensemble'] = np.arange(1, 51)

    if target_var == "2m_temperature":
        dataset = dataset.drop_vars('height')

    elif target_var == "geopotential":
        dataset.gh.attrs['units'] = 'm^2/s^2'
        dataset.gh.attrs['long_name'] = 'Geopotential'
        dataset = dataset.assign_coords(lev=dataset.lev / 100)
        dataset.lev.attrs['units'] = 'hPa'
        dataset['gh'] = dataset['gh'] * 9.80665
        dataset = dataset.rename({'gh':'geopotential', 'lev':'level'})
        dataset = dataset.sel(level=500)

    # TODO: add region selection 
    if region:
        dataset = dataset.sel(lat=slice(60, 24), lon=slice(102, 150))

    return dataset

def preprocess_era(dataset:xr.Dataset, target_var="2m_temperature", region=None):
    if region:
        dataset = dataset.sel(lat=slice(60, 25), lon=slice(102.5, 150))

# Partition

1. **t2m**
    1. all + all + global  $= 4 \times 23 \times 1 = 92$  
    2. all + only t2m + all $= 4 \times 1 \times 12 = 48$  
    3. all + except t2m + all $= 4 \times 1 \times 12 = 48$  
2. **500hPa geopotential height**
    1. all + all + global  $= 4 \times 23 \times 1 = 92$  
    2. all + only geopotential + all $= 4 \times 1 \times 12 = 48$  
    3. all + except geopotential + all $= 4 \times 1 \times 12 = 48$

In [None]:
if target_var == '2m_temperature':
    p_1 = sorted(glob.glob('/data/GC_output/2021-06-21/GC_???????????_global_scale*.nc'))
    p_2 = sorted(glob.glob('/data/GC_output/2021-06-21/GC_00100000000_*_scale*.nc'))
    p_3 = sorted(glob.glob('/data/GC_output/2021-06-21/GC_11011111111_*_scale*.nc'))
    
elif target_var == 'geopotential':
    p_1 = sorted(glob.glob('/data/GC_output/2021-06-21/GC_???????????_global_scale*.nc'))
    p_2 = sorted(glob.glob('/data/GC_output/2021-06-21/GC_00000100000_*_scale*.nc'))
    p_3 = sorted(glob.glob('/data/GC_output/2021-06-21/GC_11111011111_*_scale*.nc'))  
    

p_4 = sorted(glob.glob('/data/GC_output/2021-06-21/GC_11111111111_global_scale_*.nc'))


# GC Data load

In [None]:
# Assign base colors for each partition
partition_colors = {
    'p_1': 'blue',
    'p_2': 'green',
    'p_3': 'red',
    'p_4': 'purple'
}

# Function to extract perturbation type and value from filename
def extract_perturbation_info(filename):
    match = re.search(r'_([01][01][01][01][01][01][01][01][01][01][01])_(.*?)_(scale|wipeout)_([\d.eE+-]+)\.nc$', filename)
    if match:
        var=match.group(1)
        region=match.group(2)
        perturb_type = match.group(3)
        value = match.group(4)
        return f"{value}_{region}_{var}"
    else:
        return None

# Collect perturbation files with labels and colors
perturb_files = []
for partition_name, partition_files in zip(['p_1', 'p_2', 'p_3'], [p_1, p_2, p_3]):
# for partition_name, partition_files in zip(['p_4'], [p_4]):
    base_color = partition_colors[partition_name]
    num_files = len(partition_files)
    # Generate different shades of the base color
    colors = sns.light_palette(base_color, n_colors=num_files + 2)[1:-1]
    for i, file in enumerate(partition_files):
        perturb_info = extract_perturbation_info(file)
        if perturb_info:
            label = f"{partition_name} {perturb_info}"
            color = colors[i % len(colors)]
            perturb_files.append((label, color, file))

perturb_datasets = []

from multiprocessing import Pool

def process_file(file_info):
    label, color, file = file_info
    dataset = weighted_mean(preprocess_GC(xr.open_dataset(file), target_var))
    return (label, color, dataset)

with Pool(processes=30) as pool:
    perturb_datasets = pool.map(process_file, perturb_files)

# NWP Data load

In [None]:
def piping(dataset:xr.Dataset, target_var):
    return weighted_mean(preprocess_nwp(dataset, target_var))

from functools import partial
pipe = partial(piping, target_var = target_var)

if target_var == '2m_temperature':
    files = sorted(glob.glob('/geodata2/S2S/ECMWF_Perturbed/Dailyaveraged/t2m/nc/*/Temperature2m_2021-06-21.nc'))

elif target_var == 'geopotential':
    files = sorted(glob.glob('/geodata2/S2S/ECMWF_Perturbed/InstantaneousAccumulated/z/nc/*/Z_2021-06-21.nc'))

files = [f for f in files if int(f.split('/')[-2]) <= 24 * 7]
nwp = xr.open_mfdataset(
    files,
    combine='by_coords',
    preprocess=pipe
)
if target_var == "2m_temperature":
    nwp = nwp.rename({"2t":"2m_temperature"})

nwp = nwp.compute()
df = nwp[target_var].to_dataframe().reset_index()

In [None]:

# # Plot ensemble members
# first_ensemble = True
# for ensemble in df['ensemble'].unique():
#     subset = df[df['ensemble'] == ensemble]
#     if first_ensemble:
#         plt.plot(subset['date'], subset['2m_temperature'], color='grey', linewidth=0.5, alpha=0.5, label='Ensemble Members')
#         first_ensemble = False
#     else:
#         plt.plot(subset['date'], subset['2m_temperature'], color='grey', linewidth=0.5, alpha=0.5)

# # Plot ensemble mean
# mean_temp = df.groupby('date')['2m_temperature'].mean().reset_index()
# plt.plot(mean_temp['date'], mean_temp['2m_temperature'], color='black', linewidth=1, label='Ensemble Mean')

# # Plot perturbation datasets
# for label, color, dataset in perturb_datasets:
#     plt.plot(dataset['date'], dataset['2m_temperature'], color=color, linewidth=1, label=label)

# # Optional: Plot ERA5 data if available
# # era5 = xr.open_dataset("/camdata2/ERA5/daily/t2m/2021.nc").rename({"time":"date", "latitude":"lat", "longitude":"lon"}).sel(date=slice("2021-06-22", "2021-07-01"))
# # era5 = weighted_mean(era5)
# # plt.plot(era5['date'], era5['t2m'], color='red', linewidth=1.5, linestyle='dashed', label='ERA5')

# plt.title('Mean 2m Temperature Forecast / 2021-06-21  + 10 days', fontsize=16)
# plt.xlabel('Date', fontsize=12)
# plt.ylabel('Temperature (K)', fontsize=12)
# plt.legend(fontsize=5)

# # Adjust y-axis limits based on all datasets
# all_temps = []
# for _, _, dataset in perturb_datasets:
#     all_temps.extend(dataset['2m_temperature'].values)
# all_temps.extend(df['2m_temperature'].values)
# y_min = min(all_temps) - 2
# y_max = max(all_temps) + 2
# plt.ylim(y_min, y_max)

# plt.tight_layout()
# plt.show()
# # plt.savefig('figure/2m_temperature_forecast_mean_2021-06-21.png')

In [None]:
plt.figure(figsize=(16, 9))
sns.set_style("whitegrid")

# Plot ensemble members
if target_var == '2m_temperature': 
    first_ensemble = True
    for ensemble in df['ensemble'].unique():
        subset = df[df['ensemble'] == ensemble]
        if first_ensemble:
            plt.plot(subset['date'], subset[target_var], color='grey', linewidth=0.5, alpha=0.5, label='Ensemble Members')         
            plt.plot(subset['date'], subset[target_var], color='black', linewidth=0.5, alpha=0.5, label='Ensemble Mean')   
            first_ensemble = False
        else:
            plt.plot(subset['date'], subset[target_var], color='grey', linewidth=0.5, alpha=0.5)


# Plot perturbation datasets
for label, color, dataset in perturb_datasets:
    plt.plot(dataset['date'], dataset[target_var], color=color, linewidth=1, label=label)

# Optional: Plot ERA5 data if available
# era5 = xr.open_dataset("/camdata2/ERA5/daily/t2m/2021.nc").rename({"time":"date", "latitude":"lat", "longitude":"lon"}).sel(date=slice("2021-06-22", "2021-07-01"))
# era5 = weighted_mean(era5)
# plt.plot(era5['date'], era5[target_var], color='red', linewidth=1.5, linestyle='dashed', label='ERA5')



plt.xlabel('Date', fontsize=12)
if target_var == '2m_temperature':
    plt.title('Mean 2m Temperature Forecast / 2021-06-21  + 7 days', fontsize=16)
    plt.ylabel('Temperature (K)', fontsize=12)
elif target_var == 'geopotential':
    plt.title('Mean 500hPa geopotential Forecast / 2021-06-21  + 7 days', fontsize=16)
    plt.ylabel('Geopotential (m^2/s^2)', fontsize=12)


# Adjust y-axis limits based on all datasets
all_temps = []
for _, _, dataset in perturb_datasets:
    all_temps.extend(dataset[target_var].values)
# all_temps.extend(df[target_var].values)
# y_min = min(mean_temp[target_var]) - 1
# y_max = max(mean_temp[target_var]) + 1
# plt.ylim(y_min, y_max)
plt.xlim([pd.Timestamp('2021-06-22'), pd.Timestamp('2021-06-28')])

plt.tight_layout()
plt.show()

#

In [None]:
import plotly.graph_objs as go
import plotly.io as pio
import pandas as pd


ensemble_lines = []
ensemble_legend_shown = False  # Flag to control legend display
for ensemble in df['ensemble'].unique():
    subset = df[df['ensemble'] == ensemble]
    ensemble_lines.append(go.Scatter(
        x=subset['date'],
        y=subset[target_var],
        mode='lines',
        line=dict(color='grey', width=0.5),
        opacity=0.5,
        name='Ensemble Members' if not ensemble_legend_shown else None,
        showlegend=not ensemble_legend_shown,
        legendgroup='Ensemble Members',
        legendgrouptitle_text='Ensemble Members'
    ))
    ensemble_legend_shown = True  # Only show legend once
    
# Plot ensemble mean
mean_temp = df.groupby('date')[target_var].mean().reset_index()
ensemble_mean_line = go.Scatter(
    x=mean_temp['date'],
    y=mean_temp[target_var],
    mode='lines',
    line=dict(color='black', width=1),
    name='Ensemble Mean'
)

# Plot perturbation datasets
perturb_lines = []
partition_legend_shown = {}  # Dictionary to track legend entries per partition
for label, color, dataset in perturb_datasets:
    # Extract partition name from label (assuming label starts with partition name)
    partition_name = label.split()[1]
    # Only show legend once per partition
    if partition_name not in partition_legend_shown:
        show_legend = True
        partition_legend_shown[partition_name] = True
    else:
        show_legend = False
    # Convert color to a valid format if it's a tuple
    if isinstance(color, tuple):
        color = f'rgb({int(color[0] * 255)}, {int(color[1] * 255)}, {int(color[2] * 255)})'
    perturb_lines.append(go.Scatter(
        x=dataset['date'],
        y=dataset[target_var],
        mode='lines',
        line=dict(color=color, width=1),
        name=partition_name if show_legend else None,
        showlegend=show_legend,
        legendgroup=label.split()[0],
        legendgrouptitle_text=label.split()[0]
    ))

# Optional: Plot ERA5 data if available
# Uncomment and adjust accordingly if ERA5 data is available
# era5_line = go.Scatter(
#     x=era5['date'],
#     y=era5['t2m'],
#     mode='lines',
#     line=dict(color='red', width=1.5, dash='dash'),
#     name='ERA5',
#     legendgroup='ERA5',
#     legendgrouptitle_text='ERA5'
# )

# Combine all traces
all_traces = perturb_lines + ensemble_lines + [ensemble_mean_line] # + [era5_line]


if target_var == '2m_temperature':
    title = 'Mean 2m Temperature Forecast / 2021-06-21  + 7 days'
    unit = 'Temperature (K)'
elif target_var == 'geopotential':
    title = 'Mean 500hPa Geopotential Forecast / 2021-06-21  + 7 days'
    unit = 'Geopotential (m^2/s^2)'

# Create the layout
layout = go.Layout(
    title=title,
    xaxis=dict(title='Date', range=[pd.Timestamp('2021-06-22'), pd.Timestamp('2021-06-28')]),
    yaxis=dict(title=unit),
    margin=dict(l=40, r=40, t=40, b=40),
    height=900,  # Increased height for better visibility
    width=1600,  # Increased width for better visibility
    template='plotly_white',
    legend=dict(
        title='Legend',
        orientation='v',  # Vertical legend
        x=1.05,  # Position it just outside the right edge of the plot
        y=1,    # Align at the top
        itemsizing='constant',  # Makes the legend box size consistent
        # traceorder='grouped',  # Groups traces in the legend
        itemclick='toggle',  # Enables toggling traces on and off
        itemdoubleclick='toggleothers'  # Double-clicking will turn other traces off
    )
)

# Create the figure
fig = go.Figure(data=all_traces, layout=layout)

# Save the figure as an interactive HTML file
# fig.write_html("interactive_geopotential_forecast.html")

# Optional: Show the figure in the browser (still interactive)
pio.show(fig)


# Claude

In [31]:
import pickle

with open("/data/GC_output/analysis/GC_t2m_GlobAvg.pkl", "rb") as f:
    perturb_datasets = pickle.load(f)

nwp = xr.open_dataset("/data/GC_output/analysis/nwp_t2m_GlobAvg.nc")

# Plot perturbation datasets with enhanced legend grouping
perturb_lines = []
partition_groups = {}

for label, color, dataset in perturb_datasets:
    partition_name = label.split()[0]
    if partition_name not in partition_groups:
        partition_groups[partition_name] = []
    
    # Convert color to RGB if it's a tuple
    if isinstance(color, tuple):
        color = f'rgb({int(color[0] * 255)}, {int(color[1] * 255)}, {int(color[2] * 255)})'
    
    # Create trace
    trace = go.Scatter(
        x=dataset['date'],
        y=dataset[target_var],
        mode='lines',
        line=dict(color=color, width=1),
        name=label.split(' ')[1],
        legendgroup=partition_name,
        legendgrouptitle=dict(
            text=partition_name,
            font=dict(size=14, color='black', family='Arial Bold')
        ) if len(partition_groups[partition_name]) == 0 else None,
        showlegend=True
    )
    partition_groups[partition_name].append(trace)
    perturb_lines.append(trace)

# Add ensemble lines
ensemble_legend_shown = False  # Flag to control legend display
for ensemble in df['ensemble'].unique():
    subset = df[df['ensemble'] == ensemble]
    trace = go.Scatter(
        x=subset['date'],
        y=subset[target_var],
        mode='lines',
        line=dict(color='grey', width=0.5),
        opacity=0.5,
        name='Ensemble Members' if not ensemble_legend_shown else None,
        showlegend=not ensemble_legend_shown,
        legendgroup='Ensemble Members',
        legendgrouptitle=dict(
            text='Ensemble Members',
            font=dict(size=14, color='black', family='Arial Bold')
        ) if not ensemble_legend_shown else None,
    )
    perturb_lines.append(trace)
    ensemble_legend_shown = True  # Only show legend once

# Add ensemble mean line
mean_temp = df.groupby('date')[target_var].mean().reset_index()
ensemble_mean_line = go.Scatter(
    x=mean_temp['date'],
    y=mean_temp[target_var],
    mode='lines',
    line=dict(color='black', width=2),
    name='Ensemble Mean',
    legendgroup='Ensemble Mean',
    legendgrouptitle=dict(
        text='Ensemble Mean',
        font=dict(size=14, color='black', family='Arial Bold')
    ),
    showlegend=True
)
perturb_lines.append(ensemble_mean_line)

# Calculate y-axis range
y_values = []
for _, _, dataset in perturb_datasets:
    y_values.extend(dataset[target_var].values)
y_values.extend(df[target_var].values)  # Include ensemble data
y_min = np.min(y_values)
y_max = np.max(y_values)
y_range = y_max - y_min
y_padding = y_range * 0.05

# Create the layout
layout = go.Layout(
    title=title,
    xaxis=dict(
        title='Date', 
        range=[pd.Timestamp('2021-06-22'), pd.Timestamp('2021-06-28')]
    ),
    yaxis=dict(
        title=unit,
        range=[y_min - y_padding, y_max + y_padding],
        tickformat='.0f'
    ),
    margin=dict(l=40, r=40, t=40, b=40),
    height=900,
    width=1600,
    template='plotly_white',
    legend=dict(
        title=dict(
            text='Legend',
            font=dict(size=16)
        ),
        orientation='v',
        x=1.05,
        y=1,
        itemsizing='constant',
        groupclick='toggleitem',
        itemclick='toggle',
        itemdoubleclick='toggleothers',
        tracegroupgap=15,
        font=dict(size=12),
        grouptitlefont=dict(size=14, color='black', family='Arial Bold'),
        borderwidth=1,
        bordercolor='rgba(0,0,0,0.2)',
        bgcolor='rgba(255,255,255,0.95)',
        traceorder='grouped'
    )
)

# Create the figure
fig = go.Figure(data=perturb_lines, layout=layout)

# Save and display the figure
fig.write_html(f"interactive_{target_var}_forecast.html")
pio.show(fig)


# Error Analysis

In [None]:
# from error_accum import compute_error_accumulation_jax

# results_df = compute_error_accumulation_jax(
#     nwp_data=nwp,
#     perturb_datasets=[item[2] for item in perturb_datasets],
#     target_var='geopotential'  # or 'z500'
# )

# print(results_df)

# Scratch Paper

In [None]:
# Plot perturbation datasets with enhanced legend grouping
perturb_lines = []
partition_groups = {}

for label, color, dataset in perturb_datasets:
    partition_name = label.split()[0]
    if partition_name not in partition_groups:
        partition_groups[partition_name] = []
    
    # Convert color to RGB if it's a tuple
    if isinstance(color, tuple):
        color = f'rgb({int(color[0] * 255)}, {int(color[1] * 255)}, {int(color[2] * 255)})'
    
    # Create trace
    trace = go.Scatter(
        x=dataset['date'],
        y=dataset[target_var],
        mode='lines',
        line=dict(color=color, width=1),
        name=label.split(' ')[0],
        legendgroup=partition_name,
        legendgrouptitle=dict(
            text=partition_name,
            font=dict(size=14, color='black', family='Arial Bold')
        ) if len(partition_groups[partition_name]) == 0 else None,
        showlegend=True
    )
    partition_groups[partition_name].append(trace)
    perturb_lines.append(trace)

# Add ensemble lines
ensemble_legend_shown = False  # Flag to control legend display
for ensemble in df['ensemble'].unique():
    subset = df[df['ensemble'] == ensemble]
    trace = go.Scatter(
        x=subset['date'],
        y=subset[target_var],
        mode='lines',
        line=dict(color='grey', width=0.5),
        opacity=0.5,
        name='Ensemble Members' if not ensemble_legend_shown else None,
        showlegend=not ensemble_legend_shown,
        legendgroup='Ensemble Members',
        legendgrouptitle=dict(
            text='Ensemble Members',
            font=dict(size=14, color='black', family='Arial Bold')
        ) if not ensemble_legend_shown else None,
    )
    perturb_lines.append(trace)
    ensemble_legend_shown = True  # Only show legend once

# Add ensemble mean line
mean_temp = df.groupby('date')[target_var].mean().reset_index()
ensemble_mean_line = go.Scatter(
    x=mean_temp['date'],
    y=mean_temp[target_var],
    mode='lines',
    line=dict(color='black', width=2),
    name='Ensemble Mean',
    legendgroup='Ensemble Mean',
    legendgrouptitle=dict(
        text='Ensemble Mean',
        font=dict(size=14, color='black', family='Arial Bold')
    ),
    showlegend=True
)
perturb_lines.append(ensemble_mean_line)

# Calculate y-axis range
y_values = []
for _, _, dataset in perturb_datasets:
    y_values.extend(dataset[target_var].values)
y_values.extend(df[target_var].values)  # Include ensemble data
y_min = np.min(y_values)
y_max = np.max(y_values)
y_range = y_max - y_min
y_padding = y_range * 0.05

# Create the layout
layout = go.Layout(
    title=title,
    xaxis=dict(
        title='Date', 
        range=[pd.Timestamp('2021-06-22'), pd.Timestamp('2021-06-28')]
    ),
    yaxis=dict(
        title=unit,
        range=[y_min - y_padding, y_max + y_padding],
        tickformat='.0f'
    ),
    margin=dict(l=40, r=40, t=40, b=40),
    height=900,
    width=1600,
    template='plotly_white',
    legend=dict(
        title=dict(
            text='Legend',
            font=dict(size=16)
        ),
        orientation='v',
        x=1.05,
        y=1,
        itemsizing='constant',
        groupclick='toggleitem',
        itemclick='toggle',
        itemdoubleclick='toggleothers',
        tracegroupgap=15,
        font=dict(size=12),
        grouptitlefont=dict(size=14, color='black', family='Arial Bold'),
        borderwidth=1,
        bordercolor='rgba(0,0,0,0.2)',
        bgcolor='rgba(255,255,255,0.95)',
        traceorder='grouped'
    )
)

# Create the figure
fig = go.Figure(data=perturb_lines, layout=layout)

# Save and display the figure
fig.write_html(f"interactive_{target_var}_forecast.html")
pio.show(fig)


In [None]:
nwp.to_netcdf('/data/GC_output/analysis/nwp_t2m_GlobAvg.nc')

import pickle

with open('/data/GC_output/analysis/GC_t2m_GlobAvg.pkl', 'wb') as f:
    pickle.dump(perturb_datasets, f)

target_var = "geopotential"

pipe = partial(piping, target_var = target_var)

files = sorted(glob.glob('/geodata2/S2S/ECMWF_Perturbed/InstantaneousAccumulated/z/nc/*/Z_2021-06-21.nc'))
files = [f for f in files if int(f.split('/')[-2]) <= 24 * 7]
nwp = xr.open_mfdataset(
    files,
    combine='by_coords',
    preprocess=pipe
)
if target_var == "2m_temperature":
    nwp = nwp.rename({"2t":"2m_temperature"})

nwp = nwp.compute()

nwp.to_netcdf('/data/GC_output/analysis/nwp_z500_GlobAvg.nc')


if target_var == 'geopotential':
    p_1 = sorted(glob.glob('/data/GC_output/2021-06-21/GC_???????????_global_scale*.nc'))
    p_2 = sorted(glob.glob('/data/GC_output/2021-06-21/GC_00000100000_*_scale*.nc'))
    p_3 = sorted(glob.glob('/data/GC_output/2021-06-21/GC_11111011111_*_scale*.nc')) 
# Assign base colors for each partition
partition_colors = {
    'p_1': 'blue',
    'p_2': 'green',
    'p_3': 'red',
    'p_4': 'purple'
}

# Collect perturbation files with labels and colors
perturb_files = []
for partition_name, partition_files in zip(['p_1', 'p_2', 'p_3'], [p_1, p_2, p_3]):
# for partition_name, partition_files in zip(['p_4'], [p_4]):
    base_color = partition_colors[partition_name]
    num_files = len(partition_files)
    # Generate different shades of the base color
    colors = sns.light_palette(base_color, n_colors=num_files + 2)[1:-1]
    for i, file in enumerate(partition_files):
        perturb_info = extract_perturbation_info(file)
        if perturb_info:
            label = f"{perturb_info}"
            color = colors[i % len(colors)]
            perturb_files.append((label, color, file))

perturb_datasets = []

def process_file(file_info):
    label, color, file = file_info
    dataset = weighted_mean(preprocess_GC(xr.open_dataset(file), target_var))
    return (label, color, dataset)

with Pool(processes=30) as pool:
    perturb_datasets = pool.map(process_file, perturb_files)

with open('/data/GC_output/analysis/GC_z500_GlobAvg.pkl', 'wb') as f:
    pickle.dump(perturb_datasets, f)

In [None]:
perturb_datasets
