In [None]:
"""
NOTE: 
- Change the input directory for this script to work as intended. 
- Make sure to inspect outputs with QGIS. 
"""

import torch

In [2]:
import os
import math
import rasterio
from rasterio.warp import reproject, Resampling
from rasterio.io import DatasetReader
import numpy as np
from datetime import datetime
import pygrib

In [3]:
OUT_DIR="Jin_fwi2"
os.makedirs(OUT_DIR, exist_ok=True)

# Choose a resampling for continuous meteorological fields:
RESAMPLE_METHOD = Resampling.bilinear   # use Resampling.nearest for categorical fields

# We’ll standardize nodata to float32 NaN in outputs
DST_NODATA = np.nan
DST_DTYPE = "float32"


In [4]:
def read_fwi_reference(path: str):
    """Open the FWI GeoTIFF and return its shape, transform, CRS, and a ready-to-use output profile."""
    ref = rasterio.open(path)
    profile = ref.profile.copy()
    # We will override dtype and nodata to float32/NaN to safely carry missing values across layers
    profile.update(
        dtype=DST_DTYPE,
        nodata=DST_NODATA,
        count=1,  # writing single-band files
        compress="deflate",
        tiled=True,
        predictor=3 if DST_DTYPE.startswith("float") else 2,
        BIGTIFF="IF_SAFER"
    )
    return ref, profile

def list_grib_subdatasets(path: str):
    """Return a list of subdataset identifiers exposed by GDAL for the GRIB file."""
    with rasterio.open(path) as src:
        subs = list(src.subdatasets)
        # If empty, try to treat entire GRIB as a single dataset (some builds expose bands directly)
        if not subs:
            return [path]
        return subs

def sanitize_name(s: str) -> str:
    """Create a safe filename chunk from a subdataset string."""
    s = s.replace(":", "_").replace(",", "_").replace("=", "_")
    s = s.replace("/", "_").replace("\\", "_").replace(" ", "_")
    # Make it shorter but meaningful
    if len(s) > 140:
        s = s[-140:]
    return s

def guess_nodata(src: DatasetReader):
    """Return a reasonable src_nodata value for GRIB; many expose -3.402823e+38 for float32."""
    # Try dataset nodata first
    if src.nodata is not None and not (isinstance(src.nodata, float) and math.isnan(src.nodata)):
        return src.nodata

    # If nodata is not set, try to infer from stats or known GRIB fill value
    # Common GDAL GRIB float32 fill is -3.4028234663852886e+38
    return np.float32(-3.4028235e38)

In [5]:
def reproject_to_fwi_grid(src_ds_path: str, ref_ds: DatasetReader, out_profile_base: dict, resampling=RESAMPLE_METHOD):
    """
    Reproject (regrid) a source dataset (single-band GRIB subdataset) onto the FWI grid.
    Writes a single-band GeoTIFF with same width/height/transform/CRS as the FWI reference.
    """
    # Open source
    with rasterio.open(src_ds_path) as src:
        # If the subdataset has multiple bands, process each band separately
        bands = src.count
        src_crs = src.crs
        if src_crs is None:
            raise RuntimeError(f"Source CRS is None for {src_ds_path}; cannot reproject. "
                               "You likely need a GDAL build with GRIB support and metadata enabled.")

        src_nodata = guess_nodata(src)

        # Destination grid from FWI ref
        dst_height, dst_width = ref_ds.height, ref_ds.width
        dst_transform = ref_ds.transform
        dst_crs = ref_ds.crs

        # Prepare output profile (single-band float32)
        out_profile = out_profile_base.copy()
        out_profile.update(
            width=dst_width,
            height=dst_height,
            transform=dst_transform,
            crs=dst_crs,
        )

        # For naming, try to extract something meaningful (parameter, date, level) from metadata
        md = src.tags() or {}
        param = md.get("GRIB_ELEMENT") or md.get("GRIB_SHORT_NAME") or md.get("GRIB_COMMENT") or "layer"
        level = md.get("GRIB_LAYER") or md.get("GRIB_SHORT_NAME_LEVEL") or md.get("GRIB_INDICATOR_OF_PARAMETER")
        date_s = md.get("GRIB_REF_TIME") or md.get("GRIB_VALID_TIME")
        # normalize date
        stamp = ""
        for k in ("GRIB_VALID_TIME", "GRIB_REF_TIME"):
            if md.get(k):
                try:
                    # GRIB times are often like "2025-09-27T00:00:00Z"
                    dt = datetime.fromisoformat(md[k].replace("Z",""))
                    stamp = dt.strftime("%Y%m%dT%H%M")
                    break
                except Exception:
                    pass

        base_name_hint = f"{param}"
        if level: base_name_hint += f"_lev{level}"
        if stamp: base_name_hint += f"_{stamp}"

        # If dataset exposes multiple bands, we’ll append band index to filename
        for b in range(1, bands + 1):
            # Read band
            src_arr = src.read(b).astype(np.float32)

            # Mask the fill value (and crazy extremes)
            # Replace known GRIB nodata with NaN
            src_arr = np.where(src_arr <= (np.float32(-3.402823e38) / 10), np.nan, src_arr)
            if src_nodata is not None and not (isinstance(src_nodata, float) and math.isnan(src_nodata)):
                src_arr = np.where(src_arr == np.float32(src_nodata), np.nan, src_arr)

            # Prepare destination array
            dst_arr = np.full((dst_height, dst_width), np.nan, dtype=np.float32)

            # Reproject to FWI grid & crop to FWI extent by using the FWI transform/shape
            reproject(
                source=src_arr,
                destination=dst_arr,
                src_transform=src.transform,
                src_crs=src_crs,
                src_nodata=np.float32(src_nodata) if src_nodata is not None else None,
                dst_transform=dst_transform,
                dst_crs=dst_crs,
                dst_nodata=np.float32(np.nan),
                resampling=resampling,
            )

            # Build output filename
            hint = sanitize_name(base_name_hint) if base_name_hint else "grib_layer"
            out_name = f"{hint}_band{b}.tif" if bands > 1 else f"{hint}.tif"
            out_path = os.path.join(OUT_DIR, out_name)

            # Write file
            with rasterio.open(out_path, "w", **out_profile) as dst:
                dst.write(dst_arr, 1)

            print(f"Wrote: {out_path}")


In [6]:
import math
import numpy as np
import rasterio
from rasterio.warp import reproject, Resampling
from rasterio.transform import Affine
from rasterio.crs import CRS


import pygrib

def list_grib_subdatasets_pygrib(grib_path):
    """
    Enumerate pygrib messages with enough metadata to process them later.

    Returns: a list of dicts, one per message:
      {
        "index": int,              # 1-based message index in the GRIB file
        "shortName": str,
        "name": str,
        "level": str|int|None,
        "validDate": datetime|None,
        "analDate": datetime|None,
        "units": str|None,
        "nx": int, "ny": int,
        "crs": rasterio.crs.CRS|None,
        "transform": Affine|None,  # None if irregular/rotated; use latlons fallback
        "nodata": float|None
      }
    """
    out = []
    with pygrib.open(grib_path) as grbs:
        for i, msg in enumerate(grbs, start=1):
            # Core descriptors
            short = getattr(msg, 'shortName', None) or getattr(msg, 'name', 'layer')
            name  = getattr(msg, 'name', None) or short
            level = getattr(msg, 'level', None)
            units = getattr(msg, 'units', None)
            vdt   = getattr(msg, 'validDate', None)
            adt   = getattr(msg, 'analDate', None)

            # Dimensions/geo
            transform, crs, nx, ny = _message_georeferencing(msg)

            # GRIB fill/nodata often absent—set a conventional default
            nodata = GRIB_FILL32

            out.append({
                "index": i,
                "shortName": str(short),
                "name": str(name),
                "level": level,
                "validDate": vdt,
                "analDate": adt,
                "units": units,
                "nx": nx, "ny": ny,
                "crs": crs,
                "transform": transform,
                "nodata": nodata,
            })
    return out

def _message_georeferencing(msg):
    """
    Return (transform, crs, nx, ny). If we can't build an Affine, return (None, crs, nx, ny)
    and the caller can fall back to lat/lon grids via msg.latlons().
    """
    transform, crs, nx, ny = _affine_from_regular_grid(msg)
    if transform is not None:
        return transform, crs, nx, ny

    # Fallback CRS guess
    crs = CRS.from_epsg(4326) if 'latitudeOfFirstGridPointInDegrees' in msg.keys() else None
    nx = int(msg['Nx'])
    ny = int(msg['Ny'])
    return None, crs, nx, ny

# handle the encoding error in our fwi images 
GRIB_FILL32 = np.float32(-3.4028235e38)

def _nan_mask(a, nodata):
    a = a.astype(np.float32, copy=False)
    if nodata is not None and not (isinstance(nodata, float) and math.isnan(nodata)):
        a = np.where(a == np.float32(nodata), np.nan, a)
    # also mask extreme GRIB fill (sometimes nodata missing from metadata)
    a = np.where(a <= (GRIB_FILL32 / 10.0), np.nan, a)
    return a

def _affine_from_regular_grid(msg):
    """
    Try to build an Affine for a *regular* grid.
    Works for lat/lon (regular_ll) and many projected regular grids (e.g., regular_gg, lambert, mercator).
    """
    # Dimensions
    ny, nx = int(msg['Ny']), int(msg['Nx'])

    # For regular lat/lon grids:
    if 'latitudeOfFirstGridPointInDegrees' in msg.keys() and 'longitudeOfFirstGridPointInDegrees' in msg.keys() \
       and 'latitudeOfLastGridPointInDegrees' in msg.keys() and 'longitudeOfLastGridPointInDegrees' in msg.keys() \
       and 'iDirectionIncrementInDegrees' in msg.keys() and 'jDirectionIncrementInDegrees' in msg.keys():

        lat1 = float(msg['latitudeOfFirstGridPointInDegrees'])
        lon1 = float(msg['longitudeOfFirstGridPointInDegrees'])
        di   = float(msg['iDirectionIncrementInDegrees'])
        dj   = float(msg['jDirectionIncrementInDegrees'])

        # Affine expects top-left corner; GRIB coords reference cell centers.
        # Shift to top-left corner by half a pixel.
        x0 = lon1 - di/2.0
        y0 = lat1 + dj/2.0  # j increases southward in GRIB; we use negative y pixel size below

        transform = Affine.translation(x0, y0) * Affine(di, 0, 0, 0, -dj, 0)
        crs = CRS.from_epsg(4326)  # geographic
        return transform, crs, nx, ny

    # Try projected regular grids using grid increments in meters/degrees + projparams
    # Many messages expose projparams/proj4 string via pygrib
    proj4 = None
    if hasattr(msg, 'projparams') and msg.projparams:
        proj4 = msg.projparams

    # Lambert / Mercator / Polar stereographic often provide increments in meters:
    xinc = msg.keys().get('DxInMetres') or msg.keys().get('DiInMetres') or None
    yinc = msg.keys().get('DyInMetres') or msg.keys().get('DjInMetres') or None
    xinc_deg = msg.keys().get('iDirectionIncrementInDegrees')
    yinc_deg = msg.keys().get('jDirectionIncrementInDegrees')

    if proj4 and (xinc or yinc or xinc_deg or yinc_deg):
        # Try to derive origin. Many GRIBs give first gridpoint x/y in same units:
        # If not, we fallback to center lat/lon + proj to x/y would be needed (out of scope).
        if 'longitudeOfFirstGridPointInDegrees' in msg.keys() and 'latitudeOfFirstGridPointInDegrees' in msg.keys():
            lon1 = float(msg['longitudeOfFirstGridPointInDegrees'])
            lat1 = float(msg['latitudeOfFirstGridPointInDegrees'])
            crs = CRS.from_string(proj4)

            # We need projected x/y of the first gridpoint; without pyproj here we'd guess:
            # For projected grids, pygrib often also provides La1/Lo1 in projected units; if not,
            # this path may be insufficient. Prefer lat/lon branch above when possible.
            # Minimal, best-effort fallback: treat degrees increments if meters missing.
            if xinc is None and xinc_deg is not None:
                xinc = float(xinc_deg)
            if yinc is None and yinc_deg is not None:
                yinc = float(yinc_deg)

            # Without projecting lon/lat to x/y, we can't build a *true* projected affine.
            # So, unless you add a pyproj transform here, we return None to force lat/lon path.
            # (Add pyproj here if you need robust projected handling.)
            return None, None, None, None

    # Could not build a clean affine (e.g., irregular/rotated grid)
    return None, None, None, None


def reproject_to_fwi_grid_pygrib(
    grib_path,
    entry,
    ref_ds,
    out_profile_base,
    out_dir,
    resampling=Resampling.bilinear
):
    """
    Reproject a single pygrib message (described by 'entry' from list_grib_subdatasets_pygrib)
    into the FWI grid and write a single-band GeoTIFF.

    Parameters
    ----------
    grib_path : str
        Path to the .grib file.
    entry : dict
        One item returned by list_grib_subdatasets_pygrib(), e.g.:
        {
          "index": 5, "shortName": "t2m", "name": "...", "level": 2,
          "validDate": datetime(...), "analDate": datetime(...),
          "units": "K", "nx": 240, "ny": 121,
          "crs": rasterio.crs.CRS, "transform": Affine(...) or None,
          "nodata": -3.4028235e38
        }
    ref_ds : rasterio.io.DatasetReader
        Open FWI reference raster (defines target CRS, transform, width, height).
    out_profile_base : dict
        A base rasterio profile to update for output (dtype, nodata, compression, etc.).
    out_dir : str
        Output folder.
    resampling : rasterio.warp.Resampling
        Resampling method (bilinear for continuous, nearest for categorical).

    Returns
    -------
    str
        Path to the written GeoTIFF.

    Notes
    -----
    - Requires a *regular* grid with a valid affine `entry["transform"]`. If None,
      you likely have a rotated/irregular grid and must implement a lat/lon fallback.
    - Uses _nan_mask(...) to convert GRIB fill values to NaN before reprojection.
    """
    import os
    import numpy as np
    import pygrib
    import rasterio
    from rasterio.warp import reproject
    from rasterio.crs import CRS

    # --- helper to mask GRIB fill values ---
    def _nan_mask(a, nodata):
        GRIB_FILL32 = np.float32(-3.4028235e38)
        a = a.astype(np.float32, copy=False)
        if nodata is not None and not (isinstance(nodata, float) and np.isnan(nodata)):
            a = np.where(a == np.float32(nodata), np.nan, a)
        a = np.where(a <= (GRIB_FILL32 / 10.0), np.nan, a)
        return a

    idx = entry["index"]
    with pygrib.open(grib_path) as grbs:
        msg = grbs.message(idx)
        src_arr = _nan_mask(msg.values, entry.get("nodata"))

        transform = entry.get("transform")
        crs = entry.get("crs")

        if transform is None or crs is None:
            raise RuntimeError(
                f"Message {idx} appears irregular/rotated; no affine transform could be built. "
                f"Implement a lat/lon fallback using msg.latlons() + a dedicated remapping step."
            )

        # Target grid from FWI reference
        dst_height, dst_width = ref_ds.height, ref_ds.width
        dst_transform = ref_ds.transform
        dst_crs = ref_ds.crs

        dst_arr = np.full((dst_height, dst_width), np.nan, dtype=np.float32)

        reproject(
            source=src_arr,
            destination=dst_arr,
            src_transform=transform,
            src_crs=crs,
            src_nodata=np.float32(np.nan),  # already NaN-masked
            dst_transform=dst_transform,
            dst_crs=dst_crs,
            dst_nodata=np.float32(np.nan),
            resampling=resampling,
        )

        # Build filename from metadata
        hint = entry.get("shortName") or "layer"
        if entry.get("level") is not None:
            hint += f"_lev{entry['level']}"
        if entry.get("validDate") is not None:
            hint += f"_{entry['validDate'].strftime('%Y%m%dT%H%M')}"
        safe = (
            hint.replace(" ", "_")
                .replace("/", "_")
                .replace(":", "_")
                .replace(",", "_")
        )
        out_path = os.path.join(out_dir, f"{safe}.tif")

        out_profile = out_profile_base.copy()
        out_profile.update(
            width=dst_width,
            height=dst_height,
            transform=dst_transform,
            crs=dst_crs,
            count=1,
            dtype="float32",
            nodata=np.nan,
        )

        with rasterio.open(out_path, "w", **out_profile) as dst:
            dst.write(dst_arr, 1)

        print(f"Wrote: {out_path}")
        return out_path

In [7]:
ref, out_profile = read_fwi_reference("fwi/fwi_scribe_20250911.tif")
print("FWI reference:",
        f"\n  CRS:   {ref.crs}",
        f"\n  Size:  {ref.width} x {ref.height}",
        f"\n  Res.:  {ref.res}",
        f"\n  Bounds:{ref.bounds}",
        sep="")

FWI reference:
  CRS:   EPSG:3978
  Size:  2709 x 2281
  Res.:  (1999.999630860096, 1999.9995615957914)
  Bounds:BoundingBox(left=-2378164.0, bottom=-707617.0, right=3039835.0, top=3854382.0)


In [9]:
GRIB_PATH = "ecmwf/20250911grib.grib"
subdatasets = list_grib_subdatasets_pygrib(GRIB_PATH)
print(f"Found {len(subdatasets)} GRIB subdataset(s).")

Found 160 GRIB subdataset(s).


In [10]:
out_profile_base = ref.profile.copy()
out_profile_base.update({
    "driver": "GTiff",
    "dtype": "float32",
    "nodata": np.nan,
    "compress": "deflate",
    "predictor": 3,          # good for float + deflate
    "tiled": True,
    "blockxsize": 256,       # MUST be multiple of 16 (256 is standard)
    "blockysize": 256,       # ditto
    "count": 1,              # single-band outputs
    "BIGTIFF": "IF_SAFER",
})

In [11]:
for i, e in enumerate(subdatasets, 1):
    print(f"Processing message {i}/{len(subdatasets)}: {e['shortName']} level={e['level']} valid={e['validDate']}")
    reproject_to_fwi_grid_pygrib(
        GRIB_PATH,            # path to the .grib file
        e,                    # the dict entry from list_grib_subdatasets_pygrib
        ref,                  # your open FWI reference dataset
        out_profile,          # base profile (dtype=float32, nodata=NaN, compression, etc.)
        OUT_DIR,              # where to write the per-layer GeoTIFFs
        resampling=Resampling.bilinear
    )


Processing message 1/160: tcw level=0 valid=2025-09-25 00:00:00
Wrote: Jin_fwi2\tcw_lev0_20250925T0000.tif
Processing message 2/160: 2t level=2 valid=2025-09-25 00:00:00
Wrote: Jin_fwi2\2t_lev2_20250925T0000.tif
Processing message 3/160: 2d level=2 valid=2025-09-25 00:00:00
Wrote: Jin_fwi2\2d_lev2_20250925T0000.tif
Processing message 4/160: sm20 level=0 valid=2025-09-25 00:00:00
Wrote: Jin_fwi2\sm20_lev0_20250925T0000.tif
Processing message 5/160: st20 level=0 valid=2025-09-25 00:00:00
Wrote: Jin_fwi2\st20_lev0_20250925T0000.tif
Processing message 6/160: tcw level=0 valid=2025-09-26 00:00:00
Wrote: Jin_fwi2\tcw_lev0_20250926T0000.tif
Processing message 7/160: 2t level=2 valid=2025-09-26 00:00:00
Wrote: Jin_fwi2\2t_lev2_20250926T0000.tif
Processing message 8/160: 2d level=2 valid=2025-09-26 00:00:00
Wrote: Jin_fwi2\2d_lev2_20250926T0000.tif
Processing message 9/160: sm20 level=0 valid=2025-09-26 00:00:00
Wrote: Jin_fwi2\sm20_lev0_20250926T0000.tif
Processing message 10/160: st20 level=0