# neural network

In [2]:
import joblib
import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, multilabel_confusion_matrix
from datetime import datetime
import pandas as pd
import io
import os
import numpy as np
from preprocess.multilabel_preprocess2 import preprocess_csv_multilabel
from models.LSTM.lstm_multi_label import LSTMMultiLabelClassifier
from utils.print_batch import print_batch
from utils.json_to_csv import json_to_csv_in_memory  # <-- new util
from utils.multilabel_threshold_tuning import tune_thresholds_nn
from add_ons.feature_pipeline5 import FeaturePipeline
from add_ons.drop_columns2 import drop_columns
from add_ons.candle_dif_rate_of_change_percentage2 import add_candle_rocp
from add_ons.candle_proportion import add_candle_proportions
from add_ons.candle_rate_of_change import add_candle_ratios
from add_ons.candle_proportion_simple import add_candle_shape_features
from add_ons.normalize_candle_seq import add_label_normalized_candles
from utils.make_step import make_step

def evaluate_model(model, val_loader, mlb, threshold=0.2, return_probs=False):
    model.eval()
    all_preds, all_labels, all_probs = [], [], []

    with torch.no_grad():
        for X_batch, y_batch in val_loader:
            logits = model(X_batch)
            probs = torch.sigmoid(logits)
            preds = (probs >= threshold).int()
            all_preds.append(preds.cpu().numpy())
            all_labels.append(y_batch.cpu().numpy())
            all_probs.append(probs.cpu().numpy())

    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)
    all_probs = np.vstack(all_probs)

    print("\n📊 Validation Report (Multi-label):")
    print(classification_report(all_labels, all_preds, target_names=mlb.classes_, zero_division=0))

    print("\n🧮 Multi-label Confusion Matrices (per class):")
    mcm = multilabel_confusion_matrix(all_labels, all_preds)
    for i, cls in enumerate(mlb.classes_):
        print(f"\nClass '{cls}':")
        print(mcm[i])

    val_acc_exact = np.all(all_preds == all_labels, axis=1).mean()
    val_acc_micro = (all_preds == all_labels).mean()
    print("\nExact match ratio:", val_acc_exact)
    print("Micro accuracy (per-label):", val_acc_micro)

    if return_probs:
        return val_acc_exact, val_acc_micro, all_probs
    else:
        return val_acc_exact, val_acc_micro



def train_model(
    data_csv,
    labels_json=None,
    model_out_dir="models/saved_models",
    do_validation=True,
    seq_len=1,
    hidden_dim=10,
    num_layers=1,
    lr=0.001,
    batch_size=32,
    max_epochs=200,
    save_model=False,
    return_val_accuracy=True,
    test_mode=False,
    tune_thresholds = False,
    include_no_label = False,
    label_weighting = "none"
):
    """
    Train an LSTM classification model with labels coming from JSON (in-memory CSV).
    """

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    model_out = f"{model_out_dir}/lstm_model_class_{timestamp}.pt"
    meta_out = f"{model_out_dir}/lstm_meta_class_{timestamp}.pkl"

    # --- Prepare labels ---
    if labels_json is not None:
        csv_string = json_to_csv_in_memory(labels_json)   # returns CSV string
        labels_csv = io.StringIO(csv_string)              # file-like for pandas
    else:
        raise ValueError("labels_json must be provided")

    pipeline = FeaturePipeline(
        steps=[
            make_step(add_candle_shape_features),
            # make_step(add_candle_rocp),
            # make_step(add_label_normalized_candles),
            make_step(drop_columns, cols_to_drop=["open","high","low","close","volume"]),
        ],
        # norm_methods={
            # "main": {
            #     "upper_shadow": "robust", "body": "standard", "lower_shadow": "standard",
            #     "upper_body_ratio": "standard", "lower_body_ratio": "standard",
            #     "upper_lower_body_ratio": "standard", "Candle_Color": "standard",
                
            # }
        #         "candle_shape": {
        #             "upper_shadow": "standard",
        #             "lower_shadow": "standard",
        #             "body": "standard",
        #             "color": "standard",
        #         }
        # },
        # window_norms={
        # "main": {"open_prop": "standard", "high_prop": "standard","low_prop": "standard", "close_prop": "standard"},},

        per_window_flags=[
        False, 
        False, 
        # True
                ]
    )
        # --- Get dataset(s) ---
    if do_validation:
        train_ds, val_ds, df, feature_cols, label_encoder, label_weights = preprocess_csv_multilabel(
            data_csv, labels_csv,
            n_candles=seq_len,
            val_split=True,
            debug_sample=True,
            feature_pipeline=pipeline,
            label_weighting=label_weighting,
            include_no_label = include_no_label
        )
    else:
        full_dataset, df, feature_cols, label_encoder, label_weights = preprocess_csv_multilabel(
            data_csv, labels_csv,
            n_candles=seq_len,
            val_split=False,
            debug_sample=True,
            label_weighting=label_weighting,
            include_no_label =include_no_label
        )

    # --- Model config ---
    input_dim = train_ds[0][0].shape[1] if do_validation else full_dataset[0][0].shape[1]
    num_classes = len(label_encoder.classes_)
    label_weights_tensor = torch.tensor(label_weights, dtype=torch.float32)

    model = LSTMMultiLabelClassifier(
        input_dim=input_dim,
        hidden_dim=hidden_dim,
        num_layers=num_layers,
        num_classes=num_classes,
        lr=lr,
        label_weights_tensor=label_weights_tensor
    )

    # --- DataLoaders ---
    if do_validation:
        train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_ds, batch_size=batch_size)
    else:
        train_loader = DataLoader(full_dataset, batch_size=batch_size, shuffle=True)
        val_loader = None

    # --- Debug batch ---
    if test_mode:
        global df_seq
        df_seq = print_batch(train_loader, feature_cols, batch_idx=2)

    # --- Trainer ---
    trainer = pl.Trainer(
        max_epochs=max_epochs,
        accelerator="auto",
        devices=1,
        log_every_n_steps=10,
        fast_dev_run=test_mode,
    )

    trainer.fit(model, train_loader, val_loader)

    # --- Save model & metadata ---
    if save_model:
        os.makedirs(model_out_dir, exist_ok=True)
        trainer.save_checkpoint(model_out)
        joblib.dump({
            "input_dim": input_dim,
            "hidden_dim": hidden_dim,
            "num_layers": num_layers,
            "num_classes": num_classes,
            "seq_len": seq_len,
            "lr": lr,
            "label_classes": label_encoder.classes_,
        }, meta_out)
        print(f"\n✅ Model saved to {model_out}")
        print(f"✅ Meta saved to {meta_out}")

    # --- Validation accuracy ---
    val_acc_exact, val_acc_micro = None, None

    if do_validation:
        # --- Extract all validation labels once ---
        y_true_val = np.vstack([y for _, y in val_loader.dataset])

        # --- Step 1: Evaluate with default threshold ---
        val_acc_exact_default, val_acc_micro_default, y_probs = evaluate_model(
            model, val_loader, label_encoder, threshold=0.5, return_probs=True
        )

        print(f"\n✅ Validation before tuning: Exact={val_acc_exact_default:.3f}, Micro={val_acc_micro_default:.3f}")

        # --- Optional: tune thresholds per label ---
        if tune_thresholds:  # NEW PARAMETER
            optimal_thresholds = tune_thresholds_nn(y_true=y_true_val, y_probs=y_probs)
            print("\n📌 Optimal thresholds per label:", dict(zip(label_encoder.classes_, optimal_thresholds)))

            # --- Step 2: Apply per-label thresholds manually ---
            y_pred_tuned = (y_probs >= np.array(optimal_thresholds)).astype(int)
            val_acc_exact_tuned = np.all(y_pred_tuned == y_true_val, axis=1).mean()
            val_acc_micro_tuned = (y_pred_tuned == y_true_val).mean()
            print(f"✅ Validation after tuning: Exact={val_acc_exact_tuned:.3f}, Micro={val_acc_micro_tuned:.3f}")
        else:
            val_acc_exact_tuned, val_acc_micro_tuned = val_acc_exact_default, val_acc_micro_default



if __name__ == "__main__":
    train_model(
        data_csv="/home/iatell/projects/meta-learning/data/Bitcoin_BTCUSDT_kaggle_1D_candles.csv",
        labels_json="/home/iatell/projects/meta-learning/data/candle_labels.json",  # JSON labels, no CSV needed on disk
        do_validation=True,
        save_model=False,
        include_no_label = True,
        label_weighting="scale_pos"
    )


💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type              | Params | Mode 
--------------------------------------------------------
0 | lstm      | LSTM              | 640    | train
1 | fc        | Linear            | 77     | train
2 | criterion | BCEWithLogitsLoss | 0      | train
--------------------------------------------------------
717       Trainable params
0         Non-trainable params
717       Total params
0.003     Total estimated model params size (MB)
3         Modules in train mode
0         Modules in eval mode


clean [['no_label'], ['no_label'], ['no_label'], ['no_label'], ['no_label'], ['no_label'], ['no_label'], ['no_label'], ['no_label'], ['no_label'], ['no_label'], ['no_label'], ['H'], ['+', 's', 'v'], ['no_label'], ['+', 'v'], ['+', 's', 'v'], ['no_label'], ['no_label'], ['no_label'], ['no_label'], ['+', 's', 'v'], ['no_label'], ['s', 'v'], ['+', 's', 'v'], ['no_label'], ['no_label'], ['no_label'], ['no_label'], ['+', 's', 'v'], ['no_label'], ['H'], ['no_label'], ['no_label'], ['no_label'], ['no_label'], ['+', 's', 'v'], ['no_label'], ['no_label'], ['v'], ['no_label'], ['no_label'], ['no_label'], ['no_label'], ['no_label'], ['v'], ['no_label'], ['no_label'], ['no_label'], ['no_label'], ['no_label'], ['no_label'], ['no_label'], ['no_label'], ['v'], ['no_label'], ['no_label'], ['no_label'], ['no_label'], ['v'], ['no_label'], ['no_label'], ['no_label'], ['no_label'], ['no_label'], ['no_label'], ['+', 'v'], ['no_label'], ['no_label'], ['no_label'], ['+', 's', 'v'], ['no_label'], ['+', 's', '

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/iatell/envs/Rllib2.43/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
/home/iatell/envs/Rllib2.43/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=200` reached.



📊 Validation Report (Multi-label):
              precision    recall  f1-score   support

           +       0.64      1.00      0.78        51
           H       0.04      0.57      0.07         7
           I       0.01      0.33      0.02         3
    no_label       0.96      0.81      0.88       231
           q       0.04      1.00      0.08         3
           s       0.49      0.98      0.66        44
           v       0.71      0.95      0.81        81

   micro avg       0.49      0.87      0.62       420
   macro avg       0.41      0.81      0.47       420
weighted avg       0.80      0.87      0.80       420
 samples avg       0.56      0.83      0.64       420


🧮 Multi-label Confusion Matrices (per class):

Class '+':
[[241  29]
 [  0  51]]

Class 'H':
[[206 108]
 [  3   4]]

Class 'I':
[[223  95]
 [  2   1]]

Class 'no_label':
[[ 83   7]
 [ 45 186]]

Class 'q':
[[248  70]
 [  0   3]]

Class 's':
[[233  44]
 [  1  43]]

Class 'v':
[[208  32]
 [  4  77]]

Exact match r

# XGboost

In [1]:
import joblib
from datetime import datetime
import xgboost as xgb
from sklearn.multioutput import MultiOutputClassifier
from sklearn.metrics import classification_report, multilabel_confusion_matrix, f1_score
import os
import io
import numpy as np
from preprocess.multilabel_preprocess2 import preprocess_csv_multilabel
from utils.json_to_csv import json_to_csv_in_memory
from utils.multilabel_threshold_tuning import tune_thresholds
from add_ons.feature_pipeline5 import FeaturePipeline
from utils.make_step import make_step
from add_ons.drop_columns2 import drop_columns
from add_ons.candle_dif_rate_of_change_percentage2 import add_candle_rocp
from add_ons.candle_proportion import add_candle_proportions
from add_ons.candle_rate_of_change import add_candle_ratios
from add_ons.candle_proportion_simple import add_candle_shape_features
from add_ons.normalize_candle_seq import add_label_normalized_candles
def evaluate_multilabel_model(model, X_val, y_val, mlb, thresholds=None):
    """
    Evaluate a multi-label XGBoost model and print metrics.
    Optionally apply per-label thresholds.
    """
    # Predict probabilities per label
    y_probs = np.column_stack([est.predict_proba(X_val)[:, 1] for est in model.estimators_])

    # Apply thresholds
    if thresholds is None:
        thresholds = [0.5] * y_val.shape[1]
    y_pred = np.zeros_like(y_val)
    for i, t in enumerate(thresholds):
        y_pred[:, i] = (y_probs[:, i] >= t).astype(int)

    print("\n📊 Validation Report (Multi-label):")
    print(classification_report(y_val, y_pred, target_names=mlb.classes_, zero_division=0))

    print("\n🧮 Multi-label Confusion Matrices (per class):")
    mcm = multilabel_confusion_matrix(y_val, y_pred)
    for i, cls in enumerate(mlb.classes_):
        print(f"\nClass '{cls}':")
        print(mcm[i])

    exact_match = np.all(y_pred == y_val, axis=1).mean()
    print("\nExact match ratio:", exact_match)

    micro_acc = (y_pred == y_val).mean()
    print("Micro accuracy (per-label):", micro_acc)

    return exact_match, micro_acc, y_probs


def train_model_xgb_multilabel(
    data_csv,
    labels_json,
    model_out_dir="models/saved_models",
    do_validation=True,
    seq_len=1,
    n_estimators=200,
    max_depth=6,
    learning_rate=0.05,
    subsample=0.8,
    colsample_bytree=0.8,
    save_model=False,
    return_val_accuracy=True,
    label_weighting="none",  # "none", dict, or "scale_pos"
    threshold_tuning = False,
    include_no_label = False,
):
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    model_out = f"{model_out_dir}/xgb_model_multilabel_{timestamp}.pkl"
    meta_out = f"{model_out_dir}/xgb_meta_multilabel_{timestamp}.pkl"

    csv_string = json_to_csv_in_memory(labels_json)
    labels_csv = io.StringIO(csv_string)

    pipeline = FeaturePipeline(
        steps=[
            make_step(add_candle_shape_features),
            # make_step(add_candle_rocp),
            # make_step(add_label_normalized_candles),
            make_step(drop_columns, cols_to_drop=["open","high","low","close","volume"]),
        ],
        # norm_methods={
            # "main": {
            #     "upper_shadow": "robust", "body": "standard", "lower_shadow": "standard",
            #     "upper_body_ratio": "standard", "lower_body_ratio": "standard",
            #     "upper_lower_body_ratio": "standard", "Candle_Color": "standard",
                
            # }
        #         "candle_shape": {
        #             "upper_shadow": "standard",
        #             "lower_shadow": "standard",
        #             "body": "standard",
        #             "color": "standard",
        #         }
        # },
        # window_norms={
        # "main": {"open_prop": "standard", "high_prop": "standard","low_prop": "standard", "close_prop": "standard"},},

        per_window_flags=[
        False, 
        False, 
        # True
                ]
    )
    if do_validation:
        X_train,y_train, X_val, y_val, df, feature_cols,mlb, label_weights = preprocess_csv_multilabel(
            data_csv, labels_csv,
            n_candles=seq_len,
            val_split=True,
            for_xgboost=True,
            debug_sample=[0, 1],
            label_weighting=label_weighting,
            feature_pipeline=pipeline,
            include_no_label = include_no_label
        )
    else:
        X_train, y_train, df, feature_cols,mlb, label_weights = preprocess_csv_multilabel(
            data_csv, labels_csv,
            n_candles=seq_len,
            val_split=False,
            for_xgboost=True,
            label_weighting=label_weighting,
            feature_pipeline=pipeline,
            include_no_label=include_no_label
        )
        X_val, y_val = None, None

    xgb_models = []
    for w in label_weights:
        xgb_model = xgb.XGBClassifier(
            n_estimators=n_estimators,
            max_depth=max_depth,
            learning_rate=learning_rate,
            subsample=subsample,
            colsample_bytree=colsample_bytree,
            eval_metric='logloss',
            scale_pos_weight=w,
        )
        xgb_models.append(xgb_model)

    model = MultiOutputClassifier(xgb_models[0], n_jobs=-1)
    model.estimators_ = xgb_models
    model.fit(X_train, y_train)

    # Tune thresholds if validation set exists
    optimal_thresholds = None
    val_acc_exact, val_acc_micro = None, None
    if do_validation:
        # --- Step 1: Predict probabilities once ---
        y_probs = np.column_stack([est.predict_proba(X_val)[:, 1] for est in model.estimators_])

        # --- Step 2: Evaluate with default threshold 0.5 ---
        val_acc_exact_default, val_acc_micro_default, _ = evaluate_multilabel_model(
            model, X_val, y_val, mlb, thresholds=[0.5]*y_val.shape[1]
        )
        if threshold_tuning:
        # --- Step 3: Tune optimal thresholds per label ---
            optimal_thresholds = tune_thresholds(y_val, y_probs)
            print("\n📌 Optimal thresholds per label:", dict(zip(mlb.classes_, optimal_thresholds)))

            # --- Step 4: Evaluate with tuned thresholds ---
            val_acc_exact_tuned, val_acc_micro_tuned, _ = evaluate_multilabel_model(
                model, X_val, y_val, mlb, thresholds=optimal_thresholds
            )

    if save_model:
        os.makedirs(model_out_dir, exist_ok=True)
        joblib.dump(model, model_out)
        joblib.dump({
            'seq_len': seq_len,
            'label_classes': mlb.classes_,
            'feature_cols': feature_cols,
            'optimal_thresholds': optimal_thresholds
        }, meta_out)
        print(f"✅ Model saved to {model_out}")
        print(f"✅ Meta saved to {meta_out}")

    if return_val_accuracy:
        return {
            "exact_match": val_acc_exact,
            "micro_accuracy": val_acc_micro,
            "label_weights": label_weights,
            "optimal_thresholds": optimal_thresholds
        }


if __name__ == "__main__":
    train_model_xgb_multilabel(
        data_csv="/home/iatell/projects/meta-learning/data/Bitcoin_BTCUSDT_kaggle_1D_candles.csv",
        labels_json="/home/iatell/projects/meta-learning/data/candle_labels.json",
        do_validation=True,
        label_weighting="scale_pos"# "none", dict, or "scale_pos"
        ,include_no_label= True,

    )



=== DEBUG SAMPLE CHECK ===
Total sequences collected: 1603

--- Sequence 0 ---
Original label(s): ['no_label']
Cleaned label(s): ['no_label']
Encoded: [0 0 0 1 0 0 0]
Feature shape: (1, 4)
First few timesteps:
 [[0.05440368 0.03677583 0.08810496 0.7       ]]

--- Sequence 1 ---
Original label(s): ['no_label']
Cleaned label(s): ['no_label']
Encoded: [0 0 0 1 0 0 0]
Feature shape: (1, 4)
First few timesteps:
 [[0.02600957 0.0367597  0.01538321 0.7       ]]


📊 Validation Report (Multi-label):
              precision    recall  f1-score   support

           +       0.94      0.96      0.95        51
           H       0.67      0.29      0.40         7
           I       0.00      0.00      0.00         3
    no_label       0.92      1.00      0.96       231
           q       0.00      0.00      0.00         3
           s       0.76      0.84      0.80        44
           v       0.96      0.94      0.95        81

   micro avg       0.90      0.94      0.92       420
   macro avg   