In [1]:
# models/gru_predictor.py

import os
import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.preprocessing import MinMaxScaler

class GRUTrafficPredictor:
    def __init__(self, data_pkl: str, models_dir: str):
        """
        data_pkl:    path to your traffic_model_ready.csv (or .pkl)
        models_dir:  directory where you saved gru_model.h5
        """
        # 1) load & sort the volume DataFrame
            
        self.df = pd.read_csv(data_pkl, parse_dates=["Timestamp"])
        self.df = self.df.sort_values("Timestamp").reset_index(drop=True)

        # 2) keep the array of timestamps for indexing
        self.timestamps = pd.to_datetime(self.df["Timestamp"]).values

        # 3) extract & fit a scaler on the raw volume series
        vols = self.df["Volume"].values.reshape(-1, 1)
        self.scaler = MinMaxScaler(feature_range=(0, 1))
        self.vols_scaled = self.scaler.fit_transform(vols)

        # 4) loads the pretrained GRU model
        model_path = os.path.join(models_dir, "gru_model.h5")
        self.model = tf.keras.models.load_model(model_path)

        # 5) infer the window length from the model’s input shape
        #    (shape is (None, window_length, 1))
        _, self.seq_len, _ = self.model.input_shape

    def predict(self, site_id: str, arm: str, timestamp: str) -> float:
        """
        Called from main.py as predictor.predict(A, loc, timestamp).
        We ignore `site_id` and `arm` here, since this GRU is univariate.
        """
        # find the insertion index for the split timestamp
        idx = np.searchsorted(self.timestamps, np.datetime64(timestamp))

        # grab the preceding window
        start = idx - self.seq_len
        if start < 0:
            raise ValueError("Not enough history before " + timestamp)
        window = self.vols_scaled[start:idx]

        # predict next step
        y_scaled = self.model.predict(window[np.newaxis, ...])[0, 0]

        # inverse‐scale and return volumetric flow
        return float(self.scaler.inverse_transform([[y_scaled]])[0, 0])
