In [None]:
# Save this as:
# C:\Users\Adrija\Downloads\DFGCN\data\scripts\extract_hurdat2_irma_ian.py

import os
import csv

BASE = r"C:\Users\Adrija\Downloads\DFGCN\data"
HURDAT_PATH = os.path.join(BASE, "raw", "hurdat2", "hurdat2_atlantic.txt")
OUT_DIR = os.path.join(BASE, "processed", "tracks")

# Irma 2017 = AL112017, Ian 2022 = AL092022
TARGET_IDS = {"AL112017": "irma_2017", "AL092022": "ian_2022"}

def parse_lat_lon(lat_str: str, lon_str: str):
    # examples:  "25.1N", "80.3W"
    lat = float(lat_str[:-1])
    if lat_str.endswith("S"):
        lat = -lat
    lon = float(lon_str[:-1])
    if lon_str.endswith("W"):
        lon = -lon
    return lat, lon

def main():
    os.makedirs(OUT_DIR, exist_ok=True)

    current_id = None
    current_name = None
    rows_by_id = {sid: [] for sid in TARGET_IDS}

    with open(HURDAT_PATH, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue

            # Header line: AL112017, IRMA, 51
            if line.startswith("AL") and "," in line:
                parts = [p.strip() for p in line.split(",")]
                current_id = parts[0]
                current_name = parts[1] if len(parts) > 1 else None
                continue

            if current_id in TARGET_IDS:
                parts = [p.strip() for p in line.split(",")]
                # date, time, record_id, status, lat, lon, vmax, mslp, ...
                if len(parts) < 8:
                    continue

                date = parts[0]
                time = parts[1]
                status = parts[3]
                lat_str = parts[4]
                lon_str = parts[5]
                vmax = parts[6]
                mslp = parts[7]

                # keep synoptic times only (6-hourly)
                if time not in {"0000", "0600", "1200", "1800"}:
                    continue

                lat, lon = parse_lat_lon(lat_str, lon_str)

                rows_by_id[current_id].append({
                    "storm_id": current_id,
                    "storm_name": current_name,
                    "date": date,
                    "time": time,
                    "datetime_utc": f"{date[:4]}-{date[4:6]}-{date[6:8]} {time[:2]}:{time[2:]}:00",
                    "status": status,
                    "lat": lat,
                    "lon": lon,
                    "vmax_kt": int(vmax) if vmax.isdigit() else "",
                    "mslp_mb": int(mslp) if mslp.isdigit() else "",
                })

    for sid, tag in TARGET_IDS.items():
        out_path = os.path.join(OUT_DIR, f"{tag}_hurdat2.csv")
        with open(out_path, "w", newline="", encoding="utf-8") as wf:
            writer = csv.DictWriter(
                wf,
                fieldnames=[
                    "storm_id","storm_name","date","time","datetime_utc",
                    "status","lat","lon","vmax_kt","mslp_mb"
                ],
            )
            writer.writeheader()
            writer.writerows(rows_by_id[sid])

        print(f"Wrote {len(rows_by_id[sid])} rows -> {out_path}")

if __name__ == "__main__":
    main()


In [1]:
import cdsapi
c = cdsapi.Client()
print("CDS client OK")

CDS client OK


In [2]:
# Save this as:
# C:\Users\Adrija\Downloads\DFGCN\data\scripts\download_era5_irma_2017.py

%pip install cdsapi

import os
import cdsapi

BASE = r"C:\Users\Adrija\Downloads\DFGCN\data"
OUTDIR = os.path.join(BASE, "raw", "era5", "irma_2017")
os.makedirs(OUTDIR, exist_ok=True)

c = cdsapi.Client()

# Florida + Gulf + Bahamas coverage, storage-safe:
AREA = [35, -90, 20, -70]  # North, West, South, East

c.retrieve(
    "reanalysis-era5-pressure-levels",
    {
        "product_type": "reanalysis",
        "format": "netcdf",
        "variable": ["u_component_of_wind", "v_component_of_wind", "geopotential"],
        "pressure_level": ["850", "500"],
        "year": "2017",
        "month": ["08", "09"],
        "day": [
            "30","31",  # Aug
            "01","02","03","04","05","06","07","08","09","10","11","12","13","14","15"  # Sep
        ],
        "time": ["00:00","06:00","12:00","18:00"],
        "area": AREA,
    },
    os.path.join(OUTDIR, "era5_pl_irma_2017.nc"),
)

print("Saved:", os.path.join(OUTDIR, "era5_pl_irma_2017.nc"))



[notice] A new release of pip is available: 25.2 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip





2026-01-07 08:03:28,837 INFO Request ID is 33e5de44-5a2d-4361-9a0f-336d220de802
2026-01-07 08:03:29,028 INFO status has been updated to accepted
2026-01-07 08:03:51,173 INFO status has been updated to successful


ba80c89c5aaa0f0f616e96413613522e.nc:   0%|          | 0.00/6.89M [00:00<?, ?B/s]

Saved: C:\Users\Adrija\Downloads\DFGCN\data\raw\era5\irma_2017\era5_pl_irma_2017.nc


In [3]:
# Save this as:
# C:\Users\Adrija\Downloads\DFGCN\data\scripts\download_era5_ian_2022.py

import os
import cdsapi

BASE = r"C:\Users\Adrija\Downloads\DFGCN\data"
OUTDIR = os.path.join(BASE, "raw", "era5", "ian_2022")
os.makedirs(OUTDIR, exist_ok=True)

c = cdsapi.Client()

AREA = [35, -90, 20, -70]  # North, West, South, East

c.retrieve(
    "reanalysis-era5-pressure-levels",
    {
        "product_type": "reanalysis",
        "format": "netcdf",
        "variable": ["u_component_of_wind", "v_component_of_wind", "geopotential"],
        "pressure_level": ["850", "500"],
        "year": "2022",
        "month": ["09", "10"],
        "day": [
            "20","21","22","23","24","25","26","27","28","29","30",  # Sep
            "01","02","03"  # Oct
        ],
        "time": ["00:00","06:00","12:00","18:00"],
        "area": AREA,
    },
    os.path.join(OUTDIR, "era5_pl_ian_2022.nc"),
)

print("Saved:", os.path.join(OUTDIR, "era5_pl_ian_2022.nc"))


2026-01-07 08:04:52,976 INFO Request ID is 7a5c3e1e-bb80-4de0-b7f1-db97bddac432
2026-01-07 08:04:53,207 INFO status has been updated to accepted
2026-01-07 08:05:07,541 INFO status has been updated to running
2026-01-07 08:06:10,017 INFO status has been updated to successful


7fdc13d545a35258fa10af1fd127337d.nc:   0%|          | 0.00/5.92M [00:00<?, ?B/s]

Saved: C:\Users\Adrija\Downloads\DFGCN\data\raw\era5\ian_2022\era5_pl_ian_2022.nc


In [4]:
import os, csv
import pandas as pd

BASE = r"C:\Users\Adrija\Downloads\DFGCN\data"
HURDAT = os.path.join(BASE, "raw", "hurdat2", "hurdat2_atlantic.txt")
OUTDIR = os.path.join(BASE, "processed", "tracks")
os.makedirs(OUTDIR, exist_ok=True)

TARGET = {"AL112017": "irma_2017", "AL092022": "ian_2022"}

def parse_lat_lon(lat_str, lon_str):
    lat = float(lat_str[:-1]) * (-1 if lat_str.endswith("S") else 1)
    lon = float(lon_str[:-1]) * (-1 if lon_str.endswith("W") else 1)
    return lat, lon

rows = {sid: [] for sid in TARGET}

current_id = None
current_name = None

with open(HURDAT, "r", encoding="utf-8") as f:
    for line in f:
        line = line.strip()
        if not line:
            continue

        if line.startswith("AL") and "," in line:
            p = [x.strip() for x in line.split(",")]
            current_id = p[0]
            current_name = p[1] if len(p) > 1 else ""
            continue

        if current_id in TARGET:
            p = [x.strip() for x in line.split(",")]
            if len(p) < 8:
                continue
            date, time = p[0], p[1]
            status = p[3]
            lat_str, lon_str = p[4], p[5]
            vmax, mslp = p[6], p[7]

            # keep only 6-hourly
            if time not in {"0000", "0600", "1200", "1800"}:
                continue

            lat, lon = parse_lat_lon(lat_str, lon_str)
            rows[current_id].append({
                "storm_id": current_id,
                "storm_name": current_name,
                "datetime_utc": f"{date[:4]}-{date[4:6]}-{date[6:8]} {time[:2]}:{time[2:]}:00",
                "status": status,
                "lat": lat,
                "lon": lon,
                "vmax_kt": int(vmax) if vmax.isdigit() else None,
                "mslp_mb": int(mslp) if mslp.isdigit() else None
            })

for sid, tag in TARGET.items():
    df = pd.DataFrame(rows[sid])
    df["datetime_utc"] = pd.to_datetime(df["datetime_utc"], utc=True)
    df = df.sort_values("datetime_utc").reset_index(drop=True)
    out = os.path.join(OUTDIR, f"{tag}_hurdat2.csv")
    df.to_csv(out, index=False)
    print("Wrote:", out, "rows=", len(df))


Wrote: C:\Users\Adrija\Downloads\DFGCN\data\processed\tracks\irma_2017_hurdat2.csv rows= 59
Wrote: C:\Users\Adrija\Downloads\DFGCN\data\processed\tracks\ian_2022_hurdat2.csv rows= 35


In [10]:
%pip install xarray netCDF4

Collecting xarray
  Using cached xarray-2025.12.0-py3-none-any.whl.metadata (12 kB)
Collecting netCDF4
  Using cached netcdf4-1.7.4-cp311-abi3-win_amd64.whl.metadata (2.1 kB)
Collecting cftime (from netCDF4)
  Downloading cftime-1.6.5-cp312-cp312-win_amd64.whl.metadata (8.8 kB)
Using cached xarray-2025.12.0-py3-none-any.whl (1.4 MB)
Using cached netcdf4-1.7.4-cp311-abi3-win_amd64.whl (21.3 MB)
Downloading cftime-1.6.5-cp312-cp312-win_amd64.whl (464 kB)
Installing collected packages: cftime, netCDF4, xarray

   ------------- -------------------------- 1/3 [netCDF4]
   ------------- -------------------------- 1/3 [netCDF4]
   ------------- -------------------------- 1/3 [netCDF4]
   ------------- -------------------------- 1/3 [netCDF4]
   ------------- -------------------------- 1/3 [netCDF4]
   ------------- -------------------------- 1/3 [netCDF4]
   -------------------------- ------------- 2/3 [xarray]
   -------------------------- ------------- 2/3 [xarray]
   ----------------------


[notice] A new release of pip is available: 25.2 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


In [15]:
import numpy as np
import os

BASE = r"C:\Users\Adrija\Downloads\DFGCN\data"
for storm in ["irma","ian"]:
    path = os.path.join(BASE, "processed", "samples", f"{storm}_samples.npz")
    d = np.load(path, allow_pickle=True)
    print("\n", storm.upper(), path)
    print("X:", d["X"].shape, " (N,F,G,G) expected")
    print("past:", d["past"].shape, " (N,H,2)")
    print("meta:", d["meta"].shape, " (N,2)")
    print("y:", d["y"].shape, " (N,L,2)")
    print("t0:", d["t0"].shape)



 IRMA C:\Users\Adrija\Downloads\DFGCN\data\processed\samples\irma_samples.npz
X: (21, 5, 33, 33)  (N,F,G,G) expected
past: (21, 4, 2)  (N,H,2)
meta: (21, 2)  (N,2)
y: (21, 4, 2)  (N,L,2)
t0: (21,)

 IAN C:\Users\Adrija\Downloads\DFGCN\data\processed\samples\ian_samples.npz
X: (23, 5, 33, 33)  (N,F,G,G) expected
past: (23, 4, 2)  (N,H,2)
meta: (23, 2)  (N,2)
y: (23, 4, 2)  (N,L,2)
t0: (23,)


In [14]:
# ============================================================
# Save as:
# C:\Users\Adrija\Downloads\DFGCN\data\scripts\02_build_samples_npz.py
# ============================================================

import os
import numpy as np
import pandas as pd
import xarray as xr
import torch
import torch.nn.functional as F
from tqdm import tqdm

BASE = r"C:\Users\Adrija\Downloads\DFGCN\data"

TRACKS = {
    "irma": os.path.join(BASE, "processed", "tracks", "irma_2017_hurdat2.csv"),
    "ian":  os.path.join(BASE, "processed", "tracks", "ian_2022_hurdat2.csv"),
}
ERA5 = {
    "irma": os.path.join(BASE, "raw", "era5", "irma_2017", "era5_pl_irma_2017.nc"),
    "ian":  os.path.join(BASE, "raw", "era5", "ian_2022", "era5_pl_ian_2022.nc"),
}

OUTDIR = os.path.join(BASE, "processed", "samples")
os.makedirs(OUTDIR, exist_ok=True)

GRID = 33
CROP_DEG = 8.0
HIST = 4
LEADS_H = [6, 12, 24, 48]
LEAD_STEPS = [h // 6 for h in LEADS_H]


def open_era5(path: str) -> xr.Dataset:
    ds = xr.open_dataset(path)

    # variable names
    u = "u" if "u" in ds.variables else "u_component_of_wind"
    v = "v" if "v" in ds.variables else "v_component_of_wind"
    z = "z" if "z" in ds.variables else "geopotential"
    if u not in ds.variables or v not in ds.variables or z not in ds.variables:
        raise ValueError(f"ERA5 variables missing. Found vars: {list(ds.variables)}")

    # pressure coordinate
    if "level" in ds.coords:
        plev = "level"
    elif "pressure_level" in ds.coords:
        plev = "pressure_level"
    else:
        raise ValueError(f"No pressure-level coord found. coords={list(ds.coords)}")

    # time coordinate
    if "time" in ds.coords:
        tcoord = "time"
    elif "valid_time" in ds.coords:
        tcoord = "valid_time"
    else:
        raise ValueError(f"No time coord found. coords={list(ds.coords)}")

    # --- CRITICAL FIX: normalize longitudes to [-180, 180] and sort ---
    if "longitude" not in ds.coords:
        raise ValueError("ERA5 missing longitude coordinate")

    lon = ds["longitude"]
    # If 0..360 -> convert to -180..180
    if float(lon.max()) > 180:
        lon_new = ((lon + 180) % 360) - 180
        ds = ds.assign_coords(longitude=lon_new)
    # Sort longitudes increasing
    ds = ds.sortby("longitude")

    # Latitudes sometimes descending; keep as-is but slicing will handle
    ds.attrs["_u"] = u
    ds.attrs["_v"] = v
    ds.attrs["_z"] = z
    ds.attrs["_plev"] = plev
    ds.attrs["_tcoord"] = tcoord
    return ds


def crop_X(ds: xr.Dataset, tstamp: pd.Timestamp, lat0: float, lon0: float) -> np.ndarray:
    tcoord = ds.attrs["_tcoord"]
    u = ds.attrs["_u"]
    v = ds.attrs["_v"]
    z = ds.attrs["_z"]
    plev = ds.attrs["_plev"]

    # nearest time selection
    dsel = ds.sel({tcoord: np.datetime64(tstamp.to_datetime64())}, method="nearest")

    lat_min, lat_max = lat0 - CROP_DEG, lat0 + CROP_DEG
    lon_min, lon_max = lon0 - CROP_DEG, lon0 + CROP_DEG

    # slice latitude regardless of ascending/descending
    lat_vals = dsel["latitude"].values
    if lat_vals[0] > lat_vals[-1]:  # descending
        lat_slice = slice(lat_max, lat_min)
    else:
        lat_slice = slice(lat_min, lat_max)

    # longitude now guaranteed to be [-180,180] and sorted ascending
    lon_slice = slice(lon_min, lon_max)

    box = dsel.sel(latitude=lat_slice, longitude=lon_slice)

    # Guard: empty crop (can happen if bounds are outside dataset)
    if box.sizes.get("longitude", 0) == 0 or box.sizes.get("latitude", 0) == 0:
        raise RuntimeError(
            f"Empty crop: lat[{lat_min:.2f},{lat_max:.2f}] lon[{lon_min:.2f},{lon_max:.2f}] "
            f"-> got lat={box.sizes.get('latitude',0)} lon={box.sizes.get('longitude',0)}"
        )

    def pl(varname: str, level: int) -> np.ndarray:
        return box[varname].sel({plev: level}).values.astype(np.float32)

    u850 = pl(u, 850); v850 = pl(v, 850)
    u500 = pl(u, 500); v500 = pl(v, 500)
    z500 = pl(z, 500)

    X = np.stack([u850, v850, u500, v500, z500], axis=0)  # (F,H,W)

    Xt = torch.from_numpy(X).unsqueeze(0)  # (1,F,H,W)
    Xt = F.interpolate(Xt, size=(GRID, GRID), mode="bilinear", align_corners=False)
    return Xt.squeeze(0).numpy()  # (F,GRID,GRID)


def build(storm: str):
    df = pd.read_csv(TRACKS[storm])
    df["datetime_utc"] = pd.to_datetime(df["datetime_utc"], utc=True)
    df = df.sort_values("datetime_utc").reset_index(drop=True)

    # numeric meta
    df["vmax_kt"] = pd.to_numeric(df.get("vmax_kt"), errors="coerce")
    df["mslp_mb"] = pd.to_numeric(df.get("mslp_mb"), errors="coerce")

    ds = open_era5(ERA5[storm])

    X_list, past_list, meta_list, y_list, t0_list = [], [], [], [], []
    skipped = 0

    for i in tqdm(range(HIST, len(df)), desc=f"building {storm}"):
        if i + max(LEAD_STEPS) >= len(df):
            break

        lat0, lon0 = float(df.loc[i, "lat"]), float(df.loc[i, "lon"])
        t0 = df.loc[i, "datetime_utc"]

        # past positions
        past = []
        for k in range(HIST, 0, -1):
            past.append([float(df.loc[i-k, "lat"]), float(df.loc[i-k, "lon"])])
        past = np.array(past, dtype=np.float32)

        # metadata
        vmax = df.loc[i, "vmax_kt"]
        mslp = df.loc[i, "mslp_mb"]
        vmax = float(vmax) if pd.notna(vmax) else 0.0
        mslp = float(mslp) if pd.notna(mslp) else 0.0
        meta = np.array([vmax, mslp], dtype=np.float32)

        # unified X
        try:
            X = crop_X(ds, t0, lat0, lon0)
        except Exception as e:
            skipped += 1
            continue

        # targets
        y = []
        for s in LEAD_STEPS:
            y.append([float(df.loc[i+s, "lat"]), float(df.loc[i+s, "lon"])])
        y = np.array(y, dtype=np.float32)

        X_list.append(X)
        past_list.append(past)
        meta_list.append(meta)
        y_list.append(y)
        t0_list.append(str(t0))

    if len(X_list) == 0:
        raise RuntimeError(f"No samples built for {storm}. Check ERA5 coverage/time range.")

    out = os.path.join(OUTDIR, f"{storm}_samples.npz")
    np.savez_compressed(
        out,
        X=np.stack(X_list),          # (N,F,GRID,GRID)
        past=np.stack(past_list),    # (N,HIST,2)
        meta=np.stack(meta_list),    # (N,2)
        y=np.stack(y_list),          # (N,L,2)
        t0=np.array(t0_list)
    )
    print(f"Saved: {out} | N={len(X_list)} | skipped={skipped}")


if __name__ == "__main__":
    build("irma")
    build("ian")


building irma:  85%|████████▌ | 47/55 [00:00<00:00, 105.08it/s]


Saved: C:\Users\Adrija\Downloads\DFGCN\data\processed\samples\irma_samples.npz | N=21 | skipped=26


building ian:  74%|███████▍  | 23/31 [00:00<00:00, 86.37it/s]


Saved: C:\Users\Adrija\Downloads\DFGCN\data\processed\samples\ian_samples.npz | N=23 | skipped=0
