In [2]:
import pickle
import pathlib, sys
import warnings
from pytorch_forecasting import TimeSeriesDataSet
from pytorch_forecasting.data import NaNLabelEncoder

warnings.filterwarnings("ignore")

# ----------------------------------------------------------
# Cargar config y rutas
# ----------------------------------------------------------
ROOT = pathlib.Path().resolve().parent.parent
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from src import config as cfg

# ----------------------------------------------------------
# Cargar el TimeSeriesDataSet
# ----------------------------------------------------------
with open(cfg.DATA / "processed" / "tft_data.pkl", "rb") as f:
    dataset = pickle.load(f)

print("✅ Dataset cargado con éxito")
print(f"🔍 N.º series: {len(dataset)}")


✅ Dataset cargado con éxito
🔍 N.º series: 18577


In [4]:
from torch.utils.data import DataLoader

# ----------------------------------------------------------
# Dividir dataset en entrenamiento y validación
# ----------------------------------------------------------
print("📋 decoded_index columns:", dataset.decoded_index.columns)

training_cutoff = dataset.decoded_index["time_idx_last"].max() - 5 * dataset.max_prediction_length

train_dataset = dataset.filter(lambda x: x["time_idx_last"] <= training_cutoff)
val_dataset   = dataset.filter(lambda x: x["time_idx_last"] >  training_cutoff)

# ----------------------------------------------------------
# Crear los DataLoaders
# ----------------------------------------------------------
# ─── parámetros nuevos ───────────────────────────────────
batch_size  = 8          # ↓ antes 64
num_workers = 0          # ↓ evita procesos extra

train_dataloader = train_dataset.to_dataloader(
    train=True,
    batch_size=batch_size,
    num_workers=num_workers,
    shuffle=True
)
val_dataloader = val_dataset.to_dataloader(
    train=False,
    batch_size=batch_size * 2,   # 16
    num_workers=num_workers
)

print(f"✅ DataLoaders listos")
print(f"🔹 Train batches: {len(train_dataloader)}")
print(f"🔹 Val batches:   {len(val_dataloader)}")


📋 decoded_index columns: Index(['group_id', 'time_idx_first', 'time_idx_last',
       'time_idx_first_prediction'],
      dtype='object')
✅ DataLoaders listos
🔹 Train batches: 2319
🔹 Val batches:   2


In [None]:
import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_forecasting.models import TemporalFusionTransformer
from pytorch_forecasting.metrics import RMSE

# 1 ───────────────────────────  Modelo TFT
tft = TemporalFusionTransformer.from_dataset(
    train_dataset,
    learning_rate=1e-3,
    hidden_size=16,
    attention_head_size=1,
    dropout=0.1,
    hidden_continuous_size=8,
    loss=RMSE(),
    log_interval=10,
    reduce_on_plateau_patience=4
)

# 2 ───────────────────────────  Callbacks
ckpt_cb = ModelCheckpoint(
    dirpath=cfg.MODELS,          # carpeta models/ definida en tu config
    filename="tft_base-{epoch:02d}-{val_loss:.4f}",
    monitor="val_loss",
    save_top_k=1,
    mode="min"
)

es_cb = EarlyStopping(
    monitor="val_loss",
    patience=5,
    min_delta=1e-4,
    mode="min"
)

# 3 ───────────────────────────  Trainer
trainer = pl.Trainer(
    max_epochs=30,
    gradient_clip_val=0.1,
    callbacks=[es_cb, ckpt_cb],
    accelerator="cpu", devices=1,
    enable_progress_bar=False,      # ← sin barra
    enable_model_summary=False,     # ← sin tabla grande
    logger=False,                   # ← sin TensorBoard
    limit_train_batches=1.0,
    limit_val_batches=1.0
)

# 4 ───────────────────────────  Entrenamiento
pl.seed_everything(42, workers=True)   # para reproducibilidad

trainer.fit(
    model=tft,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader
)

print("🏁  Entrenamiento acabado")
print("📦  Mejor checkpoint:", ckpt_cb.best_model_path)


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1.0)` was configured so 100% of the batches per epoch will be used..
`Trainer(limit_val_batches=1.0)` was configured so 100% of the batches will be used..
Seed set to 42
