In [1]:
import os
import numpy as np
import pandas as pd
import rasterio
import tensorflow as tf
from concurrent.futures import ThreadPoolExecutor

# =====================
# CONFIG
# =====================
SEQ_LEN = 6                 # number of past timesteps to use
HORIZONS = 3                # number of future timesteps to predict
PATCH_SIZE = 13             # spatial patch size
HALF = PATCH_SIZE // 2
FILL_NAN_VALUE = 0.0

REQUIRED_COLS = [
    "era5_t2m_file", "era5_d2m_file", "era5_tp_file",
    "era5_u10_file", "era5_v10_file",
    "viirs_file", "dem_file", "lulc_file"
]

In [3]:
def _load_single_raster(path):
    """Helper to load one raster efficiently (single band)."""
    try:
        with rasterio.open(path, sharing=True) as src:
            arr = src.read(1, out_dtype="float32", masked=True)  # read first band
            if np.ma.is_masked(arr):
                arr = arr.filled(np.nan)
            arr = np.nan_to_num(arr, nan=FILL_NAN_VALUE).astype("float32", copy=False)
            return arr
    except Exception as e:
        print(f"⚠️ Error reading {path}: {e}")
        return None

In [4]:
def load_rasters(df, raster_cols, max_workers=8):
    """Load unique rasters into memory, cached, with multithreading."""
    all_paths = set()
    for col in raster_cols:
        if col in df.columns:
            all_paths.update(df[col].dropna().unique())
    all_paths = list(all_paths)

    cache = {}
    with ThreadPoolExecutor(max_workers=max_workers) as ex:
        results = list(ex.map(_load_single_raster, all_paths))

    for path, arr in zip(all_paths, results):
        if arr is not None:
            cache[path] = arr
    return cache


In [None]:
def compute_norm_stats(arrays):
    """Compute mean & std for normalization."""
    valid = np.concatenate([a[~np.isnan(a)].ravel() for a in arrays])
    mean, std = valid.mean(), valid.std()
    return mean, std

def normalize(arr, mean, std):
    return (arr - mean) / std

def encode_lulc(lulc_arr, known_classes=None):
    """One-hot encode LULC raster."""
    flat = lulc_arr.ravel().astype(int).reshape(-1, 1)

    enc = OneHotEncoder(categories=[known_classes] if known_classes else "auto", sparse=False)
    onehot = enc.fit_transform(flat)

    onehot_img = onehot.reshape(lulc_arr.shape[0], lulc_arr.shape[1], -1)
    return onehot_img, enc.categories_[0]


In [5]:
def _safe_center(h, w, patch_size=PATCH_SIZE):
    half = patch_size // 2
    r = np.clip(h // 2, half, h - half - 1)
    c = np.clip(w // 2, half, w - half - 1)
    return r, c

In [7]:
def _extract_patch(arr, row, col, patch_size=PATCH_SIZE):
    half = patch_size // 2
    return arr[row-half:row+half+1, col-half:col+half+1]

In [8]:
def build_sample(seq_rows, horizon_rows, cache):
    """
    Build one training sample:
      seq_rows: past SEQ_LEN rows from CSV
      horizon_rows: next HORIZONS rows from CSV
    Returns:
      X: (SEQ_LEN, PATCH_SIZE, PATCH_SIZE, 7)
      y: (HORIZONS, PATCH_SIZE, PATCH_SIZE)
    """
    seq_patches = []
    # --- sequence input ---
    for _, row in seq_rows.iterrows():
        bands = []
        for var in ["era5_t2m_file", "era5_d2m_file", "era5_tp_file", "era5_u10_file", "era5_v10_file"]:
            arr = cache[row[var]]
            r, c = _safe_center(*arr.shape)
            bands.append(_extract_patch(arr, r, c))
        # static vars
        dem = cache[row["dem_file"]]
        lulc = cache[row["lulc_file"]]
        r, c = _safe_center(*dem.shape)
        dem_patch = _extract_patch(dem, r, c)
        lulc_patch = _extract_patch(lulc, r, c)
        bands.append(dem_patch)
        bands.append(lulc_patch)

        stack = np.stack(bands, axis=-1)  # (P, P, 7)
        seq_patches.append(stack)

    X = np.stack(seq_patches, axis=0)  # (SEQ_LEN, P, P, 7)

    # --- horizon target ---
    horizon_patches = []
    for _, row in horizon_rows.iterrows():
        arr = cache[row["viirs_file"]]
        r, c = _safe_center(*arr.shape)
        horizon_patches.append(_extract_patch(arr, r, c))
    y = np.stack(horizon_patches, axis=0)  # (HORIZONS, P, P)

    return X.astype("float32"), y.astype("float32")

In [9]:
def make_generator(csv_path, cache):
    df = pd.read_csv(csv_path)
    missing = [c for c in REQUIRED_COLS if c not in df.columns]
    if missing:
        raise ValueError(f"CSV missing required columns: {missing}")

    for i in range(len(df) - SEQ_LEN - HORIZONS + 1):
        seq_rows = df.iloc[i : i + SEQ_LEN]
        horizon_rows = df.iloc[i + SEQ_LEN : i + SEQ_LEN + HORIZONS]
        yield build_sample(seq_rows, horizon_rows, cache)

In [10]:
def create_dataset(csv_path, cache, batch_size=4, shuffle=True, shuffle_buf=256):
    output_signature = (
        tf.TensorSpec(shape=(SEQ_LEN, PATCH_SIZE, PATCH_SIZE, 7), dtype=tf.float32),
        tf.TensorSpec(shape=(HORIZONS, PATCH_SIZE, PATCH_SIZE), dtype=tf.float32),
    )
    ds = tf.data.Dataset.from_generator(
        lambda: make_generator(csv_path, cache),
        output_signature=output_signature
    )
    if shuffle:
        ds = ds.shuffle(shuffle_buf, reshuffle_each_iteration=True)
    ds = ds.batch(batch_size)
    ds = ds.prefetch(tf.data.AUTOTUNE)
    return ds

In [11]:
if __name__ == "__main__":
    csv_path = r"C:\Users\Ankit\Datasets_Forest_fire\sequence_index_hourly_right.csv"
    df = pd.read_csv(csv_path)

    raster_cols = REQUIRED_COLS
    print("Loading rasters into memory...")
    cache = load_rasters(df, raster_cols, max_workers=8)
    print(f"Loaded {len(cache)} rasters into memory ✅")

    ds = create_dataset(csv_path, cache, batch_size=2)
    for X, y in ds.take(1):
        print("X shape:", X.shape)  # (B, SEQ_LEN, PATCH, PATCH, 7)
        print("y shape:", y.shape)  # (B, HORIZONS, PATCH, PATCH)

Loading rasters into memory...
Loaded 9 rasters into memory ✅
X shape: (2, 6, 13, 13, 7)
y shape: (2, 3, 13, 13)


In [13]:
for X, y in ds.take(1):
    print(type(X), type(y))

<class 'tensorflow.python.framework.ops.EagerTensor'> <class 'tensorflow.python.framework.ops.EagerTensor'>
