### Load Model

In [None]:
import os
import torch
import pandas as pd
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
from pytorch_forecasting.data import TimeSeriesDataSet
from pytorch_forecasting.models import TemporalFusionTransformer
from pytorch_forecasting.metrics import MultiLoss, QuantileLoss, MAE
from pytorch_forecasting.data.encoders import MultiNormalizer, TorchNormalizer

import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
### Model Path
modelPath="mar3_model.pth"

###
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_float32_matmul_precision("medium")  # For NVIDIA Tensor Cores
print("Using device:", DEVICE)

# -----------------------------
# 2. Data Loading
# -----------------------------
def load_data(folder):
    all_files = [os.path.join(folder, f) for f in os.listdir(folder) if f.endswith('.csv')]
    dfs = []
    for file in all_files:
        # df = pd.read_csv(file, index_col=0)  # When There is no Date index
        df = pd.read_csv(file)
        df=df.reset_index(drop=True)
        df=df.drop('date',axis=1) # Drops 'date' column
        df["time_idx"] = range(len(df))
        df["group"] = os.path.basename(file).split('.')[0]
        df["group"] = df["group"].astype(str)
        df.rename(columns={"Close": "target_1", "vclose": "target_2"}, inplace=True)
        dfs.append(df)
    return pd.concat(dfs, ignore_index=True)

train_data_folder = "data/train"
test_data_folder  = "data/test"
oos_data_folder   = "data/oos"

train_df = load_data(train_data_folder)
test_df  = load_data(test_data_folder)
oos_df   = load_data(oos_data_folder)

# -----------------------------
# 3. Multi-Target Dataset
# -----------------------------
training = TimeSeriesDataSet(
    train_df,
    time_idx="time_idx",
    target=["target_1", "target_2"],
    group_ids=["group"],
    max_encoder_length=90,
    max_prediction_length=5,
    static_categoricals=["group"],
    time_varying_known_reals=["time_idx"],
    time_varying_unknown_reals=[
        c for c in train_df.columns if c not in ["group", "time_idx"]#, "target_1", "target_2"]
    ],
    target_normalizer=MultiNormalizer([
        TorchNormalizer(method="identity"),
        TorchNormalizer(method="identity")
    ])
)

# validation & OOS sets with predict_mode=False -> keep target data
validation = TimeSeriesDataSet.from_dataset(training, test_df, predict_mode=False)
oos        = TimeSeriesDataSet.from_dataset(training, oos_df,  predict_mode=False)

# -----------------------------
# 4. DataLoaders
# -----------------------------
train_dataloader = training.to_dataloader(
    train=True, batch_size=64, shuffle=True, num_workers=16, pin_memory=False
)
val_dataloader = validation.to_dataloader(
    train=False, batch_size=64, shuffle=False, num_workers=16, pin_memory=False
)
oos_dataloader = oos.to_dataloader(
    train=False, batch_size=16, shuffle=False, num_workers=16, pin_memory=False
)

# -----------------------------
# 5. Define Multi-Target TFT Model w/ Quantile Loss
# -----------------------------
import torch.nn as nn

tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=4.171863120490385e-05,
    lstm_layers=2,
    hidden_size=256,
    attention_head_size=4,
    dropout=0.35,
    hidden_continuous_size=256,
    output_size=[1, 1],  # single output per target for MSE
    loss=MultiLoss([
        nn.MSELoss(),
        nn.MSELoss()
    ]),
    log_interval=200,
    reduce_on_plateau_patience=5,
).to(DEVICE)

print(f"Number of params in network: {tft.size() / 1e3:.1f}k")

# -----------------------------
# 6. LightningModule
# -----------------------------
class TFTLightningModule(LightningModule):
    def __init__(self, tft_model):
        super().__init__()
        self.tft_model = tft_model

    def forward(self, x):
        return self.tft_model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        # 'y' is a tuple: ([target0_tensor, target1_tensor], None)
        out = self(x)
        pred = out["prediction"]  # list: 0 => target0, 1 => target1
        loss = self.tft_model.loss(pred, y)  # automatically handles multi-target
        self.log("train_loss", loss, prog_bar=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        out = self(x)
        pred = out["prediction"]
        loss = self.tft_model.loss(pred, y)
        self.log("val_loss", loss, prog_bar=True, on_epoch=True)
        return loss

    def configure_optimizers(self):
        return self.tft_model.configure_optimizers()

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        # If the batch is a (x, y) tuple, we only pass x to the model
        if isinstance(batch, (tuple, list)):
            x = batch[0]
        else:
            x = batch
        return self(x)

tft_module = TFTLightningModule(tft).to(DEVICE)

# Optional: training setup
early_stop_callback = EarlyStopping(monitor="val_loss", patience=7, mode="min")
checkpoint_callback = ModelCheckpoint(
    dirpath="checkpoints/",
    filename="my-tft-{epoch:02d}-{val_loss:.2f}",
    save_top_k=1,
    monitor="val_loss",
    mode="min"
)

trainer = Trainer(
    max_epochs=35,   # set higher for real training
    accelerator="gpu" if DEVICE == "cuda" else "cpu",
    devices=1,
    precision=32,
    logger=CSVLogger("logs", name="tft_multi_target_quantile"),
    callbacks=[early_stop_callback, checkpoint_callback],
)

# --- Load the saved state dict ---
state_dict = torch.load(modelPath, map_location=DEVICE)
tft_module.load_state_dict(state_dict)
tft_module.to(DEVICE)
tft_module.eval()  # set the model to evaluation mode


  from tqdm.autonotebook import tqdm


Using device: cpu


/Users/alecjeffery/Documents/Playgrounds/Python/tft_v1/env/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.
/Users/alecjeffery/Documents/Playgrounds/Python/tft_v1/env/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'logging_metrics' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['logging_metrics'])`.
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Number of params in network: 32358.6k


/Users/alecjeffery/Documents/Playgrounds/Python/tft_v1/env/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


TFTLightningModule(
  (tft_model): TemporalFusionTransformer(
    	"attention_head_size":               4
    	"categorical_groups":                {}
    	"causal_attention":                  True
    	"dataset_parameters":                {'time_idx': 'time_idx', 'target': ['target_1', 'target_2'], 'group_ids': ['group'], 'weight': None, 'max_encoder_length': 90, 'min_encoder_length': 90, 'min_prediction_idx': 0, 'min_prediction_length': 5, 'max_prediction_length': 5, 'static_categoricals': ['group'], 'static_reals': None, 'time_varying_known_categoricals': None, 'time_varying_known_reals': ['time_idx'], 'time_varying_unknown_categoricals': None, 'time_varying_unknown_reals': ['open', 'high', 'low', 'target_1', 'target_2', 'vopen', 'vhigh', 'vlow', 'VIX', 'SPY', 'TNX', 'rsi14', 'rsi9', 'rsi24', 'MACD5355macddiff', 'MACD5355macddiffslope', 'MACD5355macd', 'MACD5355macdslope', 'MACD5355macdsig', 'MACD5355macdsigslope', 'MACD12269macddiff', 'MACD12269macddiffslope', 'MACD12269macd', 'MAC

## Predict w/ loaded model. RSM ranking from validation set

In [3]:
saveAs = 'RankedPreds_3-02_newModel.csv'

mu_sig_df = pd.read_csv('./data/TixMuSig.csv')

def move_to_device(batch_x, device):
    """Ensure all Tensors in batch_x are on the same device."""
    if isinstance(batch_x, torch.Tensor):
        return batch_x.to(device)
    elif isinstance(batch_x, dict):
        return {k: move_to_device(v, device) for k, v in batch_x.items()}
    elif isinstance(batch_x, list):
        return [move_to_device(item, device) for item in batch_x]
    else:
        return batch_x

def process_symbol(symbol):
    # Filter the full oos dataframe for this symbol (group)
    stock_oos_df = oos_df[oos_df["group"] == symbol]
    if stock_oos_df.empty:
        print(f"No OOS data for {symbol}.")
        return None

    # Retrieve the scaling parameters for price (target_1) from mu_sig_df.
    try:
        mu_p = mu_sig_df.loc[mu_sig_df['ticker'] == symbol, 'closemu'].values[0]
        sig_p = mu_sig_df.loc[mu_sig_df['ticker'] == symbol, 'closesig'].values[0]
    except Exception as e:
        print(f"Error retrieving mu/sig for {symbol}: {e}")
        return None

    # -------------------------------------------------------------------
    # Create two datasets:
    # 1. For current predictions (no future targets provided)
    eq_dataset_current = TimeSeriesDataSet.from_dataset(
        training, stock_oos_df, predict_mode=True
    )
    eq_dataloader_current = eq_dataset_current.to_dataloader(
        train=False,
        batch_size=len(eq_dataset_current),  # All samples in one batch
        shuffle=False,
        num_workers=0,
        pin_memory=False
    )

    # 2. For backtesting (to compute MSE, future values are provided)
    eq_dataset_backtest = TimeSeriesDataSet.from_dataset(
        training, stock_oos_df, predict_mode=False
    )
    eq_dataloader_backtest = eq_dataset_backtest.to_dataloader(
        train=False,
        batch_size=len(eq_dataset_backtest),  # All samples in one batch
        shuffle=False,
        num_workers=0,
        pin_memory=False
    )
    # -------------------------------------------------------------------

    # --- Get current predictions (using predict_mode=True) ---
    with torch.no_grad():
        for batch in eq_dataloader_current:
            x_current, _ = batch  # y is not used in predict_mode=True
            x_current = move_to_device(x_current, DEVICE)
            out_current = tft_module(x_current)
            preds_current = out_current["prediction"]
            # Assuming target_1 is at index 0 and we use the median (index 0 when using MSE)
            current_price_preds = preds_current[0][:, :, 0].cpu()  # shape: (batch, prediction_length)

    # Extract Pred1–Pred5 from current predictions:
    # (Assuming prediction_length is 5 and batch size is 1)
    pred1 = current_price_preds[0, 0].item()
    pred2 = current_price_preds[0, 1].item()
    pred3 = current_price_preds[0, 2].item()
    pred4 = current_price_preds[0, 3].item()
    pred5 = current_price_preds[0, 4].item()
    # print(current_price_preds[0])


    # --- Get backtest predictions for MSE calculation (using predict_mode=False) ---
    with torch.no_grad():
        for batch in eq_dataloader_backtest:
            x_backtest, y_tuple = batch
            y_list, _ = y_tuple
            # Get the actual future values for target_1.
            actual_future = y_list[0]  # shape: [batch, prediction_length]
            x_backtest = move_to_device(x_backtest, DEVICE)
            out_backtest = tft_module(x_backtest)
            preds_backtest = out_backtest["prediction"]
            backtest_price_preds = preds_backtest[0][:, :, 0].cpu()

    # Compute MSE over the forecast horizon using the backtest data.
    mse_value = torch.mean((backtest_price_preds[0] - actual_future[0])**2).item()
    # plt.plot(backtest_price_preds)
    # plt.plot(actual_future)
    # plt.show()
    # Build the result dictionary.
    result = {
        "Symbol": symbol,
        "MSE": mse_value,
        "Pred1": pred1,
        "Pred2": pred2,
        "Pred3": pred3,
        "Pred4": pred4,
        "Pred5": pred5,
        "Delta": pred5-pred1,
    }
    return result

# Loop over each CSV file in "./data/oos" and collect results.
results_list = []
oos_folder = "./data/oos"
oos_files = [f for f in os.listdir(oos_folder) if f.endswith('.csv')]
for file in oos_files:#[:6]:
    symbol = os.path.splitext(file)[0].upper()
    print(f"Processing {symbol} ...")
    res = process_symbol(symbol)
    if res is not None:
        results_list.append(res)

results_df = pd.DataFrame(results_list)
print("\nResults DataFrame:")
print(results_df)

results_df_sorted = results_df.sort_values(by="MSE")
print("\nRanked by MSE:")
print(results_df_sorted)
results_df_sorted.to_csv(saveAs, index=False)


Processing CSCO ...
Processing ISRG ...
Processing BA ...
Processing VRTX ...
Processing GILD ...
Processing EQIX ...
Processing MDT ...
Processing V ...
Processing MO ...
Processing CDNS ...
Processing HCA ...
Processing AJG ...
Processing C ...
Processing T ...
Processing APH ...
Processing MSI ...
Processing FCX ...
Processing BAC ...
Processing PSX ...
Processing ADI ...
Processing ADBE ...
Processing CPRT ...
Processing TDG ...
Processing SYK ...
Processing CB ...
Processing NOW ...
Processing LLY ...
Processing COST ...
Processing LOW ...
Processing MDLZ ...
Processing BKNG ...
Processing MET ...
Processing DLR ...
Processing TJX ...
Processing MPC ...
Processing D ...
Processing MRK ...
Processing NOC ...
Processing UNP ...
Processing ABBV ...
Processing ORCL ...
Processing ECL ...
Processing SBUX ...
Processing AMT ...
Processing INTU ...
Processing PG ...
Processing CAT ...
Processing MCD ...
Processing AMZN ...
Processing INTC ...
Processing BDX ...
Processing KMI ...
Process