Dynamic Spatial Modeling (Time-Varying Graph)

In [None]:
def build_multiscale_inputs(x):
    # x: [B, T, N, F] where T >= 48
    short = x[:, -12:]                       # recent 1 hour
    mid   = x[:, -24::2]                     # downsample (10 min)
    long  = x[:, -48::4]                     # downsample (20 min)
    return short, mid, long


Multi-Scale Temporal Encoder

Each scale gets its own temporal model (Mamba / Transformer / GRU).

In [None]:
class TemporalEncoder(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.mamba = Mamba(d_model=dim)

    def forward(self, x):
        # x: [B, T, N, F]
        B, T, N, F = x.shape
        x = x.view(B*N, T, F)
        out = self.mamba(x)
        return out[:, -1].view(B, N, F)


Temporal Scale Attention 

Let the model learn how much to trust each scale.

In [None]:
class TemporalScaleAttention(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.attn = nn.Linear(dim, 1)

    def forward(self, feats):
        # feats: list of [B, N, F]
        scores = [self.attn(f) for f in feats]
        weights = torch.softmax(torch.stack(scores), dim=0)
        fused = sum(w * f for w, f in zip(weights, feats))
        return fused


Final Temporal Block

In [None]:
short, mid, long = build_multiscale_inputs(x)

f_s = self.temporal_short(short)
f_m = self.temporal_mid(mid)
f_l = self.temporal_long(long)

temporal_out = self.scale_attention([f_s, f_m, f_l])


Now the model understands both sudden jams and daily patterns

# TemporalEncoder explanation:
# - Input x has shape [B, T, N, F]
#   B = batch size
#   T = number of time steps
#   N = number of nodes (e.g., sensors)
#   F = feature dimension per node
#
# - `dim` passed to Mamba(d_model=dim) MUST equal F.
#   The feature dimension is fixed at initialization and
#   cannot change dynamically during forward().
#
# - Reshaping x to [B*N, T, F] treats each nodeâ€™s time series
#   as an independent temporal sequence.
#
# - Mamba is a state-space model (NOT a Transformer):
#   * No attention heads
#   * No head splitting
#   * No dimension permutation or reshuffling
#
# - Mamba preserves the input layout and returns output
#   with the same shape [B*N, T, F].
#
# - `out[:, -1]` selects the final time step embedding
#   for each (batch, node) sequence.
#
# - Reshaping back to [B, N, F] is safe because no
#   permutation was performed between view() operations.
#
# - This module performs per-node temporal encoding only;
#   spatial relationships must be modeled separately.
