In [3]:
%%capture
!pip install gradio beautifulsoup4 yfinance torch

# Main Code

## NOTE: DQN doesn't switch signals. (WIP)

## Charting and data saving is fully functional

In [5]:
import gradio as gr
import pandas as pd
import numpy as np
import yfinance as yf
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from matplotlib.patches import Rectangle
import requests
from bs4 import BeautifulSoup
import re
import io
from PIL import Image
import math
from datetime import datetime, timedelta
import random

# For the DQN portion:
try:
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torch.nn.functional as F
except ImportError:
    print("PyTorch not installed. Please install it if you want to run the DQN training tab.")

# --- Debug flag ---
DEBUG = True
def debug_print(msg):
    if DEBUG:
        print(msg)

# ------------------------------------------------------------------------
# 1) Utility functions for scraping FOMC dates, downloading data, indicators
# ------------------------------------------------------------------------
def get_fomc_dates(start_date, end_date):
    url = "https://www.federalreserve.gov/monetarypolicy/fomccalendars.htm"
    try:
        response = requests.get(url, timeout=10)
        if response.status_code != 200:
            debug_print(f"Error: Received status code {response.status_code}")
            return []
        soup = BeautifulSoup(response.text, "html.parser")
        date_objs = []
        for text in soup.stripped_strings:
            matches = re.findall(r'([A-Za-z]+ \d{1,2}, \d{4})', text)
            for date_str in matches:
                try:
                    dt = datetime.strptime(date_str, "%B %d, %Y")
                    if dt not in date_objs:
                        date_objs.append(dt)
                except Exception:
                    continue
        date_objs = sorted(date_objs)
        start_dt = pd.to_datetime(start_date)
        end_dt = pd.to_datetime(end_date)
        filtered_dates = [dt for dt in date_objs if start_dt <= dt <= end_dt]
        return filtered_dates
    except Exception as e:
        debug_print(f"Error scraping FOMC dates: {e}")
        return []

def exp_average(series, period):
    return series.ewm(span=period, adjust=False).mean()

def wilder_average(series, length):
    return series.ewm(alpha=1/length, adjust=False).mean()

def weighted_moving_average(series, window):
    weights = np.arange(1, window + 1)
    return series.rolling(window).apply(lambda prices: np.dot(prices, weights) / weights.sum(), raw=True)

def t3(source, length=21, vf=0.7):
    ema1 = exp_average(source, length)
    ema2 = exp_average(ema1, length)
    gd1 = ema1 * (1 + vf) - ema2 * vf
    ema11 = exp_average(gd1, length)
    ema22 = exp_average(ema11, length)
    gd2 = ema11 * (1 + vf) - ema22 * vf
    ema111 = exp_average(gd2, length)
    ema222 = exp_average(ema111, length)
    gd3 = ema111 * (1 + vf) - ema222 * vf
    return gd3

def vwma(series, window, volume):
    return (series * volume).rolling(window=window, min_periods=window).sum() / volume.rolling(window=window, min_periods=window).sum()

def rsi_function(close, sensitivity, rsiPeriod, rsiBase):
    delta = close.diff()
    gain = delta.clip(lower=0)
    loss = -delta.clip(upper=0)
    avg_gain = gain.rolling(window=rsiPeriod, min_periods=rsiPeriod).mean()
    avg_loss = loss.rolling(window=rsiPeriod, min_periods=rsiPeriod).mean()
    rs = avg_gain / avg_loss
    rsi = 100 - (100 / (1 + rs))
    rsi = rsi.fillna(50)
    rsi_adj = sensitivity * (rsi - rsiBase)
    return rsi_adj.clip(lower=0, upper=20)

def download_data(ticker, start_date, end_date):
    df = yf.download(ticker, start=pd.to_datetime(start_date), end=pd.to_datetime(end_date))
    if isinstance(df.columns, pd.MultiIndex):
        df.columns = [col[0].lower() for col in df.columns]
    else:
        df.columns = [str(col).lower() for col in df.columns]
    return df

def compute_bressert(df, n_period=8, r_period=13):
    df['Ln'] = df['low'].rolling(window=n_period, min_periods=1).min()
    df['Hn'] = df['high'].rolling(window=n_period, min_periods=1).max()
    df['Y'] = ((df['close'] - df['Ln']) / (df['Hn'] - df['Ln'])) * 100
    df['X'] = exp_average(df['Y'], r_period)
    df['Lxn'] = df['X'].rolling(window=n_period, min_periods=1).min()
    df['Hxn'] = df['X'].rolling(window=n_period, min_periods=1).max()
    df['DSS'] = ((df['X'] - df['Lxn']) / (df['Hxn'] - df['Lxn'])) * 100
    df['DSSb'] = exp_average(df['DSS'], r_period)
    df['DSSsignal'] = df['DSSb'].shift(1)
    return df

def compute_zscore(df, length_m=14):
    momentum = df['close'] - df['close'].shift(length_m)
    avgMomentum = momentum.rolling(window=length_m, min_periods=length_m).mean()
    stdDevMomentum = momentum.rolling(window=length_m, min_periods=length_m).std().fillna(0)
    zScore = (momentum - avgMomentum) / stdDevMomentum
    return zScore

def compute_zero_lag_macd(source, fastLength=12, slowLength=26, signalLength=9, MacdEmaLength=9, useEma=True, useOldAlgo=False):
    if useEma:
        ma1 = source.ewm(span=fastLength, adjust=False).mean()
        ma2 = ma1.ewm(span=fastLength, adjust=False).mean()
    else:
        ma1 = source.rolling(window=fastLength, min_periods=fastLength).mean()
        ma2 = ma1.rolling(window=fastLength, min_periods=fastLength).mean()
    zerolagEMA = (2 * ma1) - ma2
    if useEma:
        mas1 = source.ewm(span=slowLength, adjust=False).mean()
        mas2 = mas1.ewm(span=slowLength, adjust=False).mean()
    else:
        mas1 = source.rolling(window=slowLength, min_periods=slowLength).mean()
        mas2 = mas1.rolling(window=slowLength, min_periods=slowLength).mean()
    zerolagslowMA = (2 * mas1) - mas2
    ZeroLagMACD = zerolagEMA - zerolagslowMA
    emasig1 = ZeroLagMACD.ewm(span=signalLength, adjust=False).mean()
    emasig2 = emasig1.ewm(span=signalLength, adjust=False).mean()
    if useOldAlgo:
        signal = ZeroLagMACD.rolling(window=signalLength, min_periods=signalLength).mean()
    else:
        signal = (2 * emasig1) - emasig2
    hist = ZeroLagMACD - signal
    upHist = hist.copy()
    upHist[hist <= 0] = 0
    downHist = hist.copy()
    downHist[hist > 0] = 0
    EMALine = ZeroLagMACD.ewm(span=MacdEmaLength, adjust=False).mean()
    dotUP = ZeroLagMACD.copy()
    dotUP[(ZeroLagMACD.shift(1) >= signal.shift(1)) | (ZeroLagMACD < signal)] = np.nan
    dotDN = ZeroLagMACD.copy()
    dotDN[(ZeroLagMACD.shift(1) <= signal.shift(1)) | (ZeroLagMACD > signal)] = np.nan
    return {
        "ZeroLagMACD": ZeroLagMACD,
        "signal": signal,
        "hist": hist,
        "upHist": upHist,
        "downHist": downHist,
        "EMALine": EMALine,
        "dotUP": dotUP,
        "dotDN": dotDN
    }

def extract_macd_signals(df, macd_dict, length_m=14):
    macd_line = macd_dict["ZeroLagMACD"]
    macd_mean = macd_line.rolling(window=length_m, min_periods=length_m).mean()
    macd_std = macd_line.rolling(window=length_m, min_periods=length_m).std().replace(0, np.nan)
    macd_zscore = (macd_line - macd_mean) / macd_std

    signals = []
    for i in range(1, len(df)):
        if (pd.notna(macd_dict["ZeroLagMACD"].iloc[i]) and
            pd.notna(macd_dict["signal"].iloc[i]) and
            pd.notna(macd_dict["ZeroLagMACD"].iloc[i-1]) and
            pd.notna(macd_dict["signal"].iloc[i-1])):
            dt = df.index[i]
            if macd_dict["ZeroLagMACD"].iloc[i] > macd_dict["signal"].iloc[i] and macd_dict["ZeroLagMACD"].iloc[i-1] <= macd_dict["signal"].iloc[i-1]:
                signals.append({
                    "Date": dt.strftime("%Y-%m-%d"),
                    "Signal": "ZeroLagMACD Buy",
                    "Z-Score": round(macd_zscore.iloc[i], 2) if not pd.isna(macd_zscore.iloc[i]) else None
                })
            elif macd_dict["ZeroLagMACD"].iloc[i] < macd_dict["signal"].iloc[i] and macd_dict["ZeroLagMACD"].iloc[i-1] >= macd_dict["signal"].iloc[i-1]:
                signals.append({
                    "Date": dt.strftime("%Y-%m-%d"),
                    "Signal": "ZeroLagMACD Sell",
                    "Z-Score": round(macd_zscore.iloc[i], 2) if not pd.isna(macd_zscore.iloc[i]) else None
                })
    signals_df = pd.DataFrame(signals)
    if not signals_df.empty:
        signals_df["Date"] = pd.to_datetime(signals_df["Date"])
        signals_df = signals_df.sort_values("Date", ascending=False)
    return signals_df

def extract_momentum_signals(df, length_m=14):
    momentum = df['close'] - df['close'].shift(length_m)
    avgMomentum = momentum.rolling(window=length_m, min_periods=length_m).mean()
    stdDevMomentum = momentum.rolling(window=length_m, min_periods=length_m).std().fillna(0)
    zScore = (momentum - avgMomentum) / stdDevMomentum

    def grade(x):
        if x >= 2:
            return "A"
        elif x >= 1:
            return "B"
        elif x >= 0:
            return "C"
        elif x >= -1:
            return "D"
        elif x >= -2:
            return "E"
        else:
            return "F"

    momentum_grade = zScore.apply(grade)
    momentum_direction = momentum.apply(lambda x: "Increasing" if x > 0 else "Decreasing")

    momentum_state = []
    for i in range(len(momentum)):
        if i == 0:
            momentum_state.append("N/A")
        else:
            if abs(momentum.iloc[i]) < abs(avgMomentum.iloc[i]) * 0.1:
                momentum_state.append("Consolidating")
            elif momentum.iloc[i] * momentum.iloc[i-1] < 0:
                momentum_state.append("Turning")
            elif momentum.iloc[i] > 0:
                momentum_state.append("Positive Trending")
            else:
                momentum_state.append("Negative Trending")
    momentum_state = pd.Series(momentum_state, index=df.index)

    signals = []
    for i in range(1, len(df)):
        if momentum_grade.iloc[i] != momentum_grade.iloc[i-1]:
            signals.append({
                "Date": df.index[i].strftime("%Y-%m-%d"),
                "Signal": f"Momentum Grade Changed to {momentum_grade.iloc[i]}",
                "Z-Score": round(zScore.iloc[i], 2)
            })
        if momentum_direction.iloc[i] != momentum_direction.iloc[i-1]:
            signals.append({
                "Date": df.index[i].strftime("%Y-%m-%d"),
                "Signal": f"Momentum Direction Changed to {momentum_direction.iloc[i]}",
                "Z-Score": round(zScore.iloc[i], 2)
            })
        if momentum_state.iloc[i] != momentum_state.iloc[i-1]:
            signals.append({
                "Date": df.index[i].strftime("%Y-%m-%d"),
                "Signal": f"Momentum State Changed to {momentum_state.iloc[i]}",
                "Z-Score": round(zScore.iloc[i], 2)
            })
    signals_df = pd.DataFrame(signals)
    if not signals_df.empty:
        signals_df["Date"] = pd.to_datetime(signals_df["Date"])
        signals_df = signals_df.sort_values("Date", ascending=False)
    return signals_df

def extract_signals(df, signalUp_ZLMA, signalDn_ZLMA, bullPt, bearPt, upSig_MCDX, dnSig_MCDX, length_m=14):
    signals = []
    zScore = compute_zscore(df, length_m)
    # ZLMA signals
    for dt in df.index[signalUp_ZLMA.fillna(False)]:
        signals.append({
            "Date": dt.strftime("%Y-%m-%d"),
            "Signal": "ZLMA Buy",
            "Z-Score": round(zScore.loc[dt], 2) if not pd.isna(zScore.loc[dt]) else None
        })
    for dt in df.index[signalDn_ZLMA.fillna(False)]:
        signals.append({
            "Date": dt.strftime("%Y-%m-%d"),
            "Signal": "ZLMA Sell",
            "Z-Score": round(zScore.loc[dt], 2) if not pd.isna(zScore.loc[dt]) else None
        })
    # Convert bullPt and bearPt to empty series if they are not pd.Series
    if not isinstance(bullPt, pd.Series):
        bullPt = pd.Series(dtype='float64')
    if not isinstance(bearPt, pd.Series):
        bearPt = pd.Series(dtype='float64')
    # RSI signals
    for dt in bullPt.dropna().index:
        signals.append({
            "Date": dt.strftime("%Y-%m-%d"),
            "Signal": "RSI Buy",
            "Z-Score": round(zScore.loc[dt], 2) if not pd.isna(zScore.loc[dt]) else None
        })
    for dt in bearPt.dropna().index:
        signals.append({
            "Date": dt.strftime("%Y-%m-%d"),
            "Signal": "RSI Sell",
            "Z-Score": round(zScore.loc[dt], 2) if not pd.isna(zScore.loc[dt]) else None
        })
    # Similarly, ensure upSig_MCDX and dnSig_MCDX are Series
    if not isinstance(upSig_MCDX, pd.Series):
        upSig_MCDX = pd.Series(dtype='float64')
    if not isinstance(dnSig_MCDX, pd.Series):
        dnSig_MCDX = pd.Series(dtype='float64')
    # MCDX signals
    for dt in upSig_MCDX.dropna().index:
        signals.append({
            "Date": dt.strftime("%Y-%m-%d"),
            "Signal": "MCDX Buy",
            "Z-Score": round(zScore.loc[dt], 2) if not pd.isna(zScore.loc[dt]) else None
        })
    for dt in dnSig_MCDX.dropna().index:
        signals.append({
            "Date": dt.strftime("%Y-%m-%d"),
            "Signal": "MCDX Sell",
            "Z-Score": round(zScore.loc[dt], 2) if not pd.isna(zScore.loc[dt]) else None
        })
    # DSS crossover signals (assuming these columns are always in df)
    for i in range(1, len(df)):
        if (pd.notna(df['DSSb'].iloc[i]) and pd.notna(df['DSSsignal'].iloc[i]) and
            pd.notna(df['DSSb'].iloc[i-1]) and pd.notna(df['DSSsignal'].iloc[i-1])):
            if df['DSSb'].iloc[i] > df['DSSsignal'].iloc[i] and df['DSSb'].iloc[i-1] <= df['DSSsignal'].iloc[i-1]:
                dt = df.index[i]
                signals.append({
                    "Date": dt.strftime("%Y-%m-%d"),
                    "Signal": "DSS Buy",
                    "Z-Score": round(zScore.loc[dt], 2) if not pd.isna(zScore.loc[dt]) else None
                })
            elif df['DSSb'].iloc[i] < df['DSSsignal'].iloc[i] and df['DSSb'].iloc[i-1] >= df['DSSsignal'].iloc[i-1]:
                dt = df.index[i]
                signals.append({
                    "Date": dt.strftime("%Y-%m-%d"),
                    "Signal": "DSS Sell",
                    "Z-Score": round(zScore.loc[dt], 2) if not pd.isna(zScore.loc[dt]) else None
                })
    signals_df = pd.DataFrame(signals)
    if not signals_df.empty:
        signals_df["Date"] = pd.to_datetime(signals_df["Date"])
        signals_df = signals_df.sort_values("Date", ascending=False)
    return signals_df

def extract_current_status(df, signalUp_ZLMA, signalDn_ZLMA, bullPt, bearPt, upSig_MCDX, dnSig_MCDX, length_m=14, macd_dict=None):
    last = df.index[-1]
    zlma_status = "Buy" if pd.notna(signalUp_ZLMA.loc[last]) and signalUp_ZLMA.loc[last] else "Sell"
    rsi_status = "Buy" if pd.notna(bullPt.loc[last]) else "Sell"
    mcdx_status = "Buy" if pd.notna(upSig_MCDX.loc[last]) else "Sell"
    dss_status = "Buy" if df['DSSb'].iloc[-1] > df['DSSsignal'].iloc[-1] else "Sell"
    zScore = compute_zscore(df, length_m)
    current_zscore = round(zScore.iloc[-1], 2)
    momentum = df['close'] - df['close'].shift(length_m)
    avgMomentum = momentum.rolling(window=length_m, min_periods=length_m).mean()
    stdDevMomentum = momentum.rolling(window=length_m, min_periods=length_m).std().fillna(0)
    zScore_m = (momentum - avgMomentum) / stdDevMomentum
    def get_momentum_grade(z):
        if z >= 2:
            return "A"
        elif z >= 1:
            return "B"
        elif z >= 0:
            return "C"
        elif z >= -1:
            return "D"
        elif z >= -2:
            return "E"
        else:
            return "F"
    current_momentum_grade = get_momentum_grade(zScore_m.iloc[-1])
    current_momentum_direction = "Increasing" if momentum.iloc[-1] > momentum.iloc[-2] else "Decreasing"
    if len(df) < 2:
        current_momentum_state = "N/A"
    else:
        if abs(momentum.iloc[-1]) < abs(avgMomentum.iloc[-1]) * 0.1:
            current_momentum_state = "Consolidating"
        elif momentum.iloc[-1] * momentum.iloc[-2] < 0:
            current_momentum_state = "Turning"
        elif momentum.iloc[-1] > 0:
            current_momentum_state = "Positive Trending"
        else:
            current_momentum_state = "Negative Trending"
    indicators = ["ZLMA", "RSI", "MCDX", "DSS", "Z-Score", "Momentum Grade", "Momentum Direction", "Momentum State"]
    signals = [zlma_status, rsi_status, mcdx_status, dss_status, current_zscore, current_momentum_grade, current_momentum_direction, current_momentum_state]
    if macd_dict is not None:
        macd_status = "Buy" if macd_dict["ZeroLagMACD"].iloc[-1] > macd_dict["signal"].iloc[-1] else "Sell"
        indicators.append("ZeroLagMACD")
        signals.append(macd_status)
    return pd.DataFrame({"Indicator": indicators, "Current Signal": signals})

def create_six_panel_combined_plot(df, ticker, start_date, end_date,
                                   ema_value, zlma, signalUp_ZLMA, signalDn_ZLMA, zlma_color, ema_color,
                                   rsi_ma_base, rsi_upper_bound, rsi_lower_bound, bullPt, bearPt,
                                   b_X, b_DSSb, b_DSSsignal,
                                   hbma, threshold, upSig_MCDX, dnSig_MCDX,
                                   Dump, DnCandle, PumpCandle, Retest, Banker,
                                   iv_series, macd_dict):
    fig, axs = plt.subplots(6, 1, sharex=True, figsize=(12, 16),
                            gridspec_kw={"height_ratios": [2, 1, 1, 1, 1, 1]})
    fig.suptitle(f"{ticker} - Combined Chart", fontsize=14)
    x_vals = mdates.date2num(df.index.to_pydatetime())

    # Panel 1: Price + ZLMA + RSI Trail + FOMC markers
    for i in range(len(df)):
        o, c, h, l = df['open'].iloc[i], df['close'].iloc[i], df['high'].iloc[i], df['low'].iloc[i]
        color = 'green' if c >= o else 'red'
        axs[0].plot([x_vals[i], x_vals[i]], [l, h], color=color, linewidth=1, zorder=1)
        candle_width = 0.6
        axs[0].add_patch(Rectangle((x_vals[i] - candle_width/2, o), candle_width, c - o,
                                   facecolor=color, edgecolor=color, zorder=2))
    axs[0].plot(df.index, df['EMA_50'], label="EMA 50", color='blue', linewidth=1.5, zorder=3)
    axs[0].plot(df.index, df['EMA_100'], label="EMA 100", color='orange', linewidth=1.5, zorder=3)
    axs[0].plot(df.index, df['EMA_200'], label="EMA 200", color='purple', linewidth=1.5, zorder=3)
    axs[0].plot(df.index, df['EMA_500'], label="EMA 500", color='brown', linewidth=1.5, zorder=3)

    axs[0].plot(df.index, ema_value, label="EMA (Trend)", color=ema_color, linewidth=2, zorder=4)
    axs[0].plot(df.index, zlma,      label="ZLMA",       color=zlma_color, linewidth=2, zorder=4)

    axs[0].fill_between(df.index, zlma, ema_value, where=(zlma>=ema_value),
                        facecolor="darkgreen", alpha=0.3, interpolate=True, zorder=3)
    axs[0].fill_between(df.index, zlma, ema_value, where=(zlma<ema_value),
                        facecolor="darkred", alpha=0.3, interpolate=True, zorder=3)

    axs[0].scatter(df.index, zlma.where(signalUp_ZLMA), color="cyan",    marker="o", s=50, label="ZLMA Buy",  zorder=5)
    axs[0].scatter(df.index, zlma.where(signalDn_ZLMA), color="magenta", marker="o", s=50, label="ZLMA Sell", zorder=5)

    axs[0].plot(df.index, rsi_ma_base,        label="RSI Trail Base",  color="gray", linestyle="--", linewidth=1)
    axs[0].plot(df.index, rsi_upper_bound,    label="RSI Trail Upper", color="blue", linewidth=1)
    axs[0].plot(df.index, rsi_lower_bound,    label="RSI Trail Lower", color="red",  linewidth=1)

    axs[0].scatter(df.index, bullPt, color="cyan",    marker="^", s=50, label="RSI Buy",  zorder=6)
    axs[0].scatter(df.index, bearPt, color="magenta", marker="v", s=50, label="RSI Sell", zorder=6)

    axs[0].fill_between(df.index, rsi_ma_base,    rsi_upper_bound, facecolor="darkgreen", alpha=0.2, interpolate=True)
    axs[0].fill_between(df.index, rsi_lower_bound,rsi_ma_base,     facecolor="darkred",   alpha=0.2, interpolate=True)

    fomc_dates = get_fomc_dates(start_date, end_date)
    for i, dt in enumerate(fomc_dates):
        axs[0].axvline(dt, color="purple", linestyle="--", linewidth=1, label="FOMC" if i==0 else "")

    axs[0].set_ylabel("Price")
    axs[0].legend(loc="upper center", bbox_to_anchor=(0.5, 1.15), ncol=5, fontsize=8)

    # Panel 2: Bressert
    axs[1].plot(df.index, b_X, label="X (EMA of Y)", color="black", linewidth=2)
    marker_colors = ['black'] + [
        'red' if b_X.iloc[i] < b_X.iloc[i-1] else 'green'
        for i in range(1, len(b_X))
    ]
    axs[1].scatter(df.index, b_X, c=marker_colors, s=20, label="X Color")
    axs[1].plot(df.index, b_DSSb,       label="DSSb",       color="blue",    linewidth=2)
    axs[1].plot(df.index, b_DSSsignal,  label="DSSsignal",  color="magenta", linewidth=2)
    axs[1].axhline(50, color="gray", linewidth=1)
    axs[1].axhline(80, color="red",  linewidth=2)
    axs[1].axhline(20, color="green",linewidth=2)
    axs[1].set_ylabel("Bressert")
    axs[1].legend(loc="lower left", fontsize=8)

    # Panel 3: MCDX HBMA & Signals
    axs[2].plot(df.index, hbma, label="HBMA", color="black", linewidth=2, zorder=3)
    axs[2].axhline(threshold, color="gray", linestyle="--", label="Threshold", zorder=2)
    axs[2].scatter(df.index, upSig_MCDX, color="green", marker="o", s=50, label="MCDX Buy",  zorder=4)
    axs[2].scatter(df.index, dnSig_MCDX, color="red",   marker="o", s=50, label="MCDX Sell", zorder=4)
    axs[2].set_ylabel("MCDX HBMA")
    axs[2].legend(loc="lower left", fontsize=8)

    # Panel 4: MCDX Bars
    axs[3].bar(df.index, Dump,       width=0.8, color="red",      alpha=0.7, label="Dump",        zorder=1)
    axs[3].bar(df.index, DnCandle,   width=0.8, color="darkgray", alpha=0.7, label="Down Candle", zorder=1)
    axs[3].bar(df.index, PumpCandle, width=0.8, color="green",    alpha=0.7, label="Pump Candle", zorder=1)
    axs[3].bar(df.index, Retest,     width=0.8, color="darkred",  alpha=0.7, label="Retest",      zorder=1)
    axs[3].bar(df.index, Banker,     width=0.8, color="#84AFC9",  alpha=0.7, label="Banker",      zorder=1)
    axs[3].set_ylabel("MCDX Bars")
    axs[3].legend(loc="lower left", fontsize=8)

    # Panel 5: Enhanced Zero Lag MACD
    axs[4].fill_between(df.index, macd_dict["ZeroLagMACD"], macd_dict["signal"],
                        where=(macd_dict["ZeroLagMACD"] >= macd_dict["signal"]),
                        facecolor="green", alpha=0.3, interpolate=True)
    axs[4].fill_between(df.index, macd_dict["ZeroLagMACD"], macd_dict["signal"],
                        where=(macd_dict["ZeroLagMACD"] < macd_dict["signal"]),
                        facecolor="red", alpha=0.3, interpolate=True)
    axs[4].plot(df.index, macd_dict["ZeroLagMACD"], label="ZeroLagMACD", color="green", linewidth=1)
    axs[4].plot(df.index, macd_dict["signal"],      label="Signal",      color="red",   linewidth=1)
    axs[4].bar(df.index, macd_dict["upHist"]   *2,  label="Histogram Up",   color="gray", width=0.8)
    axs[4].bar(df.index, macd_dict["downHist"] *2,  label="Histogram Down", color="red",  width=0.8)
    axs[4].scatter(df.index, macd_dict["dotUP"], color="green",  marker="o", s=50, label="Dot Up")
    axs[4].scatter(df.index, macd_dict["dotDN"], color="red",    marker="o", s=50, label="Dot Down")
    axs[4].set_ylabel("ZeroLag MACD")
    axs[4].legend(loc="lower left", fontsize=8)

    # Panel 6: Implied Volatility
    axs[5].plot(df.index, iv_series, label="Implied Volatility (VIX)", color="darkorange", linewidth=2)
    axs[5].set_ylabel("IV")
    axs[5].legend(loc="lower left", fontsize=8)
    axs[5].xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m-%d"))
    for tick in axs[5].get_xticklabels():
        tick.set_rotation(45)

    plt.tight_layout()
    return fig

def figure_to_pil(fig):
    buf = io.BytesIO()
    fig.savefig(buf, format="png", bbox_inches="tight")
    buf.seek(0)
    return Image.open(buf)

# ------------------------------------------------------------------------
# 2) The main chart-generation function (used by Chart tab)
# ------------------------------------------------------------------------
default_end_date = datetime.now().strftime("%Y-%m-%d")
default_start_date = (datetime.now() - timedelta(days=365)).strftime("%Y-%m-%d")

def generate_plot(ticker="SPY", start_date=default_start_date, end_date=default_end_date, n_period=8, r_period=13):
    try:
        df = download_data(ticker, start_date, end_date)
        if df.empty:
            raise gr.Error(f"No data for {ticker} from {start_date} to {end_date}")
        for col in ["open", "high", "low", "close", "volume"]:
            if col not in df.columns:
                raise gr.Error(f"Missing {col} data for {ticker}")

        df['EMA_50']  = exp_average(df['close'], 50)
        df['EMA_100'] = exp_average(df['close'], 100)
        df['EMA_200'] = exp_average(df['close'], 200)
        df['EMA_500'] = exp_average(df['close'], 500)

        movAvgLength = 15
        ema_value  = exp_average(df['close'], movAvgLength)
        correction = df['close'] + (df['close'] - ema_value)
        zlma       = exp_average(correction, movAvgLength)
        signalUp_ZLMA = (zlma > ema_value) & (zlma.shift(1) <= ema_value.shift(1))
        signalDn_ZLMA = (zlma < ema_value) & (zlma.shift(1) >= ema_value.shift(1))
        zlma_color = "green" if zlma.iloc[-1] > zlma.iloc[-2] else "red"
        ema_color  = "green" if ema_value.iloc[-1] < zlma.iloc[-1] else "red"

        df = compute_bressert(df, n_period, r_period)
        b_X         = df['X']
        b_DSSb      = df['DSSb']
        b_DSSsignal = df['DSSsignal']

        RSIBaseBanker    = 50;   RSIPeriodBanker    = 50
        RSIBaseHotMoney  = 30;   RSIPeriodHotMoney  = 40
        SensitivityBanker= 1.5;  SensitivityHotMoney= 0.7
        threshold        = 8.5

        rsi_Banker  = rsi_function(df['close'], SensitivityBanker, RSIPeriodBanker, RSIBaseBanker)
        rsi_HotMoney= rsi_function(df['close'], SensitivityHotMoney, RSIPeriodHotMoney, RSIBaseHotMoney)

        hot  = rsi_HotMoney
        bank = rsi_Banker

        hotma2  = wilder_average(hot, 2)
        hotma7  = wilder_average(hot, 7)
        hotma31 = wilder_average(hot,31)
        hotma   = exp_average((hotma2*34 + hotma7*33 + hotma31*33)/100, 2)

        bankma2  = df['close'].rolling(window=2, min_periods=2).mean()
        bankma7  = exp_average(bank, 7)
        bankma31 = exp_average(bank,31)
        bankma   = ((bankma2*70 + bankma7*20 + bankma31*10)/100).rolling(window=2, min_periods=2).mean()
        banksignal = wilder_average(bankma, 4)

        hbAvg = ((hot*10) + (hotma*35) + (wilder_average(hotma,2)*15) + (bankma*35) + (banksignal*5)) / 100
        hbma = vwma(hbAvg, 2, df['volume'])

        downtrendsignal = (hotma.shift(1) >= wilder_average(hotma,2).shift(1)) & (hotma < wilder_average(hotma,2))
        uptrendsignal   = (hotma.shift(1) <= wilder_average(hotma,2).shift(1)) & (hotma > wilder_average(hotma,2))

        upSig_MCDX = hbma.where(uptrendsignal,   np.nan)
        dnSig_MCDX = hbma.where(downtrendsignal, np.nan)

        Dump = bank.where(bank < bank.shift(1) / 1.75, np.nan)
        dnCond = (bank < bank.shift(1)) & (bank < bank.shift(2)) & (bank.shift(1) < bank.shift(2)) & \
                 (bank < bank.shift(3)) & (bank < bank.shift(4)) & (bank.shift(3) < bank.shift(4)) & \
                 (bank.shift(6) > 8.5) & (bank < 10)
        DnCandle   = bank.where(dnCond, np.nan)
        PumpCandle = bank.where(bank > hbma, np.nan)
        Retest     = bank.where((banksignal > bankma) & (bank > 0), np.nan)
        Banker     = bank

        lookbackPeriod = 15
        atrLength = 27
        atrMultiplier = 0.1
        rsiLowerThreshold = 40
        rsiUpperThreshold = 60

        ohlc4 = (df['open'] + df['high'] + df['low'] + df['close'])/4
        rsi_ma_base = t3(ohlc4, length=lookbackPeriod, vf=0.7)
        tr_series = pd.concat([df['high']-df['low'],
                               abs(df['high']-df['close'].shift(1)),
                               abs(df['low']-df['close'].shift(1))], axis=1).max(axis=1)
        nzTR = tr_series.fillna(df['high']-df['low'])
        f_volatility = wilder_average(nzTR, atrLength) * atrMultiplier
        rsi_upper_bound = rsi_ma_base + (rsiUpperThreshold - 50)/10 * f_volatility
        rsi_lower_bound = rsi_ma_base - (50 - rsiLowerThreshold)/10 * f_volatility

        crossUp = (ohlc4 > rsi_upper_bound) & (ohlc4.shift(1) <= rsi_upper_bound.shift(1))
        crossDn=  (df['close'] < rsi_lower_bound) & (df['close'].shift(1) >= rsi_lower_bound.shift(1))
        bullPt = rsi_lower_bound.where(crossUp, np.nan)
        bearPt = rsi_upper_bound.where(crossDn, np.nan)

        vix_df = yf.download("^VIX", start=pd.to_datetime(start_date), end=pd.to_datetime(end_date))
        if vix_df.empty:
            iv_series = pd.Series(np.nan, index=df.index)
        else:
            vix_df.index = pd.to_datetime(vix_df.index)
            iv_series = vix_df["Close"].reindex(df.index, method="ffill")

        fomc_dates = get_fomc_dates(start_date, end_date)

        macd_dict = compute_zero_lag_macd(df['close'], fastLength=12, slowLength=26, signalLength=9,
                                          MacdEmaLength=9, useEma=True, useOldAlgo=False)

        fig = create_six_panel_combined_plot(
            df, ticker, start_date, end_date,
            ema_value, zlma, signalUp_ZLMA, signalDn_ZLMA, zlma_color, ema_color,
            rsi_ma_base, rsi_upper_bound, rsi_lower_bound, bullPt, bearPt,
            b_X, b_DSSb, b_DSSsignal,
            hbma, threshold, upSig_MCDX, dnSig_MCDX,
            Dump, DnCandle, PumpCandle, Retest, Banker,
            iv_series, macd_dict
        )

        ax0 = fig.axes[0]
        for i, dt in enumerate(fomc_dates):
            ax0.axvline(dt, color="purple", linestyle="--", linewidth=1, label="FOMC" if i==0 else "")
        handles, labels = ax0.get_legend_handles_labels()
        unique = dict(zip(labels, handles))
        ax0.legend(unique.values(), unique.keys(), loc="lower right", ncol=3, fontsize=8)

        signals_df         = extract_signals(df, signalUp_ZLMA, signalDn_ZLMA, bullPt, bearPt, upSig_MCDX, dnSig_MCDX, length_m=14)
        momentum_signals_df= extract_momentum_signals(df, length_m=14)
        macd_signals_df    = extract_macd_signals(df, macd_dict, length_m=14)

        historical_signals_df = pd.concat([signals_df, momentum_signals_df, macd_signals_df], ignore_index=True)
        if not historical_signals_df.empty:
            historical_signals_df["Date"] = pd.to_datetime(historical_signals_df["Date"])
            historical_signals_df = historical_signals_df.sort_values("Date", ascending=False)

        current_status_df = extract_current_status(
            df, signalUp_ZLMA, signalDn_ZLMA, bullPt, bearPt, upSig_MCDX, dnSig_MCDX,
            length_m=14, macd_dict=macd_dict
        )

        pil_img = figure_to_pil(fig)
        plt.close(fig)

        return pil_img, current_status_df, historical_signals_df

    except Exception as e:
        debug_print(f"Error: {e}")
        raise gr.Error(f"An error occurred: {e}")

# ------------------------------------------------------------------------
# 3) New function to save data + signals to CSV
# ------------------------------------------------------------------------
def save_historical_data(ticker="SPY", start_date=default_start_date, end_date=default_end_date,
                         n_period=8, r_period=13,
                         data_filename="full_data.csv", signals_filename="signals_data.csv"):
    try:
        df = download_data(ticker, start_date, end_date)
        if df.empty:
            raise gr.Error(f"No data for {ticker} from {start_date} to {end_date}")
        for col in ["open", "high", "low", "close", "volume"]:
            if col not in df.columns:
                raise gr.Error(f"Missing {col} data for {ticker}")

        df['EMA_50']  = exp_average(df['close'], 50)
        df['EMA_100'] = exp_average(df['close'], 100)
        df['EMA_200'] = exp_average(df['close'], 200)
        df['EMA_500'] = exp_average(df['close'], 500)

        movAvgLength = 15
        ema_value  = exp_average(df['close'], movAvgLength)
        correction = df['close'] + (df['close'] - ema_value)
        zlma       = exp_average(correction, movAvgLength)
        signalUp_ZLMA = (zlma > ema_value) & (zlma.shift(1) <= ema_value.shift(1))
        signalDn_ZLMA = (zlma < ema_value) & (zlma.shift(1) >= ema_value.shift(1))

        df = compute_bressert(df, n_period, r_period)

        RSIBaseBanker     = 50;   RSIPeriodBanker    = 50
        RSIBaseHotMoney   = 30;   RSIPeriodHotMoney  = 40
        SensitivityBanker = 1.5;  SensitivityHotMoney= 0.7
        threshold         = 8.5

        rsi_Banker   = rsi_function(df['close'], SensitivityBanker, RSIPeriodBanker, RSIBaseBanker)
        rsi_HotMoney = rsi_function(df['close'], SensitivityHotMoney, RSIPeriodHotMoney, RSIBaseHotMoney)
        hot  = rsi_HotMoney
        bank = rsi_Banker

        hotma2  = wilder_average(hot, 2)
        hotma7  = wilder_average(hot, 7)
        hotma31 = wilder_average(hot, 31)
        hotma   = exp_average((hotma2*34 + hotma7*33 + hotma31*33)/100, 2)

        bankma2  = df['close'].rolling(window=2, min_periods=2).mean()
        bankma7  = exp_average(bank, 7)
        bankma31 = exp_average(bank, 31)
        bankma   = ((bankma2*70 + bankma7*20 + bankma31*10)/100).rolling(window=2, min_periods=2).mean()
        banksignal = wilder_average(bankma, 4)

        hbAvg = ((hot*10) + (hotma*35) + (wilder_average(hotma,2)*15) + (bankma*35) + (banksignal*5)) / 100
        hbma = vwma(hbAvg, 2, df['volume'])

        downtrendsignal = (hotma.shift(1) >= wilder_average(hotma,2).shift(1)) & (hotma < wilder_average(hotma,2))
        uptrendsignal   = (hotma.shift(1) <= wilder_average(hotma,2).shift(1)) & (hotma > wilder_average(hotma,2))
        upSig_MCDX = hbma.where(uptrendsignal, np.nan)
        dnSig_MCDX = hbma.where(downtrendsignal, np.nan)

        df['ZLMA'] = zlma
        df['HBMA'] = hbma
        df['RSI_Banker'] = bank
        df['RSI_HotMoney'] = hot

        Dump = bank.where(bank < bank.shift(1) / 1.75, np.nan)
        dnCond = (bank < bank.shift(1)) & (bank < bank.shift(2)) & (bank.shift(1) < bank.shift(2)) & \
                 (bank < bank.shift(3)) & (bank < bank.shift(4)) & (bank.shift(3) < bank.shift(4)) & \
                 (bank.shift(6) > 8.5) & (bank < 10)
        DnCandle   = bank.where(dnCond, np.nan)
        PumpCandle = bank.where(bank > hbma, np.nan)
        Retest     = bank.where((banksignal > bankma) & (bank > 0), np.nan)

        lookbackPeriod = 15
        atrLength = 27
        atrMultiplier = 0.1
        rsiLowerThreshold = 40
        rsiUpperThreshold = 60

        ohlc4 = (df['open'] + df['high'] + df['low'] + df['close'])/4
        rsi_ma_base = t3(ohlc4, length=lookbackPeriod, vf=0.7)
        tr_series = pd.concat([df['high']-df['low'],
                               abs(df['high']-df['close'].shift(1)),
                               abs(df['low']-df['close'].shift(1))], axis=1).max(axis=1)
        nzTR = tr_series.fillna(df['high']-df['low'])
        f_volatility = wilder_average(nzTR, atrLength) * atrMultiplier
        rsi_upper_bound = rsi_ma_base + (rsiUpperThreshold - 50)/10 * f_volatility
        rsi_lower_bound = rsi_ma_base - (50 - rsiLowerThreshold)/10 * f_volatility

        crossUp = (ohlc4 > rsi_upper_bound) & (ohlc4.shift(1) <= rsi_upper_bound.shift(1))
        crossDn = (df['close'] < rsi_lower_bound) & (df['close'].shift(1) >= rsi_lower_bound.shift(1))
        bullPt = rsi_lower_bound.where(crossUp, np.nan)
        bearPt = rsi_upper_bound.where(crossDn, np.nan)

        macd_dict = compute_zero_lag_macd(df['close'], fastLength=12, slowLength=26, signalLength=9,
                                          MacdEmaLength=9, useEma=True, useOldAlgo=False)

        signals_df = extract_signals(df, signalUp_ZLMA, signalDn_ZLMA, bullPt, bearPt,
                                      upSig_MCDX, dnSig_MCDX, length_m=14)
        momentum_signals_df = extract_momentum_signals(df, length_m=14)
        macd_signals_df = extract_macd_signals(df, macd_dict, length_m=14)
        historical_signals_df = pd.concat([signals_df, momentum_signals_df, macd_signals_df], ignore_index=True)
        if not historical_signals_df.empty:
            historical_signals_df["Date"] = pd.to_datetime(historical_signals_df["Date"])
            historical_signals_df = historical_signals_df.sort_values("Date", ascending=False)

        # *** NEW PART: Create per-indicator 0/1 columns ***
        indicator_list = ["ZLMA", "RSI", "MCDX", "DSS", "ZeroLagMACD"]
        signals_grouped = historical_signals_df.groupby('Date')['Signal'].apply(list)
        for ind in indicator_list:
            df[ind + "_Buy"] = df.index.to_series().apply(
                lambda d: 1 if any((ind in s and "Buy" in s) for s in signals_grouped.get(d, [])) else 0
            )
            df[ind + "_Sell"] = df.index.to_series().apply(
                lambda d: 1 if any((ind in s and "Sell" in s) for s in signals_grouped.get(d, [])) else 0
            )
        # *** End of new part ***

        df.to_csv(data_filename)
        historical_signals_df.to_csv(signals_filename, index=False)
        return f"Saved {data_filename} and {signals_filename} successfully."

    except Exception as e:
        debug_print(f"Error: {e}")
        raise gr.Error(f"An error occurred: {e}")

# ------------------------------------------------------------------------
# 4) DQN + environment that reads from CSV
# ------------------------------------------------------------------------
class TradingEnv:
    """
    An environment that:
      1) Randomly picks a start day within the dataset for each episode.
      2) Randomly picks an initial position (long, short, or flat).
      3) Gives a large alignment bonus with signals and a penalty for ignoring them.
      4) Charges a transaction cost for flipping positions.
      5) Limits how long you can hold the same position to 'max_hold_days' to force re-entry.
    """
    def __init__(self, df, window_size=3, max_hold_days=10):
        """
        df must include columns:
          'close' + [ZLMA_Buy, ZLMA_Sell, RSI_Buy, RSI_Sell, MCDX_Buy, MCDX_Sell, DSS_Buy, DSS_Sell, ZeroLagMACD_Buy, ZeroLagMACD_Sell]
        """
        self.df_original = df.reset_index(drop=False).copy()  # keep 'Date' as a column if needed
        self.length_full = len(self.df_original)
        self.window_size = window_size
        self.max_hold_days = max_hold_days

        # The columns containing your signals
        self.signal_cols = [
            "ZLMA_Buy", "ZLMA_Sell", "RSI_Buy", "RSI_Sell",
            "MCDX_Buy", "MCDX_Sell", "DSS_Buy", "DSS_Sell",
            "ZeroLagMACD_Buy", "ZeroLagMACD_Sell"
        ]
        for col in self.signal_cols:
            if col not in self.df_original.columns:
                self.df_original[col] = 0

        self.random_start_index = 0
        self.current_step = 0
        self.position = 0   # +1=long, -1=short, 0=flat
        self.hold_counter = 0
        self.history_buffer = []

    def reset(self):
        """
        Each episode, we:
          1) Pick a random start day in the dataset (leaving enough days to step through).
          2) Possibly pick a random initial position.
          3) Build self.df as a subset from start_index => end.
        """
        # pick a random start day, say anywhere up to length_full - 10 so we have at least some days
        self.random_start_index = random.randint(0, self.length_full - self.window_size - 2)
        self.df = self.df_original.iloc[self.random_start_index:].reset_index(drop=True).copy()
        self.length = len(self.df)

        # random initial position: pick from {0=flat, +1=long, -1=short}
        self.position = random.choice([0, 1, -1])

        self.current_step = 0
        self.hold_counter = 1 if self.position != 0 else 0
        self.history_buffer = []
        for _ in range(self.window_size):
            self._append_history()
        return self._get_state()

    def step(self, action):
        """
        action: 0=Hold, 1=Go Long, 2=Go Short
        We'll do:
          - daily PnL
          - large alignment bonus if signals match position
          - penalty if signals are ignored
          - forced close if hold_counter > max_hold_days
          - transaction cost if flipping
        """
        prev_position = self.position
        current_price = self.df.loc[self.current_step, 'close']
        reward = 0.0

        # Basic daily PnL
        if self.current_step > 0:
            price_prev = self.df.loc[self.current_step - 1, 'close']
            reward = (current_price - price_prev) * prev_position

        # Update position if action changed it
        new_position = 0
        if action == 1:
            new_position = 1
        elif action == 2:
            new_position = -1
        # else remain 0 => hold

        # Transaction cost if flipping from +1 to -1 or vice versa
        if (prev_position ==  1 and new_position == -1) or (prev_position == -1 and new_position == 1):
            reward -= 0.001 * current_price  # transaction cost

        # If new_position != prev_position, we are switching
        # If new_position=0, that means we closed the position
        if new_position != prev_position:
            self.position = new_position
            self.hold_counter = 1 if self.position != 0 else 0
        else:
            # holding the same position
            if self.position != 0:
                self.hold_counter += 1

        # If hold_counter goes beyond max_hold_days, forcibly close
        if self.hold_counter > self.max_hold_days:
            # forcibly close
            # This means position -> 0
            # Possibly penalize ignoring?
            reward -= 0.02 * current_price
            self.position = 0
            self.hold_counter = 0

        # Signals alignment
        day_signals = self.df.loc[self.current_step, self.signal_cols].values
        buy_sum = day_signals[0::2].sum()   # sum of the buy columns
        sell_sum = day_signals[1::2].sum()  # sum of the sell columns

        # Big alignment bonus, plus ignoring penalty
        # e.g. bigger bonus = 0.05 * sum, penalty = 0.02
        if self.position == 1:
            # If aligned with buy signals
            reward += 0.05 * buy_sum
            # If ignoring sell signals
            reward -= 0.02 * sell_sum
        elif self.position == -1:
            # If aligned with sell signals
            reward += 0.05 * sell_sum
            # If ignoring buy signals
            reward -= 0.02 * buy_sum
        else:
            # flat => we might penalize ignoring both
            if buy_sum > 0 or sell_sum > 0:
                reward -= 0.01 * (buy_sum + sell_sum)

        # Move forward
        self.current_step += 1
        done = (self.current_step >= self.length - 1)  # stop near the end

        self._append_history()
        state = self._get_state()
        return state, reward, done

    def _append_history(self):
        if self.current_step < self.length:
            row = self.df.loc[self.current_step, ["close"] + self.signal_cols].values
        else:
            row = self.df.loc[self.length - 1, ["close"] + self.signal_cols].values
        self.history_buffer.append(row.astype(np.float32))
        if len(self.history_buffer) > self.window_size:
            self.history_buffer.pop(0)

    def _get_state(self):
        # shape: (window_size, 1+10) => flatten => plus position
        history_array = np.array(self.history_buffer)
        flattened = history_array.flatten()
        return np.concatenate([flattened, [float(self.position)]], dtype=np.float32)

# For window_size=3, state dimension = 3*11 + 1 = 33 + 1 = 34.
class QNetwork(nn.Module):
    def __init__(self, input_dim=34, hidden_dim=64, output_dim=3):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

def train_dqn_from_df(df, episodes=50, gamma=0.99, lr=1e-3,
                      epsilon=1.0, epsilon_decay=0.99, epsilon_min=0.01):
    env = TradingEnv(df, window_size=3, max_hold_days=10)
    policy_net = QNetwork(input_dim=34, hidden_dim=64, output_dim=3)
    optimizer = optim.Adam(policy_net.parameters(), lr=lr)

    rewards_per_episode = []
    for ep in range(episodes):
        state = env.reset()
        done = False
        episode_reward = 0.0
        while not done:
            # Epsilon-greedy
            if np.random.rand() < epsilon:
                action = np.random.randint(0, 3)
            else:
                with torch.no_grad():
                    s = torch.FloatTensor(state).unsqueeze(0)
                    q_values = policy_net(s)
                    action = q_values.argmax(dim=1).item()

            next_state, reward, done = env.step(action)
            episode_reward += reward

            # Compute target for Q-learning
            with torch.no_grad():
                s_next = torch.FloatTensor(next_state).unsqueeze(0)
                next_q_values = policy_net(s_next)
                max_next = next_q_values.max(dim=1)[0].item()
                target_value = reward + (0 if done else gamma * max_next)

            # Q-learning update
            s = torch.FloatTensor(state).unsqueeze(0)
            q_values = policy_net(s)
            q_current = q_values[0, action]
            loss = F.mse_loss(q_current, torch.FloatTensor([target_value]))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            state = next_state

        epsilon = max(epsilon_min, epsilon * epsilon_decay)
        rewards_per_episode.append({"Episode": ep+1, "Total Reward": episode_reward})

    return pd.DataFrame(rewards_per_episode), policy_net

def simulate_dqn_trades(policy_net, df):
    """
    Runs a simulation with the trained policy (greedy) and logs an action every day
    (Buy, Sell, or Hold), along with the date and price.
    """
    env = TradingEnv(df)
    state = env.reset()
    all_actions = []  # We'll store (date, action_name, price) for every day
    day_idx = 0

    done = False
    while not done:
        with torch.no_grad():
            s = torch.FloatTensor(state).unsqueeze(0)
            q_values = policy_net(s)
            action = q_values.argmax(dim=1).item()

        # Convert numeric action => text label
        if action == 1:
            action_name = "Buy"
        elif action == 2:
            action_name = "Sell"
        else:
            action_name = "Hold"

        if day_idx < len(df):
            date_here = df.index[day_idx]
            price_here = df.loc[date_here, 'close']
            all_actions.append((date_here, action_name, price_here))

        next_state, reward, done = env.step(action)
        state = next_state
        day_idx += 1

    return all_actions

def plot_trades_chart(df, all_actions):
    """
    Plots a marker for every daily action:
      - Green ^ for Buy
      - Red v for Sell
      - Blue . for Hold
    """
    fig, ax = plt.subplots(figsize=(12, 6))
    ax.plot(df.index, df['close'], label="Close Price")

    # We track which legend labels have been used so we don't repeat them
    used_labels = set()

    for (date, action_name, price) in all_actions:
        if action_name == "Buy":
            marker, color = "^", "green"
        elif action_name == "Sell":
            marker, color = "v", "red"
        else:  # Hold
            marker, color = ".", "blue"

        # Only add a legend entry once per action_name
        lbl = action_name if action_name not in used_labels else None
        used_labels.add(action_name)

        ax.scatter(date, price, marker=marker, color=color, s=100, label=lbl)

    ax.legend()
    ax.set_title("DQN Trades on Price Chart (All Actions Plotted)")
    buf = io.BytesIO()
    fig.savefig(buf, format="png", bbox_inches="tight")
    buf.seek(0)
    plt.close(fig)
    return Image.open(buf)

def run_dqn_csv(csv_data_path="full_data.csv", episodes=5):
    try:
        df = pd.read_csv(csv_data_path, parse_dates=True, index_col=0)
        if 'close' not in df.columns:
            raise gr.Error("No 'close' column in CSV. Please ensure CSV has the necessary columns.")
        # Train the DQN
        rewards_df, policy_net = train_dqn_from_df(df, episodes=episodes)

        # Simulate day by day, logging every action
        all_actions = simulate_dqn_trades(policy_net, df)

        # Debug print so you can see all actions in console
        print("=== DQN Actions Day by Day ===")
        for item in all_actions:
            print(item)
        print("================================")

        # Plot all actions
        trade_plot = plot_trades_chart(df, all_actions)
        return rewards_df, trade_plot
    except Exception as e:
        raise gr.Error(f"An error occurred in the DQN module: {e}")

# ------------------------------------------------------------------------
# 5) Gradio UI
# ------------------------------------------------------------------------
chart_interface = gr.Interface(
    fn=generate_plot,
    inputs=[
        gr.Textbox(label="Ticker", value="SPY"),
        gr.Textbox(label="Start Date (YYYY-MM-DD)", value=default_start_date),
        gr.Textbox(label="End Date (YYYY-MM-DD)",   value=default_end_date),
        gr.Slider(minimum=3, maximum=20, step=1, value=8,  label="n_period"),
        gr.Slider(minimum=3, maximum=20, step=1, value=13, label="r_period")
    ],
    outputs=[
        gr.Image(type="pil",   label="Chart"),
        gr.Dataframe(label="Current Indicator Status"),
        gr.Dataframe(label="Historical Signals")
    ],
    title="Six-Panel Chart: ZLMA + RSI + Bressert + MCDX + MACD + VIX",
    description=(
        "Panel 1: Price candlestick with EMAs, ZLMA, RSI Trail.\n"
        "Panel 2: Bressert.\n"
        "Panel 3: MCDX (HBMA & signals).\n"
        "Panel 4: MCDX Bars.\n"
        "Panel 5: Zero Lag MACD.\n"
        "Panel 6: Implied Volatility (VIX).\n"
        "Outputs: Chart, Current Status, Historical Signals."
    )
)

with gr.Blocks() as demo:
    with gr.Tab("Chart"):
        chart_interface.render()

    with gr.Tab("Save Data"):
        gr.Markdown("### Save Historical Data & Signals to CSV")
        with gr.Row():
            sd_ticker  = gr.Textbox(label="Ticker", value="SPY")
            sd_start   = gr.Textbox(label="Start Date", value=default_start_date)
            sd_end     = gr.Textbox(label="End Date",   value=default_end_date)
        with gr.Row():
            sd_n   = gr.Slider(minimum=3,  maximum=20, step=1, value=8,  label="n_period")
            sd_r   = gr.Slider(minimum=3,  maximum=20, step=1, value=13, label="r_period")
        with gr.Row():
            data_csv_path    = gr.Textbox(label="Data CSV Filename",    value="full_data.csv")
            signals_csv_path = gr.Textbox(label="Signals CSV Filename", value="signals_data.csv")
        btn_save = gr.Button("Save to CSV")
        save_output_msg = gr.Markdown()
        def save_data_wrapper(ticker, start, end, n, r, datafile, sigfile):
            return save_historical_data(ticker, start, end, n, r, datafile, sigfile)
        btn_save.click(
            fn=save_data_wrapper,
            inputs=[sd_ticker, sd_start, sd_end, sd_n, sd_r, data_csv_path, signals_csv_path],
            outputs=save_output_msg
        )

    with gr.Tab("DQN Training"):
        gr.Markdown("### Train DQN from Saved CSV")
        csv_input = gr.Textbox(label="CSV file path", value="full_data.csv", info="Path to the CSV with price data, as saved in the previous tab.")
        episodes_input = gr.Slider(label="Episodes", value=5, minimum=1, maximum=200, step=1)
        run_dqn_btn = gr.Button("Train DQN & Show Trades")
        dqn_rewards_out = gr.Dataframe(label="Episode Rewards")
        dqn_tradeplot_out = gr.Image(type="pil", label="DQN Trades Chart")
        run_dqn_btn.click(
            fn=run_dqn_csv,
            inputs=[csv_input, episodes_input],
            outputs=[dqn_rewards_out, dqn_tradeplot_out]
        )

if __name__ == "__main__":
    demo.launch()

Running Gradio in a Colab notebook requires sharing enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://48bec4c85de7e571a3.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
