In [None]:
# Bloque 1: Importaciones y docstring inicial

"""
Temporal Fusion Transformer Optimizado
-----------------------------------------

Esta versión implementa las optimizaciones generales propuestas en
el código original publicado en 2020. Se ha estructurado el código para mejorar su
legibilidad, eficiencia computacional y mantenibilidad, sin alterar la lógica esencial
del modelo.

Modificaciones destacadas:
- Normalización de argumentos opcionales al inicio de __init__
- Inicialización optimizada de estados LSTM usando unsqueeze/expand
- Opción para compilar el método forward con torch.compile (PyTorch 2.0+)
- Eliminación de código comentado redundante
- Documentación detallada (docstrings) para cambios y funciones clave
- NUEVAS OPTIMIZACIONES: Cacheo de máscaras, fusión de operaciones, 
  vectorización y funciones JIT para mejorar rendimiento
"""

import math
from copy import copy
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from torch import nn
from torchmetrics import Metric as LightningMetric

from pytorch_forecasting.data import TimeSeriesDataSet
from pytorch_forecasting.metrics import (
    MAE,
    MAPE,
    RMSE,
    SMAPE,
    MultiHorizonMetric,
    QuantileLoss,
)
from pytorch_forecasting.models.base_model import BaseModelWithCovariates
from pytorch_forecasting.models.nn import LSTM, MultiEmbedding
from pytorch_forecasting.models.temporal_fusion_transformer.sub_modules import (
    AddNorm,
    GateAddNorm,
    GatedLinearUnit,
    GatedResidualNetwork,
    InterpretableMultiHeadAttention,
    VariableSelectionNetwork,
)
from pytorch_forecasting.utils import (
    create_mask,
    detach,
    integer_histogram,
    masked_op,
    padded_stack,
    to_list,
)
from pytorch_forecasting.utils._dependencies import _check_matplotlib

In [None]:
# Bloque 2: Definición de la clase principal y funciones auxiliares optimizadas

class TemporalFusionTransformer(BaseModelWithCovariates):
    # Inicialización optimizada de estados LSTM
    @staticmethod
    @torch.jit.script
    def initialize_lstm_states(hidden, cell, layers: int):
        """
        Inicializa los estados LSTM de manera optimizada con JIT.
        
        Args:
            hidden: Estado hidden inicial.
            cell: Estado cell inicial.
            layers: Número de capas LSTM.
            
        Returns:
            Tupla de estados inicializados (hidden, cell).
        """
        return (
            hidden.unsqueeze(0).expand(layers, -1, -1),
            cell.unsqueeze(0).expand(layers, -1, -1)
        )

    # Fusión de operaciones en el procesamiento LSTM
    def process_lstm_output(self, lstm_output, residual_input, is_encoder=True):
        """
        Combina las operaciones de gate y add_norm para el procesamiento post-LSTM.
        
        Args:
            lstm_output: Salida del LSTM.
            residual_input: Entrada residual para la conexión skip.
            is_encoder: Si se procesa el encoder (True) o decoder (False).
            
        Returns:
            Salida procesada con gate y add_norm.
        """
        gate = self.post_lstm_gate_encoder if is_encoder else self.post_lstm_gate_decoder
        add_norm = self.post_lstm_add_norm_encoder if is_encoder else self.post_lstm_add_norm_decoder
        return add_norm(gate(lstm_output), residual_input)

    # Atención eficiente en memoria
    def efficient_attention(self, q, k, v, mask=None):
        """
        Implementación de atención eficiente en memoria para secuencias largas.
        
        Args:
            q: Queries.
            k: Keys.
            v: Values.
            mask: Máscara para la atención.
            
        Returns:
            Tupla de (salida de atención, pesos de atención).
        """
        batch_size, seq_len, d_model = q.shape
        head_dim = d_model // self.hparams.attention_head_size
        
        # Reshape para computación por cabezas
        q = q.view(batch_size, seq_len, self.hparams.attention_head_size, head_dim)
        k = k.view(batch_size, -1, self.hparams.attention_head_size, head_dim)
        v = v.view(batch_size, -1, self.hparams.attention_head_size, head_dim)
        
        # Calcular atención eficientemente usando operaciones por lotes
        scores = torch.einsum("bqhd,bkhd->bhqk", q, k) / math.sqrt(head_dim)
        
        if mask is not None:
            scores = scores.masked_fill(~mask.unsqueeze(1), -1e9)
            
        attn_weights = torch.softmax(scores, dim=-1)
        context = torch.einsum("bhqk,bkhd->bqhd", attn_weights, v)
        
        return context.reshape(batch_size, seq_len, d_model), attn_weights

    # Vectorización del procesamiento de interpretación
    def process_variables_importance(self, variables, lengths):
        """
        Vectoriza el procesamiento de importancia de variables.
        
        Args:
            variables: Variables a procesar.
            lengths: Longitudes efectivas.
            
        Returns:
            Importancia de variables procesada.
        """
        mask = create_mask(variables.size(1), lengths).unsqueeze(-1)
        masked_vars = variables.masked_fill(mask, 0.0).sum(dim=1)
        return masked_vars / lengths.clamp_min(1).unsqueeze(-1)

    # Procesamiento paralelo para salida multiobjetivo
    def transform_multi_output(self, output):
        """
        Procesa todos los outputs en paralelo en lugar de secuencialmente.
        
        Args:
            output: Salida de la red.
            
        Returns:
            Lista de outputs procesados para cada target.
        """
        if self.n_targets > 1:
            # Procesar todos los outputs en paralelo en lugar de secuencialmente
            stacked_out = torch.stack([ol.weight for ol in self.output_layer])
            stacked_bias = torch.stack([ol.bias for ol in self.output_layer])
            
            # Reshape para permitir multiplicación matricial en batch
            reshaped_out = output.unsqueeze(1)  # [batch, 1, hidden]
            transformed = torch.bmm(
                reshaped_out.expand(-1, self.n_targets, -1),  # [batch, n_targets, hidden]
                stacked_out.transpose(1, 2)  # [n_targets, hidden, output_size]
            )
            transformed = transformed + stacked_bias.unsqueeze(0)
            return [transformed[:, i] for i in range(self.n_targets)]
        else:
            return self.output_layer(output)

In [None]:
# Bloque 3: Método de inicialización (init)

    def __init__(
        self,
        hidden_size: int = 16,
        lstm_layers: int = 1,
        dropout: float = 0.1,
        output_size: Union[int, List[int]] = 7,
        loss: MultiHorizonMetric = None,
        attention_head_size: int = 4,
        max_encoder_length: int = 10,
        static_categoricals: Optional[List[str]] = None,
        static_reals: Optional[List[str]] = None,
        time_varying_categoricals_encoder: Optional[List[str]] = None,
        time_varying_categoricals_decoder: Optional[List[str]] = None,
        categorical_groups: Optional[Union[Dict, List[str]]] = None,
        time_varying_reals_encoder: Optional[List[str]] = None,
        time_varying_reals_decoder: Optional[List[str]] = None,
        x_reals: Optional[List[str]] = None,
        x_categoricals: Optional[List[str]] = None,
        hidden_continuous_size: int = 8,
        hidden_continuous_sizes: Optional[Dict[str, int]] = None,
        embedding_sizes: Optional[Dict[str, Tuple[int, int]]] = None,
        embedding_paddings: Optional[List[str]] = None,
        embedding_labels: Optional[Dict[str, np.ndarray]] = None,
        learning_rate: float = 1e-3,
        log_interval: Union[int, float] = -1,
        log_val_interval: Union[int, float] = None,
        log_gradient_flow: bool = False,
        reduce_on_plateau_patience: int = 1000,
        monotone_constaints: Optional[Dict[str, int]] = None,
        share_single_variable_networks: bool = False,
        causal_attention: bool = True,
        logging_metrics: Optional[nn.ModuleList] = None,
        use_compile: bool = False,
        **kwargs,
    ):
        """
        Temporal Fusion Transformer para series temporales.
        
        Se han aplicado optimizaciones en el manejo de argumentos opcionales
        y se añade la posibilidad de compilar el método forward para mejorar
        el rendimiento en PyTorch 2.0+.

        Args:
            hidden_size: Tamaño de la capa oculta.
            lstm_layers: Número de capas LSTM.
            dropout: Tasa de dropout.
            output_size: Número de salidas (p.ej.: número de cuantiles en QuantileLoss).
            loss: Función de pérdida (debe ser un LightningMetric).
            attention_head_size: Número de cabezas en la atención.
            max_encoder_length: Longitud máxima del encoder.
            static_categoricals: Lista de variables categóricas estáticas.
            static_reals: Lista de variables continuas estáticas.
            time_varying_categoricals_encoder: Lista de variables categóricas para el encoder.
            time_varying_categoricals_decoder: Lista de variables categóricas para el decoder.
            categorical_groups: Diccionario o lista de grupos de variables categóricas.
            time_varying_reals_encoder: Lista de variables continuas para el encoder.
            time_varying_reals_decoder: Lista de variables continuas para el decoder.
            x_reals: Orden de variables continuas en el tensor de entrada.
            x_categoricals: Orden de variables categóricas en el tensor de entrada.
            hidden_continuous_size: Tamaño oculto para variables continuas.
            hidden_continuous_sizes: Diccionario que mapea variables continuas a tamaños específicos.
            embedding_sizes: Diccionario que mapea nombres de variables categóricas a tuplas (número de clases, tamaño del embedding).
            embedding_paddings: Lista de variables categóricas con padding.
            embedding_labels: Diccionario que mapea nombres de variables categóricas a etiquetas.
            learning_rate: Tasa de aprendizaje.
            log_interval: Intervalo para logging de predicciones.
            log_val_interval: Intervalo para logging en validación.
            log_gradient_flow: Si se debe loggear el flujo de gradientes.
            reduce_on_plateau_patience: Paciencia para reducir la tasa de aprendizaje.
            monotone_constaints: Restricciones de monotonía para variables continuas.
            share_single_variable_networks: Si compartir la red de variable única entre encoder y decoder.
            causal_attention: Si se aplica atención causal en el decoder.
            logging_metrics: Lista de métricas a loggear durante el entrenamiento.
            use_compile: Si se debe compilar el método forward (requiere PyTorch 2.0+).
            **kwargs: Argumentos adicionales para BaseModel.
        """
        # Normalización de argumentos opcionales para evitar múltiples comprobaciones posteriores
        static_categoricals = static_categoricals or []
        static_reals = static_reals or []
        time_varying_categoricals_encoder = time_varying_categoricals_encoder or []
        time_varying_categoricals_decoder = time_varying_categoricals_decoder or []
        time_varying_reals_encoder = time_varying_reals_encoder or []
        time_varying_reals_decoder = time_varying_reals_decoder or []
        x_categoricals = x_categoricals or []
        x_reals = x_reals or []
        embedding_labels = embedding_labels or {}
        embedding_paddings = embedding_paddings or []
        embedding_sizes = embedding_sizes or {}
        hidden_continuous_sizes = hidden_continuous_sizes or {}
        categorical_groups = categorical_groups or {}
        if monotone_constaints is None:
            monotone_constaints = {}
        if logging_metrics is None:
            logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE()])
        if loss is None:
            loss = QuantileLoss()

        # Se guardan los hiperparámetros (incluyendo los que ya tienen valor por defecto)
        super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs)
        self.save_hyperparameters(ignore=["use_compile"])
        self.hparams.use_compile = use_compile  # almacenar flag de compilación

        # Creación de los módulos de procesamiento de inputs
        # 1. Embeddings para variables categóricas
        self.input_embeddings = MultiEmbedding(
            embedding_sizes=self.hparams.embedding_sizes,
            categorical_groups=self.hparams.categorical_groups,
            embedding_paddings=self.hparams.embedding_paddings,
            x_categoricals=self.hparams.x_categoricals,
            max_embedding_size=self.hparams.hidden_size,
        )

        # 2. Procesamiento de variables continuas a través de capas lineales (prescalers)
        self.prescalers = nn.ModuleDict(
            {
                name: nn.Linear(
                    1,
                    self.hparams.hidden_continuous_sizes.get(
                        name, self.hparams.hidden_continuous_size
                    ),
                )
                for name in self.reals  # se asume que self.reals se define en la clase base
            }
        )

        # 3. Variable Selection para variables estáticas, encoder y decoder
        # Variables estáticas
        static_input_sizes = {
            name: self.input_embeddings.output_size[name]
            for name in self.hparams.static_categoricals
        }
        static_input_sizes.update(
            {
                name: self.hparams.hidden_continuous_sizes.get(
                    name, self.hparams.hidden_continuous_size
                )
                for name in self.hparams.static_reals
            }
        )
        self.static_variable_selection = VariableSelectionNetwork(
            input_sizes=static_input_sizes,
            hidden_size=self.hparams.hidden_size,
            input_embedding_flags={
                name: True for name in self.hparams.static_categoricals
            },
            dropout=self.hparams.dropout,
            prescalers=self.prescalers,
        )

        # Variables para encoder y decoder
        encoder_input_sizes = {
            name: self.input_embeddings.output_size[name]
            for name in self.hparams.time_varying_categoricals_encoder
        }
        encoder_input_sizes.update(
            {
                name: self.hparams.hidden_continuous_sizes.get(
                    name, self.hparams.hidden_continuous_size
                )
                for name in self.hparams.time_varying_reals_encoder
            }
        )
        decoder_input_sizes = {
            name: self.input_embeddings.output_size[name]
            for name in self.hparams.time_varying_categoricals_decoder
        }
        decoder_input_sizes.update(
            {
                name: self.hparams.hidden_continuous_sizes.get(
                    name, self.hparams.hidden_continuous_size
                )
                for name in self.hparams.time_varying_reals_decoder
            }
        )

        # Opción de compartir redes de variable única entre encoder y decoder
        if self.hparams.share_single_variable_networks:
            self.shared_single_variable_grns = nn.ModuleDict()
            for name, input_size in encoder_input_sizes.items():
                self.shared_single_variable_grns[name] = GatedResidualNetwork(
                    input_size,
                    min(input_size, self.hparams.hidden_size),
                    self.hparams.hidden_size,
                    self.hparams.dropout,
                )
            for name, input_size in decoder_input_sizes.items():
                if name not in self.shared_single_variable_grns:
                    self.shared_single_variable_grns[name] = GatedResidualNetwork(
                        input_size,
                        min(input_size, self.hparams.hidden_size),
                        self.hparams.hidden_size,
                        self.hparams.dropout,
                    )

        self.encoder_variable_selection = VariableSelectionNetwork(
            input_sizes=encoder_input_sizes,
            hidden_size=self.hparams.hidden_size,
            input_embedding_flags={
                name: True for name in self.hparams.time_varying_categoricals_encoder
            },
            dropout=self.hparams.dropout,
            context_size=self.hparams.hidden_size,
            prescalers=self.prescalers,
            single_variable_grns=(
                {} if not self.hparams.share_single_variable_networks else self.shared_single_variable_grns
            ),
        )

        self.decoder_variable_selection = VariableSelectionNetwork(
            input_sizes=decoder_input_sizes,
            hidden_size=self.hparams.hidden_size,
            input_embedding_flags={
                name: True for name in self.hparams.time_varying_categoricals_decoder
            },
            dropout=self.hparams.dropout,
            context_size=self.hparams.hidden_size,
            prescalers=self.prescalers,
            single_variable_grns=(
                {} if not self.hparams.share_single_variable_networks else self.shared_single_variable_grns
            ),
        )

        # 4. Codificadores estáticos para variable selection, estado inicial y enriquecimiento
        self.static_context_variable_selection = GatedResidualNetwork(
            input_size=self.hparams.hidden_size,
            hidden_size=self.hparams.hidden_size,
            output_size=self.hparams.hidden_size,
            dropout=self.hparams.dropout,
        )
        self.static_context_initial_hidden_lstm = GatedResidualNetwork(
            input_size=self.hparams.hidden_size,
            hidden_size=self.hparams.hidden_size,
            output_size=self.hparams.hidden_size,
            dropout=self.hparams.dropout,
        )
        self.static_context_initial_cell_lstm = GatedResidualNetwork(
            input_size=self.hparams.hidden_size,
            hidden_size=self.hparams.hidden_size,
            output_size=self.hparams.hidden_size,
            dropout=self.hparams.dropout,
        )
        self.static_context_enrichment = GatedResidualNetwork(
            self.hparams.hidden_size,
            self.hparams.hidden_size,
            self.hparams.hidden_size,
            self.hparams.dropout,
        )

        # 5. LSTM para procesamiento local: encoder (histórico) y decoder (futuro)
        self.lstm_encoder = LSTM(
            input_size=self.hparams.hidden_size,
            hidden_size=self.hparams.hidden_size,
            num_layers=self.hparams.lstm_layers,
            dropout=self.hparams.dropout if self.hparams.lstm_layers > 1 else 0,
            batch_first=True,
        )
        self.lstm_decoder = LSTM(
            input_size=self.hparams.hidden_size,
            hidden_size=self.hparams.hidden_size,
            num_layers=self.hparams.lstm_layers,
            dropout=self.hparams.dropout if self.hparams.lstm_layers > 1 else 0,
            batch_first=True,
        )

        # 6. Skip connections y normalización post-LSTM
        self.post_lstm_gate_encoder = GatedLinearUnit(
            self.hparams.hidden_size, dropout=self.hparams.dropout
        )
        # Reutilizamos la misma instancia para decoder para simplificar
        self.post_lstm_gate_decoder = self.post_lstm_gate_encoder
        self.post_lstm_add_norm_encoder = AddNorm(
            self.hparams.hidden_size, trainable_add=False
        )
        self.post_lstm_add_norm_decoder = self.post_lstm_add_norm_encoder

        # 7. Enriquecimiento estático y atención para procesamiento a largo alcance
        self.static_enrichment = GatedResidualNetwork(
            input_size=self.hparams.hidden_size,
            hidden_size=self.hparams.hidden_size,
            output_size=self.hparams.hidden_size,
            dropout=self.hparams.dropout,
            context_size=self.hparams.hidden_size,
        )
        self.multihead_attn = InterpretableMultiHeadAttention(
            d_model=self.hparams.hidden_size,
            n_head=self.hparams.attention_head_size,
            dropout=self.hparams.dropout,
        )
        self.post_attn_gate_norm = GateAddNorm(
            self.hparams.hidden_size, dropout=self.hparams.dropout, trainable_add=False
        )
        self.pos_wise_ff = GatedResidualNetwork(
            self.hparams.hidden_size,
            self.hparams.hidden_size,
            self.hparams.hidden_size,
            dropout=self.hparams.dropout,
        )

        # 8. Procesamiento de salida (sin dropout en esta etapa)
        self.pre_output_gate_norm = GateAddNorm(
            self.hparams.hidden_size, dropout=None, trainable_add=False
        )
        if self.n_targets > 1:  # arquitectura multiobjetivo
            self.output_layer = nn.ModuleList(
                [
                    nn.Linear(self.hparams.hidden_size, osize)
                    for osize in (self.hparams.output_size if isinstance(self.hparams.output_size, list)
                                  else [self.hparams.output_size])
                ]
            )
        else:
            self.output_layer = nn.Linear(
                self.hparams.hidden_size, self.hparams.output_size
            )

        # Cache para máscaras de atención
        self._attention_mask_cache = {}
        
        # Compilación mejorada para otras funciones críticas
        if self.hparams.use_compile and hasattr(torch, "compile"):
            self.forward = torch.compile(self.forward)
            self.interpret_output = torch.compile(
                self.interpret_output,
                dynamic=True  # Para manejar tamaños variables
            )
            # Se añade un docstring indicando que el método forward ha sido compilado
            self.forward.__doc__ = (self.forward.__doc__ or "") + "\n\nOptimized with torch.compile."

In [None]:
    # Bloque 4: Métodos de utilidad

    @classmethod
    def from_dataset(
        cls,
        dataset: TimeSeriesDataSet,
        allowed_encoder_known_variable_names: List[str] = None,
        **kwargs,
    ):
        """
        Crea el modelo a partir de un dataset.

        Args:
            dataset: Dataset de series temporales.
            allowed_encoder_known_variable_names: Lista de variables conocidas permitidas en el encoder.
            **kwargs: Argumentos adicionales (p.ej. hiperparámetros).

        Returns:
            Instancia de TemporalFusionTransformer.
        """
        new_kwargs = copy(kwargs)
        new_kwargs["max_encoder_length"] = dataset.max_encoder_length
        new_kwargs.update(
            cls.deduce_default_output_parameters(dataset, kwargs, QuantileLoss())
        )
        return super().from_dataset(
            dataset,
            allowed_encoder_known_variable_names=allowed_encoder_known_variable_names,
            **new_kwargs,
        )

    def expand_static_context(self, context: torch.Tensor, timesteps: int) -> torch.Tensor:
        """
        Expande el contexto estático a dimensión temporal.

        Args:
            context: Tensor con el contexto estático (batch, features).
            timesteps: Número de timesteps a expandir.

        Returns:
            Tensor expandido de dimensiones (batch, timesteps, features).
        """
        return context.unsqueeze(1).expand(-1, timesteps, -1)

    # Cache de máscaras de atención
    def get_attention_mask(
        self, encoder_lengths: torch.LongTensor, decoder_lengths: torch.LongTensor
    ) -> torch.Tensor:
        """
        Genera una máscara causal para la capa de self-atención con cache.

        Args:
            encoder_lengths: Longitudes efectivas del encoder en el batch.
            decoder_lengths: Longitudes efectivas del decoder en el batch.

        Returns:
            Tensor máscara combinado para el encoder y decoder.
        """
        # Usar una clave de cache basada en las dimensiones
        cache_key = (encoder_lengths.max().item(), decoder_lengths.max().item())
        if cache_key not in self._attention_mask_cache:
            decoder_length = decoder_lengths.max()
            if self.hparams.causal_attention:
                attend_step = torch.arange(decoder_length, device=self.device)
                predict_step = torch.arange(decoder_length, device=self.device).unsqueeze(1)
                decoder_mask = (attend_step >= predict_step).unsqueeze(0).expand(encoder_lengths.size(0), -1, -1)
            else:
                decoder_mask = create_mask(decoder_length, decoder_lengths).unsqueeze(1).expand(-1, decoder_length, -1)
            encoder_mask = create_mask(encoder_lengths.max(), encoder_lengths).unsqueeze(1).expand(-1, decoder_length, -1)
            self._attention_mask_cache[cache_key] = torch.cat((encoder_mask, decoder_mask), dim=2)
        return self._attention_mask_cache[cache_key]

In [None]:
    # Bloque 5: Método forward

    def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Ejecuta la propagación hacia adelante del modelo.

        Las dimensiones de entrada se esperan como: (n_samples x time x variables).

        Args:
            x: Diccionario con tensores de entrada, incluyendo:
                - "encoder_cat", "decoder_cat"
                - "encoder_cont", "decoder_cont"
                - "encoder_lengths", "decoder_lengths"
                - "target_scale", entre otros

        Returns:
            Diccionario con la predicción y datos auxiliares (e.g., pesos de atención).
        """
        # Precalcular dimensiones frecuentes
        batch_size = x["encoder_lengths"].size(0)
        encoder_lengths = x["encoder_lengths"]
        decoder_lengths = x["decoder_lengths"]
        max_encoder_length = int(encoder_lengths.max())
        
        # Concatenar inputs de categorías y continuos en la dimensión temporal
        x_cat = torch.cat([x["encoder_cat"], x["decoder_cat"]], dim=1)
        x_cont = torch.cat([x["encoder_cont"], x["decoder_cont"]], dim=1)
        timesteps = x_cont.size(1)
        
        input_vectors = self.input_embeddings(x_cat)
        # Agregar variables continuas a partir del orden definido en x_reals
        input_vectors.update(
            {
                name: x_cont[..., idx].unsqueeze(-1)
                for idx, name in enumerate(self.hparams.x_reals)
                if name in self.reals
            }
        )

        # Variable selection para variables estáticas
        if len(self.static_variables) > 0:
            static_embedding = {name: input_vectors[name][:, 0] for name in self.static_variables}
            static_embedding, static_variable_selection = self.static_variable_selection(static_embedding)
        else:
            static_embedding = torch.zeros(
                (batch_size, self.hparams.hidden_size),
                dtype=self.dtype, device=self.device
            )
            static_variable_selection = torch.zeros(
                (batch_size, 0), dtype=self.dtype, device=self.device
            )

        static_context_variable_selection = self.expand_static_context(
            self.static_context_variable_selection(static_embedding), timesteps
        )

        # Variable selection para encoder y decoder (a partir de embeddings)
        embeddings_varying_encoder = {
            name: input_vectors[name][:, :max_encoder_length]
            for name in self.encoder_variables
        }
        embeddings_varying_encoder, encoder_sparse_weights = self.encoder_variable_selection(
            embeddings_varying_encoder,
            static_context_variable_selection[:, :max_encoder_length],
        )
        embeddings_varying_decoder = {
            name: input_vectors[name][:, max_encoder_length:]
            for name in self.decoder_variables
        }
        embeddings_varying_decoder, decoder_sparse_weights = self.decoder_variable_selection(
            embeddings_varying_decoder,
            static_context_variable_selection[:, max_encoder_length:],
        )

        # Inicialización optimizada de estados LSTM
        init_hidden = self.static_context_initial_hidden_lstm(static_embedding)
        init_cell = self.static_context_initial_cell_lstm(static_embedding)
        input_hidden = init_hidden.unsqueeze(0).expand(self.hparams.lstm_layers, -1, -1)
        input_cell = init_cell.unsqueeze(0).expand(self.hparams.lstm_layers, -1, -1)

        # Procesamiento LSTM para encoder y decoder
        encoder_output, (hidden, cell) = self.lstm_encoder(
            embeddings_varying_encoder,
            (input_hidden, input_cell),
            lengths=encoder_lengths,
            enforce_sorted=False,
        )
        decoder_output, _ = self.lstm_decoder(
            embeddings_varying_decoder,
            (hidden, cell),
            lengths=decoder_lengths,
            enforce_sorted=False,
        )

        # Procesamiento fusionado post-LSTM para mejorar eficiencia
        lstm_output_encoder = self.process_lstm_output(encoder_output, embeddings_varying_encoder, True)
        lstm_output_decoder = self.process_lstm_output(decoder_output, embeddings_varying_decoder, False)
        lstm_output = torch.cat([lstm_output_encoder, lstm_output_decoder], dim=1)

        # Enriquecimiento estático
        static_context_enrichment = self.static_context_enrichment(static_embedding)
        attn_input = self.static_enrichment(
            lstm_output,
            self.expand_static_context(static_context_enrichment, timesteps),
        )

        # Usar atención eficiente en memoria para secuencias largas
        if hasattr(self, "efficient_attention") and max_encoder_length > 100:
            attn_output, attn_output_weights = self.efficient_attention(
                q=attn_input[:, max_encoder_length:],
                k=attn_input,
                v=attn_input,
                mask=self.get_attention_mask(encoder_lengths=encoder_lengths, decoder_lengths=decoder_lengths),
            )
        else:
            # Atención multi-cabeza con máscara (caso estándar)
            attn_output, attn_output_weights = self.multihead_attn(
                q=attn_input[:, max_encoder_length:],
                k=attn_input,
                v=attn_input,
                mask=self.get_attention_mask(encoder_lengths=encoder_lengths, decoder_lengths=decoder_lengths),
            )

        attn_output = self.post_attn_gate_norm(attn_output, attn_input[:, max_encoder_length:])
        output = self.pos_wise_ff(attn_output)

        # Skip connection final antes de salida
        output = self.pre_output_gate_norm(output, lstm_output[:, max_encoder_length:])
        
        # Procesamiento paralelo para salida multiobjetivo
        if self.n_targets > 1:
            output = self.transform_multi_output(output)
        else:
            output = self.output_layer(output)

        return self.to_network_output(
            prediction=self.transform_output(output, target_scale=x["target_scale"]),
            encoder_attention=attn_output_weights[..., :max_encoder_length],
            decoder_attention=attn_output_weights[..., max_encoder_length:],
            static_variables=static_variable_selection,
            encoder_variables=encoder_sparse_weights,
            decoder_variables=decoder_sparse_weights,
            decoder_lengths=decoder_lengths,
            encoder_lengths=encoder_lengths,
        )

In [None]:
    # Bloque 6: Métodos para interpretabilidad y logging

    def on_fit_end(self):
        """
        Se ejecuta al finalizar el entrenamiento. Si el logging está activo, se loggean las embeddings.
        """
        if self.log_interval > 0:
            self.log_embeddings()

    def create_log(self, x, y, out, batch_idx, **kwargs):
        """
        Crea un log de salida que incluye la interpretación de la salida si procede.

        Args:
            x: Input del modelo.
            y: Target del modelo.
            out: Salida del modelo.
            batch_idx: Índice del batch.
            **kwargs: Argumentos adicionales para logging.
        """
        log = super().create_log(x, y, out, batch_idx, **kwargs)
        if self.log_interval > 0:
            log["interpretation"] = self._log_interpretation(out)
        return log

    def _log_interpretation(self, out):
        """
        Calcula la interpretación de la salida para logging.
        """
        interpretation = self.interpret_output(
            detach(out),
            reduction="sum",
            attention_prediction_horizon=0,
        )
        return interpretation

    def on_epoch_end(self, outputs):
        """
        Se ejecuta al finalizar la época (tanto en entrenamiento como en validación) para loggear la interpretación.
        """
        if self.log_interval > 0 and not self.training:
            self.log_interpretation(outputs)

In [None]:
    # Bloque 7: Método para interpretar la salida

    def interpret_output(
        self,
        out: Dict[str, torch.Tensor],
        reduction: str = "none",
        attention_prediction_horizon: int = 0,
    ) -> Dict[str, torch.Tensor]:
        """
        Interpreta la salida del modelo.

        Args:
            out: Salida del modelo (resultado de forward).
            reduction: Método de reducción ("none", "sum" o "mean").
            attention_prediction_horizon: Horizonte de predicción para la atención.

        Returns:
            Diccionario con interpretaciones (atención, importancia de variables, etc.).
        """
        batch_size = len(out["decoder_attention"])
        # Procesa atención para decoder (si es lista, combina)
        if isinstance(out["decoder_attention"], (list, tuple)):
            max_last_dimension = max(x.size(-1) for x in out["decoder_attention"])
            first_elm = out["decoder_attention"][0]
            decoder_attention = torch.full(
                (batch_size, *first_elm.shape[:-1], max_last_dimension),
                float("nan"),
                dtype=first_elm.dtype,
                device=first_elm.device,
            )
            for idx, att in enumerate(out["decoder_attention"]):
                decoder_length = out["decoder_lengths"][idx]
                decoder_attention[idx, :, :, :decoder_length] = att[..., :decoder_length]
        else:
            decoder_attention = out["decoder_attention"].clone()
            decoder_mask = create_mask(out["decoder_attention"].size(1), out["decoder_lengths"])
            decoder_attention[decoder_mask[..., None, None].expand_as(decoder_attention)] = float("nan")

        # Procesa atención para encoder
        if isinstance(out["encoder_attention"], (list, tuple)):
            first_elm = out["encoder_attention"][0]
            encoder_attention = torch.full(
                (batch_size, *first_elm.shape[:-1], self.hparams.max_encoder_length),
                float("nan"),
                dtype=first_elm.dtype,
                device=first_elm.device,
            )
            for idx, att in enumerate(out["encoder_attention"]):
                encoder_length = out["encoder_lengths"][idx]
                encoder_attention[idx, :, :, self.hparams.max_encoder_length - encoder_length :] = att[..., :encoder_length]
        else:
            encoder_attention = out["encoder_attention"].clone()
            shifts = encoder_attention.size(3) - out["encoder_lengths"]
            new_index = (torch.arange(encoder_attention.size(3), device=encoder_attention.device)[None, None, None]
                         .expand_as(encoder_attention) - shifts[:, None, None, None]) % encoder_attention.size(3)
            encoder_attention = torch.gather(encoder_attention, dim=3, index=new_index)
            if encoder_attention.size(-1) < self.hparams.max_encoder_length:
                encoder_attention = torch.concat(
                    [
                        torch.full(
                            ( *encoder_attention.shape[:-1],
                              self.hparams.max_encoder_length - out["encoder_lengths"].max(),),
                            float("nan"),
                            dtype=encoder_attention.dtype,
                            device=encoder_attention.device,
                        ),
                        encoder_attention,
                    ],
                    dim=-1,
                )

        attention = torch.concat([encoder_attention, decoder_attention], dim=-1)
        attention[attention < 1e-5] = float("nan")

        encoder_length_histogram = integer_histogram(
            out["encoder_lengths"], min=0, max=self.hparams.max_encoder_length
        )
        decoder_length_histogram = integer_histogram(
            out["decoder_lengths"], min=1, max=out["decoder_variables"].size(1)
        )

        # Vectorización del procesamiento de variables
        encoder_variables = self.process_variables_importance(
            out["encoder_variables"].squeeze(-2).clone(),
            out["encoder_lengths"]
        )
        
        decoder_variables = self.process_variables_importance(
            out["decoder_variables"].squeeze(-2).clone(),
            out["decoder_lengths"]
        )

        static_variables = out["static_variables"].squeeze(1)
        attention = masked_op(
            attention[:, attention_prediction_horizon, :, : self.hparams.max_encoder_length + attention_prediction_horizon],
            op="mean", dim=1
        )

        if reduction != "none":
            static_variables = static_variables.sum(dim=0)
            encoder_variables = encoder_variables.sum(dim=0)
            decoder_variables = decoder_variables.sum(dim=0)
            attention = masked_op(attention, dim=0, op=reduction)
        else:
            attention = attention / masked_op(attention, dim=1, op="sum").unsqueeze(-1)

        interpretation = dict(
            attention=attention.masked_fill(torch.isnan(attention), 0.0),
            static_variables=static_variables,
            encoder_variables=encoder_variables,
            decoder_variables=decoder_variables,
            encoder_length_histogram=encoder_length_histogram,
            decoder_length_histogram=decoder_length_histogram,
        )
        return interpretation

In [None]:
    # Bloque 8: Métodos para visualización

    def plot_prediction(
        self,
        x: Dict[str, torch.Tensor],
        out: Dict[str, torch.Tensor],
        idx: int,
        plot_attention: bool = True,
        add_loss_to_title: bool = False,
        show_future_observed: bool = True,
        ax=None,
        **kwargs,
    ):
        """
        Grafica las predicciones vs. los valores reales y la atención asociada.

        Args:
            x: Input del modelo.
            out: Salida del modelo.
            idx: Índice de la muestra a graficar.
            plot_attention: Si se debe graficar la atención.
            add_loss_to_title: Si se debe agregar la pérdida en el título.
            show_future_observed: Si se deben mostrar los valores futuros observados.
            ax: Ejes de matplotlib para graficar.
            **kwargs: Argumentos adicionales para plotting.

        Returns:
            Figura(s) de matplotlib con la predicción y, opcionalmente, la atención.
        """
        fig = super().plot_prediction(
            x, out, idx=idx, add_loss_to_title=add_loss_to_title, show_future_observed=show_future_observed, ax=ax, **kwargs
        )
        if plot_attention:
            interpretation = self.interpret_output(out.iget(slice(idx, idx + 1)))
            for f in to_list(fig):
                ax = f.axes[0]
                ax2 = ax.twinx()
                ax2.set_ylabel("Attention")
                encoder_length = x["encoder_lengths"][0]
                ax2.plot(
                    torch.arange(-encoder_length, 0),
                    interpretation["attention"][0, -encoder_length:].detach().cpu(),
                    alpha=0.2,
                    color="k",
                )
                f.tight_layout()
        return fig

    def plot_interpretation(self, interpretation: Dict[str, torch.Tensor]):
        """
        Crea figuras que interpretan la salida del modelo: atención y pesos de variable selection.

        Args:
            interpretation: Diccionario con interpretaciones obtenidas de interpret_output.

        Returns:
            Diccionario de figuras de matplotlib.
        """
        _check_matplotlib("plot_interpretation")
        import matplotlib.pyplot as plt

        figs = {}

        # Gráfica de atención
        fig, ax = plt.subplots()
        attention = interpretation["attention"].detach().cpu()
        attention = attention / attention.sum(-1).unsqueeze(-1)
        ax.plot(
            np.arange(-self.hparams.max_encoder_length, attention.size(0) - self.hparams.max_encoder_length),
            attention,
        )
        ax.set_xlabel("Time index")
        ax.set_ylabel("Attention")
        ax.set_title("Attention")
        figs["attention"] = fig

        # Función auxiliar para graficar la importancia de variables
        def make_selection_plot(title, values, labels):
            fig, ax = plt.subplots(figsize=(7, len(values) * 0.25 + 2))
            order = np.argsort(values)
            values = values / values.sum(-1).unsqueeze(-1)
            ax.barh(np.arange(len(values)), values[order] * 100, tick_label=np.asarray(labels)[order])
            ax.set_title(title)
            ax.set_xlabel("Importance in %")
            plt.tight_layout()
            return fig

        figs["static_variables"] = make_selection_plot("Static variables importance",
                                                       interpretation["static_variables"].detach().cpu(),
                                                       self.static_variables)
        figs["encoder_variables"] = make_selection_plot("Encoder variables importance",
                                                        interpretation["encoder_variables"].detach().cpu(),
                                                        self.encoder_variables)
        figs["decoder_variables"] = make_selection_plot("Decoder variables importance",
                                                        interpretation["decoder_variables"].detach().cpu(),
                                                        self.decoder_variables)
        return figs

In [None]:
    # Bloque 9: Métodos para logging adicional

    def log_interpretation(self, outputs):
        """
        Loggea las interpretaciones en Tensorboard.
        """
        interpretation = {
            name: padded_stack(
                [x["interpretation"][name].detach() for x in outputs],
                side="right",
                value=0,
            ).sum(0)
            for name in outputs[0]["interpretation"].keys()
        }
        attention_occurances = (
            interpretation["encoder_length_histogram"][1:].flip(0).float().cumsum(0)
        )
        attention_occurances = attention_occurances / attention_occurances.max()
        attention_occurances = torch.cat(
            [
                attention_occurances,
                torch.ones(
                    interpretation["attention"].size(0) - attention_occurances.size(0),
                    dtype=attention_occurances.dtype,
                    device=attention_occurances.device,
                ),
            ],
            dim=0,
        )
        interpretation["attention"] = interpretation["attention"] / attention_occurances.pow(2).clamp(1.0)
        interpretation["attention"] = interpretation["attention"] / interpretation["attention"].sum()

        mpl_available = _check_matplotlib("log_interpretation", raise_error=False)
        if not mpl_available or not self._logger_supports("add_figure"):
            return None

        import matplotlib.pyplot as plt
        figs = self.plot_interpretation(interpretation)
        label = self.current_stage
        for name, fig in figs.items():
            self.logger.experiment.add_figure(
                f"{label.capitalize()} {name} importance", fig, global_step=self.global_step
            )
        for type in ["encoder", "decoder"]:
            fig, ax = plt.subplots()
            lengths = padded_stack(
                [out["interpretation"][f"{type}_length_histogram"] for out in outputs]
            ).sum(0).detach().cpu()
            start = 1 if type == "decoder" else 0
            ax.plot(torch.arange(start, start + len(lengths)), lengths)
            ax.set_xlabel(f"{type.capitalize()} length")
            ax.set_ylabel("Number of samples")
            ax.set_title(f"{type.capitalize()} length distribution in {label} epoch")
            self.logger.experiment.add_figure(
                f"{label.capitalize()} {type} length distribution", fig, global_step=self.global_step
            )

    def log_embeddings(self):
        """
        Loggea los embeddings en Tensorboard.
        """
        if not self._logger_supports("add_embedding"):
            return None
        for name, emb in self.input_embeddings.items():
            labels = self.hparams.embedding_labels.get(name, None)
            self.logger.experiment.add_embedding(
                emb.weight.data.detach().cpu(),
                metadata=labels,
                tag=name,
                global_step=self.global_step,
            )

In [None]:
#Bloque 10: Ejemplo de uso del modelo
# Este bloque muestra cómo se puede crear y utilizar el modelo con un dataset de ejemplo

def create_example_model():
    # Importamos las dependencias necesarias
    from pytorch_forecasting.data import TimeSeriesDataSet
    from pytorch_forecasting.data.examples import generate_ar_data
    
    # Generamos datos de ejemplo
    data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=100)
    
    # Creamos el TimeSeriesDataSet
    training_cutoff = data["time_idx"].max() - 50
    training = TimeSeriesDataSet(
        data[data["time_idx"] <= training_cutoff],
        time_idx="time_idx",
        target="value",
        group_ids=["series"],
        max_encoder_length=30,
        max_prediction_length=10,
        static_categoricals=[],
        static_reals=[],
        time_varying_known_categoricals=[],
        time_varying_known_reals=["time_idx"],
        time_varying_unknown_categoricals=[],
        time_varying_unknown_reals=["value"],
    )
    
    # Dataset de validación
    validation = TimeSeriesDataSet.from_dataset(training, data, min_prediction_idx=training_cutoff + 1)
    
    # Creamos dataloaders
    batch_size = 128
    train_dataloader = training.to_dataloader(train=True, batch_size=batch_size)
    val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size)
    
    # Creamos el modelo
    tft = TemporalFusionTransformer.from_dataset(
        training,
        learning_rate=0.03,
        hidden_size=32,
        attention_head_size=1,
        dropout=0.1,
        hidden_continuous_size=16,
        loss=QuantileLoss(),
        log_interval=10,
        use_compile=True,  # usamos la optimización de compilación
    )
    
    return tft, train_dataloader, val_dataloader

# Ejemplo de entrenamiento (descomentar para ejecutar)
# import pytorch_lightning as pl
# tft, train_dataloader, val_dataloader = create_example_model()
# trainer = pl.Trainer(max_epochs=10, accelerator='gpu', devices=1)
# trainer.fit(tft, train_dataloader, val_dataloader)