-
Notifications
You must be signed in to change notification settings - Fork 508
Description
[ENHANCEMENT] Multi-Scale encoder for robust Zero-Shot forecasting
Summary
This proposal outlines a Multi-Scale Encoder that processes time-series data at multiple resolutions (e.g., daily vs. weekly) in parallel. Each branch learns scale-specific patterns before a fusion step aggregates their outputs. Such a design significantly boosts zero-shot performance by enabling the model to generalize more effectively to new time-series domains with potentially unknown seasonalities.
Proposal
-
Parallel Branches for Different Timescales
-
Short-Term Branch
- Use a small patch size (e.g., 24 steps if your data is hourly, capturing a day’s pattern) with a stride that matches or slightly overlaps these patches.
- Ideal for high-frequency variations or local periodicities (daily cycles, intraday fluctuations, etc.).
-
Medium-Term Branch
- Larger patch size (e.g., 168 steps for a weekly cycle, if data is hourly).
- Aims to capture medium-range seasonality (weekly patterns, multi-day trends).
-
Optional Additional Branches
- If the dataset exhibits strong monthly or yearly periodicities, add further branches. For instance, a monthly branch (around 720 hours) or an annual branch (8760 hours).
- Each additional branch focuses on its distinct resolution, allowing the model to capture these patterns separately.
-
-
Branch-Specific Transformer Encoders
- Individual Patch & Norm Layers
- Each branch has its own patching layer (similar to the existing
Patchclass) but configured to its patch size. - Each branch may also have its own normalization (e.g.,
InstanceNorm), so that each scale handles outliers or mean shifts independently.
- Each branch has its own patching layer (similar to the existing
- Specialized Transformer Blocks
- In each branch, the patched input is passed through a stack of Transformer encoder layers (or a reduced set of layers, if parameters are a concern).
- Optionally share weights among branches if memory is limited (e.g., the same weights are reused, but each branch processes a different scaled input).
- Individual Patch & Norm Layers
-
Fusion Layer
- Concatenation or Cross-Attention
- Concatenate the hidden states from all branches along the feature dimension, then feed them into a small MLP or another Transformer block to produce a fused representation.
- Alternatively, implement a cross-attention mechanism so that each branch can attend to the outputs of the others, providing more granular inter-scale communication.
- Dimensionality Reduction
- If each branch outputs a large sequence, apply pooling (mean/max/attention pooling) to collapse them to a smaller dimension before concatenation. This helps control memory usage.
- Concatenation or Cross-Attention
-
Simplified Decoder
- Minimal Decoder Stack
- Replace the traditional T5 decoder (or any heavy seq2seq decoder) with a small set of layers, such as a single Transformer decoder block or a simple MLP that predicts the next values or the quantiles.
- Since the encoder branches already capture multi-scale patterns, the decoder’s role is primarily to integrate these signals into final predictions.
- Faster Inference
- A lean decoder improves inference speed and makes the model more practical for real-time or edge scenarios.
- Minimal Decoder Stack
Rationale
-
Broader Pattern Recognition
- By segmenting the input into multiple scales, the model separately learns day-level, week-level, month-level (etc.) patterns. This reduces confusion and improves adaptability when facing new time series with different or mixed seasonalities.
-
Reduced Overfitting
- Each branch focuses on its designated resolution, reducing the tendency of a single model to overfit to one dominant pattern in the training data.
-
Improved Generalization
- The fused representations from multiple scales offer a more holistic view of time-series behavior, aiding zero-shot scenarios where no specific fine-tuning on the new domain is performed.
-
Interpretability
- It becomes clearer which “branch” contributes most to certain predictions (e.g., if weekly effects dominate, the medium-term branch may carry stronger weight).
Additional Notes
-
Implementation Feasibility
- Each branch can replicate the same
PatchandInstanceNormmodules with different configurations. Memory usage may grow with each branch, so weight-sharing or using fewer layers can mitigate this. - Hyperparameters like patch size, stride, and number of branches can be tuned based on domain knowledge or validated on a hold-out set.
- Each branch can replicate the same
-
Potential Attention Optimizations
- If the time-series are lengthy, applying more efficient attention mechanisms (e.g., Performer, Linformer, Nystromformer) can reduce the
O(n^2)cost in each branch.
- If the time-series are lengthy, applying more efficient attention mechanisms (e.g., Performer, Linformer, Nystromformer) can reduce the
By incorporating these adjustments, the model gains the ability to learn and fuse scale-specific patterns effectively, leading to more robust zero-shot forecasting performance.