#### Header

No need to change I guess....

In [4]:
import glob
import re
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objs as go
import plotly.io as pio
import dask
from dask import delayed, compute
from dask.distributed import Client

PRESSURE_VARIABLES=[
    "geopotential",
    "u_component_of_wind",
    "v_component_of_wind",
    "specific_humidity",
    "temperature",
    "vertical_velocity"
]

SURFACE_VARIABLES=[
    "10m_u_component_of_wind", 
    "10m_v_component_of_wind", 
    "total_precipitation_6hr", 
    "mean_sea_level_pressure"
]

# Define the weighted_mean function (ensure it's properly defined in your code)
def weighted_mean(dataset: xr.Dataset):
    # for GC_* files
    if "level" in dataset.dims:
        dataset = dataset.resample(time="1D").mean().squeeze('batch')
        dataset = dataset.drop_vars(PRESSURE_VARIABLES + SURFACE_VARIABLES + ['level'])
        dataset["time"] = pd.date_range("2021-06-22", periods=7, freq="1D")
        dataset = dataset.rename({"time":"date"})
        # dataset = dataset.sel(lat=slice(25, 60), lon=slice(102.5, 150))
    
    # for nwp files
    elif "time" in dataset.dims:
        dataset = dataset.expand_dims(dim={'date': [dataset.time.values[0]]}).compute()
        dataset = dataset.rename({'time': 'ensemble'})
        dataset['ensemble'] = np.arange(1, 51)
        dataset = dataset.drop_vars('height')
        # dataset = dataset.sel(lat=slice(60, 24), lon=slice(102, 150))
    
    # for era5 files
    # else:
        # dataset = dataset.sel(lat=slice(60, 25), lon=slice(102.5, 150))
    
    weights = np.cos(np.deg2rad(dataset.lat))
    weights.name = "weights"
    weighted = dataset.weighted(weights)
    
    return weighted.mean(('lat', 'lon'))

client = Client()
dask.config.set({"array.query-planning": False})


Perhaps you already have a cluster running?
Hosting the HTTP server on port 33723 instead


<dask.config.set at 0x14b4b8ee5690>

#### Partition samples in reasonable way

1. **t2m**
    1. all + all + global  $= 4 \times 23 \times 1 = 92$  
    2. all + only t2m + all $= 4 \times 1 \times 12 = 48$  
    3. all + except t2m + all $= 4 \times 1 \times 12 = 48$  
2. **500hPa geopotential height**
    1. all + all + global  $= 4 \times 23 \times 1 = 92$  
    2. all + only geopotential + all $= 4 \times 1 \times 12 = 48$  
    3. all + except geopotential + all $= 4 \times 1 \times 12 = 48$

In [5]:
target_var = '2m_temperature'

if target_var == '2m_temperature':
    p_1 = sorted(glob.glob('/data/GC_output/2021-06-21/GC_???????????_global_scale*.nc'))
    p_2 = sorted(glob.glob('/data/GC_output/2021-06-21/GC_00100000000_*_scale*.nc'))
    p_3 = sorted(glob.glob('/data/GC_output/2021-06-21/GC_11011111111_*_scale*.nc'))
    
elif target_var == 'geopotential':
    p_1 = sorted(glob.glob('/data/GC_output/2021-06-21/GC_???????????_global_scale*.nc'))
    p_2 = sorted(glob.glob('/data/GC_output/2021-06-21/GC_00000100000_*_scale*.nc'))
    p_3 = sorted(glob.glob('/data/GC_output/2021-06-21/GC_11111011111_*_scale*.nc'))    

p_4 = sorted(glob.glob('/data/GC_output/2021-06-21/GC_11111111111_global_scale_*.nc'))
len(p_2)

files_t2m = sorted(glob.glob('/geodata2/S2S/ECMWF_Perturbed/Dailyaveraged/t2m/nc/*/Temperature2m_2021-06-21.nc'))
files_t2m = [f for f in files_t2m if int(f.split('/')[-2]) <= 24 * 7]

# Multithreading test code 

정말 빨라지는지 확인해보자ㅏㅇ

In [6]:
# Assign base colors for each partition
partition_colors = {
    # 'p_1': 'blue',
    # 'p_2': 'green',
    # 'p_3': 'red',
    'p_4': 'purple'
}

# Function to extract perturbation type and value from filename
def extract_perturbation_info(filename):
    match = re.search(r'_([01][01][01][01][01][01][01][01][01][01][01])_(.*?)_(scale|wipeout)_([\d.eE+-]+)\.nc$', filename)
    if match:
        var=match.group(1)
        region=match.group(2)
        perturb_type = match.group(3)
        value = match.group(4)
        return f"{var}_{region}_{perturb_type}_{value}"
    else:
        return None

# Collect perturbation files with labels and colors
perturb_files = []
# for partition_name, partition_files in zip(['p_1', 'p_2', 'p_3'], [p_1, p_2, p_3]):
for partition_name, partition_files in zip(['p_4'], [p_4]):
    base_color = partition_colors[partition_name]
    num_files = len(partition_files)
    # Generate different shades of the base color
    colors = sns.light_palette(base_color, n_colors=num_files + 2)[1:-1]
    for i, file in enumerate(partition_files):
        perturb_info = extract_perturbation_info(file)
        if perturb_info:
            label = f"{partition_name} {perturb_info}"
            color = colors[i % len(colors)]
            perturb_files.append((label, color, file))

def process_file(file):
    dataset = xr.open_dataset(file, chunks={'latitude': 1000, 'longitude': 1000})
    dataset = weighted_mean(dataset)
    return dataset

# Prepare delayed tasks
delayed_tasks = []
for label, color, file in perturb_files:
    delayed_task = delayed(process_file)(file)
    delayed_tasks.append((label, color, delayed_task))

# Compute datasets in parallel
results = compute(*[task for _, _, task in delayed_tasks])

# Reconstruct perturb_datasets with computed datasets
perturb_datasets = [(label, color, result) for (label, color, _), result in zip(delayed_tasks, results)]

# Process nwp data
nwp = xr.open_mfdataset(
    files_t2m,
    combine='by_coords',
    chunks={'latitude': 1440, 'longitude': 720},
    preprocess=weighted_mean
)
nwp = nwp.rename({"2t": "2m_temperature"}).compute()

2024-11-12 12:32:08,214 - distributed.worker - ERROR - Compute Failed
Key:       process_file-fea8f669-e941-4086-978a-5f13aeb5fd22
State:     executing
Function:  process_file
args:      ('/data/GC_output/2021-06-21/GC_11111111111_global_scale_1.nc')
kwargs:    {}
Exception: 'TypeError("manager must be a string or instance of ChunkManagerEntrypoint, but received type <class \'xarray.core.daskmanager.DaskManager\'>")'
Traceback: '  File "/tmp/ipykernel_1635224/2442268749.py", line 37, in process_file\n  File "/home/hiskim1/.conda/envs/hiskim1_graphcast/lib/python3.11/site-packages/xarray/backends/api.py", line 677, in open_dataset\n    ds = _dataset_from_backend_dataset(\n         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File "/home/hiskim1/.conda/envs/hiskim1_graphcast/lib/python3.11/site-packages/xarray/backends/api.py", line 401, in _dataset_from_backend_dataset\n    ds = _chunk_ds(\n         ^^^^^^^^^^\n  File "/home/hiskim1/.conda/envs/hiskim1_graphcast/lib/python3.11/site-packages/xarray

TypeError: manager must be a string or instance of ChunkManagerEntrypoint, but received type <class 'xarray.core.daskmanager.DaskManager'>

In [None]:

# Assign base colors for each partition
partition_colors = {
    # 'p_1': 'blue',
    # 'p_2': 'green',
    # 'p_3': 'red',
    'p_4': 'purple'
}

# Function to extract perturbation type and value from filename
def extract_perturbation_info(filename):
    match = re.search(r'_([01][01][01][01][01][01][01][01][01][01][01])_(.*?)_(scale|wipeout)_([\d.eE+-]+)\.nc$', filename)
    if match:
        var=match.group(1)
        region=match.group(2)
        perturb_type = match.group(3)
        value = match.group(4)
        return f"{var}_{region}_{perturb_type}_{value}"
    else:
        return None

# Collect perturbation files with labels and colors
perturb_files = []
# for partition_name, partition_files in zip(['p_1', 'p_2', 'p_3'], [p_1, p_2, p_3]):
for partition_name, partition_files in zip(['p_4'], [p_4]):
    base_color = partition_colors[partition_name]
    num_files = len(partition_files)
    # Generate different shades of the base color
    colors = sns.light_palette(base_color, n_colors=num_files + 2)[1:-1]
    for i, file in enumerate(partition_files):
        perturb_info = extract_perturbation_info(file)
        if perturb_info:
            label = f"{partition_name} {perturb_info}"
            color = colors[i % len(colors)]
            perturb_files.append((label, color, file))

# Preprocess the perturbation datasets
perturb_datasets = []
for label, color, file in perturb_files:
    dataset = weighted_mean(xr.open_dataset(file))
    perturb_datasets.append((label, color, dataset))

files_t2m = sorted(glob.glob('/geodata2/S2S/ECMWF_Perturbed/Dailyaveraged/t2m/nc/*/Temperature2m_2021-06-21.nc'))
files_t2m = [f for f in files_t2m if int(f.split('/')[-2]) <= 24 * 7]

nwp = xr.open_mfdataset(files_t2m, combine='by_coords', preprocess=weighted_mean)
nwp = nwp.rename({"2t":"2m_temperature"}).compute()
df = nwp[target_var].to_dataframe().reset_index()

In [None]:

# # Plot ensemble members
# first_ensemble = True
# for ensemble in df['ensemble'].unique():
#     subset = df[df['ensemble'] == ensemble]
#     if first_ensemble:
#         plt.plot(subset['date'], subset['2m_temperature'], color='grey', linewidth=0.5, alpha=0.5, label='Ensemble Members')
#         first_ensemble = False
#     else:
#         plt.plot(subset['date'], subset['2m_temperature'], color='grey', linewidth=0.5, alpha=0.5)

# # Plot ensemble mean
# mean_temp = df.groupby('date')['2m_temperature'].mean().reset_index()
# plt.plot(mean_temp['date'], mean_temp['2m_temperature'], color='black', linewidth=1, label='Ensemble Mean')

# # Plot perturbation datasets
# for label, color, dataset in perturb_datasets:
#     plt.plot(dataset['date'], dataset['2m_temperature'], color=color, linewidth=1, label=label)

# # Optional: Plot ERA5 data if available
# # era5 = xr.open_dataset("/camdata2/ERA5/daily/t2m/2021.nc").rename({"time":"date", "latitude":"lat", "longitude":"lon"}).sel(date=slice("2021-06-22", "2021-07-01"))
# # era5 = weighted_mean(era5)
# # plt.plot(era5['date'], era5['t2m'], color='red', linewidth=1.5, linestyle='dashed', label='ERA5')

# plt.title('Mean 2m Temperature Forecast / 2021-06-21  + 10 days', fontsize=16)
# plt.xlabel('Date', fontsize=12)
# plt.ylabel('Temperature (K)', fontsize=12)
# plt.legend(fontsize=5)

# # Adjust y-axis limits based on all datasets
# all_temps = []
# for _, _, dataset in perturb_datasets:
#     all_temps.extend(dataset['2m_temperature'].values)
# all_temps.extend(df['2m_temperature'].values)
# y_min = min(all_temps) - 2
# y_max = max(all_temps) + 2
# plt.ylim(y_min, y_max)

# plt.tight_layout()
# plt.show()
# # plt.savefig('figure/2m_temperature_forecast_mean_2021-06-21.png')

In [None]:
plt.figure(figsize=(16, 9))
sns.set_style("whitegrid")

# Plot ensemble members
first_ensemble = True
for ensemble in df['ensemble'].unique():
    subset = df[df['ensemble'] == ensemble]
    if first_ensemble:
        plt.plot(subset['date'], subset[target_var], color='grey', linewidth=0.5, alpha=0.5, label='Ensemble Members')
        first_ensemble = False
    else:
        plt.plot(subset['date'], subset[target_var], color='grey', linewidth=0.5, alpha=0.5)

# Plot ensemble mean
mean_temp = df.groupby('date')[target_var].mean().reset_index()
plt.plot(mean_temp['date'], mean_temp[target_var], color='black', linewidth=1, label='Ensemble Mean')

# Plot perturbation datasets
for label, color, dataset in perturb_datasets:
    plt.plot(dataset['date'], dataset[target_var], color=color, linewidth=1, label=label)

# Optional: Plot ERA5 data if available
# era5 = xr.open_dataset("/camdata2/ERA5/daily/t2m/2021.nc").rename({"time":"date", "latitude":"lat", "longitude":"lon"}).sel(date=slice("2021-06-22", "2021-07-01"))
# era5 = weighted_mean(era5)
# plt.plot(era5['date'], era5['t2m'], color='red', linewidth=1.5, linestyle='dashed', label='ERA5')

plt.title('Mean 2m Temperature Forecast / 2021-06-21  + 7 days', fontsize=16)
plt.xlabel('Date', fontsize=12)
plt.ylabel('Temperature (K)', fontsize=12)


# Adjust y-axis limits based on all datasets
all_temps = []
for _, _, dataset in perturb_datasets:
    all_temps.extend(dataset[target_var].values)
all_temps.extend(df[target_var].values)
y_min = min(mean_temp[target_var]) - 1
y_max = max(mean_temp[target_var]) + 1
plt.ylim(y_min, y_max)
plt.xlim([pd.Timestamp('2021-06-22'), pd.Timestamp('2021-06-28')])

plt.tight_layout()
plt.show()

#

In [None]:
import plotly.graph_objs as go
import plotly.io as pio
import pandas as pd

# Plot ensemble members
ensemble_lines = []
ensemble_legend_shown = False  # Flag to control legend display
for ensemble in df['ensemble'].unique():
    subset = df[df['ensemble'] == ensemble]
    ensemble_lines.append(go.Scatter(
        x=subset['date'],
        y=subset[target_var],
        mode='lines',
        line=dict(color='grey', width=0.5),
        opacity=0.5,
        name='Ensemble Members' if not ensemble_legend_shown else None,
        showlegend=not ensemble_legend_shown,
        legendgroup='Ensemble Members',
        legendgrouptitle_text='Ensemble Members'
    ))
    ensemble_legend_shown = True  # Only show legend once

# Plot ensemble mean
mean_temp = df.groupby('date')[target_var].mean().reset_index()
ensemble_mean_line = go.Scatter(
    x=mean_temp['date'],
    y=mean_temp[target_var],
    mode='lines',
    line=dict(color='black', width=1),
    name='Ensemble Mean'
)

# Plot perturbation datasets
perturb_lines = []
partition_legend_shown = {}  # Dictionary to track legend entries per partition
for label, color, dataset in perturb_datasets:
    # Extract partition name from label (assuming label starts with partition name)
    partition_name = label.split()[0]
    # Only show legend once per partition
    if partition_name not in partition_legend_shown:
        show_legend = True
        partition_legend_shown[partition_name] = True
    else:
        show_legend = False
    # Convert color to a valid format if it's a tuple
    if isinstance(color, tuple):
        color = f'rgb({int(color[0] * 255)}, {int(color[1] * 255)}, {int(color[2] * 255)})'
    perturb_lines.append(go.Scatter(
        x=dataset['date'],
        y=dataset[target_var],
        mode='lines',
        line=dict(color=color, width=1),
        name=partition_name if show_legend else None,
        showlegend=show_legend,
        legendgroup=partition_name,
        legendgrouptitle_text=partition_name
    ))

# Optional: Plot ERA5 data if available
# Uncomment and adjust accordingly if ERA5 data is available
# era5_line = go.Scatter(
#     x=era5['date'],
#     y=era5['t2m'],
#     mode='lines',
#     line=dict(color='red', width=1.5, dash='dash'),
#     name='ERA5',
#     legendgroup='ERA5',
#     legendgrouptitle_text='ERA5'
# )

# Combine all traces
all_traces = ensemble_lines + [ensemble_mean_line] + perturb_lines  # + [era5_line] if ERA5 data is included

# Create the layout
layout = go.Layout(
    title='Mean 2m Temperature Forecast / 2021-06-21  + 7 days',
    xaxis=dict(title='Date', range=[pd.Timestamp('2021-06-22'), pd.Timestamp('2021-06-28')]),
    yaxis=dict(title='Temperature (K)'),
    margin=dict(l=40, r=40, t=40, b=40),
    height=800,  # Increased height for better visibility
    width=1200,  # Increased width for better visibility
    template='plotly_white',
    legend=dict(
        title='Legend',
        orientation='v',  # Vertical legend
        x=1.05,  # Position it just outside the right edge of the plot
        y=1,    # Align at the top
        itemsizing='constant',  # Makes the legend box size consistent
        traceorder='grouped',  # Groups traces in the legend
        itemclick='toggle',  # Enables toggling traces on and off
        itemdoubleclick='toggleothers'  # Double-clicking will turn other traces off
    )
)

# Create the figure
fig = go.Figure(data=all_traces, layout=layout)

# Save the figure as an interactive HTML file
fig.write_html("interactive_temperature_forecast.html")

# Optional: Show the figure in the browser (still interactive)
pio.show(fig)


# Claude

In [None]:
import glob
import re
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import concurrent.futures
import dask
from dask.distributed import Client, LocalCluster
import warnings
warnings.filterwarnings('ignore')

PRESSURE_VARIABLES = [
    "geopotential",
    "u_component_of_wind",
    "v_component_of_wind",
    "specific_humidity",
    "temperature",
    "vertical_velocity"
]

SURFACE_VARIABLES = [
    "10m_u_component_of_wind", 
    "10m_v_component_of_wind", 
    "total_precipitation_6hr", 
    "mean_sea_level_pressure"
]

def weighted_mean(dataset: xr.Dataset):
    if "level" in dataset.dims:
        dataset = dataset.resample(time="1D").mean().squeeze('batch')
        dataset = dataset.drop_vars(PRESSURE_VARIABLES + SURFACE_VARIABLES + ['level'])
        dataset["time"] = pd.date_range("2021-06-22", periods=7, freq="1D")
        dataset = dataset.rename({"time":"date"})
    
    elif "time" in dataset.dims:
        dataset = dataset.expand_dims(dim={'date': [dataset.time.values[0]]}).compute()
        dataset = dataset.rename({'time': 'ensemble'})
        dataset['ensemble'] = np.arange(1, 51)
        dataset = dataset.drop_vars('height')
    
    weights = np.cos(np.deg2rad(dataset.lat))
    weights.name = "weights"
    weighted = dataset.weighted(weights)
    
    return weighted.mean(('lat', 'lon'))

def process_single_file(file_info):
    """단일 파일을 처리하는 함수"""
    label, color, file = file_info
    try:
        # lazy loading을 위해 chunks parameter 사용
        with xr.open_dataset(file, chunks={'time': 1, 'lat': 100, 'lon': 100}) as ds:
            result = weighted_mean(ds)
        return (label, color, result)
    except Exception as e:
        print(f"Error processing {file}: {str(e)}")
        return None

def main():
    # Dask 클러스터 설정
    cluster = LocalCluster(n_workers=4, threads_per_worker=2, memory_limit='8GB')
    client = Client(cluster)
    print(f"Dashboard link: {client.dashboard_link}")

    target_var = '2m_temperature'
    
    # 파일 패턴 정의
    if target_var == '2m_temperature':
        p_1 = sorted(glob.glob('/data/GC_output/2021-06-21/GC_???????????_global_scale*.nc'))
        p_2 = sorted(glob.glob('/data/GC_output/2021-06-21/GC_00100000000_*_scale*.nc'))
        p_3 = sorted(glob.glob('/data/GC_output/2021-06-21/GC_11011111111_*_scale*.nc'))
    elif target_var == 'geopotential':
        p_1 = sorted(glob.glob('/data/GC_output/2021-06-21/GC_???????????_global_scale*.nc'))
        p_2 = sorted(glob.glob('/data/GC_output/2021-06-21/GC_00000100000_*_scale*.nc'))
        p_3 = sorted(glob.glob('/data/GC_output/2021-06-21/GC_11111011111_*_scale*.nc'))

    # 파티션 색상 정의
    partition_colors = {
        'p_1': 'blue',
        'p_2': 'green',
        'p_3': 'red',
    }

    def extract_perturbation_info(filename):
        match = re.search(r'_([01][01][01][01][01][01][01][01][01][01][01])_(.*?)_(scale|wipeout)_([\d.eE+-]+)\.nc$', filename)
        if match:
            var = match.group(1)
            region = match.group(2)
            perturb_type = match.group(3)
            value = match.group(4)
            return f"{var}_{region}_{perturb_type}_{value}"
        return None

    # 파일 정보 수집
    perturb_files = []
    for partition_name, partition_files in zip(['p_1', 'p_2', 'p_3'], [p_1, p_2, p_3]):
        base_color = partition_colors[partition_name]
        num_files = len(partition_files)
        colors = sns.light_palette(base_color, n_colors=num_files + 2)[1:-1]
        
        for i, file in enumerate(partition_files):
            perturb_info = extract_perturbation_info(file)
            if perturb_info:
                label = f"{partition_name} {perturb_info}"
                color = colors[i % len(colors)]
                perturb_files.append((label, color, file))

    # 병렬 처리 실행
    perturb_datasets = []
    chunk_size = 10  # 한 번에 처리할 파일 수
    
    for i in range(0, len(perturb_files), chunk_size):
        chunk = perturb_files[i:i + chunk_size]
        
        # 청크 단위로 병렬 처리
        futures = [dask.delayed(process_single_file)(file_info) for file_info in chunk]
        results = dask.compute(*futures)
        
        # 결과 수집
        for result in results:
            if result is not None:
                perturb_datasets.append(result)
        
        print(f"Processed files {i+1} to {min(i+chunk_size, len(perturb_files))} of {len(perturb_files)}")

    # NWP 파일 처리
    files_t2m = sorted(glob.glob('/geodata2/S2S/ECMWF_Perturbed/Dailyaveraged/t2m/nc/*/Temperature2m_2021-06-21.nc'))
    files_t2m = [f for f in files_t2m if int(f.split('/')[-2]) <= 24 * 10]
    
    # NWP 파일도 청크로 나눠서 처리
    chunks = {'time': 1, 'lat': 100, 'lon': 100}
    nwp = xr.open_mfdataset(files_t2m, 
                           combine='by_coords', 
                           preprocess=weighted_mean,
                           parallel=True,
                           chunks=chunks)
    
    nwp = nwp.rename({"2t": "2m_temperature"})
    df = nwp[target_var].to_dataframe().reset_index()

    # 클러스터 종료
    client.close()
    cluster.close()

    return perturb_datasets, df

if __name__ == "__main__":
    perturb_datasets, df = main()