In [1]:
import os
import numpy as np
import pandas as pd
import rasterio
import tensorflow as tf
from concurrent.futures import ThreadPoolExecutor

# =====================
# CONFIG
# =====================
SEQ_LEN = 6                 # number of past timesteps to use
HORIZONS = 3                # number of future timesteps to predict
PATCH_SIZE = 13             # spatial patch size
HALF = PATCH_SIZE // 2
FILL_NAN_VALUE = 0.0

REQUIRED_COLS = [
    "era5_t2m_file", "era5_d2m_file", "era5_tp_file",
    "era5_u10_file", "era5_v10_file",
    "viirs_file", "dem_file", "lulc_file"
]

In [2]:
def _load_single_raster(path):
    """Helper to load one raster efficiently (single band)."""
    try:
        with rasterio.open(path, sharing=True) as src:
            arr = src.read(1, out_dtype="float32", masked=True)  # read first band
            if np.ma.is_masked(arr):
                arr = arr.filled(np.nan)
            arr = np.nan_to_num(arr, nan=FILL_NAN_VALUE).astype("float32", copy=False)
            return arr
    except Exception as e:
        print(f"⚠️ Error reading {path}: {e}")
        return None

In [3]:
def load_rasters(df, raster_cols, max_workers=8):
    """Load unique rasters into memory, cached, with multithreading."""
    all_paths = set()
    for col in raster_cols:
        if col in df.columns:
            all_paths.update(df[col].dropna().unique())
    all_paths = list(all_paths)

    cache = {}
    with ThreadPoolExecutor(max_workers=max_workers) as ex:
        results = list(ex.map(_load_single_raster, all_paths))

    for path, arr in zip(all_paths, results):
        if arr is not None:
            cache[path] = arr
    return cache


In [4]:
def compute_norm_stats(arrays):
    """Compute mean & std for normalization."""
    valid = np.concatenate([a[~np.isnan(a)].ravel() for a in arrays])
    mean, std = valid.mean(), valid.std()
    return mean, std

def normalize(arr, mean, std):
    return (arr - mean) / std

def encode_lulc(lulc_arr, known_classes=None):
    """One-hot encode LULC raster."""
    flat = lulc_arr.ravel().astype(int).reshape(-1, 1)

    enc = OneHotEncoder(categories=[known_classes] if known_classes else "auto", sparse=False)
    onehot = enc.fit_transform(flat)

    onehot_img = onehot.reshape(lulc_arr.shape[0], lulc_arr.shape[1], -1)
    return onehot_img, enc.categories_[0]


In [5]:
def _safe_center(h, w, patch_size=PATCH_SIZE):
    half = patch_size // 2
    r = np.clip(h // 2, half, h - half - 1)
    c = np.clip(w // 2, half, w - half - 1)
    return r, c

In [6]:
def _extract_patch(arr, row, col, patch_size=PATCH_SIZE):
    half = patch_size // 2
    return arr[row-half:row+half+1, col-half:col+half+1]

In [7]:
def build_sample(seq_rows, horizon_rows, cache):
    """
    Build one training sample:
      seq_rows: past SEQ_LEN rows from CSV
      horizon_rows: next HORIZONS rows from CSV
    Returns:
      X: (SEQ_LEN, PATCH_SIZE, PATCH_SIZE, 7)
      y: (HORIZONS, PATCH_SIZE, PATCH_SIZE)
    """
    seq_patches = []
    # --- sequence input ---
    for _, row in seq_rows.iterrows():
        bands = []
        for var in ["era5_t2m_file", "era5_d2m_file", "era5_tp_file", "era5_u10_file", "era5_v10_file"]:
            arr = cache[row[var]]
            r, c = _safe_center(*arr.shape)
            bands.append(_extract_patch(arr, r, c))
        # static vars
        dem = cache[row["dem_file"]]
        lulc = cache[row["lulc_file"]]
        r, c = _safe_center(*dem.shape)
        dem_patch = _extract_patch(dem, r, c)
        lulc_patch = _extract_patch(lulc, r, c)
        bands.append(dem_patch)
        bands.append(lulc_patch)

        stack = np.stack(bands, axis=-1)  # (P, P, 7)
        seq_patches.append(stack)

    X = np.stack(seq_patches, axis=0)  # (SEQ_LEN, P, P, 7)

    # --- horizon target ---
    horizon_patches = []
    for _, row in horizon_rows.iterrows():
        arr = cache[row["viirs_file"]]
        r, c = _safe_center(*arr.shape)
        horizon_patches.append(_extract_patch(arr, r, c))
    y = np.stack(horizon_patches, axis=0)  # (HORIZONS, P, P)

    return X.astype("float32"), y.astype("float32")

In [8]:
def make_generator(csv_path, cache):
    df = pd.read_csv(csv_path)
    missing = [c for c in REQUIRED_COLS if c not in df.columns]
    if missing:
        raise ValueError(f"CSV missing required columns: {missing}")

    for i in range(len(df) - SEQ_LEN - HORIZONS + 1):
        seq_rows = df.iloc[i : i + SEQ_LEN]
        horizon_rows = df.iloc[i + SEQ_LEN : i + SEQ_LEN + HORIZONS]
        yield build_sample(seq_rows, horizon_rows, cache)

In [9]:
def create_dataset(csv_path, cache, batch_size=4, shuffle=True, shuffle_buf=256):
    output_signature = (
        tf.TensorSpec(shape=(SEQ_LEN, PATCH_SIZE, PATCH_SIZE, 7), dtype=tf.float32),
        tf.TensorSpec(shape=(HORIZONS, PATCH_SIZE, PATCH_SIZE), dtype=tf.float32),
    )
    ds = tf.data.Dataset.from_generator(
        lambda: make_generator(csv_path, cache),
        output_signature=output_signature
    )
    if shuffle:
        ds = ds.shuffle(shuffle_buf, reshuffle_each_iteration=True)
    ds = ds.batch(batch_size)
    ds = ds.prefetch(tf.data.AUTOTUNE)
    return ds

In [10]:
if __name__ == "__main__":
    csv_path = r"C:\Users\Ankit\Datasets_Forest_fire\sequence_index_hourly_norm.csv"
    df = pd.read_csv(csv_path)

    raster_cols = REQUIRED_COLS
    print("Loading rasters into memory...")
    cache = load_rasters(df, raster_cols, max_workers=8)
    print(f"Loaded {len(cache)} rasters into memory ✅")

    ds = create_dataset(csv_path, cache, batch_size=2)
    for X, y in ds.take(1):
        print("X shape:", X.shape)  # (B, SEQ_LEN, PATCH, PATCH, 7)
        print("y shape:", y.shape)  # (B, HORIZONS, PATCH, PATCH)

Loading rasters into memory...
Loaded 9 rasters into memory ✅
X shape: (2, 6, 13, 13, 7)
y shape: (2, 3, 13, 13)


In [11]:
for X, y in ds.take(1):
    print(type(X), type(y))

<class 'tensorflow.python.framework.ops.EagerTensor'> <class 'tensorflow.python.framework.ops.EagerTensor'>


In [12]:
import tensorflow as tf
import numpy as np
import rasterio
import pandas as pd
import random
import ast

In [13]:
SEQ_LEN     = 6
HORIZONS    = [1, 2, 3]
PATCH_SIZE  = 13  
BATCH_SIZE  = 4

In [14]:
SEQ_CSV = r"C:\Users\Ankit\Datasets_Forest_fire\sequence_index_hourly_norm.csv"
df = pd.read_csv(SEQ_CSV)

In [15]:
#as in the read_csv , it was stored as string "[1,,2,3,4,5,6]", so through ast.literal_eval we convert it into python list
df["seq_band_idxs"] = df["seq_band_idxs"].apply(ast.literal_eval)
df["target_band_idxs"] = df["target_band_idxs"].apply(ast.literal_eval)

In [16]:
def read_patch(file, band_idx, row, col, size=PATCH_SIZE):
    with rasterio.open(file) as src:
        # Window (row, col) is center pixel
        row_off = max(row - size // 2, 0)
        col_off = max(col - size // 2, 0)
        window = rasterio.windows.Window(col_off, row_off, size, size)
        arr = src.read(band_idx+1, window=window)  # rasterio bands are 1-based
    return arr

In [17]:
def sample_generator():
    while True:
        # Randomly pick a row from index
        row = df.sample(1).iloc[0]

        # Random pixel location
        with rasterio.open(row["era5_t2m_file"]) as src:
            h, w = src.height, src.width
        r = random.randint(PATCH_SIZE//2, h - PATCH_SIZE//2 - 1)
        c = random.randint(PATCH_SIZE//2, w - PATCH_SIZE//2 - 1)

        # ---- Input sequence ----
        seq_bands = row["seq_band_idxs"]
        x_vars = []
        for f in ["era5_t2m_file", "era5_d2m_file", "era5_tp_file", 
                  "era5_u10_file", "era5_v10_file"]:
            var_stack = []
            for b in seq_bands:
                patch = read_patch(row[f], b, r, c)
                var_stack.append(patch)
            x_vars.append(np.stack(var_stack, axis=0))  # (time, H, W)

        # DEM (static)
        dem_patch = read_patch(row["dem_file"], 0, r, c)
        x_vars.append(np.repeat(dem_patch[None, :, :], SEQ_LEN, axis=0))

        # LULC (static categorical, already one-hot TIFF)
        lulc_patch = read_patch(row["lulc_file"], 0, r, c)
        x_vars.append(np.repeat(lulc_patch[None, :, :], SEQ_LEN, axis=0))

        x = np.stack(x_vars, axis=-1)  # shape: (time, H, W, channels)

        # ---- Targets ----
        tgt_bands = row["target_band_idxs"]
        y_vars = []
        for b in tgt_bands:
            patch = read_patch(row["viirs_file"], b, r, c)
            y_vars.append(patch)
        y = np.stack(y_vars, axis=0)  # shape: (horizons, H, W)

        yield x.astype(np.float32), y.astype(np.float32)

In [18]:
output_signature = (
    tf.TensorSpec(shape=(SEQ_LEN, PATCH_SIZE, PATCH_SIZE, None), dtype=tf.float32),
    tf.TensorSpec(shape=(len(HORIZONS), PATCH_SIZE, PATCH_SIZE), dtype=tf.float32)
)

In [19]:
dataset = tf.data.Dataset.from_generator(sample_generator, output_signature=output_signature)

In [20]:
TOTAL = 1000  
VAL_SPLIT = 0.2

In [21]:
dataset = dataset.shuffle(1000, reshuffle_each_iteration=False)

In [22]:
val_size = int(TOTAL * VAL_SPLIT)

In [23]:
val_dataset = dataset.take(val_size)
train_dataset = dataset.skip(val_size)

In [24]:
train_dataset = train_dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
val_dataset   = val_dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

In [25]:
from tensorflow.keras import layers, models

In [26]:
SEQ_LEN = 6           
PATCH_H = 13            
PATCH_W = 13          
CHANNELS = 7       
HORIZONS = 3            
LSTM_UNITS = 16      
CNN_FEATURES = 64 

In [27]:
def build_cnn_lstm_model(seq_len=SEQ_LEN, patch_h=PATCH_H, patch_w=PATCH_W, channels=CHANNELS, horizons=HORIZONS):

    inp = layers.Input(shape=(seq_len, patch_h, patch_w, channels))

    def build_cnn_block():
        model = models.Sequential([
            layers.Conv2D(32, (3,3), activation='relu', padding='same'),
            layers.MaxPooling2D((2,2)),
            layers.Conv2D(64, (3,3), activation='relu', padding='same'),
            layers.MaxPooling2D((2,2)),
            layers.Flatten(),
            layers.Dense(CNN_FEATURES, activation='relu')
        ])
        return model

    cnn = build_cnn_block()
    td = layers.TimeDistributed(cnn)(inp) 
    
    lstm_out = layers.LSTM(LSTM_UNITS)(td)

    #for spatial features as an output
    reg_out = layers.Dense(horizons * patch_h * patch_w, activation="linear")(lstm_out)
    reg_out = layers.Reshape((horizons, patch_h, patch_w),name="reg_out")(reg_out)

    #fire/no fire
    cls_out = layers.Dense(1, activation="sigmoid",name="cls_out")(lstm_out)  

    model = models.Model(inputs=inp, outputs=[reg_out, cls_out])
    return model

In [28]:
import tensorflow as tf

# Derive a binary fire/no-fire label from y_reg (any fire anywhere in any horizon)
FIRE_THRESHOLD = 0.5  # adjust if your VIIRS is 0/1 keep at 0.5

def add_cls_label_batched(X, y_reg):
    # y_reg: (B, HORIZONS, H, W)
    # take max over horizons+pixels -> (B,)
    y_cls = tf.reduce_max(y_reg, axis=[1, 2, 3])
    y_cls = tf.cast(y_cls > FIRE_THRESHOLD, tf.float32)  # (B,)
    y_cls = tf.expand_dims(y_cls, axis=-1)               # (B, 1)
    return X, {"reg_out": y_reg, "cls_out": y_cls}

train_dataset = train_dataset.map(add_cls_label_batched, num_parallel_calls=tf.data.AUTOTUNE)
val_dataset   = val_dataset.map(add_cls_label_batched,   num_parallel_calls=tf.data.AUTOTUNE)

In [29]:
model = build_cnn_lstm_model()

In [30]:
model.compile(optimizer='adam', loss='mse', metrics=['mae'])

In [31]:
model.summary()

In [32]:
from tensorflow.keras import callbacks

In [33]:
model = build_cnn_lstm_model(
    seq_len=SEQ_LEN,
    patch_h=PATCH_H,
    patch_w=PATCH_W,
    channels=CHANNELS,
    horizons=HORIZONS
)

In [34]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-4),
    loss={
        "reg_out": tf.keras.losses.MeanSquaredError(),
        "cls_out": tf.keras.losses.BinaryCrossentropy(from_logits=False),
    },
    loss_weights={"reg_out": 1.0, "cls_out": 0.3},
    metrics={"reg_out": [tf.keras.metrics.MeanAbsoluteError()],
             "cls_out": [tf.keras.metrics.BinaryAccuracy(), tf.keras.metrics.AUC()]},
)


In [35]:
early_stop = callbacks.EarlyStopping(
    monitor='val_loss',
    patience=5,            # stop if val_loss doesn't improve for 5 epochs
    restore_best_weights=True
)

checkpoint = callbacks.ModelCheckpoint(
    "best_cnn_lstm_model.h5",
    monitor='val_loss',
    save_best_only=True,
    verbose=1
)

In [None]:
history = model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=1,
    callbacks=[early_stop, checkpoint], 
    verbose=1,
)