# Market Regime Tuner

**Regimes:** Green=Up, Yellow=Range, Red=Down

In [None]:
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.collections import LineCollection
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML
from pathlib import Path
from datetime import datetime
import json
import warnings
warnings.filterwarnings('ignore')

sys.path.insert(0, str(Path('..').resolve()))
from src.config import FEATURES_DIR, LABELS_DIR

plt.style.use('dark_background')
SYMBOL = 'BTCUSDT'
DATA = {'df': None, 'df_display': None, 'interval': None}
CONFIG_FILE = Path('..') / 'configs' / 'regime_tuner_configs.json'

available = sorted([f.stem.split('_')[1] for f in FEATURES_DIR.glob(f'{SYMBOL}_*_features.parquet')])
default_int = '1h' if '1h' in available else available[0]
print(f"Available intervals: {available}, Default: {default_int}")

# === CONFIG MANAGEMENT ===
def load_configs():
    if CONFIG_FILE.exists():
        with open(CONFIG_FILE, 'r') as f:
            configs = json.load(f)
            # Migrate old configs with single min_bars to new format
            for name, cfg in configs.items():
                if 'min_bars' in cfg and 'min_bars_trend' not in cfg:
                    cfg['min_bars_trend'] = cfg.pop('min_bars')
                    cfg['min_bars_range'] = 40  # Default for consolidation
            return configs
    return {'default': {
        'bullish_threshold_up': 0.75, 'bullish_threshold_down': 0.35,
        'price_pos_up': 0.65, 'price_pos_down': 0.35, 'spread_min': 0.3,
        'min_bars_trend': 15, 'min_bars_range': 40, 'hysteresis': 2,
        'use_ema_smooth': True, 'ema_span': 5
    }}

def save_configs(configs):
    CONFIG_FILE.parent.mkdir(parents=True, exist_ok=True)
    with open(CONFIG_FILE, 'w') as f:
        json.dump(configs, f, indent=2)

CONFIGS = load_configs()
print(f"Loaded configs: {list(CONFIGS.keys())}")

# === DARK THEME CSS ===
display(HTML('''
<style>
.jp-Notebook, .jp-Cell, .jp-OutputArea, div.output_area, div.output_subarea { background-color: #1e1e1e !important; color: #e0e0e0 !important; }
.widget-label, .widget-readout { color: #e0e0e0 !important; }
.widget-dropdown > select, .widget-text input { background-color: #3c3c3c !important; color: #e0e0e0 !important; border: 1px solid #555 !important; }
.widget-slider .noUi-connect { background: #007acc !important; }
.widget-checkbox input[type="checkbox"] { accent-color: #007acc !important; }
.widget-button { border-radius: 4px !important; }
.widget-output { background-color: #1a1a1a !important; border: 1px solid #333 !important; }
</style>
'''))

# === REGIME DETECTION ===
def detect_regimes(df, params):
    df = df.copy()
    bull_up, bull_dn = params['bullish_threshold_up'], params['bullish_threshold_down']
    price_up, price_dn = params['price_pos_up'], params['price_pos_down']
    spread_min = params['spread_min']
    min_bars_trend = params.get('min_bars_trend', 15)
    min_bars_range = params.get('min_bars_range', 40)
    hyst = params['hysteresis']
    use_ema, ema_span = params['use_ema_smooth'], params['ema_span']
    
    bull_pct = df['bullish_pct_sma'].values if 'bullish_pct_sma' in df.columns else np.full(len(df), 0.5)
    price_pos = df['price_position'].values if 'price_position' in df.columns else np.full(len(df), 0.5)
    spread = df['spread_pct'].values if 'spread_pct' in df.columns else np.full(len(df), 1.0)
    
    if use_ema:
        bull_pct = pd.Series(bull_pct).ewm(span=ema_span, adjust=False).mean().values
        price_pos = pd.Series(price_pos).ewm(span=ema_span, adjust=False).mean().values
    
    regime = np.zeros(len(df), dtype=int)
    for i in range(len(df)):
        if bull_pct[i] >= bull_up and price_pos[i] >= price_up and spread[i] >= spread_min:
            regime[i] = 1  # Uptrend
        elif bull_pct[i] <= bull_dn and price_pos[i] <= price_dn and spread[i] >= spread_min:
            regime[i] = 2  # Downtrend
        else:
            regime[i] = 0  # Ranging
    
    df['raw_regime'] = regime.copy()
    
    # Apply hysteresis
    smoothed = regime.copy()
    cur, cnt = smoothed[0], 0
    for i in range(1, len(smoothed)):
        if smoothed[i] != cur:
            cnt += 1
            if cnt >= hyst:
                cur, cnt = smoothed[i], 0
            else:
                smoothed[i] = cur
        else:
            cnt = 0
    
    # Apply per-regime min_bars filter (asymmetric: trends vs consolidation)
    # Consolidation (regime 0) requires more bars to be confirmed
    # Trends (regime 1,2) can be shorter since moves can be explosive
    final = smoothed.copy()
    i = 0
    while i < len(final):
        j = i
        while j < len(final) and final[j] == final[i]:
            j += 1
        segment_regime = final[i]
        segment_length = j - i
        
        # Get min_bars for this regime type
        if segment_regime == 0:  # Ranging/Consolidation
            min_bars = min_bars_range
        else:  # Trend (1=up, 2=down)
            min_bars = min_bars_trend
        
        # If segment too short and not first, merge into previous
        if segment_length < min_bars and i > 0:
            final[i:j] = final[i-1]
        i = j
    
    df['regime'] = final
    return df

# === ALL WIDGETS ===
int_dd = widgets.Dropdown(options=available, value=default_int, description='Interval:')
bars_dd = widgets.Dropdown(options=[('All', -1), ('500', 500), ('1000', 1000), ('2000', 2000), ('5000', 5000)], value=500, description='Bars:')
pos_sl = widgets.IntSlider(value=100, min=0, max=100, description='Position %:', continuous_update=False)
load_btn = widgets.Button(description='Load', button_style='primary')
upd_btn = widgets.Button(description='Update', button_style='success')
status = widgets.HTML(value='<b style="color:gray">Click Load</b>')

# Config widgets
config_dd = widgets.Dropdown(options=list(CONFIGS.keys()), value='default', description='Config:')
config_name = widgets.Text(placeholder='New config name...', layout=widgets.Layout(width='150px'))
save_cfg_btn = widgets.Button(description='Save', button_style='info')
del_cfg_btn = widgets.Button(description='Delete', button_style='danger')
config_status = widgets.HTML(value='')

# Parameter sliders
default_cfg = CONFIGS.get('default', {})
bu_sl = widgets.FloatSlider(value=default_cfg.get('bullish_threshold_up', 0.75), min=0.5, max=1.0, step=0.05, description='Bull Up:', continuous_update=False)
bd_sl = widgets.FloatSlider(value=default_cfg.get('bullish_threshold_down', 0.35), min=0.0, max=0.5, step=0.05, description='Bull Down:', continuous_update=False)
pu_sl = widgets.FloatSlider(value=default_cfg.get('price_pos_up', 0.65), min=0.5, max=1.0, step=0.05, description='Price Up:', continuous_update=False)
pd_sl = widgets.FloatSlider(value=default_cfg.get('price_pos_down', 0.35), min=0.0, max=0.5, step=0.05, description='Price Down:', continuous_update=False)
sp_sl = widgets.FloatSlider(value=default_cfg.get('spread_min', 0.3), min=0.0, max=2.0, step=0.1, description='Spread Min:', continuous_update=False)

# NEW: Separate min bars for trends vs consolidation
mb_trend_sl = widgets.IntSlider(value=default_cfg.get('min_bars_trend', 15), min=5, max=50, description='Min Trend:', continuous_update=False)
mb_range_sl = widgets.IntSlider(value=default_cfg.get('min_bars_range', 40), min=20, max=100, description='Min Range:', continuous_update=False)

hy_sl = widgets.IntSlider(value=default_cfg.get('hysteresis', 2), min=1, max=10, description='Hysteresis:', continuous_update=False)
ema_cb = widgets.Checkbox(value=default_cfg.get('use_ema_smooth', True), description='EMA Smooth')
es_sl = widgets.IntSlider(value=default_cfg.get('ema_span', 5), min=2, max=20, description='EMA Span:', continuous_update=False)

# Display options
ma_cb = widgets.Checkbox(value=True, description='Show MAs')
raw_cb = widgets.Checkbox(value=False, description='Raw Regime')

# Output widgets
chart_out = widgets.Output()
stats_out = widgets.HTML(value='')
exp_btn = widgets.Button(description='Export', button_style='warning')
exp_out = widgets.Output()

# Help accordion - updated with new parameters
help_html = widgets.HTML('''
<div style="background:#252526;padding:15px;color:#ccc;font-size:12px;line-height:1.6">
<b style="color:#4fc3f7">Bull Up/Down:</b> Thresholds for % of MAs with positive slope. Up>Bull Up = uptrend, Down<Bull Down = downtrend<br>
<b style="color:#4fc3f7">Price Up/Down:</b> Where price sits in MA bundle (0=below all, 1=above all)<br>
<b style="color:#4fc3f7">Spread Min:</b> Minimum ribbon width % to confirm trend (filters noise)<br>
<b style="color:#4fc3f7">Min Trend:</b> Minimum bars for trend regimes (Up/Down) - can be short since moves are explosive<br>
<b style="color:#4fc3f7">Min Range:</b> Minimum bars for consolidation - should be longer (30-70) since real accumulation takes time<br>
<b style="color:#4fc3f7">Hysteresis:</b> Bars needed to confirm regime change (reduces flickering)<br>
<b style="color:#4fc3f7">EMA Smooth:</b> Smooth indicators before detection
</div>
''')
help_accordion = widgets.Accordion(children=[help_html], titles=['Parameter Help'])
help_accordion.selected_index = None

# === HELPER FUNCTIONS ===
def get_params():
    return {
        'bullish_threshold_up': bu_sl.value, 'bullish_threshold_down': bd_sl.value,
        'price_pos_up': pu_sl.value, 'price_pos_down': pd_sl.value, 'spread_min': sp_sl.value,
        'min_bars_trend': mb_trend_sl.value, 'min_bars_range': mb_range_sl.value,
        'hysteresis': hy_sl.value, 'use_ema_smooth': ema_cb.value, 'ema_span': es_sl.value
    }

def set_params(cfg):
    bu_sl.value, bd_sl.value = cfg.get('bullish_threshold_up', 0.75), cfg.get('bullish_threshold_down', 0.35)
    pu_sl.value, pd_sl.value = cfg.get('price_pos_up', 0.65), cfg.get('price_pos_down', 0.35)
    sp_sl.value = cfg.get('spread_min', 0.3)
    mb_trend_sl.value = cfg.get('min_bars_trend', cfg.get('min_bars', 15))  # Backward compat
    mb_range_sl.value = cfg.get('min_bars_range', 40)
    hy_sl.value, ema_cb.value, es_sl.value = cfg.get('hysteresis', 2), cfg.get('use_ema_smooth', True), cfg.get('ema_span', 5)

def on_config_change(change):
    if change['new'] in CONFIGS:
        set_params(CONFIGS[change['new']])
        config_status.value = f'<span style="color:cyan">Loaded: {change["new"]}</span>'
        if DATA['df_display'] is not None: update_chart()

def save_config(b):
    global CONFIGS
    name = config_name.value.strip() or config_dd.value
    CONFIGS[name] = get_params()
    save_configs(CONFIGS)
    config_dd.options = list(CONFIGS.keys())
    config_dd.value = name
    config_name.value = ''
    config_status.value = f'<span style="color:lime">Saved: {name}</span>'

def delete_config(b):
    global CONFIGS
    name = config_dd.value
    if name == 'default':
        config_status.value = '<span style="color:red">Cannot delete default!</span>'
        return
    del CONFIGS[name]
    save_configs(CONFIGS)
    config_dd.options = list(CONFIGS.keys())
    config_dd.value = 'default'
    config_status.value = f'<span style="color:orange">Deleted: {name}</span>'

def update_chart(b=None):
    if DATA['df_display'] is None:
        with chart_out: clear_output(); print("Load data first!")
        return
    df = detect_regimes(DATA['df_display'], get_params())
    DATA['df_with_regime'] = df
    
    fig, ax = plt.subplots(figsize=(16, 9), facecolor='#1a1a1a')
    ax.set_facecolor('#1a1a1a')
    times = np.arange(len(df))
    
    bg = {0:'#4a4a00', 1:'#004a00', 2:'#4a0000'}
    rcol = 'raw_regime' if raw_cb.value else 'regime'
    cr, bs = df.iloc[0][rcol], 0
    for i in range(1, len(df)):
        r = df.iloc[i][rcol]
        if r != cr or i == len(df)-1:
            ax.axvspan(bs, i if r != cr else i+1, alpha=0.5, color=bg[cr], zorder=0)
            cr, bs = r, i
    
    if ma_cb.value:
        for m in range(5, 37):
            c = f'ma{m}_sma'
            if c not in df.columns: continue
            v = df[c].values
            sl = np.zeros(len(v)); sl[1:] = v[1:] - v[:-1]; sl[0] = sl[1] if len(sl)>1 else 0
            pts = np.array([times, v]).T.reshape(-1,1,2)
            segs = np.concatenate([pts[:-1], pts[1:]], axis=1)
            cols = np.where(sl[1:] >= 0, '#00ff00', '#ff0000')
            ax.add_collection(LineCollection(segs, colors=cols, linewidths=1.2, alpha=0.7, zorder=1))
    
    up, dn = df[df['close']>=df['open']], df[df['close']<df['open']]
    ax.bar(up.index, up['close']-up['open'], 0.6, bottom=up['open'], color='#26a69a', zorder=2)
    ax.bar(up.index, up['high']-up['close'], 0.1, bottom=up['close'], color='#26a69a', zorder=2)
    ax.bar(up.index, up['low']-up['open'], 0.1, bottom=up['open'], color='#26a69a', zorder=2)
    ax.bar(dn.index, dn['close']-dn['open'], 0.6, bottom=dn['open'], color='#ef5350', zorder=2)
    ax.bar(dn.index, dn['high']-dn['open'], 0.1, bottom=dn['open'], color='#ef5350', zorder=2)
    ax.bar(dn.index, dn['low']-dn['close'], 0.1, bottom=dn['close'], color='#ef5350', zorder=2)
    
    ax.set_xlim(0, len(df)); ax.set_ylim(df['low'].min()*0.998, df['high'].max()*1.002)
    st = max(1, len(df)//10); ax.set_xticks(range(0,len(df),st))
    ax.set_xticklabels([pd.Timestamp(df['open_time'].iloc[i]).strftime('%m/%d') for i in range(0,len(df),st)], rotation=45, color='white')
    ax.tick_params(colors='white'); ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x,p: f'${x:,.0f}'))
    ax.set_title(f'BTCUSDT {DATA["interval"]} | Config: {config_dd.value}', color='white', fontsize=14); ax.grid(True, alpha=0.2)
    ax.legend(handles=[mpatches.Patch(color='#004a00',alpha=0.6,label='Up'), mpatches.Patch(color='#4a4a00',alpha=0.6,label='Range'),
                       mpatches.Patch(color='#4a0000',alpha=0.6,label='Down'), plt.Line2D([0],[0],color='#00ff00',lw=2,label='MA+'),
                       plt.Line2D([0],[0],color='#ff0000',lw=2,label='MA-')], loc='upper left', facecolor='#2a2a2a')
    plt.tight_layout()
    with chart_out: clear_output(); plt.show()
    plt.close(fig)
    
    # Enhanced stats with avg duration per regime
    ct = df['regime'].value_counts()
    ch = (df['regime']!=df['regime'].shift(1)).sum()-1
    
    # Calculate average duration per regime type
    segments = []
    i = 0
    while i < len(df):
        j = i
        while j < len(df) and df.iloc[j]['regime'] == df.iloc[i]['regime']:
            j += 1
        segments.append((df.iloc[i]['regime'], j - i))
        i = j
    
    avg_dur = {0: [], 1: [], 2: []}
    for r, dur in segments:
        avg_dur[r].append(dur)
    
    avg_up = np.mean(avg_dur[1]) if avg_dur[1] else 0
    avg_range = np.mean(avg_dur[0]) if avg_dur[0] else 0
    avg_down = np.mean(avg_dur[2]) if avg_dur[2] else 0
    
    stats_out.value = f'''<div style="background:#1a1a1a;padding:10px;color:white;border:1px solid #444">
        <b style="color:#0f0">Up:</b>{ct.get(1,0):,} (avg {avg_up:.0f}b) | 
        <b style="color:#ff0">Range:</b>{ct.get(0,0):,} (avg {avg_range:.0f}b) | 
        <b style="color:#f00">Down:</b>{ct.get(2,0):,} (avg {avg_down:.0f}b) | 
        <b>Changes:</b>{ch}
    </div>'''

def load_data(b=None):
    try:
        status.value = f'<b style="color:yellow">Loading {int_dd.value}...</b>'
        df = pd.read_parquet(FEATURES_DIR / f'{SYMBOL}_{int_dd.value}_features.parquet')
        df['open_time'] = pd.to_datetime(df['open_time'])
        df = df.sort_values('open_time').reset_index(drop=True)
        DATA['df'], DATA['interval'] = df, int_dd.value
        n = bars_dd.value if bars_dd.value != -1 else len(df)
        start = int(max(0, len(df)-n) * pos_sl.value/100)
        DATA['df_display'] = df.iloc[start:start+n].reset_index(drop=True)
        status.value = f'<b style="color:lime">Loaded {len(DATA["df_display"]):,} bars ({int_dd.value})</b>'
        update_chart()
    except Exception as e:
        status.value = f'<b style="color:red">{e}</b>'

def export(b):
    with exp_out:
        clear_output()
        if DATA['df'] is None: print("Load data first!"); return
        df = detect_regimes(DATA['df'], get_params())
        p = FEATURES_DIR / f"{SYMBOL}_{DATA['interval']}_features_labeled.parquet"
        df.to_parquet(p, index=False)
        ct = df['regime'].value_counts()
        print(f"Saved {len(df):,} bars to {p.name}")
        print(f"Up: {ct.get(1,0):,} | Range: {ct.get(0,0):,} | Down: {ct.get(2,0):,}")

# === CONNECT CALLBACKS ===
load_btn.on_click(load_data)
upd_btn.on_click(update_chart)
config_dd.observe(on_config_change, names='value')
save_cfg_btn.on_click(save_config)
del_cfg_btn.on_click(delete_config)
exp_btn.on_click(export)

# === DISPLAY UI ===
display(widgets.VBox([
    widgets.HTML('<h2 style="color:#e0e0e0;margin:0 0 10px 0">Market Regime Tuner</h2>'),
    widgets.HBox([widgets.HTML('<b style="color:#888">CONFIG:</b>'), config_dd, config_name, save_cfg_btn, del_cfg_btn, config_status]),
    widgets.HTML('<hr style="border-color:#444;margin:10px 0">'),
    widgets.HBox([int_dd, bars_dd, pos_sl, load_btn, status]),
    help_accordion,
    widgets.HBox([
        widgets.VBox([bu_sl, bd_sl, pu_sl, pd_sl, sp_sl]),
        widgets.VBox([mb_trend_sl, mb_range_sl, hy_sl, ema_cb, es_sl])
    ]),
    widgets.HBox([ma_cb, raw_cb, upd_btn]),
    chart_out, stats_out,
    widgets.HTML('<hr style="border-color:#444;margin:10px 0">'),
    widgets.HBox([exp_btn, exp_out])
]))

# === INITIAL LOAD ===
load_data()