In [None]:
# tensorflow_dataset.py
import ast
import rasterio
import rasterio.windows
import numpy as np
import pandas as pd
import tensorflow as tf

# -----------------------------
# Config
# -----------------------------
PATCH_SIZE = 33
SEQ_LEN = 6
HORIZONS = 3
BATCH_SIZE = 8

DEM_PATH = r"C:\Users\Ankit\Datasets_Forest_fire\merged_DEM_30m_32644_aligned_filled.tif"
LULC_PATH = r"C:\Users\Ankit\Datasets_Forest_fire\lulc_maps_tif\LULC_2015_clipped_30m_filled_categorical.tif"

CSV_INDEX = r"C:\Users\Ankit\Datasets_Forest_fire\sequence_index_hourly.csv"

# -----------------------------
# Load static rasters (DEM, LULC) into memory
# -----------------------------
with rasterio.open(DEM_PATH) as src:
    DEM = src.read(1)  # 2D numpy array (H, W)
with rasterio.open(LULC_PATH) as src:
    LULC = src.read(1)  # 2D numpy array (H, W)

# Basic checks
assert DEM.shape == LULC.shape, "DEM and LULC must have same shape/alignment."

H_RASTER, W_RASTER = DEM.shape
HALF = PATCH_SIZE // 2

# -----------------------------
# Helper: parse index strings like "[1, 2, 3]" to list[int]
# -----------------------------
def parse_list_column(s):
    # handle both already-list objects and string representations
    if isinstance(s, list):
        return s
    try:
        return list(map(int, ast.literal_eval(s)))
    except Exception:
        # fallback: split by comma
        s2 = s.strip("[] ")
        if s2 == "":
            return []
        return [int(x) for x in s2.split(",")]

# -----------------------------
# Core: single-sample extractor (returns numpy arrays)
# -----------------------------
def extract_sample(row, patch_size=PATCH_SIZE):
    """
    row: a pandas Series (one CSV row)
    returns: X -> np.float32 (SEQ_LEN, H, W, C)
             y -> np.float32 (HORIZONS, H, W)
    Assumes CSV columns:
      seq_band_idxs, target_band_idxs,
      era5_t2m_file, era5_d2m_file, era5_tp_file, era5_u10_file, era5_v10_file,
      viirs_file, row, col
    """
    r = int(row["row"])
    c = int(row["col"])
    half = patch_size // 2

    # boundary check
    if r - half < 0 or r + half >= H_RASTER or c - half < 0 or c + half >= W_RASTER:
        raise IndexError(f"Center {(r,c)} too close to edge for patch size {patch_size}")

    # parse indices
    seq_idxs = parse_list_column(row["seq_band_idxs"])
    target_idxs = parse_list_column(row["target_band_idxs"])

    # ERA5 variable filepaths
    era5_paths = [
        row["era5_t2m_file"],
        row["era5_d2m_file"],
        row["era5_tp_file"],
        row["era5_u10_file"],
        row["era5_v10_file"],
    ]
    # We'll build ERA5 patches with channel ordering: [timestep, H, W, var_channels]
    era5_seq = []  # will become (SEQ_LEN, H, W, 5)

    for band_idx in seq_idxs:
        # read the same (row,col) window from every ERA5 var and stack channels
        var_chs = []
        for var_path in era5_paths:
            with rasterio.open(var_path) as src:
                win = rasterio.windows.Window(c-half, r-half, patch_size, patch_size)
                # rasterio band indexing is 1-based; band_idx should be int
                patch = src.read(band_idx, window=win)  # shape (H, W)
                var_chs.append(patch)
        # stack variables -> shape (H, W, 5)
        timestep_patch = np.stack(var_chs, axis=-1)
        era5_seq.append(timestep_patch.astype(np.float32))

    era5_seq = np.stack(era5_seq, axis=0)  # (SEQ_LEN, H, W, 5)

    # Static patches (DEM + LULC)
    dem_patch = DEM[r-half:r+half+1, c-half:c+half+1].astype(np.float32)
    lulc_patch = LULC[r-half:r+half+1, c-half:c+half+1].astype(np.float32)

    # repeat across time
    dem_seq = np.repeat(dem_patch[None, :, :], len(seq_idxs), axis=0)  # (SEQ_LEN, H, W)
    lulc_seq = np.repeat(lulc_patch[None, :, :], len(seq_idxs), axis=0)  # (SEQ_LEN, H, W)

    # concatenate static channels to ERA5 per-timestep channels
    # era5_seq shape: (SEQ_LEN, H, W, 5)
    dem_seq = dem_seq[..., None]   # (SEQ_LEN, H, W, 1)
    lulc_seq = lulc_seq[..., None] # (SEQ_LEN, H, W, 1)
    X = np.concatenate([era5_seq, dem_seq, lulc_seq], axis=-1)  # (SEQ_LEN, H, W, 7)

    # Targets: read VIIRS bands for horizons (patches)
    y_list = []
    with rasterio.open(row["viirs_file"]) as src:
        for b in target_idxs:
            win = rasterio.windows.Window(c-half, r-half, patch_size, patch_size)
            patch = src.read(b, window=win)  # (H, W)
            y_list.append(patch.astype(np.float32))
    y = np.stack(y_list, axis=0)  # (HORIZONS, H, W)

    return X, y

# -----------------------------
# Generator for tf.data
# -----------------------------
def generator_from_df(df):
    """
    Yields X, y numpy arrays for each row in df.
    If you want random sampling of centers per row, modify this function.
    """
    for _, row in df.iterrows():
        try:
            X, y = extract_sample(row)
            # Optionally: normalization steps here (mean/std)
            yield X, y
        except Exception as e:
            # print and skip invalid rows (e.g., patch near edge)
            print("skipping row due to:", e)
            continue

# -----------------------------
# Build tf.data.Dataset
# -----------------------------
df_index = pd.read_csv(CSV_INDEX)

# Ensure columns names match: if your CSV uses different column names adapt here.
required = ["seq_band_idxs", "target_band_idxs",
            "era5_t2m_file", "era5_d2m_file", "era5_tp_file",
            "era5_u10_file", "era5_v10_file",
            "viirs_file", "row", "col"]
missing = [c for c in required if c not in df_index.columns]
if missing:
    raise ValueError(f"CSV is missing required columns: {missing}")

# create the tf.data.Dataset
output_signature = (
    tf.TensorSpec(shape=(None, PATCH_SIZE, PATCH_SIZE, 7), dtype=tf.float32),  # X: (SEQ_LEN, H, W, C) -> None for SEQ_LEN if variable
    tf.TensorSpec(shape=(None, PATCH_SIZE, PATCH_SIZE), dtype=tf.float32)       # y: (HORIZONS, H, W)
)

# If SEQ_LEN and HORIZONS are fixed we can put concrete shapes:
output_signature = (
    tf.TensorSpec(shape=(SEQ_LEN, PATCH_SIZE, PATCH_SIZE, 7), dtype=tf.float32),
    tf.TensorSpec(shape=(HORIZONS, PATCH_SIZE, PATCH_SIZE), dtype=tf.float32),
)

ds = tf.data.Dataset.from_generator(
    lambda: generator_from_df(df_index),
    output_signature=output_signature
)

ds = ds.shuffle(1024).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

# -----------------------------
# Quick test: iterate once
# -----------------------------
for X_batch, y_batch in ds.take(1):
    print("X batch shape:", X_batch.shape)  # (B, SEQ_LEN, H, W, C)
    print("y batch shape:", y_batch.shape)  # (B, HORIZONS, H, W)
