# Header

### Goal
area / intensity와 퍼진 정도간의 연관성이 있는가?

1. 기존과 같이, global t2m & z500 mean을 비교  
    1. date 설정: 6일 뒤; 2021-06-27일  
    2. GC: GC - original 
    3. NWP: NWP - mean  
    4. 새로운 file로 저-장

2. perturbation region size와 intensity간의 연관성 확인
    1. perturbation region size point 계산하기
    2. perturbation intensity별로 plot
    


In [7]:
import glob
import re
import xarray as xr
import numpy as np
import pandas as pd
import pickle
from his_preprocess import *

# Load Data

In [8]:
#@ load data

target_var = {"2m_temperature": "t2m", "geopotential": "z500"}
target_date = 6 # 2021-06-27

# nwp: 50개 ensemble
nwp = {}
# nwp_mean: 50개 ensemble의 평균
nwp_mean = {}

# gc_sample: 94개 * var 2개 
gc_sample = {}
# gc_original: var 2개 각각의 평균
gc_original = {}

for key, value in target_var.items():
    nwp[key] = xr.open_dataset(f"/data/GC_output/analysis/regional/nwp_{value}_GlobAvg.nc")#.isel(date=target_date)
    nwp_mean[key] = nwp[key].mean(dim="ensemble")
    nwp[key] -= nwp_mean[key]


    with open(f"/data/GC_output/analysis/regional/GC_{value}_GlobAvg.pkl", "rb") as f:
        gc_sample[key] = pickle.load(f)

    gc_original[key] = xr.open_dataset("/data/GC_output/2021-06-21/GC_output.nc")
    gc_original[key] = weighted_mean(preprocess_GC(gc_original[key], key)[key])


In [3]:
#@ gc_sample 추가 processing

from concurrent.futures import ThreadPoolExecutor
import itertools

target_var = {"2m_temperature": "t2m", "geopotential": "z500"}

def process_item(args):
    target_var, n = args
    if n >= len(gc_sample[target_var]):
        return None
    
    filename = gc_sample[target_var][n][0]
    color = gc_sample[target_var][n][1]
    value = (gc_sample[target_var][n][2] - gc_original[target_var])[target_var].values
    return (target_var, (filename, color, value))

with ThreadPoolExecutor() as executor:
    results = executor.map(process_item, 
                         itertools.product(list(target_var.keys()), range(94)))

gc_sample = {'2m_temperature': [], 'geopotential': []}
for r in results:
    if r is not None:
        var, item = r
        gc_sample[var].append(item)

# Plot

모든 데이터 준비는 끝났땅! 이제 그림을 그려볼까요?

In [None]:
#@ plot 1

import plotly.graph_objs as go
import plotly.io as pio
import pandas as pd

for target_var in ["2m_temperature", "geopotential"]:
    if target_var == '2m_temperature':
        title = 'Mean 2m Temperature Forecast / 2021-06-21  + 7 days'
        unit = 'Temperature (K)'
    elif target_var == 'geopotential':
        title = 'Mean 500hPa Geopotential Forecast / 2021-06-21  + 7 days'
        unit = 'Geopotential (m^2/s^2)'

    # Plot perturbation datasets with enhanced legend grouping
    perturb_lines = []
    partition_groups = {}

    for label, color, dataset in gc_sample[target_var]:
        partition_name = label.split()[0]
        if partition_name not in partition_groups:
            partition_groups[partition_name] = []
        
        # Convert color to RGB if it's a tuple
        if isinstance(color, tuple):
            color = f'rgb({int(color[0] * 255)}, {int(color[1] * 255)}, {int(color[2] * 255)})'
        
        # Create trace
        trace = go.Scatter(
            x=[pd.Timestamp('2021-06-22'), pd.Timestamp('2021-06-28')],
            y=dataset,
            mode='lines',
            line=dict(color=color, width=1),
            name=label[4:],
            legendgroup=partition_name,
            legendgrouptitle=dict(
                text=partition_name,
                font=dict(size=14, color='black', family='Arial Bold')
            ) if len(partition_groups[partition_name]) == 0 else None,
            showlegend=True
        )
        partition_groups[partition_name].append(trace)
        perturb_lines.append(trace)

    df = nwp[target_var].to_dataframe().reset_index()

    # Add ensemble lines
    ensemble_legend_shown = False  # Flag to control legend display
    for ensemble in df['ensemble'].unique():
        subset = df[df['ensemble'] == ensemble]
        trace = go.Scatter(
            x=subset['date'],
            y=subset[target_var],
            mode='lines',
            line=dict(color='grey', width=0.5),
            opacity=0.5,
            name='Ensemble Members' if not ensemble_legend_shown else None,
            showlegend=not ensemble_legend_shown,
            legendgroup='Ensemble Members',
            legendgrouptitle=dict(
                text='Ensemble Members',
                font=dict(size=14, color='black', family='Arial Bold')
            ) if not ensemble_legend_shown else None,
        )
        perturb_lines.append(trace)
        ensemble_legend_shown = True  # Only show legend once

    # Add ensemble mean line
    mean_temp = df.groupby('date')[target_var].mean().reset_index()
    ensemble_mean_line = go.Scatter(
        x=mean_temp['date'],
        y=mean_temp[target_var],
        mode='lines',
        line=dict(color='black', width=2),
        name='Ensemble Mean',
        legendgroup='Ensemble Mean',
        legendgrouptitle=dict(
            text='Ensemble Mean',
            font=dict(size=14, color='black', family='Arial Bold')
        ),
        showlegend=True
    )
    perturb_lines.append(ensemble_mean_line)

    # Calculate y-axis range
    y_values = []
    for _, _, dataset in gc_sample[target_var]:
        y_values.extend(dataset)
    y_values.extend(df[target_var].values)  # Include ensemble data
    y_min = np.min(y_values)
    y_max = np.max(y_values)
    y_range = y_max - y_min
    y_padding = y_range * 0.05

    # Create the layout
    layout = go.Layout(
        title=title,
        xaxis=dict(
            title='Date', 
            range=[pd.Timestamp('2021-06-22'), pd.Timestamp('2021-06-28')]
        ),
        yaxis=dict(
            title=unit,
            range=[y_min - y_padding, y_max + y_padding],
            tickformat='.0f'
        ),
        margin=dict(l=40, r=40, t=40, b=40),
        height=900,
        width=1600,
        template='plotly_white',
        legend=dict(
            title=dict(
                text='Legend',
                font=dict(size=16)
            ),
            orientation='v',
            x=1.05,
            y=1,
            itemsizing='constant',
            groupclick='toggleitem',
            itemclick='toggle',
            itemdoubleclick='toggleothers',
            tracegroupgap=15,
            font=dict(size=12),
            grouptitlefont=dict(size=14, color='black', family='Arial Bold'),
            borderwidth=1,
            bordercolor='rgba(0,0,0,0.2)',
            bgcolor='rgba(255,255,255,0.95)',
            traceorder='grouped'
        )
    )

    # Create the figure
    fig = go.Figure(data=perturb_lines, layout=layout)

    # Save and display the figure
    fig.write_html(f"{target_var}_spread.html")
    pio.show(fig)

# per region size

위에서 뽑아낸 값 / grid point를 해보자. 어떻게 되나.

In [None]:
#@ grid point count

from lib.his_utils import *

grid_points = {}
for key, value in REGION_BOUNDARIES.items():
    lat = value["lat"]
    lon = value.get("lon", (0, 360))  # lon이 없으면 전구 경도 사용
    
    # 위도, 경도 인덱스 계산
    lat_points = int((lat[1] - lat[0]) * 4)  # 0.25도 간격이므로 4를 곱함
    
    # 경도가 날짜변경선을 걸치는 경우
    if lon[1] < lon[0]:
        lon_points = int((360 - lon[0] + lon[1]) * 4)
    else:
        lon_points = int((lon[1] - lon[0]) * 4)
    
    grid_points[key] = lat_points * lon_points / 1036800

grid_points


In [6]:
#@ grid point count 

from multiprocessing import Pool

def process_file(args):
    filename, color, dataset = args
    region = filename.split('_')[2]

    return (filename, color, dataset * grid_points[region])

result = {}
for var in gc_sample.keys():
    with Pool(processes=35) as pool:
        result[var] = pool.map(process_file, gc_sample[var])

gc_sample = result

In [None]:
#@ plot 2

for target_var in ["2m_temperature", "geopotential"]:
    if target_var == '2m_temperature':
        title = 'Mean 2m Temperature Forecast / 2021-06-21  + 7 days'
        unit = 'Temperature (K)'
    elif target_var == 'geopotential':
        title = 'Mean 500hPa Geopotential Forecast / 2021-06-21  + 7 days'
        unit = 'Geopotential (m^2/s^2)'

    # Plot perturbation datasets with enhanced legend grouping
    perturb_lines = []
    partition_groups = {}

    for label, color, dataset in gc_sample[target_var]:
        partition_name = label.split()[0]
        if partition_name not in partition_groups:
            partition_groups[partition_name] = []
        
        # Convert color to RGB if it's a tuple
        if isinstance(color, tuple):
            color = f'rgb({int(color[0] * 255)}, {int(color[1] * 255)}, {int(color[2] * 255)})'
        
        # Create trace
        trace = go.Scatter(
            x=[pd.Timestamp('2021-06-22'), pd.Timestamp('2021-06-28')],
            y=dataset,
            mode='lines',
            line=dict(color=color, width=1),
            name=label[4:],
            legendgroup=partition_name,
            legendgrouptitle=dict(
                text=partition_name,
                font=dict(size=14, color='black', family='Arial Bold')
            ) if len(partition_groups[partition_name]) == 0 else None,
            showlegend=True
        )
        partition_groups[partition_name].append(trace)
        perturb_lines.append(trace)

    df = nwp[target_var].to_dataframe().reset_index()

    # Add ensemble lines
    ensemble_legend_shown = False  # Flag to control legend display
    for ensemble in df['ensemble'].unique():
        subset = df[df['ensemble'] == ensemble]
        trace = go.Scatter(
            x=subset['date'],
            y=subset[target_var],
            mode='lines',
            line=dict(color='grey', width=0.5),
            opacity=0.5,
            name='Ensemble Members' if not ensemble_legend_shown else None,
            showlegend=not ensemble_legend_shown,
            legendgroup='Ensemble Members',
            legendgrouptitle=dict(
                text='Ensemble Members',
                font=dict(size=14, color='black', family='Arial Bold')
            ) if not ensemble_legend_shown else None,
        )
        perturb_lines.append(trace)
        ensemble_legend_shown = True  # Only show legend once

    # Add ensemble mean line
    mean_temp = df.groupby('date')[target_var].mean().reset_index()
    ensemble_mean_line = go.Scatter(
        x=mean_temp['date'],
        y=mean_temp[target_var],
        mode='lines',
        line=dict(color='black', width=2),
        name='Ensemble Mean',
        legendgroup='Ensemble Mean',
        legendgrouptitle=dict(
            text='Ensemble Mean',
            font=dict(size=14, color='black', family='Arial Bold')
        ),
        showlegend=True
    )
    perturb_lines.append(ensemble_mean_line)

    # Calculate y-axis range
    y_values = []
    for _, _, dataset in gc_sample[target_var]:
        y_values.extend(dataset)
    y_values.extend(df[target_var].values)  # Include ensemble data
    y_min = np.min(y_values)
    y_max = np.max(y_values)
    y_range = y_max - y_min
    y_padding = y_range * 0.05

    # Create the layout
    layout = go.Layout(
        title=title,
        xaxis=dict(
            title='Date', 
            range=[pd.Timestamp('2021-06-22'), pd.Timestamp('2021-06-28')]
        ),
        yaxis=dict(
            title=unit,
            range=[y_min - y_padding, y_max + y_padding],
            tickformat='.0f'
        ),
        margin=dict(l=40, r=40, t=40, b=40),
        height=900,
        width=1600,
        template='plotly_white',
        legend=dict(
            title=dict(
                text='Legend',
                font=dict(size=16)
            ),
            orientation='v',
            x=1.05,
            y=1,
            itemsizing='constant',
            groupclick='toggleitem',
            itemclick='toggle',
            itemdoubleclick='toggleothers',
            tracegroupgap=15,
            font=dict(size=12),
            grouptitlefont=dict(size=14, color='black', family='Arial Bold'),
            borderwidth=1,
            bordercolor='rgba(0,0,0,0.2)',
            bgcolor='rgba(255,255,255,0.95)',
            traceorder='grouped'
        )
    )

    # Create the figure
    fig = go.Figure(data=perturb_lines, layout=layout)

    # Save and display the figure
    fig.write_html(f"{target_var}_spread gridpoint percentage.html")
    pio.show(fig)