In [9]:
import numpy as np
import pandas as pd
import xarray as xr

import multiprocessing as mp
from multiprocessing import Pool, Manager

from pathlib import Path
from tqdm.notebook import tqdm, trange
from concurrent.futures import ThreadPoolExecutor, as_completed

In [2]:
version = "daily_100"

Merge_data_path = Path("/data6t/AIWP_TP_dataset/merge_data")
Merge_daily_all_data_path = Merge_data_path / version
Durre_data_path = Path("/data6t/AIWP_TP_dataset/Durre2010_data")
Durre_daily_all_data_path = Durre_data_path / version

Hamada_data_path = Path("/data6t/AIWP_TP_dataset/Hamada2011_data")
Hamada_daily_all_data_path = Hamada_data_path / version

Beck_data_path = Path("/data6t/AIWP_TP_dataset/Beck2019_data")
Beck_daily_all_data_path = Beck_data_path / version

QC_data_path = Path("/data6t/AIWP_TP_dataset/QC_data")
QC_daily_data_path = QC_data_path / version

station_info_df = pd.read_csv(QC_data_path / f"Daily_station_info_{version}.csv")
station_info_df

Unnamed: 0,ID,LATITUDE,LONGITUDE,ELEVATION,FREQUENCY,REPORTING_HOUR,best_shift_hour,reporting_time,year_length
0,GHCNh_ACW00011647,17.13330,298.21670,19.2,6,0.0,4.0,-4.0,4.468493
1,GHCNh_AGI0000DAAT,22.81150,5.45110,1377.1,3,0.0,48.0,-48.0,0.654795
2,GHCNh_AGI0000DAOH,34.82000,358.23000,426.0,24,6.0,5.0,-5.0,4.928767
3,GHCNh_AGM00060353,36.81670,5.88330,6.0,6,0.0,2.0,-2.0,0.608219
4,GHCNh_AGM00060395,36.70000,4.13330,153.0,24,6.0,7.0,-7.0,4.947945
...,...,...,...,...,...,...,...,...,...
59164,GSOD_A0000253928,30.21889,263.62583,93.9,24,,-17.0,17.0,2.827397
59165,GSOD_A0000594076,40.05361,253.63111,2258.9,24,,-15.0,15.0,3.000000
59166,GSOD_A0002453848,30.51195,273.04594,34.0,24,,-17.0,17.0,4.063014
59167,GSOD_A0003225715,52.22028,185.79389,17.1,24,,-14.0,14.0,1.986301


# Integrate station data from 2020 to 2025 into NetCDF file

- 1 Passed the check
- 0 Failed the check
- -1 Missing data

In [None]:
# args = (id_list[0], Beck_daily_data_path, QC_daily_data_path)

def merge_one_station(args):
    station_id, source_dir, target_dir = args
    if not target_dir.exists():
        target_dir.mkdir(parents=True, exist_ok=True)
    output_path = target_dir / f"{station_id}.csv"
    if output_path.exists():
        df = pd.read_csv(output_path, parse_dates=True, index_col=0)
    else:
        file_path = source_dir / f"{station_id}.csv"
        df = pd.read_csv(file_path, parse_dates=True, index_col=0)
        
        # if len(df.dropna(subset=['PRCP'])) < 365:
            # return None
        flag_cols = [col for col in df.columns.tolist() if col.endswith('_flags')]

        df["PRCP_Flag"] = df.loc[:,flag_cols].sum(axis=1)
        df["PRCP_Flag"] = df["PRCP_Flag"].where(df["PRCP"].notna(), other=np.nan)
        df["PRCP_QC"] = df["PRCP"].copy().where(df["PRCP_Flag"]==12, other=np.nan)
        df.to_csv(output_path)

    df = df[["PRCP", "PRCP_QC", "PRCP_Flag"]]

    # date_2020_2024 = (df.index >= '2020-01-01') & (df.index <= '2024-12-31')
    df_2020_2025 = df.reindex(pd.date_range(start='2020-01-01', end='2025-12-31'), fill_value=np.nan)
    # df_2020_2024 = df.loc[date_2020_2024, :]
    if len(df_2020_2025.dropna(subset=['PRCP'])) < 1:
        return None
    return {"ID": station_id, 
            "raw_data": df_2020_2025["PRCP"], 
            "qc_data": df_2020_2025["PRCP_QC"], 
            "qc_flag": df_2020_2025["PRCP_Flag"]}


station_id_list = [f.name[:-4] for f in Beck_daily_all_data_path.iterdir() if f.is_file()]

n_jobs = mp.cpu_count()

args_list = [
    (station_id, Beck_daily_all_data_path, QC_daily_data_path) 
    for station_id in station_id_list if station_id in station_info_df["ID"].values
]

with Pool(n_jobs) as pool:
    results = list(tqdm(
        pool.imap_unordered(merge_one_station, args_list),
        total=len(args_list),
        desc="Processing progress",
    ))
results = [res for res in results if res is not None]
station_id = [res["ID"] for res in results]

merge_data = xr.Dataset(
    {
        "raw_data": (("station", "time"), np.array([res["raw_data"].values for res in results])),
        "qc_data": (("station", "time"), np.array([res["qc_data"].values for res in results])),
        "qc_flag": (("station", "time"), np.array([res["qc_flag"].values for res in results])),
        
    },
    coords={
        "station": station_id,
        "time": pd.date_range(start='2020-01-01', end='2025-12-31'),
        'latitude': (("station"), station_info_df.set_index("ID").loc[station_id, "LATITUDE"].values),
        'longitude': (("station"), station_info_df.set_index("ID").loc[station_id, "LONGITUDE"].values),
        'elevation': (("station"), station_info_df.set_index("ID").loc[station_id, "ELEVATION"].values),
    }
)
del results
merge_data["raw_data"] = merge_data["raw_data"].astype("float32")
merge_data["qc_data"] = merge_data["qc_data"].astype("float32")
merge_data["qc_flag"] = merge_data["qc_flag"].astype("int32")

merge_data

In [4]:
merge_data.to_netcdf(QC_data_path / f"PRCP_QC_{version}_2020_2025.nc")

In [None]:
def merge_one_station(args):
    station_id, source_dir, target_dir = args
    if not target_dir.exists():
        target_dir.mkdir(parents=True, exist_ok=True)
    output_path = target_dir / f"{station_id}.csv"
    if output_path.exists():
        df = pd.read_csv(output_path, parse_dates=True, index_col=0)
        flag_cols = [col for col in df.columns.tolist() if col.endswith('_flags')]
    else:
        file_path = source_dir / f"{station_id}.csv"
        df = pd.read_csv(file_path, parse_dates=True, index_col=0)
        
        # if len(df.dropna(subset=['PRCP'])) < 365:
            # return None
        flag_cols = [col for col in df.columns.tolist() if col.endswith('_flags')]

        df["PRCP_Flag"] = df.loc[:,flag_cols].sum(axis=1)
        df["PRCP_Flag"] = df["PRCP_Flag"].where(df["PRCP"].notna(), other=np.nan)
        df["PRCP_QC"] = df["PRCP"].copy().where(df["PRCP_Flag"]==12, other=np.nan)
        df.to_csv(output_path)

    df_2020_2025 = df.reindex(pd.date_range(start='2020-01-01', end='2025-12-31'), fill_value=np.nan)
    if len(df_2020_2025.dropna(subset=['PRCP'])) < 1:
        return None
    result = {"ID": station_id, "raw_data": df_2020_2025["PRCP"], 
              "qc_data": df_2020_2025["PRCP_QC"], "flags_num": df_2020_2025["PRCP_Flag"]}
    for col in flag_cols:
        result[col] = df_2020_2025[col].astype(float).where(df_2020_2025["PRCP"].notna(), other=-1)
        result[col] = result[col].fillna(-1).astype(int)
    return result


station_id_list = [f.name[:-4] for f in Beck_daily_all_data_path.iterdir() if f.is_file()]

n_jobs = mp.cpu_count()

args_list = [
    (station_id, Beck_daily_all_data_path, QC_daily_data_path)
    for station_id in station_id_list if station_id in station_info_df["ID"].values
]

with Pool(n_jobs) as pool:
    results = list(tqdm(
        pool.imap_unordered(merge_one_station, args_list),
        total=len(args_list),
        desc="Processing progress"
    ))
results = [res for res in results if res is not None]

station_id = [res["ID"] for res in results]
cols = list(results[0].keys())[1:]
flag_cols = [col for col in cols if col.endswith('_flags')]

data = {}
for col in cols:
    data[col] = (("station", "time"), np.array([res[col].values for res in results]))
del results

merge_data = xr.Dataset(
    data,
    coords={
        "station": station_id,
        "time": pd.date_range(start='2020-01-01', end='2025-12-31'),
        'latitude': (("station"), station_info_df.set_index("ID").loc[station_id, "LATITUDE"].values),
        'longitude': (("station"), station_info_df.set_index("ID").loc[station_id, "LONGITUDE"].values),
        'elevation': (("station"), station_info_df.set_index("ID").loc[station_id, "ELEVATION"].values),
    }
)
del data

flag_cols = [col for col in cols if col.endswith('_flags')]
for col in flag_cols:
    merge_data[col] = merge_data[col].astype("int8")
merge_data["flags_num"] = merge_data["flags_num"].astype("float32")
merge_data["raw_data"] = merge_data["raw_data"].astype("float32")
merge_data["qc_data"] = merge_data["qc_data"].astype("float32")

merge_data

处理进度:   0%|          | 0/59165 [00:00<?, ?it/s]

In [6]:
merge_data.to_netcdf(QC_data_path / f"PRCP_QC_flags_{version}_2020_2025.nc")

In [7]:
station_info_df = station_info_df.set_index("ID").loc[merge_data.station.values].reset_index()
station_info_df.to_csv(QC_data_path / f"{version}_station_info.csv", index=False)

# Infer reporting times

In [5]:
import warnings

warnings.filterwarnings("ignore", category=RuntimeWarning)

In [6]:
BENCHMARK_DIR = Path("/data6t/AIWP_TP_dataset/benchmark")

In [7]:
QC_daily_tp = xr.open_dataset(QC_data_path / f"PRCP_QC_{version}_2020_2025.nc").load()
QC_daily_tp

In [None]:
def calc_station_time_shift_corr(args):
    station_qc_tp, bench_name = args
        
    bench_path_dict = {"IMERG": "IMERG_V07_Late_24h_rolling.zarr",
                       "MSWEP": "MSWEP_V280_24h_rolling.zarr",
                       "GPCC": "GPCC_first_guess_daily_2020_2024.nc",
                       "ERA5": "ERA5_tp_24h_rolling.zarr"}
    bench_var_dict = {"IMERG": "precipitation",
                       "MSWEP": "precipitation",
                       "GPCC": "p",
                       "ERA5": "tp"}
    bench_tp_path = BENCHMARK_DIR / bench_path_dict[bench_name]
    bench_var = bench_var_dict[bench_name]
    # Load station data
    station_qc_tp = station_qc_tp.load()
    station_qc_tp["time"] = station_qc_tp["time"] + np.timedelta64(24, 'h')  
    station_id = station_qc_tp.station.item()
    lat, lon = station_qc_tp.latitude.item(), station_qc_tp.longitude.item()
    # Judge if already calculated
    temp_dir = BENCHMARK_DIR / f"{bench_name}_corr"
    if not temp_dir.exists():
        temp_dir.mkdir(parents=True, exist_ok=True)
    temp_corr_path = temp_dir / f"{station_id}_corr.npy"
    # if temp_corr_path.exists():
    #     temp_corr = np.load(temp_corr_path, allow_pickle=True).item()
    #     return temp_corr
    # Load benchmark data
    station_bench_tp = xr.open_zarr(bench_tp_path)[bench_var].sel(
        lat=lat, lon=lon, method="nearest").load()

    original_station_date = station_qc_tp.time.data
    original_bench_date = station_bench_tp.time.data
    
    station_corrs = {"station": station_id}
    for shift_hour in range(-48, 49): 
        shift_date_range = original_station_date + np.timedelta64(shift_hour, 'h')
        station_qc_tp["time"] = shift_date_range

        merge_date_range = np.intersect1d(shift_date_range, original_bench_date)

        station_qc_tp_sel = station_qc_tp.sel(time=merge_date_range)
        n_valid = station_qc_tp_sel.notnull().sum().item()
        if n_valid < 30:
            continue        

        shifted_bench_tp = station_bench_tp.sel(time=merge_date_range)

        qc_data = station_qc_tp_sel.data
        bench_data = shifted_bench_tp.data

        valid_mask = ~np.isnan(qc_data) & ~np.isnan(bench_data)

        correlation = np.corrcoef(qc_data[valid_mask], bench_data[valid_mask])[0, 1]

        station_corrs[shift_hour] = correlation
    np.save(temp_corr_path, station_corrs)
    return station_corrs
    #     print(f"Station {station_id.item()} Shift Hour {shift_hour} Correlation: {correlation:.4f} with {n_valid} valid samples.")
    #     # break
    #     # if correlation >= 0.7:
    #     #     print(f"Best shift hour for station {station.item()} is {shift_hour} with correlation {correlation:.4f}")
    # break

In [None]:
import multiprocessing as mp
from multiprocessing import Pool, Manager

bench_name = "ERA5"
args_list = [
    (QC_daily_tp["qc_data"].sel(station=station_id), bench_name) 
    for station_id in QC_daily_tp.station.data
]

with Pool(mp.cpu_count()) as pool:
    results = list(tqdm(
        pool.imap_unordered(calc_station_time_shift_corr, args_list),
        total=len(QC_daily_tp.station.data),
        desc="处理进度"
    ))

reporting_times = pd.DataFrame(results)
reporting_times.to_csv(f"Station_{bench_name}_shift_corr_{version}_2020_2025.csv", index=None)

reporting_times = reporting_times.set_index('station').idxmax(axis=1).reset_index()
reporting_times.columns = ["ID", "best_shift_hour"]
reporting_times["reporting_time"] = -reporting_times["best_shift_hour"]
reporting_times.to_csv(f"daily_station_{bench_name}_reporting_times_{version}_2020_2025.csv", index=None)
reporting_times

# Add reporting times as metadata

In [3]:
bench_name = "ERA5"
reporting_times = pd.read_csv(f"daily_station_{bench_name}_reporting_times_{version}.csv")
reporting_times

Unnamed: 0,ID,best_shift_hour,reporting_time
0,GSOD_72252012907,-16.0,16.0
1,GHCNd_US1TXGP0123,-8.0,8.0
2,GHCNd_US1KSMI0015,-11.0,11.0
3,GHCNd_USC00140119,-9.0,9.0
4,GHCNd_US1WIRK0015,-10.0,10.0
...,...,...,...
59164,GHCNd_CA001128584,15.0,-15.0
59165,GSOD_16360099999,-9.0,9.0
59166,GHCNd_USC00407184,-1.0,1.0
59167,GHCNd_US1CAHM0004,-9.0,9.0


In [4]:
dataset = xr.open_dataset(QC_data_path / f"PRCP_QC_flags_{version}_2020_2025.nc")
dataset

In [10]:
dataset = dataset.rename({"raw_data": "PRCP_raw",
                              "qc_data": "PRCP_QC",
                              "flags_num": "Flags_num"})
dataset = dataset.rename({item: item.replace("flags", "flag") for item in list(dataset.data_vars.keys()) if item[-1] == "s"})
dataset

In [11]:
dataset.assign_coords(
    inferred_reporting_time=(("station"), 
                     reporting_times.set_index("ID").loc[
                         [id for id in dataset["station"].data], "reporting_time"].values)).to_netcdf(
                             QC_data_path / f"GHCNdailyPrcp_2020_2025.nc")

# Convert to gridded data

In [3]:
bench_name = "era5"
reporting_times = pd.read_csv(f"daily_station_{bench_name}_reporting_times_{version}.csv").set_index("ID")
reporting_times

Unnamed: 0_level_0,best_shift_hour,reporting_time
ID,Unnamed: 1_level_1,Unnamed: 2_level_1
GSOD_72252012907,-16.0,16.0
GHCNd_US1TXGP0123,-8.0,8.0
GHCNd_US1KSMI0015,-11.0,11.0
GHCNd_USC00140119,-9.0,9.0
GHCNd_US1WIRK0015,-10.0,10.0
...,...,...
GHCNd_CA001128584,15.0,-15.0
GSOD_16360099999,-9.0,9.0
GHCNd_USC00407184,-1.0,1.0
GHCNd_US1CAHM0004,-9.0,9.0


In [4]:
station_info_df = pd.read_csv(QC_data_path / f"{version}_station_info.csv")
station_info_df

Unnamed: 0,ID,LATITUDE,LONGITUDE,ELEVATION,FREQUENCY,REPORTING_HOUR,best_shift_hour,reporting_time,year_length
0,GHCNd_US1NMSC0063,34.0799,252.7839,2187.5,24,,-8.0,8.0,4.841096
1,GHCNd_US1KSMI0015,38.5911,265.1294,278.0,24,,-11.0,11.0,2.569863
2,GHCNd_US1MSJC0035,30.4057,271.2334,20.1,24,,0.0,0.0,0.646575
3,GHCNd_US1WYGS0027,42.1817,255.4283,1315.2,24,,-10.0,10.0,3.586301
4,GHCNd_ASN00010628,-32.0094,117.4014,250.0,24,,-24.0,24.0,4.019178
...,...,...,...,...,...,...,...,...,...
58791,GHCNd_US1TXCLL097,32.9909,263.3206,185.0,24,,-8.0,8.0,4.780822
58792,GHCNd_US1TXHYS183,30.0588,261.8726,344.4,24,,-9.0,9.0,3.065753
58793,GHCNd_US1Harl4299,40.1045,260.6510,672.1,24,,-7.0,7.0,0.526027
58794,GHCNd_US1ARSR0009,35.7746,267.3613,506.0,24,,-11.0,11.0,4.893151


In [5]:
station_info_df = station_info_df.dropna(subset=["LATITUDE", "LONGITUDE"]).set_index("ID", drop=False)
station_info_df

Unnamed: 0_level_0,ID,LATITUDE,LONGITUDE,ELEVATION,FREQUENCY,REPORTING_HOUR,best_shift_hour,reporting_time,year_length
ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
GHCNd_US1NMSC0063,GHCNd_US1NMSC0063,34.0799,252.7839,2187.5,24,,-8.0,8.0,4.841096
GHCNd_US1KSMI0015,GHCNd_US1KSMI0015,38.5911,265.1294,278.0,24,,-11.0,11.0,2.569863
GHCNd_US1MSJC0035,GHCNd_US1MSJC0035,30.4057,271.2334,20.1,24,,0.0,0.0,0.646575
GHCNd_US1WYGS0027,GHCNd_US1WYGS0027,42.1817,255.4283,1315.2,24,,-10.0,10.0,3.586301
GHCNd_ASN00010628,GHCNd_ASN00010628,-32.0094,117.4014,250.0,24,,-24.0,24.0,4.019178
...,...,...,...,...,...,...,...,...,...
GHCNd_US1TXCLL097,GHCNd_US1TXCLL097,32.9909,263.3206,185.0,24,,-8.0,8.0,4.780822
GHCNd_US1TXHYS183,GHCNd_US1TXHYS183,30.0588,261.8726,344.4,24,,-9.0,9.0,3.065753
GHCNd_US1Harl4299,GHCNd_US1Harl4299,40.1045,260.6510,672.1,24,,-7.0,7.0,0.526027
GHCNd_US1ARSR0009,GHCNd_US1ARSR0009,35.7746,267.3613,506.0,24,,-11.0,11.0,4.893151


In [6]:
station_info_df.loc[station_info_df["ELEVATION"]>9000, "ELEVATION"] = np.nan

In [7]:
station_data = xr.open_dataset(QC_data_path / f"GHCNdailyPrcp_2020_2025.nc").load()
station_data

In [None]:
lat = np.linspace(90, -90, 721) # 0.25度
lon = np.linspace(0, 360, 1441)[:1440]
# lat = np.linspace(90, -90, 181) # 1度
# lon = np.linspace(0, 360, 361)[:360]
# lat = np.linspace(90, -90, 91) # 2度
# lon = np.linspace(0, 360, 181)[:180]

# 找到每个站点的纬度和经度在网格中的最近索引
station_info_df['LAT_INDEX'] = station_info_df['LATITUDE'].apply(lambda x: np.abs(lat - x).argmin())
station_info_df['LON_INDEX'] = station_info_df['LONGITUDE'].apply(lambda x: np.abs(lon - x).argmin())

# 建立经纬度网格和ID的映射关系
grid_to_id_mapping = station_info_df.groupby(['LAT_INDEX', 'LON_INDEX'])['ID'].apply(list).to_dict()
# grid_to_id_mapping = station_info_df.groupby(['LAT_INDEX', 'LON_INDEX'])['ELEVATION'].apply(list).to_dict()

for year in range(2020, 2026):
    target_time = pd.date_range(start=f'{year}-01-01', end=f'{year}-12-31')

    # Create a mask DataArray with the same lat and lon dimensions
    Station_elev = xr.DataArray(
        data=np.ones((len(lat), len(lon)), dtype=np.float32)*np.nan,
        coords={"lat": lat, "lon": lon},
        dims=["lat", "lon"],
        name="station_elev",
        attrs={"units": "m"}
    )
    inferred_reporting_time = xr.DataArray(
        data=np.ones((len(lat), len(lon)), dtype=np.float32)*np.nan,
        coords={"lat": lat, "lon": lon},
        dims=["lat", "lon"],
        name="inferred_reporting_time",
        attrs={"units": "hour"}
    )

    PRCP = xr.DataArray(
        data=np.zeros((len(lat), len(lon), len(target_time)), dtype=np.float32)*np.nan,
        coords={"lat": lat, "lon": lon, "time": target_time},
        dims=["lat", "lon", "time"],
        name="PRCP_mean",
        attrs={"units": "mm/day"}
    )
    Station_num = xr.DataArray(
        data=np.zeros((len(lat), len(lon), len(target_time)), dtype=int),
        coords={"lat": lat, "lon": lon, "time": target_time},
        dims=["lat", "lon", "time"],
        name="Station_num",
        attrs={"units": "stations per grid cell"}
    )
    PRCP_std = xr.DataArray(
        data=np.zeros((len(lat), len(lon), len(target_time)), dtype=np.float32)*np.nan,
        coords={"lat": lat, "lon": lon, "time": target_time},
        dims=["lat", "lon", "time"],
        name="PRCP_std",
        attrs={"units": "mm/day"}
    )

    # 填充数据
    for (lat_index, lon_index), station_ids in tqdm(grid_to_id_mapping.items()):
        
        Station_elev[lat_index, lon_index] = station_info_df.loc[station_ids, "ELEVATION"].mean()
        
        # 改进的众数计算
        try:
            mode_result = reporting_times.loc[station_ids, "reporting_time"].mode()
            if len(mode_result) > 0:
                inferred_reporting_time[lat_index, lon_index] = mode_result.iloc[0]
        except (KeyError, IndexError):
            pass
        
        single_grid_data = station_data.sel(station=station_ids, time=target_time)["PRCP_QC"].data
        
        # 统计有效数据
        valid_count = (~np.isnan(single_grid_data)).sum(axis=0)
        Station_num[lat_index, lon_index, :] = valid_count
        
        # 改进的均值和标准差计算，避免警告
        with np.errstate(invalid='ignore', divide='ignore'):
            PRCP[lat_index, lon_index, :] = np.where(
                valid_count > 1,
                np.nanmean(single_grid_data, axis=0),
                single_grid_data[0, :]
            )
            # 只在有足够数据时才计算标准差
            PRCP_std[lat_index, lon_index, :] = np.where(
                valid_count > 1,
                np.nanstd(single_grid_data, axis=0),
                np.nan
            )

    # 保存该年的数据集
    ds = xr.Dataset({
        "PRCP_mean": PRCP,
        "PRCP_std": PRCP_std,
        "Station_num": Station_num,
        "station_elev": Station_elev,
        "inferred_reporting_time": inferred_reporting_time
    })
    ds.to_netcdf(f"GHCNdgp_{year}.nc")
    print(f"Year {year} completed and saved.")
# xr.Dataset({"station_num": station_num, "mean_elevation": mean_elevation})

In [None]:
lat = np.linspace(90, -90, 721) # 0.25度
lon = np.linspace(0, 360, 1441)[:1440]
# lat = np.linspace(90, -90, 181) # 1度
# lon = np.linspace(0, 360, 361)[:360]
# lat = np.linspace(90, -90, 91) # 2度
# lon = np.linspace(0, 360, 181)[:180]

# 找到每个站点的纬度和经度在网格中的最近索引
station_info_df['LAT_INDEX'] = station_info_df['LATITUDE'].apply(lambda x: np.abs(lat - x).argmin())
station_info_df['LON_INDEX'] = station_info_df['LONGITUDE'].apply(lambda x: np.abs(lon - x).argmin())

# 建立经纬度网格和ID的映射关系
grid_to_id_mapping = station_info_df.groupby(['LAT_INDEX', 'LON_INDEX'])['ID'].apply(list).to_dict()
# grid_to_id_mapping = station_info_df.groupby(['LAT_INDEX', 'LON_INDEX'])['ELEVATION'].apply(list).to_dict()


# 填充数据
Station_elev = []
inferred_reporting_time = []
PRCP = []
PRCP_raw = []
Station_num = []
# PRCP_std = []

lat_lon = []
for (lat_index, lon_index), station_ids in tqdm(grid_to_id_mapping.items()):

    # Station_elev[lat_index, lon_index] = station_info_df.loc[station_ids, "ELEVATION"].mean()
    lat_lon.append((lat[lat_index], lon[lon_index]))
    # 改进的众数计算
    try:
        mode_result = reporting_times.loc[station_ids, "reporting_time"].mode()
        if len(mode_result) > 0:
            inferred_reporting_time.append(mode_result.iloc[0])
        else:
            inferred_reporting_time.append(0)
    except (KeyError, IndexError):
        inferred_reporting_time.append(0)

    single_grid_data = station_data.sel(station=station_ids)
    
    # 统计有效数据
    valid_count = (~np.isnan(single_grid_data["PRCP_QC"].data)).sum(axis=0)
    Station_num.append(valid_count)

    # 改进的均值和标准差计算，避免警告
    with np.errstate(invalid='ignore', divide='ignore'):
        PRCP.append(np.where(
            valid_count > 1,
            np.nanmean(single_grid_data["PRCP_QC"].data, axis=0),
            single_grid_data["PRCP_QC"].data[0, :]
        ))
        PRCP_raw.append(np.where(
            valid_count > 1,
            np.nanmean(single_grid_data["PRCP_raw"].data, axis=0),
            single_grid_data["PRCP_raw"].data[0, :]
        )
        )

ds = xr.Dataset(
    {
        "PRCP_mean": (("station", "time"), np.array(PRCP)),
        "PRCP_raw": (("station", "time"), np.array(PRCP_raw)),
    },
    coords={
        "station": np.arange(len(lat_lon)),
        "time": pd.date_range(start='2020-01-01', end='2025-12-31'),
        'latitude': (("station"), [i for i,_ in lat_lon]),
        'longitude': (("station"), [j for _,j in lat_lon]),
        'reporting_time': (("station"), inferred_reporting_time),
    }
)
ds.to_netcdf(f"Raw_QC_PRCP_by_station.nc")

In [29]:
ds

# Calculate the observation length of each station

In [18]:
reporting_times = pd.read_csv(f"daily_station_era5_reporting_times_{version}.csv")
reporting_times

Unnamed: 0,ID,best_shift_hour,reporting_time
0,GSOD_72252012907,-16.0,16.0
1,GHCNd_US1TXGP0123,-8.0,8.0
2,GHCNd_US1KSMI0015,-11.0,11.0
3,GHCNd_USC00140119,-9.0,9.0
4,GHCNd_US1WIRK0015,-10.0,10.0
...,...,...,...
59164,GHCNd_CA001128584,15.0,-15.0
59165,GSOD_16360099999,-9.0,9.0
59166,GHCNd_USC00407184,-1.0,1.0
59167,GHCNd_US1CAHM0004,-9.0,9.0


In [19]:
merge_daily_data = xr.open_dataset(QC_data_path / f"PRCP_QC_{version}_2020_2025.nc").load()
merge_daily_data

In [None]:
data_length = (merge_daily_data["qc_data"].notnull().sum(dim="time")/365).to_dataframe().reset_index()[["station","qc_data"]]
data_length = data_length.rename(columns={"station": "ID", "qc_data": "year_length"})

In [22]:
station_info_df = pd.read_csv(QC_data_path / f"Daily_station_info_{version}.csv")[["ID", "LATITUDE", "LONGITUDE", "ELEVATION", "FREQUENCY"]]
station_info_df

Unnamed: 0,ID,LATITUDE,LONGITUDE,ELEVATION,FREQUENCY
0,GHCNh_ACW00011647,17.13330,298.21670,19.2,6
1,GHCNh_AGI0000DAAT,22.81150,5.45110,1377.1,3
2,GHCNh_AGI0000DAOH,34.82000,358.23000,426.0,24
3,GHCNh_AGM00060353,36.81670,5.88330,6.0,6
4,GHCNh_AGM00060395,36.70000,4.13330,153.0,24
...,...,...,...,...,...
59164,GSOD_A0000253928,30.21889,263.62583,93.9,24
59165,GSOD_A0000594076,40.05361,253.63111,2258.9,24
59166,GSOD_A0002453848,30.51195,273.04594,34.0,24
59167,GSOD_A0003225715,52.22028,185.79389,17.1,24


In [23]:
station_info_df.merge(
    reporting_times, on="ID").merge(
        data_length, on="ID").to_csv(
            QC_data_path/f"{version}_station_info_2025.csv", index=None)