In [1]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader


def collate_keep_meta(batch):
    # tensors
    x = torch.stack([b["x_hist"] for b in batch])      # [B, L, F]
    y = torch.stack([b["y_future"] for b in batch])    # [B, H]
    lat = torch.tensor([b["lat"] for b in batch], dtype=torch.float32)
    lon = torch.tensor([b["lon"] for b in batch], dtype=torch.float32)
    # keep metadata as simple Python lists/strings
    meta = {
        "tile_id":   [b["tile_id"] for b in batch],
        "start_time": [str(b["start_time"]) for b in batch],  # stringify Timestamps
        "lat": lat, "lon": lon,
    }
    return {"x_hist": x, "y_future": y, "meta": meta}

# ---------- 1) Identify columns ----------
NON_FEATURE_COLS = {
    "lon","lat","time","source_file","PM25_MERRA2","PM25_ug_m3","class"
}

def get_feature_cols(df: pd.DataFrame):
    num_cols = df.select_dtypes(include=[np.number]).columns.tolist()
    feat_cols = [c for c in num_cols if c not in NON_FEATURE_COLS and c.lower() not in {"timestamp"}]
    return feat_cols

# ---------- 2) Parse & tidy ----------
def prepare_dataframe(df: pd.DataFrame, hourly=True, dayfirst=True, freq="H") -> pd.DataFrame:
    """
    - Parses timestamps from df['time'].
    - Builds tile_id from (lat, lon).
    - Optionally resamples per tile to an evenly spaced time grid (freq='H' or '30T').
    - Returns a tidy numeric frame where features are numeric and PM25_ug_m3 is present.
    """
    df = df.copy()

    # 1) timestamp + tile id
    df["timestamp"] = pd.to_datetime(df["time"], dayfirst=dayfirst, errors="coerce")
    df["tile_id"] = (df["lat"].round(4).astype(str) + "_" + df["lon"].round(4).astype(str))

    keep = ["timestamp", "tile_id", "lat", "lon", "PM25_ug_m3"] + get_feature_cols(df)
    df = df[keep].dropna(subset=["timestamp"]).sort_values(["tile_id", "timestamp"])

    if hourly:
        def _resample(g):
            tile = g["tile_id"].iloc[0]
            lat0 = float(g["lat"].iloc[0])
            lon0 = float(g["lon"].iloc[0])

            g = g.set_index("timestamp").sort_index()

            # numeric-only columns for resampling (avoid strings like tile_id)
            num_cols = g.select_dtypes(include=[np.number]).columns
            g_num = g[num_cols].resample(freq).mean()  # numeric_only implicitly True on numeric subset

            # fill gaps
            g_num = g_num.interpolate("time").ffill().bfill()

            # add metadata back
            g_num["tile_id"] = tile
            g_num["lat"] = lat0
            g_num["lon"] = lon0

            return g_num.reset_index()

        df = df.groupby("tile_id", group_keys=False).apply(_resample)

    # ensure numeric dtypes and fill any leftovers
    feat_cols = get_feature_cols(df)
    df[feat_cols + ["PM25_ug_m3"]] = df[feat_cols + ["PM25_ug_m3"]].astype(float).fillna(0.0)
    return df

# ---------- 3) Compute normalization stats ----------
def compute_feature_stats(df: pd.DataFrame):
    feat_cols = get_feature_cols(df)
    mean = df[feat_cols].mean().astype("float32").values
    std  = df[feat_cols].std(ddof=0).replace(0, 1.0).astype("float32").values
    return feat_cols, mean, std

# ---------- 4) Window index builder ----------
def build_indices(df: pd.DataFrame, L=168, H=72, stride=1):
    idx = []
    for tile, g in df.groupby("tile_id"):
        n = len(g)
        for t in range(L, n - H + 1, stride):  # note +1
            idx.append((tile, t))
    return idx

# ---------- 5) PyTorch Dataset ----------
class TSWindowDataset(Dataset):
    def __init__(self, df: pd.DataFrame, L=168, H=72, stride=1, stats=None):
        """
        df: output of prepare_dataframe()
        L: lookback length (hours)
        H: horizon length (72)
        stats: (feat_cols, mean, std) from compute_feature_stats(train_df)
        """
        self.df = df
        self.L, self.H = L, H
        self.feat_cols, self.mean, self.std = stats if stats is not None else compute_feature_stats(df)
        self.idx = build_indices(df, L, H, stride)

        # pre-slice groups to avoid repeated groupby in __getitem__
        self.groups = {tile: g.reset_index(drop=True) for tile, g in df.groupby("tile_id")}

    def __len__(self): return len(self.idx)

    def __getitem__(self, i):
        tile, t = self.idx[i]
        g = self.groups[tile]

        # history window [t-L .. t-1]
        hist = g.loc[t-self.L:t-1, self.feat_cols].values.astype(np.float32)   # [L, F]
        
        hist = ((hist - self.mean) / self.std).astype(np.float32)               # <- add .astype(np.float32)

        # future targets [t .. t+H-1]
        fut = g.loc[t:t+self.H-1, "PM25_ug_m3"].values.astype(np.float32)       # [H]

        # metadata
        start_ts = g.loc[t, "timestamp"]
        lat = float(g["lat"].iloc[0]); lon = float(g["lon"].iloc[0])

        return {
            "x_hist": torch.from_numpy(hist),        # [L, F]
            "y_future": torch.from_numpy(fut),       # [H]
            "tile_id": tile,
            "start_time": pd.Timestamp(start_ts),
            "lat": lat, "lon": lon,
        }

# ---------- 6) Quick usage + rich debug ----------
if __name__ == "__main__":
    # Load a CSV just for this run (swap to parquet reader when ready)
    df_raw = pd.read_csv("data.csv")  # expects: time, lat, lon, PM25_ug_m3 + numeric features

    # Build tidy frame
    df = prepare_dataframe(df_raw, hourly=True, dayfirst=True, freq="H")

    # Config (small so you see windows immediately)
    L, H, stride = 8, 4, 1

    # Feature stats
    feat_cols, mean, std = compute_feature_stats(df)
    F = len(feat_cols)

    # Dataset
    ds = TSWindowDataset(df, L=L, H=H, stride=stride, stats=(feat_cols, mean, std))

    # ----------------- High-level summary -----------------
    print("\n=== DATA SUMMARY ===")
    print("df shape:", df.shape)
    print("time range:", df["timestamp"].min(), "→", df["timestamp"].max())
    print("#tiles:", df["tile_id"].nunique())
    print("#features (F):", F)
    print("first 10 feature cols:", feat_cols[:10])
    print(f"window config: L={L}, H={H}, stride={stride}")
    print("#windows:", len(ds))

    # Per-tile coverage (top/bottom few)
    counts = df.groupby("tile_id")["timestamp"].count().sort_values()
    print("\nrows per tile (smallest 5):")
    print(counts.head(5))
    print("rows per tile (largest 5):")
    print(counts.tail(5))

    if len(ds) == 0:
        raise SystemExit("\n[!] No windows available. Increase coverage or lower L/H.")

    # ----------------- Inspect first window -----------------
    print("\n=== FIRST WINDOW DEBUG ===")
    tile0, t0 = ds.idx[0]
    g0 = ds.groups[tile0]
    b0 = ds[0]
    x0, y0 = b0["x_hist"], b0["y_future"]

    print("tile_id:", tile0)
    print("history shape [L,F]:", tuple(x0.shape), "| future shape [H]:", tuple(y0.shape))

    # time ranges for history & future
    hist_ts = g0.loc[t0-L:t0-1, "timestamp"].to_list()
    fut_ts  = g0.loc[t0:t0+H-1, "timestamp"].to_list()
    print("hist timestamps:", hist_ts[0], "→", hist_ts[-1])
    print("fut  timestamps:", fut_ts[0],  "→", fut_ts[-1])

    # check spacing is hourly
    hist_deltas = pd.Series(hist_ts).diff().dropna().dt.total_seconds().unique()
    fut_deltas  = pd.Series(fut_ts).diff().dropna().dt.total_seconds().unique()
    print("hist Δt seconds (unique):", hist_deltas)
    print("fut  Δt seconds (unique):", fut_deltas)

    # normalization sanity: mean ~ 0, std ~ 1 over this window (rough check)
    print("x_hist window mean/std (overall):", float(x0.mean()), float(x0.std()))
    # show a single feature’s first few timesteps
    print("x_hist[0:5, 0] sample:", x0[:5, 0].tolist())
    # show a few future targets
    print("y_future sample:", y0[:min(5, H)].tolist())

    # ----------------- Batch check via DataLoader -----------------
    loader = DataLoader(ds, batch_size=32, shuffle=True, drop_last=True,
                    collate_fn=collate_keep_meta)

    batch = next(iter(loader))
    Xb, Yb = batch["x_hist"], batch["y_future"]
    meta = batch["meta"]

    print("\n=== BATCH CHECK (custom collate) ===")
    print("batch x_hist shape [B,L,F]:", tuple(Xb.shape))
    print("batch y_future shape [B,H]:", tuple(Yb.shape))
    print("meta keys:", list(meta.keys()))
    print("tile_id[0]:", meta["tile_id"][0])
    print("start_time[0]:", meta["start_time"][0])
    print("lat/lon tensors:", tuple(meta["lat"].shape), tuple(meta["lon"].shape))

    # ----------------- PatchTST token check (dev aid) -----------------
    # If you plan to patchify later, this shows expected token count.
    def patchify(x, P=16, S=8):  # x: [B,L,F] -> [B,N,P*F]
        B,L_,F_ = x.shape
        N = (L_ - P) // S + 1
        return torch.stack([x[:, s:s+P, :].reshape(B, P*F_) for s in range(0, L_-P+1, S)], dim=1)

    P, S = 4, 2  # small numbers just to visualize with L=8
    tokens = patchify(Xb, P=P, S=S)
    print("\n=== PATCHIFY SMOKE TEST ===")
    print(f"P={P}, S={S} -> N={(L-P)//S + 1}")
    print("tokens shape [B,N,P*F]:", tuple(tokens.shape))

    print("\nAll sanity checks passed.\n")


=== DATA SUMMARY ===
df shape: (24400, 55)
time range: 2021-01-10 00:00:00 → 2021-02-10 19:00:00
#tiles: 400
#features (F): 50
first 10 feature cols: ['DUEXTTAU', 'BCFLUXU', 'OCFLUXV', 'BCANGSTR', 'SUFLUXV', 'SSSMASS25', 'SSSMASS', 'OCSMASS', 'BCCMASS', 'BCSMASS']
window config: L=8, H=4, stride=1
#windows: 20000

rows per tile (smallest 5):
tile_id
26.5_36.875    24
28.5_30.625    24
28.5_30.0      24
28.5_29.375    24
28.5_28.75     24
Name: timestamp, dtype: int64
rows per tile (largest 5):
tile_id
31.0_25.0    764
26.0_25.0    764
27.5_25.0    764
25.5_25.0    764
22.0_25.0    764
Name: timestamp, dtype: int64

=== FIRST WINDOW DEBUG ===
tile_id: 22.0_25.0
history shape [L,F]: (8, 50) | future shape [H]: (4,)
hist timestamps: 2021-01-10 00:00:00 → 2021-01-10 07:00:00
fut  timestamps: 2021-01-10 08:00:00 → 2021-01-10 11:00:00
hist Δt seconds (unique): [3600.]
fut  Δt seconds (unique): [3600.]
x_hist window mean/std (overall): -0.3407861292362213 0.7213065028190613
x_hist[0:5, 0] sa

In [2]:
# ---- time-based split (80/20 by timestamp) ----
cutoff = df["timestamp"].quantile(0.80)
train_df = df[df["timestamp"] <= cutoff].copy()
val_df   = df[df["timestamp"] >  cutoff].copy()

# ---- stats on train only ----
feat_cols, mean, std = compute_feature_stats(train_df)

# ---- datasets ----
L, H, stride = 168, 72, 1   # real config
train_ds = TSWindowDataset(train_df, L=L, H=H, stride=stride, stats=(feat_cols, mean, std))
val_ds   = TSWindowDataset(val_df,   L=L, H=H, stride=stride, stats=(feat_cols, mean, std))

# ---- loaders (keep your custom collate) ----
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True,  drop_last=True, collate_fn=collate_keep_meta)
val_loader   = DataLoader(val_ds,   batch_size=32, shuffle=False, drop_last=False, collate_fn=collate_keep_meta)

print("#train windows:", len(train_ds), "| #val windows:", len(val_ds))

#train windows: 5620 | #val windows: 100


In [3]:
import math, torch.nn as nn, torch

def sinusoidal_positional_encoding(n_pos: int, d_model: int, device=None):
    pe = torch.zeros(n_pos, d_model, device=device)
    pos = torch.arange(0, n_pos, device=device).unsqueeze(1).float()
    div = torch.exp(torch.arange(0, d_model, 2, device=device).float() * (-math.log(10000.0)/d_model))
    pe[:, 0::2] = torch.sin(pos * div); pe[:, 1::2] = torch.cos(pos * div)
    return pe  # [N, d]

class PatchPosEncoder(nn.Module):
    def __init__(self, in_features, patch_len=16, stride=8, d_model=128):
        super().__init__()
        self.P, self.S, self.F, self.d = patch_len, stride, in_features, d_model
        self.proj = nn.Linear(self.P * self.F, self.d)

    def forward(self, x):           # x: [B, L, F]
        B, L, F = x.shape
        starts = range(0, L - self.P + 1, self.S)
        patches = [x[:, s:s+self.P, :].reshape(B, self.P*self.F) for s in starts]
        T = torch.stack(patches, dim=1)              # [B, N, P*F]
        T = self.proj(T)                             # [B, N, d]
        pe = sinusoidal_positional_encoding(T.size(1), self.d, device=x.device)
        return T + pe                                # [B, N, d]

class SimplePatcherHead(nn.Module):
    def __init__(self, in_features, L, H, patch_len=16, stride=8, d_model=128):
        super().__init__()
        self.enc = PatchPosEncoder(in_features, patch_len, stride, d_model)
        self.head = nn.Linear(d_model, H)

    def forward(self, x_hist):      # [B, L, F]
        tokens = self.enc(x_hist)   # [B, N, d]
        pooled = tokens[:, -1]      # last-token pool (or tokens.mean(dim=1)) 
        
        return self.head(pooled)    # [B, H]

In [4]:
F = len(feat_cols)
model = SimplePatcherHead(in_features=F, L=L, H=H, patch_len=16, stride=8, d_model=128)
batch = next(iter(train_loader))
x, y = batch["x_hist"], batch["y_future"]
y_hat = model(x)
print("y_hat:", y_hat.shape)  # [B, 72]

y_hat: torch.Size([32, 72])


# TESTER

In [None]:
import math, torch, torch.nn as nn
import pennylane as qml

class QuantumEmbedOnly(nn.Module):
    """
    Data-only embedding: angles = Linear(d_model -> n_qubits)
    Circuit: RY(angle_i) on each qubit, measure <Z_i> for all wires.
    """
    def __init__(self, d_model: int, n_qubits: int = 8):
        super().__init__()
        self.n_qubits = n_qubits
        self.angle_proj = nn.Linear(d_model, n_qubits)

        self.dev = qml.device("default.qubit", wires=n_qubits)

        @qml.qnode(self.dev, interface="torch", diff_method="parameter-shift")
        def circuit(angles_1d):
            for i in range(n_qubits):
                qml.RY(angles_1d[i], wires=i)
            # per-qubit Z expectation values
            return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]

        self.circuit = circuit

    def forward(self, pooled):  # pooled: [B, d_model], float32
        angles = torch.tanh(self.angle_proj(pooled)) * math.pi      # [B, n_qubits], float32

        outs = []
        for b in range(angles.shape[0]):
            o = self.circuit(angles[b].double())                    # list/tuple OR tensor
            if isinstance(o, (list, tuple)):
                o = torch.stack([oi if isinstance(oi, torch.Tensor)
                                  else torch.as_tensor(oi, dtype=torch.float64)
                                  for oi in o], dim=0)
            outs.append(o)

        qfeat = torch.stack(outs, dim=0).float()                    # [B, n_qubits], float32
        return qfeat

class PatchToQuantum72(nn.Module):
    """
    Patch+PosEnc (yours) -> last-token pool -> QuantumEmbedOnly -> Linear -> 72
    No changes to your earlier blocks.
    """
    def __init__(self, in_features: int, L: int, H: int,
                 patch_len=16, stride=8, d_model=128, n_qubits=8):
        super().__init__()
        self.enc   = PatchPosEncoder(in_features, patch_len, stride, d_model)  # encoder
        self.qemb  = QuantumEmbedOnly(d_model=d_model, n_qubits=n_qubits)      # quantum embedding only
        self.head  = nn.Linear(n_qubits, H)

    def forward(self, x_hist):            # x_hist: [B, L, F] 
        tokens = self.enc(x_hist)         # [B, N, d_model]
        # pooled = tokens[:, -1]            # [B, d_model]  (or tokens.mean(dim=1)) (or
        
        q_each = [self.qemb(tokens[:, i, :]) for i in range(tokens.shape[1])]  # list of [B, n_qubits]
        q_stack = torch.stack(q_each, dim=1)  # [B, N, n_qubits]
        qfeat = q_stack.mean(dim=1)           # [B, n_qubits]  (or attention here too)

        # qfeat  = self.qemb(pooled)        # [B, n_qubits]
        return self.head(qfeat)           # [B, 72]

In [6]:
F = len(feat_cols)
model = PatchToQuantum72(in_features=F, L=L, H=H, patch_len=16, stride=8,
                         d_model=128, n_qubits=8)

batch = next(iter(train_loader))
x, y = batch["x_hist"], batch["y_future"]         # x:[B,L,F], y:[B,72]
y_hat = model(x)                                  # -> [B,72]
print("y_hat shape:", y_hat.shape)

# quick gradient check
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
loss = nn.MSELoss()(y_hat, y)
opt.zero_grad(); loss.backward(); opt.step()
print("loss:", float(loss))

y_hat shape: torch.Size([32, 72])
loss: 2478.3671875


# QLSTM

In [5]:
# === QLSTM cell + wrapper that uses your PatchPosEncoder ===
import pennylane as qml
import torch
import torch.nn as nn

class zzfeatuermapQLSTM(nn.Module):
    """
    Quantum LSTM cell:
      concat([h_t, x_t]) -> Linear -> n_qubits
      -> IQPEmbedding + BasicEntanglerLayers -> Z expvals
      -> Linear -> hidden_size
      -> standard LSTM updates
    """
    def __init__(self, input_size, hidden_size, n_qubits=4, n_qlayers=1, backend="default.qubit"):
        super().__init__()
        self.input_size  = input_size
        self.hidden_size = hidden_size
        self.concat_size = input_size + hidden_size
        self.n_qubits    = n_qubits
        self.n_qlayers   = n_qlayers

        # Separate devices per gate
        self.dev_forget = qml.device(backend, wires=n_qubits)
        self.dev_input  = qml.device(backend, wires=n_qubits)
        self.dev_update = qml.device(backend, wires=n_qubits)
        self.dev_output = qml.device(backend, wires=n_qubits)

        # Gate circuits (same topology for each)  <-- NOTE the arg name: inputs
        def _circuit_forget(inputs, weights):
            qml.templates.IQPEmbedding(inputs, wires=range(n_qubits))
            qml.templates.BasicEntanglerLayers(weights, wires=range(n_qubits))
            return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]

        def _circuit_input(inputs, weights):
            qml.templates.IQPEmbedding(inputs, wires=range(n_qubits))
            qml.templates.BasicEntanglerLayers(weights, wires=range(n_qubits))
            return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]

        def _circuit_update(inputs, weights):
            qml.templates.IQPEmbedding(inputs, wires=range(n_qubits))
            qml.templates.BasicEntanglerLayers(weights, wires=range(n_qubits))
            return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]

        def _circuit_output(inputs, weights):
            qml.templates.IQPEmbedding(inputs, wires=range(n_qubits))
            qml.templates.BasicEntanglerLayers(weights, wires=range(n_qubits))
            return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]

        weight_shapes = {"weights": (n_qlayers, n_qubits)}
        # Older TorchLayer: just pass the QNode and weight_shapes
        self.qlayer_forget = qml.qnn.TorchLayer(
            qml.QNode(_circuit_forget, self.dev_forget, interface="torch"),
            weight_shapes
        )
        self.qlayer_input  = qml.qnn.TorchLayer(
            qml.QNode(_circuit_input,  self.dev_input,  interface="torch"),
            weight_shapes
        )
        self.qlayer_update = qml.qnn.TorchLayer(
            qml.QNode(_circuit_update, self.dev_update, interface="torch"),
            weight_shapes
        )
        self.qlayer_output = qml.qnn.TorchLayer(
            qml.QNode(_circuit_output, self.dev_output, interface="torch"),
            weight_shapes
        )

        # Classical pre/post
        self.clayer_in  = nn.Linear(self.concat_size, n_qubits)   # [h_t, x_t] -> n_qubits
        self.clayer_out = nn.Linear(n_qubits, hidden_size)        # Z-expvals -> hidden_size

    def forward(self, x, init_states=None):
        """
        x: [B, N, input_size]  (your Patch+PosEnc tokens)
        returns: hidden_seq [B, N, hidden], (h_T, c_T)
        """
        B, N, _ = x.size()
        if init_states is None:
            h_t = torch.zeros(B, self.hidden_size, device=x.device, dtype=x.dtype)
            c_t = torch.zeros(B, self.hidden_size, device=x.device, dtype=x.dtype)
        else:
            h_t, c_t = init_states

        h_list = []
        for t in range(N):
            x_t = x[:, t, :]                         # [B, input_size]
            v_t = torch.cat([h_t, x_t], dim=1)       # [B, hidden+input]
            y_t = self.clayer_in(v_t)                # [B, n_qubits]

            f_t = torch.sigmoid(self.clayer_out(self.qlayer_forget(y_t)))  # [B, hidden]
            i_t = torch.sigmoid(self.clayer_out(self.qlayer_input (y_t)))  # [B, hidden]
            g_t = torch.tanh   (self.clayer_out(self.qlayer_update(y_t)))  # [B, hidden]
            o_t = torch.sigmoid(self.clayer_out(self.qlayer_output(y_t)))  # [B, hidden]

            c_t = f_t * c_t + i_t * g_t
            h_t = o_t * torch.tanh(c_t)
            h_list.append(h_t.unsqueeze(1))

        hidden_seq = torch.cat(h_list, dim=1)        # [B, N, hidden]
        return hidden_seq, (h_t, c_t)


class PatchToQLSTM72(nn.Module):
    """
    PatchPosEncoder (yours) -> QLSTM over tokens -> Linear -> H (e.g., 72).
    """
    def __init__(self, in_features, L, H, patch_len=16, stride=8,
                 d_model=128, hidden_size=128, n_qubits=4, n_qlayers=1, backend="default.qubit"):
        super().__init__()
        self.enc   = PatchPosEncoder(in_features, patch_len, stride, d_model)  # uses your existing class
        self.qlstm = zzfeatuermapQLSTM(input_size=d_model, hidden_size=hidden_size,
                                       n_qubits=n_qubits, n_qlayers=n_qlayers, backend=backend)
        self.head  = nn.Linear(hidden_size, H)

    def forward(self, x_hist):               # x_hist: [B, L, F]
        tokens = self.enc(x_hist)            # [B, N, d_model]
        h_seq, (hT, cT) = self.qlstm(tokens) # [B, N, hidden]
        return self.head(h_seq[:, -1, :])    # [B, H]

In [6]:
# --- Build the QLSTM model (must run before the training loop) ---
F = len(feat_cols)  # number of features from your dataframe pipeline

# If you have lightning backends installed you can swap "default.qubit" -> "lightning.qubit" (CPU) or "lightning.gpu"
backend = "default.qubit"

qlstm_model = PatchToQLSTM72(
    in_features=F, L=L, H=H,
    patch_len=16, stride=8,
    d_model=128, hidden_size=128,
    n_qubits=4, n_qlayers=1,
    backend=backend
)

# (optional) quick smoke test so we fail early if shapes don’t match
with torch.no_grad():
    _ = qlstm_model(next(iter(train_loader))["x_hist"][:2])
print("QLSTM model built and forward pass OK.")

QLSTM model built and forward pass OK.


# QGRU

In [5]:
# === QGRU cell + wrapper that uses your PatchPosEncoder (drop-in alongside QLSTM) ===
import pennylane as qml
import torch
import torch.nn as nn

class zzfeatuermapQGRU(nn.Module):
    """
    Quantum GRU cell:
      Gates (reset/update/new) are VQCs.
      - For r,z gates: concat([h_t, x_t]) -> Linear -> n_qubits -> QNode -> Linear -> hidden_size
      - For candidate n_t: concat([r_t * h_t, x_t]) -> Linear -> n_qubits -> QNode -> Linear -> hidden_size
      - Standard GRU update: h_t = (1 - z_t) * n_t + z_t * h_{t-1}
    """
    def __init__(self, input_size, hidden_size, n_qubits=4, n_qlayers=1, backend="default.qubit"):
        super().__init__()
        self.n_inputs = input_size
        self.hidden_size = hidden_size
        self.concat_size = self.n_inputs + self.hidden_size
        self.n_qubits = n_qubits
        self.n_qlayers = n_qlayers
        self.backend = backend

        # unique wire names per gate (separate devices)
        self.wires_reset  = [f"wire_reset_{i}"  for i in range(self.n_qubits)]
        self.wires_update = [f"wire_update_{i}" for i in range(self.n_qubits)]
        self.wires_new    = [f"wire_new_{i}"    for i in range(self.n_qubits)]

        self.dev_reset  = qml.device(self.backend, wires=self.wires_reset)
        self.dev_update = qml.device(self.backend, wires=self.wires_update)
        self.dev_new    = qml.device(self.backend, wires=self.wires_new)

        # circuits: IQPEmbedding + BasicEntanglerLayers -> Z expvals
        def _circuit_reset(inputs, weights):
            qml.templates.IQPEmbedding(inputs, wires=self.wires_reset)
            qml.templates.BasicEntanglerLayers(weights, wires=self.wires_reset)
            return [qml.expval(qml.PauliZ(w)) for w in self.wires_reset]

        def _circuit_update(inputs, weights):
            qml.templates.IQPEmbedding(inputs, wires=self.wires_update)
            qml.templates.BasicEntanglerLayers(weights, wires=self.wires_update)
            return [qml.expval(qml.PauliZ(w)) for w in self.wires_update]

        def _circuit_new(inputs, weights):
            qml.templates.IQPEmbedding(inputs, wires=self.wires_new)
            qml.templates.BasicEntanglerLayers(weights, wires=self.wires_new)
            return [qml.expval(qml.PauliZ(w)) for w in self.wires_new]

        weight_shapes = {"weights": (n_qlayers, n_qubits)}
        self.qlayer_reset  = qml.qnn.TorchLayer(qml.QNode(_circuit_reset,  self.dev_reset,  interface="torch"), weight_shapes)
        self.qlayer_update = qml.qnn.TorchLayer(qml.QNode(_circuit_update, self.dev_update, interface="torch"), weight_shapes)
        self.qlayer_new    = qml.qnn.TorchLayer(qml.QNode(_circuit_new,    self.dev_new,    interface="torch"), weight_shapes)

        # classical pre/post
        self.clayer_in_gates = nn.Linear(self.concat_size, n_qubits)  # for r,z
        self.clayer_in_cand  = nn.Linear(self.concat_size, n_qubits)  # for candidate n with (r ⊙ h)
        self.clayer_out      = nn.Linear(self.n_qubits, self.hidden_size)

    def forward(self, x, init_state=None):
        """
        x: [B, N, input_size]  tokens from Patch+PosEnc
        returns: hidden_seq [B, N, hidden], h_T
        """
        B, N, _ = x.size()
        hidden_seq = []

        if init_state is None:
            h_t = torch.zeros(B, self.hidden_size, device=x.device, dtype=x.dtype)
        else:
            h_t = init_state

        for t in range(N):
            x_t = x[:, t, :]                              # [B, input_size]
            v_t = torch.cat((h_t, x_t), dim=1)            # [B, hidden+input]
            y_t = self.clayer_in_gates(v_t)               # [B, n_qubits]

            r_t = torch.sigmoid(self.clayer_out(self.qlayer_reset(y_t)))   # [B, hidden]
            z_t = torch.sigmoid(self.clayer_out(self.qlayer_update(y_t)))  # [B, hidden]

            v_cand = torch.cat((r_t * h_t, x_t), dim=1)   # [B, hidden+input]
            y_cand = self.clayer_in_cand(v_cand)          # [B, n_qubits]
            n_t = torch.tanh(self.clayer_out(self.qlayer_new(y_cand)))     # [B, hidden]

            h_t = (1.0 - z_t) * n_t + z_t * h_t
            hidden_seq.append(h_t.unsqueeze(1))

        hidden_seq = torch.cat(hidden_seq, dim=1)         # [B, N, hidden]
        return hidden_seq, h_t


class PatchToQGRU72(nn.Module):
    """
    PatchPosEncoder (yours) -> QGRU over tokens -> Linear -> H (e.g., 72).
    Same interface as PatchToQLSTM72 so you can swap easily.
    """
    def __init__(self, in_features, L, H, patch_len=16, stride=8,
                 d_model=128, hidden_size=128, n_qubits=4, n_qlayers=1, backend="default.qubit"):
        super().__init__()
        self.enc  = PatchPosEncoder(in_features, patch_len, stride, d_model)
        self.qgru = zzfeatuermapQGRU(input_size=d_model, hidden_size=hidden_size,
                                     n_qubits=n_qubits, n_qlayers=n_qlayers, backend=backend)
        self.head = nn.Linear(hidden_size, H)

    def forward(self, x_hist):                # x_hist: [B, L, F]
        tokens = self.enc(x_hist)             # [B, N, d_model]
        h_seq, hT = self.qgru(tokens)         # [B, N, hidden], [B, hidden]
        return self.head(h_seq[:, -1, :])     # [B, H]

In [6]:
# reuse your existing values: F=len(feat_cols), L, H, backend, train_loader
F = len(feat_cols)  # number of features from your dataframe pipeline

# If you have lightning backends installed you can swap "default.qubit" -> "lightning.qubit" (CPU) or "lightning.gpu"
backend = "default.qubit"

qgru_model = PatchToQGRU72(
    in_features=F, L=L, H=H,
    patch_len=16, stride=8,
    d_model=128, hidden_size=128,
    n_qubits=4, n_qlayers=1,
    backend=backend
)

with torch.no_grad():
    _ = qgru_model(next(iter(train_loader))["x_hist"][:2])
print("QGRU model built and forward pass OK.")

# If you want to train QGRU **without touching your training loop** that references `qlstm_model`,
# simply alias it:
qlstm_model = qgru_model

QGRU model built and forward pass OK.


# Training loop

In [7]:
# === Train QLSTM for multiple epochs and plot the loss ===
import time
import math
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

# --- config (tweak as you like) ---
EPOCHS = 30
LR = 1e-3
WEIGHT_DECAY = 0.0
GRAD_CLIP = 1.0  # helps stabilize PQC training
PRINT_EVERY = 1

# Pennylane's default.qubit + TorchLayer are CPU-friendly; keeping everything on CPU avoids device mismatches
device = torch.device("cpu")
qlstm_model.to(device)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(qlstm_model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
# optional cosine decay
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

def run_epoch(model, loader, train: bool):
    model.train(train)
    total = 0.0
    count = 0
    start = time.time()
    for batch in loader:
        x = batch["x_hist"].to(device)          # [B, L, F]
        y = batch["y_future"].to(device)        # [B, H]
        if train:
            optimizer.zero_grad()
        with torch.set_grad_enabled(train):
            y_hat = model(x)                    # [B, H]
            loss = criterion(y_hat, y)
            if train:
                loss.backward()
                if GRAD_CLIP is not None:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
                optimizer.step()
        total += float(loss) * x.size(0)
        count += x.size(0)
    elapsed = time.time() - start
    return total / max(count, 1), elapsed

train_losses, val_losses = [], []
lrs = []

print(f"Training on {device} for {EPOCHS} epochs...")
best_val = math.inf
best_state = None

for epoch in range(1, EPOCHS + 1):
    tr_loss, tr_time = run_epoch(qlstm_model, train_loader, train=True)
    va_loss, va_time = run_epoch(qlstm_model, val_loader,   train=False)
    train_losses.append(tr_loss)
    val_losses.append(va_loss)
    lrs.append(optimizer.param_groups[0]["lr"])

    if scheduler is not None:
        scheduler.step()

    if va_loss < best_val:
        best_val = va_loss
        best_state = {k: v.cpu().clone() for k, v in qlstm_model.state_dict().items()}

    if epoch % PRINT_EVERY == 0:
        print(f"Epoch {epoch:03d} | "
              f"train {tr_loss:.4f} ({tr_time:.1f}s) | "
              f"val {va_loss:.4f} ({va_time:.1f}s) | "
              f"lr {lrs[-1]:.2e}")

# (optional) load best
if best_state is not None:
    qlstm_model.load_state_dict(best_state)
    print(f"\nLoaded best model (val MSE = {best_val:.4f}).")

# --- Plot ---
plt.figure(figsize=(7,4.5))
plt.plot(train_losses, label="train")
plt.plot(val_losses, label="val")
plt.xlabel("epoch")
plt.ylabel("MSE loss")
plt.title("QLSTM: training and validation loss")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()

# (optional) plot learning rate
plt.figure(figsize=(7,3))
plt.plot(lrs)
plt.xlabel("epoch")
plt.ylabel("learning rate")
plt.title("Learning rate schedule")
plt.grid(True)
plt.tight_layout()
plt.show()

Training on cpu for 30 epochs...
Epoch 001 | train 2386.8321 (143.7s) | val 1973.1548 (2.0s) | lr 1.00e-03
Epoch 002 | train 1754.0602 (143.0s) | val 1386.1457 (2.1s) | lr 9.97e-04
Epoch 003 | train 1154.2798 (144.2s) | val 951.5430 (2.1s) | lr 9.89e-04



KeyboardInterrupt

