# Header

preprocess to generate `/data/GC_output/analysis/percent2` dataset

forked from `preprocess.ipynb`

In [1]:
import glob
import re
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
from multiprocessing import Pool
from lib.his_preprocess import *

In [2]:
# GC 데이터 전처리

### 
# p_1: 11111111111 / 모든 변수
# p_2: 00000000001 / 타켓변수만
# p_3: 11111111110 / 타켓 변수 제외
# mean: 전지구 평균/ raw: 원본 데이터
###

for target_var in ["2m_temperature"]:

    if target_var == '2m_temperature':
        p_0 = sorted(glob.glob('/data/GC_output/percent/GC_00100000000_0.001_*.nc'))
        p_1 = sorted(glob.glob('/data/GC_output/percent2/GC_00100000000_0.01_*.nc'))
        p_2 = sorted(glob.glob('/data/GC_output/percent2/GC_00100000000_0.03_*.nc'))
        p_3 = sorted(glob.glob('/data/GC_output/percent2/GC_00100000000_0.05_*.nc'))
        p_4 = sorted(glob.glob('/data/GC_output/percent2/GC_00100000000_0.07_*.nc'))
        p_5 = sorted(glob.glob('/data/GC_output/percent2/GC_00100000000_0.085_*.nc'))
        p_6 = sorted(glob.glob('/data/GC_output/percent2/GC_00100000000_0.1_*.nc'))
        p_7 = sorted(glob.glob('/data/GC_output/percent2/GC_00100000000_0.3_*.nc'))
        p_8 = sorted(glob.glob('/data/GC_output/percent2/GC_00100000000_0.5_*.nc'))
        p_9 = sorted(glob.glob('/data/GC_output/percent/GC_00100000000_1_*.nc'))

    # elif target_var == 'geopotential':
    #     p_1 = sorted(glob.glob('/data/GC_output/percent2/GC_11111111111_*_*.nc'))
    #     p_2 = sorted(glob.glob('/data/GC_output/percent2/GC_00000100000_*_*.nc'))
    #     p_3 = sorted(glob.glob('/data/GC_output/percent2/GC_11111011111_*_*.nc'))

    # Assign base colors for each partition
    partition_colors = {
        'p_0': 'seagreen',
        'p_1': 'red',
        'p_2': 'darkorange',
        'p_3': 'cyan',
        'p_4': 'palegreen',
        'p_5': 'royalblue',
        'p_6': 'blue',
        'p_7': 'purple',
        'p_8': 'magenta',
        'p_9': 'gold',
        'p_10': 'deeppink'
    }

    def extract_perturbation_info(filename):
        match = re.search(r'GC_([01]{11})_([\d.eE+-]+)_(.*?)\.nc$', filename)
        if match:
            var = match.group(1)  # perturbation 코드
            value = match.group(2)  # 값 (예: 0.001)
            region = match.group(3)  # 지역 코드 (예: 9p)
            return f"{value}_{region}_{var}"
        else:
            return None

    # Collect perturbation files with labels and colors
    perturb_files = []
    
    part_col = ['p_0', 'p_1', 'p_2', 'p_3', 'p_4', 'p_5', 'p_6', 'p_7', 'p_8', 'p_9']
    tar_var = [p_0, p_1, p_2, p_3, p_4, p_5, p_6, p_7, p_8, p_9]
    
    for partition_name, partition_files in zip(part_col, tar_var):
        base_color = partition_colors[partition_name]
        num_files = len(partition_files)
        # Generate different shades of the base color
        colors = sns.light_palette(base_color, n_colors=num_files + 2)[1:-1]
        for i, file in enumerate(partition_files):
            perturb_info = extract_perturbation_info(file)
            if perturb_info:
                label = f"{partition_name} {perturb_info}"
                color = colors[i % len(colors)]
                perturb_files.append((label, color, file))
    
    # Modify labels for p_0 and p_9 ./ 1p,3p -> 10,30
    for i, (label, color, file) in enumerate(perturb_files):
        if 'p_0' in label or 'p_9' in label:
            # 레이블 수정: '_[0-9]+p' 패턴을 '_0'으로 대체
            new_label = re.sub(r'_(\d+)p', r'_\g<1>0', label)
            perturb_files[i] = (new_label, color, file)

    perturb_datasets_raw = []
    perturb_datasets_mean = []

    def process_file(file_info):
        label, color, file = file_info
        dataset_raw = preprocess_GC(xr.open_dataset(file), target_var)
        dataset_mean = weighted_mean(dataset_raw)
        return {"mean": (label, color, dataset_mean), "raw": (label, color, dataset_raw)}

    with Pool(processes=12) as pool:
        results = pool.map(process_file, perturb_files)

    perturb_datasets_mean = [result["mean"] for result in results]
    perturb_datasets_raw = [result["raw"] for result in results]

    if target_var == '2m_temperature':
        with open('/data/GC_output/analysis/percent2/GC_t2m_GlobAvg.pkl', 'wb') as f:
            pickle.dump(perturb_datasets_mean, f)
        with open('/data/GC_output/analysis/percent2/GC_t2m_Globraw.pkl', 'wb') as f:
            pickle.dump(perturb_datasets_raw, f)