In [None]:
import os
import re
import numpy as np
import pandas as pd
import itertools
from tqdm import tqdm
from numpy.polynomial import legendre
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.impute import SimpleImputer
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
import matplotlib.pyplot as plt
from shapely.geometry import box
import xarray as xr

# -------------------------------
BASE_DIR     = "/project2/zhan248_1326/hhao4018/UQ_analysis/PCE_Map_Day_Night_Summer_v4"
INPUT_CSV    = "day_night_citymask_shadow_vars_summer.csv"
OUT_UNCERT   = "/project2/zhan248_1326/hhao4018/UQ_analysis/UncertaintyMap_v4/Summer"
FIXED_DEG    = 3
MIN_SAMPLES  = 10 
EXTENT       = (-119.0, -116.5, 33.25, 34.85)
VALID_COLUMNS = ["ZR_scale_factor","ROOF_WIDTH_scale_factor","ROAD_WIDTH_scale_factor"]

WRF_FILE = "/project2/zhan248_1326/hhao4018/Natural_Earth/geo_em.d02.nc"
ds_wrf   = xr.open_dataset(WRF_FILE)

lons2d   = ds_wrf["XLON_M"][0].values   # shape (M,N)
lats2d   = ds_wrf["XLAT_M"][0].values

water_mask = (ds_wrf["LANDMASK"].isel(Time=0).values == 0).astype(int)

# -------------------------------
X_df    = pd.read_csv(INPUT_CSV)
X_input = X_df[VALID_COLUMNS].values
valid_X = ~np.isnan(X_input).any(axis=1)

# -------------------------------
class OrthogonalPolynomialFeatures(BaseEstimator, TransformerMixin):
    def __init__(self, degree, include_bias=True):
        self.degree = degree
        self.include_bias = include_bias

    def fit(self, X, y=None):
        X = np.asarray(X)
        self.X_min_ = X.min(axis=0)
        self.X_max_ = X.max(axis=0)
        nfeat = X.shape[1]
        self.combinations_ = [
            comb for comb in itertools.product(range(self.degree+1), repeat=nfeat)
            if sum(comb) <= self.degree and (self.include_bias or sum(comb)>0)
        ]
        self.combinations_.sort(key=lambda c: (sum(c), c))
        return self

    def transform(self, X):
        X = np.asarray(X, dtype=float)
        Xs = 2 * (X - self.X_min_) / (self.X_max_ - self.X_min_) - 1
        legvals = [legendre.legvander(Xs[:, j], self.degree) for j in range(Xs.shape[1])]
        feats = []
        for comb in self.combinations_:
            p = np.ones(X.shape[0])
            for j, deg in enumerate(comb):
                p *= legvals[j][:, deg]
            feats.append(p)
        return np.vstack(feats).T

# -------------------------------
def sorted_model_dirs(base_dir):
    dirs = [
        d for d in os.listdir(base_dir)
        if d.startswith("Plots_Model_") and os.path.isdir(os.path.join(base_dir, d))
    ]
    dirs.sort(key=lambda d: int(re.search(r"Plots_Model_(\d+)", d).group(1)))
    return dirs

# -------------------------------
def list_average_vars(sample_dir):
    files = os.listdir(sample_dir)
    vars_ = [
        f.replace("Average_","").replace(".npy","")
        for f in files if f.startswith("Average_") and f.endswith(".npy")
    ]
    return sorted(vars_)

# -------------------------------
def compute_uncertainty_and_r2():
    model_dirs = sorted_model_dirs(BASE_DIR)
    vars_ = list_average_vars(os.path.join(BASE_DIR, model_dirs[0]))
    
    if "SNET_URB" in vars_ and "LNET_URB" in vars_:
        vars_.append("NET_URB")

    os.makedirs(OUT_UNCERT, exist_ok=True)

    for var in vars_:
        print(f"--- {var} (degree={FIXED_DEG}) ---")
        FACTOR = 697.7 * 60
        Y_list = []
        for md in model_dirs:
            if var == "NET_URB":
                fp1 = os.path.join(BASE_DIR, md, f"Average_SNET_URB.npy")
                fp2 = os.path.join(BASE_DIR, md, f"Average_LNET_URB.npy")
                if not (os.path.exists(fp1) and os.path.exists(fp2)):
                    continue
                arr1 = np.load(fp1) * FACTOR
                arr2 = np.load(fp2) * FACTOR
                arr = arr1 + arr2
            else:
                fp = os.path.join(BASE_DIR, md, f"Average_{var}.npy")
                if not os.path.exists(fp):
                    continue
            
                arr = np.load(fp)
                if var in ("LNET_URB", "SNET_URB", "NET_URB"):
                    arr = arr * FACTOR
        
            Y_list.append(arr)

        Y = np.stack(Y_list, axis=0)
        n_runs, M, N = Y.shape
        std_map = np.full((M, N), np.nan)
        r2_map  = np.full((M, N), np.nan)

        for i in range(M):
            for j in range(N):
                y = Y[:, i, j]
                mask = valid_X & (~np.isnan(y))
                if mask.sum() < MIN_SAMPLES:
                    continue
                X_use, y_use = X_input[mask], y[mask]

                pipe = Pipeline([
                    ("imp",  SimpleImputer(strategy="mean")),
                    ("poly", OrthogonalPolynomialFeatures(degree=FIXED_DEG)),
                    ("reg",  LinearRegression())
                ])
                pipe.fit(X_use, y_use)

                coeffs = pipe.named_steps["reg"].coef_
                norms  = []
                for comb in pipe.named_steps["poly"].combinations_:
                    nm = 1.0
                    for d in comb:
                        if d > 0:
                            nm *= 2.0 / (2*d + 1)
                    norms.append(nm)
                norms = np.array(norms)
                pce_var    = np.sum(coeffs[1:]**2 * norms[1:]) / 2.0
                std_map[i,j] = np.sqrt(pce_var)

                r2_map[i,j] = pipe.score(X_use, y_use)

        np.save(os.path.join(OUT_UNCERT, f"PCE_std_map_{var}_deg{FIXED_DEG}.npy"), std_map)
        np.save(os.path.join(OUT_UNCERT, f"PCE_r2_map_{var}_deg{FIXED_DEG}.npy"), r2_map)

# -------------------------------
if __name__ == "__main__":
    compute_uncertainty_and_r2()


### **Plot figures_Uncertainty Distribution**

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import geopandas as gpd
from shapely.geometry import box
import xarray as xr
mpl.rcParams['mathtext.default'] = 'rm'

mpl.rcParams.update({
    'font.family': 'sans-serif',
    'font.sans-serif': ['Arial'],
    'font.size': 16,
    'axes.labelsize': 18,
    'axes.titlesize': 16,
    'xtick.labelsize': 16,
    'ytick.labelsize': 16,
    'axes.linewidth': 1,
    'xtick.direction': 'in',
    'ytick.direction': 'in',
    'xtick.major.size': 5,
    'ytick.major.size': 5,
    'figure.dpi': 600,
})

def add_county_labels(ax):
    common_kwargs = dict(
        fontsize=10,
        color='black',
        weight='bold',
        bbox=dict(
            facecolor='white',
            edgecolor='none',
            boxstyle='round,pad=0.3',
            alpha=0.7,
        ),
    )
    ax.text(-118.5, 34.3, 'Los Angeles',  **common_kwargs)
    ax.text(-117.9, 33.7, 'Orange',       **common_kwargs)
    ax.text(-117.2, 33.6, 'Riverside',    **common_kwargs)
    ax.text(-117.4, 34.2, 'San Bernardino', **common_kwargs)

OUT_UNCERT   = "/project2/zhan248_1326/hhao4018/UQ_analysis/UncertaintyMap_v4/Summer"
EXTENT       = (-119.0, -116.5, 33.25, 34.85)
R2_THRESH    = 0.6
STD_THRESH   = 0.01
FIXED_DEG    = 3

COUNTY_SHP = "../Boundary/cb_2018_us_county_500k.shp"
counties = gpd.read_file(COUNTY_SHP).to_crs("EPSG:4326")
bbox_geom = box(EXTENT[0], EXTENT[2], EXTENT[1], EXTENT[3])
counties_clip = counties[counties.geometry.intersects(bbox_geom)].copy()

variables = [
    "T2",      "TC_URB",
    "RH",      "WS",
    "PBLH",    "SNET_URB",
    "LNET_URB","NET_URB"
]

unit_dict = {
    'T2':'K','TC_URB':'K','RH':'%','WS':'m/s',
    'PBLH':'m','SNET_URB':'W/m²','LNET_URB':'W/m²','NET_URB':'W/m²'
}

unit_official_names = {
    'T2': 'T$_2$','TC_URB': 'T$_C$', 'RH':'RH','WS':'WS', "PBLH":'PBLH',
    'SNET_URB': 'SW$_{NET}$', 'LNET_URB': 'LW$_{NET}$', 'NET_URB': 'R$_{NET}$'
}

fig, axes = plt.subplots(4, 2, figsize=(14, 16), sharex=True, sharey=True)
axes = axes.flatten()

WRF_FILE = "/project2/zhan248_1326/hhao4018/Model_Evaluation/wrfout_d02_2016-08-10_10_00_00"
ds_wrf   = xr.open_dataset(WRF_FILE)
lons2d   = ds_wrf["XLONG"].isel(Time=0).values
lats2d   = ds_wrf["XLAT"].isel(Time=0).values

utype = ds_wrf["UTYPE_URB"].isel(Time=0).values 
urban_mask = (utype > 0).astype(float)    

for idx, var in enumerate(variables):
    ax = axes[idx]

    std_map = np.load(os.path.join(OUT_UNCERT, f"PCE_std_map_{var}_deg{FIXED_DEG}.npy"))
    r2_map  = np.load(os.path.join(OUT_UNCERT, f"PCE_r2_map_{var}_deg{FIXED_DEG}.npy"))

    std_filt = np.where(r2_map > R2_THRESH, std_map, np.nan)

    im = ax.pcolormesh(
        lons2d, lats2d, std_filt,
        shading='auto',
        vmin=np.nanpercentile(std_map, 1),
        vmax=np.nanpercentile(std_map, 99),
        cmap='Reds',
    )
    ax.set_xlim(EXTENT[0], EXTENT[1])
    ax.set_ylim(EXTENT[2], EXTENT[3])
    ax.set_aspect('auto')

    M, N = std_map.shape
    lon = np.linspace(EXTENT[0], EXTENT[1], N)
    lat = np.linspace(EXTENT[2], EXTENT[3], M)
    Lon, Lat = np.meshgrid(lon, lat)

    hatch_mask = (r2_map <= R2_THRESH) | (std_map < STD_THRESH)
    mpl.rcParams['hatch.color']     = 'gray'
    mpl.rcParams['hatch.linewidth'] = 1

    ax.contourf(
        Lon, Lat, hatch_mask.astype(float),
        levels=[0.5, 1.5],
        hatches=['/'],
        colors='none',
        edgecolors='none',
        antialiased=True
    )

    counties_clip.boundary.plot(
        ax=ax,
        linewidth=1,
        edgecolor='black',
        zorder=5,
    )

    ax.contour(
        lons2d, lats2d, urban_mask,
        levels=[0.5],      
        colors='royalblue',   
        linewidths=1.0,
        zorder=6,
    )

    # colorbar
    cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label('SD')

    ax.set_xlim(EXTENT[0], EXTENT[1])
    ax.set_ylim(EXTENT[2], EXTENT[3])
    ax.set_xlabel('Longitude')
    ax.set_ylabel('Latitude')

    unit = unit_dict.get(var, "")
    name = unit_official_names.get(var, "")
    ax.set_title(f"{name} [{unit}]", pad=6, fontsize=18)

    ax.text(
        0.02, 0.95,
        f"({chr(97 + idx)})",
        transform=ax.transAxes,
        fontsize=18, fontweight='bold', va='top'
    )
    add_county_labels(ax)

    from matplotlib.ticker import MultipleLocator
    x_major_locator = MultipleLocator(1)
    y_major_locator = MultipleLocator(0.5)
    for ax_ in axes:
        ax_.xaxis.set_major_locator(x_major_locator)
        ax_.yaxis.set_major_locator(y_major_locator)

for j in range(len(variables), len(axes)):
    fig.delaxes(axes[j])

plt.tight_layout(h_pad=1.0, w_pad=-5.0)
plt.show()

**Plot distribution of R2**

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import geopandas as gpd               
from shapely.geometry import box      
import xarray as xr
mpl.rcParams['mathtext.default'] = 'rm'

mpl.rcParams.update({
    'font.family': 'sans-serif',
    'font.sans-serif': ['Arial'],
    'font.size': 16,
    'axes.labelsize': 18,
    'axes.titlesize': 16,
    'xtick.labelsize': 16,
    'ytick.labelsize': 16,
    'axes.linewidth': 1,
    'xtick.direction': 'in',
    'ytick.direction': 'in',
    'xtick.major.size': 5,
    'ytick.major.size': 5,
    'figure.dpi': 600,
})

def add_county_labels(ax):
    common_kwargs = dict(
        fontsize=10,
        color='black',
        weight='bold',
        bbox=dict(
            facecolor='white',
            edgecolor='none',
            boxstyle='round,pad=0.3',
            alpha=0.7,
        ),
    )

    ax.text(-118.5, 34.3, 'Los Angeles',  **common_kwargs)
    ax.text(-117.9, 33.7, 'Orange',       **common_kwargs)
    ax.text(-117.2, 33.6, 'Riverside',    **common_kwargs)
    ax.text(-117.4, 34.2, 'San Bernardino', **common_kwargs)

OUT_UNCERT   = "/project2/zhan248_1326/hhao4018/UQ_analysis/UncertaintyMap_v4/Summer"
EXTENT       = (-119.0, -116.5, 33.25, 34.85)   # (lon_min, lon_max, lat_min, lat_max)
R2_THRESH    = 0.6
STD_THRESH   = 0.01
FIXED_DEG    = 3

COUNTY_SHP = "Boundary/cb_2018_us_county_500k.shp" 

counties = gpd.read_file(COUNTY_SHP).to_crs("EPSG:4326")

bbox_geom = box(EXTENT[0], EXTENT[2], EXTENT[1], EXTENT[3])
counties_clip = counties[counties.geometry.intersects(bbox_geom)].copy()

variables = [
    "T2",      "TC_URB",
    "RH",      "WS",
    "PBLH",    "SNET_URB",
    "LNET_URB","NET_URB"
]

unit_dict = {
    'T2':'K','TC_URB':'K','RH':'%','WS':'m/s',
    'PBLH':'m','SNET_URB':'W/m²','LNET_URB':'W/m²','NET_URB':'W/m²'
}

unit_official_names = {
    'T2': 'T$_2$','TC_URB': 'T$_C$', 'RH':'RH','WS':'WS', "PBLH":'PBLH',
    'SNET_URB': 'SW$_{NET}$', 'LNET_URB': 'LW$_{NET}$', 'NET_URB': 'R$_{NET}$'
}

fig, axes = plt.subplots(4, 2, figsize=(14, 16), sharex=True, sharey=True)
axes = axes.flatten()

WRF_FILE = "/project2/zhan248_1326/hhao4018/Model_Evaluation/wrfout_d02_2016-08-10_10_00_00"
ds_wrf   = xr.open_dataset(WRF_FILE)
lons2d   = ds_wrf["XLONG"].isel(Time=0).values   # shape (M,N)
lats2d   = ds_wrf["XLAT"].isel(Time=0).values

if "UTYPE_URB" in ds_wrf.variables:
    urb_frac   = ds_wrf["UTYPE_URB"].isel(Time=0).values  
    urban_mask = (urb_frac > 0).astype(float)  
else:
    urb_frac   = None
    urban_mask = None


for idx, var in enumerate(variables):
    ax = axes[idx]

    std_map = np.load(os.path.join(OUT_UNCERT, f"PCE_std_map_{var}_deg{FIXED_DEG}.npy"))
    r2_map  = np.load(os.path.join(OUT_UNCERT, f"PCE_r2_map_{var}_deg{FIXED_DEG}.npy"))
    
    if urb_frac is not None and var in ["SNET_URB", "LNET_URB", "NET_URB", "TC_URB"]:
        r2_map = np.where(urb_frac > 0, r2_map, 0.0)
        
    im = ax.pcolormesh(
        lons2d, lats2d, r2_map,
        shading='auto',
        vmin=0.0,
        vmax=1.0,
        cmap='Oranges',    
    )
    ax.set_xlim(EXTENT[0], EXTENT[1])
    ax.set_ylim(EXTENT[2], EXTENT[3])
    ax.set_aspect('auto')

    M, N = std_map.shape
    lon = np.linspace(EXTENT[0], EXTENT[1], N)
    lat = np.linspace(EXTENT[2], EXTENT[3], M)
    Lon, Lat = np.meshgrid(lon, lat)

    hatch_mask = (r2_map <= R2_THRESH) | (std_map < STD_THRESH)

    mpl.rcParams['hatch.color']     = 'gray'
    mpl.rcParams['hatch.linewidth'] = 1

    counties_clip.boundary.plot(
        ax=ax,
        linewidth=0.6,
        edgecolor='black',
        zorder=5,
    )   
    if urban_mask is not None:
        ax.contour(
            lons2d, lats2d, urban_mask,
            levels=[0.5],    
            colors='royalblue', 
            linewidths=1.0,
            zorder=6,    
        )
 
    cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label('R$^2$')  

    ax.set_xlim(EXTENT[0], EXTENT[1])
    ax.set_ylim(EXTENT[2], EXTENT[3])
    ax.set_xlabel('Longitude')
    ax.set_ylabel('Latitude')
    
    unit = unit_dict.get(var, "")
    name = unit_official_names.get(var, "")
    ax.set_title(f"{name} [{unit}]", pad=6, fontsize=18)

    ax.text(
        0.02, 0.95,
        f"({chr(97 + idx)})",
        transform=ax.transAxes,
        fontsize=18,
        fontweight='bold',
        va='top',
        bbox=dict(
            facecolor='white',  
            edgecolor='none',   
            boxstyle='round,pad=0.2',  
            alpha=0.7        
        )
    )
    add_county_labels(ax)


from matplotlib.ticker import MultipleLocator
x_major_locator = MultipleLocator(1)
y_major_locator = MultipleLocator(0.5)
for ax in axes:
    ax.xaxis.set_major_locator(x_major_locator)
    ax.yaxis.set_major_locator(y_major_locator)

for j in range(len(variables), len(axes)):
    fig.delaxes(axes[j])

plt.tight_layout(h_pad=1.0, w_pad=-5.0)
plt.show()

# **Plot HSI's uncertainty Map**

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

# ——— JGR 风格设置 ———
mpl.rcParams.update({
    'font.family': 'serif',
    'font.serif': ['Times New Roman'],
    'font.size': 16,
    'axes.labelsize': 18,
    'axes.titlesize': 16,
    'xtick.labelsize': 16,
    'ytick.labelsize': 16,
    'axes.linewidth': 1,
    'xtick.direction': 'in',
    'ytick.direction': 'in',
    'xtick.major.size': 5,
    'ytick.major.size': 5,
    'figure.dpi': 600,
})

# 配置
OUT_UNCERT   = "/project2/zhan248_1326/hhao4018/UQ_analysis/UncertaintyMap_v4/Summer"
EXTENT       = (-119.0, -116.5, 33.25, 34.85)
R2_THRESH    = 0.6
STD_THRESH   = 0.01
FIXED_DEG    = 3

variables = ["ESI", "HI", "NET", "WBGT"]

unit_dict = {
    'ESI':'°C','HI':'°C','NET':'°C','WBGT':'°C'
}

unit_official_names = {
    'ESI': 'ESI', 'HI': 'HI', 'NET': 'NET', 'WBGT': 'WBGT'
}

# 创建 4×1 子图
fig, axes = plt.subplots(4, 1, figsize=(8, 16), sharex=True, sharey=True)
axes = axes.flatten()

for idx, var in enumerate(variables):
    ax = axes[idx]

    # 读取 std_map 和 r2_map
    std_map = np.load(os.path.join(OUT_UNCERT, f"PCE_std_map_{var}_deg{FIXED_DEG}.npy"))
    r2_map  = np.load(os.path.join(OUT_UNCERT, f"PCE_r2_map_{var}_deg{FIXED_DEG}.npy"))

    std_filt = np.where(r2_map > R2_THRESH, std_map, np.nan)

    # —— 主图 —— 
    im = ax.pcolormesh(
        lons2d, lats2d, std_filt,
        shading='auto',
        vmin=np.nanpercentile(std_map, 1),
        vmax=np.nanpercentile(std_map, 99),
        cmap='Reds',
    )
    ax.set_xlim(EXTENT[0], EXTENT[1])
    ax.set_ylim(EXTENT[2], EXTENT[3])
    ax.set_aspect('auto')

    # —— 掩膜区域 —— 
    M, N = std_map.shape
    lon = np.linspace(EXTENT[0], EXTENT[1], N)
    lat = np.linspace(EXTENT[2], EXTENT[3], M)
    Lon, Lat = np.meshgrid(lon, lat)
    hatch_mask = (r2_map <= R2_THRESH) | (std_map < STD_THRESH)

    mpl.rcParams['hatch.color']     = 'gray'
    mpl.rcParams['hatch.linewidth'] = 1

    ax.contourf(
        Lon, Lat, hatch_mask.astype(float),
        levels=[0.5, 1.5],
        hatches=['/'],
        colors='none',
        edgecolors='none',
        antialiased=True
    )
    # 只画边界，不填充
    counties_clip.boundary.plot(
        ax=ax,
        linewidth=0.6,
        edgecolor='black',
        zorder=5,          # 确保在 imshow 之上
    )
    # —— Colorbar —— 
    cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.02)
    cbar.set_label('SD')

    # 坐标设置
    ax.set_xlim(EXTENT[0], EXTENT[1])
    ax.set_ylim(EXTENT[2], EXTENT[3])
    ax.set_ylabel('Latitude')
    if idx == len(variables) - 1:
        ax.set_xlabel('Longitude')

    # 子图标题和编号
    unit = unit_dict.get(var, "")
    name = unit_official_names.get(var, "")
    ax.set_title(f"{name}", pad=2, fontsize=20, fontweight='bold')
    ax.text(
        0.05, 1.1,
        f"({chr(97 + idx)})",
        transform=ax.transAxes,
        fontsize=20, fontweight='bold', va='top'
    )
    add_county_labels(ax)

plt.tight_layout(h_pad=1.5)
plt.show()
plt.savefig(os.path.join(OUT_UNCERT, "summer_uncertainty_4x1_ESI_HI_NET_WBGT.png"), dpi=600)
plt.close(fig)


In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import xarray as xr   # ← 新增

# ——— JGR 风格设置 ———
mpl.rcParams.update({
    'font.family': 'serif',
    'font.serif': ['Times New Roman'],
    'font.size': 16,
    'axes.labelsize': 18,
    'axes.titlesize': 16,
    'xtick.labelsize': 16,
    'ytick.labelsize': 16,
    'axes.linewidth': 1,
    'xtick.direction': 'in',
    'ytick.direction': 'in',
    'xtick.major.size': 5,
    'ytick.major.size': 5,
    'figure.dpi': 600,
})

# 配置
OUT_UNCERT   = "/project2/zhan248_1326/hhao4018/UQ_analysis/UncertaintyMap_v4/Summer"
EXTENT       = (-119.0, -116.5, 33.25, 34.85)
R2_THRESH    = 0.6
STD_THRESH   = 0.01
FIXED_DEG    = 3

variables = ["ESI", "HI", "NET", "WBGT"]

unit_dict = {
    'ESI':'°C','HI':'°C','NET':'°C','WBGT':'°C'
}

unit_official_names = {
    'ESI': 'ESI', 'HI': 'HI', 'NET': 'NET', 'WBGT': 'WBGT'
}

# ===== 从 wrfout 读经纬度 + 城市掩膜 =====
WRF_FILE = "/project2/zhan248_1326/hhao4018/Model_Evaluation/wrfout_d02_2016-08-10_10_00_00"
ds_wrf   = xr.open_dataset(WRF_FILE)
lons2d   = ds_wrf["XLONG"].isel(Time=0).values
lats2d   = ds_wrf["XLAT"].isel(Time=0).values

if "UTYPE_URB" in ds_wrf.variables:
    utype      = ds_wrf["UTYPE_URB"].isel(Time=0).values
    urban_mask = (utype > 0).astype(float)   # 城市=1, 非城市=0
else:
    urban_mask = None

# 创建 4×1 子图
fig, axes = plt.subplots(4, 1, figsize=(8, 16), sharex=True, sharey=True)
axes = axes.flatten()

for idx, var in enumerate(variables):
    ax = axes[idx]

    std_map = np.load(os.path.join(OUT_UNCERT, f"PCE_std_map_{var}_deg{FIXED_DEG}.npy"))
    r2_map  = np.load(os.path.join(OUT_UNCERT, f"PCE_r2_map_{var}_deg{FIXED_DEG}.npy"))

    std_filt = np.where(r2_map > R2_THRESH, std_map, np.nan)

    # —— 主图 —— 
    im = ax.pcolormesh(
        lons2d, lats2d, std_filt,
        shading='auto',
        vmin=np.nanpercentile(std_map, 1),
        vmax=np.nanpercentile(std_map, 99),
        cmap='Reds',
    )
    ax.set_xlim(EXTENT[0], EXTENT[1])
    ax.set_ylim(EXTENT[2], EXTENT[3])
    ax.set_aspect('auto')

    # —— 掩膜区域 —— 
    M, N = std_map.shape
    lon = np.linspace(EXTENT[0], EXTENT[1], N)
    lat = np.linspace(EXTENT[2], EXTENT[3], M)
    Lon, Lat = np.meshgrid(lon, lat)
    hatch_mask = (r2_map <= R2_THRESH) | (std_map < STD_THRESH)

    mpl.rcParams['hatch.color']     = 'gray'
    mpl.rcParams['hatch.linewidth'] = 1

    ax.contourf(
        Lon, Lat, hatch_mask.astype(float),
        levels=[0.5, 1.5],
        hatches=['/'],
        colors='none',
        edgecolors='none',
        antialiased=True
    )

    # county 边界
    counties_clip.boundary.plot(
        ax=ax,
        linewidth=0.6,
        edgecolor='black',
        zorder=5,
    )

    # ==== 新增：城市边界线（UTYPE_URB > 0） ====
    if urban_mask is not None:
        ax.contour(
            lons2d, lats2d, urban_mask,
            levels=[0.5],          # 0 和 1 中间阈值
            colors='royalblue',    # 和 county 不同的颜色
            linewidths=1.0,
            zorder=6,
        )

    # Colorbar
    cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.02)
    cbar.set_label('SD')

    # 坐标设置
    ax.set_xlim(EXTENT[0], EXTENT[1])
    ax.set_ylim(EXTENT[2], EXTENT[3])
    ax.set_ylabel('Latitude')
    if idx == len(variables) - 1:
        ax.set_xlabel('Longitude')

    # 子图标题和编号
    unit = unit_dict.get(var, "")
    name = unit_official_names.get(var, "")
    ax.set_title(f"{name}", pad=2, fontsize=20, fontweight='bold')
    ax.text(
        0.05, 1.1,
        f"({chr(97 + idx)})",
        transform=ax.transAxes,
        fontsize=20, fontweight='bold', va='top'
    )
    add_county_labels(ax)
    
# 把刻度设置放到循环外，避免重复设置
from matplotlib.ticker import MultipleLocator
x_major_locator = MultipleLocator(1)
y_major_locator = MultipleLocator(0.5)
for ax in axes:
    ax.xaxis.set_major_locator(x_major_locator)
    ax.yaxis.set_major_locator(y_major_locator)
    
plt.tight_layout(h_pad=1.5)

#plt.savefig(os.path.join(OUT_UNCERT, "summer_uncertainty_4x1_ESI_HI_NET_WBGT.png"), dpi=600)
plt.show()
plt.close(fig)
