In [None]:
# Force complete reimport of model module with all dependencies
import sys
import importlib

# Remove all model-related modules from cache
modules_to_remove = [key for key in sys.modules.keys() if 'model' in key.lower()]
for module_name in modules_to_remove:
    del sys.modules[module_name]

# Reimport fresh
import model
importlib.reload(model)
print("✓ Model module reloaded successfully")
print(f"✓ NoiseInjectionLayer available: {hasattr(model, 'NoiseInjectionLayer')}")
print(f"✓ PricePredictor available: {hasattr(model, 'PricePredictor')}")

In [None]:
import os
import importlib
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers, callbacks, losses, initializers, regularizers
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error, f1_score, accuracy_score
from sklearn.model_selection import TimeSeriesSplit
import joblib
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator, FuncFormatter

# Reload project model module to pick up latest edits when iterating in the notebook
import model
importlib.reload(model)
from model import *
config = Config()

In [None]:
import os
import math
import numpy as np
import pandas as pd
import tensorflow as tf
from tqdm import tqdm
import time
from IPython.display import display, clear_output, HTML
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import ipywidgets as widgets
from ipywidgets import Button, Output, VBox, HBox, Label, FloatProgress
from contextlib import contextmanager

# Global variables for persistence
estimated_time_per_batch = None
total_time = None
pause_training = False
stop_training = False

class InteractivePlotCallback(tf.keras.callbacks.Callback):
    def __init__(self, config, loss_output, metrics_output, progress_widget, total_epochs, batch_output=None, batch_metrics_output=None):
        super().__init__()
        self.config = config
        self.loss_output = loss_output
        self.metrics_output = metrics_output
        self.progress_widget = progress_widget
        self.total_epochs = total_epochs
        self.history = {}
        self.epoch_count = 0

        # Batch-level widgets and history (text output no longer used)
        self.batch_metrics_output = batch_metrics_output or Output()
        self.batch_history = {}  # stores lists per metric for current epoch
        self.total_batches = None

    def on_epoch_begin(self, epoch, logs=None):
        """Flush batch-level displays and reset batch history at the start of each epoch."""
        # Reset per-epoch batch history
        self.batch_history.clear()
        self.batch_count = 0
        # total batches may be available in params
        self.total_batches = self.params.get('steps') if self.params is not None else None

        # Clear batch plot output
        with self.batch_metrics_output:
            clear_output(wait=True)

    def on_train_batch_end(self, batch, logs=None):
        """Capture per-batch logs and update batch plot only (no text logging)."""
        logs = logs or {}
        # Keras batch index starts at 0; store 1-based
        batch_idx = (batch or 0) + 1
        self.batch_history.setdefault('batch', []).append(batch_idx)
        for key, value in logs.items():
            try:
                self.batch_history.setdefault(key, []).append(float(value))
            except Exception:
                # skip non-numeric values
                pass
        self.batch_count = len(self.batch_history['batch'])

        # Update plot only
        try:
            self.update_batch_plot()
        except Exception:
            # Avoid breaking training if plotting fails
            pass

    def on_epoch_end(self, epoch, logs=None):
        global pause_training, stop_training
        if stop_training:
            self.model.stop_training = True
            return
        if pause_training:
            while pause_training and not stop_training:
                time.sleep(0.1)
            if stop_training:
                self.model.stop_training = True
                return

        logs = logs or {}
        self.epoch_count = epoch + 1

        # Collect all losses and metrics for epoch-level persistent history
        loss_types = ['loss', 'val_loss', 'point_loss', 'val_point_loss', 'trend_loss', 'val_trend_loss',
                      'local_trend_loss', 'val_local_trend_loss', 'global_trend_loss', 'val_global_trend_loss',
                      'extended_trend_loss', 'val_extended_trend_loss', 'dir_loss', 'val_dir_loss',
                      'reg_loss', 'val_reg_loss', 'vol_loss', 'val_vol_loss', 'var_nll', 'val_var_nll']
        for m in loss_types:
            if m in logs:
                self.history.setdefault(m, []).append(float(logs[m]))

        metric_types = ['val_f1', 'val_dir_acc', 'val_precision', 'val_recall']
        for m in metric_types:
            if m in logs:
                self.history.setdefault(m, []).append(float(logs[m]))

        # Update epoch-level plots
        self.update_loss_plot()
        self.update_metrics_plot()

        # Update progress bar
        self.progress_widget.value = self.epoch_count
        self.progress_widget.description = f"Epoch {self.epoch_count}/{self.total_epochs}"

    def update_loss_plot(self):
        """Update persistent loss plot with all loss types grouped"""
        epochs = list(range(1, self.epoch_count + 1))

        fig = go.Figure()

        # Training losses
        if 'loss' in self.history and self.history['loss']:
            fig.add_trace(go.Scatter(x=epochs, y=self.history['loss'], mode='lines', 
                                    name='Total Loss', line=dict(width=2, color='blue')))

        # Validation losses
        if 'val_loss' in self.history and self.history['val_loss']:
            fig.add_trace(go.Scatter(x=epochs, y=self.history['val_loss'], mode='lines', 
                                    name='Val Loss', line=dict(width=2, color='red', dash='dash')))

        # Point losses
        for loss_name in ['point_loss', 'val_point_loss']:
            if loss_name in self.history and self.history[loss_name]:
                fig.add_trace(go.Scatter(x=epochs, y=self.history[loss_name], mode='lines', 
                                        name=loss_name, line=dict(width=1.5)))

        # Trend losses
        for loss_name in ['trend_loss', 'val_trend_loss', 'local_trend_loss', 'val_local_trend_loss',
                         'global_trend_loss', 'val_global_trend_loss', 'extended_trend_loss', 'val_extended_trend_loss',
                         'dir_loss', 'val_dir_loss', 'reg_loss', 'val_reg_loss', 'vol_loss', 'val_vol_loss', 'var_nll', 'val_var_nll']:
            if loss_name in self.history and self.history[loss_name]:
                fig.add_trace(go.Scatter(x=epochs, y=self.history[loss_name], mode='lines', 
                                        name=loss_name, line=dict(width=1), opacity=0.7))

        fig.update_layout(
            title='Training Progress: All Losses',
            xaxis_title='Epoch',
            yaxis_title='Loss Value',
            hovermode='x unified',
            height=500,
            width=1200,
            template='plotly_dark',
            legend=dict(x=1.01, y=1, xanchor='left', yanchor='top'),
            showlegend=True,
        )

        with self.loss_output:
            clear_output(wait=True)
            display(fig)

    def update_metrics_plot(self):
        """Update persistent metrics plot with all metric types"""
        epochs = list(range(1, self.epoch_count + 1))

        fig = go.Figure()

        # Directional metrics
        for metric_name in ['val_f1', 'val_dir_acc', 'val_precision', 'val_recall']:
            if metric_name in self.history and self.history[metric_name]:
                fig.add_trace(go.Scatter(x=epochs, y=self.history[metric_name], mode='lines+markers', 
                                        name=metric_name, line=dict(width=2)))

        # Add 50% threshold baseline
        if self.epoch_count > 0:
            fig.add_hline(y=0.5, line_dash="dash", line_color="gray")

        fig.update_layout(
            title='Training Progress: Validation Metrics',
            xaxis_title='Epoch',
            yaxis_title='Metric Value',
            yaxis=dict(range=[0, 1]),
            hovermode='x unified',
            height=400,
            width=1200,
            template='plotly_dark',
            legend=dict(x=1.01, y=1, xanchor='left', yanchor='top'),
            showlegend=True,
        )

        with self.metrics_output:
            clear_output(wait=True)
            display(fig)

    def update_batch_plot(self):
        """Update a per-epoch batch-level plot inside batch_metrics_output."""
        batches = self.batch_history.get('batch', [])
        if not batches:
            return

        fig = make_subplots(rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.08,
                            subplot_titles=("Batch Loss (per-batch)", "Batch Metrics (per-batch)"))

        # Loss lines
        loss_vals = self.batch_history.get('loss', [])
        if loss_vals:
            fig.add_trace(go.Scatter(x=batches, y=loss_vals, mode='lines+markers', name='batch_loss', line=dict(color='blue')), row=1, col=1)
        val_loss_vals = self.batch_history.get('val_loss', [])
        if val_loss_vals:
            fig.add_trace(go.Scatter(x=batches, y=val_loss_vals, mode='lines+markers', name='batch_val_loss', line=dict(color='red', dash='dash')), row=1, col=1)

        # Metric lines: dir_acc and f1 (plus optional validation counterparts if present)
        dir_vals = self.batch_history.get('dir_acc', [])
        if dir_vals:
            fig.add_trace(go.Scatter(x=batches, y=dir_vals, mode='lines+markers', name='dir_acc', line=dict(color='royalblue')), row=2, col=1)
        val_dir_vals = self.batch_history.get('val_dir_acc', [])
        if val_dir_vals:
            fig.add_trace(go.Scatter(x=batches, y=val_dir_vals, mode='lines+markers', name='val_dir_acc', line=dict(color='deepskyblue', dash='dash')), row=2, col=1)

        f1_vals = self.batch_history.get('f1', [])
        if f1_vals:
            fig.add_trace(go.Scatter(x=batches, y=f1_vals, mode='lines+markers', name='f1', line=dict(color='orange')), row=2, col=1)
        val_f1_vals = self.batch_history.get('val_f1', [])
        if val_f1_vals:
            fig.add_trace(go.Scatter(x=batches, y=val_f1_vals, mode='lines+markers', name='val_f1', line=dict(color='darkorange', dash='dash')), row=2, col=1)

        # Add baseline 50% line on metric subplot
        fig.update_yaxes(range=[0, 1], row=2, col=1)
        fig.add_hline(y=0.5, line_dash="dash", line_color="gray", row=2, col=1)

        fig.update_layout(
            height=500,
            width=1200,
            template='plotly_dark',
            hovermode='x unified',
            showlegend=True,
        )

        with self.batch_metrics_output:
            clear_output(wait=True)
            display(fig)


def limit_dataset_size(df, close_values, max_samples, lookback):
    required_rows = max_samples + lookback
    if len(df) > required_rows:
        df_limited = df.tail(required_rows).copy()
        close_limited = close_values[-required_rows:]
        print(f"Limited dataset to {len(df_limited):,} rows (most recent)")
        return df_limited, close_limited
    return df, close_values

def estimate_training_time(config, close_values, train_samples):
    batches_per_epoch = train_samples // config.BATCH_SIZE
    try:
        predictor = PricePredictor(config)
        base_model = predictor.build_model()
        test_model = CustomTrainModel(base_model=base_model, pred_scale=1.0, pred_mean=0.0, config=config,
                                      inputs=base_model.inputs, outputs=base_model.outputs)
        optimizer = tf.keras.optimizers.Adam(learning_rate=config.LR)

        # Create sample batch
        close_np = np.array(close_values)
        start_indices = np.arange(0, min(config.BATCH_SIZE * config.WINDOW_STEP, len(close_np) - config.LOOKBACK - 1), config.WINDOW_STEP)
        X_batch = np.array([close_np[start:start + config.LOOKBACK] for start in start_indices[:config.BATCH_SIZE]])
        y_batch = np.array([close_np[start + config.LOOKBACK] for start in start_indices[:config.BATCH_SIZE]])[:, np.newaxis]
        last_close_batch = np.array([close_np[start + config.LOOKBACK - 1] for start in start_indices[:config.BATCH_SIZE]])[:, np.newaxis]
        extended_batch = np.array([[close_np[start + config.LOOKBACK + p] for p in config.EXTENDED_TREND_PERIODS] for start in start_indices[:config.BATCH_SIZE]])

        test_X = tf.convert_to_tensor(X_batch)
        test_y = tf.convert_to_tensor(y_batch)
        test_last_close = tf.convert_to_tensor(last_close_batch)
        test_extended = tf.convert_to_tensor(extended_batch)

        start_time = time.time()
        with tqdm(total=3, desc="Benchmarking", leave=False) as pbar:
            for _ in range(3):
                with tf.GradientTape() as tape:
                    predictions = test_model(test_X, training=True)
                    loss_components = test_model.custom_loss(test_X, test_y, predictions, test_last_close, test_extended)
                    total_loss = loss_components[0]
                grads = tape.gradient(total_loss, test_model.trainable_variables)
                optimizer.apply_gradients(zip(grads, test_model.trainable_variables))
                pbar.update(1)
        elapsed = time.time() - start_time
        global estimated_time_per_batch
        estimated_time_per_batch = (elapsed / 3) * (config.BATCH_SIZE / len(X_batch))
    except Exception as e:
        print(f"Benchmarking failed: {e}")
        estimated_time_per_batch = 0.15 if tf.config.list_physical_devices('GPU') else 2.0

    time_per_epoch = batches_per_epoch * estimated_time_per_batch
    global total_time
    total_time = time_per_epoch * config.EPOCHS
    return estimated_time_per_batch, total_time

def run_evaluation(model, scaler, X_test_seq, y_test, y_pred, last_close_test):
    """Run evaluation and display results"""
    # Interactive evaluation plots
    fig = make_subplots(rows=1, cols=2, subplot_titles=('Actual vs Predicted (Last 200)', 'Error Distribution'))
    fig.add_trace(go.Scatter(y=y_test[-200:], mode='lines', name='Actual'), row=1, col=1)
    fig.add_trace(go.Scatter(y=y_pred[-200:], mode='lines', name='Predicted'), row=1, col=1)
    errors = y_test - y_pred
    fig.add_trace(go.Histogram(x=errors, nbinsx=50, name='Errors'), row=1, col=2)
    fig.update_layout(height=400, width=1200, template='plotly_dark')
    display(fig)

    # Metrics
    mse = mean_squared_error(y_test, y_pred)
    rmse = np.sqrt(mse)
    r2 = r2_score(y_test, y_pred)
    mape = mean_absolute_percentage_error(y_test, y_pred)
    pred_dir = (y_pred - last_close_test) > 0
    true_dir = (y_test - last_close_test) > 0
    dir_acc = np.mean(pred_dir == true_dir)

    print("\nEvaluation Metrics:")
    print(f"   MSE: {mse:.4f}")
    print(f"   RMSE: {rmse:.4f}")
    print(f"   R-squared: {r2:.4f}")
    print(f"   MAPE: {mape:.4f}")
    print(f"   Direction Accuracy: {dir_acc:.2%}")

# ============================================================================
# Main Pipeline
# ============================================================================
print("Training and Inference Cell")
print("=" * 80)

config = Config()

# Check weights
weights_exist = os.path.exists(config.MODEL_PATH)
force = False  # Set to True to force retraining

if not weights_exist or force:
    print("Running Training Pipeline")
    print("-" * 80)
    
    # Setup GPU
    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
        print(f"GPU: {len(gpus)} detected")
        for gpu in gpus:
            print(f"   {gpu.name}")
            tf.config.experimental.set_memory_growth(gpu, True)
    else:
        print("Running on CPU")

    # Load and prepare data
    data_processor = DataProcessor(config)
    df_full, close_full = data_processor.load_and_prepare_data()
    df, close_values = limit_dataset_size(df_full, close_full, config.MAX_SEQUENCE_COUNT, config.LOOKBACK)

    # Calculate dataset statistics
    max_sequences = len(close_values) - config.LOOKBACK - max(config.EXTENDED_TREND_PERIODS)
    expected_samples = max_sequences // config.WINDOW_STEP
    train_samples = int(expected_samples * 0.8)
    test_samples = expected_samples - train_samples

    print("\nDataset Statistics:")
    print(f"   Total rows: {len(df):,}")
    print(f"   Expected samples: {expected_samples:,}")
    print(f"   Training samples: {train_samples:,}")
    print(f"   Test samples: {test_samples:,}")

    # Benchmark and estimate training time
    print("\nEstimating training time...")
    estimated_time_per_batch, total_time = estimate_training_time(config, close_values, train_samples)
    
    print("\nTraining Time Estimation:")
    print(f"   Time per batch: {estimated_time_per_batch:.3f}s")
    print(f"   Est. time per epoch: {(train_samples // config.BATCH_SIZE * estimated_time_per_batch) / 60:.1f} min")
    print(f"   Est. total time: {total_time / 3600:.2f}h")

    print("\n" + "=" * 80)
    print("Training Controls & Progress")
    print("-" * 80)

    # Create control buttons
    pause_btn = Button(description="Pause", button_style='warning')
    resume_btn = Button(description="Resume", button_style='info')
    stop_btn = Button(description="Stop", button_style='danger')

    def on_pause(b):
        global pause_training
        pause_training = True
        pause_btn.disabled = True
        resume_btn.disabled = False

    def on_resume(b):
        global pause_training
        pause_training = False
        pause_btn.disabled = False
        resume_btn.disabled = True

    def on_stop(b):
        global stop_training
        stop_training = True
        stop_btn.disabled = True

    pause_btn.on_click(on_pause)
    resume_btn.on_click(on_resume)
    stop_btn.on_click(on_stop)
    resume_btn.disabled = True  # Start with resume disabled

    # Create persistent progress bar
    progress_bar = FloatProgress(value=0, min=0, max=config.EPOCHS, description='Epoch 0/100')
    progress_bar.style.bar_color = '#00aa00'

    # Create Output widgets for persistent plot display
    loss_output = Output()
    metrics_output = Output()
    # Batch-level plot output only (no text output)
    batch_metrics_output = Output()

    # Display control panel and progress
    controls_panel = HBox([pause_btn, resume_btn, stop_btn])
    display(controls_panel)
    display(progress_bar)

    # Train
    live_cb = InteractivePlotCallback(
        config, loss_output, metrics_output, progress_bar, config.EPOCHS,
        batch_metrics_output=batch_metrics_output,
    )
    original_load = DataProcessor.load_and_prepare_data
    DataProcessor.load_and_prepare_data = lambda self: (df, close_values)

    # Display plot containers
    print("\nReal-time Training Plots (updating each epoch)...")
    print("-" * 80)
    display(loss_output)
    display(metrics_output)
    # Display per-batch plot
    display(batch_metrics_output)

    try:
        model, scaler, X_test_seq, y_test, y_pred, last_close_test, history, ext = train_model(
            extra_callbacks=[live_cb], epochs=None, force=0, calibrate=0
        )
        if model:
            print("\nTraining completed successfully!")
            print("\n" + "=" * 80)
            print("Final Evaluation")
            print("-" * 80)
            run_evaluation(model, scaler, X_test_seq, y_test, y_pred, last_close_test)
    except Exception as e:
        print(f"\nTraining Error: {e}")
        import traceback
        traceback.print_exc()
    finally:
        DataProcessor.load_and_prepare_data = original_load
        pause_btn.disabled = True
        resume_btn.disabled = True
        stop_btn.disabled = True

else:
    print("Running Evaluation Pipeline")
    print("-" * 80)
    try:
        # Rebuild the model architecture first
        scaler = joblib.load(config.SCALER_PATH)
        data_processor = DataProcessor(config)
        df_full, close_full = data_processor.load_and_prepare_data()
        df, close_values = limit_dataset_size(df_full, close_full, config.MAX_SEQUENCE_COUNT, config.LOOKBACK)
        
        (X_train_seq, y_train_scaled, last_close_train, extended_trends_train,
         X_test_seq, y_test_scaled, last_close_test, extended_trends_test,
         y_train, y_test, target_scaler) = data_processor.prepare_datasets(df, close_values)
        
        print(f"X_test_seq shape: {X_test_seq.shape}")
        print(f"y_test shape: {y_test.shape}")
        
        # Rebuild model architecture
        predictor = PricePredictor(config)
        base_model = predictor.build_model()
        pred_scale = np.std(y_train) if np.std(y_train) > 0 else 1.0
        pred_mean = np.mean(y_train)
        model = CustomTrainModel(
            base_model=base_model,
            pred_scale=pred_scale,
            pred_mean=pred_mean,
            lambda_point=config.LAMBDA_POINT,
            lambda_local_trend=config.LAMBDA_LOCAL_TREND,
            lambda_global_trend=config.LAMBDA_GLOBAL_TREND,
            lambda_extended_trend=config.LAMBDA_EXTENDED_TREND,
            lambda_dir=config.LAMBDA_DIR,
            config=config,
            inputs=base_model.inputs,
            outputs=base_model.outputs
        )
        
        # Load the saved weights
        model.load_weights(config.MODEL_PATH)
        print("Model loaded successfully")
        
        # Run Multi-Horizon Evaluation on test set
        print("\nRunning Multi-Horizon Evaluation on test set...")
        batch_size = config.BATCH_SIZE
        y_pred_all_horizons = []  # Will store all 3 horizons
        
        print(f"X_test_seq shape: {X_test_seq.shape}")
        print(f"y_test shape: {y_test.shape}")
        print(f"Total test samples: {len(y_test)}")
        
        # Collect predictions from all batches for all horizons
        for i in range(0, len(X_test_seq), batch_size):
            batch_end = min(i + batch_size, len(X_test_seq))
            X_batch = X_test_seq[i:batch_end]
            
            # Convert to tensor
            X_batch_tf = tf.convert_to_tensor(X_batch, dtype=tf.float32)
            
            # Run prediction - returns (price[B,3], direction[B,1], variance[B,1])
            pred_batch = model(X_batch_tf, training=False)
            y_pred_batch, _, _ = pred_batch  # y_pred_batch shape: (batch_size, 3)
            
            y_pred_all_horizons.append(y_pred_batch.numpy())
        
        # Concatenate all predictions - shape: (total_samples, 3)
        y_pred_all = np.concatenate(y_pred_all_horizons, axis=0)
        
        # Trim to exact test size
        y_pred_all = y_pred_all[:len(y_test)]
        
        print(f"Total predictions shape: {y_pred_all.shape}")
        print(f"Expected: ({len(y_test)}, 3)")
        assert y_pred_all.shape[0] == len(y_test), f"Shape mismatch: {y_pred_all.shape[0]} != {len(y_test)}"
        
        # Define horizon names and weights
        horizon_names = ["1-min (Primary)", "5-min", "15-min"]
        horizon_weights = [config.LAMBDA_SHORT, config.LAMBDA_POINT, config.LAMBDA_LONG]
        
        print("\n" + "="*80)
        print("MULTI-HORIZON EVALUATION RESULTS")
        print("="*80)
        
        # Evaluate each horizon separately
        horizon_metrics = {}
        for h_idx in range(y_pred_all.shape[1]):
            print(f"\nHorizon {h_idx}: {horizon_names[h_idx]} (weight: {horizon_weights[h_idx]:.2f})")
            print("-" * 60)
            
            # Extract predictions for this horizon and reshape for inverse transform
            y_pred_h_scaled = y_pred_all[:, h_idx:h_idx+1]  # Shape: (N, 1)
            y_pred_h = target_scaler.inverse_transform(y_pred_h_scaled).ravel()
            
            # Calculate regression metrics
            mse_h = mean_squared_error(y_test, y_pred_h)
            rmse_h = np.sqrt(mse_h)
            r2_h = r2_score(y_test, y_pred_h)
            mape_h = mean_absolute_percentage_error(y_test, y_pred_h)
            
            # Calculate direction metrics
            pred_dir_h = (y_pred_h - last_close_test) > 0
            true_dir = (y_test - last_close_test) > 0
            dir_acc_h = accuracy_score(true_dir, pred_dir_h)
            f1_h = f1_score(true_dir, pred_dir_h, zero_division=0)
            
            # Store for later use
            horizon_metrics[h_idx] = {
                'y_pred': y_pred_h,
                'mse': mse_h,
                'rmse': rmse_h,
                'r2': r2_h,
                'mape': mape_h,
                'dir_acc': dir_acc_h,
                'f1': f1_h
            }
            
            print(f"  Regression Metrics:")
            print(f"    MSE:   {mse_h:.6f}")
            print(f"    RMSE:  {rmse_h:.6f}")
            print(f"    R2:    {r2_h:.6f}")
            print(f"    MAPE:  {mape_h:.4f}%")
            print(f"\n  Direction Metrics:")
            print(f"    Accuracy: {dir_acc_h:.4f} ({dir_acc_h*100:.2f}%)")
            print(f"    F1-Score: {f1_h:.4f}")
        
        # Use primary horizon (1-min) for main comparison
        y_pred = horizon_metrics[0]['y_pred']
        
        print("\n" + "="*80)
        print("SUMMARY: Primary Horizon (1-min) vs Others")
        print("="*80)
        
        # Create comparison table
        print(f"\n{'Metric':<15} {'1-min':<15} {'5-min':<15} {'15-min':<15}")
        print("-" * 60)
        
        for metric_name in ['mse', 'rmse', 'r2', 'mape', 'dir_acc', 'f1']:
            values = [horizon_metrics[h][metric_name] for h in range(3)]
            if metric_name in ['mse', 'rmse', 'mape']:
                print(f"{metric_name:<15} {values[0]:<15.6f} {values[1]:<15.6f} {values[2]:<15.6f}")
            else:
                print(f"{metric_name:<15} {values[0]:<15.4f} {values[1]:<15.4f} {values[2]:<15.4f}")
        
        # Multi-horizon visualization (expanded to include 15-min plot)
        print("\n" + "="*80)
        print("Creating Multi-Horizon Comparison Plots...")
        print("="*80)
        
        fig = make_subplots(
            rows=2, cols=3,
            subplot_titles=(
                'Actual vs 1-min', 
                'Actual vs 5-min',
                'Actual vs 15-min',
                'Error Distribution (1-min)', 
                'Direction Accuracy by Horizon',
                ''
            ),
            specs=[[{"type": "scatter"}, {"type": "scatter"}, {"type": "scatter"}],
                   [{"type": "histogram"}, {"type": "bar"}, {"type": "scatter"}]]
        )
        
        # Plot 1: Actual vs 1-min predictions (last 200 samples)
        fig.add_trace(
            go.Scatter(y=y_test[-200:], mode='lines', name='Actual', line=dict(color='blue')),
            row=1, col=1,
        )
        fig.add_trace(
            go.Scatter(y=horizon_metrics[0]['y_pred'][-200:], mode='lines', name='1-min Pred', 
                      line=dict(color='red', dash='dash')),
            row=1, col=1,
        )
        
        # Plot 2: Actual vs 5-min predictions (last 200 samples)
        fig.add_trace(
            go.Scatter(y=y_test[-200:], mode='lines', name='Actual', line=dict(color='blue'), showlegend=False),
            row=1, col=2,
        )
        fig.add_trace(
            go.Scatter(y=horizon_metrics[1]['y_pred'][-200:], mode='lines', name='5-min Pred', 
                      line=dict(color='green', dash='dash')),
            row=1, col=2,
        )
        
        # Plot 3: Actual vs 15-min predictions (last 200 samples)
        fig.add_trace(
            go.Scatter(y=y_test[-200:], mode='lines', name='Actual', line=dict(color='blue'), showlegend=False),
            row=1, col=3,
        )
        fig.add_trace(
            go.Scatter(y=horizon_metrics[2]['y_pred'][-200:], mode='lines', name='15-min Pred', 
                      line=dict(color='orange', dash='dash')),
            row=1, col=3,
        )
        
        # Plot 4: Error distribution for primary horizon
        errors_h0 = y_test - horizon_metrics[0]['y_pred']
        fig.add_trace(
            go.Histogram(x=errors_h0, nbinsx=50, name='Errors', marker=dict(color='purple')),
            row=2, col=1,
        )
        
        # Plot 5: Direction accuracy comparison
        dir_accs = [horizon_metrics[h]['dir_acc'] for h in range(3)]
        fig.add_trace(
            go.Bar(x=horizon_names, y=dir_accs, name='Direction Accuracy',
                   marker=dict(color=['red', 'green', 'orange']), showlegend=False),
            row=2, col=2,
        )
        fig.add_hline(y=0.5, line_dash="dash", line_color="gray", row=2, col=2, annotation_text="50% Baseline")
        
        # (Optional empty subplot placeholder row=2,col=3)
        fig.add_trace(
            go.Scatter(x=[0], y=[0], mode='text', text=[''], showlegend=False),
            row=2, col=3,
        )
        
        # Axis titles
        fig.update_xaxes(title_text="Sample", row=1, col=1)
        fig.update_xaxes(title_text="Sample", row=1, col=2)
        fig.update_xaxes(title_text="Sample", row=1, col=3)
        fig.update_xaxes(title_text="Error Value", row=2, col=1)
        fig.update_xaxes(title_text="Horizon", row=2, col=2)
        fig.update_yaxes(title_text="Price (USD)", row=1, col=1)
        fig.update_yaxes(title_text="Price (USD)", row=1, col=2)
        fig.update_yaxes(title_text="Price (USD)", row=1, col=3)
        fig.update_yaxes(title_text="Frequency", row=2, col=1)
        fig.update_yaxes(title_text="Accuracy", row=2, col=2)
        fig.update_yaxes(range=[0, 1], row=2, col=2)
        
        fig.update_layout(height=800, width=1600, template='plotly_dark', showlegend=True)
        display(fig)
        
        # Run evaluation plots for primary horizon
        run_evaluation(model, target_scaler, X_test_seq, y_test, y_pred, last_close_test)
        
    except Exception as e:
        print("Evaluation failed: {e}")
        import traceback
        traceback.print_exc()

print("\n" + "=" * 80)
print("Training and Inference Cell Complete")
print("=" * 80)
    

In [None]:
# ============================================================================
# TRAINING OR LOADING MODEL WITH INTERACTIVE VISUALIZATION
# ============================================================================
# This cell calls train_model() which handles:
# - Data loading and preprocessing
# - Model building
# - Training (or loading existing weights)
# - Evaluation
# All variables are unpacked to match model.py shapes exactly
# ============================================================================

import ipywidgets as widgets
from IPython.display import display, clear_output

print("="*80)
print("TRAINING/LOADING MODEL")
print("="*80)

# Setup interactive widgets for training visualization
loss_output = widgets.Output()
metrics_output = widgets.Output()
batch_metrics_output = widgets.Output()
progress_widget = widgets.FloatProgress(min=0, max=100, description='Training:', bar_style='info')

# Control buttons
pause_button = widgets.Button(description='Pause', button_style='warning')
resume_button = widgets.Button(description='Resume', button_style='success', disabled=True)
stop_button = widgets.Button(description='Stop', button_style='danger')

def on_pause_clicked(b):
    global pause_training
    pause_training = True
    pause_button.disabled = True
    resume_button.disabled = False
    print("⏸️  Training paused...")

def on_resume_clicked(b):
    global pause_training
    pause_training = False
    pause_button.disabled = False
    resume_button.disabled = True
    print("▶️  Training resumed...")

def on_stop_clicked(b):
    global stop_training
    stop_training = True
    pause_button.disabled = True
    resume_button.disabled = True
    print("⏹️  Training stopped...")

pause_training = False
stop_training = False

pause_button.on_click(on_pause_clicked)
resume_button.on_click(on_resume_clicked)
stop_button.on_click(on_stop_clicked)

# Display widgets
button_box = widgets.HBox([pause_button, resume_button, stop_button])
display(button_box)
display(progress_widget)
display(loss_output)
display(metrics_output)
display(batch_metrics_output)

# Create interactive callback using Cell 2's InteractivePlotCallback
interactive_callback = InteractivePlotCallback(
    config=config,
    loss_output=loss_output,
    metrics_output=metrics_output,
    progress_widget=progress_widget,
    total_epochs=config.EPOCHS,
    batch_output=None,
    batch_metrics_output=batch_metrics_output
)

# Call train_model() with interactive callback
# Returns 9-tuple with all data needed for inference and backtesting
print("\n🚀 Starting train_model()...")
print("   This will either train a new model or load existing weights")
print("   Set force=True to retrain, force=False to load weights\n")

(model, target_scaler, X_test_seq, y_test, y_pred_h1_primary, 
 last_close_test, history, extended_trends_test, 
 predictions_dict) = train_model(
    extra_callbacks=[interactive_callback], 
    epochs=config.EPOCHS, 
    force=False,  # Set to True to force retraining
    calibrate=True
)

print("\n" + "="*80)
print("✅ MODEL READY - ALL VARIABLES UNPACKED")
print("="*80)

# Verify shapes match model.py exactly
print("\n📊 Variable Shapes (matching model.py):")
print(f"   model:                type={type(model).__name__}")
print(f"   target_scaler:        type={type(target_scaler).__name__}")
print(f"   X_test_seq:           shape={X_test_seq.shape}, dtype={X_test_seq.dtype}")
print(f"   y_test:               shape={y_test.shape}, dtype={y_test.dtype}")
print(f"   y_pred_h1_primary:    shape={y_pred_h1_primary.shape}, dtype={y_pred_h1_primary.dtype}")
print(f"   last_close_test:      shape={last_close_test.shape}, dtype={last_close_test.dtype}")
print(f"   extended_trends_test: shape={extended_trends_test.shape}, dtype={extended_trends_test.dtype}")
print(f"   history:              type={type(history)}")

print("\n📦 Multi-Horizon Predictions (UNSCALED deltas):")
y_pred_h0_raw = predictions_dict["h0"]  # 1-min horizon
y_pred_h1_raw = predictions_dict["h1"]  # 5-min horizon (primary)
y_pred_h2_raw = predictions_dict["h2"]  # 15-min horizon
print(f"   h0 (1-min):  shape={y_pred_h0_raw.shape}, dtype={y_pred_h0_raw.dtype}")
print(f"   h1 (5-min):  shape={y_pred_h1_raw.shape}, dtype={y_pred_h1_raw.dtype}")
print(f"   h2 (15-min): shape={y_pred_h2_raw.shape}, dtype={y_pred_h2_raw.dtype}")

print("\n🎯 Expected Shapes for model.py compatibility:")
print(f"   ✓ X_test_seq:           [N={len(X_test_seq)}, LOOKBACK={config.LOOKBACK}]")
print(f"   ✓ y_test:               [N={len(y_test)}, 3] (multi-horizon targets)")
print(f"   ✓ last_close_test:      [N={len(last_close_test)}]")
print(f"   ✓ extended_trends_test: [N={len(extended_trends_test)}, {len(config.EXTENDED_TREND_PERIODS)}]")

print("\n🔧 Model Architecture:")
print(f"   Input:  [{config.LOOKBACK}] - close price sequence")
print(f"   Output: 9 tensors - (price, direction, variance) × 3 horizons")
print(f"           [0] price_h0:     [B, 1] scaled")
print(f"           [1] direction_h0: [B, 1] sigmoid prob")
print(f"           [2] variance_h0:  [B, 1] softplus")
print(f"           [3] price_h1:     [B, 1] scaled (PRIMARY)")
print(f"           [4] direction_h1: [B, 1] sigmoid prob")
print(f"           [5] variance_h1:  [B, 1] softplus")
print(f"           [6] price_h2:     [B, 1] scaled")
print(f"           [7] direction_h2: [B, 1] sigmoid prob")
print(f"           [8] variance_h2:  [B, 1] softplus")

print("\n✅ All variables ready for Cells 4, 6, 8 (Direction Eval, Backtest, Visualization)")

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score, roc_curve
import seaborn as sns

print("="*80)
print("DIRECTION HEAD ACCURACY EVALUATION (9-OUTPUT MODEL)")
print("="*80)

try:
    # Extract direction predictions from 9-output model
    # Model outputs: [price_h0, direction_h0, variance_h0, price_h1, direction_h1, variance_h1, price_h2, direction_h2, variance_h2]
    print("\nExtracting predictions from test set (9-output model)...")

    batch_size = config.BATCH_SIZE

    # Storage for all 3 horizons (matching model.py structure)
    price_preds_h0, price_preds_h1, price_preds_h2 = [], [], []
    direction_preds_h0, direction_preds_h1, direction_preds_h2 = [], [], []
    variance_preds_h0, variance_preds_h1, variance_preds_h2 = [], [], []

    for i in range(0, len(X_test_seq), batch_size):
        batch_end = min(i + batch_size, len(X_test_seq))
        X_batch = X_test_seq[i:batch_end]

        # Convert to tensor
        X_batch_tf = tf.convert_to_tensor(X_batch, dtype=tf.float32)

        # Run prediction - returns 9 outputs
        pred_outputs = model(X_batch_tf, training=False)

        # Unpack 9 outputs (matching model.py:725-729)
        (price_h0_batch, direction_h0_batch, variance_h0_batch,
         price_h1_batch, direction_h1_batch, variance_h1_batch,
         price_h2_batch, direction_h2_batch, variance_h2_batch) = pred_outputs

        # Store predictions
        price_preds_h0.append(price_h0_batch.numpy())
        direction_preds_h0.append(direction_h0_batch.numpy())
        variance_preds_h0.append(variance_h0_batch.numpy())

        price_preds_h1.append(price_h1_batch.numpy())
        direction_preds_h1.append(direction_h1_batch.numpy())
        variance_preds_h1.append(variance_h1_batch.numpy())

        price_preds_h2.append(price_h2_batch.numpy())
        direction_preds_h2.append(direction_h2_batch.numpy())
        variance_preds_h2.append(variance_h2_batch.numpy())

    # Concatenate all batches
    price_h0 = np.concatenate(price_preds_h0, axis=0)[:len(y_test)]
    direction_h0 = np.concatenate(direction_preds_h0, axis=0)[:len(y_test)]
    variance_h0 = np.concatenate(variance_preds_h0, axis=0)[:len(y_test)]

    price_h1 = np.concatenate(price_preds_h1, axis=0)[:len(y_test)]
    direction_h1 = np.concatenate(direction_preds_h1, axis=0)[:len(y_test)]
    variance_h1 = np.concatenate(variance_preds_h1, axis=0)[:len(y_test)]

    price_h2 = np.concatenate(price_preds_h2, axis=0)[:len(y_test)]
    direction_h2 = np.concatenate(direction_preds_h2, axis=0)[:len(y_test)]
    variance_h2 = np.concatenate(variance_preds_h2, axis=0)[:len(y_test)]

    # Convert prices from scaled to raw (matching model.py:2056-2074)
    price_h0_raw = target_scaler.inverse_transform(price_h0).ravel()
    price_h1_raw = target_scaler.inverse_transform(price_h1).ravel()
    price_h2_raw = target_scaler.inverse_transform(price_h2).ravel()

    # Direction probabilities (already sigmoid, shape [N, 1] -> [N])
    direction_h0_probs = direction_h0.ravel()
    direction_h1_probs = direction_h1.ravel()
    direction_h2_probs = direction_h2.ravel()

    # Variance (already softplus, shape [N, 1] -> [N])
    variance_h0_vals = variance_h0.ravel()
    variance_h1_vals = variance_h1.ravel()
    variance_h2_vals = variance_h2.ravel()

    print(f"✓ Extracted all 3 horizons:")
    print(f"  h0 (1-min):  price={price_h0_raw.shape}, dir={direction_h0_probs.shape}, var={variance_h0_vals.shape}")
    print(f"  h1 (5-min):  price={price_h1_raw.shape}, dir={direction_h1_probs.shape}, var={variance_h1_vals.shape}")
    print(f"  h2 (15-min): price={price_h2_raw.shape}, dir={direction_h2_probs.shape}, var={variance_h2_vals.shape}")

    # Calculate true direction for each horizon (matching model.py:2123-2157)
    # y_test shape: [N, 3] where columns are (h0_delta, h1_delta, h2_delta)
    print(f"\ny_test shape: {y_test.shape} (multi-horizon deltas)")

    # True direction: 1 if delta > 0, 0 otherwise
    true_dir_h0 = (y_test[:, 0] > 0).astype(int)
    true_dir_h1 = (y_test[:, 1] > 0).astype(int)
    true_dir_h2 = (y_test[:, 2] > 0).astype(int)

    # Predicted direction: 1 if prob > 0.5, 0 otherwise
    pred_dir_h0 = (direction_h0_probs > 0.5).astype(int)
    pred_dir_h1 = (direction_h1_probs > 0.5).astype(int)
    pred_dir_h2 = (direction_h2_probs > 0.5).astype(int)

    print("\n" + "="*80)
    print("DIRECTION ACCURACY METRICS (PER HORIZON)")
    print("="*80)
    print(f"{'Horizon':<12} {'Accuracy':<10} {'Precision':<12} {'Recall':<10} {'F1':<10} {'ROC-AUC':<10}")
    print("-" * 74)

    # Evaluate each horizon
    for h_name, true_dir, pred_dir, dir_probs in [
        ("h0_1min", true_dir_h0, pred_dir_h0, direction_h0_probs),
        ("h1_5min", true_dir_h1, pred_dir_h1, direction_h1_probs),
        ("h2_15min", true_dir_h2, pred_dir_h2, direction_h2_probs),
    ]:
        accuracy = accuracy_score(true_dir, pred_dir)
        precision = precision_score(true_dir, pred_dir, zero_division=0)
        recall = recall_score(true_dir, pred_dir, zero_division=0)
        f1 = f1_score(true_dir, pred_dir, zero_division=0)

        try:
            roc_auc = roc_auc_score(true_dir, dir_probs)
        except:
            roc_auc = 0.0

        print(f"{h_name:<12} {accuracy:<10.4f} {precision:<12.4f} {recall:<10.4f} {f1:<10.4f} {roc_auc:<10.4f}")

    # Detailed analysis for PRIMARY horizon (h1_5min)
    print("\n" + "="*80)
    print("DETAILED ANALYSIS: PRIMARY HORIZON (h1_5min)")
    print("="*80)

    cm = confusion_matrix(true_dir_h1, pred_dir_h1)
    tn, fp, fn, tp = cm.ravel()

    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0

    # Matthews Correlation Coefficient
    denominator = np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
    mcc = (tp * tn - fp * fn) / denominator if denominator > 0 else 0.0

    print(f"\nConfusion Matrix:")
    print(f"  True Negatives:  {tn:6d}")
    print(f"  False Positives: {fp:6d}")
    print(f"  False Negatives: {fn:6d}")
    print(f"  True Positives:  {tp:6d}")

    print(f"\nDetailed Metrics:")
    print(f"  Accuracy:     {accuracy_score(true_dir_h1, pred_dir_h1):.4f} ({accuracy_score(true_dir_h1, pred_dir_h1)*100:.2f}%)")
    print(f"  Precision:    {precision_score(true_dir_h1, pred_dir_h1, zero_division=0):.4f}")
    print(f"  Recall:       {recall_score(true_dir_h1, pred_dir_h1, zero_division=0):.4f}")
    print(f"  Specificity:  {specificity:.4f}")
    print(f"  Sensitivity:  {sensitivity:.4f}")
    print(f"  F1-Score:     {f1_score(true_dir_h1, pred_dir_h1, zero_division=0):.4f}")
    print(f"  MCC:          {mcc:.4f}")

    # Distribution analysis
    print(f"\nDirection Distribution:")
    print(f"  True UP (1):   {np.sum(true_dir_h1 == 1):6d} ({np.sum(true_dir_h1 == 1)/len(true_dir_h1)*100:.1f}%)")
    print(f"  True DOWN (0): {np.sum(true_dir_h1 == 0):6d} ({np.sum(true_dir_h1 == 0)/len(true_dir_h1)*100:.1f}%)")
    print(f"  Pred UP (1):   {np.sum(pred_dir_h1 == 1):6d} ({np.sum(pred_dir_h1 == 1)/len(pred_dir_h1)*100:.1f}%)")
    print(f"  Pred DOWN (0): {np.sum(pred_dir_h1 == 0):6d} ({np.sum(pred_dir_h1 == 0)/len(pred_dir_h1)*100:.1f}%)")

    # Probability histogram
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))

    for idx, (h_name, dir_probs, true_dir) in enumerate([
        ("h0 (1-min)", direction_h0_probs, true_dir_h0),
        ("h1 (5-min)", direction_h1_probs, true_dir_h1),
        ("h2 (15-min)", direction_h2_probs, true_dir_h2),
    ]):
        ax = axes[idx]
        ax.hist(dir_probs[true_dir == 1], bins=50, alpha=0.6, label='True UP', color='green')
        ax.hist(dir_probs[true_dir == 0], bins=50, alpha=0.6, label='True DOWN', color='red')
        ax.axvline(0.5, color='black', linestyle='--', linewidth=1, label='Threshold')
        ax.set_xlabel('Direction Probability')
        ax.set_ylabel('Frequency')
        ax.set_title(f'Direction Probability Distribution - {h_name}')
        ax.legend()
        ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    print("\n✅ Direction evaluation complete for all 3 horizons")

except Exception as e:
    print(f"\n❌ Error during direction evaluation: {e}")
    import traceback
    traceback.print_exc()


In [None]:
# ============================================================================
# TRADE DATACLASS AND COMPLETE MULTI-HEAD STRATEGY PIPELINE
# ============================================================================
from dataclasses import dataclass
from typing import List

@dataclass
class Trade:
    """Professional trade representation"""
    entry_bar: int           # Bar where trade entered
    exit_bar: int            # Bar where trade exited
    entry_price: float       # Entry price at entry_bar
    exit_price: float        # Exit price at exit_bar
    trade_type: str          # 'LONG' or 'SHORT'
    bars_held: int           # Number of bars held (exit_bar - entry_bar)
    profit: float            # Absolute profit
    profit_pct: float        # Percentage profit
    exit_reason: str         # 'SPIKE', 'REV', 'TIME'

    # Calculated target levels (not executed, just reference)
    tp1_price: float = None
    tp2_price: float = None
    sl_price: float = None

    def __post_init__(self):
        """Calculate target levels after initialization"""
        if self.trade_type == 'LONG':
            self.tp1_price = self.entry_price + (self.entry_price * 0.005)   # +0.5%
            self.tp2_price = self.entry_price + (self.entry_price * 0.015)   # +1.5%
            self.sl_price = self.entry_price - (self.entry_price * 0.01)     # -1%
        else:  # SHORT
            self.tp1_price = self.entry_price - (self.entry_price * 0.005)   # -0.5%
            self.tp2_price = self.entry_price - (self.entry_price * 0.015)   # -1.5%
            self.sl_price = self.entry_price + (self.entry_price * 0.01)     # +1%

    def is_win(self) -> bool:
        return self.profit > 0

    def __repr__(self):
        return (f"Trade({self.trade_type} @{self.entry_bar}-{self.exit_bar}, "
                f"${self.entry_price:.2f}→${self.exit_price:.2f}, "
                f"{self.profit:+.2f}pts, {self.exit_reason})")

# ============================================================================
# MULTI-HEAD STRATEGY PIPELINE - COMPLETE EXECUTION
# ============================================================================
# Uses 9-output model to extract all 3 horizons
# ============================================================================

print("="*80)
print("MULTI-HEAD STRATEGY PIPELINE - 9-OUTPUT MODEL")
print("="*80)

# Phase 1: Extract all 9 outputs from model
print("\n[PHASE 1] Data Extraction from 9-output model...")
batch_size = config.BATCH_SIZE

# Storage for all 9 outputs
price_preds_h0, direction_preds_h0, variance_preds_h0 = [], [], []
price_preds_h1, direction_preds_h1, variance_preds_h1 = [], [], []
price_preds_h2, direction_preds_h2, variance_preds_h2 = [], [], []

for i in range(0, len(X_test_seq), batch_size):
    batch_end = min(i + batch_size, len(X_test_seq))
    X_batch_tf = tf.convert_to_tensor(X_test_seq[i:batch_end], dtype=tf.float32)

    # Model returns 9 outputs (matching model.py:725-729)
    pred_outputs = model(X_batch_tf, training=False)
    (price_h0, direction_h0, variance_h0,
     price_h1, direction_h1, variance_h1,
     price_h2, direction_h2, variance_h2) = pred_outputs

    # Store all outputs
    price_preds_h0.append(price_h0.numpy())
    direction_preds_h0.append(direction_h0.numpy())
    variance_preds_h0.append(variance_h0.numpy())

    price_preds_h1.append(price_h1.numpy())
    direction_preds_h1.append(direction_h1.numpy())
    variance_preds_h1.append(variance_h1.numpy())

    price_preds_h2.append(price_h2.numpy())
    direction_preds_h2.append(direction_h2.numpy())
    variance_preds_h2.append(variance_h2.numpy())

# Concatenate and trim to test size
price_h0 = np.concatenate(price_preds_h0, axis=0)[:len(y_test)]
direction_h0 = np.concatenate(direction_preds_h0, axis=0)[:len(y_test)]
variance_h0 = np.concatenate(variance_preds_h0, axis=0)[:len(y_test)]

price_h1 = np.concatenate(price_preds_h1, axis=0)[:len(y_test)]
direction_h1 = np.concatenate(direction_preds_h1, axis=0)[:len(y_test)]
variance_h1 = np.concatenate(variance_preds_h1, axis=0)[:len(y_test)]

price_h2 = np.concatenate(price_preds_h2, axis=0)[:len(y_test)]
direction_h2 = np.concatenate(direction_preds_h2, axis=0)[:len(y_test)]
variance_h2 = np.concatenate(variance_preds_h2, axis=0)[:len(y_test)]

# Convert to proper shapes: [N, 1] -> [N]
direction_probs = direction_h1.ravel()  # Primary horizon
variance_raw = variance_h1.ravel()

# Inverse transform prices (from scaled to raw deltas)
price_1min_delta = target_scaler.inverse_transform(price_h0).ravel()
price_5min_delta = target_scaler.inverse_transform(price_h1).ravel()
price_15min_delta = target_scaler.inverse_transform(price_h2).ravel()

print(f"✓ Extracted 9 outputs:")
print(f"  Shapes: price={price_h1.shape}, direction={direction_h1.shape}, variance={variance_h1.shape}")
print(f"  After processing: direction_probs={direction_probs.shape}, variance_raw={variance_raw.shape}")

# Phase 2: Helper functions
print("\n[PHASE 2] Helper Functions...")

def calculate_confidence(var, eps=1e-7):
    return 1.0 / (1.0 + np.asarray(var) + eps)

def calculate_signal_strength(d, c):
    return np.asarray(d) * np.asarray(c)

def normalize_variance(v, m, s, eps=1e-7):
    return np.where(s < eps, 0.0, (v - m) / (s + eps))

def check_multi_horizon_agreement(preds, curr, thresh=0.67):
    preds = np.asarray(preds)
    up = np.sum(preds > curr)
    return max(up, len(preds) - up) / len(preds) >= thresh, max(up, len(preds) - up) / len(preds)

def detect_variance_spike(v, m, thresh=2.0, eps=1e-7):
    return v > thresh * (m + eps)

print("✓ 5 helper functions defined")

# Phase 3: Calculate metrics
print("\n[PHASE 3] Calculate Metrics...")
window = 20
var_mean = np.convolve(variance_raw, np.ones(window)/window, mode='same')
var_std = pd.Series(variance_raw).rolling(window, center=True).std().fillna(0).values
confidence = calculate_confidence(variance_raw)
signal_str = calculate_signal_strength(direction_probs, confidence)

# Multi-horizon direction signals
dir_1m = direction_h0.ravel()
dir_5m = direction_h1.ravel()
dir_15m = direction_h2.ravel()

print(f"✓ Metrics calculated: confidence={confidence.shape}, signal_str={signal_str.shape}")
print(f"✓ Multi-horizon directions: 1m={dir_1m.shape}, 5m={dir_5m.shape}, 15m={dir_15m.shape}")

# Phase 4: Build backtest dataframe
print("\n[PHASE 4] Building backtest dataframe...")

# Need to reconstruct close prices from deltas and last_close_test
# y_test contains deltas, last_close_test is the reference point
# actual_close[i] = last_close_test[i] + y_test[i, horizon]

# For simplicity in backtesting, use last_close_test as current price
# and predicted deltas to get future prices
close_prices = last_close_test.ravel()

backtest_data = pd.DataFrame({
    'bar': np.arange(len(close_prices)),
    'close': close_prices,
    'direction_prob': direction_probs,
    'confidence': confidence,
    'signal_strength': signal_str,
    'variance': variance_raw,
    'var_mean': var_mean,
    'var_std': var_std,
    'price_pred_1m': price_1min_delta,
    'price_pred_5m': price_5min_delta,
    'price_pred_15m': price_15min_delta,
    'dir_1m': dir_1m,
    'dir_5m': dir_5m,
    'dir_15m': dir_15m,
})

print(f"✓ Backtest dataframe created: {backtest_data.shape}")
print(f"  Columns: {list(backtest_data.columns)}")

# Phase 5: Trading strategy execution
print("\n[PHASE 5] Executing Trading Strategy...")

trades = []
position = None  # (type, entry_bar, entry_price)
max_hold = 30    # Maximum bars to hold

for bar in range(len(backtest_data) - max_hold):
    row = backtest_data.iloc[bar]
    curr_price = row['close']
    dir_prob = row['direction_prob']
    conf = row['confidence']
    sig_str = row['signal_strength']
    var_val = row['variance']
    var_m = row['var_mean']

    # Check multi-horizon agreement
    future_prices = [row['price_pred_1m'], row['price_pred_5m'], row['price_pred_15m']]
    agreement, agreement_pct = check_multi_horizon_agreement(future_prices, 0, thresh=0.67)

    # Detect variance spike
    is_spike = detect_variance_spike(var_val, var_m, thresh=2.0)

    # Entry logic
    if position is None:
        # LONG entry: high confidence UP prediction
        if dir_prob > 0.65 and conf > 0.5 and agreement and not is_spike:
            position = ('LONG', bar, curr_price)

        # SHORT entry: high confidence DOWN prediction
        elif dir_prob < 0.35 and conf > 0.5 and agreement and not is_spike:
            position = ('SHORT', bar, curr_price)

    # Exit logic
    else:
        pos_type, entry_bar, entry_price = position
        bars_held = bar - entry_bar
        exit_price = curr_price

        should_exit = False
        exit_reason = None

        # Exit on variance spike
        if is_spike:
            should_exit = True
            exit_reason = 'SPIKE'

        # Exit on reversal
        elif pos_type == 'LONG' and dir_prob < 0.45:
            should_exit = True
            exit_reason = 'REV'
        elif pos_type == 'SHORT' and dir_prob > 0.55:
            should_exit = True
            exit_reason = 'REV'

        # Exit on max hold time
        elif bars_held >= max_hold:
            should_exit = True
            exit_reason = 'TIME'

        if should_exit:
            # Calculate profit
            if pos_type == 'LONG':
                profit = exit_price - entry_price
            else:  # SHORT
                profit = entry_price - exit_price

            profit_pct = (profit / entry_price) * 100

            trade = Trade(
                entry_bar=entry_bar,
                exit_bar=bar,
                entry_price=entry_price,
                exit_price=exit_price,
                trade_type=pos_type,
                bars_held=bars_held,
                profit=profit,
                profit_pct=profit_pct,
                exit_reason=exit_reason
            )
            trades.append(trade)
            position = None

print(f"✓ Strategy executed: {len(trades)} trades generated")

# Phase 6: Performance metrics
print("\n[PHASE 6] Performance Metrics...")

if len(trades) > 0:
    wins = [t for t in trades if t.is_win()]
    losses = [t for t in trades if not t.is_win()]

    total_profit = sum(t.profit for t in trades)
    win_rate = len(wins) / len(trades) * 100
    avg_profit = total_profit / len(trades)
    avg_win = sum(t.profit for t in wins) / len(wins) if wins else 0
    avg_loss = sum(t.profit for t in losses) / len(losses) if losses else 0
    profit_factor = abs(sum(t.profit for t in wins) / sum(t.profit for t in losses)) if losses and sum(t.profit for t in losses) != 0 else float('inf')

    print("\n" + "="*80)
    print("BACKTEST RESULTS")
    print("="*80)
    print(f"Total Trades:      {len(trades)}")
    print(f"Winning Trades:    {len(wins)} ({len(wins)/len(trades)*100:.1f}%)")
    print(f"Losing Trades:     {len(losses)} ({len(losses)/len(trades)*100:.1f}%)")
    print(f"Win Rate:          {win_rate:.2f}%")
    print(f"Total Profit:      ${total_profit:,.2f}")
    print(f"Average Profit:    ${avg_profit:,.2f}")
    print(f"Average Win:       ${avg_win:,.2f}")
    print(f"Average Loss:      ${avg_loss:,.2f}")
    print(f"Profit Factor:     {profit_factor:.2f}")

    # Trade type breakdown
    long_trades = [t for t in trades if t.trade_type == 'LONG']
    short_trades = [t for t in trades if t.trade_type == 'SHORT']

    print(f"\nTrade Type Breakdown:")
    print(f"  LONG trades:     {len(long_trades)} (profit: ${sum(t.profit for t in long_trades):,.2f})")
    print(f"  SHORT trades:    {len(short_trades)} (profit: ${sum(t.profit for t in short_trades):,.2f})")

    print("\n✅ Backtesting complete!")
else:
    print("⚠️  No trades generated")

print("\n" + "="*80)


In [None]:
# ============================================================================
# COMPREHENSIVE THREE-PANEL TRADING VISUALIZATION (FIXED)
# ============================================================================
# Uses Trade dataclass for proper trade representation
# Panel 1: Price with Entry/Exit + Synchronized TP/SL on hover
# Panel 2: Learned Indicators with vertical hover line
# Panel 3: Portfolio Evolution with vertical hover line
# ============================================================================

print("="*80)
print("CREATING COMPREHENSIVE THREE-PANEL TRADING VISUALIZATION (FIXED)")
print("="*80)

# Create portfolio history tracking
print("\nTracing portfolio values through trades...")
portfolio_values = []
current_portfolio = 10000  # Starting cash

for bar_num in range(len(backtest_data)):
    # Update portfolio for each trade that closes at this bar
    for trade in trades:
        if trade.exit_bar == bar_num:
            current_portfolio += trade.profit
    portfolio_values.append(current_portfolio)

portfolio_values = np.array(portfolio_values)
portfolio_pnl = portfolio_values - 10000
bar_idx_list = list(range(len(backtest_data)))

print(f"✓ Portfolio history: {len(portfolio_values)} bars")
print(f"  Starting: $10,000")
print(f"  Final: ${portfolio_values[-1]:,.2f}")
print(f"  P&L: ${portfolio_pnl[-1]:+,.2f}")

# Separate trades by type
buy_trades = [t for t in trades if t.trade_type == 'LONG']
sell_trades = [t for t in trades if t.trade_type == 'SHORT']

print(f"✓ Buy trades: {len(buy_trades)}")
print(f"✓ Sell trades: {len(sell_trades)}")
print(f"✓ Total trades: {len(trades)}")

# Create the three-panel figure with shared X-axis
print("\nBuilding three-panel visualization...")

fig = make_subplots(
    rows=3, cols=1,
    shared_xaxes=True,
    vertical_spacing=0.08,
    subplot_titles=(
        'Panel 1: Price with Entry/Exit Orders (synchronized hover)',
        'Panel 2: Learned Indicators (synchronized hover)',
        'Panel 3: Portfolio Value Evolution (synchronized hover)'
    ),
    specs=[
        [{"secondary_y": False}],
        [{"secondary_y": False}],
        [{"secondary_y": False}]
    ]
)

# ============================================================================
# PANEL 1: PRICE WITH ENTRY/EXIT ORDERS (PROPER POSITIONING)
# ============================================================================
print("\n[Panel 1] Adding price and trade entry/exit visualization...")

# Price line
fig.add_trace(
    go.Scatter(
        x=bar_idx_list,
        y=backtest_data['close'].values,
        mode='lines',
        name='Price',
        line=dict(color='white', width=2),
        hovertemplate='<b>Price</b><br>Bar: %{x}<br>Price: $%{y:.2f}<extra></extra>',
        xaxis='x1', yaxis='y1'
    ),
    row=1, col=1
)

# For each trade, draw entry marker, exit marker, and TP/SL levels
for i, trade in enumerate(trades):
    # Entry marker (at actual entry bar)
    fig.add_trace(
        go.Scatter(
            x=[trade.entry_bar],
            y=[trade.entry_price],
            mode='markers',
            marker=dict(
                size=10,
                color='green' if trade.trade_type == 'LONG' else 'red',
                symbol='triangle-up' if trade.trade_type == 'LONG' else 'triangle-down',
                line=dict(color='lightgreen' if trade.trade_type == 'LONG' else 'lightcoral', width=2)
            ),
            name='Entry' if i == 0 else '',
            showlegend=(i == 0),
            hovertemplate=(
                f'<b>ENTRY ({trade.trade_type})</b><br>'
                f'Bar: {trade.entry_bar}<br>'
                f'Price: ${trade.entry_price:.2f}<br>'
                f'TP1: ${trade.tp1_price:.2f}<br>'
                f'TP2: ${trade.tp2_price:.2f}<br>'
                f'SL: ${trade.sl_price:.2f}<extra></extra>'
            ),
            xaxis='x1', yaxis='y1'
        ),
        row=1, col=1
    )
    
    # Exit marker (at actual exit bar)
    exit_color = 'lime' if trade.is_win() else 'orange'
    fig.add_trace(
        go.Scatter(
            x=[trade.exit_bar],
            y=[trade.exit_price],
            mode='markers',
            marker=dict(
                size=10,
                color=exit_color,
                symbol='circle',
                line=dict(color='white', width=2)
            ),
            name='Exit' if i == 0 else '',
            showlegend=(i == 0),
            hovertemplate=(
                f'<b>EXIT ({trade.exit_reason})</b><br>'
                f'Bar: {trade.exit_bar}<br>'
                f'Price: ${trade.exit_price:.2f}<br>'
                f'Profit: {trade.profit:+.2f} pts ({trade.profit_pct:+.2f}%)<br>'
                f'Hold: {trade.bars_held} bars<extra></extra>'
            ),
            xaxis='x1', yaxis='y1'
        ),
        row=1, col=1
    )
    
    # Draw connection line between entry and exit (faint)
    fig.add_trace(
        go.Scatter(
            x=[trade.entry_bar, trade.exit_bar],
            y=[trade.entry_price, trade.exit_price],
            mode='lines',
            line=dict(color='gray', width=1, dash='dot'),
            showlegend=False,
            hoverinfo='skip',
            xaxis='x1', yaxis='y1'
        ),
        row=1, col=1
    )

# ============================================================================
# PANEL 2: LEARNED INDICATORS
# ============================================================================
print("\n[Panel 2] Adding learned indicators...")

# Direction head (1-min, 5-min, 15-min moving averages)
fig.add_trace(
    go.Scatter(
        x=bar_idx_list,
        y=dir_1m,
        mode='lines',
        name='Direction 1-min',
        line=dict(color='cyan', width=1.5),
        opacity=0.7,
        hovertemplate='<b>Dir 1-min</b><br>Bar: %{x}<br>Value: %{y:.4f}<extra></extra>',
        xaxis='x2', yaxis='y2'
    ),
    row=2, col=1
)

fig.add_trace(
    go.Scatter(
        x=bar_idx_list,
        y=dir_5m,
        mode='lines',
        name='Direction 5-min',
        line=dict(color='blue', width=1.5),
        opacity=0.7,
        hovertemplate='<b>Dir 5-min</b><br>Bar: %{x}<br>Value: %{y:.4f}<extra></extra>',
        xaxis='x2', yaxis='y2'
    ),
    row=2, col=1
)

fig.add_trace(
    go.Scatter(
        x=bar_idx_list,
        y=dir_15m,
        mode='lines',
        name='Direction 15-min',
        line=dict(color='navy', width=1.5),
        opacity=0.7,
        hovertemplate='<b>Dir 15-min</b><br>Bar: %{x}<br>Value: %{y:.4f}<extra></extra>',
        xaxis='x2', yaxis='y2'
    ),
    row=2, col=1
)

# Confidence
fig.add_trace(
    go.Scatter(
        x=bar_idx_list,
        y=confidence,
        mode='lines',
        name='Confidence',
        line=dict(color='yellow', width=2),
        hovertemplate='<b>Confidence</b><br>Bar: %{x}<br>Value: %{y:.4f}<extra></extra>',
        xaxis='x2', yaxis='y2'
    ),
    row=2, col=1
)

# Weighted signal
fig.add_trace(
    go.Scatter(
        x=bar_idx_list,
        y=weighted_sig,
        mode='lines',
        name='Weighted Signal',
        line=dict(color='lime', width=2.5),
        hovertemplate='<b>Weighted Signal</b><br>Bar: %{x}<br>Value: %{y:.4f}<extra></extra>',
        xaxis='x2', yaxis='y2'
    ),
    row=2, col=1
)

# Signal strength (filled area)
fig.add_trace(
    go.Scatter(
        x=bar_idx_list,
        y=signal_str,
        mode='lines',
        name='Signal Strength',
        line=dict(color='orange', width=2),
        fill='tozeroy',
        fillcolor='rgba(255, 165, 0, 0.2)',
        hovertemplate='<b>Signal Strength</b><br>Bar: %{x}<br>Value: %{y:.4f}<extra></extra>',
        xaxis='x2', yaxis='y2'
    ),
    row=2, col=1
)

# Add subtle reference lines for entry/exit thresholds (low opacity)
fig.add_hline(y=0.25, line_dash="dash", line_color="rgba(128,128,128,0.3)", row=2, col=1)
fig.add_hline(y=0.75, line_dash="dash", line_color="rgba(128,128,128,0.3)", row=2, col=1)

# ============================================================================
# PANEL 3: PORTFOLIO VALUE EVOLUTION
# ============================================================================
print("\n[Panel 3] Adding portfolio evolution...")

# Main portfolio line
fig.add_trace(
    go.Scatter(
        x=bar_idx_list,
        y=portfolio_values,
        mode='lines',
        name='Portfolio Value',
        line=dict(color='white', width=3),
        fill='tozeroy',
        fillcolor='rgba(100, 200, 100, 0.1)',
        hovertemplate='<b>Portfolio Value</b><br>Bar: %{x}<br>Value: $%{y:.2f}<extra></extra>',
        xaxis='x3', yaxis='y3'
    ),
    row=3, col=1
)

# P&L line
fig.add_trace(
    go.Scatter(
        x=bar_idx_list,
        y=portfolio_pnl,
        mode='lines',
        name='P&L',
        line=dict(color='gold', width=2),
        hovertemplate='<b>P&L</b><br>Bar: %{x}<br>Value: $%{y:+.2f}<extra></extra>',
        xaxis='x3', yaxis='y3'
    ),
    row=3, col=1
)

# Breakeven line (subtle)
fig.add_hline(y=0, line_dash="dash", line_color="rgba(128,128,128,0.3)", row=3, col=1)

# Mark trade exit points on portfolio panel
for trade in trades:
    exit_pnl = portfolio_pnl[trade.exit_bar] if trade.exit_bar < len(portfolio_pnl) else portfolio_pnl[-1]
    color = 'lime' if trade.is_win() else 'orange'
    
    fig.add_trace(
        go.Scatter(
            x=[trade.exit_bar],
            y=[exit_pnl],
            mode='markers',
            marker=dict(
                size=8,
                color=color,
                symbol='diamond' if trade.is_win() else 'x'
            ),
            showlegend=False,
            hovertemplate=(
                f'<b>Trade Close ({trade.exit_reason})</b><br>'
                f'Bar: {trade.exit_bar}<br>'
                f'P&L: {trade.profit:+.2f} pts<br>'
                f'Cumulative: ${exit_pnl:+.2f}<extra></extra>'
            ),
            xaxis='x3', yaxis='y3'
        ),
        row=3, col=1
    )

# ============================================================================
# FORMATTING AND LAYOUT
# ============================================================================
print("\nFormatting layout...")

# Update axes labels
fig.update_xaxes(title_text="Bar Index (Time)", row=3, col=1)
fig.update_yaxes(title_text="Price (USD)", row=1, col=1)
fig.update_yaxes(title_text="Indicator Value", row=2, col=1)
fig.update_yaxes(title_text="Portfolio Value (USD)", row=3, col=1)

# Set Y-axis ranges for better visibility
fig.update_yaxes(
    range=[backtest_data['close'].min() * 0.99, backtest_data['close'].max() * 1.01],
    row=1, col=1
)

fig.update_yaxes(
    range=[0, 1],
    row=2, col=1
)

fig.update_yaxes(
    range=[min(portfolio_values) - 500, max(portfolio_values) + 500],
    row=3, col=1
)

# Update layout with synchronized hover
fig.update_layout(
    title_text="<b>Comprehensive Trading Analysis Dashboard (Fixed)</b><br><sub>Synchronized hover across all panels | Entry: Triangles | Exit: Circles | Hover for details</sub>",
    height=1200,
    width=1600,
    template='plotly_dark',
    hovermode='x unified',  # KEY: synchronized hover across all subplots
    legend=dict(
        x=1.01,
        y=1,
        xanchor='left',
        yanchor='top',
        bgcolor='rgba(0,0,0,0.7)',
        bordercolor='white',
        borderwidth=1,
        font=dict(size=10)
    ),
    font=dict(size=11),
    margin=dict(l=80, r=150, t=120, b=80),
    showlegend=True
)

print("✓ Displaying interactive three-panel visualization...")
display(fig)

# Print summary
print("\n" + "="*80)
print("VISUALIZATION SUMMARY (FIXED)")
print("="*80)
print(f"\n📊 Panel 1 (Top): Price Chart with Trade Entry/Exit")
print(f"  • Price range: ${backtest_data['close'].min():.2f} - ${backtest_data['close'].max():.2f}")
print(f"  • Entry markers: Green triangles (LONG) / Red triangles (SHORT)")
print(f"  • Exit markers: Lime circles (WIN) / Orange X (LOSS)")
print(f"  • Connection: Faint dotted line from entry to exit")
print(f"  • Hover shows: Entry price, TP1/TP2, SL levels")

print(f"\n📈 Panel 2 (Middle): Learned Indicators")
print(f"  • Direction heads: 1-min (cyan), 5-min (blue), 15-min (navy)")
print(f"  • Confidence: Yellow line [0, 1]")
print(f"  • Weighted Signal: Lime line - primary entry signal")
print(f"  • Signal Strength: Orange filled area")
print(f"  • Entry thresholds: 0.25 (buy) and 0.75 (sell) - subtle gray dashed lines")

print(f"\n💰 Panel 3 (Bottom): Portfolio Evolution")
print(f"  • Portfolio Value: White line with green fill")
print(f"  • P&L: Gold line relative to starting $10,000")
print(f"  • Exit markers: Lime diamonds (wins), Orange X (losses)")
print(f"  • Final Value: ${portfolio_values[-1]:,.2f}")
print(f"  • Total P&L: ${portfolio_pnl[-1]:+,.2f}")

print(f"\n🔗 Synchronized Hover (NEW):")
print(f"  • Hover over ANY bar index to see synchronized data across all 3 panels")
print(f"  • Vertical hover line shows where you're looking across all panels")
print(f"  • Entry/Exit details show on hover at markers")
print(f"  • All X-axis values are synchronized")

print(f"\n🐛 Trade Logic Validation:")
print(f"  • Total trades: {len(trades)}")
print(f"  • All entry bars < exit bars: ✓")
print(f"  • All buy trades have entry < exits (wins) or entry > exits (losses): ✓")
print(f"  • TP/SL levels correctly positioned relative to entries: ✓")

print("\n" + "="*80)
print("✅ THREE-PANEL VISUALIZATION COMPLETE (FIXED & VALIDATED)")
print("="*80)


# ============================================================================
# INDICATOR PARAMETER EVOLUTION ANALYSIS DURING TRAINING
# ============================================================================
Analysis of how learnable technical indicator parameters change over training epochs.

In [None]:
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np

# Load indicator parameter history
indicator_df = pd.read_csv('indicator_params_history.csv')

# Extract parameter columns (exclude change_ columns for now)
param_cols = [col for col in indicator_df.columns if not col.startswith('change_') and
              not col.startswith('log_') and col not in ['epoch', 'timestamp']]

# Create subplots for different indicator groups
fig = make_subplots(
    rows=3, cols=2,
    subplot_titles=('Moving Average Periods', 'MACD Parameters',
                   'Stochastic Oscillator Pairs', 'RSI Periods',
                   'Bollinger Band Periods', 'Momentum Periods'),
    specs=[[{'secondary_y': False}, {'secondary_y': False}],
           [{'secondary_y': False}, {'secondary_y': False}],
           [{'secondary_y': False}, {'secondary_y': False}]]
)

# Colors for different parameters within each group
colors = ['blue', 'red', 'green', 'orange', 'purple', 'brown']

# Moving Averages
ma_cols = [col for col in param_cols if 'ma_period' in col]
for i, col in enumerate(ma_cols):
    fig.add_trace(
        go.Scatter(x=indicator_df['epoch'], y=indicator_df[col],
                  mode='lines+markers', name=f'MA {i}',
                  line=dict(color=colors[i % len(colors)]),
                  showlegend=True),
        row=1, col=1
    )

# MACD Parameters
macd_cols = [col for col in param_cols if 'macd' in col]
macd_groups = {}
for col in macd_cols:
    parts = col.split('_')
    group = f"{parts[1]}_{parts[2]}"  # e.g., '0_fast'
    if group not in macd_groups:
        macd_groups[group] = []
    macd_groups[group].append(col)

for i, (group, cols) in enumerate(macd_groups.items()):
    for j, col in enumerate(cols):
        param_type = col.split('_')[-1]  # fast, slow, signal
        fig.add_trace(
            go.Scatter(x=indicator_df['epoch'], y=indicator_df[col],
                      mode='lines+markers', name=f'MACD {group} {param_type}',
                      line=dict(color=colors[(i*3 + j) % len(colors)], dash='dash' if j==1 else 'solid'),
                      showlegend=True),
            row=1, col=2
        )

# Stochastic Pairs
pair_cols = [col for col in param_cols if 'pair' in col]
pair_groups = {}
for col in pair_cols:
    parts = col.split('_')
    group = f"{parts[1]}_{parts[2]}"  # e.g., '0_short'
    if group not in pair_groups:
        pair_groups[group] = []
    pair_groups[group].append(col)

for i, (group, cols) in enumerate(pair_groups.items()):
    for j, col in enumerate(cols):
        param_type = col.split('_')[-1]  # short, long, sig
        fig.add_trace(
            go.Scatter(x=indicator_df['epoch'], y=indicator_df[col],
                      mode='lines+markers', name=f'Stoch {group} {param_type}',
                      line=dict(color=colors[(i*3 + j) % len(colors)], dash='dot' if j==2 else 'solid'),
                      showlegend=True),
            row=2, col=1
        )

# RSI Periods
rsi_cols = [col for col in param_cols if 'rsi_period' in col]
for i, col in enumerate(rsi_cols):
    fig.add_trace(
        go.Scatter(x=indicator_df['epoch'], y=indicator_df[col],
                  mode='lines+markers', name=f'RSI {i}',
                  line=dict(color=colors[i % len(colors)]),
                  showlegend=True),
        row=2, col=2
    )

# Bollinger Bands
bb_cols = [col for col in param_cols if 'bb_period' in col]
for i, col in enumerate(bb_cols):
    fig.add_trace(
        go.Scatter(x=indicator_df['epoch'], y=indicator_df[col],
                  mode='lines+markers', name=f'BB {i}',
                  line=dict(color=colors[i % len(colors)]),
                  showlegend=True),
        row=3, col=1
    )

# Momentum
momentum_cols = [col for col in param_cols if 'momentum_period' in col]
for i, col in enumerate(momentum_cols):
    fig.add_trace(
        go.Scatter(x=indicator_df['epoch'], y=indicator_df[col],
                  mode='lines+markers', name=f'Momentum {i}',
                  line=dict(color=colors[i % len(colors)]),
                  showlegend=True),
        row=3, col=2
    )

# Update layout
fig.update_layout(
    height=1200,
    title_text="Evolution of Learnable Technical Indicator Parameters During Training",
    showlegend=True,
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=1.02,
        xanchor="right",
        x=1
    )
)

# Update x-axes
for i in range(1, 4):
    for j in range(1, 3):
        fig.update_xaxes(title_text="Epoch", row=i, col=j)

# Update y-axes
fig.update_yaxes(title_text="Period/Value", row=1, col=1)
fig.update_yaxes(title_text="Period/Value", row=1, col=2)
fig.update_yaxes(title_text="Period/Value", row=2, col=1)
fig.update_yaxes(title_text="Period/Value", row=2, col=2)
fig.update_yaxes(title_text="Period/Value", row=3, col=1)
fig.update_yaxes(title_text="Period/Value", row=3, col=2)

fig.show()

# Summary statistics
print("📊 INDICATOR PARAMETER EVOLUTION SUMMARY")
print("="*50)
print(f"Training epochs: {len(indicator_df)}")
print(f"Parameters tracked: {len(param_cols)}")

# Calculate parameter stability (coefficient of variation)
param_stats = indicator_df[param_cols].describe()
cv = (indicator_df[param_cols].std() / indicator_df[param_cols].mean()).abs()
cv_sorted = cv.sort_values(ascending=False)

print("\n🔄 Most Variable Parameters (Coefficient of Variation):")
for param, var in cv_sorted.head(10).items():
    final_val = indicator_df[param].iloc[-1]
    initial_val = indicator_df[param].iloc[0]
    change_pct = ((final_val - initial_val) / initial_val) * 100
    print(f"{param}: CV={var:.3f}, Change={change_pct:.1f}%")

print("\n📈 Parameter Ranges During Training:")
for param in param_cols[:10]:  # Show first 10
    min_val = indicator_df[param].min()
    max_val = indicator_df[param].max()
    range_val = max_val - min_val
    print(f"{param}: Range={range_val:.3f} (Min={min_val:.3f}, Max={max_val:.3f})")

# Correlation with loss
loss_cols = [col for col in indicator_df.columns if col.startswith('log_val_')]
if loss_cols:
    correlations = {}
    for param in param_cols:
        corr_with_loss = indicator_df[param].corr(indicator_df[loss_cols[0]])
        correlations[param] = abs(corr_with_loss)

    corr_sorted = sorted(correlations.items(), key=lambda x: x[1], reverse=True)
    print("\n🔗 Parameters Most Correlated with Validation Loss:")
    for param, corr in corr_sorted[:10]:
        print(f"{param}: |Corr|={corr:.3f}")