In [None]:
# tensorflow_dataset_dual_lulc_no_coords.py
import ast
import rasterio
import rasterio.windows
import numpy as np
import pandas as pd
import tensorflow as tf

# -----------------------------
# Config
# -----------------------------
PATCH_SIZE = 33
SEQ_LEN = 6
HORIZONS = 3
BATCH_SIZE = 8

DEM_PATH   = r"C:\Users\Ankit\Datasets_Forest_fire\merged_DEM_30m_32644_aligned_filled.tif"
LULC_2015_PATH = r"C:\Users\Ankit\Datasets_Forest_fire\lulc_maps_tif\LULC_2015_clipped_30m_filled_categorical.tif"
LULC_2016_PATH = r"C:\Users\Ankit\Datasets_Forest_fire\lulc_maps_tif\LULC_2016_clipped_30m_filled_categorical.tif"

CSV_INDEX = r"C:\Users\Ankit\Datasets_Forest_fire\sequence_index_hourly.csv"

# -----------------------------
# Load static rasters (DEM, LULC) into memory
# -----------------------------
with rasterio.open(DEM_PATH) as src:
    DEM = src.read(1)
    DEM_TRANSFORM = src.transform
with rasterio.open(LULC_2015_PATH) as src:
    LULC_2015 = src.read(1)
with rasterio.open(LULC_2016_PATH) as src:
    LULC_2016 = src.read(1)

assert DEM.shape == LULC_2015.shape == LULC_2016.shape, "DEM and LULC maps must align."

H_RASTER, W_RASTER = DEM.shape
HALF = PATCH_SIZE // 2

# -----------------------------
# Helpers
# -----------------------------
def parse_list_column(s):
    if isinstance(s, list):
        return s
    try:
        return list(map(int, ast.literal_eval(s)))
    except Exception:
        s2 = s.strip("[] ")
        if s2 == "":
            return []
        return [int(x) for x in s2.split(",")]

# -----------------------------
# Core extraction
# -----------------------------
def extract_sample(row, patch_size=PATCH_SIZE):
    # 🔹 Pick patch center = raster center
    r = H_RASTER // 2
    c = W_RASTER // 2
    half = patch_size // 2

    # --- parse indices ---
    seq_idxs = parse_list_column(row["seq_band_idxs"])
    target_idxs = parse_list_column(row["target_band_idxs"])

    # --- ERA5 variable filepaths ---
    era5_paths = [
        row["era5_t2m_file"],
        row["era5_d2m_file"],
        row["era5_tp_file"],
        row["era5_u10_file"],
        row["era5_v10_file"],
    ]
    era5_seq = []
    for band_idx in seq_idxs:
        var_chs = []
        for var_path in era5_paths:
            with rasterio.open(var_path) as src:
                win = rasterio.windows.Window(c-half, r-half, patch_size, patch_size)
                patch = src.read(band_idx, window=win)
                var_chs.append(patch)
        timestep_patch = np.stack(var_chs, axis=-1)
        era5_seq.append(timestep_patch.astype(np.float32))
    era5_seq = np.stack(era5_seq, axis=0)  # (SEQ_LEN,H,W,5)

    # --- static patches ---
    dem_patch   = DEM[r-half:r+half+1, c-half:c+half+1].astype(np.float32)
    lulc2015_patch = LULC_2015[r-half:r+half+1, c-half:c+half+1].astype(np.float32)
    lulc2016_patch = LULC_2016[r-half:r+half+1, c-half:c+half+1].astype(np.float32)

    dem_seq   = np.repeat(dem_patch[None, :, :], len(seq_idxs), axis=0)
    lulc2015_seq = np.repeat(lulc2015_patch[None, :, :], len(seq_idxs), axis=0)
    lulc2016_seq = np.repeat(lulc2016_patch[None, :, :], len(seq_idxs), axis=0)

    dem_seq   = dem_seq[..., None]
    lulc2015_seq = lulc2015_seq[..., None]
    lulc2016_seq = lulc2016_seq[..., None]

    X = np.concatenate([era5_seq, dem_seq, lulc2015_seq, lulc2016_seq], axis=-1)  # (SEQ_LEN,H,W,8)

    # --- target VIIRS ---
    y_list = []
    with rasterio.open(row["viirs_file"]) as src:
        for b in target_idxs:
            win = rasterio.windows.Window(c-half, r-half, patch_size, patch_size)
            patch = src.read(b, window=win)
            y_list.append(patch.astype(np.float32))
    y = np.stack(y_list, axis=0)  # (HORIZONS,H,W)

    return X, y

# -----------------------------
# Build tf.data.Dataset
# -----------------------------
df_index = pd.read_csv(CSV_INDEX)

required = ["seq_band_idxs", "target_band_idxs",
            "era5_t2m_file", "era5_d2m_file", "era5_tp_file",
            "era5_u10_file", "era5_v10_file",
            "viirs_file"]
missing = [c for c in required if c not in df_index.columns]
if missing:
    raise ValueError(f"CSV is missing required columns: {missing}")

def generator_from_df(df):
    for _, row in df.iterrows():
        try:
            X, y = extract_sample(row)
            yield X, y
        except Exception as e:
            print("Skipping row:", e)
            continue

# Output signature with 8 channels
output_signature = (
    tf.TensorSpec(shape=(SEQ_LEN, PATCH_SIZE, PATCH_SIZE, 8), dtype=tf.float32),
    tf.TensorSpec(shape=(HORIZONS, PATCH_SIZE, PATCH_SIZE), dtype=tf.float32),
)

ds = tf.data.Dataset.from_generator(
    lambda: generator_from_df(df_index),
    output_signature=output_signature
)

ds = ds.shuffle(1024).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

# -----------------------------
# Quick test
# -----------------------------
for X_batch, y_batch in ds.take(1):
    print("X batch shape:", X_batch.shape)  # (B,SEQ_LEN,H,W,8)
    print("y batch shape:", y_batch.shape)  # (B,HORIZONS,H,W)


Skipping row: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 1, the array at index 0 has size 0 and the array at index 1 has size 33
Skipping row: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 1, the array at index 0 has size 0 and the array at index 1 has size 33
Skipping row: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 1, the array at index 0 has size 0 and the array at index 1 has size 33
Skipping row: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 1, the array at index 0 has size 0 and the array at index 1 has size 33
Skipping row: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 1, the array at index 0 has size 0 and the array at index 1 has size 33
Skipping row: all the input array dimensions 

In [None]:
# tensorflow_dataset_dual_lulc_no_coords.py
import ast
import rasterio
import rasterio.windows
import numpy as np
import pandas as pd
import tensorflow as tf

# -----------------------------
# Config
# -----------------------------
PATCH_SIZE = 33
SEQ_LEN = 6
HORIZONS = 3
BATCH_SIZE = 8

DEM_PATH   = r"C:\Users\Ankit\Datasets_Forest_fire\merged_DEM_30m_32644_aligned_filled.tif"
LULC_2015_PATH = r"C:\Users\Ankit\Datasets_Forest_fire\lulc_maps_tif\LULC_2015_clipped_30m_filled_categorical.tif"
LULC_2016_PATH = r"C:\Users\Ankit\Datasets_Forest_fire\lulc_maps_tif\LULC_2016_clipped_30m_filled_categorical.tif"

CSV_INDEX = r"C:\Users\Ankit\Datasets_Forest_fire\sequence_index_hourly.csv"

# -----------------------------
# Load static rasters (DEM, LULC) into memory
# -----------------------------
with rasterio.open(DEM_PATH) as src:
    DEM = src.read(1)
    DEM_TRANSFORM = src.transform
with rasterio.open(LULC_2015_PATH) as src:
    LULC_2015 = src.read(1)
with rasterio.open(LULC_2016_PATH) as src:
    LULC_2016 = src.read(1)

assert DEM.shape == LULC_2015.shape == LULC_2016.shape, "DEM and LULC maps must align."

H_RASTER, W_RASTER = DEM.shape
HALF = PATCH_SIZE // 2

# -----------------------------
# Helpers
# -----------------------------
def parse_list_column(s):
    if isinstance(s, list):
        return s
    try:
        return list(map(int, ast.literal_eval(s)))
    except Exception:
        s2 = s.strip("[] ")
        if s2 == "":
            return []
        return [int(x) for x in s2.split(",")]

# -----------------------------
# Core extraction
# -----------------------------
def extract_sample(row, patch_size=PATCH_SIZE):
    half = patch_size // 2

    # --- Reference: DEM center ---
    r_dem = H_RASTER // 2
    c_dem = W_RASTER // 2
    lon, lat = DEM_TRANSFORM * (c_dem, r_dem)  # pixel -> geo

    # --- parse band idxs ---
    seq_idxs = parse_list_column(row["seq_band_idxs"])
    target_idxs = parse_list_column(row["target_band_idxs"])

    # --- ERA5 variables ---
    era5_paths = [
        row["era5_t2m_file"],
        row["era5_d2m_file"],
        row["era5_tp_file"],
        row["era5_u10_file"],
        row["era5_v10_file"],
    ]
    era5_seq = []
    for band_idx in seq_idxs:
        var_chs = []
        for var_path in era5_paths:
            with rasterio.open(var_path) as src:
                # convert DEM center geo -> this raster's pixel
                c, r = src.index(lon, lat)
                if (r-half < 0 or c-half < 0 or 
                    r+half >= src.height or c+half >= src.width):
                    raise ValueError("Patch goes outside raster bounds.")
                win = rasterio.windows.Window(c-half, r-half, patch_size, patch_size)
                patch = src.read(band_idx, window=win)
                if patch.shape != (patch_size, patch_size):
                    raise ValueError("Empty patch extracted.")
                var_chs.append(patch)
        timestep_patch = np.stack(var_chs, axis=-1)
        era5_seq.append(timestep_patch.astype(np.float32))
    era5_seq = np.stack(era5_seq, axis=0)  # (SEQ_LEN,H,W,5)

    # --- Static DEM + LULC patches (from already loaded arrays) ---
    dem_patch   = DEM[r_dem-half:r_dem+half+1, c_dem-half:c_dem+half+1].astype(np.float32)
    lulc2015_patch = LULC_2015[r_dem-half:r_dem+half+1, c_dem-half:c_dem+half+1].astype(np.float32)
    lulc2016_patch = LULC_2016[r_dem-half:r_dem+half+1, c_dem-half:c_dem+half+1].astype(np.float32)

    dem_seq   = np.repeat(dem_patch[None, :, :], len(seq_idxs), axis=0)[..., None]
    lulc2015_seq = np.repeat(lulc2015_patch[None, :, :], len(seq_idxs), axis=0)[..., None]
    lulc2016_seq = np.repeat(lulc2016_patch[None, :, :], len(seq_idxs), axis=0)[..., None]

    X = np.concatenate([era5_seq, dem_seq, lulc2015_seq, lulc2016_seq], axis=-1)  # (SEQ_LEN,H,W,8)

    # --- VIIRS targets ---
    y_list = []
    with rasterio.open(row["viirs_file"]) as src:
        c, r = src.index(lon, lat)
        win = rasterio.windows.Window(c-half, r-half, patch_size, patch_size)
        for b in target_idxs:
            patch = src.read(b, window=win)
            if patch.shape != (patch_size, patch_size):
                raise ValueError("Empty VIIRS patch.")
            y_list.append(patch.astype(np.float32))
    y = np.stack(y_list, axis=0)

    return X, y


# -----------------------------
# Build tf.data.Dataset
# -----------------------------
df_index = pd.read_csv(CSV_INDEX)

required = ["seq_band_idxs", "target_band_idxs",
            "era5_t2m_file", "era5_d2m_file", "era5_tp_file",
            "era5_u10_file", "era5_v10_file",
            "viirs_file"]
missing = [c for c in required if c not in df_index.columns]
if missing:
    raise ValueError(f"CSV is missing required columns: {missing}")

def generator_from_df(df):
    for _, row in df.iterrows():
        try:
            X, y = extract_sample(row)
            yield X, y
        except Exception as e:
            print("Skipping row:", e)
            continue

# Output signature with 8 channels
output_signature = (
    tf.TensorSpec(shape=(SEQ_LEN, PATCH_SIZE, PATCH_SIZE, 8), dtype=tf.float32),
    tf.TensorSpec(shape=(HORIZONS, PATCH_SIZE, PATCH_SIZE), dtype=tf.float32),
)

ds = tf.data.Dataset.from_generator(
    lambda: generator_from_df(df_index),
    output_signature=output_signature
)

ds = ds.shuffle(1024).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

# -----------------------------
# Quick test
# -----------------------------
for X_batch, y_batch in ds.take(1):
    print("X batch shape:", X_batch.shape)  # (B,SEQ_LEN,H,W,8)
    print("y batch shape:", y_batch.shape)  # (B,HORIZONS,H,W)


Skipping row: Patch goes outside raster bounds.
Skipping row: Patch goes outside raster bounds.
Skipping row: Patch goes outside raster bounds.
Skipping row: Patch goes outside raster bounds.
Skipping row: Patch goes outside raster bounds.
Skipping row: Patch goes outside raster bounds.
Skipping row: Patch goes outside raster bounds.
Skipping row: Patch goes outside raster bounds.
Skipping row: Patch goes outside raster bounds.
Skipping row: Patch goes outside raster bounds.
Skipping row: Patch goes outside raster bounds.
Skipping row: Patch goes outside raster bounds.
Skipping row: Patch goes outside raster bounds.
Skipping row: Patch goes outside raster bounds.
Skipping row: Patch goes outside raster bounds.
Skipping row: Patch goes outside raster bounds.
Skipping row: Patch goes outside raster bounds.
Skipping row: Patch goes outside raster bounds.
Skipping row: Patch goes outside raster bounds.
Skipping row: Patch goes outside raster bounds.
Skipping row: Patch goes outside raster 