In [None]:
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pyproj import Geod
import os

# -------------------------------
# Load GRACE and downscaled data
# -------------------------------
gr = xr.open_dataset("grace_025.nc")  # variable: grace
dn = xr.open_dataset("downscaled_grace_tws_data_with_uncertainty_gan_hat.nc")  # variable: tws_downscaled

gr["time"] = pd.to_datetime(gr["time"].values)
dn["time"] = pd.to_datetime(dn["time"].values)

# -------------------------------
# Compute area grid for any lat/lon resolution
# -------------------------------
def compute_area_grid(lat, lon):
    geod = Geod(ellps="WGS84")
    area_grid = np.zeros((len(lat), len(lon)))
    for i in range(len(lat) - 1):
        for j in range(len(lon) - 1):
            lons = [lon[j], lon[j+1], lon[j+1], lon[j]]
            lats = [lat[i], lat[i], lat[i+1], lat[i+1]]
            area, _ = geod.polygon_area_perimeter(lons, lats)
            area_grid[i, j] = abs(area) / 1e6  # m² → km²
    area_grid[-1, :] = area_grid[-2, :]
    area_grid[:, -1] = area_grid[:, -2]
    return area_grid

area_grid_grace = compute_area_grid(gr.lat.values, gr.lon.values)
area_grid_down = compute_area_grid(dn.lat.values, dn.lon.values)

# -------------------------------
# Define drought/flood events
# -------------------------------
events = {
    "2004_06-ali_drought": {
        "time_range": ("2003-12", "2004-12"),
        "region": {"lat_min": 30.0, "lat_max": 34.5, "lon_min": 78.0, "lon_max": 84.5}
    },
    "2004_6-8southwest_flood": {
        "time_range": ("2004-01", "2005-01"),
        "region": {"lat_min": 28.0, "lat_max": 31.5, "lon_min": 90.0, "lon_max": 97.0}
    },
    "2009_-6tibet_drought": {
        "time_range": ("2008-12", "2009-12"),
        "region": {"lat_min": 28.5, "lat_max": 32.5, "lon_min": 87.0, "lon_max": 97.0}
    },
    "2010_qinghai_flood": {
        "time_range": ("2010-01", "2011-01"),
        "region": {"lat_min": 36.5, "lat_max": 37.5, "lon_min": 101.5, "lon_max": 102.5}
}
}

# -------------------------------
# Compute water volume and time series
# -------------------------------
def calc_water_volume_and_series(ds, var, event, area_grid, lat_array, lon_array):
    t_start = pd.to_datetime(event["time_range"][0])
    t_end = pd.to_datetime(event["time_range"][1])
    region = event["region"]

    sub = ds[var].sel(
        time=slice(t_start, t_end),
        lat=slice(region["lat_min"], region["lat_max"]),
        lon=slice(region["lon_min"], region["lon_max"])
    )
    delta = sub.isel(time=-1) - sub.isel(time=0)
    delta = delta.where(np.isfinite(delta), 0)

    lat_mask = (lat_array >= region["lat_min"]) & (lat_array <= region["lat_max"])
    lon_mask = (lon_array >= region["lon_min"]) & (lon_array <= region["lon_max"])
    area_sub = area_grid[np.ix_(lat_mask, lon_mask)]

    assert delta.shape == area_sub.shape, f"Shape mismatch: {delta.shape} vs {area_sub.shape}"
    water_km3 = np.nansum(delta.values * area_sub) * 0.01 / 1000  # cm * km² → km³

    mean_series = sub.mean(dim=["lat", "lon"]).to_dataframe()[var].reset_index()
    return water_km3, mean_series

# -------------------------------
# Run analysis and plot/save per-event results
# -------------------------------
out_dir = "event_results"
os.makedirs(out_dir, exist_ok=True)

print("🌊 TWSA Volumetric Change and Mean Time Series (km³)")

for idx, (event_name, info) in enumerate(events.items()):
    vol_grace, ts_grace = calc_water_volume_and_series(
        gr, "data", info, area_grid_grace, gr.lat.values, gr.lon.values)

    vol_dn, ts_dn = calc_water_volume_and_series(
        dn, "data", info, area_grid_down, dn.lat.values, dn.lon.values)

    print(f"\n🌀 {event_name.replace('_', ' ').title()}:")
    print(f"   • GRACE     ΔTWSA : {vol_grace:.2f} km³")
    print(f"   • Downscaled ΔTWSA: {vol_dn:.2f} km³")

    # Save time series to CSV
    ts = pd.DataFrame({
        "Date": ts_grace["time"],
        "GRACE_TWSA": ts_grace["data"].values,
        "Downscaled_TWSA": ts_dn["data"].values
    })
    csv_path = os.path.join(out_dir, f"{event_name}_timeseries.csv")
    ts.to_csv(csv_path, index=False)

    # Plot individual event figure
    plt.figure(figsize=(8, 5))
    plt.plot(ts["Date"], ts["GRACE_TWSA"], label="GRACE TWSA", color='blue', linestyle='--', marker='o')
    plt.plot(ts["Date"], ts["Downscaled_TWSA"], label="Downscaled TWSA", color='red', linestyle='-', marker='s')
    plt.title(event_name.replace('_', ' ').title())
    plt.xlabel("Date")
    plt.ylabel("Mean TWSA (cm)")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    fig_path = os.path.join(out_dir, f"{event_name}_timeseries.png")
    plt.savefig(fig_path, dpi=300)
    plt.show()
    plt.close()

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pyproj import Geod
import geopandas as gpd
import os
from shapely.geometry import box
# -------------------------------
# User settings
# -------------------------------
ADM_SHAPEFILE = "gadm41_CHN_shp/gadm41_CHN_3.shp"  # GADM v4.1 China admin level-3 shapefile
OUT_DIR = "event_results"
os.makedirs(OUT_DIR, exist_ok=True)

# -------------------------------
# 1. Load GRACE and downscaled data
# -------------------------------
gr = xr.open_dataset("grace_025.nc")               # variable: grace
_dn = xr.open_dataset("downscaled_grace_tws_data_with_uncertainty_gan_danet.nc")  # variable: tws_downscaled

gr["time"] = pd.to_datetime(gr["time"].values)
_dn["time"] = pd.to_datetime(_dn["time"].values)

# -------------------------------
# 2. Compute area grid for any lat/lon resolution
# -------------------------------
def compute_area_grid(lat, lon):
    geod = Geod(ellps="WGS84")
    A = np.zeros((len(lat), len(lon)))
    for i in range(len(lat)-1):
        for j in range(len(lon)-1):
            ring_lons = [lon[j], lon[j+1], lon[j+1], lon[j]]
            ring_lats = [lat[i], lat[i], lat[i+1], lat[i+1]]
            area, _ = geod.polygon_area_perimeter(ring_lons, ring_lats)
            A[i, j] = abs(area) / 1e6  # km²
    A[-1, :] = A[-2, :]
    A[:, -1] = A[:, -2]
    return A

area_grace = compute_area_grid(gr.lat.values, gr.lon.values)
area_down  = compute_area_grid(_dn.lat.values, _dn.lon.values)

# -------------------------------
# 3. Define drought/flood events
# -------------------------------
events = {
    "2004_06-ali_drought": {
        "time_range": ("2003-12", "2004-12"),
        "timestep": ("2004-6"),
        "region": {"lat_min": 30.0, "lat_max": 34.5, "lon_min": 78.0, "lon_max": 84.5}
    },
    "2004_6-8southwest_flood": {
        "time_range": ("2004-01", "2005-01"),
        "timestep": ("2004-7"),
        "region": {"lat_min": 28.0, "lat_max": 31.5, "lon_min": 90.0, "lon_max": 97.0}
    },
    "2009_-6tibet_drought": {
        "time_range": ("2008-12", "2009-12"),
        "timestep": ("2009-7"),
        "region": {"lat_min": 28.5, "lat_max": 32.5, "lon_min": 87.0, "lon_max": 97.0}
    },
    "2010_qinghai_flood": {
        "time_range": ("2010-01", "2011-01"),
        "timestep": ("2010-7"),
        "region": {"lat_min": 36.5, "lat_max": 37.5, "lon_min": 101.5, "lon_max": 102.5}
    }
}

# -------------------------------
# 4. Compute water volume and time series
# -------------------------------
def calc_water_volume_and_series(ds, var, event, area_grid, lat_array, lon_array):
    t_start = pd.to_datetime(event["time_range"][0])
    t_end   = pd.to_datetime(event["time_range"][1])
    reg     = event["region"]

    sub = ds[var].sel(
        time=slice(t_start, t_end),
        lat=slice(reg["lat_min"], reg["lat_max"]),
        lon=slice(reg["lon_min"], reg["lon_max"])
    )
    delta = sub.isel(time=-1) - sub.isel(time=0)
    delta = delta.where(np.isfinite(delta), 0)

    lat_mask = (lat_array >= reg["lat_min"]) & (lat_array <= reg["lat_max"])
    lon_mask = (lon_array >= reg["lon_min"]) & (lon_array <= reg["lon_max"])
    area_sub = area_grid[np.ix_(lat_mask, lon_mask)]

    assert delta.shape == area_sub.shape, f"Shape mismatch: {delta.shape} vs {area_sub.shape}"
    water_km3 = np.nansum(delta.values * area_sub) * 0.01 / 1000  # cm * km² → km³

    mean_series = sub.mean(dim=["lat", "lon"]).to_dataframe()[var].reset_index()
    return water_km3, mean_series

# Load administrative boundaries
gdf_adm = gpd.read_file(ADM_SHAPEFILE)
gdf_yunnan = gdf_adm[gdf_adm["NAME_1"]=="Yunnan"]
# -------------------------------
# 5. Run analysis and plot/save per-event results
# -------------------------------
for event_name, info in events.items():
    # Calculate volume and mean series
    vol_grace, ts_grace = calc_water_volume_and_series(gr, "data", info, area_grace, gr.lat.values, gr.lon.values)
    vol_dn,    ts_dn    = calc_water_volume_and_series(_dn,  "data", info, area_down, _dn.lat.values, _dn.lon.values)

    print(f"\n🌀 {event_name.replace('_',' ').title()}: GRACE Δ={vol_grace:.2f} km³, Downscaled Δ={vol_dn:.2f} km³")

    # Save mean time series CSV
    df_ts = pd.DataFrame({
        "Date": ts_grace["time"],
        "GRACE_TWSA": ts_grace["data"].values,
        "Downscaled_TWSA": ts_dn["data"].values
    })
    df_ts.to_csv(os.path.join(OUT_DIR, f"{event_name}_timeseries.csv"), index=False)

    # Plot time series
    plt.figure(figsize=(8,5))
    plt.plot(df_ts["Date"], df_ts["GRACE_TWSA"],     '--o', label="GRACE",     color='blue')
    plt.plot(df_ts["Date"], df_ts["Downscaled_TWSA"], '-s', label="Downscaled", color='red')
    plt.title(event_name.replace('_',' ').title())
    plt.xlabel("Date"); plt.ylabel("Mean TWSA (mm)")
    plt.legend(); plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(OUT_DIR, f"{event_name}_timeseries.pdf"), dpi=300)
    plt.close()

    # Plot spatial comparison at last time step
    last_date = info["timestep"]
    
    print(last_date)
    tg = gr["data"].sel(time=last_date).isel(time=-1).sel(
        lat=slice(info["region"]["lat_min"], info["region"]["lat_max"]),
        lon=slice(info["region"]["lon_min"], info["region"]["lon_max"]) )
    td = _dn["data"].sel(time=last_date).isel(time=-1).sel(
        lat=slice(info["region"]["lat_min"], info["region"]["lat_max"]),
        lon=slice(info["region"]["lon_min"], info["region"]["lon_max"]) )

    fig, axes = plt.subplots(1,2,figsize=(12,5), sharex=True, sharey=True)
    for ax, data, title in zip(axes, [tg, td], ["GRACE", "Downscaled"]):
        pcm = ax.pcolormesh(data.lon, data.lat, data.values, cmap='jet_r', shading='auto')
        # Clip and overlay boundaries
        window = gpd.GeoDataFrame(geometry=[box(
            info["region"]["lon_min"], info["region"]["lat_min"],
            info["region"]["lon_max"], info["region"]["lat_max"] )], crs=gdf_adm.crs)
        gpd.clip(gdf_adm, window).boundary.plot(ax=ax, edgecolor='black', linewidth=0.8)
        ax.set_title(f"{title} TWSA {last_date}")
        ax.set_xlabel("Longitude"); ax.set_ylabel("Latitude")
    fig.colorbar(pcm, ax=axes.ravel(), label="TWSA (mm)", fraction=0.02)
    #plt.tight_layout()
    plt.savefig(os.path.join(OUT_DIR, f"{event_name}_spatial.pdf"), dpi=300)
    plt.savefig(os.path.join(OUT_DIR, f"{event_name}_spatial.png"), dpi=300)
    plt.close(fig)



🌀 2004 06-Ali Drought: GRACE Δ=-146.71 km³, Downscaled Δ=-168.32 km³
2004-6

🌀 2004 6-8Southwest Flood: GRACE Δ=334.64 km³, Downscaled Δ=183.41 km³
2004-7

🌀 2009 -6Tibet Drought: GRACE Δ=-117.95 km³, Downscaled Δ=16.77 km³
2009-7

🌀 2010 Qinghai Flood: GRACE Δ=0.36 km³, Downscaled Δ=0.89 km³
2010-7


In [None]:
import xarray as xr
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# ------------------------------------------------------------------
# 1.  Paths
# ------------------------------------------------------------------
DS_PATH   = "downscaled_grace_tws_data_with_uncertainty_gan_hat.nc"
SHAPE_L1  = "gadm41_CHN_shp/gadm41_CHN_1.shp"      # province level
OUT_FIG   = "qng_tws_2009-09_2010-05.png"

# ------------------------------------------------------------------
# 2.  Load data (downscaled only) & Qinghai boundary
# ------------------------------------------------------------------
ds   = xr.open_dataset(DS_PATH)          # variable: data
ds["time"] = pd.to_datetime(ds["time"])

gdf  = gpd.read_file(SHAPE_L1)
qh   = gdf[gdf["NAME_1"] == "Qinghai"].to_crs("EPSG:4326")  # ensure lon/lat

# Bounding box of Qinghai
minx, miny, maxx, maxy = qh.total_bounds

# ------------------------------------------------------------------
# 3.  Time range & figure layout
# ------------------------------------------------------------------
times = pd.date_range("2009-09", "2010-05", freq="MS")
n     = len(times)
cols  = 3
rows  = int(np.ceil(n / cols))

fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4), sharex=True, sharey=True)

# ------------------------------------------------------------------
# 4.  Plot every month, clipped to Qinghai bbox
# ------------------------------------------------------------------
for i, t in enumerate(times):
    ax   = axes.flat[i]
    da   = ds["data"].sel(time=t, method="nearest")
    clip = da.sel(lat=slice(miny, maxy), lon=slice(minx, maxx))

    pcm  = ax.pcolormesh(clip.lon, clip.lat, clip.values,
                         cmap="jet_r",vmin=-200, vmax=200, shading="auto")   # jet_r as requested

    # Draw Qinghai province outline only
    qh.boundary.plot(ax=ax, edgecolor="black", linewidth=1.0)

    ax.set_title(t.strftime("%Y-%m"))
    ax.set_xticks([]); ax.set_yticks([])

# Hide any unused panels (if rows*cols > n)
for ax in axes.flat[n:]:
    ax.set_visible(False)

# Color-bar
cbar = fig.colorbar(pcm, ax=axes.ravel().tolist(), fraction=0.02, pad=0.02)
cbar.set_label("TWSA (cm)")

fig.savefig(OUT_FIG, dpi=300)
plt.close(fig)

In [None]:
import xarray as xr
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# --------------------------------------------------------------
# 1.  Paths and basic settings
# --------------------------------------------------------------
DS_PATH   = "downscaled_grace_tws_data_with_uncertainty_gan_danet.nc"
SHAPE_L1  = "gadm41_CHN_shp/gadm41_CHN_1.shp"   # province level
OUT_FIG   = "xizang_tws_2009-09_2010-05.png"

PROVINCE_NAME = "Xizang"       # change to "Xizang" if that’s the label in the shapefile
CMAP = "jet_r"

# --------------------------------------------------------------
# 2.  Load dataset and Xizang shape
# --------------------------------------------------------------
ds = xr.open_dataset(DS_PATH)
ds["time"] = pd.to_datetime(ds["time"])

gdf = gpd.read_file(SHAPE_L1)
xz  = gdf[gdf["NAME_1"] == PROVINCE_NAME].to_crs("EPSG:4326")

# Bounding box of Xizang
minx, miny, maxx, maxy = xz.total_bounds
print(minx, miny, maxx, maxy)
# Determine whether lat is ordered north→south or south→north
lat_descending = ds.lat[0] > ds.lat[-1]

def lat_slice(min_lat, max_lat):
    """Return a slice object that works whether lat is ascending or descending."""
    return slice(max_lat, min_lat) if lat_descending else slice(min_lat, max_lat)

# --------------------------------------------------------------
# 3.  Time range and figure layout
# --------------------------------------------------------------
times = pd.date_range("2009-09", "2010-05", freq="MS")
n     = len(times)
cols  = 3
rows  = int(np.ceil(n / cols))

fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4),
                         sharex=True, sharey=True)
print(da.lon, da.lat, da)
# --------------------------------------------------------------
# 4.  Plot each month clipped to Xizang extent
# --------------------------------------------------------------
for i, t in enumerate(times):
    ax  = axes.flat[i]

    # Subset to time and spatial bounds
    da  = ds["data"].sel(time=t, method="nearest")
    clip  = da.sel(
        lon=slice(minx, maxx),
        lat=lat_slice(miny, maxy)
    )

    # pcolormesh needs 2-D lon/lat arrays; xarray broadcasts automatically
    pcm = ax.pcolormesh(clip.lon, clip.lat, clip, cmap=CMAP,vmin=-200, vmax=200, shading="auto")

    # Draw Xizang boundary only
    xz.boundary.plot(ax=ax, edgecolor="black", linewidth=1.0)

    ax.set_title(t.strftime("%Y-%m"))
    ax.set_xticks([]); ax.set_yticks([])

# Hide any unused axes
for ax in axes.flat[n:]:
    ax.set_visible(False)

# Shared colour-bar
cbar2 = fig.colorbar(pcm, ax=axes.ravel().tolist(), fraction=0.02, pad=0.02)
cbar2.set_label("TWSA (cm)")

fig.savefig(OUT_FIG, dpi=300)
plt.close(fig)

78.40432739300019 27.305332183000075 99.11231231700009 36.483509063000156
<xarray.DataArray 'lon' (lon: 900)>
array([ 65.1 ,  65.15,  65.2 , ..., 109.95, 110.  , 110.05], dtype=float32)
Coordinates:
    time     datetime64[ns] 2010-04-21
  * lon      (lon) float32 65.1 65.15 65.2 65.25 ... 109.9 109.9 110.0 110.1
Attributes:
    units:    degrees_east <xarray.DataArray 'lat' (lat: 440)>
array([24.1 , 24.15, 24.2 , ..., 45.95, 46.  , 46.05], dtype=float32)
Coordinates:
    time     datetime64[ns] 2010-04-21
  * lat      (lat) float32 24.1 24.15 24.2 24.25 24.3 ... 45.9 45.95 46.0 46.05
Attributes:
    units:    degrees_north <xarray.DataArray 'data' (lat: 440, lon: 900)>
[396000 values with dtype=float32]
Coordinates:
    time     datetime64[ns] 2010-04-21
  * lat      (lat) float32 24.1 24.15 24.2 24.25 24.3 ... 45.9 45.95 46.0 46.05
  * lon      (lon) float32 65.1 65.15 65.2 65.25 ... 109.9 109.9 110.0 110.1
Attributes:
    units:                cm
    source:               GRACE/GRAC

In [None]:
import xarray as xr
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# ------------------------------------------------------------------
# 1.  Paths
# ------------------------------------------------------------------
DS_PATH   = "downscaled_grace_tws_data_with_uncertainty_gan_hat.nc"
SHAPE_L1  = "gadm41_CHN_shp/gadm41_CHN_1.shp"      # province level
OUT_FIG   = "yunnan_tws_2009-09_2010-05.png"

# ------------------------------------------------------------------
# 2.  Load data (downscaled only) & Qinghai boundary
# ------------------------------------------------------------------
ds   = xr.open_dataset(DS_PATH)          # variable: data
ds["time"] = pd.to_datetime(ds["time"])

gdf  = gpd.read_file(SHAPE_L1)
qh   = gdf[gdf["NAME_1"] == "Yunnan"].to_crs("EPSG:4326")  # ensure lon/lat

# Bounding box of Qinghai
minx, miny, maxx, maxy = qh.total_bounds

# ------------------------------------------------------------------
# 3.  Time range & figure layout
# ------------------------------------------------------------------
times = pd.date_range("2009-09", "2010-05", freq="MS")
n     = len(times)
cols  = 3
rows  = int(np.ceil(n / cols))

fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4), sharex=True, sharey=True)

# ------------------------------------------------------------------
# 4.  Plot every month, clipped to Qinghai bbox
# ------------------------------------------------------------------
for i, t in enumerate(times):
    ax   = axes.flat[i]
    da   = ds["data"].sel(time=t, method="nearest")
    clip = da.sel(lat=slice(miny, maxy), lon=slice(minx, maxx))

    pcm  = ax.pcolormesh(clip.lon, clip.lat, clip.values,
                         cmap="jet_r",vmin=-200, vmax=200, shading="auto")   # jet_r as requested

    # Draw Qinghai province outline only
    qh.boundary.plot(ax=ax, edgecolor="black", linewidth=1.0)

    ax.set_title(t.strftime("%Y-%m"))
    ax.set_xticks([]); ax.set_yticks([])

# Hide any unused panels (if rows*cols > n)
for ax in axes.flat[n:]:
    ax.set_visible(False)

# Color-bar
cbar = fig.colorbar(pcm, ax=axes.ravel().tolist(), fraction=0.02, pad=0.02)
cbar.set_label("TWSA (cm)")

fig.savefig(OUT_FIG, dpi=300)
plt.close(fig)

In [None]:
import xarray as xr
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# ------------------------------------------------------------------
# 1.  Paths
# ------------------------------------------------------------------
DS_PATH   = "downscaled_grace_tws_data_with_uncertainty_gan_hat.nc"
SHAPE_L1  = "gadm41_CHN_shp/gadm41_CHN_1.shp"      # province level
OUT_FIG   = "Sichuan_tws_2009-09_2010-05.png"

# ------------------------------------------------------------------
# 2.  Load data (downscaled only) & Qinghai boundary
# ------------------------------------------------------------------
ds   = xr.open_dataset(DS_PATH)          # variable: data
ds["time"] = pd.to_datetime(ds["time"])

gdf  = gpd.read_file(SHAPE_L1)
qh   = gdf[gdf["NAME_1"] == "Sichuan"].to_crs("EPSG:4326")  # ensure lon/lat

# Bounding box of Qinghai
minx, miny, maxx, maxy = qh.total_bounds

# ------------------------------------------------------------------
# 3.  Time range & figure layout
# ------------------------------------------------------------------
times = pd.date_range("2009-09", "2010-05", freq="MS")
n     = len(times)
cols  = 3
rows  = int(np.ceil(n / cols))

fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4), sharex=True, sharey=True)

# ------------------------------------------------------------------
# 4.  Plot every month, clipped to Qinghai bbox
# ------------------------------------------------------------------
for i, t in enumerate(times):
    ax   = axes.flat[i]
    da   = ds["data"].sel(time=t, method="nearest")
    clip = da.sel(lat=slice(miny, maxy), lon=slice(minx, maxx))

    pcm  = ax.pcolormesh(clip.lon, clip.lat, clip.values,
                         cmap="jet_r",vmin=-200, vmax=200, shading="auto")   # jet_r as requested

    # Draw Qinghai province outline only
    qh.boundary.plot(ax=ax, edgecolor="black", linewidth=1.0)

    ax.set_title(t.strftime("%Y-%m"))
    ax.set_xticks([]); ax.set_yticks([])

# Hide any unused panels (if rows*cols > n)
for ax in axes.flat[n:]:
    ax.set_visible(False)

# Color-bar
cbar = fig.colorbar(pcm, ax=axes.ravel().tolist(), fraction=0.02, pad=0.02)
cbar.set_label("TWSA (cm)")

fig.savefig(OUT_FIG, dpi=300)
plt.close(fig)

In [None]:
import xarray as xr
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from shapely.geometry import box
# --------------------------------------------------------------
# 1.  Paths
# --------------------------------------------------------------
DS_PATH   = "downscaled_grace_tws_data_with_uncertainty_gan_danet.nc"
SHAPE_L0  = "gadm41_CHN_shp/gadm41_CHN_0.shp"   # national outline
SHAPE_L1  = "gadm41_CHN_shp/gadm41_CHN_1.shp"   # provinces
OUT_FIG   = "china_tws_2009-09_2010-05.png"
CMAP      = "jet_r"

# --------------------------------------------------------------
# 2.  Load dataset & administrative boundaries
# --------------------------------------------------------------
ds = xr.open_dataset(DS_PATH)
ds["time"] = pd.to_datetime(ds["time"])

# National outline and province boundaries
gdf_nat  = gpd.read_file(SHAPE_L0).to_crs("EPSG:4326")
gdf_prov = gpd.read_file(SHAPE_L1).to_crs("EPSG:4326")

# --------------------------------------------------------------
# 3.  Determine dataset bounding box & clip boundaries
# --------------------------------------------------------------
lon_min, lon_max = float(ds.lon.min()), float(ds.lon.max())
lat_min, lat_max = float(ds.lat.min()), float(ds.lat.max())

bbox_geom = box(lon_min, lat_min, lon_max, lat_max)

# Clip national / province outlines to that bbox
# (gpd.clip works in all recent versions; fall back to boolean mask if needed)
try:
    gdf_nat_clip  = gpd.clip(gdf_nat,  bbox_geom)
    gdf_prov_clip = gpd.clip(gdf_prov, bbox_geom)
except AttributeError:                       # very old GeoPandas
    gdf_nat_clip  = gdf_nat[gdf_nat.intersects(bbox_geom)]
    gdf_prov_clip = gdf_prov[gdf_prov.intersects(bbox_geom)]

# --------------------------------------------------------------
# 4.  Time range & colour scale (global min/max over period)
# --------------------------------------------------------------
times  = pd.date_range("2009-09", "2010-05", freq="MS")
ds_sub = ds["data"].sel(time=times, method="nearest")       # <-- add method="nearest"
vmin   = float(ds_sub.min())
vmax   = float(ds_sub.max())

# --------------------------------------------------------------
# 5.  Figure layout
# --------------------------------------------------------------
n    = len(times)
cols = 3
rows = int(np.ceil(n / cols))

fig, axes = plt.subplots(rows, cols,
                         figsize=(cols * 4, rows * 4),
                         sharex=True, sharey=True)

# --------------------------------------------------------------
# 6.  Plot each month
# --------------------------------------------------------------
lat_desc = ds.lat[0] > ds.lat[-1]
lat_slice = slice(lat_max, lat_min) if lat_desc else slice(lat_min, lat_max)

for i, t in enumerate(times):
    ax = axes.flat[i]

    da = ds["data"].sel(time=t, method="nearest").sel(
        lon=slice(lon_min, lon_max),
        lat=lat_slice
    )

    pcm = ax.pcolormesh(da.lon, da.lat, da,
                        cmap=CMAP, shading="auto",
                        vmin=vmin, vmax=vmax)

    # Overlays
    gdf_nat_clip.boundary.plot(ax=ax, edgecolor="black", linewidth=1.2)    # national outline
    gdf_prov_clip.boundary.plot(ax=ax, edgecolor="black", linewidth=0.6)   # provinces

    ax.set_title(t.strftime("%Y-%m"))
    ax.set_xticks([]); ax.set_yticks([])

# Hide unused panels
for ax in axes.flat[n:]:
    ax.set_visible(False)

# Shared colour-bar
cbar = fig.colorbar(pcm, ax=axes.ravel().tolist(), fraction=0.02, pad=0.02)
cbar.set_label("TWSA (cm)")

fig.tight_layout()
fig.savefig(OUT_FIG, dpi=300)
plt.close(fig)


  fig.tight_layout()


In [None]:
import xarray as xr
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from shapely.geometry import box
from matplotlib.patches import Rectangle
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

# --------------------------------------------------------------
# 1) Paths & settings
# --------------------------------------------------------------
DS_PATH   = "downscaled_grace_tws_data_with_uncertainty_gan_danet.nc"
SHAPE_L0  = "gadm41_CHN_shp/gadm41_CHN_0.shp"
SHAPE_L1  = "gadm41_CHN_shp/gadm41_CHN_1.shp"
OUT_FIG   = "china_tws_2009-09_2010-05_with_inset.png"
CMAP      = "jet_r"

# Focus window (lon_min, lon_max, lat_min, lat_max)
HLON_MIN, HLON_MAX = 85.5, 110.5
HLAT_MIN, HLAT_MAX = 25.5, 35.5

# --------------------------------------------------------------
# 2) Load dataset & boundaries
# --------------------------------------------------------------
ds = xr.open_dataset(DS_PATH)
ds["time"] = pd.to_datetime(ds["time"])

gdf_nat  = gpd.read_file(SHAPE_L0).to_crs("EPSG:4326")
gdf_prov = gpd.read_file(SHAPE_L1).to_crs("EPSG:4326")

# --------------------------------------------------------------
# 3) Dataset bounds & clipped outlines (GeoPandas-version safe)
# --------------------------------------------------------------
lon_min, lon_max = float(ds.lon.min()), float(ds.lon.max())
lat_min, lat_max = float(ds.lat.min()), float(ds.lat.max())
bbox_geom = box(lon_min, lat_min, lon_max, lat_max)

try:  # gpd.clip exists ≥0.10
    gdf_nat_clip  = gpd.clip(gdf_nat,  bbox_geom)
    gdf_prov_clip = gpd.clip(gdf_prov, bbox_geom)
except AttributeError:  # very old GeoPandas
    gdf_nat_clip  = gdf_nat[gdf_nat.intersects(bbox_geom)]
    gdf_prov_clip = gdf_prov[gdf_prov.intersects(bbox_geom)]

# --------------------------------------------------------------
# 4) Time range & colour scale
# --------------------------------------------------------------
times  = pd.date_range("2009-09", "2010-05", freq="MS")
ds_sub = ds["data"].sel(time=slice("2009-09", "2010-05"))
vmin   = float(ds_sub.min())
vmax   = float(ds_sub.max())

# --------------------------------------------------------------
# 5) Figure layout
# --------------------------------------------------------------
n, cols = len(times), 3
rows    = int(np.ceil(n / cols))
fig, axes = plt.subplots(rows, cols, figsize=(cols * 4.2, rows * 4.2),
                         sharex=True, sharey=True)

# Latitude orientation helper
lat_desc  = ds.lat[0] > ds.lat[-1]
def lat_slice(lo, hi):
    return slice(hi, lo) if lat_desc else slice(lo, hi)

# --------------------------------------------------------------
# 6) Draw each month
# --------------------------------------------------------------
for i, t in enumerate(times):
    ax = axes.flat[i]

    da_full = ds["data"].sel(time=t, method="nearest").sel(
        lon=slice(lon_min, lon_max),
        lat=lat_slice(lat_min, lat_max)
    )

    pcm = ax.pcolormesh(da_full.lon, da_full.lat, da_full,
                        cmap=CMAP, shading="auto",
                        vmin=vmin, vmax=vmax)

    # Outlines
    gdf_nat_clip.boundary.plot(ax=ax, edgecolor="black", linewidth=1.0)
    gdf_prov_clip.boundary.plot(ax=ax, edgecolor="black", linewidth=0.6)

    # Main focus rectangle
    ax.add_patch(Rectangle((HLON_MIN, HLAT_MIN),
                           HLON_MAX - HLON_MIN,
                           HLAT_MAX - HLAT_MIN,
                           fill=False, edgecolor="white",
                           linewidth=1.2, linestyle="--"))

    ax.set_title(t.strftime("%Y-%m"))
    ax.set_xticks([]); ax.set_yticks([])

    # ---- Hawk-eye inset with DATA ----
    iax = inset_axes(ax, width="28%", height="28%", loc="lower left", borderpad=0.6)
    da_zoom = da_full.sel(
        lon=slice(HLON_MIN, HLON_MAX),
        lat=lat_slice(HLAT_MIN, HLAT_MAX)
    )
    iax.pcolormesh(da_zoom.lon, da_zoom.lat, da_zoom,
                   cmap=CMAP, shading="auto",
                   vmin=vmin, vmax=vmax)

    # Draw rectangle (optional inside the inset too, here as a red frame)
    iax.add_patch(Rectangle((HLON_MIN, HLAT_MIN),
                            HLON_MAX - HLON_MIN,
                            HLAT_MAX - HLAT_MIN,
                            fill=False, edgecolor="red", linewidth=1.0))

    gdf_nat_clip.boundary.plot(ax=iax, edgecolor="black", linewidth=0.8)
    gdf_prov_clip.boundary.plot(ax=iax, edgecolor="black", linewidth=0.5)

    iax.set_xlim(HLON_MIN, HLON_MAX)
    iax.set_ylim(HLAT_MIN, HLAT_MAX)
    iax.set_xticks([]); iax.set_yticks([])

# Hide unused axes
for ax in axes.flat[n:]:
    ax.set_visible(False)

# Shared colour-bar
cbar = fig.colorbar(pcm, ax=axes.ravel().tolist(), fraction=0.02, pad=0.02)
cbar.set_label("TWSA (cm)")

fig.tight_layout()
fig.savefig(OUT_FIG, dpi=300)
plt.close(fig)

  fig.tight_layout()


In [6]:
import xarray as xr
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import geopandas as gpd
from shapely.geometry import box
import string
import cv2
import geopandas as gpd
from shapely.geometry import box
# --- new imports ---
import matplotlib.patheffects as pe
import geopandas as gpd
from shapely.geometry import box
# ------------------------------------------------------------------
# 1)  Paths & settings
# ------------------------------------------------------------------
DS_PATH = "downscaled_grace_tws_data_with_uncertainty_gan_danet.nc"
OUT_FIG = "hawk_eye_only_2009-09_2010-05.pdf"
CMAP    = "jet_r"

# Shapefiles
TP_SHP  = "Union/TP_basins.shp"                     # <-- update if your TP boundary lives elsewhere
ADM_SHP = "gadm41_CHN_shp/gadm41_CHN_2.shp"         # GADM level-3 for China

# Region of interest
HLON_MIN, HLON_MAX = 87.5, 102.5   # °E
HLAT_MIN, HLAT_MAX = 25.5, 32.5    # °N
vmin, vmax = -200, 200             # same scale as before; label below says mm

# ------------------------------------------------------------------
# 2)  Load dataset
# ------------------------------------------------------------------
ds = xr.open_dataset(DS_PATH)
ds["time"] = pd.to_datetime(ds["time"])
lat_desc = ds.lat[0] > ds.lat[-1]
lat_slice = lambda lo, hi: slice(hi, lo) if lat_desc else slice(lo, hi)

# Optional bilateral smoothing (kept as in your script)
def smooth_spatial_bilateral(raw_downscaled, d=1, sigma_color=75, sigma_space=75):
    n, w, l = raw_downscaled.shape
    smoothed = np.zeros_like(raw_downscaled)
    for i in range(n):
        raw_slice = raw_downscaled[i, :, :]
        nan_mask = np.isnan(raw_slice)
        mean_value = np.nanmean(raw_slice)
        filled_slice = np.where(nan_mask, mean_value, raw_slice)
        filled_slice = filled_slice.astype(np.float32)
        smoothed_slice = cv2.bilateralFilter(filled_slice, d=d, sigmaColor=sigma_color, sigmaSpace=sigma_space)
        smoothed_slice[nan_mask] = np.nan
        smoothed[i, :, :] = smoothed_slice
    return smoothed

ds["data"].values = smooth_spatial_bilateral(ds["data"].values, d=11, sigma_color=75, sigma_space=75)

# ------------------------------------------------------------------
# 3)  Load & prep boundaries (WGS84 lon/lat), clip to ROI for speed
# ------------------------------------------------------------------
TP_SHP  = "Union/TP_basins.shp"
ADM_SHP = "gadm41_CHN_shp/gadm41_CHN_2.shp"

# --- load + reproject ---
gdf_tp  = gpd.read_file(TP_SHP)
gdf_adm = gpd.read_file(ADM_SHP)
if gdf_tp.crs is not None:  gdf_tp  = gdf_tp.to_crs(epsg=4326)
if gdf_adm.crs is not None: gdf_adm = gdf_adm.to_crs(epsg=4326)

# --- true geometric clip to ROI ---
roi_poly = gpd.GeoDataFrame(geometry=[box(HLON_MIN, HLAT_MIN, HLON_MAX, HLAT_MAX)], crs="EPSG:4326")
gdf_tp_clip  = gpd.clip(gdf_tp, roi_poly)
gdf_adm_clip = gpd.clip(gdf_adm, roi_poly)

# --- pick the best available name column (robust across GADM variants) ---
name_candidates = ["NAME_3", "NL_NAME_3", "VARNAME_3", "NAME_2", "NL_NAME_2"]
ADM_NAME_COL = next((c for c in name_candidates if c in gdf_adm_clip.columns), None)

# --- precompute label anchor points (inside each polygon) ---
adm_labels = []
if ADM_NAME_COL is not None and len(gdf_adm_clip) > 0:
    gdf_lab = gdf_adm_clip[gdf_adm_clip[ADM_NAME_COL].notna()].copy()
    reps = gdf_lab.representative_point()  # guaranteed inside polygon
    adm_labels = list(zip(reps.x, reps.y, gdf_lab[ADM_NAME_COL].astype(str)))

# ------------------------------------------------------------------
# 4)  Time range
# ------------------------------------------------------------------
times = pd.date_range("2009-09", "2010-05", freq="MS")

# ------------------------------------------------------------------
# 5)  Figure & axes (extra left margin, share x only)
# ------------------------------------------------------------------
n, cols = len(times), 3
rows    = int(np.ceil(n / cols))

fig, axes = plt.subplots(
    rows, cols, figsize=(9.0, 7.5),  # tweak if needed
    sharex=True,
    gridspec_kw={"left": 0.17}
)
axes = np.atleast_2d(axes).ravel()

# Tick positions
xticks = np.arange(np.ceil(HLON_MIN), HLON_MAX + 1, 4)
yticks = np.arange(np.ceil(HLAT_MIN), HLAT_MAX + 1, 2)

# ------------------------------------------------------------------
# 6)  Plot panels
# ------------------------------------------------------------------
pcm = None
for i, t in enumerate(times):
    row, col = divmod(i, cols)
    ax = axes[i]

    da = ds["data"].sel(time=t, method="nearest").sel(
        lon=slice(HLON_MIN, HLON_MAX),
        lat=lat_slice(HLAT_MIN, HLAT_MAX)
    )

    pcm = ax.pcolormesh(
        da.lon, da.lat, da,
        cmap=CMAP, shading="auto",
        vmin=vmin, vmax=vmax,
        rasterized=True, zorder=1
    )

    # --- lock view strictly to ROI (so nothing outside shows) ---
    ax.set_xlim(HLON_MIN, HLON_MAX)
    if lat_desc:
        ax.set_ylim(HLAT_MAX, HLAT_MIN)
    else:
        ax.set_ylim(HLAT_MIN, HLAT_MAX)

    # --- boundaries (clipped to ROI) ---
    gdf_adm_clip.boundary.plot(ax=ax, linewidth=0.3, edgecolor="black", alpha=0.6, zorder=3)
    gdf_tp_clip.boundary.plot(ax=ax,  linewidth=1.2, edgecolor="black", zorder=4)

    # --- NEW: admin labels (names) ---
    # small font + white outline for readability over colors
    for xlab, ylab, name in adm_labels:
        ax.text(
            xlab, ylab, name,
            fontsize=4, ha="center", va="center", zorder=5, clip_on=True, alpha=0.95,
            path_effects=[pe.withStroke(linewidth=1.5, foreground="white")]
        )

    # panel title "(a) YYYY-MM"
    panel_letter = string.ascii_lowercase[i]
    ax.set_title(f"({panel_letter}) {t.strftime('%Y-%m')}", fontsize=10, loc="left")


    # First column: show y-ticks
    if col == 0:
        ax.set_yticks(yticks)
        ax.set_yticklabels([f"{y:.0f}" for y in yticks], fontsize=8)
        ax.set_ylabel("Latitude (°N)")
    else:
        ax.set_yticks([])
        ax.set_ylabel("")

    # Bottom row: show x-ticks
    if row == rows - 1:
        ax.set_xticks(xticks)
        ax.set_xticklabels([f"{x:.0f}" for x in xticks], fontsize=8)
        ax.set_xlabel("Longitude (°E)")
    else:
        ax.set_xticks([])
        ax.set_xlabel("")

# Hide unused axes
for ax in axes[n:]:
    ax.set_visible(False)

# Shared colorbar
cbar = fig.colorbar(pcm, ax=axes[:n], fraction=0.02, pad=0.02)
cbar.set_label("TWSA (mm)")

# Save
fig.savefig(OUT_FIG, bbox_inches="tight")
plt.close(fig)
