# 使用 pyresample 将弯曲网格 GOCI 子集重采样到 Landsat 规则网格（5 波段批处理）

**目标**：1) 5 波段上采样；2) 输出 GeoTIFF；3) 可选 NPY；4) 每波段三图对比。

## 0. 依赖

In [1]:

import os, json, math, time
import numpy as np
import rasterio
from rasterio.transform import Affine
from rasterio.crs import CRS
from netCDF4 import Dataset
import matplotlib.pyplot as plt
from pyproj import Transformer
from pyresample import geometry, kd_tree
print("✅ imports ok")

✅ imports ok


In [None]:
print(os.getcwd())


/Users/zy/Python_code/My_Git/img_match/TOA_match


## 1. 参数与路径

In [21]:

GOCI_NC = '../goci_subset_5bands.nc'
LANDSAT_TIF = '../SR_Imagery/LC09_L1TP_116035_20250504_20250504_02_T1/LC09_L1TP_116035_20250504_20250504_02_T1_TOA_RAD_B1-2-3-4-5.tif'
OUT_TIF   = '../outputs/goci_resampled_to_landsat.tif'
SAVE_NPY  = True
NPY_DIR   = '../outputs/npy'
META_JSON = '../outputs/goci_to_landsat_meta.json'
USE_GAUSSIAN = True
ROI_METERS   = 800
SIGMA_METERS = 320
NEIGHBOURS   = 16
FILL_VALUE   = np.nan
NPROCS       = 1
GOCI_BANDS = {443:'L_TOA_443', 490:'L_TOA_490', 555:'L_TOA_555', 660:'L_TOA_660', 865:'L_TOA_865'}
LANDSAT_WAVELENGTHS = [443, 483, 561, 655, 865]
PAIR_L2G = {0:443, 1:490, 2:555, 3:660, 4:865}
AOI_BOUNDS = None
GEOTIFF_NODATA = -9999.0
os.makedirs(os.path.dirname(OUT_TIF), exist_ok=True)
if SAVE_NPY: os.makedirs(NPY_DIR, exist_ok=True)
os.makedirs(os.path.dirname(META_JSON), exist_ok=True)
print("✅ params ready")

✅ params ready


## 2. 数据读取

In [4]:

def read_goci_subset(nc_path: str, bands_map: dict):
    if not os.path.exists(nc_path): raise FileNotFoundError(nc_path)
    with Dataset(nc_path, 'r') as ds:
        nav = ds['navigation_data']; geo = ds['geophysical_data']
        lat = np.array(nav['latitude'][:], dtype=np.float32)
        lon = np.array(nav['longitude'][:], dtype=np.float32)
        band_list, wl_order = [], []
        for wl in sorted(bands_map.keys()):
            var = geo[bands_map[wl]]; data = var[:]
            arr = data.filled(np.nan).astype(np.float32) if np.ma.isMaskedArray(data) else np.array(data, dtype=np.float32)
            if '_FillValue' in var.ncattrs():
                try: fv = float(var.getncattr('_FillValue'))
                except: fv = None
                if fv is not None: arr = np.where(arr == fv, np.nan, arr)
            band_list.append(arr); wl_order.append(wl)
        data_array = np.stack(band_list, axis=-1)
    print("GOCI:", data_array.shape, "(H,W,B)")
    return {'data': data_array, 'lat': lat, 'lon': lon, 'wl_order': wl_order}

def read_landsat_tif(tif_path: str):
    if not os.path.exists(tif_path): raise FileNotFoundError(tif_path)
    with rasterio.open(tif_path) as ds:
        stack = ds.read().astype(np.float32)
        if ds.nodata is not None: stack = np.where(stack == ds.nodata, np.nan, stack)
        H, W = ds.height, ds.width; T, crs = ds.transform, ds.crs
        rows = np.arange(H); cols = np.arange(W); cgrid, rgrid = np.meshgrid(cols, rows)
        x = T.c + T.a*(cgrid+0.5) + T.b*(rgrid+0.5); y = T.f + T.d*(cgrid+0.5) + T.e*(rgrid+0.5)
        lon, lat = Transformer.from_crs(crs, 'EPSG:4326', always_xy=True).transform(x, y)
    print("Landsat:", stack.shape, "(B,H,W)")
    return {'data': stack, 'lon': lon, 'lat': lat, 'transform': T, 'crs': crs}

## 3. 逐波段重采样

In [5]:

def resample_all_bands(goci_dict, landsat_dict, use_gauss=True,
                       roi_m=800, sigma_m=320, neighbours=16, fill_value=np.nan, nprocs=1):
    g_arr = goci_dict['data']; g_lat=goci_dict['lat']; g_lon=goci_dict['lon']; wl_order=goci_dict['wl_order']
    L_lat=landsat_dict['lat']; L_lon=landsat_dict['lon']
    source_swath = geometry.SwathDefinition(lons=g_lon, lats=g_lat)
    target_swath = geometry.SwathDefinition(lons=L_lon, lats=L_lat)
    print("Swath:", g_lat.shape, "->", L_lat.shape)
    outs, logs = [], []
    for bi in range(g_arr.shape[-1]):
        src = g_arr[:, :, bi]; src_masked = np.ma.array(src, mask=~np.isfinite(src))
        t0 = time.time()
        if use_gauss:
            out = kd_tree.resample_gauss(source_swath, src_masked, target_swath,
                                         radius_of_influence=roi_m, sigmas=sigma_m,
                                         fill_value=fill_value, neighbours=neighbours,
                                         reduce_data=True, with_uncert=False, nprocs=nprocs); method="gauss"
        else:
            out = kd_tree.resample_nearest(source_swath, src_masked, target_swath,
                                           radius_of_influence=roi_m, fill_value=fill_value); method="nearest"
        dt = time.time()-t0; finite=np.isfinite(out); ratio=float(finite.mean())
        vmin=float(np.nanmin(out)) if finite.any() else float('nan')
        vmax=float(np.nanmax(out)) if finite.any() else float('nan')
        vmean=float(np.nanmean(out)) if finite.any() else float('nan')
        print(f"  band {bi} (wl={wl_order[bi]} nm) -> {out.shape}, {method}, {dt:.2f}s, finite={ratio*100:.2f}%")
        outs.append(out.astype(np.float32))
        logs.append({"band":bi,"wl_nm":int(wl_order[bi]),"method":method,"seconds":round(dt,3),
                     "finite_ratio":round(ratio,5),"min":vmin,"max":vmax,"mean":vmean})
    return np.stack(outs, axis=0), logs

## 4. 保存：GeoTIFF / NPY / Meta

In [6]:

def save_geotiff(out_path, stack_BHW, transform, crs, nodata_value=-9999.0):
    B,H,W = stack_BHW.shape
    profile={"driver":"GTiff","dtype":"float32","count":B,"height":H,"width":W,
             "transform":transform,"crs":crs,"nodata":nodata_value,
             "compress":"deflate","predictor":2,"tiled":True,"blockxsize":512,"blockysize":512}
    data_to_write = np.where(np.isfinite(stack_BHW), stack_BHW, nodata_value).astype(np.float32)
    with rasterio.open(out_path,"w",**profile) as dst:
        for i in range(B): dst.write(data_to_write[i], i+1)
    print("💾 GeoTIFF:", out_path)

def save_npys(npy_dir, stack_BHW, wl_order, save_stack=True):
    os.makedirs(npy_dir, exist_ok=True)
    for i, wl in enumerate(wl_order): np.save(os.path.join(npy_dir, f"resampled_{wl}nm.npy"), stack_BHW[i])
    if save_stack: np.save(os.path.join(npy_dir, "resampled_stack_BHW.npy"), stack_BHW)
    print("💾 NPY ->", npy_dir)

def save_meta_json(json_path, meta_dict):
    with open(json_path,"w",encoding="utf-8") as f: json.dump(meta_dict,f,ensure_ascii=False,indent=2)
    print("💾 Meta:", json_path)

## 5. 可视化对比

In [7]:

def robust_min_max(*arrays, pmin=2, pmax=98):
    vals_list=[a[np.isfinite(a)] for a in arrays if a is not None and np.isfinite(a).any()]
    if not vals_list: return 0.0,1.0
    vals=np.concatenate(vals_list); v1,v2=np.percentile(vals,[pmin,pmax])
    if not np.isfinite(v1) or not np.isfinite(v2) or v1==v2:
        v1,v2=float(np.nanmin(vals)),float(np.nanmax(vals)); 
        if v1==v2: v2=v1+1e-6
    return float(v1),float(v2)

def crop_to_bounds(arr, lon, lat, bounds):
    if bounds is None: return arr,lon,lat
    lon_min,lon_max,lat_min,lat_max = bounds
    mask=(lon>=lon_min)&(lon<=lon_max)&(lat>=lat_min)&(lat<=lat_max)
    if not np.any(mask): return arr,lon,lat
    rows=np.any(mask,axis=1); cols=np.any(mask,axis=0)
    rmin,rmax=int(np.argmax(rows)), int(len(rows)-np.argmax(rows[::-1])-1)
    cmin,cmax=int(np.argmax(cols)), int(len(cols)-np.argmax(cols[::-1])-1)
    return arr[rmin:rmax+1, cmin:cmax+1], lon[rmin:rmax+1, cmin:cmax+1], lat[rmin:rmax+1, cmin:cmax+1]

def quicklook(arr, lon=None, lat=None, max_pixels=2_000_000):
    h,w=arr.shape; total=h*w
    if total<=max_pixels:
        stride=1; arr_sub=arr; lon_sub=lon; lat_sub=lat
    else:
        import math
        stride=int(math.ceil((total/max_pixels)**0.5))
        arr_sub=arr[::stride,::stride]
        lon_sub=lon[::stride,::stride] if lon is not None else None
        lat_sub=lat[::stride,::stride] if lat is not None else None
    if lon_sub is not None and lat_sub is not None and np.isfinite(lon_sub).any() and np.isfinite(lat_sub).any():
        extent=[float(np.nanmin(lon_sub)),float(np.nanmax(lon_sub)),
                float(np.nanmin(lat_sub)),float(np.nanmax(lat_sub))]
    else: extent=None
    return arr_sub,stride,extent

def save_band_comparisons(out_dir, band_idx, wl_g, goci_native, g_lon, g_lat,
                          resampled, L_lon, L_lat, landsat_native, AOI_BOUNDS=None, vmin=None, vmax=None):
    os.makedirs(out_dir, exist_ok=True)
    g_native_crop,g_lon_crop,g_lat_crop=crop_to_bounds(goci_native,g_lon,g_lat,AOI_BOUNDS)
    g_show,s_g,ext_g=quicklook(g_native_crop,g_lon_crop,g_lat_crop)
    if vmin is None or vmax is None: v1,v2=robust_min_max(g_show)
    else: v1,v2=vmin,vmax
    plt.figure(figsize=(6,5)); plt.imshow(g_show,vmin=v1,vmax=v2,origin='upper',extent=ext_g); plt.grid(False)
    plt.title(f"GOCI original {wl_g} nm (stride {s_g})"); plt.xlabel("Longitude"); plt.ylabel("Latitude")
    plt.savefig(os.path.join(out_dir,f"band{band_idx}_GOCI_{wl_g}nm.png"),dpi=300,bbox_inches='tight'); plt.close()

    r_crop,r_lon_crop,r_lat_crop=crop_to_bounds(resampled,L_lon,L_lat,AOI_BOUNDS)
    r_show,s_r,ext_r=quicklook(r_crop,r_lon_crop,r_lat_crop)
    plt.figure(figsize=(6,5)); plt.imshow(r_show,vmin=v1,vmax=v2,origin='upper',extent=ext_r); plt.grid(False)
    plt.title(f"GOCI→Landsat resampled {wl_g} nm (stride {s_r})"); plt.xlabel("Longitude"); plt.ylabel("Latitude")
    plt.savefig(os.path.join(out_dir,f"band{band_idx}_Resampled_{wl_g}nm.png"),dpi=300,bbox_inches='tight'); plt.close()

    l_crop,l_lon_crop,l_lat_crop=crop_to_bounds(landsat_native,L_lon,L_lat,AOI_BOUNDS)
    l_show,s_l,ext_l=quicklook(l_crop,l_lon_crop,l_lat_crop)
    plt.figure(figsize=(6,5)); plt.imshow(l_show,vmin=v1,vmax=v2,origin='upper',extent=ext_l); plt.grid(False)
    plt.title(f"Landsat original (~{wl_g} nm) (stride {s_l})"); plt.xlabel("Longitude"); plt.ylabel("Latitude")
    plt.savefig(os.path.join(out_dir,f"band{band_idx}_Landsat_{wl_g}nm.png"),dpi=300,bbox_inches='tight'); plt.close()

In [13]:
goci_data = read_goci_subset(GOCI_NC, GOCI_BANDS)
print(goci_data['wl_order'])

GOCI: (669, 901, 5) (H,W,B)
[443, 490, 555, 660, 865]


## 6. 运行全流程

In [22]:

goci_data = read_goci_subset(GOCI_NC, GOCI_BANDS)
landsat_data = read_landsat_tif(LANDSAT_TIF)
resampled_stack_BHW, logs = resample_all_bands(
    goci_data, landsat_data,
    use_gauss=USE_GAUSSIAN, roi_m=ROI_METERS, sigma_m=SIGMA_METERS,
    neighbours=NEIGHBOURS, fill_value=FILL_VALUE, nprocs=NPROCS
)
save_geotiff(OUT_TIF, resampled_stack_BHW, landsat_data['transform'], landsat_data['crs'], GEOTIFF_NODATA)
if SAVE_NPY: save_npys(NPY_DIR, resampled_stack_BHW, goci_data['wl_order'], save_stack=True)
meta={"goci_nc":os.path.abspath(GOCI_NC),"landsat_tif":os.path.abspath(LANDSAT_TIF),"out_tif":os.path.abspath(OUT_TIF),
      "use_gaussian":USE_GAUSSIAN,"roi_m":ROI_METERS,"sigma_m":SIGMA_METERS,"neighbours":NEIGHBOURS,"nprocs":NPROCS,
      "geotiff_nodata":GEOTIFF_NODATA,"goci_wavelength_order":goci_data['wl_order'],"logs":logs}
with open(META_JSON,"w",encoding="utf-8") as f: json.dump(meta,f,ensure_ascii=False,indent=2)
fig_dir='outputs/figs_compare'; os.makedirs(fig_dir, exist_ok=True)
for bi, wl in enumerate(goci_data['wl_order']):
    landsat_idx=int(np.argmin(np.abs(np.array(LANDSAT_WAVELENGTHS)-wl)))
    save_band_comparisons(fig_dir, bi, wl,
        goci_data['data'][:, :, bi], goci_data['lon'], goci_data['lat'],
        resampled_stack_BHW[bi, :, :], landsat_data['lon'], landsat_data['lat'],
        landsat_data['data'][landsat_idx, :, :], AOI_BOUNDS=AOI_BOUNDS)
print("✅ 全流程完成")

GOCI: (669, 901, 5) (H,W,B)
Landsat: (5, 7961, 7841) (B,H,W)
Swath: (669, 901) -> (7961, 7841)


  get_neighbour_info(source_geo_def,


  band 0 (wl=443 nm) -> (7961, 7841), gauss, 182.68s, finite=100.00%
  band 1 (wl=490 nm) -> (7961, 7841), gauss, 188.62s, finite=100.00%
  band 2 (wl=555 nm) -> (7961, 7841), gauss, 191.81s, finite=100.00%
  band 3 (wl=660 nm) -> (7961, 7841), gauss, 193.06s, finite=100.00%
  band 4 (wl=865 nm) -> (7961, 7841), gauss, 189.16s, finite=100.00%
💾 GeoTIFF: ../outputs/goci_resampled_to_landsat.tif


NotImplementedError: MaskedArray.tofile() not implemented yet.