## Data Preparation

We also referenced **WiSoSuper** (Caltech SURF 2020–2021), a public benchmark study of super-resolution models on wind and solar fields built from NREL WIND Toolkit and NSRDB data.

In this repository, we follow WiSoSuper mainly for its HSDS-based data access patterns and dataset construction workflow (e.g., cropping and tiling), while using our own choices for variables, downsampling, and storage format.

Reference: https://github.com/RupaKurinchiVendhan/WiSoSuper/blob/main/data.ipynb

### Key Differences from the Reference Implementation
See the dataset README for details: [datasets/README.md](datasets/README.md)

## 1) Setup

Install dependencies (run once):

In [None]:
!pip install -q h5pyd numpy matplotlib

In [None]:
import os, getpass

os.environ["NREL_API_KEY"] = getpass.getpass("Enter NREL API key (hidden): ")
print("Key loaded:", "NREL_API_KEY" in os.environ)

Set your NREL API key via an environment variable (recommended for public repos):

- **macOS/Linux**
  ```bash
  export NREL_API_KEY="YOUR_KEY"

- **Windows (PowerShell)**

  ```powershell
  setx NREL_API_KEY "YOUR_KEY"

## 2) Configuration

Edit `CONFIG` below to match your needs. Key choices to document in a public repo:

- `row_slice` / `col_slice`: what region you select and why
- `sampling_interval`: your intended temporal cadence
- filters (`min_mean`, `min_value_gt`): these can bias the distribution (explain or simplify)

In [None]:
import os
import math
import numpy as np
import h5pyd

CONFIG = {
    "h5_path": "/nrel/wtk-us.h5",
    "bucket": "nrel-pds-hsds",
    "endpoint": "https://developer.nrel.gov/api/hsds",
    "api_key_env": "NREL_API_KEY",
    "speed_dataset": "windspeed_100m",

    "row_slice": (0, 1600),
    "col_slice": (500, 2100),

    "sampling_interval": 48,

    "tile_hr": (100, 100),
    "down_factor": 5,
    "tile_keep_ratio": 0.30,

    "drop_if_nan": True,
    "min_value_gt": 0.0,
    "min_mean": 0.5,

    "train_ratio": 0.70,
    "val_ratio": 0.15,

    "max_samples": 40000,
    "seed": 42,

    "out_dir": "./dataset_wind_sr",
    "chunk_flush": 2048,
}

CONFIG


## 3) Utilities

This section defines helper functions for:
- average pooling downsampling (HR → LR)
- time-ordered timestep splitting (to reduce temporal leakage)
- safe HSDS file opening (API key via env var)
- streaming writes via NumPy memmap (memory-efficient dataset generation)

In [None]:
def ensure_dir(path: str) -> None:
    os.makedirs(path, exist_ok=True)

def avg_pool2d(arr: np.ndarray, k: int) -> np.ndarray:
    """
    Block average pooling for a 2D array (H, W) with kernel size k.
    Requires H and W to be divisible by k.
    """
    h, w = arr.shape
    if (h % k != 0) or (w % k != 0):
        raise ValueError(f"Input shape {arr.shape} must be divisible by k={k}.")
    return arr.reshape(h // k, k, w // k, k).mean(axis=(1, 3))

def open_hsds_file(cfg: dict):
    """
    Open remote HSDS file using an API key from an env var (do NOT hardcode keys).
    """
    api_key = os.environ.get(cfg["api_key_env"], "").strip()
    if not api_key:
        raise RuntimeError(
            f"Missing API key. Please set environment variable {cfg['api_key_env']}."
        )

    f = h5pyd.File(
        cfg["h5_path"],
        "r",
        bucket=cfg["bucket"],
        endpoint=cfg["endpoint"],
        api_key=api_key,
    )
    return f

class MemmapWriter:
    """
    Stream samples to disk using numpy memmap to avoid holding everything in RAM.
    """
    def __init__(self, path: str, shape: tuple, dtype=np.float32):
        self.path = path
        self.shape = shape
        self.dtype = dtype
        self.mm = np.memmap(path, dtype=dtype, mode="w+", shape=shape)

    def write(self, idx: int, arr: np.ndarray):
        self.mm[idx] = arr

    def flush(self):
        self.mm.flush()

    def close(self):
        self.flush()
        del self.mm  # ensure file is closed

def finalize_npy_from_memmap(dat_path: str, out_npy_path: str, final_shape: tuple, dtype=np.float32):
    """
    Convert a raw memmap .dat file into a standard .npy file with exact final shape.
    Note: this loads `final_shape` into RAM once. If too big, prefer zarr/npz chunks.
    """
    mm = np.memmap(dat_path, dtype=dtype, mode="r", shape=final_shape)
    arr = np.asarray(mm)
    np.save(out_npy_path, arr)
    del mm


## 4) Access HSDS and Inspect the Dataset

We connect to the NREL HSDS endpoint and verify the dataset shape and dtype.
Expected shape is `(time, y, x)`.

In [None]:
f = open_hsds_file(CONFIG)
dset_speed = f[CONFIG["speed_dataset"]]

print("Dataset:", CONFIG["speed_dataset"])
print("Shape  :", dset_speed.shape)   # (time, y, x)
print("Dtype  :", dset_speed.dtype)

## 5) Tile Extraction + Downsampling (per split)

For each timestep:
1. crop a fixed spatial region (`row_slice`, `col_slice`)
2. split into HR tiles (default `100×100`)
3. keep a random subset of tiles (`tile_keep_ratio`)
4. apply minimal filtering
5. generate LR tiles via **block average pooling**
6. stream-save tiles to disk using memmap

In [None]:
def build_all_tiles(
    dset_speed,
    cfg: dict,
    rng: np.random.Generator,
):
    tile_h, tile_w = cfg["tile_hr"]
    k = cfg["down_factor"]
    tile_lr_h, tile_lr_w = tile_h // k, tile_w // k

    r0, r1 = cfg["row_slice"]
    c0, c1 = cfg["col_slice"]

    total_timesteps = dset_speed.shape[0]
    valid_timesteps = np.arange(0, total_timesteps, cfg["sampling_interval"], dtype=np.int64)

    HR_tiles = []
    LR_tiles = []

    for ts in valid_timesteps:
        speed = dset_speed[int(ts), r0:r1, c0:c1]
        speed = np.asarray(speed, dtype=np.float32)

        n_h = speed.shape[0] // tile_h
        n_w = speed.shape[1] // tile_w
        if n_h <= 0 or n_w <= 0:
            continue

        tile_indices = [(i, j) for i in range(n_h) for j in range(n_w)]
        keep_n = max(1, int(len(tile_indices) * cfg["tile_keep_ratio"]))
        chosen = rng.choice(len(tile_indices), size=keep_n, replace=False)

        for idx in chosen:
            i, j = tile_indices[int(idx)]
            tile_hr = speed[i*tile_h:(i+1)*tile_h, j*tile_w:(j+1)*tile_w]

            # Filters (same as original)
            if cfg["drop_if_nan"] and np.isnan(tile_hr).any():
                continue
            if np.min(tile_hr) <= cfg["min_value_gt"]:
                continue
            if np.mean(tile_hr) < cfg["min_mean"]:
                continue

            tile_lr = avg_pool2d(tile_hr, k).astype(np.float32)

            HR_tiles.append(tile_hr.astype(np.float32))
            LR_tiles.append(tile_lr)

            if len(HR_tiles) >= cfg["max_samples"]:
                print("Reached max_samples.")
                return np.array(HR_tiles), np.array(LR_tiles)

        print(f"Timestep {int(ts)} processed. Total tiles: {len(HR_tiles)}")

    return np.array(HR_tiles), np.array(LR_tiles)

## 6) Generate Dataset (random tile-level split)

All HR/LR tiles are first collected across timesteps.
The dataset is then split into train/validation/test sets
using random shuffling at the tile level.

In [None]:
rng = np.random.default_rng(CONFIG["seed"])

ensure_dir(CONFIG["out_dir"])

# Build all tiles
HR_all, LR_all = build_all_tiles(dset_speed, CONFIG, rng)

print("Total tiles collected:", HR_all.shape[0])

# Random tile-level split (original implementation)
N = HR_all.shape[0]
indices = rng.permutation(N)

train_end = int(CONFIG["train_ratio"] * N)
val_end = train_end + int(CONFIG["val_ratio"] * N)

train_idx = indices[:train_end]
val_idx = indices[train_end:val_end]
test_idx = indices[val_end:]

HR_train, HR_val, HR_test = HR_all[train_idx], HR_all[val_idx], HR_all[test_idx]
LR_train, LR_val, LR_test = LR_all[train_idx], LR_all[val_idx], LR_all[test_idx]

# Save
np.save(os.path.join(CONFIG["out_dir"], "train_HR.npy"), HR_train)
np.save(os.path.join(CONFIG["out_dir"], "train_LR.npy"), LR_train)

np.save(os.path.join(CONFIG["out_dir"], "val_HR.npy"), HR_val)
np.save(os.path.join(CONFIG["out_dir"], "val_LR.npy"), LR_val)

np.save(os.path.join(CONFIG["out_dir"], "test_HR.npy"), HR_test)
np.save(os.path.join(CONFIG["out_dir"], "test_LR.npy"), LR_test)

print("Dataset saved using random tile-level split.")

## 7) Visualize a Few HR/LR Pairs (optional)

Loads `train_HR.npy` and `train_LR.npy` and plots a few random examples.

In [None]:
import matplotlib.pyplot as plt

train_hr = np.load(os.path.join(CONFIG["out_dir"], "train_HR.npy"), mmap_mode="r")
train_lr = np.load(os.path.join(CONFIG["out_dir"], "train_LR.npy"), mmap_mode="r")

n = train_hr.shape[0]
rng_vis = np.random.default_rng(CONFIG["seed"] + 123)
idxs = rng_vis.choice(n, size=min(5, n), replace=False)

fig, axes = plt.subplots(len(idxs), 2, figsize=(8, 3*len(idxs)))
if len(idxs) == 1:
    axes = np.array([axes])

for r, idx in enumerate(idxs):
    axes[r, 0].imshow(train_hr[idx], cmap="viridis")
    axes[r, 0].set_title(f"HR (100x100) idx={idx}")
    axes[r, 0].axis("off")

    axes[r, 1].imshow(train_lr[idx], cmap="viridis")
    axes[r, 1].set_title(f"LR (20x20) idx={idx}")
    axes[r, 1].axis("off")

plt.tight_layout()
plt.show()

print("Train HR shape:", train_hr.shape)
print("Train LR shape:", train_lr.shape)


## 8) Cleanup

Close the HSDS file handle when you are done.

In [None]:
f.close()
print("Closed HSDS file.")