# 1. Data Acquisition

In [None]:
# 0) Install
!pip install yfinance -q

# 1) Imports
import os
from datetime import datetime
import pandas as pd
import yfinance as yf

# 2) Ticker list (50 stocks + 3 indices)
tickers = [
    # 50 stocks (sector-balanced sample)
    "AAPL","MSFT","GOOGL","AMZN","META","NVDA","TSLA","ORCL","INTC","AMD",
    "CSCO","CRM","JPM","BAC","WFC","C","GS","MS","BRK-B","V",
    "MA","JNJ","PFE","MRK","UNH","ABBV","LLY","AMGN","HD",
    "MCD","NKE","SBUX","COST","KO","PEP","WMT","CVS","XOM","CVX",
    "COP","SLB","BA","CAT","GE","HON","AMT","T","VZ","TMUS"
]

# 3) Indices
indices = ["^GSPC", "^IXIC", "^DJI"]

all_tickers = tickers + indices

# 4) Params
start = "2015-01-01"
end   = "2024-12-31"
interval = "1d"   # daily

# 5) Fetch data
print(f"Downloading {len(all_tickers)} tickers from {start} to {end} ...")
raw = yf.download(
    tickers = all_tickers,
    start = start,
    end = end,
    interval = interval,
    group_by = "ticker",
    auto_adjust = True,   # adjust for splits/dividends -> easier modeling
    threads = True,
    progress = True
)

# 6) Normalize into dictionary of DataFrames (OHLCV) per ticker
os.makedirs("/content/data", exist_ok=True)
data_dict = {}

def extract_df(all_df, symbol):
    # yfinance returns a MultiIndex column when multiple tickers requested
    if isinstance(all_df.columns, pd.MultiIndex):
        # columns like (symbol, 'Open'), (symbol, 'High'), ...
        df = all_df[symbol].copy()
    else:
        # single ticker case
        df = all_df.copy()
    # ensure standard columns exist
    expected = ["Open","High","Low","Close","Volume"]
    missing = [c for c in expected if c not in df.columns]
    # If Adj Close was used by auto_adjust, Close exists; if not, try Adj Close
    if "Close" not in df.columns and "Adj Close" in df.columns:
        df = df.rename(columns={"Adj Close":"Close"})
    return df

for sym in all_tickers:
    try:
        df = extract_df(raw, sym)
    except Exception:
        # fallback: try downloading single ticker (robustness)
        df = yf.download(sym, start=start, end=end, interval=interval, auto_adjust=True)
    # Basic cleaning
    df.index = pd.to_datetime(df.index)
    # reorder columns if present
    cols = [c for c in ["Open","High","Low","Close","Volume"] if c in df.columns]
    df = df[cols].copy()
    data_dict[sym] = df
    # Save per-ticker csv
    df.to_csv(f"/content/data/{sym.replace('^','_')}.csv")

print("Saved per-ticker CSVs to /content/data/")

# 7) Quick integrity report
report_rows = []
for sym, df in data_dict.items():
    n = len(df)
    n_missing = df.isna().sum().sum()
    zero_vol = int((df.get("Volume", pd.Series(dtype=float)) == 0).sum()) if "Volume" in df.columns else 0
    first, last = (df.index.min(), df.index.max()) if n>0 else (None,None)
    report_rows.append((sym, n, n_missing, zero_vol, first, last))

report = pd.DataFrame(report_rows, columns=["ticker","n_rows","n_missing_cells","zero_volume_rows","first_date","last_date"])
display(report.sort_values("ticker").reset_index(drop=True))

# 9) Create a combined parquet (long format) for faster I/O later
rows = []
for sym, df in data_dict.items():
    if df.empty:
        continue
    tmp = df.reset_index().rename(columns={"index":"date"})
    tmp["ticker"] = sym
    rows.append(tmp)

combined = pd.concat(rows, ignore_index=True)
combined.to_parquet("/content/data/combined_ohlcv.parquet", index=False)
print("Combined dataset saved to /content/data/combined_ohlcv.parquet")
print("Combined shape:", combined.shape)

# 10) Final checks: overall missing & sample
print("Overall missing cells:", combined.isna().sum().sum())
print("Sample rows:")
display(combined.head())

# 2. Data Preprocessing


In [None]:
import pandas as pd
import numpy as np
import tensorflow as tf
from sklearn.preprocessing import StandardScaler

# Load combined dataset
combined = pd.read_parquet("/content/data/combined_ohlcv.parquet")

# 1) Fill missing values with 5-day rolling mean
combined[['Open','High','Low','Close','Volume']] = combined.groupby('ticker')[
    ['Open','High','Low','Close','Volume']
].transform(lambda x: x.fillna(x.rolling(5, min_periods=1).mean()))

# 2) Drop duplicates and zero-volume rows
combined = combined.drop_duplicates()
combined = combined[combined['Volume'] != 0]

# 2.5) Add Volume_log
combined['Volume_log'] = np.log1p(combined['Volume'])  # log1p = log(1 + Volume)

# 3) Sort by time (Important!)
combined = combined.sort_values(['ticker','Date'])

# ============================================================================
# 4) Add 3 new parameters (must be calculated by ticker group to ensure temporal compliance)
# ============================================================================

print("Starting calculation of 3 new parameters...")

# 4.1 Calculate returns (for volatility and momentum)
combined['Returns'] = combined.groupby('ticker')['Close'].pct_change()

# 4.2 Parameter 1: volatility_ratio_5_20 (5-day / 20-day volatility ratio)
def calculate_volatility_ratio(group):
    """Calculate volatility ratio, ensuring temporal compliance"""
    returns = group['Returns']

    # 5-day volatility (rolling(5) uses current and previous 4 days of data)
    vol_5 = returns.rolling(window=5, min_periods=3).std()

    # 20-day volatility (rolling(20) uses current and previous 19 days of data)
    vol_20 = returns.rolling(window=20, min_periods=10).std()

    # Calculate ratio (avoid division by zero)
    ratio = vol_5 / vol_20.replace(0, np.nan)
    return ratio

combined['volatility_ratio_5_20'] = combined.groupby('ticker').apply(
    lambda x: calculate_volatility_ratio(x)
).reset_index(level=0, drop=True)

print("  ‚úÖ Completed volatility_ratio_5_20")

# 4.3 Parameter 2: momentum_5_sign_strength (5-day momentum strength adjusted by volatility)
def calculate_momentum_strength(group):
    """Calculate momentum strength, ensuring temporal compliance"""
    close = group['Close']
    returns = group['Returns']

    # 5-day return (uses data from t-5 to t)
    returns_5d = (close - close.shift(5)) / close.shift(5)

    # 5-day volatility (uses returns from t-4 to t)
    vol_5d = returns.rolling(window=5, min_periods=3).std()

    # Momentum strength = 5-day return / 5-day volatility (return per unit volatility)
    # Handle division by zero and extreme values
    with np.errstate(divide='ignore', invalid='ignore'):
        strength = returns_5d / vol_5d.replace(0, np.nan)

    # Clip extreme values
    strength = strength.clip(lower=-5, upper=5)
    return strength

combined['momentum_5_sign_strength'] = combined.groupby('ticker').apply(
    lambda x: calculate_momentum_strength(x)
).reset_index(level=0, drop=True)

print("  ‚úÖ Completed momentum_5_sign_strength")

# 4.4 Parameter 3: volume_price_divergence (simple divergence indicator)
def calculate_volume_divergence(group):
    """Calculate volume-price divergence, ensuring temporal compliance"""
    close = group['Close']
    volume_log = group['Volume_log']

    # 5-day price momentum (uses data from t-5 to t)
    price_momentum_5 = (close - close.shift(5)) / close.shift(5)

    # 5-day volume momentum (uses mean of t-5 to t-1 data, avoids future data)
    # Note: cannot use today's volume to calculate momentum since it's already known
    volume_mean_5 = volume_log.shift(1).rolling(window=5, min_periods=3).mean()
    current_volume = volume_log
    volume_momentum_5 = (current_volume - volume_mean_5) / volume_mean_5.replace(0, np.nan)

    # Simple divergence indicator: price momentum √ó volume momentum
    # Positive: price and volume rise together (healthy uptrend)
    # Negative: price-volume divergence (potential reversal)
    divergence = price_momentum_5 * volume_momentum_5

    # Clip extreme values
    divergence = divergence.clip(lower=-1, upper=1)
    return divergence

combined['volume_price_divergence'] = combined.groupby('ticker').apply(
    lambda x: calculate_volume_divergence(x)
).reset_index(level=0, drop=True)

print("  ‚úÖ Completed volume_price_divergence")
print("‚úÖ All 3 new parameters calculated")

# ============================================================================
# 5) Create two Y labels: binary classification (volatility presence) and ternary classification (volatility direction)
# ============================================================================

print("\nCreating two Y labels...")

def create_volatility_targets(combined, threshold=0.025):
    """
    Create two target variables:
    1. y_binary: binary classification, 0=calm, 1=significant volatility (up or down)
    2. y_3class: ternary classification, 0=calm, 1=big rise, 2=big drop
    """
    combined = combined.sort_values(['ticker','Date']).copy()

    # Get closing prices for the next 1 and 2 days
    price_tomorrow = combined.groupby('ticker')['Close'].shift(-1)
    price_day_after = combined.groupby('ticker')['Close'].shift(-2)

    # Calculate returns relative to today's closing price
    return_tomorrow = (price_tomorrow - combined['Close']) / combined['Close']
    return_day_after = (price_day_after - combined['Close']) / combined['Close']

    # Find maximum rise and maximum drop within the next 2 days
    max_up = pd.concat([return_tomorrow, return_day_after], axis=1).max(axis=1)
    max_down = pd.concat([return_tomorrow, return_day_after], axis=1).min(axis=1)

    # ========== Create binary classification Y: whether there is significant volatility ==========
    # Significant volatility = any day's absolute price change > threshold within next 2 days
    y_binary = ((max_up > threshold) | (max_down < -threshold)).astype(int)
    y_binary.name = 'y_binary'

    # ========== Create ternary classification Y: volatility direction ==========
    y_3class = pd.Series(0, index=combined.index)  # Default class 0 (calm)

    # Case 1: only big rise (max rise > threshold, and no big drop)
    condition_up_only = (max_up > threshold) & (max_down >= -threshold)
    y_3class.loc[condition_up_only] = 1  # big rise

    # Case 2: only big drop (max drop < -threshold, and no big rise)
    condition_down_only = (max_down < -threshold) & (max_up <= threshold)
    y_3class.loc[condition_down_only] = 2  # big drop

    # Case 3: both big rise and big drop (high volatility)
    condition_both = (max_up > threshold) & (max_down < -threshold)
    if condition_both.any():
        # Choose the direction with larger magnitude
        for idx in condition_both[condition_both].index:
            if abs(max_up.loc[idx]) > abs(max_down.loc[idx]):
                y_3class.loc[idx] = 1  # rise magnitude is larger
            else:
                y_3class.loc[idx] = 2  # drop magnitude is larger

    y_3class.name = 'y_3class'

    # Statistics distribution
    total = len(y_binary.dropna())
    if total > 0:
        print(f"Binary classification distribution:")
        print(f"  Class 0 (calm): {(y_binary == 0).sum()/total:.2%}")
        print(f"  Class 1 (significant volatility): {(y_binary == 1).sum()/total:.2%}")

        print(f"\nTernary classification distribution:")
        class_counts = y_3class.value_counts()
        print(f"  Class 0 (calm): {class_counts.get(0, 0)/total:.2%}")
        print(f"  Class 1 (big rise): {class_counts.get(1, 0)/total:.2%}")
        print(f"  Class 2 (big drop): {class_counts.get(2, 0)/total:.2%}")
        print(f"  Total samples: {total}")

    return y_binary, y_3class

# Create two target variables (2.5% threshold)
combined['y_binary'], combined['y_3class'] = create_volatility_targets(combined, threshold=0.025)

# ============================================================================
# 6) Clean data and define features
# ============================================================================

print("\nCleaning data...")

# Delete rows containing NaN (target variables and important features)
required_columns = ['y_binary', 'y_3class',
                   'volatility_ratio_5_20', 'momentum_5_sign_strength',
                   'volume_price_divergence']

combined = combined.dropna(subset=required_columns).reset_index(drop=True)

print(f"Data size after cleaning: {len(combined)} rows")

# Define feature columns (5 basic + 3 new features)
base_features = ['Open', 'High', 'Low', 'Close', 'Volume_log']
new_features = ['volatility_ratio_5_20', 'momentum_5_sign_strength', 'volume_price_divergence']
features = base_features + new_features  # 8 features in total

print(f"\nFeature list ({len(features)} features):")
for i, feat in enumerate(features, 1):
    print(f"  {i:2d}. {feat}")

# ============================================================================
# 7) Construct 30-day sliding windows (using all features)
# ============================================================================

window_size = 30
X_list, y_binary_list, y_3class_list = [], [], []

print(f"\nConstructing sliding windows ({window_size} days)...")

for ticker, group in combined.groupby('ticker'):
    group = group.sort_values('Date')

    # Extract feature data
    data = group[features].values
    y_binary = group['y_binary'].values
    y_3class = group['y_3class'].values

    # Create sliding windows (ensure sufficient data)
    n_windows = len(group) - window_size
    if n_windows < 1:
        continue

    for i in range(n_windows):
        X_list.append(data[i:i+window_size])
        y_binary_list.append(y_binary[i+window_size])
        y_3class_list.append(y_3class[i+window_size])

    if len(X_list) % 5000 == 0:
        print(f"  Created {len(X_list)} windows...")

X = np.array(X_list)
y_binary = np.array(y_binary_list)
y_3class = np.array(y_3class_list)

print("\nDataset shapes:")
print(f"X: {X.shape}  (samples, {window_size} days, {len(features)} features)")
print(f"y_binary: {y_binary.shape}  (binary classification: 0=calm, 1=significant volatility)")
print(f"y_3class: {y_3class.shape}  (ternary classification: 0=calm, 1=big rise, 2=big drop)")

# Binary classification distribution statistics
binary_pos = np.sum(y_binary == 1)
binary_total = len(y_binary)
print(f"\nBinary classification distribution:")
print(f"  Class 0 (calm): {binary_total - binary_pos} samples ({(binary_total - binary_pos)/binary_total:.2%})")
print(f"  Class 1 (significant volatility): {binary_pos} samples ({binary_pos/binary_total:.2%})")

# Ternary classification distribution statistics
unique, counts = np.unique(y_3class, return_counts=True)
print(f"\nTernary classification distribution:")
for cls, cnt in zip(unique, counts):
    label = {0: 'Calm', 1: 'Big rise', 2: 'Big drop'}[cls]
    print(f"  Class {cls} ({label}): {cnt} samples ({cnt/len(y_3class):.2%})")

# ============================================================================
# 8) Split dataset chronologically (training 80%, testing 20%)
# ============================================================================

split_idx = int(len(X) * 0.8)

X_train_full = X[:split_idx]
y_binary_train_full = y_binary[:split_idx]
y_3class_train_full = y_3class[:split_idx]

X_test_final = X[split_idx:]
y_binary_test_final = y_binary[split_idx:]
y_3class_test_final = y_3class[split_idx:]

print(f"\nSplit results:")
print(f"Training set: {len(X_train_full)} samples ({len(X_train_full)/len(X)*100:.1f}%)")
print(f"Test set: {len(X_test_final)} samples ({len(X_test_final)/len(X)*100:.1f}%)")

print(f"\nTraining set distribution - Binary: calm={np.sum(y_binary_train_full==0)}, significant volatility={np.sum(y_binary_train_full==1)}")
print(f"Training set distribution - Ternary: calm={np.sum(y_3class_train_full==0)}, big rise={np.sum(y_3class_train_full==1)}, big drop={np.sum(y_3class_train_full==2)}")

# ============================================================================
# 9) Fit normalizer only on training set
# ============================================================================

print("\nNormalizing...")

n_train_samples = X_train_full.shape[0]
n_test_samples = X_test_final.shape[0]
n_features = len(features)

# Reshape to 2D for normalization
X_train_2d = X_train_full.reshape(-1, n_features)  # (train_samples*30, 8)
X_test_2d = X_test_final.reshape(-1, n_features)    # (test_samples*30, 8)

# Create normalizer
scaler = StandardScaler()

# Fit normalizer only on training set
X_train_normalized_2d = scaler.fit_transform(X_train_2d)

# Transform test set using training set statistics
X_test_normalized_2d = scaler.transform(X_test_2d)

# Restore 3D shape
X_train = X_train_normalized_2d.reshape(n_train_samples, window_size, n_features)
X_test = X_test_normalized_2d.reshape(n_test_samples, window_size, n_features)

print("‚úÖ Normalization completed")
print(f"Training set after normalization - Mean: {np.mean(X_train):.4f}, Std: {np.std(X_train):.4f}")
print(f"Test set after normalization - Mean: {np.mean(X_test):.4f}, Std: {np.std(X_test):.4f}")

# ============================================================================
# 10) Save normalizer and feature list
# ============================================================================

import joblib

joblib.dump(scaler, 'scaler_volatility.pkl')
joblib.dump(features, 'features_volatility.pkl')

print("\n‚úÖ Data preparation completed!")
print("‚úÖ Normalizer saved as: scaler_volatility.pkl")
print("‚úÖ Feature list saved as: features_volatility.pkl")
print(f"‚úÖ You can now train a model using {len(features)} features")

# ============================================================================
# Usage instructions
# ============================================================================

print("\n" + "="*60)
print("Data Preparation Complete - Usage Instructions")
print("="*60)
print("Available datasets:")
print("\n1. Binary classification model (detect significant volatility):")
print("   Features: X_train, X_test")
print("   Labels: y_binary_train_full, y_binary_test_final")
print("   Description: 0=calm(no 2.5%+ volatility in next 2 days), 1=significant volatility")
print("   Suitable for: simple volatility detection, use 'binary_crossentropy' loss function")
print("   Output layer: Dense(1, activation='sigmoid')")
print("")
print("2. Ternary classification model (predict volatility direction):")
print("   Features: X_train, X_test")
print("   Labels: y_3class_train_full, y_3class_test_final")
print("   Description: 0=calm, 1=big rise(>2.5%), 2=big drop(<-2.5%)")
print("   Suitable for: directional trading strategies, use 'categorical_crossentropy' loss function")
print("   Output layer: Dense(3, activation='softmax')")
print("")
print("3. Feature descriptions:")
for i, feat in enumerate(features, 1):
    feat_desc = {
        'Open': 'Opening price',
        'High': 'Highest price',
        'Low': 'Lowest price',
        'Close': 'Closing price',
        'Volume_log': 'Logarithm of trading volume',
        'volatility_ratio_5_20': '5-day / 20-day volatility ratio',
        'momentum_5_sign_strength': '5-day momentum strength adjusted by volatility',
        'volume_price_divergence': 'Volume-price divergence indicator'
    }
    print(f"   {i:2d}. {feat:25s} - {feat_desc.get(feat, '')}")

# 3. LSTM Scenario 1


class 0: No significant fluctuations expected in the next 2 days

class 1: Significant fluctuations expected in the next 2 days

In [None]:
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout, BatchNormalization
from tensorflow.keras.optimizers import Adam
from sklearn.metrics import accuracy_score, classification_report, roc_auc_score, confusion_matrix, precision_score, recall_score, f1_score
import matplotlib.pyplot as plt
from sklearn.utils.class_weight import compute_class_weight
import warnings
warnings.filterwarnings('ignore')

# ============================================================================
# Practical Strategy: Find High Precision Threshold on Validation Set
# Goal: Validation set precision ‚â• 70%, Test set precision ‚â• 60%
# ============================================================================

print("="*60)
print("Practical High Precision Threshold Strategy")
print("Strategy: Find threshold achieving precision ‚â• 70% on validation set")
print("Expectation: Achieve precision ‚â• 60% on test set")
print("="*60)

# ============================================================================
# 1. Data Preparation
# ============================================================================

y_train_full = y_binary_train_full.copy()
y_test_final = y_binary_test_final.copy()

print(f"Data Distribution:")
print(f"  Test Set: Calm={np.sum(y_test_final==0)} ({np.mean(y_test_final==0):.2%}), "
      f"High Volatility={np.sum(y_test_final==1)} ({np.mean(y_test_final==1):.2%})")

# Split training and validation sets
train_ratio = 0.8
train_idx = int(len(X_train) * train_ratio)

X_train_final = X_train[:train_idx]
X_val = X_train[train_idx:]
y_train_final = y_train_full[:train_idx]
y_val = y_train_full[train_idx:]

print(f"\nData Split:")
print(f"  Training Set: {len(X_train_final)} samples")
print(f"  Validation Set: {len(X_val)} samples (for threshold selection)")
print(f"  Test Set: {len(X_test)} samples (for final evaluation)")

# ============================================================================
# 2. Train Model (or Load Existing Model)
# ============================================================================

print("\n" + "="*60)
print("Model Training")
print("="*60)

try:
    # Try to load an existing model
    print("Attempting to load existing model...")
    model = tf.keras.models.load_model('precision_optimized_trading_model.keras')
    print("‚úÖ Model loaded successfully")
except:
    print("‚ö†Ô∏è  Unable to load model, training a new model...")
    
    def create_simple_model():
        model = Sequential([
            LSTM(64, input_shape=(window_size, len(features))),
            Dropout(0.3),
            BatchNormalization(),
            Dense(32, activation='relu'),
            Dropout(0.2),
            Dense(16, activation='relu'),
            Dense(1, activation='sigmoid')
        ])
        
        model.compile(
            optimizer=Adam(learning_rate=0.0005),
            loss='binary_crossentropy',
            metrics=['accuracy']
        )
        return model
    
    model = create_simple_model()
    
    # Train the model
    print("Training model...")
    history = model.fit(
        X_train_final, y_train_final,
        validation_data=(X_val, y_val),
        epochs=20,
        batch_size=64,
        verbose=1,
        callbacks=[
            tf.keras.callbacks.EarlyStopping(
                monitor='val_loss',
                patience=5,
                restore_best_weights=True,
                verbose=1
            )
        ]
    )

# ============================================================================
# 3. Find Threshold Achieving ‚â•70% Precision on Validation Set (Key Step)
# ============================================================================

print("\n" + "="*60)
print("Find Threshold Achieving ‚â•70% Precision on Validation Set")
print("="*60)

# Get prediction probabilities on validation set
y_val_proba = model.predict(X_val, verbose=0, batch_size=128).flatten()

# Find all thresholds achieving precision ‚â• 70%
thresholds = np.arange(0.5, 0.95, 0.005)  # Start search from 0.5
val_results = []

print("Searching thresholds...")
for thresh in thresholds:
    y_val_pred = (y_val_proba > thresh).astype(int)
    
    if np.sum(y_val_pred == 1) > 0:
        precision = np.mean(y_val[y_val_pred == 1] == 1)
        recall = np.mean(y_val_pred[y_val == 1] == 1) if np.sum(y_val == 1) > 0 else 0
        signals = np.sum(y_val_pred == 1)
        
        val_results.append({
            'threshold': thresh,
            'precision': precision,
            'recall': recall,
            'signals': signals,
            'signal_ratio': signals / len(y_val)
        })

val_results_df = pd.DataFrame(val_results)

# Find all thresholds with precision ‚â• 70%
high_precision_thresholds = val_results_df[val_results_df['precision'] >= 0.7].copy()

if len(high_precision_thresholds) > 0:
    print(f"\n‚úÖ Found {len(high_precision_thresholds)} thresholds achieving ‚â•70% precision")
    
    # Select the threshold with the highest number of signals (while maintaining precision)
    high_precision_thresholds = high_precision_thresholds.sort_values('signals', ascending=False)
    
    print(f"\nThresholds achieving ‚â•70% precision on validation set (sorted by signal count):")
    print(f"{'Threshold':<8} {'Precision':<10} {'Signals':<10} {'Signal Ratio':<10}")
    print("-" * 45)
    
    for i, row in high_precision_thresholds.head(10).iterrows():
        print(f"{row['threshold']:.3f}     {row['precision']:.2%}     "
              f"{int(row['signals']):<10} {row['signal_ratio']:.2%}")
    
    # Threshold selection strategy
    print(f"\nüîç Threshold Selection Strategy:")
    
    # Strategy 1: Select threshold with the most signals (most practical)
    selected_threshold = high_precision_thresholds.iloc[0]['threshold']
    val_precision = high_precision_thresholds.iloc[0]['precision']
    val_signals = int(high_precision_thresholds.iloc[0]['signals'])
    
    print(f"Strategy 1 - Maximize Signal Count:")
    print(f"  Threshold: {selected_threshold:.3f}")
    print(f"  Validation Set Precision: {val_precision:.2%}")
    print(f"  Validation Set Signals: {val_signals} ({high_precision_thresholds.iloc[0]['signal_ratio']:.2%})")
    
    # If too few signals, try slightly lower precision but more signals
    if val_signals < 100:
        print(f"\n‚ö†Ô∏è  Low signal count ({val_signals}), relaxing conditions...")
        
        # Find thresholds achieving precision ‚â• 65%
        medium_precision_thresholds = val_results_df[val_results_df['precision'] >= 0.65].copy()
        if len(medium_precision_thresholds) > 0:
            medium_precision_thresholds = medium_precision_thresholds.sort_values('signals', ascending=False)
            selected_threshold = medium_precision_thresholds.iloc[0]['threshold']
            val_precision = medium_precision_thresholds.iloc[0]['precision']
            val_signals = int(medium_precision_thresholds.iloc[0]['signals'])
            
            print(f"Strategy 2 - Precision ‚â• 65%:")
            print(f"  Threshold: {selected_threshold:.3f}")
            print(f"  Validation Set Precision: {val_precision:.2%}")
            print(f"  Validation Set Signals: {val_signals} ({medium_precision_thresholds.iloc[0]['signal_ratio']:.2%})")
    
else:
    print(f"\n‚ö†Ô∏è  No threshold found achieving ‚â•70% precision")
    print(f"  Maximum precision on validation set: {val_results_df['precision'].max():.2%}")
    
    # Find thresholds achieving precision ‚â• 65%
    medium_precision_thresholds = val_results_df[val_results_df['precision'] >= 0.65].copy()
    if len(medium_precision_thresholds) > 0:
        medium_precision_thresholds = medium_precision_thresholds.sort_values('signals', ascending=False)
        selected_threshold = medium_precision_thresholds.iloc[0]['threshold']
        val_precision = medium_precision_thresholds.iloc[0]['precision']
        val_signals = int(medium_precision_thresholds.iloc[0]['signals'])
        
        print(f"Using threshold achieving ‚â•65% precision:")
        print(f"  Threshold: {selected_threshold:.3f}")
        print(f"  Validation Set Precision: {val_precision:.2%}")
        print(f"  Validation Set Signals: {val_signals} ({medium_precision_thresholds.iloc[0]['signal_ratio']:.2%})")
    else:
        # Use threshold with highest precision
        best_row = val_results_df.loc[val_results_df['precision'].idxmax()]
        selected_threshold = best_row['threshold']
        val_precision = best_row['precision']
        val_signals = int(best_row['signals'])
        
        print(f"Using threshold with highest precision:")
        print(f"  Threshold: {selected_threshold:.3f}")
        print(f"  Validation Set Precision: {val_precision:.2%}")
        print(f"  Validation Set Signals: {val_signals} ({best_row['signal_ratio']:.2%})")

# ============================================================================
# 4. Evaluate Selected Threshold on Test Set
# ============================================================================

print("\n" + "="*60)
print("Evaluate Selected Threshold on Test Set")
print(f"Threshold: {selected_threshold:.3f} (from validation set)")
print("="*60)

# Get prediction probabilities on test set
y_test_proba = model.predict(X_test, verbose=0, batch_size=128).flatten()

# Use the threshold selected from validation set
y_test_pred = (y_test_proba > selected_threshold).astype(int)

# Calculate test set metrics
test_precision = precision_score(y_test_final, y_test_pred, zero_division=0)
test_recall = recall_score(y_test_final, y_test_pred, zero_division=0)
test_f1 = f1_score(y_test_final, y_test_pred, zero_division=0)
test_accuracy = accuracy_score(y_test_final, y_test_pred)
test_auc = roc_auc_score(y_test_final, y_test_proba)
test_signals = np.sum(y_test_pred == 1)

print(f"Test Set Performance:")
print(f"  Threshold: {selected_threshold:.3f}")
print(f"  Precision: {test_precision:.4f} (Target: ‚â•60%)")
print(f"  Recall: {test_recall:.4f}")
print(f"  F1 Score: {test_f1:.4f}")
print(f"  Accuracy: {test_accuracy:.4f}")
print(f"  AUC: {test_auc:.4f}")
print(f"  Signal Count: {test_signals} ({test_signals/len(y_test_final):.2%})")

# Check if target is achieved
if test_precision >= 0.6:
    print(f"\n‚úÖ Success! Test set precision ‚â• 60% ({test_precision:.2%})")
else:
    print(f"\n‚ö†Ô∏è  Target not reached! Test set precision: {test_precision:.2%} (< 60%)")

print(f"\nClassification Report:")
print(classification_report(y_test_final, y_test_pred,
                           target_names=['Calm', 'High Volatility']))

# Confusion matrix
cm_test = confusion_matrix(y_test_final, y_test_pred)
print(f"Confusion Matrix:")
print(f"               Predicted Calm    Predicted High Volatility")
print(f"Actual Calm      {cm_test[0,0]:8d}               {cm_test[0,1]:8d}")
print(f"Actual High Volatility    {cm_test[1,0]:8d}               {cm_test[1,1]:8d}")

# ============================================================================
# 5. Visualization Analysis
# ============================================================================

print("\n" + "="*60)
print("Visualization Analysis")
print("="*60)

fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# 1. Validation Set Precision vs Threshold Curve
axes[0, 0].plot(val_results_df['threshold'], val_results_df['precision'], 'b-', linewidth=2)
axes[0, 0].axhline(y=0.7, color='red', linestyle='--', alpha=0.5, label='Target Line (70%)')
axes[0, 0].axvline(x=selected_threshold, color='green', linestyle='--', 
                  label=f'Selected Threshold ({selected_threshold:.3f})')
axes[0, 0].set_title('Validation Set: Precision vs Threshold')
axes[0, 0].set_xlabel('Threshold')
axes[0, 0].set_ylabel('Precision')
axes[0, 0].legend()
axes[0, 0].grid(True)

# 2. Validation Set Signal Count vs Threshold Curve
axes[0, 1].plot(val_results_df['threshold'], val_results_df['signals'], 'g-', linewidth=2)
axes[0, 1].axvline(x=selected_threshold, color='green', linestyle='--')
axes[0, 1].set_title('Validation Set: Signal Count vs Threshold')
axes[0, 1].set_xlabel('Threshold')
axes[0, 1].set_ylabel('Signal Count')
axes[0, 1].grid(True)

# 3. Test Set Prediction Probability Distribution
axes[0, 2].hist(y_test_proba[y_test_final == 0], alpha=0.5, label='Actual Calm', bins=50, color='blue')
axes[0, 2].hist(y_test_proba[y_test_final == 1], alpha=0.5, label='Actual High Volatility', bins=50, color='red')
axes[0, 2].axvline(x=selected_threshold, color='green', linestyle='--', linewidth=2, 
                  label=f'Threshold={selected_threshold:.3f}')
axes[0, 2].set_title('Test Set: Prediction Probability Distribution')
axes[0, 2].set_xlabel('Probability of High Volatility')
axes[0, 2].set_ylabel('Sample Count')
axes[0, 2].legend()
axes[0, 2].grid(True)

# 4. ROC Curve
from sklearn.metrics import roc_curve
fpr, tpr, _ = roc_curve(y_test_final, y_test_proba)
axes[1, 0].plot(fpr, tpr, 'b-', label=f'AUC={test_auc:.3f}')
axes[1, 0].plot([0, 1], [0, 1], 'k--')
axes[1, 0].set_title('ROC Curve')
axes[1, 0].set_xlabel('False Positive Rate')
axes[1, 0].set_ylabel('True Positive Rate')
axes[1, 0].legend()
axes[1, 0].grid(True)

# 5. Precision-Recall Curve
from sklearn.metrics import precision_recall_curve
precision_curve, recall_curve, _ = precision_recall_curve(y_test_final, y_test_proba)
axes[1, 1].plot(recall_curve, precision_curve, 'purple', linewidth=2)
axes[1, 1].axhline(y=0.6, color='red', linestyle='--', alpha=0.5, label='Target Line (60%)')
axes[1, 1].scatter([test_recall], [test_precision], color='green', s=100, marker='o', 
                  label=f'Current Point ({test_recall:.2f}, {test_precision:.2f})')
axes[1, 1].set_title('Precision-Recall Curve')
axes[1, 1].set_xlabel('Recall')
axes[1, 1].set_ylabel('Precision')
axes[1, 1].legend()
axes[1, 1].grid(True)

# 6. Performance Comparison
categories = ['Validation Precision', 'Test Precision', 'Validation Signals/100', 'Test Signals/100']
values = [val_precision, test_precision, val_signals/100, test_signals/100]
colors = ['blue', 'red', 'green', 'orange']

bars = axes[1, 2].bar(categories, values, color=colors)
axes[1, 2].set_title('Validation Set vs Test Set Performance Comparison')
axes[1, 2].set_ylabel('Value (Precision as %, Signals/100)')
axes[1, 2].axhline(y=0.7, color='red', linestyle='--', alpha=0.5, label='70% Precision Target')
axes[1, 2].axhline(y=0.6, color='orange', linestyle='--', alpha=0.5, label='60% Precision Target')
axes[1, 2].legend()

# Add value labels on bars
for bar, value in zip(bars, values):
    height = bar.get_height()
    axes[1, 2].text(bar.get_x() + bar.get_width()/2., height + 0.02,
                   f'{value:.2f}' if value < 1 else f'{int(value*100)}',
                   ha='center', va='bottom')

plt.tight_layout()
plt.show()

# ============================================================================
# 6. Save Model and Configuration
# ============================================================================

print("\n" + "="*60)
print("Save Model and Configuration")
print("="*60)

# ============================================================================
# 6. ‰∫§ÊòìÊÄßËÉΩÂàÜÊûê
# ============================================================================

print("\n" + "="*60)
print("Trading Performance Analysis")
print("="*60)

# È¢ÑÂÖàÂÆö‰πâÂèòÈáèÔºåÈÅøÂÖçÂêéÁª≠ÂºïÁî®ÈîôËØØ
expected_return_per_trade = 0

print(f"üéØ Trading Signal Quality:")
print(f"  Total Signals: {test_signals}")
print(f"  Correct Signals: {cm_test[1,1]} (Success Rate: {test_precision:.1%})")
print(f"  Wrong Signals: {cm_test[0,1]} (Risk Rate: {cm_test[0,1]/test_signals if test_signals>0 else 0:.1%})")
print(f"  Missed Opportunities: {cm_test[1,0]} (Opportunity Cost)")

if test_signals > 0:
    # ÁÆÄÂçïÊî∂ÁõäËÆ°ÁÆó
    win_rate = test_precision
    loss_rate = 1 - win_rate
    avg_win = 0.025  # Âπ≥ÂùáÁõàÂà©2.5%
    avg_loss = 0.015  # Âπ≥Âùá‰∫èÊçü1.5%
    
    expected_return_per_trade = win_rate * avg_win - loss_rate * avg_loss
    total_expected_return = expected_return_per_trade * test_signals
    
    print(f"\nüí∞ Expected Returns:")
    print(f"  Expected Return Per Trade: {expected_return_per_trade:.2%}")
    print(f"  Total Expected Return: {total_expected_return:.2%}")
    print(f"  Estimated Annual Trade Count: {test_signals/(len(y_test_final)/252):.0f} trades")
    print(f"  Annualized Expected Return: {expected_return_per_trade * (test_signals/(len(y_test_final)/252)):.1%}")
    
    if expected_return_per_trade > 0:
        print(f"  ‚úÖ Positive expected return, model has trading value")
    else:
        print(f"  ‚ö†Ô∏è  Negative expected return, needs optimization")

# ============================================================================
# 7. ‰øùÂ≠òÊ®°ÂûãÂíåÈÖçÁΩÆ
# ============================================================================

print("\n" + "="*60)
print("Save Model and Configuration")
print("="*60)

# save model
model.save(f'practical_model_threshold_{selected_threshold:.3f}.keras')
print(f"‚úÖ Model saved: practical_model_threshold_{selected_threshold:.3f}.keras")

# ‰øùÂ≠òÈÖçÁΩÆ
import joblib

if 'total_expected_return' not in locals():
    total_expected_return = 0

if 'test_signals' not in locals():
    test_signals = 0

if 'cm_test' not in locals():
    cm_test = np.array([[0, 0], [0, 0]])

config = {
    'strategy': 'Find ‚â•70% precision threshold on validation set, expecting ‚â•60% on test set',
    'threshold_selection': {
        'source': 'validation_set',
        'min_precision_target': 0.7,
        'selected_threshold': float(selected_threshold),
        'validation_performance': {
            'precision': float(val_precision),
            'recall': float(val_results_df[val_results_df['threshold'] == selected_threshold]['recall'].iloc[0]),
            'signals': int(val_signals),
            'signal_ratio': float(val_results_df[val_results_df['threshold'] == selected_threshold]['signal_ratio'].iloc[0])
        }
    },
    'test_performance': {
        'precision': float(test_precision),
        'recall': float(test_recall),
        'f1': float(test_f1),
        'accuracy': float(test_accuracy),
        'auc': float(test_auc),
        'signals': int(test_signals) if 'test_signals' in locals() else 0,
        'signal_ratio': float(test_signals / len(y_test_final)) if 'test_signals' in locals() and test_signals > 0 else 0
    },
    'trading_analysis': {
        'expected_return_per_trade': float(expected_return_per_trade) if test_signals>0 else 0,
        'total_expected_return': float(total_expected_return),
        'win_rate': float(test_precision),
        'total_signals': int(test_signals) if 'test_signals' in locals() else 0,
        'correct_signals': int(cm_test[1,1]) if cm_test.size > 1 else 0,
        'wrong_signals': int(cm_test[0,1]) if cm_test.size > 1 else 0,
        'annualized_trades': float(test_signals/(len(y_test_final)/252)) if 'test_signals' in locals() and test_signals > 0 else 0
    }
}
joblib.dump(config, f'practical_model_config_{selected_threshold:.3f}.pkl')
print(f"‚úÖ Configuration saved: practical_model_config_{selected_threshold:.3f}.pkl")

# ============================================================================
# 7. Final Recommendations
# ============================================================================

print("\n" + "="*60)
print("Final Recommendations")
print("="*60)

print(f"üéØ Strategy Summary:")
print(f"  1. Find threshold achieving ‚â•70% precision on validation set")
print(f"  2. Selected threshold: {selected_threshold:.3f}")
print(f"  3. Validation set precision: {val_precision:.1%}")
print(f"  4. Test set precision: {test_precision:.1%}")

print(f"\nüìä Performance Evaluation:")
if test_precision >= 0.6:
    print(f"  ‚úÖ Target achieved! Test set precision ‚â• 60%")
    print(f"  ‚úÖ Model can be used for actual trading")
    
else:
    print(f"  ‚ö†Ô∏è  Target not reached! Test set precision: {test_precision:.1%} (< 60%)")
    print(f"  ‚ö†Ô∏è  Further optimization needed")
    
    print(f"\nüí° Optimization Suggestions:")
    print(f"  1. Try higher thresholds for higher precision")
    print(f"  2. Retrain model, adjust class weights")
    print(f"  3. Improve feature engineering")

print(f"\nüîß Usage Instructions:")
print(f"  1. Load model: practical_model_threshold_{selected_threshold:.3f}.keras")
print(f"  2. Predict on new data: probabilities = model.predict(X_new)")
print(f"  3. Generate trading signals: signals = probabilities > {selected_threshold:.3f}")

print(f"\n‚úÖ Practical High Precision Strategy Complete!")

# 4. LSTM Scenario 2

class 0: No significant fluctuations expected in the next 2 days

class 1: Significant fluctuations expected (upward trend) in the next 2 days

class 2: Significant fluctuations expected (downward trend) in the next 2 days

In [None]:
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout, BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import to_categorical
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, precision_score, recall_score, f1_score
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.utils.class_weight import compute_class_weight
import warnings
warnings.filterwarnings('ignore')

# ============================================================================
# Practical Ternary Classification Strategy: Optimization Based on 0.4 Baseline Threshold
# Goal: Precision for both big rise (1) and big drop (2) predictions ‚â•55%
# Baseline threshold: 0.4 (more suitable for ternary classification and imbalanced data scenarios)
# ============================================================================

print("="*60)
print("Ternary Classification High Precision Threshold Strategy - Based on 0.4 Baseline")
print("Strategy: Find thresholds achieving ‚â•55% precision for both big rise and big drop on validation set")
print("Baseline threshold: 0.4 (more reasonable starting point for ternary classification)")
print("Goal: Achieve ‚â•55% precision for both big rise and big drop on test set")
print("="*60)

# ============================================================================
# 1. Data Preparation - Ternary Classification
# ============================================================================

y_train_full = y_3class_train_full.copy()  # 0=calm, 1=big rise, 2=big drop
y_test_final = y_3class_test_final.copy()

# Convert to one-hot encoding for model training
y_train_onehot = to_categorical(y_train_full, num_classes=3)

print(f"Data Distribution:")
print(f"Training Set Label Distribution:")
for i, label in enumerate(['Calm', 'Big Rise', 'Big Drop']):
    count = np.sum(y_train_full == i)
    print(f"  Class {i}({label}): {count} samples ({count/len(y_train_full):.2%})")

print(f"\nTest Set Label Distribution:")
for i, label in enumerate(['Calm', 'Big Rise', 'Big Drop']):
    count = np.sum(y_test_final == i)
    print(f"  Class {i}({label}): {count} samples ({count/len(y_test_final):.2%})")

# Split training and validation sets
train_ratio = 0.8
train_idx = int(len(X_train) * train_ratio)

X_train_final = X_train[:train_idx]
X_val = X_train[train_idx:]
y_train_final = y_train_full[:train_idx]
y_val = y_train_full[train_idx:]
y_train_onehot_final = y_train_onehot[:train_idx]
y_val_onehot = y_train_onehot[train_idx:]

print(f"\nData Split:")
print(f"  Training Set: {len(X_train_final)} samples")
print(f"  Validation Set: {len(X_val)} samples (for threshold selection)")
print(f"  Test Set: {len(X_test)} samples (for final evaluation)")

# ============================================================================
# 2. Train Model (or Load Existing Model)
# ============================================================================

print("\n" + "="*60)
print("Model Training - Ternary Classification")
print("="*60)

try:
    # Try to load existing model
    print("Attempting to load existing model...")
    model = tf.keras.models.load_model('ternary_classification_trading_model.keras')
    print("‚úÖ Model loaded successfully")
except:
    print("‚ö†Ô∏è  Unable to load model, training a new model...")
    
    def create_ternary_model():
        model = Sequential([
            LSTM(64, input_shape=(window_size, len(features)), return_sequences=False),
            Dropout(0.3),
            BatchNormalization(),
            Dense(48, activation='relu'),
            Dropout(0.2),
            Dense(32, activation='relu'),
            Dense(3, activation='softmax')  # Ternary classification output layer
        ])
        
        model.compile(
            optimizer=Adam(learning_rate=0.0005),
            loss='categorical_crossentropy',
            metrics=['accuracy']
        )
        return model
    
    model = create_ternary_model()
    
    # Calculate class weights (handle imbalanced data)
    class_weights = compute_class_weight(
        class_weight='balanced',
        classes=np.unique(y_train_final),
        y=y_train_final
    )
    class_weight_dict = {i: class_weights[i] for i in range(3)}
    
    print(f"Class Weights: {class_weight_dict}")
    
    # Train the model
    print("Training model...")
    history = model.fit(
        X_train_final, y_train_onehot_final,
        validation_data=(X_val, y_val_onehot),
        epochs=30,
        batch_size=64,
        verbose=1,
        class_weight=class_weight_dict,
        callbacks=[
            tf.keras.callbacks.EarlyStopping(
                monitor='val_loss',
                patience=8,
                restore_best_weights=True,
                verbose=1
            ),
            tf.keras.callbacks.ReduceLROnPlateau(
                monitor='val_loss',
                factor=0.5,
                patience=3,
                verbose=1
            )
        ]
    )

# Get validation set prediction probabilities for threshold search
y_val_proba = model.predict(X_val, verbose=0, batch_size=128)

# ============================================================================
# Optimized Threshold Strategy - Adjusted According to Your Requirements
# ============================================================================

print("\n" + "="*60)
print("Optimized Threshold Strategy")
print("1. Default threshold set to 0.4 (ternary classification)")
print("2. Validation set goal: Big rise and big drop precision ‚â•60%")
print("3. Threshold enumeration step size: 0.01")
print("="*60)

# ============================================================================
# 1. Re-search Thresholds (Finer Step Size)
# ============================================================================

print("\nRe-searching thresholds, step size: 0.01")
rise_thresholds = np.arange(0.3, 0.8, 0.01)  # 0.3 to 0.8, step 0.01
drop_thresholds = np.arange(0.3, 0.8, 0.01)

val_results_fine = []
print("Fine-grained search in progress...")

for rise_thresh in rise_thresholds:
    for drop_thresh in drop_thresholds:
        # Big rise prediction condition
        rise_mask = (y_val_proba[:, 1] > rise_thresh) & \
                   (y_val_proba[:, 1] > y_val_proba[:, 0]) & \
                   (y_val_proba[:, 1] > y_val_proba[:, 2])
        
        # Big drop prediction condition
        drop_mask = (y_val_proba[:, 2] > drop_thresh) & \
                   (y_val_proba[:, 2] > y_val_proba[:, 0]) & \
                   (y_val_proba[:, 2] > y_val_proba[:, 1])
        
        y_val_pred = np.zeros(len(y_val), dtype=int)
        y_val_pred[rise_mask] = 1
        y_val_pred[drop_mask] = 2
        
        # Calculate metrics
        rise_indices = np.where(y_val_pred == 1)[0]
        drop_indices = np.where(y_val_pred == 2)[0]
        
        # Big rise metrics
        if len(rise_indices) > 0:
            rise_precision = np.mean(y_val[rise_indices] == 1)
        else:
            rise_precision = 0
            
        # Big drop metrics
        if len(drop_indices) > 0:
            drop_precision = np.mean(y_val[drop_indices] == 2)
        else:
            drop_precision = 0
        
        rise_signals = len(rise_indices)
        drop_signals = len(drop_indices)
        total_signals = rise_signals + drop_signals
        
        # Calculate recall
        total_rise = np.sum(y_val == 1)
        total_drop = np.sum(y_val == 2)
        
        if total_rise > 0:
            rise_recall = np.sum((y_val_pred == 1) & (y_val == 1)) / total_rise
        else:
            rise_recall = 0
            
        if total_drop > 0:
            drop_recall = np.sum((y_val_pred == 2) & (y_val == 2)) / total_drop
        else:
            drop_recall = 0
        
        val_results_fine.append({
            'rise_threshold': rise_thresh,
            'drop_threshold': drop_thresh,
            'rise_precision': rise_precision,
            'drop_precision': drop_precision,
            'rise_signals': rise_signals,
            'drop_signals': drop_signals,
            'total_signals': total_signals,
            'rise_recall': rise_recall,
            'drop_recall': drop_recall,
            'signal_ratio': total_signals / len(y_val)
        })

val_results_df_fine = pd.DataFrame(val_results_fine)

# ============================================================================
# 2. Threshold Selection Strategy (According to Your Requirements)
# ============================================================================

print("\n" + "="*60)
print("Threshold Selection Strategy")
print("Primary goal: Big rise and big drop precision ‚â•60%")
print("Alternative goal: If signals are too few, relax to ‚â•58%")
print("="*60)

# Strategy 1: Precision ‚â•60%
thresholds_60 = val_results_df_fine[
    (val_results_df_fine['rise_precision'] >= 0.60) & 
    (val_results_df_fine['drop_precision'] >= 0.60)
].copy()

if len(thresholds_60) > 0:
    print(f"\n‚úÖ Found {len(thresholds_60)} threshold combinations with both big rise and big drop precision ‚â•60%")
    
    # Sort by signal count
    thresholds_60 = thresholds_60.sort_values('total_signals', ascending=False)
    
    print(f"\nThreshold combinations with precision ‚â•60% (top 10):")
    print(f"{'Rise Thresh':<10} {'Drop Thresh':<10} {'Rise Prec':<12} {'Drop Prec':<12} {'Rise Signals':<12} {'Drop Signals':<12} {'Total Signals':<12}")
    print("-" * 90)
    
    for i, row in thresholds_60.head(10).iterrows():
        print(f"{row['rise_threshold']:.3f}      {row['drop_threshold']:.3f}      "
              f"{row['rise_precision']:.2%}        {row['drop_precision']:.2%}        "
              f"{int(row['rise_signals']):<12} {int(row['drop_signals']):<12} {int(row['total_signals']):<12}")
    
    # Check if signal count is reasonable
    if thresholds_60.iloc[0]['total_signals'] >= 50:
        print(f"\nStrategy 1A: Sufficient signal count, using ‚â•60% precision combination")
        selected_row = thresholds_60.iloc[0]
        selected_rise_threshold = selected_row['rise_threshold']
        selected_drop_threshold = selected_row['drop_threshold']
        selection_strategy = "60% precision, signal priority"
        
    else:
        print(f"\n‚ö†Ô∏è  Best combination has too few signals ({int(thresholds_60.iloc[0]['total_signals'])}), relaxing to 58%")
        
        # Strategy 2: Precision ‚â•58%
        thresholds_58 = val_results_df_fine[
            (val_results_df_fine['rise_precision'] >= 0.58) & 
            (val_results_df_fine['drop_precision'] >= 0.58)
        ].copy()
        
        if len(thresholds_58) > 0:
            thresholds_58 = thresholds_58.sort_values('total_signals', ascending=False)
            
            # Find combinations with signal count ‚â•100
            good_thresholds_58 = thresholds_58[thresholds_58['total_signals'] >= 100]
            
            if len(good_thresholds_58) > 0:
                print(f"\nStrategy 2A: Found 58% precision combination with signals ‚â•100")
                selected_row = good_thresholds_58.iloc[0]
                selection_strategy = "58% precision, signals ‚â•100"
            else:
                # Select 58% precision combination with most signals
                print(f"\nStrategy 2B: Using 58% precision combination with most signals")
                selected_row = thresholds_58.iloc[0]
                selection_strategy = "58% precision, most signals"
            
            selected_rise_threshold = selected_row['rise_threshold']
            selected_drop_threshold = selected_row['drop_threshold']
            
        else:
            print(f"\n‚ö†Ô∏è  No combinations found with ‚â•58% precision, using ‚â•55% precision")
            
            # Strategy 3: Precision ‚â•55%
            thresholds_55 = val_results_df_fine[
                (val_results_df_fine['rise_precision'] >= 0.55) & 
                (val_results_df_fine['drop_precision'] >= 0.55)
            ].copy()
            
            if len(thresholds_55) > 0:
                thresholds_55 = thresholds_55.sort_values('total_signals', ascending=False)
                selected_row = thresholds_55.iloc[0]
                selected_rise_threshold = selected_row['rise_threshold']
                selected_drop_threshold = selected_row['drop_threshold']
                selection_strategy = "55% precision, most signals"
            else:
                print(f"\n‚ùå No combinations found with ‚â•55% precision, using default threshold 0.4")
                selected_rise_threshold = 0.4  # Default value
                selected_drop_threshold = 0.4  # Default value
                selection_strategy = "Default threshold 0.4"
                
else:
    print(f"\n‚ö†Ô∏è  No threshold combinations found with both big rise and big drop precision ‚â•60%")
    print(f"  Maximum big rise precision on validation set: {val_results_df_fine['rise_precision'].max():.2%}")
    print(f"  Maximum big drop precision on validation set: {val_results_df_fine['drop_precision'].max():.2%}")
    
    # Directly try 58% precision
    print(f"\nTrying combinations with precision ‚â•58%...")
    thresholds_58 = val_results_df_fine[
        (val_results_df_fine['rise_precision'] >= 0.58) & 
        (val_results_df_fine['drop_precision'] >= 0.58)
    ].copy()
    
    if len(thresholds_58) > 0:
        thresholds_58 = thresholds_58.sort_values('total_signals', ascending=False)
        selected_row = thresholds_58.iloc[0]
        selected_rise_threshold = selected_row['rise_threshold']
        selected_drop_threshold = selected_row['drop_threshold']
        selection_strategy = "58% precision (no 60% results)"
        
        print(f"‚úÖ Found ‚â•58% precision combination")
        print(f"  Big rise threshold: {selected_rise_threshold:.3f}")
        print(f"  Big drop threshold: {selected_drop_threshold:.3f}")
        print(f"  Big rise precision: {selected_row['rise_precision']:.2%}")
        print(f"  Big drop precision: {selected_row['drop_precision']:.2%}")
        print(f"  Total signals: {int(selected_row['total_signals'])}")
        
    else:
        print(f"‚ö†Ô∏è  No combinations found with ‚â•58% precision, using default threshold 0.4")
        selected_rise_threshold = 0.4  # Default value
        selected_drop_threshold = 0.4  # Default value
        selection_strategy = "Default threshold 0.4"

# ============================================================================
# 3. Display Final Selected Thresholds
# ============================================================================

print("\n" + "="*60)
print("Final Selected Thresholds")
print("="*60)

print(f"Selection Strategy: {selection_strategy}")
print(f"Big rise threshold: {selected_rise_threshold:.3f}")
print(f"Big drop threshold: {selected_drop_threshold:.3f}")

if selection_strategy != "Default threshold 0.4":
    # Get metrics for selected thresholds
    mask = (val_results_df_fine['rise_threshold'] == selected_rise_threshold) & \
           (val_results_df_fine['drop_threshold'] == selected_drop_threshold)
    
    if mask.any():
        row_data = val_results_df_fine[mask].iloc[0]
        print(f"\nValidation Set Performance:")
        print(f"  Big rise precision: {row_data['rise_precision']:.2%}")
        print(f"  Big drop precision: {row_data['drop_precision']:.2%}")
        print(f"  Big rise signals: {int(row_data['rise_signals'])}")
        print(f"  Big drop signals: {int(row_data['drop_signals'])}")
        print(f"  Total signals: {int(row_data['total_signals'])} ({row_data['signal_ratio']:.2%})")
        print(f"  Big rise recall: {row_data['rise_recall']:.2%}")
        print(f"  Big drop recall: {row_data['drop_recall']:.2%}")

# ============================================================================
# 4. Evaluate on Test Set
# ============================================================================

print("\n" + "="*60)
print("Test Set Evaluation")
print("="*60)

# Get test set prediction probabilities
y_test_proba = model.predict(X_test, verbose=0, batch_size=128)

# Use selected thresholds for prediction
print(f"Using thresholds: Big rise={selected_rise_threshold:.3f}, Big drop={selected_drop_threshold:.3f}")

rise_mask_test = (y_test_proba[:, 1] > selected_rise_threshold) & \
                (y_test_proba[:, 1] > y_test_proba[:, 0]) & \
                (y_test_proba[:, 1] > y_test_proba[:, 2])

drop_mask_test = (y_test_proba[:, 2] > selected_drop_threshold) & \
                (y_test_proba[:, 2] > y_test_proba[:, 0]) & \
                (y_test_proba[:, 2] > y_test_proba[:, 1])

# Handle conflicts (theoretically shouldn't have conflicts since requiring highest probability)
conflict_mask_test = rise_mask_test & drop_mask_test
if np.sum(conflict_mask_test) > 0:
    print(f"‚ö†Ô∏è  Found {np.sum(conflict_mask_test)} conflicting samples, handling by probability")
    for idx in np.where(conflict_mask_test)[0]:
        if y_test_proba[idx, 1] > y_test_proba[idx, 2]:
            drop_mask_test[idx] = False
        else:
            rise_mask_test[idx] = False

# Generate final predictions
y_test_pred = np.zeros(len(y_test_final), dtype=int)
y_test_pred[rise_mask_test] = 1
y_test_pred[drop_mask_test] = 2

# Calculate metrics
test_accuracy = accuracy_score(y_test_final, y_test_pred)
test_precision = precision_score(y_test_final, y_test_pred, average=None, labels=[0, 1, 2], zero_division=0)
test_recall = recall_score(y_test_final, y_test_pred, average=None, labels=[0, 1, 2], zero_division=0)
test_f1 = f1_score(y_test_final, y_test_pred, average=None, labels=[0, 1, 2], zero_division=0)

test_rise_signals = np.sum(y_test_pred == 1)
test_drop_signals = np.sum(y_test_pred == 2)
test_total_signals = test_rise_signals + test_drop_signals

print(f"\nTest Set Performance:")
print(f"  Accuracy: {test_accuracy:.4f}")
print(f"  Big rise precision: {test_precision[1]:.4f} (Goal: ‚â•55%)")
print(f"  Big drop precision: {test_precision[2]:.4f} (Goal: ‚â•55%)")
print(f"  Big rise recall: {test_recall[1]:.4f}")
print(f"  Big drop recall: {test_recall[2]:.4f}")
print(f"  Big rise F1 score: {test_f1[1]:.4f}")
print(f"  Big drop F1 score: {test_f1[2]:.4f}")
print(f"  Big rise signal count: {test_rise_signals} ({test_rise_signals/len(y_test_final):.2%})")
print(f"  Big drop signal count: {test_drop_signals} ({test_drop_signals/len(y_test_final):.2%})")
print(f"  Total signal count: {test_total_signals} ({test_total_signals/len(y_test_final):.2%})")

# Check if goal is achieved
target_achieved = (test_precision[1] >= 0.55) and (test_precision[2] >= 0.55)
rise_achieved = test_precision[1] >= 0.55
drop_achieved = test_precision[2] >= 0.55

print(f"\nGoal Achievement Status:")
print(f"  Big rise precision: {test_precision[1]:.2%} {'‚úì' if rise_achieved else '‚úó'} {'(Achieved)' if rise_achieved else '(Below 55%)'}")
print(f"  Big drop precision: {test_precision[2]:.2%} {'‚úì' if drop_achieved else '‚úó'} {'(Achieved)' if drop_achieved else '(Below 55%)'}")

if target_achieved:
    print(f"‚úÖ Success! Both big rise and big drop precision ‚â• 55%")
elif rise_achieved and not drop_achieved:
    print(f"‚ö†Ô∏è  Partial success: Big rise achieved, big drop not achieved")
    print(f"  Suggestion: Can trade big rise, avoid big drop trades")
elif drop_achieved and not rise_achieved:
    print(f"‚ö†Ô∏è  Partial success: Big drop achieved, big rise not achieved")
    print(f"  Suggestion: Can trade big drop, avoid big rise trades")
else:
    print(f"‚ùå Not achieved: Both big rise and big drop precision < 55%")

# ============================================================================
# 5. Enhanced Visualization: Final Prediction Results Analysis
# ============================================================================

print("\n" + "="*60)
print("Enhanced Visualization: Final Prediction Results Analysis")
print("="*60)

# Create a comprehensive visualization figure
fig = plt.figure(figsize=(20, 16))
fig.suptitle(f'Ternary Classification Results Analysis\nRise Threshold: {selected_rise_threshold:.3f}, Drop Threshold: {selected_drop_threshold:.3f}', 
             fontsize=16, fontweight='bold')

# 1. Test Set Probability Distribution by True Class
ax1 = plt.subplot(3, 4, 1)
for i, label in enumerate(['Calm', 'Big Rise', 'Big Drop']):
    probs = y_test_proba[y_test_final == i, i]
    ax1.hist(probs, bins=30, alpha=0.6, label=f'{label} (True)', density=True)
ax1.set_title('1. True Class Probability Distribution')
ax1.set_xlabel('Predicted Probability')
ax1.set_ylabel('Density')
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.axvline(x=selected_rise_threshold, color='green', linestyle='--', alpha=0.7, label=f'Rise Thr={selected_rise_threshold:.3f}')
ax1.axvline(x=selected_drop_threshold, color='red', linestyle='--', alpha=0.7, label=f'Drop Thr={selected_drop_threshold:.3f}')

# 2. Probability Heatmap by True Class
ax2 = plt.subplot(3, 4, 2)
probs_by_class = []
labels = ['Calm Prob', 'Rise Prob', 'Drop Prob']
for i in range(3):
    probs_by_class.append(y_test_proba[y_test_final == i].mean(axis=0))
probs_matrix = np.array(probs_by_class)
im = ax2.imshow(probs_matrix, cmap='YlOrRd', aspect='auto', vmin=0, vmax=1)
ax2.set_title('2. Average Probability by True Class')
ax2.set_xlabel('Predicted Class Probability')
ax2.set_ylabel('True Class')
ax2.set_xticks(range(3))
ax2.set_yticks(range(3))
ax2.set_xticklabels(labels)
ax2.set_yticklabels(['Calm', 'Rise', 'Drop'])
for i in range(3):
    for j in range(3):
        text = ax2.text(j, i, f'{probs_matrix[i, j]:.3f}',
                       ha="center", va="center", color="white" if probs_matrix[i, j] > 0.5 else "black")
plt.colorbar(im, ax=ax2)

# 3. Signal Distribution by Predicted Class
ax3 = plt.subplot(3, 4, 3)
pred_counts = [np.sum(y_test_pred == i) for i in range(3)]
colors = ['blue', 'green', 'red']
bars = ax3.bar(['Calm', 'Rise', 'Drop'], pred_counts, color=colors, alpha=0.7)
ax3.set_title('3. Signal Distribution by Predicted Class')
ax3.set_ylabel('Count')
ax3.set_xlabel('Predicted Class')
for bar, count in zip(bars, pred_counts):
    height = bar.get_height()
    ax3.text(bar.get_x() + bar.get_width()/2., height + max(pred_counts)*0.01,
            f'{count}\n({count/len(y_test_pred):.1%})', ha='center', va='bottom', fontsize=9)
ax3.grid(True, alpha=0.3)

# 4. Probability vs Threshold Analysis
ax4 = plt.subplot(3, 4, 4)
threshold_range = np.linspace(0.1, 0.9, 50)
rise_signals_by_thresh = []
drop_signals_by_thresh = []
for thresh in threshold_range:
    rise_mask = (y_test_proba[:, 1] > thresh)
    drop_mask = (y_test_proba[:, 2] > thresh)
    rise_signals_by_thresh.append(np.sum(rise_mask))
    drop_signals_by_thresh.append(np.sum(drop_mask))
ax4.plot(threshold_range, rise_signals_by_thresh, 'g-', label='Rise Signals', linewidth=2)
ax4.plot(threshold_range, drop_signals_by_thresh, 'r-', label='Drop Signals', linewidth=2)
ax4.axvline(x=selected_rise_threshold, color='green', linestyle='--', alpha=0.7)
ax4.axvline(x=selected_drop_threshold, color='red', linestyle='--', alpha=0.7)
ax4.set_title('4. Signal Count vs Threshold')
ax4.set_xlabel('Threshold')
ax4.set_ylabel('Signal Count')
ax4.legend()
ax4.grid(True, alpha=0.3)

# 5. Confusion Matrix Heatmap
ax5 = plt.subplot(3, 4, 5)
cm = confusion_matrix(y_test_final, y_test_pred)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax5, 
            xticklabels=['Calm', 'Rise', 'Drop'], 
            yticklabels=['Calm', 'Rise', 'Drop'])
ax5.set_title('5. Confusion Matrix')
ax5.set_xlabel('Predicted')
ax5.set_ylabel('True')

# 6. Precision-Recall by Class
ax6 = plt.subplot(3, 4, 6)
classes = ['Calm', 'Rise', 'Drop']
precisions = test_precision
recalls = test_recall
x = np.arange(len(classes))
width = 0.35
bars1 = ax6.bar(x - width/2, precisions, width, label='Precision', alpha=0.7)
bars2 = ax6.bar(x + width/2, recalls, width, label='Recall', alpha=0.7)
ax6.set_title('6. Precision & Recall by Class')
ax6.set_xlabel('Class')
ax6.set_ylabel('Score')
ax6.set_xticks(x)
ax6.set_xticklabels(classes)
ax6.axhline(y=0.55, color='red', linestyle='--', alpha=0.5, label='55% Target')
ax6.legend()
ax6.grid(True, alpha=0.3)

# 7. Probability Distribution for Rise Predictions
ax7 = plt.subplot(3, 4, 7)
rise_indices = np.where(y_test_pred == 1)[0]
if len(rise_indices) > 0:
    true_rise = y_test_final[rise_indices] == 1
    false_rise = ~true_rise
    if np.sum(true_rise) > 0:
        ax7.hist(y_test_proba[rise_indices][true_rise, 1], bins=20, alpha=0.7, 
                color='green', label='True Rise', density=True)
    if np.sum(false_rise) > 0:
        ax7.hist(y_test_proba[rise_indices][false_rise, 1], bins=20, alpha=0.7, 
                color='orange', label='False Rise', density=True)
    ax7.axvline(x=selected_rise_threshold, color='darkgreen', linestyle='--', linewidth=2)
    ax7.set_title(f'7. Rise Predictions (Precision: {test_precision[1]:.2%})')
    ax7.set_xlabel('Rise Probability')
    ax7.set_ylabel('Density')
    ax7.legend()
    ax7.grid(True, alpha=0.3)

# 8. Probability Distribution for Drop Predictions
ax8 = plt.subplot(3, 4, 8)
drop_indices = np.where(y_test_pred == 2)[0]
if len(drop_indices) > 0:
    true_drop = y_test_final[drop_indices] == 2
    false_drop = ~true_drop
    if np.sum(true_drop) > 0:
        ax8.hist(y_test_proba[drop_indices][true_drop, 2], bins=20, alpha=0.7, 
                color='red', label='True Drop', density=True)
    if np.sum(false_drop) > 0:
        ax8.hist(y_test_proba[drop_indices][false_drop, 2], bins=20, alpha=0.7, 
                color='orange', label='False Drop', density=True)
    ax8.axvline(x=selected_drop_threshold, color='darkred', linestyle='--', linewidth=2)
    ax8.set_title(f'8. Drop Predictions (Precision: {test_precision[2]:.2%})')
    ax8.set_xlabel('Drop Probability')
    ax8.set_ylabel('Density')
    ax8.legend()
    ax8.grid(True, alpha=0.3)

# 9. 3D Probability Scatter Plot (simplified 2D projection)
ax9 = plt.subplot(3, 4, 9)
# Plot rise probability vs drop probability
scatter = ax9.scatter(y_test_proba[:, 1], y_test_proba[:, 2], 
                     c=y_test_pred, cmap='viridis', alpha=0.6, s=10)
ax9.set_title('9. Rise vs Drop Probability Scatter')
ax9.set_xlabel('Rise Probability')
ax9.set_ylabel('Drop Probability')
ax9.axvline(x=selected_rise_threshold, color='green', linestyle='--', alpha=0.5)
ax9.axhline(y=selected_drop_threshold, color='red', linestyle='--', alpha=0.5)
ax9.grid(True, alpha=0.3)

# 10. Performance Summary Bar Chart
ax10 = plt.subplot(3, 4, 10)
metrics = ['Accuracy', 'Rise Prec', 'Drop Prec', 'Rise Rec', 'Drop Rec']
values = [test_accuracy, test_precision[1], test_precision[2], 
          test_recall[1], test_recall[2]]
colors_metric = ['blue', 'lightgreen', 'lightcoral', 'green', 'red']
bars10 = ax10.bar(metrics, values, color=colors_metric, alpha=0.7)
ax10.set_title('10. Performance Metrics Summary')
ax10.set_ylabel('Score')
ax10.axhline(y=0.55, color='red', linestyle='--', alpha=0.5, label='55% Target')
for bar, value in zip(bars10, values):
    height = bar.get_height()
    ax10.text(bar.get_x() + bar.get_width()/2., height + 0.01,
             f'{value:.3f}', ha='center', va='bottom', fontsize=9)
ax10.set_ylim([0, 1.1])
ax10.grid(True, alpha=0.3)

# 11. Signal Quality Analysis
ax11 = plt.subplot(3, 4, 11)
if test_rise_signals > 0 and test_drop_signals > 0:
    signal_data = {
        'Rise': [test_rise_signals, test_precision[1]],
        'Drop': [test_drop_signals, test_precision[2]]
    }
    x = np.arange(2)
    width = 0.35
    
    # Signal count bars
    counts = [signal_data['Rise'][0], signal_data['Drop'][0]]
    bars_counts = ax11.bar(x - width/2, counts, width, label='Signal Count', alpha=0.7, color=['lightgreen', 'lightcoral'])
    
    # Precision bars
    precs = [signal_data['Rise'][1], signal_data['Drop'][1]]
    bars_precs = ax11.bar(x + width/2, precs, width, label='Precision', alpha=0.7, color=['green', 'red'])
    
    ax11.set_title('11. Signal Quality Analysis')
    ax11.set_xlabel('Signal Type')
    ax11.set_ylabel('Value')
    ax11.set_xticks(x)
    ax11.set_xticklabels(['Rise', 'Drop'])
    ax11.axhline(y=0.55, color='red', linestyle='--', alpha=0.5)
    ax11.legend()
    ax11.grid(True, alpha=0.3)

# 12. Model Confidence Distribution
ax12 = plt.subplot(3, 4, 12)
max_probs = np.max(y_test_proba, axis=1)
confidence_bins = [0, 0.3, 0.5, 0.7, 0.9, 1.0]
confidence_labels = ['Very Low', 'Low', 'Medium', 'High', 'Very High']
confidence_counts, _ = np.histogram(max_probs, bins=confidence_bins)
ax12.pie(confidence_counts, labels=confidence_labels, autopct='%1.1f%%', 
        colors=['red', 'orange', 'yellow', 'lightgreen', 'green'])
ax12.set_title('12. Model Confidence Distribution')

plt.tight_layout()
plt.show()

# ============================================================================
# Additional Detailed Analysis Charts
# ============================================================================

print("\n" + "="*60)
print("Additional Detailed Analysis")
print("="*60)

# Create second figure for detailed analysis
fig2, axes2 = plt.subplots(2, 3, figsize=(15, 10))
fig2.suptitle('Detailed Prediction Analysis', fontsize=14, fontweight='bold')

# 1. Class-wise ROC-like curves (probability distributions)
ax1 = axes2[0, 0]
for true_class in range(3):
    for pred_class in range(3):
        if true_class != pred_class:
            continue
        probs = y_test_proba[y_test_final == true_class, pred_class]
        if len(probs) > 0:
            ax1.hist(probs, bins=30, alpha=0.5, 
                    label=f'True {true_class} -> Pred {pred_class}', density=True)
ax1.set_title('1. Correct Prediction Probability Distribution')
ax1.set_xlabel('Probability')
ax1.set_ylabel('Density')
ax1.legend(fontsize=8)
ax1.grid(True, alpha=0.3)

# 2. Error analysis: misclassified samples
ax2 = axes2[0, 1]
misclassified_mask = y_test_pred != y_test_final
if np.sum(misclassified_mask) > 0:
    misclassified_probs = y_test_proba[misclassified_mask]
    misclassified_true = y_test_final[misclassified_mask]
    misclassified_pred = y_test_pred[misclassified_mask]
    
    error_types = []
    for true_label, pred_label in zip(misclassified_true, misclassified_pred):
        error_types.append(f'{true_label}->{pred_label}')
    
    unique_errors, error_counts = np.unique(error_types, return_counts=True)
    colors_error = plt.cm.Set3(np.linspace(0, 1, len(unique_errors)))
    ax2.pie(error_counts, labels=unique_errors, autopct='%1.1f%%', colors=colors_error)
    ax2.set_title('2. Error Type Distribution')

# 3. Threshold sensitivity analysis
ax3 = axes2[0, 2]
test_thresholds = np.linspace(0.3, 0.7, 20)
rise_precisions = []
drop_precisions = []
for thresh in test_thresholds:
    # Test with same threshold for both rise and drop
    rise_mask = (y_test_proba[:, 1] > thresh) & (y_test_proba[:, 1] > y_test_proba[:, 0]) & (y_test_proba[:, 1] > y_test_proba[:, 2])
    drop_mask = (y_test_proba[:, 2] > thresh) & (y_test_proba[:, 2] > y_test_proba[:, 0]) & (y_test_proba[:, 2] > y_test_proba[:, 1])
    
    if np.sum(rise_mask) > 0:
        rise_prec = np.mean(y_test_final[rise_mask] == 1)
    else:
        rise_prec = 0
    
    if np.sum(drop_mask) > 0:
        drop_prec = np.mean(y_test_final[drop_mask] == 2)
    else:
        drop_prec = 0
    
    rise_precisions.append(rise_prec)
    drop_precisions.append(drop_prec)

ax3.plot(test_thresholds, rise_precisions, 'g-', label='Rise Precision', linewidth=2)
ax3.plot(test_thresholds, drop_precisions, 'r-', label='Drop Precision', linewidth=2)
ax3.axvline(x=selected_rise_threshold, color='green', linestyle='--', alpha=0.7)
ax3.axvline(x=selected_drop_threshold, color='red', linestyle='--', alpha=0.7)
ax3.axhline(y=0.55, color='black', linestyle='--', alpha=0.5, label='55% Target')
ax3.set_title('3. Threshold Sensitivity Analysis')
ax3.set_xlabel('Threshold')
ax3.set_ylabel('Precision')
ax3.legend()
ax3.grid(True, alpha=0.3)

# 4. Signal timing analysis (if time series data)
ax4 = axes2[1, 0]
if hasattr(X_test, 'shape') and len(X_test.shape) > 2:
    # Assuming X_test has time dimension
    signal_positions = np.where(y_test_pred > 0)[0]
    if len(signal_positions) > 0:
        ax4.hist(signal_positions, bins=30, alpha=0.7, color='purple')
        ax4.set_title('4. Signal Distribution Over Time')
        ax4.set_xlabel('Time Position')
        ax4.set_ylabel('Signal Count')
        ax4.grid(True, alpha=0.3)

# 5. Model confidence vs accuracy
ax5 = axes2[1, 1]
confidence_bins = 10
bin_edges = np.linspace(0, 1, confidence_bins + 1)
accuracies = []
confidence_midpoints = []
for i in range(confidence_bins):
    mask = (max_probs >= bin_edges[i]) & (max_probs < bin_edges[i+1])
    if np.sum(mask) > 0:
        accuracy = np.mean(y_test_pred[mask] == y_test_final[mask])
        accuracies.append(accuracy)
        confidence_midpoints.append((bin_edges[i] + bin_edges[i+1]) / 2)

ax5.plot(confidence_midpoints, accuracies, 'b-o', linewidth=2)
ax5.set_title('5. Model Confidence vs Accuracy')
ax5.set_xlabel('Model Confidence (Max Probability)')
ax5.set_ylabel('Accuracy')
ax5.grid(True, alpha=0.3)
ax5.set_ylim([0, 1.1])

# 6. Summary statistics table
ax6 = axes2[1, 2]
ax6.axis('tight')
ax6.axis('off')
summary_text = f"""
Model Performance Summary:
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
Threshold Strategy: {selection_strategy}
Rise Threshold: {selected_rise_threshold:.3f}
Drop Threshold: {selected_drop_threshold:.3f}

Performance Metrics:
‚Ä¢ Overall Accuracy: {test_accuracy:.2%}
‚Ä¢ Rise Precision: {test_precision[1]:.2%} {'‚úì' if test_precision[1] >= 0.55 else '‚úó'}
‚Ä¢ Drop Precision: {test_precision[2]:.2%} {'‚úì' if test_precision[2] >= 0.55 else '‚úó'}
‚Ä¢ Rise Recall: {test_recall[1]:.2%}
‚Ä¢ Drop Recall: {test_recall[2]:.2%}

Signal Statistics:
‚Ä¢ Total Signals: {test_total_signals}
‚Ä¢ Rise Signals: {test_rise_signals} ({test_rise_signals/len(y_test_final):.2%})
‚Ä¢ Drop Signals: {test_drop_signals} ({test_drop_signals/len(y_test_final):.2%})

Trading Potential:
‚Ä¢ Annual Trades: {annual_trades:.0f}
‚Ä¢ Target Achieved: {'Yes' if target_achieved else 'Partial' if (rise_achieved or drop_achieved) else 'No'}
"""
ax6.text(0.1, 0.9, summary_text, fontsize=9, family='monospace',
        verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

plt.tight_layout()
plt.show()

# ============================================================================
# 6. Final Analysis and Recommendations
# ============================================================================

print("\n" + "="*60)
print("Final Analysis and Recommendations")
print("="*60)

print(f"\nüìä Final Performance Summary:")
print(f"  Threshold selection strategy: {selection_strategy}")
print(f"  Big rise threshold: {selected_rise_threshold:.3f}")
print(f"  Big drop threshold: {selected_drop_threshold:.3f}")
print(f"  Big rise precision: {test_precision[1]:.2%} {'(Achieved)' if test_precision[1] >= 0.55 else '(Not achieved)'}")
print(f"  Big drop precision: {test_precision[2]:.2%} {'(Achieved)' if test_precision[2] >= 0.55 else '(Not achieved)'}")
print(f"  Big rise signal count: {test_rise_signals} ({test_rise_signals/len(y_test_final):.2%})")
print(f"  Big drop signal count: {test_drop_signals} ({test_drop_signals/len(y_test_final):.2%})")
print(f"  Total signal count: {test_total_signals} ({test_total_signals/len(y_test_final):.2%})")

# Calculate annualized trade count
annual_trades = test_total_signals / len(y_test_final) * 252
print(f"  Estimated annual trade count: {annual_trades:.0f} trades")

print(f"\nüí° Trading Recommendations:")
if test_total_signals == 0:
    print(f"  ‚ùå No trading signals, model needs retraining")
elif test_total_signals < 20:
    print(f"  ‚ö†Ô∏è  Too few signals ({test_total_signals}), not recommended for actual trading")
elif test_total_signals >= 20 and test_total_signals < 100:
    print(f"  ‚ö° Moderate signals ({test_total_signals}), recommended for small position testing")
    if test_precision[1] >= 0.55 and test_precision[2] >= 0.55:
        print(f"    Two-way trading, position: 1-2%")
    elif test_precision[1] >= 0.55:
        print(f"    Focus on big rise trading, position: 2-3%")
    elif test_precision[2] >= 0.55:
        print(f"    Focus on big drop trading, position: 2-3%")
else:
    print(f"  ‚úÖ Sufficient signals ({test_total_signals}), can be used for regular trading")
    if test_precision[1] >= 0.55 and test_precision[2] >= 0.55:
        print(f"    Two-way trading, position: 2-3%")
    elif test_precision[1] >= 0.55:
        print(f"    Focus on big rise trading, position: 3-4%")
    elif test_precision[2] >= 0.55:
        print(f"    Focus on big drop trading, position: 3-4%")

print(f"\nüìà Visualization Insights:")
print(f"  1. Probability Distribution: Check if predictions are well-calibrated")
print(f"  2. Confusion Matrix: Identify which errors are most common")
print(f"  3. Threshold Sensitivity: Shows how precision changes with threshold")
print(f"  4. Model Confidence: High confidence ‚â† high accuracy (check Chart 5)")

print(f"\nüîß Model Usage Instructions:")
print(f"  After loading model, predict on new data:")
print(f"  probabilities = model.predict(X_new)")
print(f"  # Generate big rise signals")
print(f"  rise_signals = (probabilities[:, 1] > {selected_rise_threshold:.3f}) & \\")
print(f"                 (probabilities[:, 1] > probabilities[:, 0]) & \\")
print(f"                 (probabilities[:, 1] > probabilities[:, 2])")
print(f"  # Generate big drop signals")
print(f"  drop_signals = (probabilities[:, 2] > {selected_drop_threshold:.3f}) & \\")
print(f"                 (probabilities[:, 2] > probabilities[:, 0]) & \\")
print(f"                 (probabilities[:, 2] > probabilities[:, 1])")

print(f"\n‚úÖ Enhanced visualization and threshold optimization completed!")