In [1]:
import sys
from pathlib import Path

# Current notebook location
notebook_path = Path().resolve()

# Add parent folder (meta/) to sys.path
sys.path.append(str(notebook_path.parent))
import joblib
import torch
import pytorch_lightning as pl
import pandas as pd
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, confusion_matrix
from datetime import datetime
from add_ons.relative_change import add_pct_changes
from add_ons.drop_column import drop_columns
from add_ons.zigzag_single import add_zigzag
from preprocess.classification_pre_dict import preprocess_csv
from models.LSTM.multilstm_classification import MultiLSTMClassifier
from itertools import islice
from add_ons.featue_pipeline2 import FeaturePipeline


# ----------------- Evaluation -----------------
def evaluate_model(model, val_loader, label_encoder):
    """Generate classification report & confusion matrix."""
    model.eval()
    all_preds, all_labels = [], []

    with torch.no_grad():
        for X_batch, y_batch in val_loader:
            logits = model(X_batch)
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y_batch.cpu().numpy())

    print("\n📊 Validation Report:")
    print(classification_report(all_labels, all_preds, target_names=label_encoder.classes_))

    cm = confusion_matrix(all_labels, all_preds)
    print("Confusion Matrix:")
    print(cm)

# ----------------- Training -----------------
def train_model(
    data_csv,
    labels_csv,
    model_out_dir="models/saved_models",
    do_validation=False,
    seq_len=3,
    hidden_dim=64,
    num_layers=1,
    lr=0.001,
    batch_size=32,
    max_epochs=10,
    save_model=False,
    return_val_accuracy=False,
    test_mode=True
):
    """
    Train an LSTM classification model with zigzag features and custom normalization.
    """
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    model_out = f"{model_out_dir}/lstm_model_class_{timestamp}.pt"
    meta_out  = f"{model_out_dir}/lstm_meta_class_{timestamp}.pkl"

    # --- Define Feature Pipeline ---
    pipeline = FeaturePipeline(
            steps=[lambda df: add_pct_changes(df, separatable="complete")],
            norm_methods={
                "main": {"upper_shadow": "standard"},
                "pct_changes": {"open_pct": "standard", "high_pct": "standard"}
            }
        )

    seq_dict = {"main": 5, "pct_changes": 3}  # different seq lens per group


    # --- Get dataset(s) ---
    if do_validation:
        train_ds, val_ds, label_encoder, df = preprocess_csv(
            data_csv, labels_csv,
            n_candles=seq_dict,
            val_split=True,
            feature_pipeline=pipeline
        )
    else:
        full_dataset,label_encoder, df ,feature_cols= preprocess_csv(
            data_csv, labels_csv,
            n_candles=seq_dict,
            val_split=True,
            feature_pipeline=pipeline
)

    # --- Model config ---
    input_dims = {
        "main": train_ds.X_dict["main"].shape[-1],
        "pct_changes": train_ds.X_dict["pct_changes"].shape[-1]
    }
    num_classes = len(label_encoder.classes_)

    model = MultiLSTMClassifier(
        input_dims=input_dims,
        hidden_dim=hidden_dim,
        num_layers=num_layers,
        num_classes=num_classes,
        lr=lr
    )

    # --- DataLoaders ---
    if do_validation:
        train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
        val_loader   = DataLoader(val_ds, batch_size=32)
    else:
        train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
        val_loader   = None

    # --- Debug mode ---
    if test_mode:
        X_batch_dict, y_batch = next(islice(iter(train_loader), 2, 3))
        print("🔍 Debug batch (third batch):")
        print("  Keys in X_batch:", list(X_batch_dict.keys()))
        print("  y_batch shape:", y_batch.shape)   # (batch_size,)
        print("  First label in batch:", y_batch[0])

        # Iterate over dict to inspect each input
        for name, X_batch in X_batch_dict.items():
            print(f"\nFeature group: {name}")
            print("  X_batch shape:", X_batch.shape)  # (batch_size, seq_len, feature_dim)
            print("  First sequence in batch:\n", X_batch[0])

            batch_size, seq_len, feature_dim = X_batch.shape
            global df_seq
            df_seq = pd.DataFrame(
                X_batch.reshape(batch_size * seq_len, feature_dim).numpy(),
                columns=[f"{name}_{c}" for c in range(feature_dim)]  # temporary column names
            )


    # --- Trainer ---
    trainer = pl.Trainer(
        max_epochs=max_epochs,
        accelerator="auto",
        devices=1,
        log_every_n_steps=10,
        fast_dev_run=test_mode
    )

    trainer.fit(model, train_loader, val_loader)

    # --- Save model & meta ---
    if save_model:
        trainer.save_checkpoint(model_out)
        joblib.dump({
            'input_dim': input_dims,
            'hidden_dim': hidden_dim,
            'num_layers': num_layers,
            'num_classes': num_classes,
            'seq_len': seq_len,
            'lr': lr,
            'label_classes': label_encoder.classes_
        }, meta_out)
        print(f"\n✅ Model saved to {model_out}")
        print(f"✅ Meta saved to {meta_out}")

    # --- Optional evaluation ---
    val_acc = None
    if do_validation and val_loader is not None:
        evaluate_model(model, val_loader, label_encoder)
        model.eval()
        all_preds, all_labels = [], []
        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                logits = model(X_batch)
                preds = torch.argmax(logits, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(y_batch.cpu().numpy())
        val_acc = (torch.tensor(all_preds) == torch.tensor(all_labels)).float().mean().item()

    if return_val_accuracy:
        return {"accuracy": val_acc}


if __name__ == "__main__":
    train_model(
        "/home/iatell/projects/meta-learning/data/Bitcoin_BTCUSDT_kaggle_1D_candles_prop.csv",
        "/home/iatell/projects/meta-learning/data/labeled_ohlcv_string.csv",
        do_validation=True
    )


💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.
You are using a CUDA device ('NVIDIA GeForce RTX 3060') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


🔍 Debug batch (third batch):
  Keys in X_batch: ['main', 'pct_changes']
  y_batch shape: torch.Size([32])
  First label in batch: tensor(2)

Feature group: main
  X_batch shape: torch.Size([32, 5, 16])
  First sequence in batch:
 tensor([[ 1.0001e+04,  1.0323e+04,  9.6660e+03,  1.0160e+04,  3.8161e+04,
          1.1598e-01,  1.1293e-01,  2.3772e-01,  2.0000e+00,  1.0270e+00,
          2.1050e+00,  4.8789e-01,  5.8297e-02,  1.0164e-02,  3.9187e-02,
          1.5989e-02],
        [ 1.0156e+04,  1.1075e+04,  1.0050e+04,  1.1040e+04,  4.1883e+04,
          2.5351e-02,  6.3056e-01,  7.5705e-02,  2.0000e+00,  4.0205e-02,
          1.2006e-01,  3.3487e-01,  1.5517e-02,  7.2815e-02,  3.9727e-02,
          8.6572e-02],
        [ 1.1040e+04,  1.1274e+04,  1.0080e+04,  1.0383e+04,  6.1138e+04,
          1.6783e-01, -4.6968e-01,  2.1721e-01,  1.0000e+00,  3.5733e-01,
          4.6246e-01,  7.7267e-01,  8.6990e-02,  1.7962e-02,  2.9851e-03,
         -5.9434e-02],
        [ 1.0375e+04,  1.1250e+04, 


  | Name       | Type       | Params | Mode 
--------------------------------------------------
0 | lstms      | ModuleDict | 38.9 K | train
1 | classifier | Linear     | 387    | train
--------------------------------------------------
39.3 K    Trainable params
0         Non-trainable params
39.3 K    Total params
0.157     Total estimated model params size (MB)
4         Modules in train mode
0         Modules in eval mode
/home/iatell/envs/Rllib2.43/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
/home/iatell/envs/Rllib2.43/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argumen

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_steps=1` reached.



📊 Validation Report:
              precision    recall  f1-score   support

           0       0.71      0.21      0.32        24
           a       0.00      0.00      0.00         5
           s       0.00      0.00      0.00         2

    accuracy                           0.16        31
   macro avg       0.24      0.07      0.11        31
weighted avg       0.55      0.16      0.25        31

Confusion Matrix:
[[ 5  0 19]
 [ 0  0  5]
 [ 2  0  0]]


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


In [2]:
df_seq

Unnamed: 0,pct_changes_0,pct_changes_1,pct_changes_2,pct_changes_3
0,0.086990,0.017962,0.002985,-0.059434
1,-0.060196,-0.002129,0.018882,0.074115
2,0.074419,0.047645,0.080841,0.004303
3,-0.059567,-0.037497,-0.021281,0.009739
4,0.010269,-0.010501,0.005664,-0.015579
...,...,...,...,...
91,0.044600,0.005535,0.044596,0.017721
92,0.016710,0.010672,-0.039099,-0.033525
93,-0.030641,-0.022082,-0.013121,-0.018947
94,-0.019968,-0.009044,-0.010545,0.013327
