In [None]:
# =====================================================
# GADHT - Centralized Setup, Imports, and Environment
# =====================================================

# ----------------------------
# Standard Libraries
# ----------------------------
import os
import sys
import random
import logging
import warnings
from datetime import datetime
from typing import List, Tuple, Dict, Any, Optional, Union

# ----------------------------
# Numerical Computation & Data
# ----------------------------
import numpy as np
import pandas as pd
import yfinance as yf

# ----------------------------
# Visualization
# ----------------------------
import matplotlib.pyplot as plt
import seaborn as sns

# ----------------------------
# Machine Learning / Deep Learning
# ----------------------------
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, ConcatDataset, Subset

# ----------------------------
# Time Series Decomposition
# ----------------------------
from PyEMD import CEEMDAN

# ----------------------------
# Preprocessing & Evaluation
# ----------------------------
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import KFold, TimeSeriesSplit
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

# ----------------------------
# Explainability
# ----------------------------
import shap

# ----------------------------
# Utilities
# ----------------------------
from tqdm import tqdm

# =====================================================
# Global Configuration
# =====================================================

# Warnings
warnings.filterwarnings("ignore")

# Logging Configuration
logging.basicConfig(
    level=logging.INFO,
    format="[%(asctime)s] [%(levelname)s] - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)
logger = logging.getLogger(__name__)

# Device Setup (CPU / GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"🖥️ Using device: {device}")

# Global Random Seed
def set_global_seed(seed: int = 42) -> None:
    """Set random seed across random, numpy, and torch."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    logger.info(f"🔒 Global seed set to {seed}")

set_global_seed(42)

# Version Information
logger.info(
    "📦 Versions -> Torch: %s | NumPy: %s | Pandas: %s | Matplotlib: %s | Sklearn: %s | SHAP: %s | yfinance: %s | PyEMD: %s",
    torch.__version__, np.__version__, pd.__version__, plt.matplotlib.__version__,
    "sklearn", shap.__version__, yf.__version__, CEEMDAN.__module__.split('.')[0]
)

In [None]:
# ============================================================
# GADHT - Multi-Phase Financial Data Downloader (Yahoo Finance)
# ============================================================

PHASE_TICKERS = {
    "pretraining": [
        "NVDA",  # NVIDIA Corp. (Technology)
        "JPM",   # JPMorgan Chase (Financials)
        "XOM",   # ExxonMobil (Energy)
        "LIN",   # Linde plc (Materials)
        "PLD",   # Prologis, Inc. (Real Estate)
        "AMZN",  # Amazon.com Inc. (Consumer Discretionary)
        "PFE",   # Pfizer Inc. (Health Care)
        "BA",    # Boeing Co. (Industrials)
    ],
    "finetuning": [
        "TSLA",  # Tesla, Inc. (Consumer Discretionary)
        "JNJ",   # Johnson & Johnson (Health Care)
        "MSFT",  # Microsoft Corp. (Technology)
        "NKE",   # Nike, Inc. (Consumer Discretionary)
        "PG",    # Procter & Gamble (Consumer Staples)
        "UNH",   # UnitedHealth Group (Health Care)
    ],
    "zeroshot": [
        "META",  # Meta Platforms (Communication Services)
        "KO",    # Coca-Cola Co. (Consumer Staples)
        "CAT",   # Caterpillar Inc. (Industrials)
        "BAC",   # Bank of America (Financials)
        "AAPL",  # Apple Inc. (Technology)
        "NEE",   # NextEra Energy (Utilities)
    ],
}

ALL_TICKERS = sorted({ticker for phase in PHASE_TICKERS.values() for ticker in phase})

START_DATE = "2005-01-01"
END_DATE = "2025-05-31"
DATA_DIRECTORY = "data/ohlcv"

import shutil
if os.path.exists(DATA_DIRECTORY):
    shutil.rmtree(DATA_DIRECTORY)
os.makedirs(DATA_DIRECTORY, exist_ok=True)


def download_ohlcv_data(ticker: str, start: str, end: str, save_dir: str) -> None:
    """
    Download OHLCV data for a single ticker (sequential mode).
    """
    try:
        logger.info(f"📥 Downloading {ticker} ({start} → {end})")

        df = yf.download(ticker, start=start, end=end, progress=False)[
            ["Open", "High", "Low", "Close", "Volume"]
        ].dropna()

        if df.empty:
            logger.warning(f"⚠️ {ticker}: No valid OHLCV data found. Skipping.")
            return

        filepath = os.path.join(save_dir, f"{ticker}.csv")
        df.to_csv(filepath, index=True, index_label="Date")

        logger.info(f"✅ {ticker}: {len(df)} rows saved → {filepath}")

    except Exception as e:
        logger.error(f"❌ {ticker}: Error downloading ({e})")


logger.info(f"🚀 Starting SEQUENTIAL OHLCV download for {len(ALL_TICKERS)} tickers: {ALL_TICKERS}")

for ticker in ALL_TICKERS:
    download_ohlcv_data(ticker, START_DATE, END_DATE, DATA_DIRECTORY)

logger.info("🏁 All OHLCV asset data downloaded successfully.")

In [None]:
# =============================================================
# GADHT - Adaptive Dataset with Sliding-Window CEEMDAN Caching
# =============================================================

class FinancialIMFDataset(Dataset):
    """
    PyTorch dataset for GADHT model.
    
    Modes:
    - Supervised mode: predict closing price H days ahead.
    - Pretraining mode: randomly mask IMF components (self-supervised).
    
    Optimization:
        Instead of recomputing CEEMDAN for every window (very slow),
        we decompose the time series in overlapping segments (sliding-window CEEMDAN).
        This preserves local adaptivity while avoiding redundant computation.
    """

    def __init__(
        self,
        filepath: str,
        window_size: int = 30,
        max_imfs: int = 5,
        use_ceemdan: bool = True,
        pretraining_mode: bool = False,
        normalize: bool = True,
        prediction_horizon: int = 1,
        segment_size: int = 90,
        overlap: int = 30
    ):
        self.window_size = window_size
        self.prediction_horizon = prediction_horizon
        self.max_imfs = max_imfs
        self.use_ceemdan = use_ceemdan
        self.pretraining_mode = pretraining_mode
        self.segment_size = segment_size
        self.overlap = overlap

        # ----------------------------
        # Load OHLCV data
        # ----------------------------
        df = pd.read_csv(filepath)
        df = df[["Open", "High", "Low", "Close", "Volume"]].dropna()
        df = df.apply(pd.to_numeric, errors="coerce").dropna()

        if normalize:
            scaler = StandardScaler()
            df[df.columns] = scaler.fit_transform(df)

        self.data = df.values  # shape: (T, d)

        # ----------------------------
        # CEEMDAN decomposition
        # ----------------------------
        if self.use_ceemdan:
            self.imf_cache = self._precompute_imfs()
        else:
            self.imf_cache = None

        # ----------------------------
        # Build training samples
        # ----------------------------
        self.samples = self._build_sequences()

    def _precompute_imfs(self) -> np.ndarray:
        ceemdan = CEEMDAN()
        T, d = self.data.shape
        imf_cache = np.zeros((T, self.max_imfs, d), dtype=np.float32)

        for feature_index in range(d):
            signal = self.data[:, feature_index]
            for start in range(0, T - self.segment_size + 1, self.segment_size - self.overlap):
                end = start + self.segment_size
                segment = signal[start:end]
                imfs = ceemdan(segment)

                if imfs.shape[0] < self.max_imfs:
                    padding = self.max_imfs - imfs.shape[0]
                    imfs = np.pad(imfs, ((0, padding), (0, 0)), mode="constant")
                else:
                    imfs = imfs[:self.max_imfs]

                seg_len = end - start
                imf_cache[start:end, :, feature_index] = imfs.T[:seg_len]

        return imf_cache  # (T, max_imfs, d)

    def _build_sequences(self) -> List:
        sequences = []
        total_length = len(self.data)

        for i in range(total_length - self.window_size - self.prediction_horizon + 1):
            if self.use_ceemdan:
                imf_cube = self.imf_cache[i:i+self.window_size, :, :]  # (T, M, F)
                input_tensor = imf_cube
            else:
                window = self.data[i:i+self.window_size]  # (T, F)
                input_tensor = np.expand_dims(window, axis=1)  # (T, 1, F)

            target_close = self.data[i + self.window_size + self.prediction_horizon - 1, 3]
            sequences.append((input_tensor.astype(np.float32), np.float32(target_close)))

        return sequences

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int):
        input_tensor, target = self.samples[idx]

        if self.pretraining_mode:
            # Full IMF cube as target
            target_tensor = torch.tensor(input_tensor, dtype=torch.float32)  # (T, M, F)

            # Start from clean input
            masked_tensor = target_tensor.clone()

            # Build mask (M, F)
            mask = torch.ones(self.max_imfs, input_tensor.shape[-1], dtype=torch.float32)

            # Randomly choose IMF indices to mask
            num_masked = np.random.randint(1, 3)  # mask 1–2 IMFs
            masked_indices = np.random.choice(self.max_imfs, size=num_masked, replace=False)

            for m in masked_indices:
                masked_tensor[:, m, :] = 0.0
                mask[m, :] = 0.0

            return (
                masked_tensor,   # (T, M, F)
                target_tensor,   # (T, M, F)
                mask             # (M, F)
            )

        else:
            return (
                torch.tensor(input_tensor, dtype=torch.float32),  # (T, M, F) or (T, 1, F)
                torch.tensor(target, dtype=torch.float32)         # scalar
            )

In [None]:
# =============================================================
# GADHT - Dataset Loader and Safe Collation for DataLoader Batches
# =============================================================

def load_ticker_datasets(
    tickers: List[str],
    data_path: str = DATA_DIRECTORY,
    window_size: int = 30,
    max_imfs: int = 5,
    use_ceemdan: bool = True,
    pretraining_mode: bool = False,
    normalize: bool = True,
    prediction_horizon: int = 1
) -> ConcatDataset:
    """
    Load multiple FinancialIMFDataset instances from OHLCV CSV files.
    Skips missing or empty datasets automatically.

    Args:
        tickers (List[str]): List of ticker symbols (e.g., ['AAPL', 'TSLA']).
        data_path (str): Directory containing the OHLCV CSV files.
        window_size (int): Sequence length per sample.
        max_imfs (int): Number of IMFs per feature (if CEEMDAN enabled).
        use_ceemdan (bool): Whether to apply CEEMDAN decomposition.
        pretraining_mode (bool): If True, load in IMF-masking mode.
        normalize (bool): Apply Z-score normalization per feature.
        prediction_horizon (int): Days ahead to predict (default=1).

    Returns:
        ConcatDataset: Combined dataset of all valid tickers.
    """
    logger.info(f"📦 Loading datasets for {len(tickers)} tickers...")
    datasets = []

    for ticker in tqdm(tickers, desc="Loading datasets"):
        filepath = os.path.join(data_path, f"{ticker}.csv")

        if not os.path.isfile(filepath):
            logger.warning(f"⚠️ Missing file: {filepath}")
            continue

        try:
            dataset = FinancialIMFDataset(
                filepath=filepath,
                window_size=window_size,
                max_imfs=max_imfs,
                use_ceemdan=use_ceemdan,
                pretraining_mode=pretraining_mode,
                normalize=normalize,
                prediction_horizon=prediction_horizon
            )

            if len(dataset) == 0:
                logger.warning(f"⚠️ Empty dataset for {ticker}. Skipped.")
                continue

            datasets.append(dataset)
            logger.info(f"✅ {ticker} loaded with {len(dataset)} samples.")

        except Exception as e:
            logger.error(f"❌ Failed to load {ticker}: {e}")

    if not datasets:
        logger.critical("❌ No valid datasets could be loaded.")
        raise RuntimeError("No datasets available. Check input files and configuration.")

    logger.info(f"🎯 Successfully loaded {len(datasets)} datasets ({sum(len(ds) for ds in datasets)} total samples).")
    return ConcatDataset(datasets)


def safe_collate(batch: List) -> Union[dict, torch.Tensor, None]:
    """
    Custom collate function for DataLoader that ignores None samples
    and handles both supervised and pretraining modes.

    Returns:
        - Pretraining mode: dict with {"masked", "target", "mask"}
        - Supervised mode: dict with {"input", "target"}
    """
    # Filter out None values
    batch = [sample for sample in batch if sample is not None]
    if not batch:
        return None

    # Pretraining mode: (masked_tensor, original_tensor, mask)
    if isinstance(batch[0], tuple) and len(batch[0]) == 3:
        masked_tensors, originals, masks = zip(*batch)
        return {
            "masked": torch.stack(masked_tensors),
            "target": torch.stack(originals),
            "mask": torch.stack(masks)
        }

    # Supervised mode: (input_tensor, target)
    if isinstance(batch[0], tuple) and len(batch[0]) == 2:
        inputs, targets = zip(*batch)
        return {
            "input": torch.stack(inputs),
            "target": torch.stack(targets)
        }

    # Fallback → default PyTorch behavior
    return torch.utils.data.default_collate(batch)

In [None]:
# ==========================================================
# GADHT - Energy-Aware Multihead Attention
# ==========================================================

class EnergyAwareMultiheadAttention(nn.Module):
    """
    Multi-head self-attention layer with spectral energy weighting.

    Instead of simply adding energy to the output, this module 
    multiplies the attention logits by IMF-specific energy values 
    before the softmax, as described in the GADHT paper (Eq. 12).

    Args:
        d_model (int): Embedding dimension of inputs.
        num_heads (int): Number of attention heads.
    """

    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.head_dim = d_model // num_heads

        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        # Linear projections for Q, K, V
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.o_proj = nn.Linear(d_model, d_model)

    def forward(
        self,
        x: torch.Tensor,              # (B, T, d_model)
        energy: Optional[torch.Tensor] = None  # (B, T, 1) or (B, T)
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        
        B, T, _ = x.shape

        # Project Q, K, V
        Q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)  # (B, h, T, d_k)
        K = self.k_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)  # (B, h, T, d_k)
        V = self.v_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)  # (B, h, T, d_k)

        # Compute scaled dot-product attention logits
        attn_logits = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)  # (B, h, T, T)

        # Apply spectral energy weighting (broadcast along heads)
        if energy is not None:
            if energy.dim() == 2:
                energy = energy.unsqueeze(-1)  # (B, T, 1)
            # Expand to match attention logits
            energy = energy.unsqueeze(1)  # (B, 1, T, 1)
            attn_logits = attn_logits * energy

        # Softmax over keys
        attn_weights = torch.softmax(attn_logits, dim=-1)  # (B, h, T, T)

        # Weighted sum of values
        output = torch.matmul(attn_weights, V)  # (B, h, T, d_k)

        # Reshape back
        output = output.transpose(1, 2).contiguous().view(B, T, self.d_model)
        output = self.o_proj(output)

        return output, attn_weights

In [None]:
# ==========================================================
# GADHT - Generative Adaptive Decomposition Hierarchical Transformer
# ==========================================================

class GADHTModel(nn.Module):
    """
    GADHT: Generative Adaptive Decomposition Hierarchical Transformer.

    Modes:
      - Supervised forecasting: predict next-day closing price.
      - Pretraining: masked IMF reconstruction (reconstruct full (T, M, F) cube).

    Args:
        input_shape (Tuple[int, int, int]): (time_steps, num_imfs, num_features).
        d_model (int): Embedding dimension for Transformer.
        num_heads (int): Number of attention heads.
        num_layers (int): Number of temporal Transformer layers (per-IMF).
        dropout (float): Dropout rate.
        use_energy (bool): If True, use energy-weighted spectral attention.
        pretraining (bool): If True, model outputs IMF reconstruction instead of forecast.
        spectral_layers (int): Number of spectral attention layers across IMFs.
    """

    def __init__(
        self,
        input_shape: Tuple[int, int, int] = (30, 5, 5),
        d_model: int = 128,
        num_heads: int = 4,
        num_layers: int = 3,
        dropout: float = 0.1,
        use_energy: bool = True,
        pretraining: bool = False,
        spectral_layers: int = 1
    ):
        super().__init__()

        self.seq_len, self.num_imfs, self.num_features = input_shape
        self.d_model = d_model
        self.pretraining = pretraining
        self.use_energy = use_energy
        self.spectral_layers = spectral_layers

        # ----------------------------
        # Input embedding (feature -> d_model), shared across IMFs
        # ----------------------------
        self.input_proj = nn.Linear(self.num_features, d_model)

        # Fixed sinusoidal positional encoding (temporal)
        self.register_buffer("positional_encoding",
            self._build_positional_encoding(self.seq_len, d_model),
            persistent=False
        )

        # ----------------------------
        # Temporal Encoder (per-IMF): shared parameters for all IMFs
        # We reuse EnergyAwareMultiheadAttention with energy=None here.
        # ----------------------------
        self.temporal_layers = nn.ModuleList([
            EnergyAwareMultiheadAttention(d_model, num_heads) for _ in range(num_layers)
        ])
        self.temporal_norm = nn.LayerNorm(d_model)
        self.temporal_dropout = nn.Dropout(dropout)

        # ----------------------------
        # Spectral Encoder (across IMFs): energy-weighted attention over IMF tokens
        # Sequence length here is M (number of IMFs).
        # ----------------------------
        self.spectral_layers_mod = nn.ModuleList([
            EnergyAwareMultiheadAttention(d_model, num_heads) for _ in range(max(1, spectral_layers))
        ])
        self.spectral_norm = nn.LayerNorm(d_model)
        self.spectral_dropout = nn.Dropout(dropout)

        # ----------------------------
        # Output heads
        # ----------------------------
        if pretraining:
            # Decoder back to features for each time step and each IMF
            self.decoder = nn.Linear(d_model, self.num_features)
        else:
            # Forecast head (global pooled representation -> scalar)
            self.head = nn.Sequential(
                nn.Linear(d_model, 64),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(64, 32),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(32, 1)
            )

        # Initialize weights
        self._initialize_weights()

    # ======================================================
    # Initialization
    # ======================================================
    def _initialize_weights(self) -> None:
        """Xavier initialization for linear layers."""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

    def _build_positional_encoding(self, seq_len: int, d_model: int) -> torch.Tensor:
        """
        Build fixed sinusoidal positional encoding (temporal).
        Returns: (1, T, d_model)
        """
        pe = torch.zeros(seq_len, d_model)
        position = torch.arange(seq_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe.unsqueeze(0)  # (1, T, d_model)

    # ======================================================
    # Forward Pass
    # ======================================================
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass.

        Args:
            x: (B, T, M, F)

        Returns:
            - Pretraining: reconstruction tensor (B, T, M, F)
            - Forecasting: scalar prediction (B, 1)
        """
        B, T, M, F = x.shape
        assert T == self.seq_len and M == self.num_imfs and F == self.num_features, \
            f"Input shape mismatch: expected {(self.seq_len, self.num_imfs, self.num_features)}, got {(T, M, F)}"

        # ----------------------------
        # Compute IMF energy per sample (Eq. 10)
        # e_m = (1 / (T*F)) * sum_{t,j} x[t,m,j]^2
        # Shape: (B, M)
        # ----------------------------
        with torch.no_grad():
            energy = (x.pow(2).mean(dim=(1, 3)))  # (B, M)
            # Normalize energies for stability (softmax over IMFs)
            energy = torch.softmax(energy, dim=-1)

        # ----------------------------
        # Temporal encoding per IMF
        # Process each IMF as a separate sequence of length T, with shared weights.
        # ----------------------------
        # Reshape to (B*M, T, F) then project to (B*M, T, d_model)
        x_tm = x.permute(0, 2, 1, 3).contiguous().view(B * M, T, F)  # (B*M, T, F)
        x_tm = self.input_proj(x_tm)  # (B*M, T, d_model)

        # Add temporal positional encoding
        x_tm = x_tm + self.positional_encoding.to(x_tm.device)

        # Shared temporal Transformer layers (energy=None)
        for layer in self.temporal_layers:
            x_tm, _ = layer(x_tm, energy=None)  # (B*M, T, d_model)
            x_tm = self.temporal_dropout(x_tm)

        # Temporal pooling -> per-IMF embeddings: (B*M, d_model) -> (B, M, d_model)
        x_tm = self.temporal_norm(x_tm.mean(dim=1))  # (B*M, d_model)
        x_spec = x_tm.view(B, M, self.d_model)       # (B, M, d_model)

        # If pretraining, we need full (T, M, F) reconstruction.
        # Decode from temporal features BEFORE pooling using a lightweight decoder.
        if self.pretraining:
            # Re-run a small linear decoder on the sequence-level features x_tm_full
            # For better recon, we decode from the last temporal hidden states per time step.
            # We reuse the temporal sequence features computed above by
            # recomputing the temporal path without mean-pooling:
            x_seq = x.permute(0, 2, 1, 3).contiguous().view(B * M, T, F)   # (B*M, T, F)
            x_seq = self.input_proj(x_seq) + self.positional_encoding.to(x_seq.device)  # (B*M, T, d_model)
            for layer in self.temporal_layers:
                x_seq, _ = layer(x_seq, energy=None)  # (B*M, T, d_model)

            # Decode to features per time step
            rec = self.decoder(x_seq)  # (B*M, T, F)
            rec = rec.view(B, M, T, F).permute(0, 2, 1, 3).contiguous()  # (B, T, M, F)
            return rec

        # ----------------------------
        # Spectral encoding across IMFs (energy-weighted attention over M tokens)
        # x_spec: (B, M, d_model), energy: (B, M)
        # ----------------------------
        for layer in self.spectral_layers_mod:
            x_spec, _ = layer(x_spec, energy=energy)  # (B, M, d_model)
            x_spec = self.spectral_dropout(x_spec)

        x_spec = self.spectral_norm(x_spec)  # (B, M, d_model)

        # Spectral pooling -> global representation
        x_global = x_spec.mean(dim=1)  # (B, d_model)

        # ----------------------------
        # Forecast head
        # ----------------------------
        out = self.head(x_global)  # (B, 1)
        return out

    # ======================================================
    # Encoder-only mode (returns latent vector)
    # ======================================================
    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """
        Encoder-only forward pass that returns a pooled latent vector.
        Args:
            x: (B, T, M, F)
        Returns:
            (B, d_model)
        """
        B, T, M, F = x.shape

        with torch.no_grad():
            energy = torch.softmax((x.pow(2).mean(dim=(1, 3))), dim=-1)  # (B, M)

        # Temporal path
        x_tm = x.permute(0, 2, 1, 3).contiguous().view(B * M, T, F)
        x_tm = self.input_proj(x_tm) + self.positional_encoding.to(x_tm.device)
        for layer in self.temporal_layers:
            x_tm, _ = layer(x_tm, energy=None)
        x_tm = self.temporal_norm(x_tm.mean(dim=1))  # (B*M, d_model)
        x_spec = x_tm.view(B, M, self.d_model)

        # Spectral path
        for layer in self.spectral_layers_mod:
            x_spec, _ = layer(x_spec, energy=energy)
        x_spec = self.spectral_norm(x_spec)

        # Global pooled latent
        return x_spec.mean(dim=1)  # (B, d_model)

In [None]:
# ===========================================================
# GADHT - Pretraining Wrapper for Masked IMF Reconstruction
# ===========================================================

class GADHTPretrainingWrapper(nn.Module):
    """
    Wrapper for self-supervised pretraining via masked IMF reconstruction.

    This module wraps a GADHT model (initialized with pretraining=True) 
    and ensures the forward pass outputs a reconstruction of the full 
    (T, M, F) cube as required by the reconstruction loss in the paper.

    Args:
        encoder (GADHTModel): Base GADHT model with `pretraining=True`.
    """

    def __init__(self, encoder: GADHTModel):
        super().__init__()
        if not encoder.pretraining:
            raise ValueError("Encoder must be initialized with pretraining=True")

        self.encoder = encoder

    def forward(self, x_masked: torch.Tensor) -> torch.Tensor:
        """
        Forward pass for masked IMF reconstruction.

        Args:
            x_masked (Tensor): Masked sequence of shape (B, T, M, F).

        Returns:
            Tensor: Reconstructed sequence of shape (B, T, M, F).
        """
        return self.encoder(x_masked)

def reconstruction_loss(x_reconstructed, x_original, mask):
    """
    Masked MSE reconstruction loss.

    Args:
        x_reconstructed, x_original: (B, T, M, F)
        mask: (B, M, 1, 1) binary mask of which IMFs were kept (1) or masked (0).
    """
    masked_diff = (x_reconstructed - x_original) * mask.unsqueeze(1)  # broadcast over T and F
    return (masked_diff ** 2).mean()

In [None]:
# ===========================================================
# GADHT - Pretraining Loop for IMF Masking Reconstruction
# ===========================================================

def run_gadht_pretraining(
    model: nn.Module,
    dataloader: torch.utils.data.DataLoader,
    num_epochs: int = 150,
    learning_rate: float = 1e-5,
    early_stopping_patience: int = 10,
    max_gradient_norm: float = 1.0,
    checkpoint_path: str = "checkpoints/pretraining/gadht_pretrain.pth",
    plot_loss_curve: bool = True,
    loss_plot_path: str = "results/pretraining/gadht_pretrain_loss.png"
) -> None:
    """
    Self-supervised pretraining loop for GADHT (masked IMF reconstruction).
    
    Args:
        model (nn.Module): GADHTPretrainingWrapper (with encoder.pretraining=True).
        dataloader (DataLoader): Training batches with dict:
                                 {"masked": ..., "target": ..., "mask": ...}.
        num_epochs (int): Maximum number of training epochs.
        learning_rate (float): Optimizer learning rate.
        early_stopping_patience (int): Stop training if no improvement after N epochs.
        max_gradient_norm (float): Gradient clipping threshold.
        checkpoint_path (str): File path to save best model weights.
        plot_loss_curve (bool): Whether to save loss curve as PNG.
        loss_plot_path (str): Path to save the loss curve plot.
    """
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    best_loss = float("inf")
    patience_counter = 0
    epoch_losses = []

    os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
    os.makedirs(os.path.dirname(loss_plot_path), exist_ok=True)

    logger.info(f"🔧 Starting GADHT pretraining for {num_epochs} epochs...")

    for epoch in range(1, num_epochs + 1):
        model.train()
        total_loss = 0.0

        for batch in tqdm(dataloader, desc=f"[Epoch {epoch}/{num_epochs}]"):
            masked_inputs = batch["masked"].to(device)   # (B, T, M, F) with masked IMFs
            targets = batch["target"].to(device)         # (B, T, M, F) original sequence
            mask = batch["mask"].to(device)              # (B, M, F) or (B, M)

            # Forward
            reconstructions = model(masked_inputs)       # (B, T, M, F)

            # --- Fix: align mask dimensions with reconstructions ---
            if mask.dim() == 2:  # (B, M)
                mask = mask.unsqueeze(-1)  # (B, M, 1)
            if mask.dim() == 3:  # (B, M, F)
                mask = mask.unsqueeze(1)  # (B, 1, M, F)
            mask = mask.expand(-1, targets.size(1), -1, -1)  # (B, T, M, F)

            # Masked reconstruction loss: only on masked IMFs
            masked_diff = (reconstructions - targets) * (1 - mask)
            loss = (masked_diff ** 2).mean()

            # Backprop
            optimizer.zero_grad()
            loss.backward()
            if max_gradient_norm:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_gradient_norm)
            optimizer.step()

            total_loss += loss.item()

        scheduler.step()

        avg_loss = total_loss / len(dataloader)
        epoch_losses.append(avg_loss)
        logger.info(f"📉 Epoch {epoch} - Avg MSE: {avg_loss:.6f}")

        # Early stopping
        if avg_loss < best_loss:
            best_loss = avg_loss
            patience_counter = 0
            torch.save(model.state_dict(), checkpoint_path)
            logger.info(f"✅ Best model updated at epoch {epoch} (loss={best_loss:.6f})")
        else:
            patience_counter += 1
            logger.info(f"⏳ Early stopping patience: {patience_counter}/{early_stopping_patience}")

        if patience_counter >= early_stopping_patience:
            logger.info("🛑 Early stopping triggered.")
            break

    logger.info(f"🏁 Pretraining complete. Best MSE: {best_loss:.6f}")

    # Plot loss curve
    if plot_loss_curve:
        plt.figure(figsize=(8, 5))
        plt.plot(epoch_losses, label="Reconstruction Loss")
        plt.title("GADHT Pretraining Loss")
        plt.xlabel("Epoch")
        plt.ylabel("MSE Loss")
        plt.grid(True, linestyle="--", alpha=0.7)
        plt.legend()
        plt.tight_layout()
        plt.savefig(loss_plot_path)
        plt.close()
        logger.info(f"📊 Loss curve saved to {loss_plot_path}")

In [None]:
# ============================================================
# GADHT - Fine-Tuning & Evaluation Pipeline (Final Version)
# ============================================================

def run_gadht_finetune(
    model: nn.Module,
    dataloader: torch.utils.data.DataLoader,
    epochs: int = 150,
    lr: float = 1e-5,
    patience: int = 10,
    max_grad_norm: float = 1.0,
    save_path: Optional[str] = None,
    log_metrics: bool = True,
    plot_curve: bool = True,
    fig_path: Optional[str] = None
) -> None:
    """
    Fine-tune GADHT model on Close@t+1 forecasting with early stopping.

    Args:
        model (nn.Module): GADHT model with forecasting head.
        dataloader (DataLoader): Training data loader (dict with "input", "target").
        epochs (int): Maximum training epochs.
        lr (float): Learning rate.
        patience (int): Early stopping patience.
        max_grad_norm (float): Gradient clipping threshold.
        save_path (str, optional): File path to save best model weights.
        log_metrics (bool): Whether to log metrics each epoch.
        plot_curve (bool): Whether to save RMSE learning curve.
        fig_path (str, optional): Path to save RMSE curve figure.
    """
    # ----------------------------
    # Setup
    # ----------------------------
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    loss_fn = nn.MSELoss()

    best_rmse = float("inf")
    patience_counter = 0
    rmse_history = []

    # Ensure output directories exist
    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
    if fig_path:
        os.makedirs(os.path.dirname(fig_path), exist_ok=True)

    logger.info(f"🚀 Starting GADHT fine-tuning for {epochs} epochs...")

    # ----------------------------
    # Training loop
    # ----------------------------
    for epoch in range(1, epochs + 1):
        model.train()
        epoch_loss = 0.0
        valid_batches = 0
        preds_all, targets_all = [], []

        pbar = tqdm(dataloader, desc=f"Epoch {epoch}/{epochs}", leave=False)
        for batch in pbar:
            if batch is None:
                continue
            try:
                # Use dict keys from safe_collate
                x, targets = batch["input"].to(device), batch["target"].to(device).unsqueeze(1)

                # Forward + Loss
                outputs = model(x)
                loss = loss_fn(outputs, targets)

                # Backward
                optimizer.zero_grad()
                loss.backward()
                if max_grad_norm:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
                optimizer.step()

                # Metrics accumulation
                epoch_loss += loss.item()
                valid_batches += 1
                preds_all.append(outputs.detach().cpu().numpy())
                targets_all.append(targets.detach().cpu().numpy())

                pbar.set_postfix({"loss": f"{loss.item():.6f}"})

            except Exception as e:
                logger.error(f"⚠️ Skipped batch due to error: {e}")
                continue

        scheduler.step()

        # ----------------------------
        # End of epoch evaluation
        # ----------------------------
        if valid_batches > 0:
            preds = np.concatenate(preds_all).flatten()
            targets = np.concatenate(targets_all).flatten()

            rmse = np.sqrt(mean_squared_error(targets, preds))
            mae = mean_absolute_error(targets, preds)
            mape = np.mean(np.abs((targets - preds) / targets)) * 100
            r2 = r2_score(targets, preds)
            avg_loss = epoch_loss / valid_batches
            rmse_history.append(rmse)

            if log_metrics:
                logger.info(
                    f"📉 Epoch {epoch} — Loss: {avg_loss:.6f} | "
                    f"RMSE: {rmse:.4f} | MAE: {mae:.4f} | "
                    f"MAPE: {mape:.2f}% | R²: {r2:.4f}"
                )

            # Save best model
            if rmse < best_rmse:
                best_rmse = rmse
                patience_counter = 0
                if save_path:
                    torch.save(model.state_dict(), save_path)
                    logger.info(f"✅ Best model saved with RMSE {best_rmse:.4f}")
            else:
                patience_counter += 1
                logger.info(f"⏳ Patience counter: {patience_counter}/{patience}")

            # Early stopping
            if patience_counter >= patience:
                logger.info("🛑 Early stopping triggered.")
                break
        else:
            logger.warning(f"❌ Epoch {epoch}: No valid batches.")

    # ----------------------------
    # Training complete
    # ----------------------------
    logger.info(f"🏁 Fine-tuning complete. Best RMSE: {best_rmse:.4f}")

    # ----------------------------
    # Plot learning curve
    # ----------------------------
    if plot_curve and len(rmse_history) > 1 and fig_path:
        plt.figure(figsize=(8, 5))
        plt.plot(rmse_history, label="RMSE", linewidth=2)
        plt.xlabel("Epoch")
        plt.ylabel("RMSE")
        plt.title("GADHT Fine-Tuning RMSE Curve")
        plt.grid(True, linestyle="--", alpha=0.7)
        plt.legend()
        plt.tight_layout()
        plt.savefig(fig_path)
        plt.close()
        logger.info(f"📊 RMSE curve saved to {fig_path}")

In [None]:
# =====================================================
# 📈 Backtesting Utilities
# =====================================================

def collect_predictions(model: nn.Module, dataloader: torch.utils.data.DataLoader, device: torch.device = device):
    """
    Run inference over a dataloader and return flat numpy arrays.
    
    Args:
        model (nn.Module): Trained forecasting model.
        dataloader (DataLoader): DataLoader providing {"input", "target"} batches.
        device (torch.device): Device to run inference on.
    
    Returns:
        preds (np.ndarray): Model predictions, shape (N,).
        reals (np.ndarray): Ground-truth targets, shape (N,).
    """
    model.eval()
    all_preds, all_true = [], []

    with torch.no_grad():
        for batch in dataloader:
            if batch is None:
                continue

            # Use dict keys from safe_collate
            x, y = batch["input"].to(device), batch["target"].to(device)

            y_hat = model(x).squeeze()
            all_preds.append(y_hat.detach().cpu().numpy())
            all_true.append(y.detach().cpu().numpy())

    if not all_preds or not all_true:
        return np.array([]), np.array([])

    preds = np.concatenate(all_preds).ravel()
    reals = np.concatenate(all_true).ravel()
    return preds, reals


def backtest_strategy(preds: np.ndarray, reals: np.ndarray, cost_bps: float = 20, freq: int = 252) -> dict:
    """
    Simple long/short daily backtest driven by predicted price direction.
    
    Strategy:
        - Signal = sign of predicted price change (Δ preds).
        - PnL = signal × realized return.
        - Transaction costs applied when signal changes (entry, exit, flip).
    
    Args:
        preds (np.ndarray): Predicted prices (aligned with targets), shape (T,).
        reals (np.ndarray): Realized prices (true values), shape (T,).
        cost_bps (float): Transaction cost per trade in basis points (20 = 0.20%).
        freq (int): Trading days per year for annualization (default 252).
    
    Returns:
        dict with:
            - Annualized Return
            - Annualized Volatility
            - Sharpe
            - Max Drawdown
            - Win Rate
    """
    if len(preds) != len(reals) or len(preds) < 2:
        return {
            "Annualized Return": 0.0,
            "Annualized Volatility": 0.0,
            "Sharpe": 0.0,
            "Max Drawdown": 0.0,
            "Win Rate": 0.0,
        }

    # Realized returns
    real_rets = np.diff(reals) / reals[:-1]  # (T-1,)

    # Predicted signal (direction of price change)
    sig = np.sign(np.diff(preds))            # (T-1,)
    sig = np.where(sig == 0, 0, sig)         # flat if exactly zero

    # Strategy returns (before costs)
    strat = sig * real_rets

    # Transaction costs (per switch or flip)
    per_trade_cost = cost_bps / 10000.0
    costs = per_trade_cost * np.abs(np.diff(sig, prepend=0))
    strat -= costs

    # Equity curve
    equity = np.cumprod(1.0 + strat)
    if equity.size == 0:
        return {
            "Annualized Return": 0.0,
            "Annualized Volatility": 0.0,
            "Sharpe": 0.0,
            "Max Drawdown": 0.0,
            "Win Rate": 0.0,
        }

    # Max drawdown
    peak = np.maximum.accumulate(equity)
    drawdown = equity / peak - 1.0
    max_dd = float(drawdown.min())

    # Annualized metrics
    mu = strat.mean()
    sigma = strat.std()
    ann_ret = float(mu * freq)
    ann_vol = float(sigma * np.sqrt(freq))
    sharpe = float(ann_ret / ann_vol) if ann_vol > 0 else 0.0
    win_rate = float((strat > 0).mean())

    return {
        "Annualized Return": ann_ret,
        "Annualized Volatility": ann_vol,
        "Sharpe": sharpe,
        "Max Drawdown": max_dd,
        "Win Rate": win_rate,
    }

In [None]:
# ===========================================================
# GADHT - Rolling-Origin Cross-Validation Pipeline
# ===========================================================

def run_gadht_rolling_cv(
    tickers: List[str],
    model_name: str,
    dataset_name: str,
    pretrained_path: str,
    n_splits: int = 5,
    batch_size: int = 64,
    lr: float = 1e-5,
    epochs: int = 150,
    patience: int = 10,
    max_grad_norm: float = 1.0,
    seed: int = 42,
    prediction_horizon: int = 1
) -> None:
    """
    Execute GADHT fine-tuning with rolling-origin cross-validation.
    
    Each fold preserves chronological order to avoid lookahead bias,
    and evaluates both statistical and financial metrics:
        - Statistical: RMSE, MAE, MAPE, R²
        - Financial: Annualized Return, Volatility, Sharpe, MaxDD, WinRate
    """
    logger.info(
        f"🚀 Starting GADHT {n_splits}-fold rolling CV "
        f"on {dataset_name} | Horizon={prediction_horizon}"
    )
    os.makedirs("results/finetune", exist_ok=True)

    # ----------------------------
    # Dataset loading
    # ----------------------------
    dataset = load_ticker_datasets(
        tickers=tickers,
        data_path=DATA_DIRECTORY,
        window_size=30,
        max_imfs=5,
        use_ceemdan=True,
        pretraining_mode=False,
        normalize=True,
        prediction_horizon=prediction_horizon
    )

    # ----------------------------
    # Rolling-origin CV definition (chronological, no shuffle!)
    # ----------------------------
    tscv = TimeSeriesSplit(n_splits=n_splits)
    fold_metrics = []

    for fold_idx, (train_idx, val_idx) in enumerate(tscv.split(dataset)):
        logger.info(
            f"🌀 Fold {fold_idx+1}/{n_splits} | "
            f"Train: {len(train_idx)} | Val: {len(val_idx)}"
        )

        # DataLoaders
        train_loader = DataLoader(
            Subset(dataset, train_idx),
            batch_size=batch_size,
            shuffle=False,   # preserve time order
            collate_fn=safe_collate
        )
        val_loader = DataLoader(
            Subset(dataset, val_idx),
            batch_size=batch_size,
            shuffle=False,
            collate_fn=safe_collate
        )

        # ----------------------------
        # Model loading
        # ----------------------------
        model = GADHTModel(pretraining=False).to(device)
        state_dict = torch.load(pretrained_path, map_location=device)
        model.load_state_dict(state_dict, strict=False)

        # ----------------------------
        # Fine-tuning
        # ----------------------------
        run_gadht_finetune(
            model=model,
            dataloader=train_loader,
            epochs=epochs,
            lr=lr,
            patience=patience,
            max_grad_norm=max_grad_norm,
            save_path=None
        )

        # ----------------------------
        # Evaluation (val fold)
        # ----------------------------
        preds, reals = collect_predictions(model, val_loader, device=device)

        if len(preds) == 0:
            logger.warning(f"❌ Fold {fold_idx+1}: no predictions collected.")
            continue

        # Statistical metrics
        rmse = np.sqrt(mean_squared_error(reals, preds))
        mae = mean_absolute_error(reals, preds)
        mape = np.mean(np.abs((reals - preds) / reals)) * 100
        r2 = r2_score(reals, preds)
        metrics = {
            "RMSE": rmse,
            "MAE": mae,
            "MAPE": mape,
            "R²": r2,
        }

        # Financial metrics
        bt_metrics = backtest_strategy(preds, reals, cost_bps=20)
        metrics.update(bt_metrics)

        logger.info(f"📊 Fold {fold_idx+1} metrics: {metrics}")
        fold_metrics.append(metrics)

    # ----------------------------
    # Save aggregated metrics
    # ----------------------------
    if fold_metrics:
        df = pd.DataFrame(fold_metrics)
        csv_path = f"results/finetune/{model_name}_cv_metrics.csv"
        df.to_csv(csv_path, index=False)
        logger.info(f"✅ Saved CV metrics to {csv_path}")

        # Plot error bars
        plt.figure(figsize=(10, 5))
        plt.errorbar(df.columns, df.mean(), yerr=df.std(), fmt="o", capsize=5)
        plt.title(
            f"{dataset_name} — {n_splits}-Fold Rolling-Origin CV "
            f"(Horizon={prediction_horizon})"
        )
        plt.grid(True, linestyle="--", alpha=0.7)
        plt.tight_layout()
        plot_path = f"results/finetune/{model_name}_cv_plot.png"
        plt.savefig(plot_path)
        plt.close()
        logger.info(f"📊 CV plot saved to {plot_path}")
    else:
        logger.warning("❌ No metrics collected during CV.")

In [None]:
# ===========================================================
# GADHT - Model Evaluation for Close@t+H Forecasting
# ===========================================================

def evaluate_model(
    model: nn.Module,
    dataloader: torch.utils.data.DataLoader,
    return_predictions: bool = False,
    device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    cost_bps: float = 20
) -> Union[Dict[str, float], Tuple[Dict[str, float], np.ndarray, np.ndarray]]:
    """
    Evaluate a trained GADHT model on a labeled dataset.

    Computes:
        - Statistical metrics: RMSE, MAE, MAPE, R²
        - Financial metrics (via backtest): Annualized Return, Volatility, Sharpe, MaxDD, WinRate

    Args:
        model (nn.Module): Trained GADHT model (forecasting mode).
        dataloader (DataLoader): Evaluation data loader (dict with "input", "target").
        return_predictions (bool): If True, also return y_pred and y_true arrays.
        device (torch.device): Inference device (default = CUDA if available).
        cost_bps (float): Transaction cost in basis points for backtest.

    Returns:
        metrics (Dict[str, float]): Dictionary with evaluation metrics.
        Optionally:
            y_pred (np.ndarray): Predicted values, shape (N,).
            y_true (np.ndarray): Ground-truth values, shape (N,).
    """
    model.to(device)
    model.eval()

    all_preds, all_targets = [], []

    with torch.no_grad():
        for batch in dataloader:
            if batch is None:
                continue
            try:
                # Use dict keys from safe_collate
                x, y_true = batch["input"].to(device), batch["target"].to(device).unsqueeze(1)

                y_pred = model(x).detach().cpu().numpy()
                all_preds.append(y_pred)
                all_targets.append(y_true.cpu().numpy())

            except Exception as e:
                logger.warning(f"⚠️ Skipped batch during evaluation: {e}")
                continue

    if not all_preds or not all_targets:
        raise RuntimeError("❌ No valid predictions. Check dataloader or model output.")

    # Flatten arrays
    y_true = np.concatenate(all_targets).ravel()
    y_pred = np.concatenate(all_preds).ravel()

    # --- Statistical Metrics ---
    non_zero = y_true != 0
    if np.any(non_zero):
        mape = np.mean(np.abs((y_true[non_zero] - y_pred[non_zero]) / y_true[non_zero])) * 100
    else:
        mape = np.nan

    metrics = {
        "RMSE": round(np.sqrt(mean_squared_error(y_true, y_pred)), 4),
        "MAE": round(mean_absolute_error(y_true, y_pred), 4),
        "MAPE": round(mape, 2),
        "R2": round(r2_score(y_true, y_pred), 4),
    }

    # --- Financial Backtest Metrics ---
    try:
        bt_metrics = backtest_strategy(y_pred, y_true, cost_bps=cost_bps)
        metrics.update({
            "AnnReturn": round(bt_metrics["Annualized Return"], 4),
            "AnnVol": round(bt_metrics["Annualized Volatility"], 4),
            "Sharpe": round(bt_metrics["Sharpe"], 4),
            "MaxDD": round(bt_metrics["Max Drawdown"], 4),
            "WinRate": round(bt_metrics["Win Rate"], 4),
        })
    except Exception as e:
        logger.warning(f"⚠️ Backtest failed: {e}")

    # --- Logging summary ---
    logger.info(
        f"📊 Evaluation | "
        f"RMSE: {metrics['RMSE']:.4f} | "
        f"MAE: {metrics['MAE']:.4f} | "
        f"MAPE: {metrics['MAPE']:.2f}% | "
        f"R²: {metrics['R2']:.4f} | "
        f"Sharpe: {metrics.get('Sharpe', np.nan):.4f} | "
        f"MaxDD: {metrics.get('MaxDD', np.nan):.4f} | "
        f"WinRate: {metrics.get('WinRate', np.nan):.2f}"
    )

    return (metrics, y_pred, y_true) if return_predictions else metrics

In [None]:
# ===========================================================
# GADHT - Model Saving Utility
# ===========================================================

def save_gadht_model(
    model: nn.Module,
    filename: str = "gadht_pretrained.pt",
    directory: str = "checkpoints/pretraining",
    hyperparams: Optional[dict] = None,
    save_full_model: bool = False
) -> Optional[str]:
    """
    Save the state_dict of a GADHT model after pretraining or fine-tuning.
    Optionally also saves hyperparameters and/or the full model object.

    Args:
        model (nn.Module): GADHT model (or wrapper).
        filename (str): Target filename for the saved model.
        directory (str): Directory where the model will be stored.
        hyperparams (dict, optional): Dictionary of model hyperparameters to save alongside weights.
        save_full_model (bool): If True, save the full model object (not recommended for production).

    Returns:
        str or None: Full path of the saved model if successful, else None.
    """
    try:
        os.makedirs(directory, exist_ok=True)
        save_path = os.path.join(directory, filename)

        if save_full_model:
            # Warning: less portable, but keeps full class structure
            torch.save(model, save_path)
            logger.info(f"✅ Full model object saved → {save_path}")
        else:
            torch.save(model.state_dict(), save_path)
            logger.info(f"✅ Model state_dict saved → {save_path}")

        # Save hyperparameters if provided
        if hyperparams is not None:
            hp_path = os.path.splitext(save_path)[0] + "_hparams.json"
            with open(hp_path, "w") as f:
                json.dump(hyperparams, f, indent=4)
            logger.info(f"⚙️  Hyperparameters saved → {hp_path}")

        return save_path

    except (OSError, IOError) as e:
        logger.error(f"❌ File system error while saving model: {e}")
    except Exception as e:
        logger.error(f"❌ Unexpected error while saving model: {e}")

    return None

In [None]:
# ===========================================================
# GADHT - Model Loading Utility
# ===========================================================

def load_gadht_model(
    model: Optional[nn.Module],
    filename: str,
    directory: str = "checkpoints/pretraining",
    map_location: Optional[str] = None,
    strict: bool = True,
    load_hparams: bool = True
) -> Union[nn.Module, Tuple[nn.Module, dict]]:
    """
    Load model weights (and optionally hyperparameters) into an initialized GADHT architecture.
    
    Args:
        model (nn.Module or None): Initialized GADHT model. If None, will try to
                                   rebuild from saved hyperparameters.
        filename (str): Name of the saved weight file (.pt or .pth).
        directory (str): Folder where the model checkpoint is stored.
        map_location (Optional[str]): Device override ('cpu', 'cuda', or None).
        strict (bool): Whether to strictly enforce state_dict key matching.
        load_hparams (bool): If True, also try to load '<filename>_hparams.json'.
    
    Returns:
        model (nn.Module) or (model, hparams) if load_hparams=True and file exists.
    
    Raises:
        FileNotFoundError: If the checkpoint file does not exist.
        RuntimeError: If state_dict loading fails due to mismatch (when strict=True).
        Exception: For unexpected errors during loading.
    """
    load_path = os.path.join(directory, filename)
    try:
        location = map_location or torch.device("cpu")
        state_dict = torch.load(load_path, map_location=location)

        # Load hyperparameters if available
        hparams = None
        if load_hparams:
            hp_path = os.path.splitext(load_path)[0] + "_hparams.json"
            if os.path.exists(hp_path):
                with open(hp_path, "r") as f:
                    hparams = json.load(f)
                logger.info(f"⚙️ Loaded hyperparameters from {hp_path}")
                # If model not provided, try to re-init automatically
                if model is None:
                    model = GADHTModel(**hparams)
                    logger.info("🛠️ Reconstructed GADHTModel from hyperparameters.")

        if model is None:
            raise ValueError("Model is None and no hyperparameters were found to rebuild it.")

        model.load_state_dict(state_dict, strict=strict)
        logger.info(f"✅ Model weights loaded successfully from: {load_path}")

        return (model, hparams) if hparams is not None else model

    except FileNotFoundError:
        logger.error(f"❌ Model file not found: {load_path}")
        raise
    except RuntimeError as e:
        logger.error(f"❌ State_dict mismatch while loading model (strict={strict}): {e}")
        raise
    except Exception as e:
        logger.error(f"❌ Unexpected error while loading model: {e}")
        raise

In [None]:
# ===========================================================
# GADHT - Full Pretraining Pipeline (Tickers-Aware)
# ===========================================================

# ----------------------------
# 🔧 Pretraining configuration
# ----------------------------
PRETRAIN_EPOCHS = 50
PRETRAIN_BATCH_SIZE = 64
PRETRAIN_LR = 1e-5
PRETRAIN_PATIENCE = 8

CHECKPOINT_DIR = "checkpoints/pretraining"
PLOT_DIR = "results/pretraining"
PRETRAINED_MODEL_FILE = "gadht_pretrained.pt"
PRETRAIN_LOSS_CURVE_FILE = "pretraining_loss_curve.png"

os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(PLOT_DIR, exist_ok=True)

# ----------------------------
# Step 1 — Load dataset
# ----------------------------
pretrain_tickers = PHASE_TICKERS.get("pretraining", [])
if not pretrain_tickers:
    logger.critical("❌ No tickers defined for pretraining phase.")
    raise RuntimeError("No pretraining tickers available.")

logger.info(f"🔁 Loading pretraining dataset for tickers: {pretrain_tickers}")

try:
    pretrain_dataset = load_ticker_datasets(
        tickers=pretrain_tickers,
        data_path=DATA_DIRECTORY,
        window_size=30,
        max_imfs=5,
        use_ceemdan=True,
        pretraining_mode=True,   # IMF masking for self-supervised training
        normalize=True,
        prediction_horizon=1     # always 1-day ahead for pretraining
    )
except Exception as e:
    logger.error(f"❌ Failed to load pretraining dataset: {e}")
    raise

# ----------------------------
# Step 2 — DataLoader setup
# ----------------------------
pretrain_loader = DataLoader(
    pretrain_dataset,
    batch_size=PRETRAIN_BATCH_SIZE,
    shuffle=True,
    collate_fn=safe_collate
)

logger.info(
    f"📦 Pretraining dataset ready "
    f"({len(pretrain_dataset)} samples, batch_size={PRETRAIN_BATCH_SIZE})"
)

In [None]:
# ===========================================================
# Step 2 — 🧠 Initialize model (with wrapper for IMF reconstruction)
# ===========================================================

logger.info("🧠 Initializing GADHT base model and pretraining wrapper...")

try:
    # Base GADHT model in pretraining mode (outputs reconstructed (T, M, F) sequences)
    base_model = GADHTModel(pretraining=True)
    
    # Wrapper ensures correct interface for masked IMF reconstruction training
    pretrain_model = GADHTPretrainingWrapper(base_model)

    logger.info(
        f"✅ Model initialized | "
        f"d_model={base_model.d_model}, "
        f"IMFs={base_model.num_imfs}, "
        f"features={base_model.num_features}"
    )
except Exception as e:
    logger.critical(f"❌ Failed to initialize GADHT pretraining model: {e}")
    raise

In [None]:
# ===========================================================
# Step 3 — 🏋️ Launch self-supervised pretraining
# ===========================================================

logger.info("🏁 Starting GADHT self-supervised pretraining...")

try:
    run_gadht_pretraining(
        model=pretrain_model,
        dataloader=pretrain_loader,
        num_epochs=PRETRAIN_EPOCHS,
        learning_rate=PRETRAIN_LR,
        early_stopping_patience=PRETRAIN_PATIENCE,
        max_gradient_norm=1.0,
        checkpoint_path=os.path.join(CHECKPOINT_DIR, PRETRAINED_MODEL_FILE),
        plot_loss_curve=True,
        loss_plot_path=os.path.join(PLOT_DIR, PRETRAIN_LOSS_CURVE_FILE)
    )
    logger.info(f"✅ Pretraining completed successfully. Model saved to {os.path.join(CHECKPOINT_DIR, PRETRAINED_MODEL_FILE)}")

except Exception as e:
    logger.critical(f"❌ Pretraining failed: {e}")
    raise

In [None]:
# ===========================================================
# Step 4 — 💾 Save pretrained encoder weights
# ===========================================================

logger.info("💾 Saving pretrained encoder weights...")

try:
    saved_path = save_gadht_model(
        model=base_model,  # only encoder, not wrapper
        filename=PRETRAINED_MODEL_FILE,
        directory=CHECKPOINT_DIR
    )

    if saved_path:
        logger.info(f"✅ Pretrained encoder weights saved at {saved_path}")
    else:
        logger.warning("⚠️ Encoder weights could not be saved.")

    logger.info("🎯 GADHT pretraining completed. Encoder and loss curve saved.")

except Exception as e:
    logger.critical(f"❌ Failed to save pretrained encoder weights: {e}")
    raise

In [None]:
# ===========================================================
# GADHT - Fine-Tuning via Rolling CV + Final Model Training + Backtest
# ===========================================================

# ----------------------------
# 🔧 Fine-tuning configuration
# ----------------------------
FINETUNE_EPOCHS = 150
FINETUNE_BATCH_SIZE = 64
FINETUNE_LR = 1e-5
FINETUNE_PATIENCE = 10

PRETRAINED_MODEL_PATH = "checkpoints/pretraining/gadht_pretrained.pt"
FINETUNE_MODEL_DIR = "checkpoints/finetuned"
RESULTS_DIR = "results/finetune"
os.makedirs(FINETUNE_MODEL_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

# ----------------------------
# Assets to fine-tune on (single + multi-asset)
# ----------------------------
FINE_TUNE_TICKER_GROUPS = {
    "tsla": ["TSLA"],
    "jnj": ["JNJ"],
    "msft": ["MSFT"],
    "nke": ["NKE"],
    "pg":  ["PG"],
    "unh": ["UNH"],
    "all_assets": ["TSLA", "JNJ", "MSFT", "NKE", "PG", "UNH"]
}

HORIZONS = [1, 5, 10]  # run fine-tuning for 1, 5, and 10 days ahead


# ----------------------------
# Main fine-tuning loop
# ----------------------------
for model_name, tickers in FINE_TUNE_TICKER_GROUPS.items():
    for H in HORIZONS:
        try:
            logger.info(f"🚀 Running 5-fold CV for: {model_name.upper()} | Horizon={H}")

            # === 1. Rolling-origin CV (includes backtest inside evaluate_model)
            run_gadht_rolling_cv(
                tickers=tickers,
                model_name=f"finetuned_{model_name}_H{H}",
                dataset_name=" + ".join(tickers),
                pretrained_path=PRETRAINED_MODEL_PATH,
                n_splits=5,
                batch_size=FINETUNE_BATCH_SIZE,
                lr=FINETUNE_LR,
                epochs=FINETUNE_EPOCHS,
                patience=FINETUNE_PATIENCE,
                max_grad_norm=1.0,
                prediction_horizon=H
            )

            # === 2. Final training on full dataset
            logger.info(f"🧠 Training final model on full dataset: {model_name.upper()} | Horizon={H}")

            full_dataset = load_ticker_datasets(
                tickers=tickers,
                data_path=DATA_DIRECTORY,
                window_size=30,
                max_imfs=5,
                use_ceemdan=True,
                pretraining_mode=False,
                normalize=True,
                prediction_horizon=H
            )

            full_loader = DataLoader(
                full_dataset,
                batch_size=FINETUNE_BATCH_SIZE,
                shuffle=True,
                collate_fn=safe_collate
            )

            # Load pretrained encoder
            final_model = GADHTModel(pretraining=False).to(device)
            final_model.load_state_dict(
                torch.load(PRETRAINED_MODEL_PATH, map_location=device),
                strict=False
            )

            save_path = os.path.join(FINETUNE_MODEL_DIR, f"gadht_finetuned_{model_name}_H{H}.pt")

            run_gadht_finetune(
                model=final_model,
                dataloader=full_loader,
                epochs=FINETUNE_EPOCHS,
                lr=FINETUNE_LR,
                patience=FINETUNE_PATIENCE,
                max_grad_norm=1.0,
                save_path=save_path,
                plot_curve=True,
                fig_path=os.path.join(RESULTS_DIR, f"{model_name}_final_curve_H{H}.png")
            )

            logger.info(f"✅ Final fine-tuned model saved: {save_path}")

            # === 3. Backtest on full dataset
            logger.info(f"📊 Running backtest on full dataset: {model_name.upper()} | Horizon={H}")
            metrics, preds, targets = evaluate_model(
                final_model,
                DataLoader(full_dataset, batch_size=FINETUNE_BATCH_SIZE, shuffle=False, collate_fn=safe_collate),
                return_predictions=True,
                device=device,
                cost_bps=20  # 20 bps transaction cost
            )

            results_path = os.path.join(RESULTS_DIR, f"{model_name}_final_backtest_H{H}.csv")
            pd.DataFrame([metrics]).to_csv(results_path, index=False)
            logger.info(f"✅ Backtest results saved: {results_path}")

        except Exception as e:
            logger.critical(f"❌ Fine-tuning failed for {model_name.upper()} | Horizon={H}: {e}")
            continue

In [None]:
# ===========================================================
# ✅ ZEROSHOT PIPELINE – All Zero-Shot Assets (with Backtesting)
# ===========================================================

def backtest_portfolio(preds, targets, cost=0.002, risk_free_rate=0.0, freq=252):
    """
    Run a simple backtest given predictions and true values.
    """
    preds = np.array(preds)
    targets = np.array(targets)

    if len(preds) != len(targets) or len(preds) < 2:
        return {"Sharpe": 0.0, "MaxDD": 0.0, "WinRate": 0.0, "EquityCurve": np.array([])}

    # Trading signals
    signals = np.sign(np.diff(preds, prepend=preds[0]))

    # Realized returns
    rets = np.diff(targets) / targets[:-1]
    rets = np.concatenate([[0], rets])

    # Strategy returns
    strat_rets = signals * rets

    # Transaction costs
    trades = np.diff(signals) != 0
    trades = np.concatenate([[False], trades])
    strat_rets[trades] -= cost

    # Equity curve
    equity = np.cumprod(1 + strat_rets)

    # Sharpe ratio
    mu, sigma = strat_rets.mean(), strat_rets.std() + 1e-8
    sharpe = (mu - risk_free_rate / freq) / sigma * np.sqrt(freq)

    # Max drawdown
    peak = np.maximum.accumulate(equity)
    drawdown = equity / peak - 1
    max_dd = drawdown.min()

    # Win rate
    win_rate = (strat_rets > 0).mean()

    return {
        "Sharpe": float(sharpe),
        "MaxDD": float(max_dd),
        "WinRate": float(win_rate),
        "EquityCurve": equity
    }


def run_zero_shot_evaluation(model_path, save_dir="figures/zeroshot"):
    """
    Run zero-shot evaluation on all assets defined in PHASE_TICKERS["zeroshot"].
    """
    os.makedirs(save_dir, exist_ok=True)
    results = []

    tickers = PHASE_TICKERS["zeroshot"]

    for ticker in tickers:
        logger.info(f"\n🧪 Zero-Shot Evaluation on {ticker}...")

        # === Load dataset
        dataset = load_ticker_datasets(
            tickers=[ticker],
            data_path=DATA_DIRECTORY,
            window_size=30,
            max_imfs=5,
            use_ceemdan=True,
            pretraining_mode=False,
            normalize=True,
            prediction_horizon=1
        )
        logger.info(f"🌟 Dataset ready for {ticker}: {len(dataset)} samples.")

        # === 5-fold rolling validation
        kf = KFold(n_splits=5, shuffle=False)
        all_preds, all_targets = [], []
        fold_metrics = []

        for fold, (_, val_idx) in enumerate(kf.split(dataset)):
            val_subset = Subset(dataset, val_idx)
            val_loader = DataLoader(
                val_subset,
                batch_size=64,
                shuffle=False,
                collate_fn=safe_collate
            )

            # Load pretrained model
            model = GADHTModel(pretraining=False).to(device)
            model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
            model.eval()

            preds, targets = [], []
            with torch.no_grad():
                for batch in val_loader:
                    if batch is None:
                        continue

                    # Handle dict batches (from safe_collate)
                    if isinstance(batch, dict):
                        if "input" in batch and "target" in batch:  # supervised mode
                            x, y = batch["input"], batch["target"]
                        elif "masked" in batch:  # pretraining mode (should not happen here)
                            logger.warning("⚠️ Skipping pretraining batch in zero-shot eval.")
                            continue
                        else:
                            logger.error(f"❌ Unexpected batch keys: {list(batch.keys())}")
                            continue
                    else:
                        try:
                            x, y = batch
                        except Exception as e:
                            logger.error(f"❌ Could not unpack batch: {e}")
                            continue

                    x, y = x.to(device), y.to(device)
                    y_hat = model(x).squeeze()
                    preds.append(y_hat.cpu().numpy())
                    targets.append(y.cpu().numpy())

            if not preds or not targets:
                logger.warning(f"⚠️ No valid predictions for fold {fold+1} of {ticker}, skipping...")
                continue

            preds = np.concatenate(preds)
            targets = np.concatenate(targets)

            all_preds.append(preds)
            all_targets.append(targets)

            # Fold metrics
            rmse = np.sqrt(mean_squared_error(targets, preds))
            mae = mean_absolute_error(targets, preds)
            mape = np.mean(np.abs((targets - preds) / targets)) * 100
            r2 = r2_score(targets, preds)

            fold_metrics.append({"RMSE": rmse, "MAE": mae, "MAPE": mape, "R2": r2})

        if not fold_metrics:
            logger.warning(f"❌ No valid folds for {ticker}, skipping...")
            continue

        # === Aggregate fold metrics
        avg_metrics = pd.DataFrame(fold_metrics).mean().to_dict()

        # === Concatenate predictions
        all_preds = np.concatenate(all_preds)
        all_targets = np.concatenate(all_targets)

        # Backtest
        bt_metrics = backtest_portfolio(all_preds, all_targets, cost=0.002)

        # Save metrics
        results.append({"Stock": ticker, **avg_metrics, **bt_metrics})

        # === Save plots
        plot_predictions_vs_real_array(
            all_targets, all_preds, ticker,
            strategy="Zero-shot",
            save_path=f"{save_dir}/{ticker.lower()}_pred_actual.png"
        )
        plot_scatter_array(
            all_targets, all_preds, ticker,
            strategy="Zero-shot",
            save_path=f"{save_dir}/{ticker.lower()}_scatter.png"
        )

    return pd.DataFrame(results)

In [None]:
# ===========================================================
# 📈 ZEROSHOT VISUALIZATIONS
# ===========================================================

def plot_predictions_vs_real_array(y_true, y_pred, stock_name, strategy="Zero-shot", save_path=None):
    """
    Plot time-series comparison between actual and predicted values.

    Args:
        y_true (array-like): Ground truth values.
        y_pred (array-like): Model predictions.
        stock_name (str): Ticker symbol (e.g., "KO", "META").
        strategy (str): Label for strategy (default = "Zero-shot").
        save_path (str or None): Path to save figure (if provided).
    """
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    plt.figure(figsize=(12, 6))
    plt.plot(y_true, label="Actual", linestyle="--", linewidth=2, color="black")
    plt.plot(y_pred, label="Prediction", linewidth=2, color="royalblue", alpha=0.8)
    plt.title(f"{stock_name} – Close Price Prediction (+1 day) [{strategy}]", fontsize=16)
    plt.xlabel("Samples")
    plt.ylabel("Normalized Price")
    plt.legend(framealpha=0.8)
    plt.grid(True, linestyle="--", alpha=0.7)
    plt.tight_layout()
    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, dpi=300)
    plt.close()


def plot_scatter_array(y_true, y_pred, stock_name, strategy="Zero-shot", save_path=None):
    """
    Scatter plot of actual vs predicted values, with diagonal reference.

    Args:
        y_true (array-like): Ground truth values.
        y_pred (array-like): Model predictions.
        stock_name (str): Ticker symbol (e.g., "KO", "META").
        strategy (str): Label for strategy (default = "Zero-shot").
        save_path (str or None): Path to save figure (if provided).
    """
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    # Compute R² for quick visual feedback
    try:
        r2 = r2_score(y_true, y_pred)
    except Exception:
        r2 = np.nan

    plt.figure(figsize=(8, 8))
    plt.scatter(y_true, y_pred, alpha=0.6, color="dodgerblue", edgecolor="k")
    min_val, max_val = min(y_true.min(), y_pred.min()), max(y_true.max(), y_pred.max())
    plt.plot([min_val, max_val], [min_val, max_val], 'r--', lw=2, label="Ideal Fit")
    plt.title(f"{stock_name} – Actual vs Predicted [{strategy}] (R²={r2:.3f})", fontsize=16)
    plt.xlabel("Actual")
    plt.ylabel("Predicted")
    plt.legend(framealpha=0.8)
    plt.grid(True, linestyle="--", alpha=0.7)
    plt.tight_layout()
    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, dpi=300)
    plt.close()

In [None]:
# ========================================
# ⚡️ Execution du pipeline zero-shot
# ========================================

PRETRAINED_MODEL_PATH = "checkpoints/finetuned/gadht_finetuned_all_assets_H1.pt"

zeroshot_df = run_zero_shot_evaluation(
    model_path=PRETRAINED_MODEL_PATH
)

print(zeroshot_df)

In [None]:
# Affichage console
print(zeroshot_df.round(4))

os.makedirs("results/zeroshot", exist_ok=True)

# Export CSV + LaTeX
zeroshot_df.round(2).to_csv("results/zeroshot/zeroshot_summary.csv", index=False)
print(zeroshot_df.round(2).to_latex(index=False))

In [None]:
# ===========================================================
# 📊 IMF & Temporal Attention Analysis (All Fine-Tuning Assets)
# ===========================================================

# Make sure these come from the main codebase
# from model import GADHTModel
# from data import load_ticker_datasets, safe_collate

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define all fine-tuning checkpoints (1-day horizon models)
CHECKPOINTS = {
    "TSLA": "checkpoints/finetuned/gadht_finetuned_tsla_H1.pt",
    "JNJ":  "checkpoints/finetuned/gadht_finetuned_jnj_H1.pt",
    "MSFT": "checkpoints/finetuned/gadht_finetuned_msft_H1.pt",
    "NKE":  "checkpoints/finetuned/gadht_finetuned_nke_H1.pt",
    "PG":   "checkpoints/finetuned/gadht_finetuned_pg_H1.pt",
    "UNH":  "checkpoints/finetuned/gadht_finetuned_unh_H1.pt",
}

FIG_INT_DIR = "figures/interpretability"
BATCH_SIZE = 32
NUM_IMFS = 5

# Ensure figure directory exists
os.makedirs(FIG_INT_DIR, exist_ok=True)

In [None]:
# ===========================================================
# 🔍 Extract Average Temporal Attention Weights
# ===========================================================

def extract_average_temporal_attention(ticker, model_path, max_batches=30):
    """
    Extract the average temporal attention weights for a given ticker
    using the first temporal attention layer of GADHT.
    """
    model = GADHTModel(pretraining=False).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
    model.eval()

    dataset = load_ticker_datasets(
        [ticker],
        data_path=DATA_DIRECTORY,
        window_size=30,
        max_imfs=5,
        use_ceemdan=True,
        pretraining_mode=False,
        normalize=True,
        prediction_horizon=1
    )
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=safe_collate)

    all_weights = []

    with torch.no_grad():
        for i, batch in enumerate(loader):
            if i >= max_batches:
                break

            # Supervised mode → dict {"input", "target"}
            if isinstance(batch, dict):
                x = batch["input"].to(device)   # (B, T, M, F)
            else:
                x, _ = batch
                x = x.to(device)

            B, T, M, F = x.shape

            # Reshape: treat each IMF as an independent sequence
            x_tm = x.permute(0, 2, 1, 3).contiguous().view(B * M, T, F)  # (B*M, T, F)
            x_tm = model.input_proj(x_tm)  # (B*M, T, d_model)

            # Add positional encoding
            x_tm = x_tm + model.positional_encoding.to(x_tm.device)

            # Take the FIRST temporal attention layer for interpretability
            x_tm, attn = model.temporal_layers[0](x_tm, energy=None)  # attn: (B*M, heads, T, T)

            # Average over heads and IMF dimension
            attn_mean = attn.mean(dim=1)         # (B*M, T, T)
            attn_mean = attn_mean.mean(dim=0)    # (T, T)
            all_weights.append(attn_mean.cpu().numpy())

    if not all_weights:
        logger.warning(f"⚠️ No attention extracted for {ticker}")
        return np.array([])

    # Average across batches
    mean_attn_matrix = np.mean(np.stack(all_weights, axis=0), axis=0)  # (T, T)

    # Collapse query dimension → importance per time step
    mean_weights = mean_attn_matrix.mean(axis=0)  # (T,)

    return mean_weights

In [None]:
# ===========================================================
# 📈 Temporal Attention Weights – All Fine-Tuning Assets
# ===========================================================

weights_dict = {}
for ticker, path in CHECKPOINTS.items():
    weights = extract_average_temporal_attention(ticker, path)
    if weights.size == 0:
        logger.warning(f"⚠️ Skipping {ticker}, no valid attention weights.")
        continue
    weights_dict[ticker] = weights

if not weights_dict:
    raise RuntimeError("❌ No attention weights available to plot.")

# ✅ Build time step labels dynamically
time_labels = [f"T{i+1}" for i in range(len(next(iter(weights_dict.values()))))]

# Plot all assets
plt.figure(figsize=(16, 8))  # Large canvas for readability

for ticker, weights in weights_dict.items():
    plt.plot(time_labels, weights, marker="o", linewidth=2, label=ticker)

plt.title("Mean Attention Weights per Time Step (Fine-Tuning Assets)", fontsize=18)
plt.ylabel("Average Attention Weight", fontsize=14)
plt.xlabel("Time Step", fontsize=14)
plt.xticks(rotation=45, fontsize=10)
plt.yticks(fontsize=10)
plt.legend(fontsize=12, ncol=2)  # multi-column legend for clarity
plt.grid(True)
plt.tight_layout()
plt.savefig(f"{FIG_INT_DIR}/temporal_attention_weights_all.png")
plt.show()

In [None]:
# ===========================================================
# 🔍 Temporal Attention Heatmap for 1 Sample
# ===========================================================

def plot_temporal_attention_map(ticker, model_path):
    """
    Plot a heatmap of temporal attention weights for a single sample.

    Args:
        ticker (str): Stock ticker (e.g., "TSLA").
        model_path (str): Path to the fine-tuned checkpoint.
    """
    # Load model
    model = GADHTModel(pretraining=False).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
    model.eval()

    # Load dataset
    dataset = load_ticker_datasets(
        [ticker],
        data_path=DATA_DIRECTORY,
        window_size=30,
        max_imfs=5,
        use_ceemdan=True,
        pretraining_mode=False,
        normalize=True,
        prediction_horizon=1
    )
    loader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=safe_collate)

    attn_map = None

    with torch.no_grad():
        for batch in loader:
            if batch is None:
                continue

            x = batch["input"].to(device)   # (1, T, M, F)

            B, T, M, F = x.shape
            # Prepare temporal input: (B*M, T, F)
            x_tm = x.permute(0, 2, 1, 3).contiguous().view(B * M, T, F)
            x_tm = model.input_proj(x_tm) + model.positional_encoding.to(x_tm.device)

            # Run through first temporal attention layer to extract attention map
            x_tm, attn = model.temporal_layers[0](x_tm, energy=None)  # attn: (heads, T, T) or (B*M, heads, T, T)

            # Average over heads and batch dimension if necessary
            if attn.dim() == 4:
                attn_map = attn.mean(dim=1)[0].cpu().numpy()  # (T, T) from first IMF
            elif attn.dim() == 3:
                attn_map = attn.mean(dim=0).cpu().numpy()     # (T, T)
            break

    if attn_map is None:
        logger.warning(f"⚠️ No attention map extracted for {ticker}")
        return

    # Plot heatmap
    plt.figure(figsize=(10, 6))
    sns.heatmap(attn_map, cmap="YlGnBu", xticklabels=False, yticklabels=False)
    plt.title(f"Temporal Attention Map – {ticker} (1-day horizon)")
    plt.xlabel("Time (Key)")
    plt.ylabel("Time (Query)")
    plt.tight_layout()
    plt.savefig(f"{FIG_INT_DIR}/{ticker.lower()}_temporal_attention_map_H1.png")
    plt.show()

In [None]:
# ===========================================================
# 🔍 Temporal Attention Maps – All Fine-Tuning Assets
# ===========================================================

for ticker, path in CHECKPOINTS.items():
    logger.info(f"🖼️ Generating temporal attention map for {ticker}...")
    plot_temporal_attention_map(ticker, path)

In [None]:
# ===========================================================
# 📈 IMF-Level Attention Analysis (L2-Norm Method)
# ===========================================================

def analyze_imf_attention_weights(ticker, model_path):
    """
    Compute relative IMF-level importance weights using L2-norm energy.

    Args:
        ticker (str): Stock ticker (e.g., "TSLA").
        model_path (str): Path to fine-tuned checkpoint.

    Returns:
        np.ndarray: Normalized IMF importance weights (shape: [num_imfs]).
    """
    # Load model
    model = GADHTModel(pretraining=False).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
    model.eval()

    # Load dataset
    dataset = load_ticker_datasets(
        [ticker],
        data_path=DATA_DIRECTORY,
        window_size=30,
        max_imfs=5,
        use_ceemdan=True,
        pretraining_mode=False,
        normalize=True,
        prediction_horizon=1
    )
    loader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=safe_collate)

    all_weights = []

    with torch.no_grad():
        for batch in loader:
            if batch is None:
                continue

            # Extract input tensor
            X = batch["input"].to(device)  # (B, T, M, F)
            B, T, M, F = X.shape

            # Compute IMF-level L2 energy across time and features
            imf_energies = []
            for i in range(M):
                imf_energy = X[:, :, i, :].norm(p=2, dim=-1).mean().item()
                imf_energies.append(imf_energy)

            # Normalize to obtain importance weights
            imf_weights = np.array(imf_energies)
            imf_weights /= imf_weights.sum()

            all_weights.append(imf_weights)

    # Return average across samples
    return np.mean(all_weights, axis=0)

In [None]:
# ===========================================================
# 📊 Aggregated Bar Plot for Multiple Stocks (IMF Weights)
# ===========================================================

results = {}
for ticker in CHECKPOINTS:
    weights = analyze_imf_attention_weights(ticker, CHECKPOINTS[ticker])
    results[ticker] = weights

# Build DataFrame
df_imf = pd.DataFrame(results, index=[f"IMF{i+1}" for i in range(NUM_IMFS)])

if df_imf.empty:
    logger.warning("⚠️ No IMF weights could be computed. Skipping plot.")
else:
    plt.figure(figsize=(12, 6))
    bar_width = 0.12  # Smaller width since multiple stocks
    x = np.arange(NUM_IMFS)

    # Loop over tickers and plot bars
    for i, ticker in enumerate(df_imf.columns):
        plt.bar(
            x + i * bar_width,
            df_imf[ticker],
            width=bar_width,
            label=ticker,
            alpha=0.85
        )

    # Axis labels and title
    plt.xlabel("IMF Index", fontsize=12)
    plt.ylabel("Average Energy-Based Weight", fontsize=12)
    plt.title("Average Energy-Based IMF Weights (Across Fine-Tuned Models)", fontsize=14)

    # Center x-ticks
    plt.xticks(x + (len(df_imf.columns) - 1) * bar_width / 2, df_imf.index)

    plt.legend()
    plt.grid(axis="y", linestyle="--", alpha=0.6)
    plt.tight_layout()

    # Save and show
    plt.savefig(f"{FIG_INT_DIR}/imf_energy_weights.png")
    plt.show()

    # Console output of the DataFrame
    print(df_imf.T.round(4))