In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import math
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import gc
from scipy.io import loadmat

In [None]:
def generate_time_features(length, periods=[24, 168, 720]):
    """
    Generate sinusoidal time features with multiple periodicities.
    
    Args:
        length: Number of time steps
        periods: List of period values for sinusoidal encoding
        
    Returns:
        Array of shape (length, 2*len(periods)+1) containing:
        - sin/cos pairs for each period
        - normalized time progress
        
    Note: For satellite channels, you may want to adjust periods based on
    orbital dynamics. Current defaults (24, 168, 720) are suitable for
    phenomena with daily/weekly patterns.
    """
    i = np.arange(length, dtype=np.float32)
    
    features = []
    for period in periods:
        angle = 2 * np.pi * i / period
        features.append(np.sin(angle).astype(np.float32))
        features.append(np.cos(angle).astype(np.float32))
    
    # Add normalized time index for sequence progression
    time_progress = (i / length).astype(np.float32)
    features.append(time_progress)
    
    return np.column_stack(features)  # Shape: (length, 2*len(periods)+1)

In [None]:
class TimeDataset(Dataset):
    """
    Dataset for single-variable time series with channel ageing.
    
    The dataset creates training samples where:
    - Input window: CSI from t-P-A+1 to t-A (P historical steps)
    - Target: CSI from t-G-A+1+A to t-G-A+F+A (F future steps after delay A)
    
    Args:
        raw_data: 1D array of channel measurements
        P: Number of historical steps (input sequence length)
        G: Number of recent steps used by decoder
        F: Number of future steps to predict
        A: Aging delay (time steps)
        mean/std: Normalization parameters (computed from training set)
        start_idx/end_idx: Data slice for train/val/test split
    """
    def __init__(self, raw_data, P, G, F, A,
                 mean=None, std=None, 
                 start_idx=None, end_idx=None):
        self.P = P
        self.G = G
        self.F = F
        self.A = A
        
        # Convert to float32 and slice
        raw_data = raw_data.astype(np.float32)
        self.raw_data = raw_data[start_idx:end_idx]
        
        # Compute normalization statistics
        self.mean = mean if mean is not None else self.raw_data.mean()
        self.std = std if std is not None else self.raw_data.std()
        
        # Normalize data
        self.data = ((self.raw_data - self.mean) / (self.std + 1e-8)).astype(np.float32)
        
        # Generate time features (7 features: 3 periods × 2 + time progress)
        self.time_features = generate_time_features(len(self.raw_data)).astype(np.float32)
        
        # Combine: [normalized_value, time_features] → shape (T, 8)
        self.features = np.hstack([
            self.data.reshape(-1, 1),  # (T, 1)
            self.time_features         # (T, 7)
        ]).astype(np.float32)
    
    def __len__(self):
        """
        Number of valid samples.
        
        Explanation:
        - Need P steps for encoder input
        - Need A steps for aging delay
        - Need F steps for target output
        Total: P + A + F - 1 steps minimum
        """
        return len(self.features) - self.P - self.A - self.F + 2
    
    def __getitem__(self, idx):
        """
        Get a single training sample.
        
        Returns:
            encoder_input: (P, 8) - Historical CSI from t-P-A+1 to t-A
            decoder_input: (G+F, 8) - G recent CSI + F zeros (for prediction)
            target: (F, 8) - Ground truth CSI from t+1 to t+F
        """
        # Encoder input: P historical steps ending at t-A
        encoder_input = self.features[idx:idx+self.P]  # (P, 8)
        
        # Target: F future steps (from t-A+1 to t-A+F)
        target = self.features[idx+self.P:idx+self.P+self.F]  # (F, 8)
        
        # Decoder input: G recent steps + F zeros
        # Recent G steps: from t-G-A+1 to t-A
        decoder_recent = self.features[idx+self.P-self.G:idx+self.P]  # (G, 8)
        
        # Create F zero steps for future prediction
        decoder_zeros = np.zeros((self.F, 8), dtype=np.float32)
        
        # Concatenate: decoder_input has shape (G+F, 8)
        decoder_input = np.vstack([decoder_recent, decoder_zeros])
        
        return (
            torch.as_tensor(encoder_input, dtype=torch.float32),
            torch.as_tensor(decoder_input, dtype=torch.float32),
            torch.as_tensor(target, dtype=torch.float32)
        )

In [None]:
class TimeDatasetFF(Dataset):
    """
    Args:
        raw_data_in: Input channel measurements
        raw_data_out: Output channel measurements
        P, G, F, A: Same as TimeDataset
        mean_in/std_in: Normalization for input
        mean_out/std_out: Normalization for output
    """
    def __init__(self, 
                 raw_data_in, raw_data_out,
                 P, G, F, A,
                 mean_in=None, std_in=None, 
                 mean_out=None, std_out=None,
                 start_idx=None, end_idx=None):
        self.P = int(P)
        self.G = int(G)
        self.F = int(F)
        self.A = int(A)
        
        # Process input data
        raw_data_in = raw_data_in.astype(np.float32)
        self.raw_data_in = raw_data_in[start_idx:end_idx]
        self.mean_in = mean_in if mean_in is not None else self.raw_data_in.mean()
        self.std_in = std_in if std_in is not None else self.raw_data_in.std()
        self.data_in = ((self.raw_data_in - self.mean_in) / (self.std_in + 1e-8)).astype(np.float32)
        
        # Process output data
        raw_data_out = raw_data_out.astype(np.float32)
        self.raw_data_out = raw_data_out[start_idx:end_idx]
        self.mean_out = mean_out if mean_out is not None else self.raw_data_out.mean()
        self.std_out = std_out if std_out is not None else self.raw_data_out.std()
        self.data_out = ((self.raw_data_out - self.mean_out) / (self.std_out + 1e-8)).astype(np.float32)
        
        # Generate time features
        self.time_features_in = generate_time_features(len(self.raw_data_in)).astype(np.float32)
        self.time_features_out = generate_time_features(len(self.raw_data_out)).astype(np.float32)
        
        # Combine features
        self.features_in = np.hstack([
            self.data_in.reshape(-1, 1),
            self.time_features_in
        ]).astype(np.float32)
        
        self.features_out = np.hstack([
            self.data_out.reshape(-1, 1),
            self.time_features_out
        ]).astype(np.float32)
    
    def __len__(self):
        return min(len(self.features_in), len(self.features_out)) - self.P - self.A - self.F + 2
    
    def __getitem__(self, idx):
        """Get sample with separate input/output sources"""
        # Encoder uses input features
        encoder_input = self.features_in[idx:idx+self.P]
        
        # Decoder and target use output features
        decoder_recent = self.features_out[idx+self.P-self.G:idx+self.P]
        decoder_zeros = np.zeros((self.F, 8), dtype=np.float32)
        decoder_input = np.vstack([decoder_recent, decoder_zeros])
        
        target = self.features_out[idx+self.P:idx+self.P+self.F]
        
        return (
            torch.as_tensor(encoder_input, dtype=torch.float32),
            torch.as_tensor(decoder_input, dtype=torch.float32),
            torch.as_tensor(target, dtype=torch.float32)
        )

In [None]:
class PositionalEncoding(nn.Module):
    """
    Standard sinusoidal positional encoding from "Attention is All You Need".
    
    This encoding allows the model to utilise sequence order information.
    Each position gets a unique encoding based on sine/cosine functions
    at different frequencies.
    """
    def __init__(self, d_model, max_len=500):
        super().__init__()
        # Create positional encoding matrix
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        # Apply sin to even indices, cos to odd indices
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # Add batch dimension: (1, max_len, d_model)
        
        # Register as buffer (not a parameter, but part of state)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        """
        Add positional encoding to input
        
        Args:
            x: (batch_size, seq_len, d_model)
        Returns:
            x + positional_encoding: same shape as x
        """
        return x + self.pe[:, :x.size(1), :]

In [None]:
class TransformerEncoderBlock(nn.Module):
    """
    Single Transformer encoder block with:
    1. Multi-head self-attention
    2. Feed-forward network
    3. Residual connections and layer normalization
    
    This follows the standard Transformer architecture from Vaswani et al.
    """
    def __init__(self, d_model, nhead, d_ff=256, dropout=0.1):
        super().__init__()
        # Multi-head attention (batch_first=True for easier handling)
        self.self_attn = nn.MultiheadAttention(
            embed_dim=d_model, 
            num_heads=nhead, 
            dropout=dropout, 
            batch_first=True
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        
        # Feed-forward network
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout2 = nn.Dropout(dropout)
    
    def forward(self, x):
        """
        Forward pass with residual connections
        
        Args:
            x: (batch_size, seq_len, d_model)
        Returns:
            x: same shape, after attention and FFN
        """
        # Self-attention with residual connection
        attn_out, _ = self.self_attn(x, x, x)
        x = x + self.dropout1(attn_out)
        x = self.norm1(x)
        
        # Feed-forward with residual connection
        ffn_out = self.ffn(x)
        x = x + self.dropout2(ffn_out)
        x = self.norm2(x)
        
        return x

In [None]:
class TransformerEncoder(nn.Module):
    """
    Complete Transformer encoder with:
    - Input projection (data_dim → d_model)
    - Positional encoding
    - Stack of encoder blocks
    - Optional final normalisation
    """
    def __init__(self, attn_layers, data_dim, d_model, norm_layer=None):
        super().__init__()
        self.input_proj = nn.Linear(data_dim, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        self.attn_layers = nn.ModuleList(attn_layers)
        self.norm = norm_layer
    
    def forward(self, x):
        """
        Process input sequence through all encoder layers
        
        Args:
            x: (batch_size, seq_len, data_dim)
        Returns:
            encoded: (batch_size, seq_len, d_model)
        """
        x = self.input_proj(x)       # Project to d_model
        x = self.pos_encoder(x)      # Add positional encoding
        
        # Pass through all encoder blocks
        for attn_layer in self.attn_layers:
            x = attn_layer(x)
        
        if self.norm is not None:
            x = self.norm(x)
        
        return x

In [None]:
def generate_square_subsequent_mask(sz, device):
    """
    
    This mask ensures that position i can only attend to positions <= i,
    maintaining the autoregressive property needed for sequential prediction.
    
    Args:
        sz: Sequence length
        device: torch device
    
    Returns:
        mask: (sz, sz) with -inf in upper triangle, 0 elsewhere
        
    Example for sz=4:
        [[0,    -inf, -inf, -inf],
         [0,    0,    -inf, -inf],
         [0,    0,    0,    -inf],
         [0,    0,    0,    0   ]]
    """
    mask = torch.triu(torch.ones(sz, sz, device=device) * float('-inf'), diagonal=1)
    return mask


class TransformerDecoderBlock(nn.Module):
    """
    
    The decoder has three sub-layers:
    1. Masked multi-head self-attention (with causal mask)
    2. Cross-attention to encoder output
    3. Feed-forward network
    
    """
    def __init__(self, d_model, nhead, d_ff=256, dropout=0.1):
        super().__init__()
        
        # 1. Masked self-attention
        self.masked_attn = nn.MultiheadAttention(
            d_model, nhead, dropout=dropout, batch_first=True
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        
        # 2. Cross-attention
        self.cross_attn = nn.MultiheadAttention(
            d_model, nhead, dropout=dropout, batch_first=True
        )
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout2 = nn.Dropout(dropout)
        
        # 3. Feed-forward
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout3 = nn.Dropout(dropout)
    
    def forward(self, x, memory, tgt_mask=None):
        """
        Forward pass with causal masking.
        
        Args:
            x: (batch, tgt_len, d_model) - decoder input
            memory: (batch, src_len, d_model) - encoder output
            tgt_mask: (tgt_len, tgt_len) - causal mask
        
        Returns:
            x: (batch, tgt_len, d_model) - decoder output
        """
        # 1. Masked self-attention (with causal mask!)
        _x, _ = self.masked_attn(x, x, x, attn_mask=tgt_mask)
        x = self.norm1(x + self.dropout1(_x))
        
        # 2. Cross-attention with encoder
        _x, _ = self.cross_attn(x, memory, memory)
        x = self.norm2(x + self.dropout2(_x))
        
        # 3. Feed-forward
        _x = self.ffn(x)
        x = self.norm3(x + self.dropout3(_x))
        
        return x


class Decoder(nn.Module):
    """
    Complete decoder combining Transformer and LSTM.
    
    Architecture (as per paper):
    - Transformer path: Processes decoder input with cross-attention to encoder
    - LSTM path: Runs in parallel on the raw sequence values
    - Concatenate both paths and use FC layer for final prediction
    
    This hybrid approach leverages:
    - Transformer: Global dependencies, parallel processing
    - LSTM: Local temporal patterns, sequential modeling
    """
    def __init__(self, layers, data_dim, d_model, lstm_dim, hidden_size=64, norm_layer=None):
        super().__init__()
        self.d_model = d_model
        
        # Transformer components
        self.input_proj = nn.Linear(data_dim, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        self.layers = nn.ModuleList(layers)
        self.norm = norm_layer
        
        # LSTM components (processes raw channel values)
        self.lstm = nn.LSTM(
            input_size=lstm_dim, 
            hidden_size=hidden_size, 
            num_layers=2,  # Stack 2 LSTM layers
            batch_first=True,
            dropout=0.1
        )
        
        # Fusion: Combine Transformer and LSTM features
        self.fc = nn.Sequential(
            nn.Linear(d_model + hidden_size, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, 1)
        )
    
    def forward(self, data, memory, original_seq):
        """
        Forward pass through decoder
        
        Args:
            data: (batch, G+F, data_dim) - decoder input with time features
            memory: (batch, P, d_model) - encoder output
            original_seq: (batch, G+F) - raw channel values for LSTM
        
        Returns:
            out: (batch, G+F) - predicted channel values
        """
        batch_size = data.size(0)
        seq_len = data.size(1)
        device = data.device
        
        # Generate causal mask
        tgt_mask = generate_square_subsequent_mask(seq_len, device)
        
        # Transformer path
        x = self.input_proj(data)       # (batch, G+F, d_model)
        x = self.pos_encoder(x)
        
        for layer in self.layers:
            x = layer(x, memory, tgt_mask)  # Use causal mask!
        
        if self.norm is not None:
            x = self.norm(x)
        
        # LSTM path (processes raw channel values)
        lstm_out, _ = self.lstm(original_seq.unsqueeze(-1))  # (batch, G+F, hidden_size)
        
        # Combine Transformer and LSTM features
        combined = torch.cat([x, lstm_out], dim=-1)  # (batch, G+F, d_model + hidden_size)
        out = self.fc(combined)                      # (batch, G+F, 1)
        out = out.squeeze(-1)                        # (batch, G+F)
        
        return out

In [None]:
def evaluate_model(encoder, decoder, loader, device, F):
    """
    Evaluate model on validation/test set using NMSE metric.
    
    NMSE (Normalised Mean Squared Error) is used in the paper because
    it's scale-invariant and better suited for channel prediction.
    
    Args:
        encoder, decoder: Model components
        loader: DataLoader
        device: torch device
        F: Number of future steps to predict
    
    Returns:
        avg_loss: Average NMSE across all batches
    """
    encoder.eval()
    decoder.eval()
    total_loss = 0
    
    with torch.no_grad():
        for encoder_in, decoder_in, target in loader:
            # Move to device
            encoder_in = encoder_in.to(device)
            decoder_in = decoder_in.to(device)
            target = target.to(device)
            
            # Extract original sequence for LSTM (first column = channel value)
            original_seq = decoder_in[:, :, 0]  # (batch, G+F)
            
            # Forward pass
            encoded = encoder(encoder_in)
            pred = decoder(decoder_in, encoded, original_seq)
            
            # We only care about the last F predictions (future steps)
            pred_future = pred[:, -F:]      # (batch, F)
            target_vals = target[:, :, 0]   # (batch, F)
            
            # Compute NMSE per sample, then average
            numerator = torch.sum((pred_future - target_vals) ** 2, dim=1)
            denominator = torch.sum(target_vals ** 2, dim=1) + 1e-8
            nmse = numerator / denominator
            loss = torch.mean(nmse)
            
            total_loss += loss.item()
    
    return total_loss / len(loader)

In [None]:
def train_with_early_stop(encoder, decoder, train_loader, val_loader, 
                          enc_opt, dec_opt, F, device, 
                          epochs=100, patience=20, min_delta=0.001,
                          train_dataset=None):
    """
    Train the T-LSTM model with early stopping
    
    Early stopping prevents overfitting by monitoring validation loss
    and stopping when it stops improving
    
    Args:
        encoder, decoder: Model components
        train_loader, val_loader: DataLoaders
        enc_opt, dec_opt: Optimizers
        F: Number of future steps
        device: torch device
        epochs: Maximum training epochs
        patience: Early stopping patience
        min_delta: Minimum improvement to reset patience
        train_dataset: For saving normalization stats
    
    Returns:
        history: Dict with 'train' and 'val' loss curves
    """
    best_loss = float('inf')
    counter = 0
    history = {'train': [], 'val': []}
    best_model = None
    
    print("Starting training...")
    print(f"Device: {device}")
    print(f"Training samples: {len(train_loader.dataset)}")
    print(f"Validation samples: {len(val_loader.dataset)}")
    print(f"Prediction horizon (F): {F} steps\n")
    
    for epoch in range(epochs):
        # ========== TRAINING PHASE ==========
        encoder.train()
        decoder.train()
        train_loss = 0
        
        for encoder_in, decoder_in, target in train_loader:
            # Move to device
            encoder_in = encoder_in.to(device)
            decoder_in = decoder_in.to(device)
            target = target.to(device)
            
            # Extract original sequence for LSTM
            original_seq = decoder_in[:, :, 0]
            
            # Zero gradients
            enc_opt.zero_grad()
            dec_opt.zero_grad()
            
            # Forward pass
            encoded = encoder(encoder_in)
            pred = decoder(decoder_in, encoded, original_seq)
            
            # Compute loss on future predictions only
            pred_future = pred[:, -F:]
            target_vals = target[:, :, 0]
            
            # NMSE loss
            numerator = torch.sum((pred_future - target_vals) ** 2, dim=1)
            denominator = torch.sum(target_vals ** 2, dim=1) + 1e-8
            nmse = numerator / denominator
            loss = torch.mean(nmse)
            
            # Backward pass
            loss.backward()
            
            # Gradient clipping (prevents exploding gradients)
            torch.nn.utils.clip_grad_norm_(encoder.parameters(), 1.0)
            torch.nn.utils.clip_grad_norm_(decoder.parameters(), 1.0)
            
            # Update weights
            enc_opt.step()
            dec_opt.step()
            
            train_loss += loss.item()
        
        # ========== VALIDATION PHASE ==========
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = evaluate_model(encoder, decoder, val_loader, device, F)
        
        # Record history
        history['train'].append(avg_train_loss)
        history['val'].append(avg_val_loss)
        
        # Print progress
        print(f"Epoch {epoch+1:03d}/{epochs} | "
              f"Train Loss: {avg_train_loss:.6f} | "
              f"Val Loss: {avg_val_loss:.6f} | "
              f"EarlyStop: {counter}/{patience}")
        
        # ========== EARLY STOPPING CHECK ==========
        if avg_val_loss < best_loss - min_delta:
            best_loss = avg_val_loss
            counter = 0
            # Save best model
            best_model = {
                'encoder': encoder.state_dict(),
                'decoder': decoder.state_dict(),
                'mean': train_dataset.mean if train_dataset else None,
                'std': train_dataset.std if train_dataset else None
            }
            print(f"  → New best model! Val loss: {best_loss:.6f}")
        else:
            counter += 1
            if counter >= patience:
                print(f"\nEarly stopping at epoch {epoch+1}")
                break
        
        # Memory cleanup
        torch.cuda.empty_cache()
        gc.collect()
    
    # Save final model
    if best_model is not None:
        torch.save(best_model, 'best_model.pth')
        print(f"\nBest model saved with val loss: {best_loss:.6f}")
    
    return history

In [None]:
def inverse_normalize(data, mean, std):
    """Denormalize data back to original scale"""
    return data * std + mean


def calculate_metrics(true, pred):
    """
    Calculate comprehensive evaluation metrics
    
    Metrics:
    - MSE: Mean Squared Error
    - RMSE: Root Mean Squared Error
    - MAE: Mean Absolute Error
    - R²: Coefficient of determination
    - NMSE: Normalised MSE
    """
    true = np.array(true, dtype=np.float32)
    pred = np.array(pred, dtype=np.float32)
    
    mse = mean_squared_error(true, pred)
    rmse = np.sqrt(mse)
    mae = mean_absolute_error(true, pred)
    r2 = r2_score(true, pred)
    
    # NMSE
    num = np.sum((pred - true) ** 2)
    denom = np.sum(true ** 2) + 1e-8
    nmse = num / denom
    
    print(f"MSE:  {mse:.6f}")
    print(f"RMSE: {rmse:.6f}")
    print(f"MAE:  {mae:.6f}")
    print(f"R²:   {r2:.6f}")
    print(f"NMSE: {nmse:.6f}")
    print(f"NMSE (dB): {10 * np.log10(nmse):.2f} dB")
    
    return {'MSE': mse, 'RMSE': rmse, 'MAE': mae, 'R2': r2, 'NMSE': nmse}

In [None]:
def test_pipeline(model_path, test_dataset, data_dim, d_model, nhead, 
                 lstm_dim, hidden_size, F, device):
    """
    Complete testing pipeline: load model, run inference, compute metrics
    
    Args:
        model_path: Path to saved model
        test_dataset: Test Dataset instance
        data_dim, d_model, nhead, lstm_dim, hidden_size: Model hyperparameters
        F: Prediction horizon
        device: torch device
    
    Returns:
        trues_denorm: Ground truth in original scale
        preds_denorm: Predictions in original scale
    """
    # Load checkpoint
    checkpoint = torch.load(model_path, map_location=device, weights_only=False)
    
    # Rebuild model architecture
    attn_layers = [TransformerEncoderBlock(d_model=d_model, nhead=nhead) for _ in range(2)]
    encoder = TransformerEncoder(
        attn_layers=attn_layers,
        data_dim=data_dim,
        d_model=d_model,
        norm_layer=nn.LayerNorm(d_model)
    ).to(device)
    
    decoder_blocks = [TransformerDecoderBlock(d_model=d_model, nhead=nhead) for _ in range(2)]
    decoder = Decoder(
        decoder_blocks,
        data_dim=data_dim,
        d_model=d_model,
        lstm_dim=lstm_dim,
        hidden_size=hidden_size
    ).to(device)
    
    # Load weights
    encoder.load_state_dict(checkpoint['encoder'])
    decoder.load_state_dict(checkpoint['decoder'])
    
    # Set to evaluation mode
    encoder.eval()
    decoder.eval()
    
    # Run inference
    test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for encoder_in, decoder_in, target in test_loader:
            encoder_in = encoder_in.to(device)
            decoder_in = decoder_in.to(device)
            target = target.to(device)
            
            original_seq = decoder_in[:, :, 0]
            
            encoded = encoder(encoder_in)
            pred = decoder(decoder_in, encoded, original_seq)
            
            # Get future predictions
            pred_future = pred[:, -F:].cpu().numpy()
            target_vals = target[:, :, 0].cpu().numpy()
            
            all_preds.append(pred_future)
            all_targets.append(target_vals)
    
    # Concatenate results
    preds = np.concatenate(all_preds)  # (N, F)
    trues = np.concatenate(all_targets)  # (N, F)
    
    # Denormalize
    preds_denorm = inverse_normalize(preds, checkpoint['mean'], checkpoint['std'])
    trues_denorm = inverse_normalize(trues, checkpoint['mean'], checkpoint['std'])
    
    return trues_denorm, preds_denorm

In [None]:
def plot_results(true_values, pred_values, save_path='prediction_comparison.png'):
    """
    Visualise prediction results.
    
    Args:
        true_values: Ground truth array
        pred_values: Prediction array
        save_path: Where to save the plot
    """
    plt.figure(figsize=(15, 5))
    
    # Plot first 500 samples for clarity
    plot_len = min(500, len(true_values))
    
    plt.subplot(1, 2, 1)
    plt.plot(true_values[:plot_len], label='True', alpha=0.7, linewidth=1)
    plt.plot(pred_values[:plot_len], label='Predicted', alpha=0.7, linewidth=1)
    plt.title("Time Series Prediction Comparison")
    plt.xlabel("Sample Index")
    plt.ylabel("Channel Value")
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 2, 2)
    plt.scatter(true_values, pred_values, alpha=0.3, s=1)
    plt.plot([true_values.min(), true_values.max()], 
             [true_values.min(), true_values.max()], 
             'r--', label='Perfect Prediction')
    plt.title("Prediction Scatter Plot")
    plt.xlabel("True Values")
    plt.ylabel("Predicted Values")
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150)
    print(f"Plot saved to {save_path}")
    plt.show()