In [1]:
%matplotlib inline


# 1) Wipe out your namespace
%reset -f

# 2) Clear Jupyter’s stored outputs (and inputs if you like)
try:
    Out.clear()
except NameError:
    pass

try:
    In.clear()
except NameError:
    pass

# 3) Force Python GC
import gc
gc.collect()

# 4) Free any GPU buffers
import torch
if torch.cuda.is_available():
    torch.cuda.empty_cache()


import importlib
from libs import params, trades, feats, plots, models_core
from libs.models import dual_lstm
importlib.reload(params)
importlib.reload(trades)
importlib.reload(feats)
importlib.reload(plots)
importlib.reload(models_core)
importlib.reload(dual_lstm)

<module 'libs.models.dual_lstm' from '/workspace/my_models/Trading/_Stock_Analysis_/libs/models/dual_lstm.py'>

In [2]:
import pandas as pd
pd.set_option('display.max_columns', None)

import numpy  as np
import math
import matplotlib.pyplot as plt

import datetime as dt
import os
from typing import Sequence, List, Tuple, Optional, Union

import torch.nn as nn
import torch.nn.functional as Funct
from torch_lr_finder import LRFinder
from torch.utils.data import DataLoader, TensorDataset

from tqdm import tqdm

In [3]:
df_feat_sel = pd.read_csv(params.feat_all_csv, index_col=0, parse_dates=True)[params.features_cols_tick + ['bid','ask'] + [params.label_col]]
    
df_feat_sel

Unnamed: 0,rsi_14,macd_line_12_26_9,macd_signal_12_26_9,macd_diff_12_26_9,sma_20,sma_100,atr_14,bb_lband_20,bb_hband_20,bb_width_20,plus_di_14,minus_di_14,adx_14,obv,obv_sma_14,vwap_20,vol_spike_14,vwap_dev_20,ema_7,sma_7,sma_15,sma_30,macd_diff_7_15_3,atr_15,atr_30,bb_lband_15,bb_hband_15,bb_width_15,rsi_15,stoch_k_15,stoch_d_3,plus_di_15,minus_di_15,adx_15,obv_sma_15,vwap_dev_15,vol_spike_15,r_1,r_15,r_30,vol_15,eng_ma,eng_macd,eng_bb,eng_rsi,eng_adx,eng_obv,eng_atr_div,open,high,low,close,volume,hour,day_of_week,month,bid,ask,signal
2004-01-02 13:09:00,0.00000,0.00000,1.00000,0.19934,0.00000,0.00000,0.00000,0.00000,1.00000,0.18696,0.111549,0.642061,0.625459,0.104721,0.474260,0.515377,0.528843,0.500000,0.6,0.018115,1.0,0.500000,0.666667,0.315907,0.509338,0.467641,0.484536,0.000087,0.000052,0.000000,0.000139,0.000052,0.000000,0.322382,0.322619,0.000094,0.000104,0.000122,0.000104,0.000087,0.517986,0.000000,0.000000,0.000156,0.000069,0.000000,0.322633,0.000279,0.764235,0.764235,0.764235,0.764235,48081.25,0.496,-0.976,0.696,0.763664,0.764807,0.021536
2004-01-02 13:10:00,0.00000,0.00000,1.00000,0.25653,0.00000,0.00000,0.00000,0.00000,1.00000,0.24117,0.119645,0.641972,0.625379,0.112746,0.474260,0.515304,0.528775,0.500000,0.6,0.016827,1.0,0.464375,0.666667,0.315907,0.509338,0.467641,0.484536,0.000087,0.000052,0.000000,0.000139,0.000052,0.000000,0.322367,0.322621,0.000093,0.000104,0.000122,0.000104,0.000087,0.517986,0.000000,0.000000,0.000156,0.000069,0.000000,0.322635,0.000310,0.764219,0.764219,0.764219,0.764219,54775.00,0.496,-0.976,0.696,0.763647,0.764790,0.023284
2004-01-02 13:11:00,0.00000,0.00000,1.00000,0.30964,0.00000,0.00000,0.00000,0.00000,1.00000,0.29175,0.125161,0.641894,0.625318,0.118430,0.474260,0.515231,0.528707,0.500000,0.6,0.015539,1.0,0.431250,0.666667,0.315907,0.509338,0.467641,0.484536,0.000087,0.000052,0.000000,0.000139,0.000052,0.000000,0.322351,0.322621,0.000093,0.000104,0.000122,0.000104,0.000087,0.517986,0.000000,0.000000,0.000156,0.000069,0.000000,0.322636,0.000341,0.764202,0.764202,0.764202,0.764202,61468.75,0.496,-0.976,0.696,0.763631,0.764773,0.025175
2004-01-02 13:12:00,0.00000,0.00000,1.00000,0.35895,0.00000,0.00000,0.00000,0.00000,1.00000,0.33897,0.128314,0.641835,0.625263,0.121907,0.474260,0.515161,0.528641,0.500000,0.6,0.014251,1.0,0.400625,0.666667,0.315907,0.509338,0.467641,0.484536,0.000087,0.000052,0.000000,0.000139,0.000052,0.000000,0.322334,0.322619,0.000093,0.000104,0.000122,0.000104,0.000087,0.517986,0.000000,0.000000,0.000156,0.000069,0.000000,0.322635,0.000341,0.764185,0.764185,0.764185,0.764185,68162.50,0.496,-0.976,0.696,0.763614,0.764757,0.027222
2004-01-02 13:13:00,0.00000,0.00000,1.00000,0.40474,0.00000,0.00000,0.00000,0.00000,1.00000,0.38304,0.129531,0.641781,0.625227,0.123579,0.474260,0.515088,0.528573,0.500000,0.6,0.012964,1.0,0.371875,0.666667,0.315907,0.509338,0.467641,0.484536,0.000087,0.000052,0.000000,0.000139,0.000052,0.000000,0.322314,0.322616,0.000093,0.000104,0.000122,0.000104,0.000087,0.517986,0.000000,0.000000,0.000156,0.000069,0.000000,0.322632,0.000341,0.764169,0.764169,0.764169,0.764169,74856.25,0.496,-0.976,0.696,0.763597,0.764740,0.029436
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2025-06-18 20:56:00,0.67217,0.36312,0.07485,0.30194,0.66709,0.73661,0.62549,0.35895,0.07958,0.28287,0.117495,0.652911,0.635234,0.113816,0.478595,0.527683,0.541606,0.515625,0.6,0.000017,0.0,0.518125,0.666667,0.327030,0.808149,0.749478,0.625430,3.399805,3.398280,0.534043,3.397055,3.399910,0.024691,0.321726,0.321693,3.402269,3.403066,3.402026,3.399395,3.398880,0.661871,0.480545,0.270370,3.393991,3.401916,0.023121,0.321703,0.024910,196.680000,196.860000,196.630000,196.815000,385695.00,-0.976,0.039,-0.861,196.667400,196.962600,0.684908
2025-06-18 20:57:00,0.61669,0.34721,0.06704,0.32868,0.61512,0.58962,0.65041,0.34426,0.07170,0.30770,0.126236,0.647611,0.629672,0.123111,0.469708,0.525829,0.542649,0.515625,0.6,0.000000,0.0,0.522500,0.666667,0.327401,0.813243,0.768267,0.608247,3.400205,3.398384,0.553191,3.397020,3.400725,0.024691,0.321606,0.321710,3.402877,3.403535,3.402200,3.400090,3.399384,0.568345,0.498054,0.277778,3.394790,3.402523,0.023121,0.321720,0.025902,196.810000,196.940000,196.560000,196.675000,460630.00,-0.976,0.039,-0.861,196.527500,196.822500,0.612206
2025-06-18 20:58:00,0.61473,0.33627,0.06493,0.35352,0.61329,0.58091,0.63571,0.33401,0.06957,0.33087,0.128887,0.646647,0.628744,0.125853,0.474240,0.523206,0.545497,0.515625,0.6,0.000000,0.0,0.526250,0.666667,0.324064,0.814941,0.782881,0.584192,3.400553,3.398523,0.529787,3.397020,3.401437,0.030864,0.321469,0.321714,3.403342,3.403865,3.403225,3.400612,3.399993,0.510791,0.478599,0.272840,3.395190,3.403165,0.023121,0.321726,0.025499,196.675000,196.740000,196.630000,196.670000,525245.00,-0.976,0.039,-0.861,196.522500,196.817500,0.610468
2025-06-18 20:59:00,0.47540,0.29058,0.16002,0.34896,0.48143,0.17135,0.44729,0.29098,0.15913,0.32834,0.344319,0.634643,0.616953,0.337903,0.459964,0.514750,0.538547,0.515625,0.6,0.000000,0.0,0.511875,0.666667,0.326288,0.753820,0.778706,0.467354,3.400535,3.398593,0.570213,3.396968,3.401437,0.030864,0.320926,0.321683,3.402760,3.402249,3.402843,3.400560,3.400341,0.287770,0.513619,0.285185,3.395034,3.403200,0.023121,0.321694,0.031509,196.680000,196.750000,196.240000,196.240000,2075503.00,-0.976,0.039,-0.861,196.092800,196.387200,0.477090


In [4]:
# Build LSTM input tensors (disk-backed memmaps)
X, y_sig, y_ret, raw_close, raw_bid, raw_ask, end_times = models_core.build_tensors(
    df            = df_feat_sel,
    sess_start    = params.sess_start_pred_tick 
)

# quick shapes
print("Shapes:")
print("  X         =", X.shape,    "(samples, look_back, features)")
print("  y_sig     =", y_sig.shape, "(samples,)")
print("  y_ret     =", y_ret.shape, "(samples,)")
print("  raw_close =", raw_close.shape)
print("  raw_bid   =", raw_bid.shape)
print("  raw_ask   =", raw_ask.shape)
print("  end_times =", end_times.shape)


Inside build_tensors, features: ['rsi_14', 'macd_line_12_26_9', 'macd_signal_12_26_9', 'macd_diff_12_26_9', 'sma_20', 'sma_100', 'atr_14', 'bb_lband_20', 'bb_hband_20', 'bb_width_20', 'plus_di_14', 'minus_di_14', 'adx_14', 'obv', 'obv_sma_14', 'vwap_20', 'vol_spike_14', 'vwap_dev_20', 'ema_7', 'sma_7', 'sma_15', 'sma_30', 'macd_diff_7_15_3', 'atr_15', 'atr_30', 'bb_lband_15', 'bb_hband_15', 'bb_width_15', 'rsi_15', 'stoch_k_15', 'stoch_d_3', 'plus_di_15', 'minus_di_15', 'adx_15', 'obv_sma_15', 'vwap_dev_15', 'vol_spike_15', 'r_1', 'r_15', 'r_30', 'vol_15', 'eng_ma', 'eng_macd', 'eng_bb', 'eng_rsi', 'eng_adx', 'eng_obv', 'eng_atr_div', 'open', 'high', 'low', 'close', 'volume', 'hour', 'day_of_week', 'month']


Counting windows:   0%|          | 0/5400 [00:00<?, ?it/s]

Writing memmaps:   0%|          | 0/5400 [00:00<?, ?it/s]

Shapes:
  X         = torch.Size([2555208, 90, 56]) (samples, look_back, features)
  y_sig     = torch.Size([2555208]) (samples,)
  y_ret     = torch.Size([2555208]) (samples,)
  raw_close = torch.Size([2555208])
  raw_bid   = torch.Size([2555208])
  raw_ask   = torch.Size([2555208])
  end_times = (2555208,)


In [5]:
# Split into train/val/test by calendar day
(
(X_tr,  y_sig_tr,  y_ret_tr),
(X_val, y_sig_val, y_ret_val),
(X_te,  y_sig_te,  y_ret_te,  raw_close_te, raw_bid_te, raw_ask_te),
samples_per_day,
day_id_tr, day_id_val, day_id_te
) = models_core.chronological_split(
    X, y_sig, y_ret,
    raw_close, raw_bid, raw_ask,
    end_times   = end_times,
    train_prop  = params.train_prop,
    val_prop    = params.val_prop,
    train_batch = params.hparams['TRAIN_BATCH']
)

# Print shapes of all tensors
print("Shapes:")
print("  X_tr  =", X_tr.shape)
print("  y_sig_tr, y_ret_tr =", y_sig_tr.shape, y_ret_tr.shape)
print("  X_val =", X_val.shape)
print("  y_sig_val, y_ret_val =", y_sig_val.shape, y_ret_val.shape)
print("  X_te  =", X_te.shape)
print("  y_sig_te, y_ret_te =", y_sig_te.shape, y_ret_te.shape)


Shapes:
  X_tr  = torch.Size([1805015, 90, 56])
  y_sig_tr, y_ret_tr = torch.Size([1805015]) torch.Size([1805015])
  X_val = torch.Size([361231, 90, 56])
  y_sig_val, y_ret_val = torch.Size([361231]) torch.Size([361231])
  X_te  = torch.Size([388962, 90, 56])
  y_sig_te, y_ret_te = torch.Size([388962]) torch.Size([388962])


In [None]:
# carve `end_times` into the same three splits:
n_tr  = day_id_tr .shape[0] 
n_val = day_id_val.shape[0]
i_tr  = n_tr
i_val = n_tr + n_val

end_times_tr  = end_times[:i_tr]
end_times_val = end_times[i_tr:i_val]
end_times_te  = end_times[i_val:]

# -----------------------------------------------------------------------------
#  Build DataLoaders over calendar‐days
# -----------------------------------------------------------------------------
train_loader, val_loader, test_loader = models_core.split_to_day_datasets(
    # train split:   
    X_tr,            y_sig_tr,     y_ret_tr,   end_times_tr,
    # val split:
    X_val,           y_sig_val,    y_ret_val,  end_times_val,
    # test split + raw‐prices
    X_te,            y_sig_te,     y_ret_te,   end_times_te,
    raw_close_te, raw_bid_te, raw_ask_te,
    
    sess_start_time       = params.sess_start_pred_tick,
    signal_thresh         = params.best_optuna_params["buy_threshold"],
    return_thresh         = 0.01,  # flat‐zone threshold for returns (to tune)
    train_batch           = params.hparams["TRAIN_BATCH"],
    train_workers         = params.hparams["NUM_WORKERS"],
    train_prefetch_factor = params.hparams["TRAIN_PREFETCH_FACTOR"]
)

print(f"Days  → train={len(train_loader.dataset)}, val={len(val_loader.dataset)}, test={len(test_loader.dataset)}")
print(f"Batches → train={len(train_loader)},   val={len(val_loader)},   test={len(test_loader)}")

Creating DayWindowDatasets:   0%|          | 0/3 [00:00<?, ?split/s]

In [None]:
# -----------------------------------------------------------------------------
# Instantiate the stateful DualMemoryLSTM & move to device
# -----------------------------------------------------------------------------
model = dual_lstm.DualMemoryLSTM(
    n_feats        = X.shape[-1],                          
    short_units    = params.hparams['SHORT_UNITS'],    
    long_units     = params.hparams['LONG_UNITS'],     
    dropout_short  = params.hparams['DROPOUT_SHORT'],  
    dropout_long   = params.hparams['DROPOUT_LONG'],   
    att_heads      = params.hparams['ATT_HEADS'],
    att_drop       = params.hparams['ATT_DROPOUT']
)
model.to(params.device)  

model

In [None]:
# -----------------------------------------------------------------------------
# Build optimizer, LR scheduler, AMP scaler, and gradient‐clip norm
# -----------------------------------------------------------------------------
optimizer, plateau_sched, cosine_sched, scaler, clipnorm = models_core.make_optimizer_and_scheduler(
    model,
    initial_lr        = params.hparams['INITIAL_LR'],       
    weight_decay      = params.hparams['WEIGHT_DECAY'],     
    clipnorm          = params.hparams['CLIPNORM']   
)

optimizer

In [None]:
# -----------------------------------------------------------------------------
# Helper: extract the true “signal” values from any loader into a flat array
# -----------------------------------------------------------------------------
def extract_y(loader):
    return np.concatenate([batch[1].cpu().numpy().ravel() for batch in loader])

# Pull out train & validation targets
y_train = extract_y(train_loader)
y_val   = extract_y(val_loader)

# -----------------------------------------------------------------------------
# 1) Zero‐forecast baseline RMSE (predict 0 always)
#    RMSE_zero = √(mean(y²))
# -----------------------------------------------------------------------------
rmse_zero_train = np.sqrt(np.mean(y_train**2))
rmse_zero_val   = np.sqrt(np.mean(y_val**2))
print(f"Zero‐forecast RMSE (predict 0): train = {rmse_zero_train:.6f},  val = {rmse_zero_val:.6f}\n")

# -----------------------------------------------------------------------------
# 2) Compute mean, variance & std for train/validation targets
#    and derive the mean‐predictor baseline (R² = 0)
# -----------------------------------------------------------------------------
for split, y in [("Train", y_train), ("Validation", y_val)]:
    mean_y    = y.mean()
    std_y     = y.std(ddof=0)     # population std = √variance
    var_y     = std_y**2
    rmse_mean = std_y            # RMSE_baseline = std(target)

    print(f"{split} target stats:")
    print(f"  mean = {mean_y:.4f},  var = {var_y:.4f},  std = {std_y:.4f}")
    print(f"{split} mean‐predictor baseline:")
    print(f"  RMSE_baseline = {rmse_mean:.6f}")
    print("  R²_baseline   = 0.00\n")

    if split == "Validation":
        rmse_mean_val = rmse_mean # used afterwards for the final reporting
        


In [None]:
# Visualize the true‐signal distributions on train vs. validation
plt.hist(y_train, bins=100, alpha=0.5, label="train true")
plt.hist(y_val,   bins=100, alpha=0.5, label="val true")
plt.xlabel("Signal value")
plt.ylabel("Count")
plt.title("True Signal Distribution: Train vs. Validation")
plt.legend()
plt.tight_layout()
plt.show()


In [None]:
# How many unique trading days does each epoch see?
n_days = len(train_loader.dataset)
print(f"Training sees {n_days} unique trading days per epoch.\n")

print('Using HyperParameters:\n "look_back":', params.look_back_tick, params.hparams)

# -----------------------------------------------------------------------------
# Run the custom stateful training loop
# -----------------------------------------------------------------------------
best_val_rmse  = dual_lstm.lstm_training_loop(
    model               = model,
    optimizer           = optimizer,
    cosine_sched        = cosine_sched,
    plateau_sched       = plateau_sched,
    scaler              = scaler,
    train_loader        = train_loader,
    val_loader          = val_loader,
    max_epochs          = params.hparams['MAX_EPOCHS'],
    early_stop_patience = params.hparams['EARLY_STOP_PATIENCE'],
    clipnorm            = clipnorm,
    device              = params.device
)


In [None]:
# -----------------------------------------------------------------------------
# Final reporting: best RMSE and relative improvement
# -----------------------------------------------------------------------------
print(f"\nChampion validation RMSE = {best_val_rmse:.6f}")

improvement_zero = 100.0 * (1.0 - best_val_rmse / rmse_zero_val)
print(f"Improvement over zero‐baseline = {improvement_zero:5.1f}%")

improvement_mean = 100.0 * (1.0 - best_val_rmse / rmse_mean_val)
print(f"Improvement over mean‐baseline = {improvement_mean:5.1f}%")
