In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Temporal Fusion Transformer - Versión Optimizada\n",
    "\n",
    "Esta versión del notebook implementa las mejoras de rendimiento y lógica discutidas. El objetivo es maximizar la velocidad de entrenamiento y la eficiencia de memoria.\n",
    "\n",
    "### Resumen de Optimizaciones Implementadas:\n",
    "\n",
    "1.  **Atención de Alto Rendimiento (`Flash Attention`):** Se ha modificado el módulo `InterpretableMultiHeadAttention` para usar `torch.nn.functional.scaled_dot_product_attention`, que es significativamente más rápido y eficiente en memoria.\n",
    "    - **Trade-off:** Esta optimización implica que los pesos de atención no se calculan durante el `forward pass`. Como resultado, las funcionalidades de visualización e interpretación de la atención han sido desactivadas.\n",
    "2.  **Gradient Checkpointing:** Se ha añadido un nuevo hiperparámetro `gradient_checkpointing` al modelo. Cuando se activa, reduce drásticamente el consumo de memoria a costa de un ligero aumento en el tiempo de computación, permitiendo entrenar con secuencias o lotes más grandes.\n",
    "3.  **Lógica de Interpretación Simplificada:** El método `interpret_output` se ha refactorizado, eliminando la compleja y lenta operación `gather` para alinear la atención del encoder. Ahora se centra en la importancia de las variables.\n",
    "4.  **Inicialización LSTM Consistente:** Se ha corregido el método `forward` para que utilice la función `initialize_lstm_states` optimizada con JIT, eliminando código redundante.\n",
    "5.  **Código Autocontenido:** Todos los sub-módulos necesarios (`GatedResidualNetwork`, `VariableSelectionNetwork`, etc.) se han incluido directamente en este notebook para que sea totalmente funcional sin dependencias ocultas."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e9e9855-350d-4727-a74b-3dbeb22f301a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Bloque 1: Importaciones y docstring inicial\n",
    "\n",
    "import math\n",
    "from copy import copy\n",
    "from typing import Dict, List, Optional, Tuple, Union\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch import nn\n",
    "from torch.utils.checkpoint import checkpoint\n",
    "from torchmetrics import Metric as LightningMetric\n",
    "\n",
    "from pytorch_forecasting.data import TimeSeriesDataSet\n",
    "from pytorch_forecasting.metrics import (\n",
    "    MAE,\n",
    "    MAPE,\n",
    "    RMSE,\n",
    "    SMAPE,\n",
    "    MultiHorizonMetric,\n",
    "    QuantileLoss,\n",
    ")\n",
    "from pytorch_forecasting.models.base_model import BaseModelWithCovariates\n",
    "from pytorch_forecasting.models.nn import LSTM, MultiEmbedding\n",
    "from pytorch_forecasting.utils import (\n",
    "    create_mask,\n",
    "    detach,\n",
    "    integer_histogram,\n",
    "    masked_op,\n",
    "    padded_stack,\n",
    "    to_list,\n",
    ")\n",
    "from pytorch_forecasting.utils._dependencies import _check_matplotlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7c2a1e8-d1f5-4e2b-8a8f-2a4c9f13e7d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Bloque 2: Definición de Sub-Módulos (para un notebook autocontenido)\n",
    "\n",
    "class AddNorm(nn.Module):\n",
    "    def __init__(self, input_size: int, trainable_add: bool = True):\n",
    "        super().__init__()\n",
    "        self.input_size = input_size\n",
    "        self.add = nn.Parameter(torch.zeros(1, 1, input_size), requires_grad=trainable_add)\n",
    "        self.norm = nn.LayerNorm(input_size)\n",
    "\n",
    "    def forward(self, x: torch.Tensor, add: Optional[torch.Tensor] = None):\n",
    "        if add is not None:\n",
    "            x = x + add\n",
    "        return self.norm(x + self.add)\n",
    "\n",
    "class GatedLinearUnit(nn.Module):\n",
    "    def __init__(self, input_size: int, hidden_size: Optional[int] = None, dropout: Optional[float] = None):\n",
    "        super().__init__()\n",
    "        if dropout is not None:\n",
    "            self.dropout = nn.Dropout(dropout)\n",
    "        else:\n",
    "            self.dropout = None\n",
    "        self.hidden_size = hidden_size or input_size\n",
    "        self.fc = nn.Linear(input_size, self.hidden_size * 2)\n",
    "        self.init_weights()\n",
    "\n",
    "    def init_weights(self):\n",
    "        for n, p in self.named_parameters():\n",
    "            if \"bias\" in n:\n",
    "                torch.nn.init.zeros_(p)\n",
    "            elif \"weight\" in n:\n",
    "                torch.nn.init.xavier_uniform_(p)\n",
    "\n",
    "    def forward(self, x):\n",
    "        if self.dropout is not None:\n",
    "            x = self.dropout(x)\n",
    "        x = self.fc(x)\n",
    "        x = F.glu(x, dim=-1)\n",
    "        return x\n",
    "\n",
    "class GateAddNorm(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size=None, dropout=None, trainable_add=False):\n",
    "        super().__init__()\n",
    "        self.input_size = input_size\n",
    "        self.hidden_size = hidden_size or input_size\n",
    "        self.gating_unit = GatedLinearUnit(self.input_size, self.hidden_size, dropout)\n",
    "        self.add_norm = AddNorm(self.hidden_size, trainable_add=trainable_add)\n",
    "\n",
    "    def forward(self, x, skip):\n",
    "        output = self.gating_unit(x)\n",
    "        return self.add_norm(output, skip)\n",
    "\n",
    "\n",
    "class GatedResidualNetwork(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, output_size, dropout=0.1, context_size=None):\n",
    "        super().__init__()\n",
    "        self.input_size = input_size\n",
    "        self.output_size = output_size\n",
    "        self.context_size = context_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.dropout = dropout\n",
    "\n",
    "        if self.input_size != self.output_size:\n",
    "            self.skip_layer = nn.Linear(self.input_size, self.output_size)\n",
    "\n",
    "        self.fc1 = nn.Linear(self.input_size, self.hidden_size)\n",
    "        self.elu = nn.ELU()\n",
    "        self.fc2 = nn.Linear(self.hidden_size, self.hidden_size)\n",
    "        self.dropout_layer = nn.Dropout(self.dropout)\n",
    "\n",
    "        if self.context_size is not None:\n",
    "            self.context_vector = nn.Linear(self.context_size, self.hidden_size, bias=False)\n",
    "\n",
    "        self.gate = GatedLinearUnit(self.hidden_size, output_size, dropout=self.dropout)\n",
    "        self.norm = nn.LayerNorm(self.output_size)\n",
    "\n",
    "    def forward(self, x, context=None):\n",
    "        if self.input_size != self.output_size:\n",
    "            residual = self.skip_layer(x)\n",
    "        else:\n",
    "            residual = x\n",
    "\n",
    "        x = self.fc1(x)\n",
    "        if context is not None:\n",
    "            context = self.context_vector(context)\n",
    "            x = x + context\n",
    "\n",
    "        x = self.elu(x)\n",
    "        x = self.fc2(x)\n",
    "        x = self.dropout_layer(x)\n",
    "        x = self.gate(x)\n",
    "        return self.norm(x + residual)\n",
    "\n",
    "class VariableSelectionNetwork(nn.Module):\n",
    "    def __init__(self, input_sizes, hidden_size, input_embedding_flags=None, dropout=0.1, context_size=None, prescalers=None, single_variable_grns=None):\n",
    "        super().__init__()\n",
    "        if prescalers is None: prescalers = {}\n",
    "        if input_embedding_flags is None: input_embedding_flags = {}\n",
    "        if single_variable_grns is None: single_variable_grns = {}\n",
    "        \n",
    "        self.hidden_size = hidden_size\n",
    "        self.input_sizes = input_sizes\n",
    "        self.input_embedding_flags = input_embedding_flags\n",
    "        self.dropout = dropout\n",
    "        self.context_size = context_size\n",
    "        self.prescalers = prescalers\n",
    "\n",
    "        self.flattened_grn = GatedResidualNetwork(\n",
    "            sum(self.input_sizes.values()), self.hidden_size, len(self.input_sizes), self.dropout, self.context_size\n",
    "        )\n",
    "        self.single_variable_grns = nn.ModuleDict()\n",
    "        for name, input_size in self.input_sizes.items():\n",
    "            if name in single_variable_grns:\n",
    "                self.single_variable_grns[name] = single_variable_grns[name]\n",
    "            else:\n",
    "                self.single_variable_grns[name] = GatedResidualNetwork(\n",
    "                    input_size, min(input_size, self.hidden_size), self.hidden_size, self.dropout\n",
    "                )\n",
    "\n",
    "    def forward(self, x, context=None):\n",
    "        if self.context_size is not None and context is None:\n",
    "            raise ValueError(\"Context must be supplied\")\n",
    "\n",
    "        # Transform single variables\n",
    "        var_outputs = []\n",
    "        for name in self.input_sizes.keys():\n",
    "            # Pre-scale continuous variables\n",
    "            if name in self.prescalers:\n",
    "                variable_input = self.prescalers[name](x[name])\n",
    "            else:\n",
    "                variable_input = x[name]\n",
    "            var_outputs.append(self.single_variable_grns[name](variable_input))\n",
    "\n",
    "        # Concat all variables\n",
    "        flat_vars = torch.cat(var_outputs, dim=-1)\n",
    "        \n",
    "        # Get weights\n",
    "        sparse_weights = self.flattened_grn(flat_vars, context)\n",
    "        sparse_weights = F.softmax(sparse_weights, dim=-1)\n",
    "\n",
    "        # Combine variables\n",
    "        var_outputs = torch.stack(var_outputs, dim=-2)\n",
    "        outputs = torch.sum(sparse_weights.unsqueeze(-1) * var_outputs, dim=-2)\n",
    "        return outputs, sparse_weights\n",
    "\n",
    "class InterpretableMultiHeadAttention(nn.Module):\n",
    "    def __init__(self, d_model: int, n_head: int, dropout: float = 0.0):\n",
    "        super().__init__()\n",
    "        self.d_model = d_model\n",
    "        self.n_head = n_head\n",
    "        self.dropout = dropout\n",
    "        self.head_dim = d_model // n_head\n",
    "        assert self.head_dim * n_head == self.d_model, \"d_model must be divisible by n_head\"\n",
    "\n",
    "        self.q_proj = nn.Linear(d_model, d_model)\n",
    "        self.k_proj = nn.Linear(d_model, d_model)\n",
    "        self.v_proj = nn.Linear(d_model, d_model)\n",
    "        self.out_proj = nn.Linear(d_model, d_model)\n",
    "\n",
    "    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor = None):\n",
    "        batch_size = q.size(0)\n",
    "        \n",
    "        # Project and reshape\n",
    "        q = self.q_proj(q).view(batch_size, -1, self.n_head, self.head_dim).transpose(1, 2)\n",
    "        k = self.k_proj(k).view(batch_size, -1, self.n_head, self.head_dim).transpose(1, 2)\n",
    "        v = self.v_proj(v).view(batch_size, -1, self.n_head, self.head_dim).transpose(1, 2)\n",
    "\n",
    "        # Convert mask for scaled_dot_product_attention: True means ignore\n",
    "        if mask is not None:\n",
    "            # The input mask has True for positions to attend to. We need the opposite.\n",
    "            # We also need to expand it for the head dimension.\n",
    "            attn_mask = ~mask.unsqueeze(1).repeat(1, self.n_head, 1, 1)\n",
    "        else:\n",
    "            attn_mask = None\n",
    "        \n",
    "        # Use Pytorch 2.0's optimized attention\n",
    "        # NOTE: This is much faster but does NOT return attention weights.\n",
    "        # The model's interpretability for attention is therefore disabled for performance.\n",
    "        attn_output = F.scaled_dot_product_attention(\n",
    "            q, k, v,\n",
    "            attn_mask=attn_mask,\n",
    "            dropout_p=self.dropout if self.training else 0.0,\n",
    "        )\n",
    "        \n",
    "        # Reshape and project back\n",
    "        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)\n",
    "        attn_output = self.out_proj(attn_output)\n",
    "        \n",
    "        # Return None for weights as they are not computed\n",
    "        return attn_output, None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0873c74-dea9-44aa-a98d-6af9569046f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Bloque 3: Definición de la clase principal y funciones auxiliares optimizadas\n",
    "\n",
    "class TemporalFusionTransformer(BaseModelWithCovariates):\n",
    "    @staticmethod\n",
    "    @torch.jit.script\n",
    "    def initialize_lstm_states(hidden, cell, layers: int):\n",
    "        \"\"\"\n",
    "        Inicializa los estados LSTM de manera optimizada con JIT.\n",
    "        \"\"\"\n",
    "        return (\n",
    "            hidden.unsqueeze(0).expand(layers, -1, -1),\n",
    "            cell.unsqueeze(0).expand(layers, -1, -1)\n",
    "        )\n",
    "\n",
    "    def process_lstm_output(self, lstm_output, residual_input, is_encoder=True):\n",
    "        \"\"\"\n",
    "        Combina las operaciones de gate y add_norm para el procesamiento post-LSTM.\n",
    "        \"\"\"\n",
    "        gate = self.post_lstm_gate_encoder if is_encoder else self.post_lstm_gate_decoder\n",
    "        add_norm = self.post_lstm_add_norm_encoder if is_encoder else self.post_lstm_add_norm_decoder\n",
    "        return add_norm(gate(lstm_output), residual_input)\n",
    "\n",
    "    def process_variables_importance(self, variables, lengths):\n",
    "        \"\"\"\n",
    "        Vectoriza el procesamiento de importancia de variables.\n",
    "        \"\"\"\n",
    "        mask = create_mask(variables.size(1), lengths).unsqueeze(-1)\n",
    "        masked_vars = variables.masked_fill(~mask, 0.0).sum(dim=1)\n",
    "        return masked_vars / lengths.clamp_min(1).unsqueeze(-1)\n",
    "\n",
    "    def transform_multi_output(self, output):\n",
    "        \"\"\"\n",
    "        Procesa todos los outputs en paralelo en lugar de secuencialmente.\n",
    "        \"\"\"\n",
    "        if self.n_targets > 1:\n",
    "            stacked_out = torch.stack([ol.weight for ol in self.output_layer])\n",
    "            stacked_bias = torch.stack([ol.bias for ol in self.output_layer])\n",
    "            \n",
    "            reshaped_out = output.unsqueeze(1)\n",
    "            transformed = torch.bmm(\n",
    "                reshaped_out.expand(-1, self.n_targets, -1),\n",
    "                stacked_out.transpose(1, 2)\n",
    "            )\n",
    "            transformed = transformed + stacked_bias.unsqueeze(0)\n",
    "            return [transformed[:, i] for i in range(self.n_targets)]\n",
    "        else:\n",
    "            return self.output_layer(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "115c9a08-f184-40ae-89e4-3fd73cf888e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Bloque 4: Método de inicialización (init)\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        hidden_size: int = 16,\n",
    "        lstm_layers: int = 1,\n",
    "        dropout: float = 0.1,\n",
    "        output_size: Union[int, List[int]] = 7,\n",
    "        loss: MultiHorizonMetric = None,\n",
    "        attention_head_size: int = 4,\n",
    "        max_encoder_length: int = 10,\n",
    "        static_categoricals: Optional[List[str]] = None,\n",
    "        static_reals: Optional[List[str]] = None,\n",
    "        time_varying_categoricals_encoder: Optional[List[str]] = None,\n",
    "        time_varying_categoricals_decoder: Optional[List[str]] = None,\n",
    "        categorical_groups: Optional[Union[Dict, List[str]]] = None,\n",
    "        time_varying_reals_encoder: Optional[List[str]] = None,\n",
    "        time_varying_reals_decoder: Optional[List[str]] = None,\n",
    "        x_reals: Optional[List[str]] = None,\n",
    "        x_categoricals: Optional[List[str]] = None,\n",
    "        hidden_continuous_size: int = 8,\n",
    "        hidden_continuous_sizes: Optional[Dict[str, int]] = None,\n",
    "        embedding_sizes: Optional[Dict[str, Tuple[int, int]]] = None,\n",
    "        embedding_paddings: Optional[List[str]] = None,\n",
    "        embedding_labels: Optional[Dict[str, np.ndarray]] = None,\n",
    "        learning_rate: float = 1e-3,\n",
    "        log_interval: Union[int, float] = -1,\n",
    "        log_val_interval: Union[int, float] = None,\n",
    "        log_gradient_flow: bool = False,\n",
    "        reduce_on_plateau_patience: int = 1000,\n",
    "        monotone_constraints: Optional[Dict[str, int]] = None,\n",
    "        share_single_variable_networks: bool = False,\n",
    "        causal_attention: bool = True,\n",
    "        logging_metrics: Optional[nn.ModuleList] = None,\n",
    "        use_compile: bool = False,\n",
    "        gradient_checkpointing: bool = False, # NUEVO: Parámetro para optimización de memoria\n",
    "        **kwargs,\n",
    "    ):\n",
    "        \"\"\"\n",
    "        Temporal Fusion Transformer para series temporales. Versión optimizada.\n",
    "\n",
    "        Args:\n",
    "            (...)\n",
    "            gradient_checkpointing (bool): Si es True, usa checkpointing para reducir el uso de memoria de la GPU.\n",
    "            (...)\n",
    "        \"\"\"\n",
    "        # Normalización de argumentos opcionales\n",
    "        static_categoricals = static_categoricals or []\n",
    "        static_reals = static_reals or []\n",
    "        time_varying_categoricals_encoder = time_varying_categoricals_encoder or []\n",
    "        time_varying_categoricals_decoder = time_varying_categoricals_decoder or []\n",
    "        time_varying_reals_encoder = time_varying_reals_encoder or []\n",
    "        time_varying_reals_decoder = time_varying_reals_decoder or []\n",
    "        if loss is None: loss = QuantileLoss()\n",
    "        if logging_metrics is None: logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE()])\n",
    "\n",
    "        super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs)\n",
    "        self.save_hyperparameters(ignore=[\"use_compile\"])\n",
    "        self.hparams.use_compile = use_compile\n",
    "        self.hparams.gradient_checkpointing = gradient_checkpointing\n",
    "\n",
    "        # El resto de la inicialización sigue la lógica original...\n",
    "        self.input_embeddings = MultiEmbedding(\n",
    "            embedding_sizes=self.hparams.embedding_sizes,\n",
    "            categorical_groups=self.hparams.categorical_groups,\n",
    "            embedding_paddings=self.hparams.embedding_paddings,\n",
    "            x_categoricals=self.hparams.x_categoricals,\n",
    "            max_embedding_size=self.hparams.hidden_size,\n",
    "        )\n",
    "        self.prescalers = nn.ModuleDict({name: nn.Linear(1, self.hparams.hidden_continuous_sizes.get(name, self.hparams.hidden_continuous_size)) for name in self.reals})\n",
    "\n",
    "        static_input_sizes = {name: self.input_embeddings.output_size[name] for name in self.hparams.static_categoricals}\n",
    "        static_input_sizes.update({name: self.hparams.hidden_continuous_sizes.get(name, self.hparams.hidden_continuous_size) for name in self.hparams.static_reals})\n",
    "        self.static_variable_selection = VariableSelectionNetwork(\n",
    "            input_sizes=static_input_sizes, hidden_size=self.hparams.hidden_size, input_embedding_flags={name: True for name in self.hparams.static_categoricals}, dropout=self.hparams.dropout, prescalers=self.prescalers\n",
    "        )\n",
    "\n",
    "        encoder_input_sizes = {name: self.input_embeddings.output_size[name] for name in self.hparams.time_varying_categoricals_encoder}\n",
    "        encoder_input_sizes.update({name: self.hparams.hidden_continuous_sizes.get(name, self.hparams.hidden_continuous_size) for name in self.hparams.time_varying_reals_encoder})\n",
    "        decoder_input_sizes = {name: self.input_embeddings.output_size[name] for name in self.hparams.time_varying_categoricals_decoder}\n",
    "        decoder_input_sizes.update({name: self.hparams.hidden_continuous_sizes.get(name, self.hparams.hidden_continuous_size) for name in self.hparams.time_varying_reals_decoder})\n",
    "\n",
    "        if self.hparams.share_single_variable_networks:\n",
    "             self.shared_single_variable_grns = nn.ModuleDict()\n",
    "             for name, input_size in encoder_input_sizes.items():\n",
    "                 self.shared_single_variable_grns[name] = GatedResidualNetwork(input_size, min(input_size, self.hparams.hidden_size), self.hparams.hidden_size, self.hparams.dropout)\n",
    "             for name, input_size in decoder_input_sizes.items():\n",
    "                 if name not in self.shared_single_variable_grns:\n",
    "                     self.shared_single_variable_grns[name] = GatedResidualNetwork(input_size, min(input_size, self.hparams.hidden_size), self.hparams.hidden_size, self.hparams.dropout)\n",
    "\n",
    "        self.encoder_variable_selection = VariableSelectionNetwork(\n",
    "            input_sizes=encoder_input_sizes, hidden_size=self.hparams.hidden_size, input_embedding_flags={name: True for name in self.hparams.time_varying_categoricals_encoder}, dropout=self.hparams.dropout, context_size=self.hparams.hidden_size, prescalers=self.prescalers, single_variable_grns=({} if not self.hparams.share_single_variable_networks else self.shared_single_variable_grns)\n",
    "        )\n",
    "        self.decoder_variable_selection = VariableSelectionNetwork(\n",
    "            input_sizes=decoder_input_sizes, hidden_size=self.hparams.hidden_size, input_embedding_flags={name: True for name in self.hparams.time_varying_categoricals_decoder}, dropout=self.hparams.dropout, context_size=self.hparams.hidden_size, prescalers=self.prescalers, single_variable_grns=({} if not self.hparams.share_single_variable_networks else self.shared_single_variable_grns)\n",
    "        )\n",
    "\n",
    "        self.static_context_variable_selection = GatedResidualNetwork(self.hparams.hidden_size, self.hparams.hidden_size, self.hparams.hidden_size, self.hparams.dropout)\n",
    "        self.static_context_initial_hidden_lstm = GatedResidualNetwork(self.hparams.hidden_size, self.hparams.hidden_size, self.hparams.hidden_size, self.hparams.dropout)\n",
    "        self.static_context_initial_cell_lstm = GatedResidualNetwork(self.hparams.hidden_size, self.hparams.hidden_size, self.hparams.hidden_size, self.hparams.dropout)\n",
    "        self.static_context_enrichment = GatedResidualNetwork(self.hparams.hidden_size, self.hparams.hidden_size, self.hparams.hidden_size, self.hparams.dropout)\n",
    "\n",
    "        self.lstm_encoder = LSTM(input_size=self.hparams.hidden_size, hidden_size=self.hparams.hidden_size, num_layers=self.hparams.lstm_layers, dropout=self.hparams.dropout if self.hparams.lstm_layers > 1 else 0, batch_first=True)\n",
    "        self.lstm_decoder = LSTM(input_size=self.hparams.hidden_size, hidden_size=self.hparams.hidden_size, num_layers=self.hparams.lstm_layers, dropout=self.hparams.dropout if self.hparams.lstm_layers > 1 else 0, batch_first=True)\n",
    "\n",
    "        self.post_lstm_gate_encoder = GatedLinearUnit(self.hparams.hidden_size, dropout=self.hparams.dropout)\n",
    "        self.post_lstm_gate_decoder = GatedLinearUnit(self.hparams.hidden_size, dropout=self.hparams.dropout)\n",
    "        self.post_lstm_add_norm_encoder = AddNorm(self.hparams.hidden_size, trainable_add=False)\n",
    "        self.post_lstm_add_norm_decoder = AddNorm(self.hparams.hidden_size, trainable_add=False)\n",
    "\n",
    "        self.static_enrichment = GatedResidualNetwork(self.hparams.hidden_size, self.hparams.hidden_size, self.hparams.hidden_size, self.hparams.dropout, context_size=self.hparams.hidden_size)\n",
    "        \n",
    "        # Se usa la nueva clase de atención optimizada\n",
    "        self.multihead_attn = InterpretableMultiHeadAttention(d_model=self.hparams.hidden_size, n_head=self.hparams.attention_head_size, dropout=self.hparams.dropout)\n",
    "        \n",
    "        self.post_attn_gate_norm = GateAddNorm(self.hparams.hidden_size, dropout=self.hparams.dropout, trainable_add=False)\n",
    "        self.pos_wise_ff = GatedResidualNetwork(self.hparams.hidden_size, self.hparams.hidden_size, self.hparams.hidden_size, dropout=self.hparams.dropout)\n",
    "\n",
    "        self.pre_output_gate_norm = GateAddNorm(self.hparams.hidden_size, dropout=None, trainable_add=False)\n",
    "        if self.n_targets > 1:\n",
    "            self.output_layer = nn.ModuleList([nn.Linear(self.hparams.hidden_size, osize) for osize in self.hparams.output_size])\n",
    "        else:\n",
    "            self.output_layer = nn.Linear(self.hparams.hidden_size, self.hparams.output_size)\n",
    "\n",
    "        self._attention_mask_cache = {}\n",
    "        \n",
    "        if self.hparams.use_compile and hasattr(torch, \"compile\"):\n",
    "            self.forward = torch.compile(self.forward)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c742878-cf40-4898-a130-6af26193ab1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Bloque 5: Métodos de utilidad\n",
    "\n",
    "    @classmethod\n",
    "    def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs):\n",
    "        new_kwargs = copy(kwargs)\n",
    "        new_kwargs[\"max_encoder_length\"] = dataset.max_encoder_length\n",
    "        new_kwargs.update(cls.deduce_default_output_parameters(dataset, kwargs, QuantileLoss()))\n",
    "        return super().from_dataset(dataset, **new_kwargs)\n",
    "\n",
    "    def expand_static_context(self, context: torch.Tensor, timesteps: int) -> torch.Tensor:\n",
    "        return context.unsqueeze(1).expand(-1, timesteps, -1)\n",
    "\n",
    "    def get_attention_mask(self, encoder_lengths: torch.LongTensor, decoder_lengths: torch.LongTensor) -> torch.Tensor:\n",
    "        cache_key = (encoder_lengths.max().item(), decoder_lengths.max().item(), self.device)\n",
    "        if cache_key in self._attention_mask_cache:\n",
    "            return self._attention_mask_cache[cache_key]\n",
    "\n",
    "        max_decoder_length = decoder_lengths.max()\n",
    "        if self.hparams.causal_attention:\n",
    "            # Create a causal mask for the decoder\n",
    "            decoder_mask = torch.triu(torch.ones(max_decoder_length, max_decoder_length, device=self.device), 1).bool()\n",
    "            # Expand to batch and add a dummy dimension for heads\n",
    "            decoder_mask = decoder_mask.unsqueeze(0).expand(encoder_lengths.size(0), -1, -1)\n",
    "        else:\n",
    "            decoder_mask = ~create_mask(max_decoder_length, decoder_lengths).unsqueeze(1).expand(-1, max_decoder_length, -1)\n",
    "\n",
    "        # Create mask for encoder\n",
    "        encoder_mask = create_mask(encoder_lengths.max(), encoder_lengths)\n",
    "        encoder_mask = encoder_mask.unsqueeze(1).expand(-1, max_decoder_length, -1)\n",
    "\n",
    "        # The final mask should be True for positions we want to attend to\n",
    "        # We concatenate the encoder and decoder masks\n",
    "        mask = torch.cat([encoder_mask, ~decoder_mask], dim=2)\n",
    "        self._attention_mask_cache[cache_key] = mask\n",
    "        return mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d3befcb-70ee-4171-b443-b86a157048a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Bloque 6: Método forward\n",
    "\n",
    "    def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:\n",
    "        batch_size = x[\"encoder_lengths\"].size(0)\n",
    "        encoder_lengths = x[\"encoder_lengths\"]\n",
    "        decoder_lengths = x[\"decoder_lengths\"]\n",
    "        max_encoder_length = int(encoder_lengths.max())\n",
    "        timesteps = max_encoder_length + int(decoder_lengths.max())\n",
    "\n",
    "        # Procesamiento de entradas\n",
    "        input_vectors = self.input_embeddings(torch.cat([x[\"encoder_cat\"], x[\"decoder_cat\"]], dim=1))\n",
    "        x_cont = torch.cat([x[\"encoder_cont\"], x[\"decoder_cont\"]], dim=1)\n",
    "        input_vectors.update({name: x_cont[..., idx].unsqueeze(-1) for idx, name in enumerate(self.hparams.x_reals) if name in self.reals})\n",
    "\n",
    "        # Selección de variables estáticas\n",
    "        if len(self.static_variables) > 0:\n",
    "            static_embedding = {name: input_vectors[name][:, 0] for name in self.static_variables}\n",
    "            static_embedding, static_variable_selection = self.static_variable_selection(static_embedding)\n",
    "        else:\n",
    "            static_embedding = torch.zeros((batch_size, self.hparams.hidden_size), device=self.device)\n",
    "            static_variable_selection = torch.zeros((batch_size, 0), device=self.device)\n",
    "\n",
    "        static_context_variable_selection = self.expand_static_context(self.static_context_variable_selection(static_embedding), timesteps)\n",
    "\n",
    "        # Selección de variables del encoder y decoder\n",
    "        embeddings_varying_encoder = {name: input_vectors[name][:, :max_encoder_length] for name in self.encoder_variables}\n",
    "        embeddings_varying_encoder, encoder_sparse_weights = self.encoder_variable_selection(embeddings_varying_encoder, static_context_variable_selection[:, :max_encoder_length])\n",
    "        embeddings_varying_decoder = {name: input_vectors[name][:, max_encoder_length:] for name in self.decoder_variables}\n",
    "        embeddings_varying_decoder, decoder_sparse_weights = self.decoder_variable_selection(embeddings_varying_decoder, static_context_variable_selection[:, max_encoder_length:])\n",
    "\n",
    "        # OPTIMIZACIÓN: Usar el método estático para inicializar estados LSTM\n",
    "        init_hidden = self.static_context_initial_hidden_lstm(static_embedding)\n",
    "        init_cell = self.static_context_initial_cell_lstm(static_embedding)\n",
    "        input_hidden, input_cell = self.initialize_lstm_states(init_hidden, init_cell, self.hparams.lstm_layers)\n",
    "\n",
    "        # Procesamiento LSTM\n",
    "        encoder_output, (hidden, cell) = self.lstm_encoder(embeddings_varying_encoder, (input_hidden, input_cell), lengths=encoder_lengths, enforce_sorted=False)\n",
    "        decoder_output, _ = self.lstm_decoder(embeddings_varying_decoder, (hidden, cell), lengths=decoder_lengths, enforce_sorted=False)\n",
    "        \n",
    "        lstm_output_encoder = self.process_lstm_output(encoder_output, embeddings_varying_encoder, True)\n",
    "        lstm_output_decoder = self.process_lstm_output(decoder_output, embeddings_varying_decoder, False)\n",
    "        lstm_output = torch.cat([lstm_output_encoder, lstm_output_decoder], dim=1)\n",
    "\n",
    "        # Enriquecimiento estático\n",
    "        static_context_enrichment = self.static_context_enrichment(static_embedding)\n",
    "        attn_input = self.static_enrichment(lstm_output, self.expand_static_context(static_context_enrichment, timesteps))\n",
    "\n",
    "        # OPTIMIZACIÓN: Atención con checkpointing (si está activado)\n",
    "        def attention_block(attn_input, mask):\n",
    "            attn_output, _ = self.multihead_attn(q=attn_input[:, max_encoder_length:], k=attn_input, v=attn_input, mask=mask)\n",
    "            attn_output = self.post_attn_gate_norm(attn_output, attn_input[:, max_encoder_length:])\n",
    "            return self.pos_wise_ff(attn_output)\n",
    "\n",
    "        mask = self.get_attention_mask(encoder_lengths=encoder_lengths, decoder_lengths=decoder_lengths)\n",
    "        if self.hparams.gradient_checkpointing:\n",
    "            output = checkpoint(attention_block, attn_input, mask)\n",
    "        else:\n",
    "            output = attention_block(attn_input, mask)\n",
    "        \n",
    "        # Salida final\n",
    "        output = self.pre_output_gate_norm(output, lstm_output[:, max_encoder_length:])\n",
    "        output = self.transform_multi_output(output) if self.n_targets > 1 else self.output_layer(output)\n",
    "\n",
    "        return self.to_network_output(\n",
    "            prediction=self.transform_output(output, target_scale=x[\"target_scale\"]),\n",
    "            # Los pesos de atención no están disponibles con la implementación optimizada\n",
    "            encoder_attention=None,\n",
    "            decoder_attention=None,\n",
    "            static_variables=static_variable_selection,\n",
    "            encoder_variables=encoder_sparse_weights,\n",
    "            decoder_variables=decoder_sparse_weights,\n",
    "            decoder_lengths=decoder_lengths,\n",
    "            encoder_lengths=encoder_lengths,\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8dfb742-a5b5-4a8a-9c1e-2645765c7321",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Bloque 7: Método para interpretar la salida\n",
    "\n",
    "    def interpret_output(\n",
    "        self,\n",
    "        out: Dict[str, torch.Tensor],\n",
    "        reduction: str = \"none\",\n",
    "        attention_prediction_horizon: int = 0, # Este parámetro ya no tiene efecto\n",
    "    ) -> Dict[str, torch.Tensor]:\n",
    "        \"\"\"\n",
    "        Interpreta la salida del modelo (solo importancia de variables).\n",
    "        NOTA: La interpretación de la atención está desactivada debido al uso de \n",
    "        F.scaled_dot_product_attention para mejorar el rendimiento.\n",
    "        \"\"\"\n",
    "        # Vectorización del procesamiento de variables (lógica sin cambios)\n",
    "        encoder_variables = self.process_variables_importance(\n",
    "            out[\"encoder_variables\"].squeeze(-2).clone(),\n",
    "            out[\"encoder_lengths\"]\n",
    "        )\n",
    "        decoder_variables = self.process_variables_importance(\n",
    "            out[\"decoder_variables\"].squeeze(-2).clone(),\n",
    "            out[\"decoder_lengths\"]\n",
    "        )\n",
    "        static_variables = out[\"static_variables\"].squeeze(1)\n",
    "        \n",
    "        if reduction != \"none\":\n",
    "            static_variables = static_variables.sum(dim=0)\n",
    "            encoder_variables = encoder_variables.sum(dim=0)\n",
    "            decoder_variables = decoder_variables.sum(dim=0)\n",
    "\n",
    "        # La atención es None, se devuelve un tensor vacío\n",
    "        attention = torch.empty(0, device=self.device)\n",
    "\n",
    "        interpretation = dict(\n",
    "            attention=attention, # Vacío\n",
    "            static_variables=static_variables,\n",
    "            encoder_variables=encoder_variables,\n",
    "            decoder_variables=decoder_variables,\n",
    "            encoder_length_histogram=integer_histogram(out[\"encoder_lengths\"], min=0, max=self.hparams.max_encoder_length),\n",
    "            decoder_length_histogram=integer_histogram(out[\"decoder_lengths\"], min=1, max=out[\"decoder_variables\"].size(1)),\n",
    "        )\n",
    "        return interpretation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98475a16-f4dc-4e06-9f4f-914adff91129",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Bloque 8: Métodos para visualización y logging (atención desactivada)\n",
    "\n",
    "    def plot_prediction(self, x: Dict[str, torch.Tensor], out: Dict[str, torch.Tensor], **kwargs):\n",
    "        \"\"\"\n",
    "        Grafica las predicciones vs. los valores reales.\n",
    "        NOTA: El ploteo de la atención ha sido desactivado.\n",
    "        \"\"\"\n",
    "        # La llamada original a plot_attention se omite\n",
    "        if 'plot_attention' in kwargs:\n",
    "            print(\"W: plot_attention no está disponible en la versión optimizada.\")\n",
    "            kwargs['plot_attention'] = False\n",
    "            \n",
    "        return super().plot_prediction(x, out, **kwargs)\n",
    "\n",
    "    def plot_interpretation(self, interpretation: Dict[str, torch.Tensor]):\n",
    "        \"\"\"\n",
    "        Crea figuras que interpretan la importancia de las variables.\n",
    "        NOTA: La gráfica de atención ha sido desactivada.\n",
    "        \"\"\"\n",
    "        _check_matplotlib(\"plot_interpretation\")\n",
    "        import matplotlib.pyplot as plt\n",
    "\n",
    "        figs = {}\n",
    "        \n",
    "        # La gráfica de atención se omite\n",
    "        # figs[\"attention\"] = ...\n",
    "        \n",
    "        def make_selection_plot(title, values, labels):\n",
    "            if len(labels) == 0:\n",
    "                return plt.figure()\n",
    "            fig, ax = plt.subplots(figsize=(7, max(1, len(values) * 0.25) + 2))\n",
    "            order = np.argsort(values.detach().cpu())\n",
    "            norm_values = values / values.sum(-1, keepdim=True).clamp_min(1e-9)\n",
    "            bar_values = norm_values[order] * 100\n",
    "            y_pos = np.arange(len(values))\n",
    "            tick_labels = np.asarray(labels)[order]\n",
    "            ax.barh(y_pos, bar_values.detach().cpu(), tick_label=tick_labels)\n",
    "            ax.set_title(title)\n",
    "            ax.set_xlabel(\"Importance in %\")\n",
    "            plt.tight_layout()\n",
    "            return fig\n",
    "\n",
    "        figs[\"static_variables\"] = make_selection_plot(\"Static variables importance\", interpretation[\"static_variables\"], self.static_variables)\n",
    "        figs[\"encoder_variables\"] = make_selection_plot(\"Encoder variables importance\", interpretation[\"encoder_variables\"], self.encoder_variables)\n",
    "        figs[\"decoder_variables\"] = make_selection_plot(\"Decoder variables importance\", interpretation[\"decoder_variables\"], self.decoder_variables)\n",
    "        return figs\n",
    "    \n",
    "    # Los métodos de logging y de ciclo de vida no requieren cambios significativos\n",
    "    def on_fit_end(self):\n",
    "        if self.log_interval > 0: self.log_embeddings()\n",
    "        \n",
    "    def create_log(self, x, y, out, batch_idx, **kwargs):\n",
    "        log = super().create_log(x, y, out, batch_idx, **kwargs)\n",
    "        if self.log_interval > 0: log[\"interpretation\"] = self.interpret_output(detach(out), reduction=\"sum\")\n",
    "        return log\n",
    "\n",
    "    def log_embeddings(self):\n",
    "        if not self._logger_supports(\"add_embedding\"): return\n",
    "        for name, emb in self.input_embeddings.items():\n",
    "            self.logger.experiment.add_embedding(emb.weight.data.detach().cpu(), metadata=self.hparams.embedding_labels.get(name, None), tag=name, global_step=self.global_step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab44c44c-2978-4a4c-93eb-2802b34b7fed",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Bloque 9: Ejemplo de uso del modelo\n",
    "\n",
    "def create_example_model_and_data():\n",
    "    from pytorch_forecasting.data import TimeSeriesDataSet\n",
    "    from pytorch_forecasting.data.examples import generate_ar_data\n",
    "    \n",
    "    data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=100)\n",
    "    \n",
    "    training_cutoff = data[\"time_idx\"].max() - 50\n",
    "    training = TimeSeriesDataSet(\n",
    "        data[data[\"time_idx\"] <= training_cutoff],\n",
    "        time_idx=\"time_idx\",\n",
    "        target=\"value\",\n",
    "        group_ids=[\"series\"],\n",
    "        max_encoder_length=30,\n",
    "        max_prediction_length=10,\n",
    "        time_varying_known_reals=[\"time_idx\"],\n",
    "        time_varying_unknown_reals=[\"value\"],\n",
    "    )\n",
    "    \n",
    "    validation = TimeSeriesDataSet.from_dataset(training, data, min_prediction_idx=training_cutoff + 1)\n",
    "    \n",
    "    batch_size = 128\n",
    "    train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)\n",
    "    val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=0)\n",
    "    \n",
    "    # Creamos el modelo con los nuevos parámetros de optimización\n",
    "    tft = TemporalFusionTransformer.from_dataset(\n",
    "        training,\n",
    "        learning_rate=0.03,\n",
    "        hidden_size=32,\n",
    "        attention_head_size=1,\n",
    "        dropout=0.1,\n",
    "        hidden_continuous_size=16,\n",
    "        loss=QuantileLoss(),\n",
    "        log_interval=10,\n",
    "        use_compile=True,  # Usar compilador de PyTorch 2.0+\n",
    "        gradient_checkpointing=False, # Activar para ahorrar memoria\n",
    "    )\n",
    "    \n",
    "    print(f\"Modelo TFT creado. Número de parámetros: {tft.size()/1e6:.2f}M\")\n",
    "    return tft, train_dataloader, val_dataloader\n",
    "\n",
    "# Ejemplo de entrenamiento (requiere pytorch-lightning)\n",
    "# try:\n",
    "#     import pytorch_lightning as pl\n",
    "#     tft, train_dataloader, val_dataloader = create_example_model_and_data()\n",
    "#     trainer = pl.Trainer(\n",
    "#         max_epochs=5, \n",
    "#         accelerator='gpu' if torch.cuda.is_available() else 'cpu', \n",
    "#         devices=1,\n",
    "#         gradient_clip_val=0.1\n",
    "#     )\n",
    "#     trainer.fit(tft, train_dataloader, val_dataloader)\n",
    "# except ImportError:\n",
    "#     print(\"Por favor, instala pytorch-lightning para ejecutar el entrenamiento: pip install pytorch-lightning\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
