In [None]:
# ==============================================================================
# 1. SETUP & CONFIGURATION
# ==============================================================================
# C√†i ƒë·∫∑t th∆∞ vi·ªán (Ch·ªâ ch·∫°y l·∫ßn ƒë·∫ßu)
# !pip install pytorch-forecasting pytorch-lightning polars --quiet

import pandas as pd
import polars as pl
import numpy as np
import os
from datetime import timedelta

# Pytorch
import torch
import torch.nn as nn
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint

# Pytorch Forecasting
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import QuantileLoss

# Thi·∫øt l·∫≠p Random Seed ƒë·ªÉ t√°i t·∫°o k·∫øt qu·∫£
SEED = 42
seed_everything(SEED, workers=True)

# C·∫•u h√¨nh ƒë∆∞·ªùng d·∫´n
DATA_PATH = "/kaggle/input/financial-data-ohlcv-global/dataset_final_kaggle.parquet"
CHECKPOINT_PATH = "/kaggle/working/tft_v1"

# Tham s·ªë m√¥ h√¨nh
MAX_ENCODER_LENGTH = 30 # Nh√¨n l·∫°i 30 ng√†y qu√° kh·ª©
MAX_PREDICTION_LENGTH = 5 # D·ª± b√°o 5 ng√†y t∆∞∆°ng lai
BATCH_SIZE = 128
MAX_EPOCHS = 50 
LEARNING_RATE = 1e-3

In [None]:
# ==============================================================================
# 2. LOAD DATA & PREPROCESSING
# ==============================================================================

print("1. ƒêang t·∫£i v√† ti·ªÅn x·ª≠ l√Ω d·ªØ li·ªáu...")
df_pl = pl.read_parquet(DATA_PATH)
df = df_pl.to_pandas()

# Chuy·ªÉn ƒë·ªïi v√† l√†m s·∫°ch cu·ªëi c√πng (ƒê·∫£m b·∫£o ƒë√∫ng type cho Pytorch)
df = df.sort_values(by=['symbol', 'ts']).reset_index(drop=True)
df['symbol_id'] = df['symbol'].astype('category').cat.codes
df['asset_type'] = df['asset_type'].astype('category').cat.codes

# Lo·∫°i b·ªè c√°c m√£ √≠t d·ªØ li·ªáu (D∆∞·ªõi 1 nƒÉm, ƒë·ªÉ tƒÉng ch·∫•t l∆∞·ª£ng)
symbols_to_keep = df.groupby('symbol').size().nlargest(3000).index # Gi·ªØ 3000 m√£ t·ªët nh·∫•t
df = df[df['symbol'].isin(symbols_to_keep)].copy()

print(f"   -> K√≠ch th∆∞·ªõc t·∫≠p d·ªØ li·ªáu cu·ªëi c√πng: {df.shape[0]:,} d√≤ng ({len(symbols_to_keep)} m√£)")

# Chia t·∫≠p Train/Validation theo TH·ªúI GIAN
# Gi·ªØ 5 ng√†y d·ª± b√°o + 10 ng√†y cho Encoder (t·ªïng 15 ng√†y) ƒë·ªÉ ƒë√°nh gi√°.
training_cutoff = df["time_idx"].max() - MAX_PREDICTION_LENGTH * 3 
print(f"   -> C·∫Øt d·ªØ li·ªáu t·∫°i time_idx: {training_cutoff}")

In [None]:
# ==============================================================================
# 3. ƒê·ªäNH NGHƒ®A TIMESERIESDATASET
# ==============================================================================

print("2. ƒêang c·∫•u h√¨nh TimeSeriesDataSet...")

training = TimeSeriesDataSet(
    df[lambda x: x.time_idx <= training_cutoff],
    time_idx="time_idx",
    target="log_return", # M·ª•c ti√™u: D·ª± b√°o l·ª£i nhu·∫≠n
    group_ids=["symbol_id"], # Nh√≥m: Multi-asset learning
    
    max_encoder_length=MAX_ENCODER_LENGTH,
    max_prediction_length=MAX_PREDICTION_LENGTH,
    
    # 3a. STATIC (C·ªë ƒë·ªãnh):
    static_categoricals=["symbol_id", "asset_type"],
    
    # 3b. KNOWN REAL (Bi·∫øt tr∆∞·ªõc t∆∞∆°ng lai):
    time_varying_known_reals=[
        "time_idx", "day_sin", "day_cos", "month_sin", "month_cos"
    ],
    
    # 3c. UNKNOWN REAL (Quan s√°t qu√° kh·ª©):
    time_varying_unknown_reals=[
        "log_return", "vol_relative", "bb_width", "roc_10", "macd_proxy", # Indicators
        "ctx_sp500_ret", "ctx_gold_ret", "ctx_oil_ret", "ctx_forex_ret",  # Global Context
        "ctx_sp500_vol", "ctx_gold_vol", "ctx_oil_vol", "ctx_forex_vol",  # Global Volume
    ],
    
    # Chu·∫©n h√≥a ri√™ng cho t·ª´ng m√£
    target_normalizer=GroupNormalizer(groups=["symbol_id"], center=False, scale_by_group=True),
    add_target_scales=True,
    add_encoder_length=True,
)

# T·∫°o Validation Dataset
validation = TimeSeriesDataSet.from_dataset(training, df, min_prediction_idx=training_cutoff + 1, overwrite_existing_dataset_fields=True)

# T·∫°o DataLoaders (TƒÉng t·ªëc ƒë·ªçc d·ªØ li·ªáu)
train_dataloader = training.to_dataloader(train=True, batch_size=BATCH_SIZE, num_workers=4)
val_dataloader = validation.to_dataloader(train=False, batch_size=BATCH_SIZE, num_workers=4)

In [None]:
# ==============================================================================
# 4. KH·ªûI T·∫†O V√Ä HU·∫§N LUY·ªÜN MODEL TFT
# ==============================================================================

print("3. ƒêang c·∫•u h√¨nh TFT Model v√† Callbacks...")

# Callbacks:
# 1. Early Stopping: D·ª´ng n·∫øu val_loss kh√¥ng c·∫£i thi·ªán trong 7 epochs
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=7, verbose=False, mode="min")
# 2. Model Checkpoint: L∆∞u model t·ªët nh·∫•t (val_loss th·∫•p nh·∫•t)
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath=CHECKPOINT_PATH,
    filename="best_tft_v1-{epoch:02d}-{val_loss:.4f}",
    save_top_k=1,
    mode="min",
    verbose=True
)

# Kh·ªüi t·∫°o TFT Model
tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=LEARNING_RATE,
    hidden_size=64,
    attention_head_size=4,
    dropout=0.1,
    hidden_continuous_size=64,
    output_size=7, # 7 Quantiles cho QuantileLoss
    loss=QuantileLoss(),
    log_interval=10,
    optimizer="adam",
    reduce_on_plateau_patience=4,
)

# Kh·ªüi t·∫°o Trainer
trainer = Trainer(
    max_epochs=MAX_EPOCHS,
    accelerator="gpu",
    devices=1,
    gradient_clip_val=0.1,
    callbacks=[early_stop_callback, checkpoint_callback, LearningRateMonitor()],
)

print("\nüöÄ B·∫ÆT ƒê·∫¶U TRAINING BASE TFT (Version 1.0)...")
trainer.fit(
    tft,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
)

# L∆∞u Model cu·ªëi c√πng (N·∫øu kh√¥ng mu·ªën d√πng checkpoint)
best_model_path = trainer.checkpoint_callback.best_model_path
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)
print(f"‚úÖ Training ho√†n th√†nh! Model t·ªët nh·∫•t ƒë∆∞·ª£c l∆∞u t·∫°i: {best_model_path}")

In [None]:
# ==============================================================================
# 5. D·ª∞ B√ÅO & VISUALIZATION (Sau khi Training)
# ==============================================================================

print("\n4. ƒêang ki·ªÉm tra d·ª± b√°o tr√™n t·∫≠p Validation...")

# Ch·ªçn m·ªôt v√†i m√£ ng·∫´u nhi√™n ƒë·ªÉ v·∫Ω bi·ªÉu ƒë·ªì
target_symbols = validation.data.symbol.unique()
# Ch·ªçn 3 m√£ b·∫•t k·ª≥
symbols_to_plot = np.random.choice(target_symbols, 3, replace=False)

for sym in symbols_to_plot:
    # L·ªçc d·ªØ li·ªáu c·ªßa m√£ ƒë√≥
    encoder_data = df[lambda x: (x.symbol == sym) & (x.time_idx <= training_cutoff)]
    last_encoder_data = encoder_data[encoder_data.time_idx == encoder_data.time_idx.max()]
    
    # L·∫•y d·ªØ li·ªáu validation cho m√£ ƒë√≥
    validation_data = df[lambda x: (x.symbol == sym) & (x.time_idx > training_cutoff)]
    
    # D·ª± b√°o
    raw_predictions = best_tft.predict(last_encoder_data, mode="raw", return_x=True)
    
    # Plotting (s·ª≠ d·ª•ng th∆∞ vi·ªán n·ªôi b·ªô c·ªßa Pytorch Forecasting)
    fig = best_tft.plot_prediction(
        raw_predictions, 
        x=raw_predictions.x, 
        add_loss_to_title=True, 
        show_future_observed_values=True,
        ax_kwargs={"title": f"D·ª± b√°o 5 ng√†y cho M√£: {sym}"}
    )
    fig.show() # Tr√™n Kaggle, b·∫°n s·∫Ω th·∫•y bi·ªÉu ƒë·ªì hi·ªÉn th·ªã ngay d∆∞·ªõi cell n√†y