In [1]:
%matplotlib inline

# 1) Wipe out your namespace
%reset -f

# 2) Clear Jupyter’s stored outputs (and inputs if you like)
try:
    Out.clear()
except NameError:
    pass

try:
    In.clear()
except NameError:
    pass

# 3) Force Python GC
import gc
gc.collect()

# 4) Free any GPU buffers
import torch
if torch.cuda.is_available():
    torch.cuda.empty_cache()

import importlib
from libs import params, preps, feats, plots, models_core, models_custom
importlib.reload(params)
importlib.reload(preps)
importlib.reload(feats)
importlib.reload(plots)
importlib.reload(models_core)
importlib.reload(models_custom)

<module 'libs.models_custom' from '/workspace/my_models/Trading/_Stock_Analysis_/libs/models_custom.py'>

In [2]:
import pandas as pd
pd.set_option('display.max_columns', None)

import numpy  as np
import math
import matplotlib.pyplot as plt

import datetime as dt
import os
from typing import Sequence, List, Tuple, Optional, Union

import torch.nn as nn
from torch.nn import MSELoss, Dropout
import torch.nn.functional as Funct
from torch_lr_finder import LRFinder
from torch.utils.data import DataLoader, TensorDataset
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, ReduceLROnPlateau, OneCycleLR
from torch.amp import GradScaler

from tqdm import tqdm

In [3]:
df_sign_selfeats = pd.read_csv(params.sign_featall_csv, index_col=0, parse_dates=True)[params.features_cols_tick + params.signals_cols_tick]
df_sign_selfeats = df_sign_selfeats[df_sign_selfeats.index >= '2020-12-31'] ###########
df_sign_selfeats

Unnamed: 0,range_pct,atr_pct_7,atr_pct_28,time_afthour,time_premark,kc_w_20_20_2.0,bb_w_20_2p0,donch_w_20,ret_std_63,atr_pct_14,donch_w_55,ret_std_21,upper_shad,time_in_sess,dist_high_200,bb_w_50_2p0,lower_shad,time_hour,dist_low_200,time_week_of_year,trade_count,volume,atr_7_RZ,atr_14_RZ,atr_28_RZ,time_day_of_year,time_month,vol_spike_28,plus_di_28,stoch_k_14_3_3,rolling_max_close_200_RZ,adx_14,minus_di_28,minus_di_14,cci_20,plus_di_7,plus_di_14,adx_28,rsi_6,rolling_min_close_200_RZ,vol_spike_14,stoch_d_9_3_3,sma_5_RZ,minus_di_7,sma_21_RZ,sma_9_RZ,cci_14,sma_pct_200,stoch_k_9_3_3,cmf_14,close_raw,signal_raw,signal_thresh
2020-12-31 08:00:00,0.542954,0.179611,0.088732,0.0,1.0,0.102595,0.090002,0.115841,0.139730,0.122468,0.060489,0.230386,0.0000,0.0,0.066522,0.028942,0.000000,0.270833,0.007069,0.500000,0.002550,0.000234,0.446326,0.452637,0.414132,0.500000,0.416667,0.033347,0.402053,0.461282,0.60294,0.184634,0.456535,0.427611,0.119221,0.252371,0.333563,0.081782,0.276591,0.366904,0.033849,0.538291,0.482865,0.415002,0.447276,0.502435,0.108973,0.475250,0.461282,0.751227,133.560000,1.558898e-08,0.063077
2020-12-31 08:01:00,0.361843,0.233363,0.107583,0.0,1.0,0.127570,0.100127,0.115798,0.141450,0.155615,0.060463,0.233947,0.0000,0.0,0.055798,0.033313,0.000000,0.270833,0.018058,0.500000,0.002040,0.000170,0.572108,0.593658,0.537115,0.500000,0.416667,0.029998,0.305918,0.392883,0.60294,0.170553,0.350936,0.309395,0.232951,0.166325,0.238001,0.075984,0.360348,0.366904,0.029252,0.506859,0.396382,0.274279,0.424026,0.458333,0.236820,0.487457,0.392883,0.747974,133.606667,1.523993e-08,0.063077
2020-12-31 08:02:00,0.180858,0.248224,0.117011,0.0,1.0,0.139405,0.102794,0.115756,0.143089,0.169859,0.060437,0.237105,0.0000,0.0,0.045081,0.034731,0.000000,0.270833,0.029040,0.500000,0.001530,0.000106,0.607046,0.654431,0.598799,0.500000,0.416667,0.025275,0.262146,0.179418,0.60294,0.157478,0.302855,0.260681,0.333068,0.133704,0.198623,0.070394,0.438398,0.366904,0.024416,0.344528,0.328940,0.220930,0.414918,0.391994,0.338798,0.499617,0.179418,0.744203,133.653333,1.452162e-08,0.063077
2020-12-31 08:03:00,0.000000,0.229783,0.117362,0.0,1.0,0.138757,0.102794,0.115714,0.144665,0.166566,0.060411,0.239949,0.0000,0.0,0.034372,0.034855,0.000000,0.270833,0.040014,0.500000,0.001020,0.000041,0.564118,0.640690,0.601373,0.500000,0.416667,0.020719,0.248178,0.358836,0.60294,0.145337,0.287511,0.245491,0.425166,0.123519,0.186344,0.065003,0.510150,0.366904,0.020673,0.310379,0.307826,0.204273,0.416117,0.379466,0.417618,0.511717,0.358836,0.791758,133.700000,1.334189e-08,0.063077
2020-12-31 08:04:00,0.000000,0.234169,0.123431,0.0,1.0,0.145760,0.112237,0.115804,0.151128,0.174262,0.060466,0.251782,0.0000,0.0,0.057330,0.038973,0.000000,0.270833,0.016489,0.500000,0.003570,0.000509,0.573962,0.672842,0.640276,0.500000,0.416667,0.059034,0.220703,0.350293,0.60294,0.154419,0.392251,0.359027,0.035009,0.103694,0.162503,0.069585,0.383585,0.366904,0.060273,0.296182,0.175016,0.338214,0.391189,0.321467,0.132859,0.486107,0.350293,0.775047,133.600000,2.267731e-08,0.063077
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2026-01-09 17:29:00,0.134332,0.190067,0.202690,0.0,0.0,0.195248,0.203577,0.213304,0.172218,0.190330,0.121831,0.173700,0.2400,1.0,0.004729,0.114646,0.207692,0.645833,0.285845,0.519231,0.281999,0.054167,0.145232,0.140438,0.150573,0.521918,0.500000,0.247562,0.360762,0.958040,0.60294,0.161600,0.171854,0.109079,0.985913,0.495497,0.407921,0.113547,0.882500,1.000000,0.318022,0.907809,0.488809,0.032127,0.439971,0.470742,0.992630,0.641047,0.932446,0.711288,259.120000,0.000000e+00,0.114549
2026-01-09 17:30:00,0.190326,0.195380,0.204350,0.0,0.0,0.197654,0.222271,0.234976,0.171744,0.193785,0.133665,0.170860,0.8800,1.0,0.017737,0.118855,0.153846,0.645833,0.282234,0.519231,0.316930,0.051326,0.155111,0.149870,0.154733,0.521918,0.500000,0.237045,0.377511,0.916498,0.60294,0.195943,0.160423,0.097698,0.914192,0.503735,0.426817,0.125810,0.827711,1.000000,0.295539,0.929136,0.507977,0.026670,0.446067,0.478739,0.870487,0.636175,0.906892,0.630932,259.090000,8.627221e-03,0.114549
2026-01-09 17:31:00,0.106354,0.185436,0.201863,0.0,0.0,0.194412,0.237508,0.234967,0.170516,0.189298,0.133660,0.169706,0.4000,1.0,0.016554,0.123305,0.269231,0.645833,0.283438,0.519231,0.372514,0.058680,0.137223,0.143397,0.150715,0.521918,0.500000,0.266432,0.365909,0.857634,0.60294,0.225278,0.160791,0.101967,0.816218,0.454419,0.404017,0.136925,0.831894,1.000000,0.320799,0.895140,0.520831,0.040196,0.452411,0.487369,0.761877,0.636791,0.846082,0.632254,259.100000,4.887917e-03,0.114549
2026-01-09 17:32:00,0.111970,0.177914,0.199776,0.0,0.0,0.191705,0.237214,0.235005,0.170380,0.185681,0.133683,0.172307,0.0800,1.0,0.021405,0.126612,0.376923,0.645833,0.278502,0.519231,0.279449,0.037245,0.122512,0.136667,0.147740,0.521918,0.500000,0.168755,0.353873,0.789616,0.60294,0.244850,0.174061,0.126572,0.740656,0.405610,0.380780,0.145514,0.743239,1.000000,0.203202,0.841089,0.532654,0.086459,0.458174,0.495565,0.685212,0.630843,0.770293,0.601727,259.059000,2.824729e-02,0.114549


In [None]:
train_loader, val_loader, test_loader, end_times_tr, end_times_val, end_times_te = models_core.model_core_pipeline(
    df              = df_sign_selfeats,
    train_batch     = params.hparams["TRAIN_BATCH"],
    train_workers   = params.hparams["TRAIN_WORKERS"],
    prefetch_factor = params.hparams["TRAIN_PREFETCH_FACTOR"],
    look_back       = params.hparams["LOOK_BACK"],
    features_cols   = params.features_cols_tick,
)

# del df_sign_selfeats, end_times_tr, end_times_val, end_times_te
# gc.collect()

for name, ld, tm in zip(
    ["train","val","test"],
    [train_loader, val_loader, test_loader],
    [end_times_tr, end_times_val, end_times_te]
):
    models_core.summarize_split(name, ld, tm)


Preparing days:   0%|          | 0/1262 [00:00<?, ?it/s]

N_total: 1120170 look_back: 60 F: 50
Estimated X_buf size: 13.44 GB — using RAM (in-memory) (thresh 56 GiB)


Writing days:   0%|          | 0/1262 [00:00<?, ?it/s]

In [None]:
y_train = np.concatenate([batch[1].cpu().numpy().ravel() for batch in train_loader])
y_val = np.concatenate([batch[1].cpu().numpy().ravel() for batch in val_loader])

low, high = np.percentile(y_train, [1, 99])
bins = np.linspace(low, high, 50)   # zeros will be included in the leftmost bin if <= low

plt.figure(figsize=(10,5))
plt.hist(y_train, bins=bins, alpha=0.5, label="train")
plt.hist(y_val,   bins=bins, alpha=0.5, label="val")

# description / annotation
desc = (
    "Histogram compares the distribution of true signal values\n"
    "between training and validation sets. Percentiles (1st–99th)\n"
    f"of the training set were used to define the plotting range: low={low:.4g}, high={high:.4g}.\n"
    "Exact zeros are included in the leftmost bin; heavy zero-mass can dominate counts."
)
plt.title("True Signal Distribution: Train vs. Validation")
plt.xlabel("Signal value")
plt.ylabel("Count")
plt.legend()
plt.gca().text(
    0.99, -0.18, desc, ha="right", va="top", transform=plt.gca().transAxes,
    fontsize=9, color="gray"
)
plt.tight_layout()
plt.show()



In [None]:
# importlib.reload(models_custom) #############
# importlib.reload(params) #############

model = models_custom.ModelClass(
    n_feats             = len(params.features_cols_tick),
    short_units         = params.hparams["SHORT_UNITS"],
    long_units          = params.hparams["LONG_UNITS"],
    transformer_d_model = params.hparams["TRANSFORMER_D_MODEL"],
    transformer_layers  = params.hparams["TRANSFORMER_LAYERS"],
    dropout_short       = params.hparams["DROPOUT_SHORT"],
    dropout_long        = params.hparams["DROPOUT_LONG"],
    dropout_trans       = params.hparams["DROPOUT_TRANS"],
    pred_hidden         = params.hparams["PRED_HIDDEN"],
    look_back           = params.hparams["LOOK_BACK"],

    # Gating flags
    use_conv            = params.hparams["USE_CONV"],
    use_tcn             = params.hparams["USE_TCN"],
    use_short_lstm      = params.hparams["USE_SHORT_LSTM"],
    use_transformer     = params.hparams["USE_TRANSFORMER"],
    use_long_lstm       = params.hparams["USE_LONG_LSTM"],
    use_delta           = params.hparams["USE_DELTA"],
    flatten_mode        = params.hparams["FLATTEN_MODE"]
)

model.feature_names = params.features_cols_tick # for logging
model.to(params.device)  
print('Using:', params.device)
model

In [None]:
base_lr = params.hparams["ONECYCLE_MAX_LR"]
head_lr = base_lr * params.hparams["HEAD_LR_PCT"] # if we want to reduce the learning rate applied to the head

params_map = dict(model.named_parameters())
head_param = params_map["head_flat.2.bias"] # select the correct head name for the lr reduction
backbone_params = [p for n,p in params_map.items() if n != "head_flat.2.bias"]

optimizer = AdamW(
    [
        {"params": backbone_params, "lr": base_lr},
        {"params": [head_param],    "lr": head_lr},
    ],
    weight_decay = params.hparams["WEIGHT_DECAY"]
)


batches_per_epoch = len(train_loader)
total_steps = batches_per_epoch * params.hparams["MAX_EPOCHS"]

scheduler = OneCycleLR(
  optimizer,
  max_lr           = params.hparams["ONECYCLE_MAX_LR"],
  total_steps      = total_steps,
  pct_start        = params.hparams["ONECYCLE_PCT_START"],
  div_factor       = params.hparams["ONECYCLE_DIV_FACTOR"],
  final_div_factor = params.hparams["ONECYCLE_FINAL_DIV"],
  anneal_strategy  = params.hparams["ONECYCLE_STRATEGY"],
)
optimizer.scheduler = scheduler # necessary to log sched_field

if getattr(scheduler, "total_steps", None) != total_steps:
    raise RuntimeError(f"Scheduler total_steps mismatch: scheduler={getattr(scheduler,'total_steps',None)} expected={total_steps}")

optimizer

In [None]:
# importlib.reload(models_custom) #############
# importlib.reload(plots) #############

n_days = len(train_loader.dataset)
print(f"Training sees {n_days} unique trading days per epoch.\n")

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model parameters: total={total_params:,}, trainable={trainable_params:,}\n")

print('Using HyperParameters:\n', params.hparams)

best_val_rmse  = models_custom.model_training_loop(
    model               = model,
    optimizer           = optimizer,
    scheduler           = scheduler,
    scaler              = GradScaler(),
    train_loader        = train_loader,
    val_loader          = val_loader
)