### Step 0: Environment Setup (Install Libraries)

In [None]:
# Run this cell first to install all dependencies for the entire project.
! pip install -q numpy pandas matplotlib seaborn scipy scikit-learn
! pip install -q dask distributed xarray bottleneck
! pip install -q zarr gcsfs fsspec earthaccess
! pip install cdsapi earthaccess xarray pandas numpy netCDF4 h5netcdf
! pip install -q netCDF4 h5netcdf rioxarray rasterio
! pip install netcdf4

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.3/43.3 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m284.1/284.1 kB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m70.5/70.5 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.2/9.2 MB[0m [31m94.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.3/87.3 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.3/14.3 MB[0m [31m126.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m88.0/88.0 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting cdsapi
  Downloading cdsapi-0.7.7-py2

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
import os
import glob
import cdsapi
import earthaccess
import shutil
import re
import h5py
import matplotlib.pyplot as plt
from scipy import stats
from scipy.stats import gamma, lognorm, weibull_min
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
from sklearn.impute import SimpleImputer
from scipy.spatial.distance import pdist
from sklearn.utils import resample

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


### Step 1: Ground Truth & Aux Data (CWA Stations + DEM)

In [None]:
'''
#KIWI'S CHANGES
from google.colab import drive
drive.mount("/content/drive")

# ===== Team-fixed BASE (everyone must have the shortcut here) =====
BASE = "/content/drive/MyDrive/CIE5140_Final_Project/Google_Colab/Database/Running_Result"

# ---- CSV patterns ----
RAIN_PATTERN = os.path.join(BASE,"Monthly_Average_Rainfall_Data","觀測_月資料_臺灣_降雨量_*.csv")
TEMP_PATTERN = os.path.join(BASE,"Monthly_Average_Temperature_Data","觀測_月資料_臺灣_平均溫_*.csv")

# ---- Output directory ----
OUT_DIR = os.path.join(BASE, "outputs")
os.makedirs(OUT_DIR, exist_ok=True)

OUT_RAIN_NC = os.path.join(OUT_DIR, "monthly_rain_climatology.nc")
OUT_TEMP_NC = os.path.join(OUT_DIR, "monthly_temp_climatology.nc")

'''

SyntaxError: incomplete input (ipython-input-2396795403.py, line 19)

In [None]:
# Compiles 24 years of CWA Rainfall and Temperature CSV files
# Generates a DEM based on the station grid
START_DATE = "2000-06-01"
END_DATE = "2023-12-31"
RAIN_PATTERN = "./drive/MyDrive/CIE5140_Final_Project/Google_Colab/Database/TCCIP_Data/Monthly_Average_Rainfall_Data/觀測_月資料_臺灣_降雨量_*.csv"
TEMP_PATTERN = "./drive/MyDrive/CIE5140_Final_Project/Google_Colab/Database/TCCIP_Data/Monthly_Average_Temperature_Data/觀測_月資料_臺灣_平均溫_*.csv"
OUT_RAIN_NC = "./drive/MyDrive/CIE5140_Final_Project/Google_Colab/Database/Running_Result/CWA_Rainfall_Compiled.nc"
OUT_TEMP_NC = "./drive/MyDrive/CIE5140_Final_Project/Google_Colab/Database/Running_Result/CWA_Temperature_Compiled.nc"
OUT_DEM_NC = "Taiwan_DEM.nc"

def compile_csv_to_netcdf(file_pattern, output_filename, var_name):
    files = glob.glob(file_pattern)
    all_data = []
    for f in files:
        df = pd.read_csv(f)
        df.columns = [c.strip() for c in df.columns] # Standardize column names
        # Melt the CSVs
        id_vars = ['LON', 'LAT']
        val_vars = [c for c in df.columns if c not in id_vars and 'Unnamed' not in c]
        df_melt = df.melt(id_vars=id_vars, value_vars=val_vars, var_name='YYYYMM', value_name='Value')
        df_melt['Value'] = pd.to_numeric(df_melt['Value'], errors='coerce')

        if '降雨量' in file_pattern:
            df_melt.loc[df_melt['Value'] < 0, 'Value'] = np.nan
        else:
            df_melt.loc[df_melt['Value'] <= -50, 'Value'] = np.nan

        df_melt['time'] = pd.to_datetime(df_melt['YYYYMM'], format='%Y%m', errors='coerce')
        mask = (df_melt['time'] >= START_DATE) & (df_melt['time'] <= END_DATE)
        df_filtered = df_melt.loc[mask]
        df_filtered = df_melt
        if not df_filtered.empty:
            all_data.append(df_filtered[['time', 'LAT', 'LON', 'Value']])

    full_df = pd.concat(all_data)
    full_df = full_df.groupby(['time', 'LAT', 'LON']).mean().reset_index()
    ds = full_df.set_index(['time', 'LAT', 'LON']).to_xarray()
    ds = ds.rename({'Value': var_name})
    ds = ds.rename({'LAT': 'lat', 'LON': 'lon'})
    ds.to_netcdf(output_filename)
    return ds

# #KIWI'S CHANGES
# def compile_csv_to_netcdf(file_pattern, output_filename, var_name):
#     files = sorted(glob.glob(file_pattern))
#     print(f"[compile_csv_to_netcdf] pattern={file_pattern}")
#     print(f"[compile_csv_to_netcdf] files found={len(files)} sample={files[:2]}")

#     if len(files) == 0:
#         raise FileNotFoundError(f"glob=0, pattern={file_pattern}")

#     all_data = []
#     empty_file_cnt = 0
#     time_fail_cnt = 0

#     for f in files:
#         # 1) read
#         df = pd.read_csv(f, encoding="utf-8-sig")
#         df.columns = [c.strip() for c in df.columns]

#         # 2) melt
#         id_vars = ["LON", "LAT"]
#         if not all(v in df.columns for v in id_vars):
#             raise KeyError(f"{f} missing LON/LAT. cols={df.columns.tolist()[:20]}")

#         val_vars = [c for c in df.columns if c not in id_vars and "Unnamed" not in str(c)]
#         if len(val_vars) == 0:
#             empty_file_cnt += 1
#             print(f"[skip] {f}: no val_vars")
#             continue

#         df_melt = df.melt(id_vars=id_vars, value_vars=val_vars, var_name="YYYYMM", value_name="Value")
#         df_melt["Value"] = pd.to_numeric(df_melt["Value"], errors="coerce")

#         # ✅ 3) 把 -99.9 當缺值（你檔案就是 -99.9）
#         df_melt.loc[df_melt["Value"] <= -99, "Value"] = np.nan

#         # 4) time parse：欄名是 int 200001 這種 => 轉成字串再 parse
#         s = df_melt["YYYYMM"].astype(str).str.strip().str.replace("/", "", regex=False).str.replace("-", "", regex=False)
#         df_melt["time"] = pd.to_datetime(s, format="%Y%m", errors="coerce")

#         if df_melt["time"].notna().sum() == 0:
#             time_fail_cnt += 1
#             print(f"[skip] {f}: time parse failed, month examples={sorted(set(df_melt['YYYYMM'].astype(str)))[:5]}")
#             continue

#         # ✅ 5) 不做 START_DATE/END_DATE 篩選（避免全空）
#         df_filtered = df_melt[["time", "LAT", "LON", "Value"]].dropna(subset=["time"])

#         # ✅ 6) 只要不是整份都 NaN 就收
#         if df_filtered["Value"].notna().sum() == 0:
#             # 這種情況通常是整張網格都是 -99.9
#             print(f"[skip] {f}: all values are NaN after replacing -99.9")
#             continue

#         all_data.append(df_filtered)

#     print(f"[compile_csv_to_netcdf] appended dfs={len(all_data)} (empty_file={empty_file_cnt}, time_fail={time_fail_cnt})")

#     if len(all_data) == 0:
#         raise ValueError(
#             "No objects to concatenate: all files were skipped.\n"
#             "Common reasons:\n"
#             "1) glob抓到檔但每個檔都沒有月份欄位\n"
#             "2) time parse 全失敗\n"
#             "3) 全部值都是 -99.9 被轉成 NaN\n"
#             f"empty_file_cnt={empty_file_cnt}, time_fail_cnt={time_fail_cnt}"
#         )

#     full_df = pd.concat(all_data, ignore_index=True)

#     # 7) average duplicates
#     full_df = full_df.groupby(["time", "LAT", "LON"], as_index=False)["Value"].mean()

#     # 8) to xarray
#     ds = full_df.set_index(["time", "LAT", "LON"]).to_xarray()
#     ds = ds.rename({"Value": var_name, "LAT": "lat", "LON": "lon"})

#     # 9) write (注意：shared folder 可能 permission denied；output_filename 換到 MyDrive 自己資料夾/或 /content)
#     if output_filename:
#         ds.to_netcdf(output_filename)

#     return ds

def generate_dem(reference_ds):
    lat = reference_ds.lat
    lon = reference_ds.lon
    if len(lat.dims) == 1 and len(lon.dims) == 1:   # Handle 1D vs 2D coords
        lat_grid, lon_grid = np.meshgrid(lat, lon, indexing='ij')
    else:
        lat_grid = lat.values
        lon_grid = lon.values
    # Synthetic elevation, peak at 23.5N, 121.0E
    dist = np.sqrt((lat_grid - 23.5)**2 + (lon_grid - 121.0)**2)
    elevation = 3500 * np.exp(-dist * 3)
    elevation = np.maximum(elevation, 0)

    ds_dem = xr.DataArray(
        elevation,
        coords=reference_ds.isel(time=0).coords,
        name='elevation'
    ).to_dataset()
    ds_dem.to_netcdf(OUT_DEM_NC, mode='w')
    return ds_dem

if __name__ == "__main__":
    ds_rain = compile_csv_to_netcdf(RAIN_PATTERN, OUT_RAIN_NC, "Precip")
    ds_temp = compile_csv_to_netcdf(TEMP_PATTERN, OUT_TEMP_NC, "Temperature")
    if ds_rain:
        generate_dem(ds_rain)

PermissionError: [Errno 13] Permission denied: '/content/drive/MyDrive/CIE5140_Final_Project/Google_Colab/Database/Running_Result/CWA_Rainfall_Compiled.nc'

### Step 2: Cloud Streaming (ERA5 + IMERG)

In [None]:
# ERA5
def download_era5(url, key, start_year=2000, end_year=2023):
    output_file = "ERA5_Taiwan_Monthly.nc"
    if os.path.exists(output_file):
        return
    # Create CDS Client Configuration
    with open(os.path.expanduser('~/.cdsapirc'), 'w') as f:
        f.write(f"url: {url}\nkey: {key}")

    c = cdsapi.Client()
    # Split request: Year 2000 (Jun-Dec) + Rest (Jan-Dec)
    req_year_2000 = {
        'product_type': 'monthly_averaged_reanalysis',
        'variable': 'total_precipitation',
        'year': '2000',
        'month': [str(m).zfill(2) for m in range(6, 13)],
        'time': '00:00',
        'area': [26, 119, 21, 123],
        'format': 'netcdf',
    }
    req_years_rest = {
        'product_type': 'monthly_averaged_reanalysis',
        'variable': 'total_precipitation',
        'year': [str(y) for y in range(start_year + 1, end_year + 1)],
        'month': [str(m).zfill(2) for m in range(1, 13)],
        'time': '00:00',
        'area': [26, 119, 21, 123],
        'format': 'netcdf',
    }

    if not os.path.exists('ERA5_Part1.nc'):
        print("  -> Downloading Part 1 (2000 Jun-Dec)...")
        c.retrieve('reanalysis-era5-single-levels-monthly-means', req_year_2000, 'ERA5_Part1.nc')

    if not os.path.exists('ERA5_Part2.nc'):
        print(f"  -> Downloading Part 2 ({start_year+1}-{end_year} Full Years)...")
        c.retrieve('reanalysis-era5-single-levels-monthly-means', req_years_rest, 'ERA5_Part2.nc')

    ds1 = xr.open_dataset('ERA5_Part1.nc')
    ds2 = xr.open_dataset('ERA5_Part2.nc')
    if 'valid_time' in ds1.coords:
        ds1 = ds1.rename({'valid_time': 'time'})
    if 'valid_time' in ds2.coords:
        ds2 = ds2.rename({'valid_time': 'time'})
    ds = xr.concat([ds1, ds2], dim='time').sortby('time')
    ds1.close()
    ds2.close()
    # Handle 'expver' (occurs if the data includes recent months)
    if 'expver' in ds.coords or 'expver' in ds.dims:
        try:
            # Try combine_first (best for overlapping data), ds.sel(expver=1) is final, ds.sel(expver=5) is preliminary
            ds_final = ds.sel(expver=1)
            ds_prelim = ds.sel(expver=5)
            ds = ds_final.combine_first(ds_prelim)
        except:
            # Just take the final data if available, or first index
            try:
                ds = ds.isel(expver=0)
            except:
                print("Continuing with raw data.")

    if not np.issubdtype(ds.time.dtype, np.datetime64):
            try: ds['time'] = ds.indexes['time'].to_datetimeindex()
            except: pass
    raw_times = ds.time.values
    new_times = pd.to_datetime(raw_times).to_period('M').to_timestamp()
    ds = ds.assign_coords(time=new_times)

    unique_times, counts = np.unique(new_times, return_counts=True)
    if np.any(counts > 1):
        ds = ds.groupby('time').mean(keep_attrs=True)

    if 'tp' in ds:
        ds['tp'].encoding = {}
        target_var = 'tp'
    elif 'total_precipitation' in ds:
        ds['total_precipitation'].encoding = {}
        target_var = 'total_precipitation'

    # Total monthly precipitation: (m/day) * 1000 (mm/m) * days_in_a_month = total mm/month
    days = ds['time'].dt.days_in_month
    if ds[target_var].max() < 100:
        ds_mm = ds[target_var] * 1000 * days
    else:
        ds_mm = ds[target_var]

    ds_mm.name = 'total_precipitation'
    ds_mm.attrs['units'] = 'mm'
    ds_mm.to_netcdf(output_file, encoding={'total_precipitation': {'dtype': 'float32'}})
    ds.close()
    for f in ['ERA5_Part1.nc', 'ERA5_Part2.nc']:
        if os.path.exists(f): os.remove(f)

# IMERG
def download_imerg(start_year=2000, end_year=2023):
    earthaccess.login(strategy="interactive")
    output_file = "IMERG_Taiwan_Monthly_Full.nc"

    if os.path.exists(output_file):
        return

    results = earthaccess.search_data(
        short_name="GPM_3IMERGM",
        version="07",
        temporal=(f"{start_year}-06-01", f"{end_year}-12-31"),
        bounding_box=(119, 21, 123, 26)
    )
    temp_dir = "imerg_temp_download"
    os.makedirs(temp_dir, exist_ok=True)

    paths = earthaccess.download(results, temp_dir)
    datasets = []
    for i, file_path in enumerate(paths):
        if os.path.getsize(file_path) < 10000:
            continue

        dt = None
        granule = results[i]
        try:
            date_str = granule["umm"]["TemporalExtent"]["RangeDateTime"]["BeginningDateTime"]
            dt = pd.to_datetime(date_str)
            if dt.tz is not None: dt = dt.tz_localize(None)
        except:
            match = re.search(r'(\d{8})', str(file_path))
            if match: dt = pd.to_datetime(match.group(1))

        if dt is None:
            continue

        with h5py.File(file_path, 'r') as f:
            grid = f['Grid']
            if 'precipitation' in grid:
                var_name = 'precipitation'
            elif 'precipitationCal' in grid:
                var_name = 'precipitationCal'
            else:
                continue

            lats = grid['lat'][:]
            lons = grid['lon'][:]
            lat_idx = np.where((lats >= 21) & (lats <= 26))[0]
            lon_idx = np.where((lons >= 119) & (lons <= 123))[0]
            if len(lat_idx) == 0 or len(lon_idx) == 0:
                continue
            lon_start, lon_end = lon_idx.min(), lon_idx.max() + 1
            lat_start, lat_end = lat_idx.min(), lat_idx.max() + 1

            raw_data = grid[var_name][0, lon_start:lon_end, lat_start:lat_end]
            ds_sub = xr.DataArray(
                raw_data,
                coords={'lon': lons[lon_start:lon_end], 'lat': lats[lat_start:lat_end]},
                dims=['lon', 'lat'],
                name='precipitation'
            )
            ds_sub = ds_sub.expand_dims(time=[dt])
            datasets.append(ds_sub)

            if i % 10 == 0:
                print(f"  Processed {i}/{len(paths)}...", end='\r')
    ds_combined = xr.concat(datasets, dim='time').sortby('time')

    # Conversion: rate (mm/hr) -> total (mm)
    days = ds_combined.time.dt.days_in_month
    ds_total = ds_combined * 24 * days
    ds_total.name = 'precipitation'
    ds_total.attrs['units'] = 'mm'
    ds_total.to_netcdf(output_file, encoding={'precipitation': {'dtype': 'float32'}})

    if os.path.exists(temp_dir):
        shutil.rmtree(temp_dir)

if __name__ == "__main__":
    CDS_URL = "https://cds.climate.copernicus.eu/api"
    CDS_KEY = input("Enter your NEW CDS Personal Access Token (or press Enter to skip ERA5): ")
    download_era5(CDS_URL, CDS_KEY, 2000, 2023)
    download_imerg(2000, 2023)

### Step 3: Distributional Diagnosis (Gamma Fitting)

In [None]:
# Aligns ERA5 & IMERG to CWA ground truth, performs distribution diagnosis (bias, quantiles, value ranges),
# fit multiple distributions (Gamma, LogNorm, Weibull) to find best fit.
INPUT_ERA5 = "ERA5_Taiwan_Monthly.nc"
INPUT_IMERG = "IMERG_Taiwan_Monthly_Full.nc"
INPUT_CWA = "CWA_Rainfall_Compiled.nc"
OUTPUT_DIAGNOSIS = "Taiwan_Rainfall_Diagnosis.nc"

def fix_era5_raw_structure(ds):
    if 'expver' in ds.dims:
        try:
            ds_fixed = ds.sel(expver=1).combine_first(ds.sel(expver=5))
            return ds_fixed
        except:
            return ds.isel(expver=0, drop=True)

    return ds

def find_best_fit(data, name):
    if len(data) == 0:
        return None, None, None
    valid_data = data[~np.isnan(data)]
    valid_data = valid_data[valid_data >= 0]

    # These distributions (Gamma/LogNorm/Weibull) are for POSITIVE continuous variables.
    # They cannot mathematically handle 0. We fit to the "Wet" portion.
    wet_data = valid_data[valid_data > 0.1] # Treat <0.1mm as "Trace/Zero"
    n_zeros = len(valid_data) - len(wet_data)
    dry_prob = (n_zeros / len(valid_data)) * 100 if len(valid_data) > 0 else 0
    if len(wet_data) < 10:
        return None, None, None

    results = {}

    # Gamma Distribution
    params_g = gamma.fit(wet_data, floc=0)
    # Calculate Log-Likelihood for AIC comparison (2k - 2logL), lower AIC is better
    ll_g = np.sum(np.log(gamma.pdf(wet_data, *params_g)))
    k_g = len(params_g) # Number of parameters
    aic_g = 2*k_g - 2*ll_g
    results['Gamma'] = {'params': params_g, 'aic': aic_g, 'func': gamma}

    # Log-Normal Distribution
    params_l = lognorm.fit(wet_data, floc=0)
    ll_l = np.sum(np.log(lognorm.pdf(wet_data, *params_l)))
    k_l = len(params_l)
    aic_l = 2*k_l - 2*ll_l
    results['LogNorm'] = {'params': params_l, 'aic': aic_l, 'func': lognorm}

    # Weibull Distribution
    params_w = weibull_min.fit(wet_data, floc=0)
    ll_w = np.sum(np.log(weibull_min.pdf(wet_data, *params_w)))
    k_w = len(params_w)
    aic_w = 2*k_w - 2*ll_w
    results['Weibull'] = {'params': params_w, 'aic': aic_w, 'func': weibull_min}

    best_dist_name = min(results, key=lambda x: results[x]['aic'])
    best_res = results[best_dist_name]
    print(f"{name} AIC Scores: Gamma={aic_g:.0f}, LogNorm={aic_l:.0f}, Weibull={aic_w:.0f}")
    return best_dist_name, best_res['params'], best_res['func'], dry_prob

def plot_for_presentation(ds, output_prefix="Presentation_Fig"):
    # Spatial Bias Maps: shows bias relative to CWA Ground Truth
    fig, axes = plt.subplots(1, 2, figsize=(16, 7))
    bias_imerg = ds['bias_imerg_cwa'].mean(dim='time', skipna=True)
    bias_era5 = ds['bias_era5_cwa'].mean(dim='time', skipna=True)
    # IMERG Bias
    bias_imerg.plot.pcolormesh(ax=axes[0], cmap='RdBu', center=0, cbar_kwargs={'label': 'Bias (mm/month)'})
    axes[0].set_title('IMERG Bias (vs CWA Ground Truth)\n(Blue=Wet, Red=Dry)')
    axes[0].set_xlabel('Longitude')
    axes[0].set_ylabel('Latitude')
    # ERA5 Bias
    bias_era5.plot.pcolormesh(ax=axes[1], cmap='RdBu', center=0, cbar_kwargs={'label': 'Bias (mm/month)'})
    axes[1].set_title('ERA5 Bias (vs CWA Ground Truth)\n(Blue=Wet, Red=Dry)')
    axes[1].set_xlabel('Longitude')
    axes[1].set_ylabel('Latitude')
    plt.tight_layout()
    plt.savefig(f"{output_prefix}_1_Bias_Maps.png", dpi=150)
    plt.close()

    # Seasonal cycle: average over lat/lon to get a single line for the whole region
    season_era = ds['seasonality_era'].mean(dim=['lat', 'lon'], skipna=True)
    season_imerg = ds['seasonality_imerg'].mean(dim=['lat', 'lon'], skipna=True)
    season_cwa = ds['seasonality_cwa'].mean(dim=['lat', 'lon'], skipna=True)
    plt.figure(figsize=(10, 6))
    season_cwa.plot(label='CWA (Ground Truth)', color='green', linewidth=3, marker='^', linestyle='-')
    season_era.plot(label='ERA5 (Reanalysis)', color='blue', linewidth=2, marker='o', linestyle='--')
    season_imerg.plot(label='IMERG (Satellite)', color='red', linewidth=2, marker='s', linestyle='--')
    plt.title('Average Seasonal Cycle (Taiwan Region)')
    plt.ylabel('Precipitation (mm/month)')
    plt.xlabel('Month')
    plt.xticks(range(1, 13))
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{output_prefix}_2_Seasonal_Cycle.png", dpi=150)
    plt.close()
    # Distribution fitting (best fit PDF vs empirical histogram)
    plt.figure(figsize=(10, 6))
    era_vals = ds['era5_on_cwa'].values.flatten()
    imerg_vals = ds['imerg_on_cwa'].values.flatten()
    cwa_vals = ds['cwa_ref'].values.flatten()
    cwa_vals = cwa_vals[cwa_vals >= 0]
    # Empirical Histograms
    plt.hist(cwa_vals, bins=50, density=True, range=(0,800), alpha=0.3, color='green', label='CWA (Ground Truth)')
    plt.hist(imerg_vals, bins=50, density=True, range=(0,800), alpha=0.3, color='red', label='IMERG')
    plt.hist(era_vals, bins=50, density=True, range=(0,800), alpha=0.3, color='blue', label='ERA5')
    x = np.linspace(0, 800, 1000)
    # CWA (ground truth)
    dist_name_c, params_c, func_c, dry_c = find_best_fit(cwa_vals, "CWA")
    if func_c:
        plt.plot(x, func_c.pdf(x, *params_c), 'g-', lw=2.5, label=f'CWA ({dist_name_c} Fit)')
    # ERA5
    dist_name_e, params_e, func_e, dry_e = find_best_fit(era_vals, "ERA5")
    if func_e:
        plt.plot(x, func_e.pdf(x, *params_e), 'b--', lw=2, label=f'ERA5 ({dist_name_e} Fit)')
    # IMERG
    dist_name_i, params_i, func_i, dry_i = find_best_fit(imerg_vals, "IMERG")
    if func_i:
        plt.plot(x, func_i.pdf(x, *params_i), 'r--', lw=2, label=f'IMERG ({dist_name_i} Fit)')
    plt.title(f'Rainfall Distribution Fitting (Best Fit Selection)')
    plt.xlabel('Monthly Rainfall (mm)')
    plt.ylabel('Probability Density')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.xlim(0, 600)
    plt.tight_layout()
    plt.savefig(f"{output_prefix}_3_Distribution_Fit.png", dpi=150)
    plt.close()

    print(f"CWA (Truth): {dry_c:.1f}% dry months. Real stations record true '0.0' rain.")
    print(f"ERA5: {dry_e:.1f}% dry months. Reanalysis tends to generate 'micro-drizzle', rarely 0.")
    print(f"IMERG: {dry_i:.1f}% dry months. Spatial averaging makes 0.0 rare over large pixels.")
    if dry_e < 0.1:
        print("ERA5's lack of dry months explains why it prefers Gamma (continuous) while CWA prefers Weibull.")

def run_distribution_diagnosis():
    ds_era = xr.open_dataset(INPUT_ERA5)
    ds_imerg = xr.open_dataset(INPUT_IMERG)
    ds_cwa = None
    if os.path.exists(INPUT_CWA):
        ds_cwa = xr.open_dataset(INPUT_CWA)

    ds_era = fix_era5_raw_structure(ds_era)

    if 'latitude' in ds_era.coords:
        ds_era = ds_era.rename({'latitude': 'lat'})
    if 'longitude' in ds_era.coords:
        ds_era = ds_era.rename({'longitude': 'lon'})
    if 'LAT' in ds_cwa.coords:
        ds_cwa = ds_cwa.rename({'LAT': 'lat'})
    if 'LON' in ds_cwa.coords:
        ds_cwa = ds_cwa.rename({'LON': 'lon'})

    var_e = 'tp' if 'tp' in ds_era else 'total_precipitation'
    var_i = 'precipitation' if 'precipitation' in ds_imerg else 'precipitationCal'
    var_c = 'Precip'

    # Spatial alignment (Regridded to CWA)
    ds_era_interp = ds_era[var_e].interp_like(ds_cwa[var_c], method='linear')
    ds_imerg_interp = ds_imerg[var_i].interp_like(ds_cwa[var_c], method='linear')
    # Calculate bias
    bias_imerg = ds_imerg_interp - ds_cwa[var_c]
    bias_era5 = ds_era_interp - ds_cwa[var_c]
    bias_imerg.name = 'bias_imerg_cwa'
    bias_era5.name = 'bias_era5_cwa'

    # Seasonality
    clim_era = ds_era_interp.groupby('time.month').mean(dim='time')
    clim_imerg = ds_imerg_interp.groupby('time.month').mean(dim='time')
    clim_cwa = ds_cwa[var_c].groupby('time.month').mean(dim='time')
    ds_out = xr.Dataset({
        'cwa_ref': ds_cwa[var_c],
        'era5_on_cwa': ds_era_interp,
        'imerg_on_cwa': ds_imerg_interp,
        'bias_imerg_cwa': bias_imerg,
        'bias_era5_cwa': bias_era5,
        'seasonality_era': clim_era,
        'seasonality_imerg': clim_imerg,
        'seasonality_cwa': clim_cwa
    })
    ds_out.to_netcdf(OUTPUT_DIAGNOSIS)
    plot_for_presentation(ds_out)

if __name__ == "__main__":
    run_distribution_diagnosis()

### Step 4: Error Modeling (Regression vs. Altitude)

In [None]:
def load_dataset_safe(path):
    if not os.path.exists(path):
        return None
    with xr.open_dataset(path) as ds:
        ds = ds.load()
        ds = standardize_coords(ds)
        return ds

# Align to CWA grid
def standardize_coords(ds):
    rename = {}
    if 'latitude' in ds.coords: rename['latitude'] = 'lat'
    if 'longitude' in ds.coords: rename['longitude'] = 'lon'
    if rename: ds = ds.rename(rename)
    # Force sort to ensure alignment
    ds = ds.sortby(['lat', 'lon'])
    return ds

# Focus: Modeling error variance based on Altitude and Temperature (Static).
def run_error_modeling_alt_temp():
    ds_rain = load_dataset_safe(OUT_RAIN_NC)
    ds_temp = load_dataset_safe(OUT_TEMP_NC)
    ds_dem = load_dataset_safe(OUT_DEM_NC)

    ds_era = load_dataset_safe(INPUT_ERA5)
    ds_imerg = load_dataset_safe(INPUT_IMERG)
    ds_era = standardize_coords(ds_era)
    ds_imerg = standardize_coords(ds_imerg)

    var_rain = 'Precip'
    var_temp = 'Temperature'
    var_elev = 'elevation'
    var_e = 'tp' if 'tp' in ds_era else 'total_precipitation'
    var_i = 'precipitation' if 'precipitation' in ds_imerg else 'precipitationCal'

    # Interpolation to rain grid
    era_aligned = ds_era[var_e].interp(lat=ds_rain.lat, lon=ds_rain.lon, method='linear')
    imerg_aligned = ds_imerg[var_i].interp(lat=ds_rain.lat, lon=ds_rain.lon, method='linear')
    dem_aligned = ds_dem[var_elev].interp(lat=ds_rain.lat, lon=ds_rain.lon, method='nearest')
    # Calculate mean temp to rain grid
    temp_mean = ds_temp[var_temp].mean(dim='time', skipna=True)
    temp_aligned = temp_mean.interp(lat=ds_rain.lat, lon=ds_rain.lon, method='linear')
    dem_aligned = dem_aligned.transpose('lat', 'lon')
    temp_aligned = temp_aligned.transpose('lat', 'lon')
    era_aligned = era_aligned.transpose('time', 'lat', 'lon')
    imerg_aligned = imerg_aligned.transpose('time', 'lat', 'lon')

    # Calculate static error variance
    common_time = np.intersect1d(ds_rain.time, era_aligned.time)
    resid_era = era_aligned.sel(time=common_time) - ds_rain[var_rain].sel(time=common_time)
    resid_imerg = imerg_aligned.sel(time=common_time) - ds_rain[var_rain].sel(time=common_time)
    # Variance over all time
    sigma2_era = resid_era.var(dim='time', skipna=True)
    sigma2_imerg = resid_imerg.var(dim='time', skipna=True)
    sigma2_era = sigma2_era.transpose('lat', 'lon')
    sigma2_imerg = sigma2_imerg.transpose('lat', 'lon')
    # Debug Shapes
    print(f"   DEBUG SHAPES (Post-Transpose):")
    print(f"   Sigma ERA: {sigma2_era.shape}")
    print(f"   DEM:       {dem_aligned.shape}")
    print(f"   Temp:      {temp_aligned.shape}")

    # Regression
    def fit_and_predict(sigma_map):
        # Stacking (lat, lon) into a single dimension 'z' ensures pixels stay matched
        y_stacked = sigma_map.stack(z=('lat', 'lon'))
        x1_stacked = dem_aligned.stack(z=('lat', 'lon'))
        x2_stacked = temp_aligned.stack(z=('lat', 'lon'))
        y = y_stacked.values
        x1 = x1_stacked.values
        x2 = x2_stacked.values
        # Unified Masking
        mask = ~np.isnan(y) & ~np.isnan(x1) & ~np.isnan(x2) & ~np.isinf(y)

        # Train
        y_train = y[mask]
        X_train = np.column_stack((x1[mask], x2[mask]))
        model = LinearRegression()
        model.fit(X_train, y_train)
        a, b = model.coef_
        c = model.intercept_
        r2 = r2_score(y_train, model.predict(X_train))
        print(f"σ² = ({a:.2f} * Alt) + ({b:.2f} * Temp) + {c:.2f}")
        print(f"R^2: {r2:.4f}")

        # Predict full map
        imputer = SimpleImputer(strategy='mean')
        X_full = np.column_stack((x1, x2))
        X_full_imputed = imputer.fit_transform(X_full)
        y_pred_flat = model.predict(X_full_imputed)
        # Reconstruct 2D Map by unstacking
        da_pred = xr.DataArray(y_pred_flat, coords=y_stacked.coords, dims=y_stacked.dims)
        y_pred_map = da_pred.unstack('z')
        # Mask Ocean based on original DEM
        y_pred_map = y_pred_map.where(~np.isnan(dem_aligned))
        return model, y_pred_map

    _, pred_map_e = fit_and_predict(sigma2_era)
    _, pred_map_i = fit_and_predict(sigma2_imerg)

    plt.figure(figsize=(12, 10))
    x_flat = dem_aligned.stack(z=('lat', 'lon')).values
    t_flat = temp_aligned.stack(z=('lat', 'lon')).values
    y_era_flat = sigma2_era.stack(z=('lat', 'lon')).values
    y_img_flat = sigma2_imerg.stack(z=('lat', 'lon')).values
    # Plot 1: Altitude vs Error
    plt.subplot(2, 2, 1)
    plt.scatter(x_flat, y_era_flat, alpha=0.3, s=5, c='blue', label='ERA5')
    plt.scatter(x_flat, y_img_flat, alpha=0.3, s=5, c='red', label='IMERG')
    plt.xlabel('Altitude (m)')
    plt.ylabel('Error Variance')
    plt.title('Error vs Altitude')
    plt.legend()
    plt.grid(True, alpha=0.3)
    # Plot 2: Temp vs Error
    plt.subplot(2, 2, 2)
    plt.scatter(t_flat, y_era_flat, alpha=0.3, s=5, c='blue', label='ERA5')
    plt.scatter(t_flat, y_img_flat, alpha=0.3, s=5, c='red', label='IMERG')
    plt.xlabel('Mean Temp (°C)')
    plt.ylabel('Error Variance')
    plt.title('Error vs Temperature')
    plt.legend()
    plt.grid(True, alpha=0.3)
    # Plot 3: Predicted Map IMERG (RdBu_r)
    if pred_map_i is not None:
        plt.subplot(2, 2, 3)
        # Transpose back to (lat, lon) for plotting just in case unstack flipped it
        if 'lat' in pred_map_i.dims and 'lon' in pred_map_i.dims:
             pred_map_i = pred_map_i.transpose('lat', 'lon')
        plt.imshow(np.flipud(pred_map_i.values), cmap='magma')
        plt.title('Predicted Error Map (IMERG)')
        plt.colorbar(label='Predicted σ²')
        plt.axis('off')
    # Plot 4: Predicted Map ERA5 (RdBu_r)
    if pred_map_e is not None:
        plt.subplot(2, 2, 4)
        if 'lat' in pred_map_e.dims and 'lon' in pred_map_e.dims:
             pred_map_e = pred_map_e.transpose('lat', 'lon')
        plt.imshow(np.flipud(pred_map_e.values), cmap='magma')
        plt.title('Predicted Error Map (ERA5)')
        plt.colorbar(label='Predicted σ²')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig("Presentation_Fig_4_Static_Error.png", dpi=150)

if __name__ == "__main__":
    run_error_modeling_alt_temp()

In [None]:
# Focus: Modeling error variance based on Altitude, Temperature, and Seasonality.
def run_error_modeling_alt_temp_seas():
    ds_rain = load_dataset_safe(OUT_RAIN_NC)
    ds_temp = load_dataset_safe(OUT_TEMP_NC)
    ds_dem = load_dataset_safe(OUT_DEM_NC)

    ds_era = load_dataset_safe(INPUT_ERA5)
    ds_imerg = load_dataset_safe(INPUT_IMERG)
    ds_era = standardize_coords(ds_era)
    ds_imerg = standardize_coords(ds_imerg)

    # Variables
    var_rain = 'Precip'
    var_temp = 'Temperature'
    var_elev = 'elevation'
    # Robust variable finding for ERA5/IMERG
    var_e = 'tp' if 'tp' in ds_era else 'total_precipitation'
    var_i = 'precipitation' if 'precipitation' in ds_imerg else 'precipitationCal'

    # Interpolate models to CWA Grid
    era_aligned = ds_era[var_e].interp(lat=ds_rain.lat, lon=ds_rain.lon, method='linear')
    imerg_aligned = ds_imerg[var_i].interp(lat=ds_rain.lat, lon=ds_rain.lon, method='linear')
    # Align DEM and Temp to CWA grid
    dem_aligned = ds_dem[var_elev].interp(lat=ds_rain.lat, lon=ds_rain.lon, method='nearest')
    temp_mean = ds_temp[var_temp].mean(dim='time', skipna=True)
    temp_aligned = temp_mean.interp(lat=ds_rain.lat, lon=ds_rain.lon, method='linear')

    # Ensure time dimensions align if they differ slightly
    common_time = np.intersect1d(ds_rain.time, era_aligned.time)
    resid_era = era_aligned.sel(time=common_time) - ds_rain[var_rain].sel(time=common_time)
    resid_imerg = imerg_aligned.sel(time=common_time) - ds_rain[var_rain].sel(time=common_time)

    # Calculate variance per month, instead of one sigma^2 per pixel, we get 12 sigma^2 maps (one for Jan, Feb...)
    sigma2_era = resid_era.groupby('time.month').var(dim='time', skipna=True)
    sigma2_imerg = resid_imerg.groupby('time.month').var(dim='time', skipna=True)

    # Regression with seasonality
    def solve_regression_seasonal(sigma_map):
        # Prepare Data
        # We need to flatten (Month, Lat, Lon) -> 1D Arrays
        # X1: Elevation (Repeated for 12 months)
        # X2: Temperature (Repeated for 12 months) - ideally we'd use monthly temp, but mean is ok for spatial trend
        # X3: Month Sin
        # X4: Month Cos
        months = sigma_map.month.values
        n_months = len(months)

        # Expand static maps to 3D (12, Lat, Lon)
        dem_3d = np.tile(dem_aligned.values, (n_months, 1, 1))
        temp_3d = np.tile(temp_aligned.values, (n_months, 1, 1))
        # Create Seasonality Maps
        month_grid = np.tile(months[:, None, None], (1, dem_aligned.shape[0], dem_aligned.shape[1]))
        sin_month = np.sin(2 * np.pi * month_grid / 12)
        cos_month = np.cos(2 * np.pi * month_grid / 12)

        y = sigma_map.values.flatten()
        x1 = dem_3d.flatten()
        x2 = temp_3d.flatten()
        x3 = sin_month.flatten()
        x4 = cos_month.flatten()

        # Mask NaNs
        mask = ~np.isnan(y) & ~np.isnan(x1) & ~np.isnan(x2) & ~np.isinf(y)
        y_clean = y[mask]
        X_clean = np.column_stack((x1[mask], x2[mask], x3[mask], x4[mask]))

        # Regression
        model = LinearRegression()
        model.fit(X_clean, y_clean)
        a, b, c, d = model.coef_
        intercept = model.intercept_
        r2 = r2_score(y_clean, model.predict(X_clean))
        print(f"Equation: σ² = ({a:.2f}*Alt) + ({b:.2f}*Temp) + ({c:.2f}*sinM) + ({d:.2f}*cosM) + {intercept:.2f}")
        print(f"R^2: {r2:.4f}")

        # Predict Map (Just for Month 8 - August/Typhoon Season as example)
        # We construct a 2D map for a specific month to visualize
        target_month = 8

        dem_2d = dem_aligned.values.flatten()
        temp_2d = temp_aligned.values.flatten()
        sin_2d = np.full_like(dem_2d, np.sin(2 * np.pi * target_month / 12))
        cos_2d = np.full_like(dem_2d, np.cos(2 * np.pi * target_month / 12))

        imputer = SimpleImputer(strategy='mean')
        X_aug = np.column_stack((dem_2d, temp_2d, sin_2d, cos_2d))
        X_aug_imp = imputer.fit_transform(X_aug)

        y_pred_aug = model.predict(X_aug_imp)
        y_pred_aug[np.isnan(dem_2d)] = np.nan # Mask ocean

        return y_pred_aug.reshape(dem_aligned.shape)

    pred_map_e = solve_regression_seasonal(sigma2_era)
    pred_map_i = solve_regression_seasonal(sigma2_imerg)

    plt.figure(figsize=(12, 6))
    # Plot 1: Predicted Reliability Map (IMERG - August)
    if pred_map_i is not None:
        plt.subplot(1, 2, 1)
        plt.imshow(np.flipud(pred_map_i), cmap='magma')
        plt.title('Predicted Error Map (IMERG) - August\n(Including Seasonality)')
        plt.colorbar(label='Predicted σ²')
        plt.axis('off')
    # Plot 2: Predicted Reliability Map (ERA5 - August)
    if pred_map_e is not None:
        plt.subplot(1, 2, 2)
        plt.imshow(np.flipud(pred_map_e), cmap='magma')
        plt.title('Predicted Error Map (ERA5) - August\n(Including Seasonality)')
        plt.colorbar(label='Predicted σ²')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig("Presentation_Fig_4_Error_Analysis_Seasonal.png", dpi=150)

if __name__ == "__main__":
    run_error_modeling_alt_temp_seas()

### Step 5: Bayesian Fusion (Merging Datasets)

In [None]:
# Global Inverse Variance Weighting:
# Since ERA5 has lower global RMSE than IMERG (Phase 6), it gets slightly more weight.
# This is safer than pixel-wise weighting which caused overfitting.
INPUT_CWA_TEMP = "CWA_Temperature_Compiled.nc"
INPUT_DEM = "Taiwan_DEM.nc"
OUTPUT_FUSED = "Taiwan_Rainfall_Fused.nc"
OUTPUT_CSV = "Fusion_Variogram_Data.csv" # New Output

def calculate_variogram(data_map, n_bins=20):
    vals = data_map.values.flatten()
    lats = data_map.lat.values
    lons = data_map.lon.values
    # Handle coordinates if 1D or 2D
    if lats.ndim == 1 and lons.ndim == 1:
        lat_grid, lon_grid = np.meshgrid(lats, lons, indexing='ij')
    else:
        lat_grid, lon_grid = lats, lons
    mask = ~np.isnan(vals)
    z = vals[mask]
    coords = np.column_stack((lat_grid.flatten()[mask], lon_grid.flatten()[mask]))

    # Pairwise distances (Euclidean approx)
    # Using a subset if too large to prevent memory crash
    if len(z) > 5000:
        idx = np.random.choice(len(z), 5000, replace=False)
        z = z[idx]
        coords = coords[idx]

    dists = pdist(coords)
    sq_diffs = pdist(z[:, None], metric='sqeuclidean')

    # Binning
    bins = np.linspace(0, np.max(dists)/3, n_bins)
    bin_centers = 0.5 * (bins[1:] + bins[:-1])
    gamma = []
    inds = np.digitize(dists, bins)
    for i in range(1, len(bins)):
        var_val = 0.5 * np.mean(sq_diffs[inds == i]) if np.any(inds == i) else np.nan
        gamma.append(var_val)

    return bin_centers, np.array(gamma)

# Spatial Fusion
def run_data_fusion():
    ds_rain = load_dataset_safe(INPUT_CWA)
    ds_era = load_dataset_safe(INPUT_ERA5)
    ds_imerg = load_dataset_safe(INPUT_IMERG)

    var_c = 'Precip'
    var_e = 'tp' if 'tp' in ds_era else 'total_precipitation'
    var_i = 'precipitation' if 'precipitation' in ds_imerg else 'precipitationCal'

    era_aligned = ds_era[var_e].interp(lat=ds_rain.lat, lon=ds_rain.lon, method='linear')
    imerg_aligned = ds_imerg[var_i].interp(lat=ds_rain.lat, lon=ds_rain.lon, method='linear')

    common_time = np.intersect1d(ds_rain.time, era_aligned.time)
    print(f"   -> Found {len(common_time)} common months.")

    # Bias Correction
    # Calculate mean bias over time (Climatological Bias)
    bias_era = (era_aligned.sel(time=common_time) - ds_rain[var_c].sel(time=common_time)).mean(dim='time')
    bias_imerg = (imerg_aligned.sel(time=common_time) - ds_rain[var_c].sel(time=common_time)).mean(dim='time')
    # Correct the models (Unbiased Models)
    era_corrected = era_aligned - bias_era
    imerg_corrected = imerg_aligned - bias_imerg
    era_corrected = era_corrected.where(era_corrected >= 0, 0)
    imerg_corrected = imerg_corrected.where(imerg_corrected >= 0, 0)

    # Optimized Global Weight Fusion
    # We use the RMSE results from Phase 6 to determine global weights.
    # W = 1 / MSE
    resid_era = (era_corrected.sel(time=common_time) - ds_rain[var_c].sel(time=common_time))
    resid_imerg = (imerg_corrected.sel(time=common_time) - ds_rain[var_c].sel(time=common_time))
    # Global MSE
    mse_era = (resid_era**2).mean().item()
    mse_imerg = (resid_imerg**2).mean().item()
    print(f"ERA5 MSE: {mse_era:.2f}")
    print(f"IMERG MSE: {mse_imerg:.2f}")
    # Inverse Variance Weights
    w_era_val = 1.0 / mse_era
    w_imerg_val = 1.0 / mse_imerg
    # Normalize
    w_sum = w_era_val + w_imerg_val
    W_ERA = w_era_val / w_sum
    W_IMERG = w_imerg_val / w_sum
    print(f"Optimal weights: ERA5={W_ERA:.3f}, IMERG={W_IMERG:.3f}")

    ds_fused_raw = (W_ERA * era_corrected) + (W_IMERG * imerg_corrected)
    fused_bias = (ds_fused_raw.sel(time=common_time) - ds_rain[var_c].sel(time=common_time)).mean(dim='time')
    ds_fused = ds_fused_raw - fused_bias
    ds_fused = ds_fused.where(ds_fused >= 0, 0)
    ds_fused.name = 'precip_fused'
    ds_fused.to_netcdf(OUTPUT_FUSED)

    # Spatial check (variogram)
    mean_obs = ds_rain[var_c].mean(dim='time', skipna=True)

    bins_cwa, gam_cwa = calculate_variogram(mean_obs)
    bins_era, gam_era = calculate_variogram(era_corrected.mean(dim='time'))
    bins_img, gam_imerg = calculate_variogram(imerg_corrected.mean(dim='time'))
    bins_fus, gam_fused = calculate_variogram(ds_fused.mean(dim='time'))

    if bins_cwa is not None:
        df_var = pd.DataFrame({
            'Distance_Deg': bins_cwa,
            'Semivariance_CWA': gam_cwa,
            'Semivariance_ERA5': gam_era if bins_era is not None else np.nan,
            'Semivariance_IMERG': gam_imerg if bins_img is not None else np.nan,
            'Semivariance_FUSED': gam_fused if bins_fus is not None else np.nan
        })
        df_var.to_csv(OUTPUT_CSV, index=False)

    # Plot 1: Sptial Maps
    plt.figure(figsize=(15, 10))
    # Prepare Mean Maps for Visualization
    map_cwa = mean_obs
    map_era = era_corrected.mean(dim='time', skipna=True)
    map_imerg = imerg_corrected.mean(dim='time', skipna=True)
    map_fused = ds_fused.mean(dim='time', skipna=True)

    vmin = min(map_cwa.min(), map_era.min(), map_imerg.min(), map_fused.min())
    vmax = max(map_cwa.max(), map_era.max(), map_imerg.max(), map_fused.max())

    plt.subplot(2, 2, 1)
    map_cwa.plot(cmap='Blues', vmin=vmin, vmax=vmax, cbar_kwargs={'label': 'mm/month'})
    plt.title('GROUND TRUTH (CWA)\n(Target Pattern)')
    plt.axis('off')

    plt.subplot(2, 2, 2)
    map_era.plot(cmap='Blues', vmin=vmin, vmax=vmax, cbar_kwargs={'label': 'mm/month'})
    plt.title(f'ERA5 Corrected (W={W_ERA:.2f})')
    plt.axis('off')

    plt.subplot(2, 2, 3)
    map_imerg.plot(cmap='Blues', vmin=vmin, vmax=vmax, cbar_kwargs={'label': 'mm/month'})
    plt.title(f'IMERG Corrected (W={W_IMERG:.2f})')
    plt.axis('off')

    plt.subplot(2, 2, 4)
    map_fused.plot(cmap='Blues', vmin=vmin, vmax=vmax, cbar_kwargs={'label': 'mm/month'})
    plt.title('FINAL FUSED PRODUCT\n(Optimal Global Weight)')
    plt.axis('off')

    plt.tight_layout()
    plt.savefig("Presentation_Fig_5a_Spatial_Maps.png", dpi=150)

    # Plot 2: Variogram comparison
    plt.figure(figsize=(10, 6))

    if bins_era is not None: plt.plot(bins_era, gam_era, 'b--.', label=f'ERA5 (W={W_ERA:.2f})')
    if bins_img is not None: plt.plot(bins_img, gam_imerg, 'r--.', label=f'IMERG (W={W_IMERG:.2f})')
    if bins_fus is not None: plt.plot(bins_fus, gam_fused, 'g-^', linewidth=2, label='Fused Product Texture')

    plt.xlabel('Distance (Degrees)')
    plt.ylabel('Semivariance (Spatial Variability)')
    plt.title('Spatial Structure Comparison (Variogram)\n(Goal: Green line should match Black line)')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig("Presentation_Fig_5b_Variogram.png", dpi=150)

if __name__ == "__main__":
    run_data_fusion()

### Step 6: Validation (Bootstrapping & Final Stats)

In [None]:
# Goal: Prove that Fused Data is statistically significantly better than ERA5/IMERG.
# Technique: Bootstrapping (10000x) & Hypothesis Testing (t-test).
INPUT_FUSED = "Taiwan_Rainfall_Fused.nc"
BOOTSTRAP_N = 10000 # Increased for higher precision
TEST_SPLIT_RATIO = 0.20 # 80/20 train/val split
RANDOM_SEED = 42

def load_and_align():
    """
    Loads all datasets and aligns them to the CWA grid for pixel-to-pixel comparison.
    """
    # Load CWA (Ground Truth)
    ds_cwa = load_dataset_safe(INPUT_CWA)
    # Load Models
    ds_era = load_dataset_safe(INPUT_ERA5)
    ds_imerg = load_dataset_safe(INPUT_IMERG)
    ds_fused = load_dataset_safe(INPUT_FUSED)

    # Variables
    v_c = 'Precip'
    v_e = 'tp' if 'tp' in ds_era else 'total_precipitation'
    v_i = 'precipitation' if 'precipitation' in ds_imerg else 'precipitationCal'
    v_f = 'precip_fused'

    era_a = ds_era[v_e].interp(lat=ds_cwa.lat, lon=ds_cwa.lon, method='linear')
    imerg_a = ds_imerg[v_i].interp(lat=ds_cwa.lat, lon=ds_cwa.lon, method='linear')
    fused_a = ds_fused[v_f].interp(lat=ds_cwa.lat, lon=ds_cwa.lon, method='nearest') # Fused is already on grid

    common_time = np.intersect1d(ds_cwa.time, era_a.time)
    # Flatten Data for Statistical Testing (Vector of all pixels, all times)
    # This creates a massive 1D array of every observation
    def flatten_data(da):
        return da.sel(time=common_time).values.flatten()
    cwa_flat = flatten_data(ds_cwa[v_c])
    era_flat = flatten_data(era_a)
    imerg_flat = flatten_data(imerg_a)
    fused_flat = flatten_data(fused_a)
    mask = ~np.isnan(cwa_flat) & ~np.isnan(era_flat) & ~np.isnan(imerg_flat) & ~np.isnan(fused_flat)
    return cwa_flat[mask], era_flat[mask], imerg_flat[mask], fused_flat[mask]

def bootstrap_rmse(truth, model, n_boot=1000):
    errors = (model - truth) ** 2
    rmse_list = []

    # Resample indices with replacement
    indices = np.arange(len(truth))

    for _ in range(n_boot):
        # Sample indices
        boot_idx = resample(indices, replace=True, n_samples=len(truth))
        mse = np.mean(errors[boot_idx])
        rmse_list.append(np.sqrt(mse))

    return np.array(rmse_list)

def run_validation():
    obs, mod_era, mod_imerg, mod_fused = load_and_align()

    # 2. TRAIN/TEST SPLIT
    # We simulate a "held-out" validation set.
    # Since we already fused using all data in Phase 5 (for the map), strictly speaking
    # this is an "in-sample" check unless we re-ran Ph5.
    # However, for demonstrating the STATISTICAL METHOD, we will split the *residuals*.
    n_samples = len(obs)
    n_val = int(n_samples * TEST_SPLIT_RATIO)
    # Randomly select 20% indices for validation
    np.random.seed(RANDOM_SEED)
    val_indices = np.random.choice(n_samples, n_val, replace=False)
    val_obs = obs[val_indices]
    val_era = mod_era[val_indices]
    val_imerg = mod_imerg[val_indices]
    val_fused = mod_fused[val_indices]

    # Bootstrapping RMSE
    rmse_boot_era = bootstrap_rmse(val_obs, val_era, BOOTSTRAP_N)
    rmse_boot_imerg = bootstrap_rmse(val_obs, val_imerg, BOOTSTRAP_N)
    rmse_boot_fused = bootstrap_rmse(val_obs, val_fused, BOOTSTRAP_N)

    # Confidence Intervals
    def get_ci(boot_dist):
        return np.percentile(boot_dist, [2.5, 97.5]), np.mean(boot_dist)
    ci_era, m_era = get_ci(rmse_boot_era)
    ci_img, m_img = get_ci(rmse_boot_imerg)
    ci_fus, m_fus = get_ci(rmse_boot_fused)

    print(f"ERA5 RMSE:  {m_era:.2f} [{ci_era[0]:.2f}, {ci_era[1]:.2f}] mm")
    print(f"IMERG RMSE: {m_img:.2f} [{ci_img[0]:.2f}, {ci_img[1]:.2f}] mm")
    print(f"FUSED RMSE: {m_fus:.2f} [{ci_fus[0]:.2f}, {ci_fus[1]:.2f}] mm")

    # Hypothesis Testing
    # H0: Fused Error >= ERA5 Error
    # H1: Fused Error < ERA5 Error (One-sided test)
    # We treat the bootstrap distributions as the samples for the t-test
    # Test vs ERA5
    t_stat, p_val = stats.ttest_ind(rmse_boot_fused, rmse_boot_era, alternative='less')
    is_sig = p_val < 0.05
    print(f"   vs ERA5:  t={t_stat:.2f}, p={p_val:.4e} | Significant? {'Yes' if is_sig else 'No'}")
    # Test vs IMERG
    t_stat2, p_val2 = stats.ttest_ind(rmse_boot_fused, rmse_boot_imerg, alternative='less')
    is_sig2 = p_val2 < 0.05
    print(f"   vs IMERG: t={t_stat2:.2f}, p={p_val2:.4e} | Significant? {'Yes' if is_sig2 else 'No'}")

    # Plot Distribution
    plt.figure(figsize=(10, 6))
    plt.hist(rmse_boot_era, bins=30, alpha=0.5, label='ERA5 Error', color='blue')
    plt.hist(rmse_boot_imerg, bins=30, alpha=0.5, label='IMERG Error', color='red')
    plt.hist(rmse_boot_fused, bins=30, alpha=0.7, label='FUSED Error', color='green')
    plt.axvline(m_fus, color='green', linestyle='--', linewidth=2, label='Fused Mean')
    plt.axvline(m_era, color='blue', linestyle='--', linewidth=2)
    plt.xlabel('RMSE (mm/month)')
    plt.ylabel('Frequency (Bootstrap Samples)')
    plt.title(f'Improvement Verification: Fused Model vs Inputs\n(p-value vs ERA5: {p_val:.1e})')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig("Presentation_Fig_6_Validation.png", dpi=150)

if __name__ == "__main__":
    run_validation()