In [1]:
# ==================================================
# Cell 1 — Install dependencies (run once if needed)
# ==================================================
import sys
!{sys.executable} -m pip install --upgrade pip setuptools wheel
!{sys.executable} -m pip install --quiet numpy pandas matplotlib seaborn MetaTrader5 stable-baselines3 gymnasium==0.29.1 torch ffmpeg-python pillow




In [22]:
# ==================================================
# Cell 2 — Imports & configuration
# ==================================================
import os, glob, json, time
from datetime import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("darkgrid")

import MetaTrader5 as mt5
import gymnasium as gym

from stable_baselines3 import PPO


In [23]:
# ==================================================
# Cell 3 — Paths (adjust if different)
# ==================================================
DATA_DIR = os.path.join("data", "multiasset")
MODEL_DIR = os.path.join("models", "multiasset")
MODEL_FILE = os.path.join(MODEL_DIR, "ppo_multiasset.zip")
EMBED_FILE = os.path.join(MODEL_DIR, "asset_embeddings.npy")
ASSET_MAP_FILE = os.path.join(DATA_DIR, "asset_to_idx.csv")
SCALER_GLOB = os.path.join(DATA_DIR, "*_scaler.csv")  # per-symbol scalers
LOG_DIR = os.path.join(MODEL_DIR, "live_logs")
os.makedirs(LOG_DIR, exist_ok=True)


In [24]:
# ==================================================
# Cell 4 — Live configuration
# ==================================================
TIMEFRAME = "M15"        # timeframe to request from MT5
WINDOW = 50
RISK_PER_TRADE = 0.02   # default risk per trade (1%)
MIN_LOT = 0.01
MAX_LOT = 10.0
DEVIATION = 20
MAGIC = 2025001
DRY_RUN_DEFAULT = False  # safety default

In [25]:
# ==================================================
# Cell 5 — Helpers: safe action extraction, load datasets, scalers, embeddings
# ==================================================
def extract_action_scalar(action):
    import numpy as _np
    if isinstance(action, (int, _np.integer)):
        return int(action)
    a = _np.array(action)
    if a.size == 1:
        return int(a.flatten()[0])
    return int(a.flatten()[0])

def load_scalers(data_dir=DATA_DIR):
    scalers = {}
    for p in glob.glob(os.path.join(data_dir, "*_scaler.csv")):
        safe = os.path.basename(p).replace("_scaler.csv","")
        try:
            df = pd.read_csv(p, index_col=0)
            scalers[safe] = {"mean": df['mean'].astype(float), "std": df['std'].astype(float)}
        except Exception:
            scalers[safe] = None
    return scalers

def load_assets_and_embeddings(data_dir=DATA_DIR, embed_file=EMBED_FILE):
    # load asset order from asset_to_idx.csv if present
    asset_map = {}
    if os.path.exists(ASSET_MAP_FILE):
        try:
            s = pd.read_csv(ASSET_MAP_FILE, index_col=0, squeeze=True)
            asset_map = (s.to_dict() if hasattr(s, "to_dict") else dict(s))
            # s might be safe->idx or vice-versa; normalize
            # produce ordered list by idx
            try:
                ordered = sorted(asset_map.items(), key=lambda x:int(x[1]))
                safes = [k for k,_ in ordered]
            except Exception:
                safes = list(asset_map.keys())
        except Exception:
            safes = []
    else:
        # fallback: read normalized csv names
        safes = [os.path.basename(p).replace("_normalized.csv","") for p in sorted(glob.glob(os.path.join(DATA_DIR,"*_normalized.csv")))]
    embeddings = np.load(embed_file) if os.path.exists(embed_file) else np.zeros((len(safes), 1), dtype=np.float32)
    # build canonical list
    safes = safes if len(safes)>0 else list({os.path.basename(p).replace("_normalized.csv",""):None for p in glob.glob(os.path.join(DATA_DIR,"*_normalized.csv"))}.keys())
    return safes, embeddings, load_scalers(data_dir)


In [26]:
# ==================================================
# Cell 6 — MT5 init helpers (safe)
# ==================================================
def mt5_init_if_needed():
    try:
        if not mt5.initialize():
            # sometimes initialize() returns False but terminal still ok; give another try
            return mt5.initialize()
        return True
    except Exception:
        return False

def mt5_shutdown_if_needed():
    try:
        mt5.shutdown()
    except Exception:
        pass


In [27]:
# ==================================================
# Cell 7 — symbol_published (symbol)
# ==================================================
def ensure_symbol_published(symbol):
    """Ensure symbol is visible in Market Watch; attempt to add if not."""
    try:
        sinfo = mt5.symbol_info(symbol)
        if sinfo is None:
            # try to enable the symbol
            mt5.symbol_select(symbol, True)
            sinfo = mt5.symbol_info(symbol)
        return sinfo is not None
    except Exception:
        return False


In [28]:
# ==================================================
# Cell 8 — Time frame mapping
# ==================================================
# Cell 5 — Build observation for a symbol in the same format used in training
def timeframe_to_mt5(tf_str):
    TF_MAP = {
        "M1": mt5.TIMEFRAME_M1, "M5": mt5.TIMEFRAME_M5, "M15": mt5.TIMEFRAME_M15,
        "M30": mt5.TIMEFRAME_M30, "H1": mt5.TIMEFRAME_H1, "H4": mt5.TIMEFRAME_H4,
        "D1": mt5.TIMEFRAME_D1
    }
    return TF_MAP.get(tf_str.upper(), mt5.TIMEFRAME_M15)

TF_MT5 = timeframe_to_mt5(TIMEFRAME)

In [32]:
# ==================================================
# Cell 8 — 
# ==================================================
def fetch_and_build_obs(symbol, window, scalers, embeddings, safe_names, fallback_embedding_dim=None):
    """
    Build observation window for a symbol using:
    - normalized percent-change features
    - saved scaler (mean/std)
    - saved embedding vector

    Parameters:
        symbol (str): raw symbol, e.g. "EURUSD"
        window (int): sliding window length
        scalers (dict): {safe: {"mean": Series, "std": Series}}
        embeddings (dict): {safe: np.ndarray}
        safe_names (dict): {raw_symbol: safe_symbol}
        fallback_embedding_dim (int or None): optional zero-vector if embedding missing

    Returns:
        obs_window (np.ndarray)
        scaler_info (dict)
        embedding (np.ndarray)
    """

    # -----------------------------------------------------------
    # 1) Convert symbol → safe-name
    # -----------------------------------------------------------
    if symbol not in safe_names:
        print(f"❌ No safe-name found for symbol: {symbol}")
        return None, None, None

    safe = safe_names[symbol]

    # -----------------------------------------------------------
    # 2) Check required preprocessing objects
    # -----------------------------------------------------------
    if safe not in scalers:
        print(f"❌ Missing scaler for: {safe}")
        return None, None, None

    if safe not in embeddings:
        if fallback_embedding_dim is None:
            print(f"❌ Missing embedding for: {safe}")
            return None, None, None
        else:
            print(f"⚠️ Missing embedding → using zero vector for {safe}")
            embedding = np.zeros(fallback_embedding_dim)
    else:
        embedding = embeddings[safe]

    scaler_info = scalers[safe]

    # Extract mean / std
    mean = scaler_info["mean"]
    std = scaler_info["std"].replace(0, 1.0)

    # -----------------------------------------------------------
    # 3) Load raw normalized dataset for this symbol
    # -----------------------------------------------------------
    if safe not in datasets:
        print(f"❌ No dataset loaded for: {safe}")
        return None, None, None

    df = datasets[safe].copy()

    # Must contain the required columns
    required_columns = ["o_pc", "h_pc", "l_pc", "c_pc", "v_pc"]
    for col in required_columns:
        if col not in df.columns:
            print(f"❌ Missing required column '{col}' in dataset for {safe}")
            return None, None, None

    # -----------------------------------------------------------
    # 4) Apply normalization using stored scaler
    # -----------------------------------------------------------
    df_norm = (df[required_columns] - mean[required_columns]) / std[required_columns]

    # -----------------------------------------------------------
    # 5) Extract last window
    # -----------------------------------------------------------
    if len(df_norm) < window:
        print(f"❌ Not enough data for window={window} | {safe} has {len(df_norm)} rows")
        return None, None, None

    obs_window = df_norm.tail(window).values.astype(np.float32)

    # -----------------------------------------------------------
    # 6) Append embedding at each timestep (optional architecture)
    # -----------------------------------------------------------
    embedding_repeated = np.tile(embedding, (window, 1))
    obs_window = np.concatenate([obs_window, embedding_repeated], axis=1)

    return obs_window, scaler_info, embedding


# ==================================================
# Cell 9 — 
# ==================================================
def fetch_and_build_obs(symbol, window=WINDOW, scalers=None, embeddings=None, safe_names=None):
    """
    Returns: obs (shape (window, feat_dim)), vol_est (std of c_pc), last_price
    obs columns: o_pc,h_pc,l_pc,c_pc,v_pc, balance_col omitted (live uses balance separately), embedding appended as columns
    """
    # fetch extra slack bars
    count = window + 10
    bars = mt5.copy_rates_from_pos(symbol, TF_MT5, 0, count)
    if bars is None or len(bars) < window + 2:
        return None, None, None
    df = pd.DataFrame(bars)
    df['time'] = pd.to_datetime(df['time'], unit='s')
    df = df.set_index('time')
    df = df[['open','high','low','close','tick_volume']].rename(columns={'tick_volume':'volume'})
    pct = pd.DataFrame(index=df.index)
    pct['o_pc'] = df['open'].pct_change()
    pct['h_pc'] = df['high'].pct_change()
    pct['l_pc'] = df['low'].pct_change()
    pct['c_pc'] = df['close'].pct_change()
    pct['v_pc'] = df['volume'].pct_change()
    pct.dropna(inplace=True)
    if len(pct) < window:
        return None, None, None
    recent = pct.iloc[-window:]
    # normalization using scaler if available
    safe = symbol.replace(" ", "_").replace("/", "_").replace("(", "").replace(")","")
    if scalers and safe in scalers and scalers[safe] is not None:
        s = scalers[safe]
        mean = s['mean']
        std = s['std'].replace(0,1.0)
        # ensure correct order of columns
        mean = mean.reindex(recent.columns).fillna(0.0)
        std = std.reindex(recent.columns).replace(0,1.0).fillna(1.0)
        norm = (recent - mean) / std
    else:
        mean = recent.mean()
        std = recent.std().replace(0,1.0)
        norm = (recent - mean) / std
    # features
    feats = norm[['o_pc','h_pc','l_pc','c_pc','v_pc']].values.astype(np.float32)
    vol_est = float(norm['c_pc'].std())
    last_price = float(df['close'].iloc[-1])
    # embeddings: find index for symbol in safe_names (if provided)
    if embeddings is not None and safe_names is not None and safe in safe_names:
        idx = safe_names.index(safe)
        emb = np.tile(embeddings[idx].reshape(1,-1).astype(np.float32),(window,1))
        obs = np.concatenate([feats, emb], axis=1)
    else:
        obs = feats
    return obs, vol_est, last_price


# ==================================================
# Cell 10 — 
# ==================================================
def fetch_and_build_obs(symbol, window, scalers, embeddings, safe_names):
    safe = symbol.replace(" ", "_").replace("/", "_")

    if safe not in scalers or safe not in embeddings:
        return None, None, None

    scaler = scalers[safe]
    embed_vec = embeddings[safe]      # shape MUST match training (e.g., 7 dims)

    # fetch extra slack bars
    count = window + 10
    bars = mt5.copy_rates_from_pos(symbol, TF_MT5, 0, count)
    if bars is None or len(bars) < window + 2:
        return None, None, None

    print("bars: ",bars)
    
    df = pd.DataFrame(bars)
    df['time'] = pd.to_datetime(df['time'], unit='s')
    df = df.set_index('time')
    df = df[['open','high','low','close','tick_volume']].rename(columns={'tick_volume':'volume'})
    pct = pd.DataFrame(index=df.index)
    pct['o_pc'] = df['open'].pct_change()
    pct['h_pc'] = df['high'].pct_change()
    pct['l_pc'] = df['low'].pct_change()
    pct['c_pc'] = df['close'].pct_change()
    pct['v_pc'] = df['volume'].pct_change()
    pct.dropna(inplace=True)
    
    #df = fetch_recent_raw(symbol, window + 1)   # you must have this helper
    if df is None or len(df) < window + 1:
        return None, None, None

    pct = df.pct_change().dropna()
    pct = pct.tail(window)

    # normalize with per-asset scaler
    pct_norm = (pct - scaler["mean"]) / scaler["std"]

    # last raw close (for entry price)
    last_price = df["close"].iloc[-1]

    # volatility estimate (std of returns)
    vol_est = pct["close"].std()

    # balance normalized like training (fixed 10000 for live)
    balance_norm = np.full((window, 1), 1.0)

    # asset_id normalized
    asset_id_norm = safe_names.index(safe) / len(safe_names)
    asset_id_norm = np.full((window, 1), asset_id_norm)

    # repeat embedding per row
    emb = np.tile(embed_vec, (window, 1))

    # BUILD FINAL OBSERVATION — EXACT MATCH WITH TRAINING SHAPE
    obs = np.column_stack([
        pct_norm[["open","high","low","close","volume"]].values,
        emb,
        balance_norm,
        asset_id_norm
    ]).astype(np.float32)

    return obs, vol_est, last_price


In [37]:
# ==================================================
# Cell 11 — 
# ==================================================
def fetch_and_build_obs(symbol, window, scalers, embeddings, safe_names):# The correct one
    safe = symbol.replace(" ", "_").replace("/", "_")

    # Check if preprocessing objects exist
    if safe not in scalers:
        print(f"❌ Missing scaler for: {safe}")
        return None, None, None

    if safe not in embeddings:
        print(f"❌ Missing embedding for: {safe}")
        return None, None, None


    scaler = scalers[safe]
    embed_vec = embeddings[safe]  # Expecting shape (embedding_dim,)

    # Fetch raw bars
    count = window + 10   # fetch surplus for pct-change
    bars = mt5.copy_rates_from_pos(symbol, TF_MT5, 0, count)

    if bars is None or len(bars) < window + 2:
        return None, None, None

    df = pd.DataFrame(bars)
    df['time'] = pd.to_datetime(df['time'], unit='s')
    df = df.set_index('time')
    df = df[['open','high','low','close','tick_volume']].rename(columns={'tick_volume':'volume'})

    # Ensure enough rows
    if len(df) < window + 1:
        return None, None, None

    # Compute pct changes
    pct = df.pct_change().dropna()
    pct = pct.tail(window)

    if pct.shape[0] < window:
        return None, None, None

    # Normalize using the saved scaler
    pct_norm = (pct - scaler["mean"]) / scaler["std"]

    # Extract last price
    last_price = df["close"].iloc[-1]

    # Volatility estimate
    vol_est = pct["close"].std()

    # balance normalized (fixed 1.0 for live)
    balance_norm = np.full((window, 1), 1.0, dtype=np.float32)

    # asset-id normalized
    asset_id_norm_val = safe_names.index(safe) / len(safe_names)
    asset_id_norm = np.full((window, 1), asset_id_norm_val, dtype=np.float32)

    # repeat embedding per timestep
    emb = np.tile(embed_vec, (window, 1)).astype(np.float32)

    # final obs = [normalized OHLCV + embedding + balance + asset_id]
    obs = np.column_stack([
        pct_norm[["open", "high", "low", "close", "volume"]].values.astype(np.float32),
        emb,
        balance_norm,
        asset_id_norm
    ]).astype(np.float32)

    return obs, float(vol_est), float(last_price)


In [41]:
# ==================================================
# Cell 12 — 
# ==================================================
def make_safe(symbol: str) -> str:
    return (
        symbol.replace(" ", "_")
              .replace("/", "_")
              .replace("(", "")
              .replace(")", "")
              .replace(".", "_")
    )


In [42]:
def fetch_and_build_obs(symbol, window, scalers, embeddings, safe_names):
    # -------------------------------------------------------
    # 1) Always generate safe-name the SAME way
    # -------------------------------------------------------
    safe = make_safe(symbol)

    # -------------------------------------------------------
    # 2) Verify preprocessing data
    # -------------------------------------------------------
    if safe not in scalers:
        print(f"❌ Missing scaler for: {safe}")
        return None, None, None

    if safe not in embeddings:
        print(f"❌ Missing embedding for: {safe}")
        return None, None, None

    scaler = scalers[safe]
    embed_vec = embeddings[safe]

    # -------------------------------------------------------
    # 3) MT5 price fetch
    # -------------------------------------------------------
    count = window + 10
    bars = mt5.copy_rates_from_pos(symbol, TF_MT5, 0, count)

    if bars is None or len(bars) < window + 2:
        print(f"❌ Insufficient data for {symbol}")
        return None, None, None

    df = pd.DataFrame(bars)
    df['time'] = pd.to_datetime(df['time'], unit='s')
    df = df.set_index('time')
    df = df[['open', 'high', 'low', 'close', 'tick_volume']].rename(columns={'tick_volume': 'volume'})

    if len(df) < window + 1:
        return None, None, None

    # -------------------------------------------------------
    # 4) Percent change window
    # -------------------------------------------------------
    pct = df.pct_change().dropna().tail(window)

    if pct.shape[0] < window:
        return None, None, None

    # -------------------------------------------------------
    # 5) Normalize using the saved scaler
    # -------------------------------------------------------
    pct_norm = (pct - scaler["mean"]) / scaler["std"]

    # -------------------------------------------------------
    # 6) Build obs components
    # -------------------------------------------------------
    last_price = df["close"].iloc[-1]
    vol_est = pct["close"].std()

    balance_norm = np.full((window, 1), 1.0, dtype=np.float32)

    # ---- safe_names MUST be a dict: {raw_symbol: safe} ----
    safe_list = list(safe_names.values())
    asset_id_norm_val = safe_list.index(safe) / len(safe_list)
    asset_id_norm = np.full((window, 1), asset_id_norm_val, dtype=np.float32)

    emb = np.tile(embed_vec, (window, 1)).astype(np.float32)

    # -------------------------------------------------------
    # 7) Final Observation
    # -------------------------------------------------------
    obs = np.column_stack([
        pct_norm[["open", "high", "low", "close", "volume"]].values.astype(np.float32),
        emb,
        balance_norm,
        asset_id_norm
    ]).astype(np.float32)

    return obs, float(vol_est), float(last_price)


In [33]:
# Cell 6 — Lot-size helpers (MT5-aware)
def pip_value_per_lot_from_mt5_live(symbol, entry_price):
    info = None
    try:
        info = mt5.symbol_info(symbol)
    except Exception:
        info = None
    if info is None:
        return None, None
    try:
        pip_val = info.trade_tick_value / info.trade_tick_size
    except Exception:
        contract_size = getattr(info, "trade_contract_size", 100000.0)
        pip_val = (info.point / entry_price) * contract_size
    return float(pip_val), float(info.point)

def calculate_lot_size_live(symbol, balance, risk_percent, entry_price, stop_loss_price):
    pip_val, point = pip_value_per_lot_from_mt5_live(symbol, entry_price)
    if point and point not in (0,None):
        pip_risk = abs(entry_price - stop_loss_price) / point
    else:
        pip_risk = abs(entry_price - stop_loss_price) / (0.01 if "JPY" in symbol else 0.0001)
    if pip_risk <= 0:
        return MIN_LOT
    dollar_risk = balance * float(risk_percent)
    if pip_val is not None:
        lot = dollar_risk / (pip_risk * pip_val)
        info = mt5.symbol_info(symbol)
        if info is not None:
            try:
                step = float(info.volume_step)
                if step>0:
                    lot = round(lot/step)*step
            except Exception:
                pass
        lot = max(MIN_LOT, min(MAX_LOT, round(lot,2)))
        return lot
    # fallback
    pip_size = 0.01 if "JPY" in symbol else 0.0001
    pip_value_per_lot = (pip_size / entry_price) * 100000.0
    lot = dollar_risk / (pip_risk * pip_value_per_lot)
    lot = max(MIN_LOT, min(MAX_LOT, round(lot,2)))
    return lot


In [34]:
# Cell 7 — Place order wrapper (dry_run default True)
def place_order_on_mt5(symbol, direction, lot, sl, tp=None, dry_run=DRY_RUN_DEFAULT, deviation=DEVIATION, magic=MAGIC, comment="rl_live"):
    """
    direction: "BUY" or "SELL"
    Returns MT5 response or a dict if dry_run.
    """
    if dry_run:
        req = {
            "action": "DEAL (dry_run)",
            "symbol": symbol,
            "volume": lot,
            "type": direction,
            "price": None,
            "sl": sl,
            "tp": tp,
            "deviation": deviation,
            "magic": magic,
            "comment": comment
        }
        print("DRY RUN order:", req)
        return {"dry_run": True, "request": req}

    if not mt5_init_if_needed():
        raise RuntimeError("MT5 not initialized. Open terminal and login.")
    if not ensure_symbol_published(symbol):
        raise RuntimeError(f"Symbol not available in MT5: {symbol}")

    tick = mt5.symbol_info_tick(symbol)
    if tick is None:
        raise RuntimeError("No tick info for symbol: "+symbol)

    price = tick.ask if direction=="BUY" else tick.bid
    order_type = mt5.ORDER_TYPE_BUY if direction=="BUY" else mt5.ORDER_TYPE_SELL

    request = {
        "action": mt5.TRADE_ACTION_DEAL,
        "symbol": symbol,
        "volume": float(lot),
        "type": order_type,
        "price": price,
        "sl": float(sl),
        "tp": float(tp) if tp is not None else 0.0,
        "deviation": int(deviation),
        "magic": int(magic),
        "comment": comment,
        "type_filling": mt5.ORDER_FILLING_FOK,
    }
    res = mt5.order_send(request)
    return res


In [35]:
# Cell 8 — Load trained model and environment assets
# Model
if not os.path.exists(MODEL_FILE):
    raise FileNotFoundError("Trained model not found at: " + MODEL_FILE)
model = PPO.load(MODEL_FILE)
print("Loaded model:", MODEL_FILE)

# Assets, embeddings, scalers
safe_names, embeddings, scalers = load_assets_and_embeddings(DATA_DIR, EMBED_FILE)
print("Assets:", safe_names)
print("Embeddings shape:", embeddings.shape)


Loaded model: models\multiasset\ppo_multiasset.zip
Assets: ['EURUSD', 'Jump_100_Index', 'Jump_10_Index', 'Jump_25_Index', 'Jump_50_Index', 'Jump_75_Index', 'Volatility_100_1s_Index', 'Volatility_100_Index', 'Volatility_10_1s_Index', 'Volatility_10_Index', 'Volatility_25_1s_Index', 'Volatility_25_Index', 'Volatility_50_1s_Index', 'Volatility_50_Index', 'Volatility_75_1s_Index', 'Volatility_75_Index']
Embeddings shape: (16, 8)


In [36]:
# Cell 9 — Single-pass prediction and optional execution across all assets.
# This performs one sweep and returns a log of attempted/executed orders.

def run_once_live(dry_run=DRY_RUN_DEFAULT, risk_per_trade=RISK_PER_TRADE):
    if not mt5_init_if_needed():
        print("⚠️ Warning: MT5 initialization failed or MT5 not running. Running in offline/dry mode.")
    acct = mt5.account_info() if mt5_init_if_needed() else None
    balance = float(acct.balance) if acct is not None else 10000.0

    records = []
    for sym_safe in safe_names:
        # Try to reconstruct broker symbol name: assume safe==symbol; if your safe names differ, update asset_to_symbol mapping
        symbol = sym_safe
        print("\n---", symbol, "---")
        obs, vol_est, last_price = fetch_and_build_obs(symbol, window=WINDOW, scalers=scalers, embeddings=embeddings, safe_names=safe_names)
        if obs is None:
            print("Insufficient data for", symbol, "- skipping.")
            continue

        # model.predict expects shape like env obs. Ensure dims align:
        try:
            action, _ = model.predict(obs, deterministic=True)
        except Exception as e:
            # try flatten
            try:
                action, _ = model.predict(obs.flatten(), deterministic=True)
            except Exception as e2:
                print("Model predict failed for", symbol, ":", e, e2)
                continue

        act = extract_action_scalar(action)
        if act == 0:
            print(symbol, "→ HOLD")
            records.append({"symbol":symbol, "action":"HOLD"})
            continue

        direction = "BUY" if act==1 else "SELL"
        # compute stoploss/TP heuristics (based on vol_est)
        if vol_est is None or vol_est<=0:
            sl_dist = 0.001 * last_price
        else:
            sl_dist = max(1.5 * vol_est * last_price, 0.0005 * last_price)
        sl = last_price - sl_dist if direction=="BUY" else last_price + sl_dist
        tp = last_price + (2.5 * vol_est * last_price) if direction=="BUY" else last_price - (2.5 * vol_est * last_price)

        lot = calculate_lot_size_live(symbol, balance, risk_per_trade, last_price, sl)
        lot = max(MIN_LOT, min(MAX_LOT, round(lot, 2)))

        # place order
        res = place_order_on_mt5(symbol, direction, lot, sl, tp, dry_run=dry_run)
        rec = {
            "timestamp": datetime.utcnow().isoformat(),
            "symbol": symbol,
            "action": direction,
            "lot": lot,
            "price": last_price,
            "sl": sl,
            "tp": tp,
            "result": str(getattr(res, "retcode", res))
        }
        records.append(rec)
        print("Planned trade:", rec)
    # save records
    out_csv = os.path.join(LOG_DIR, f"live_run_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}.csv")
    pd.DataFrame(records).to_csv(out_csv, index=False)
    print("Saved run log to", out_csv)
    return records

# Example single-run (dry-run): change dry_run=False to execute live (be cautious!)
records = run_once_live(dry_run=False)



--- EURUSD ---


TypeError: list indices must be integers or slices, not str

In [20]:
# Cell 10 — Continuous loop for live trading (uncomment to enable). KEEP dry_run=True until fully tested.
# WARNING: Live trading will place real orders if dry_run is False. Test on demo account first.

def run_continuous(interval_s=60, dry_run=DRY_RUN_DEFAULT, risk_per_trade=RISK_PER_TRADE, max_iterations=None):
    print("Starting continuous live loop — dry_run =", dry_run)
    iteration = 0
    try:
        while True:
            iteration += 1
            print(f"\n=== Iteration {iteration} @ {datetime.utcnow().isoformat()} ===")
            run_once_live(dry_run=dry_run, risk_per_trade=risk_per_trade)
            if max_iterations is not None and iteration >= max_iterations:
                print("Max iterations reached; stopping.")
                break
            time.sleep(interval_s)
    except KeyboardInterrupt:
        print("Stopped by user (KeyboardInterrupt).")
    except Exception as e:
        print("Loop error:", e)
    finally:
        mt5_shutdown_if_needed()
        print("MT5 connection closed.")

# Example usage (commented): run_continuous(interval_s=60, dry_run=True, max_iterations=10)
run_continuous(interval_s=60, dry_run=True, max_iterations=10)

Starting continuous live loop — dry_run = True

=== Iteration 1 @ 2025-11-16T08:26:34.948204 ===

--- EURUSD ---
Insufficient data for EURUSD - skipping.

--- Jump_100_Index ---
Insufficient data for Jump_100_Index - skipping.

--- Jump_10_Index ---
Insufficient data for Jump_10_Index - skipping.

--- Jump_25_Index ---
Insufficient data for Jump_25_Index - skipping.

--- Jump_50_Index ---
Insufficient data for Jump_50_Index - skipping.

--- Jump_75_Index ---
Insufficient data for Jump_75_Index - skipping.

--- Volatility_100_1s_Index ---
Insufficient data for Volatility_100_1s_Index - skipping.

--- Volatility_100_Index ---
Insufficient data for Volatility_100_Index - skipping.

--- Volatility_10_1s_Index ---
Insufficient data for Volatility_10_1s_Index - skipping.

--- Volatility_10_Index ---
Insufficient data for Volatility_10_Index - skipping.

--- Volatility_25_1s_Index ---
Insufficient data for Volatility_25_1s_Index - skipping.

--- Volatility_25_Index ---
Insufficient data for V

  print(f"\n=== Iteration {iteration} @ {datetime.utcnow().isoformat()} ===")
  out_csv = os.path.join(LOG_DIR, f"live_run_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}.csv")
  print(f"\n=== Iteration {iteration} @ {datetime.utcnow().isoformat()} ===")
  out_csv = os.path.join(LOG_DIR, f"live_run_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}.csv")



=== Iteration 2 @ 2025-11-16T08:27:34.976006 ===

--- EURUSD ---
Insufficient data for EURUSD - skipping.

--- Jump_100_Index ---
Insufficient data for Jump_100_Index - skipping.

--- Jump_10_Index ---
Insufficient data for Jump_10_Index - skipping.

--- Jump_25_Index ---
Insufficient data for Jump_25_Index - skipping.

--- Jump_50_Index ---
Insufficient data for Jump_50_Index - skipping.

--- Jump_75_Index ---
Insufficient data for Jump_75_Index - skipping.

--- Volatility_100_1s_Index ---
Insufficient data for Volatility_100_1s_Index - skipping.

--- Volatility_100_Index ---
Insufficient data for Volatility_100_Index - skipping.

--- Volatility_10_1s_Index ---
Insufficient data for Volatility_10_1s_Index - skipping.

--- Volatility_10_Index ---
Insufficient data for Volatility_10_Index - skipping.

--- Volatility_25_1s_Index ---
Insufficient data for Volatility_25_1s_Index - skipping.

--- Volatility_25_Index ---
Insufficient data for Volatility_25_Index - skipping.

--- Volatility_

  print(f"\n=== Iteration {iteration} @ {datetime.utcnow().isoformat()} ===")



=== Iteration 3 @ 2025-11-16T08:28:49.037744 ===

--- EURUSD ---
Insufficient data for EURUSD - skipping.

--- Jump_100_Index ---
Insufficient data for Jump_100_Index - skipping.

--- Jump_10_Index ---
Insufficient data for Jump_10_Index - skipping.

--- Jump_25_Index ---
Insufficient data for Jump_25_Index - skipping.

--- Jump_50_Index ---
Insufficient data for Jump_50_Index - skipping.

--- Jump_75_Index ---
Insufficient data for Jump_75_Index - skipping.

--- Volatility_100_1s_Index ---
Insufficient data for Volatility_100_1s_Index - skipping.

--- Volatility_100_Index ---
Insufficient data for Volatility_100_Index - skipping.

--- Volatility_10_1s_Index ---
Insufficient data for Volatility_10_1s_Index - skipping.

--- Volatility_10_Index ---
Insufficient data for Volatility_10_Index - skipping.

--- Volatility_25_1s_Index ---
Insufficient data for Volatility_25_1s_Index - skipping.

--- Volatility_25_Index ---
Insufficient data for Volatility_25_Index - skipping.

--- Volatility_

  out_csv = os.path.join(LOG_DIR, f"live_run_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}.csv")
  print(f"\n=== Iteration {iteration} @ {datetime.utcnow().isoformat()} ===")
  out_csv = os.path.join(LOG_DIR, f"live_run_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}.csv")



=== Iteration 4 @ 2025-11-16T08:33:44.028995 ===

--- EURUSD ---
Insufficient data for EURUSD - skipping.

--- Jump_100_Index ---
Insufficient data for Jump_100_Index - skipping.

--- Jump_10_Index ---
Insufficient data for Jump_10_Index - skipping.

--- Jump_25_Index ---
Insufficient data for Jump_25_Index - skipping.

--- Jump_50_Index ---
Insufficient data for Jump_50_Index - skipping.

--- Jump_75_Index ---
Insufficient data for Jump_75_Index - skipping.

--- Volatility_100_1s_Index ---
Insufficient data for Volatility_100_1s_Index - skipping.

--- Volatility_100_Index ---
Insufficient data for Volatility_100_Index - skipping.

--- Volatility_10_1s_Index ---
Insufficient data for Volatility_10_1s_Index - skipping.

--- Volatility_10_Index ---
Insufficient data for Volatility_10_Index - skipping.

--- Volatility_25_1s_Index ---
Insufficient data for Volatility_25_1s_Index - skipping.

--- Volatility_25_Index ---
Insufficient data for Volatility_25_Index - skipping.

--- Volatility_

  print(f"\n=== Iteration {iteration} @ {datetime.utcnow().isoformat()} ===")
  out_csv = os.path.join(LOG_DIR, f"live_run_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}.csv")



=== Iteration 5 @ 2025-11-16T08:34:44.058892 ===

--- EURUSD ---
Insufficient data for EURUSD - skipping.

--- Jump_100_Index ---
Insufficient data for Jump_100_Index - skipping.

--- Jump_10_Index ---
Insufficient data for Jump_10_Index - skipping.

--- Jump_25_Index ---
Insufficient data for Jump_25_Index - skipping.

--- Jump_50_Index ---
Insufficient data for Jump_50_Index - skipping.

--- Jump_75_Index ---
Insufficient data for Jump_75_Index - skipping.

--- Volatility_100_1s_Index ---
Insufficient data for Volatility_100_1s_Index - skipping.

--- Volatility_100_Index ---
Insufficient data for Volatility_100_Index - skipping.

--- Volatility_10_1s_Index ---
Insufficient data for Volatility_10_1s_Index - skipping.

--- Volatility_10_Index ---
Insufficient data for Volatility_10_Index - skipping.

--- Volatility_25_1s_Index ---
Insufficient data for Volatility_25_1s_Index - skipping.

--- Volatility_25_Index ---
Insufficient data for Volatility_25_Index - skipping.

--- Volatility_

  print(f"\n=== Iteration {iteration} @ {datetime.utcnow().isoformat()} ===")
  out_csv = os.path.join(LOG_DIR, f"live_run_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}.csv")



=== Iteration 6 @ 2025-11-16T08:35:44.084677 ===

--- EURUSD ---
Insufficient data for EURUSD - skipping.

--- Jump_100_Index ---
Insufficient data for Jump_100_Index - skipping.

--- Jump_10_Index ---
Insufficient data for Jump_10_Index - skipping.

--- Jump_25_Index ---
Insufficient data for Jump_25_Index - skipping.

--- Jump_50_Index ---
Insufficient data for Jump_50_Index - skipping.

--- Jump_75_Index ---
Insufficient data for Jump_75_Index - skipping.

--- Volatility_100_1s_Index ---
Insufficient data for Volatility_100_1s_Index - skipping.

--- Volatility_100_Index ---
Insufficient data for Volatility_100_Index - skipping.

--- Volatility_10_1s_Index ---
Insufficient data for Volatility_10_1s_Index - skipping.

--- Volatility_10_Index ---
Insufficient data for Volatility_10_Index - skipping.

--- Volatility_25_1s_Index ---
Insufficient data for Volatility_25_1s_Index - skipping.

--- Volatility_25_Index ---
Insufficient data for Volatility_25_Index - skipping.

--- Volatility_

  print(f"\n=== Iteration {iteration} @ {datetime.utcnow().isoformat()} ===")
  out_csv = os.path.join(LOG_DIR, f"live_run_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}.csv")



=== Iteration 7 @ 2025-11-16T08:36:44.108793 ===

--- EURUSD ---
Insufficient data for EURUSD - skipping.

--- Jump_100_Index ---
Insufficient data for Jump_100_Index - skipping.

--- Jump_10_Index ---
Insufficient data for Jump_10_Index - skipping.

--- Jump_25_Index ---
Insufficient data for Jump_25_Index - skipping.

--- Jump_50_Index ---
Insufficient data for Jump_50_Index - skipping.

--- Jump_75_Index ---
Insufficient data for Jump_75_Index - skipping.

--- Volatility_100_1s_Index ---
Insufficient data for Volatility_100_1s_Index - skipping.

--- Volatility_100_Index ---
Insufficient data for Volatility_100_Index - skipping.

--- Volatility_10_1s_Index ---
Insufficient data for Volatility_10_1s_Index - skipping.

--- Volatility_10_Index ---
Insufficient data for Volatility_10_Index - skipping.

--- Volatility_25_1s_Index ---
Insufficient data for Volatility_25_1s_Index - skipping.

--- Volatility_25_Index ---
Insufficient data for Volatility_25_Index - skipping.

--- Volatility_

  print(f"\n=== Iteration {iteration} @ {datetime.utcnow().isoformat()} ===")
  out_csv = os.path.join(LOG_DIR, f"live_run_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}.csv")



=== Iteration 8 @ 2025-11-16T08:37:44.139535 ===

--- EURUSD ---
Insufficient data for EURUSD - skipping.

--- Jump_100_Index ---
Insufficient data for Jump_100_Index - skipping.

--- Jump_10_Index ---
Insufficient data for Jump_10_Index - skipping.

--- Jump_25_Index ---
Insufficient data for Jump_25_Index - skipping.

--- Jump_50_Index ---
Insufficient data for Jump_50_Index - skipping.

--- Jump_75_Index ---
Insufficient data for Jump_75_Index - skipping.

--- Volatility_100_1s_Index ---
Insufficient data for Volatility_100_1s_Index - skipping.

--- Volatility_100_Index ---
Insufficient data for Volatility_100_Index - skipping.

--- Volatility_10_1s_Index ---
Insufficient data for Volatility_10_1s_Index - skipping.

--- Volatility_10_Index ---
Insufficient data for Volatility_10_Index - skipping.

--- Volatility_25_1s_Index ---
Insufficient data for Volatility_25_1s_Index - skipping.

--- Volatility_25_Index ---
Insufficient data for Volatility_25_Index - skipping.

--- Volatility_

  print(f"\n=== Iteration {iteration} @ {datetime.utcnow().isoformat()} ===")
  out_csv = os.path.join(LOG_DIR, f"live_run_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}.csv")



=== Iteration 9 @ 2025-11-16T08:38:44.167804 ===

--- EURUSD ---
Insufficient data for EURUSD - skipping.

--- Jump_100_Index ---
Insufficient data for Jump_100_Index - skipping.

--- Jump_10_Index ---
Insufficient data for Jump_10_Index - skipping.

--- Jump_25_Index ---
Insufficient data for Jump_25_Index - skipping.

--- Jump_50_Index ---
Insufficient data for Jump_50_Index - skipping.

--- Jump_75_Index ---
Insufficient data for Jump_75_Index - skipping.

--- Volatility_100_1s_Index ---
Insufficient data for Volatility_100_1s_Index - skipping.

--- Volatility_100_Index ---
Insufficient data for Volatility_100_Index - skipping.

--- Volatility_10_1s_Index ---
Insufficient data for Volatility_10_1s_Index - skipping.

--- Volatility_10_Index ---
Insufficient data for Volatility_10_Index - skipping.

--- Volatility_25_1s_Index ---
Insufficient data for Volatility_25_1s_Index - skipping.

--- Volatility_25_Index ---
Insufficient data for Volatility_25_Index - skipping.

--- Volatility_

  print(f"\n=== Iteration {iteration} @ {datetime.utcnow().isoformat()} ===")
  out_csv = os.path.join(LOG_DIR, f"live_run_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}.csv")


In [10]:
# Cell 11 — Visualize logs and compute simple metrics
def summarize_run_log(csv_path):
    df = pd.read_csv(csv_path)
    total_trades = len(df[df['action']!="HOLD"])
    if total_trades==0:
        print("No trades in log.")
    print("Run summary:", csv_path)
    print("Total rows:", len(df), "Trades executed/planned:", total_trades)
    return df

def plot_price_with_signals(symbol, run_records_csv):
    df_log = pd.read_csv(run_records_csv)
    df_sym = df_log[df_log['symbol']==symbol]
    if df_sym.empty:
        print("No records for", symbol)
        return
    # fetch historical price for plotting (recent window)
    bars = mt5.copy_rates_from_pos(symbol, TF_MT5, 0, WINDOW*5)
    if bars is None:
        print("No bars to plot for", symbol)
        return
    d = pd.DataFrame(bars)
    d['time'] = pd.to_datetime(d['time'], unit='s')
    d.set_index('time', inplace=True)
    plt.figure(figsize=(12,5))
    plt.plot(d['close'], label='close')
    for _, row in df_sym.iterrows():
        # mark entry price
        plt.axhline(row['price'], linestyle='--', alpha=0.6)
    plt.title(f"{symbol} recent closes and signals")
    plt.legend()
    plt.show()

# Usage: df = summarize_run_log("models/multiasset/live_logs/live_run_YYYY...csv")
# plot_price_with_signals("Volatility_75_Index", "path-to-log.csv")


In [11]:
# Cell 12 — Safe shutdown and final notes
print("To run a single dry-run pass: records = run_once_live(dry_run=True)")
print("To run continuous (demo): run_continuous(interval_s=60, dry_run=True, max_iterations=10)")
print("To execute live for real orders set dry_run=False in run_once_live() or run_continuous(). DO NOT DO THIS UNTIL FULLY TESTED ON DEMO ACCOUNT.")
# Do not auto-init MT5 here — user will call run_* functions to start.


To run a single dry-run pass: records = run_once_live(dry_run=True)
To run continuous (demo): run_continuous(interval_s=60, dry_run=True, max_iterations=10)
To execute live for real orders set dry_run=False in run_once_live() or run_continuous(). DO NOT DO THIS UNTIL FULLY TESTED ON DEMO ACCOUNT.


In [None]:
# ==================================================
# Cell 1 — Install dependencies (run once if needed)
# ==================================================
import sys
!{sys.executable} -m pip install --upgrade pip setuptools wheel
!{sys.executable} -m pip install --quiet numpy pandas matplotlib seaborn MetaTrader5 stable-baselines3 gymnasium==0.29.1 torch ffmpeg-python pillow

# ==================================================
# Cell 2 — Imports & configuration
# ==================================================
import os, glob, json, time
from datetime import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("darkgrid")

import MetaTrader5 as mt5
import gymnasium as gym

from stable_baselines3 import PPO

# ==================================================
# Cell 3 — Paths (adjust if different)
# ==================================================
DATA_DIR = os.path.join("data", "multiasset")
MODEL_DIR = os.path.join("models", "multiasset")
MODEL_FILE = os.path.join(MODEL_DIR, "ppo_multiasset.zip")
EMBED_FILE = os.path.join(MODEL_DIR, "asset_embeddings.npy")
ASSET_MAP_FILE = os.path.join(DATA_DIR, "asset_to_idx.csv")
SCALER_GLOB = os.path.join(DATA_DIR, "*_scaler.csv")  # per-symbol scalers
LOG_DIR = os.path.join(MODEL_DIR, "live_logs")
os.makedirs(LOG_DIR, exist_ok=True)

# ==================================================
# Cell 4 — Live configuration
# ==================================================
TIMEFRAME = "M15"        # timeframe to request from MT5
WINDOW = 50
RISK_PER_TRADE = 0.02   # default risk per trade (1%)
MIN_LOT = 0.01
MAX_LOT = 10.0
DEVIATION = 20
MAGIC = 2025001
DRY_RUN_DEFAULT = False  # safety default

# ==================================================
# Cell 5 — Helpers: safe action extraction, load datasets, scalers, embeddings
# ==================================================
def extract_action_scalar(action):
    import numpy as _np
    if isinstance(action, (int, _np.integer)):
        return int(action)
    a = _np.array(action)
    if a.size == 1:
        return int(a.flatten()[0])
    return int(a.flatten()[0])

def load_scalers(data_dir=DATA_DIR):
    scalers = {}
    for p in glob.glob(os.path.join(data_dir, "*_scaler.csv")):
        safe = os.path.basename(p).replace("_scaler.csv","")
        try:
            df = pd.read_csv(p, index_col=0)
            scalers[safe] = {"mean": df['mean'].astype(float), "std": df['std'].astype(float)}
        except Exception:
            scalers[safe] = None
    return scalers

def load_assets_and_embeddings(data_dir=DATA_DIR, embed_file=EMBED_FILE):
    # load asset order from asset_to_idx.csv if present
    asset_map = {}
    if os.path.exists(ASSET_MAP_FILE):
        try:
            s = pd.read_csv(ASSET_MAP_FILE, index_col=0, squeeze=True)
            asset_map = (s.to_dict() if hasattr(s, "to_dict") else dict(s))
            # s might be safe->idx or vice-versa; normalize
            # produce ordered list by idx
            try:
                ordered = sorted(asset_map.items(), key=lambda x:int(x[1]))
                safes = [k for k,_ in ordered]
            except Exception:
                safes = list(asset_map.keys())
        except Exception:
            safes = []
    else:
        # fallback: read normalized csv names
        safes = [os.path.basename(p).replace("_normalized.csv","") for p in sorted(glob.glob(os.path.join(DATA_DIR,"*_normalized.csv")))]
    embeddings = np.load(embed_file) if os.path.exists(embed_file) else np.zeros((len(safes), 1), dtype=np.float32)
    # build canonical list
    safes = safes if len(safes)>0 else list({os.path.basename(p).replace("_normalized.csv",""):None for p in glob.glob(os.path.join(DATA_DIR,"*_normalized.csv"))}.keys())
    return safes, embeddings, load_scalers(data_dir)

# ==================================================
# Cell 6 — MT5 init helpers (safe)
# ==================================================
def mt5_init_if_needed():
    try:
        if not mt5.initialize():
            # sometimes initialize() returns False but terminal still ok; give another try
            return mt5.initialize()
        return True
    except Exception:
        return False

def mt5_shutdown_if_needed():
    try:
        mt5.shutdown()
    except Exception:
        pass


# ==================================================
# Cell 7 — symbol_published (symbol)
# ==================================================
def ensure_symbol_published(symbol):
    """Ensure symbol is visible in Market Watch; attempt to add if not."""
    try:
        sinfo = mt5.symbol_info(symbol)
        if sinfo is None:
            # try to enable the symbol
            mt5.symbol_select(symbol, True)
            sinfo = mt5.symbol_info(symbol)
        return sinfo is not None
    except Exception:
        return False

# ==================================================
# Cell 8 — Time frame mapping
# ==================================================
# Cell 5 — Build observation for a symbol in the same format used in training
def timeframe_to_mt5(tf_str):
    TF_MAP = {
        "M1": mt5.TIMEFRAME_M1, "M5": mt5.TIMEFRAME_M5, "M15": mt5.TIMEFRAME_M15,
        "M30": mt5.TIMEFRAME_M30, "H1": mt5.TIMEFRAME_H1, "H4": mt5.TIMEFRAME_H4,
        "D1": mt5.TIMEFRAME_D1
    }
    return TF_MAP.get(tf_str.upper(), mt5.TIMEFRAME_M15)

TF_MT5 = timeframe_to_mt5(TIMEFRAME)

# ==================================================
# Cell 8 — 
# ==================================================
def fetch_and_build_obs(symbol, window, scalers, embeddings, safe_names, fallback_embedding_dim=None):
    """
    Build observation window for a symbol using:
    - normalized percent-change features
    - saved scaler (mean/std)
    - saved embedding vector

    Parameters:
        symbol (str): raw symbol, e.g. "EURUSD"
        window (int): sliding window length
        scalers (dict): {safe: {"mean": Series, "std": Series}}
        embeddings (dict): {safe: np.ndarray}
        safe_names (dict): {raw_symbol: safe_symbol}
        fallback_embedding_dim (int or None): optional zero-vector if embedding missing

    Returns:
        obs_window (np.ndarray)
        scaler_info (dict)
        embedding (np.ndarray)
    """

    # -----------------------------------------------------------
    # 1) Convert symbol → safe-name
    # -----------------------------------------------------------
    if symbol not in safe_names:
        print(f"❌ No safe-name found for symbol: {symbol}")
        return None, None, None

    safe = safe_names[symbol]

    # -----------------------------------------------------------
    # 2) Check required preprocessing objects
    # -----------------------------------------------------------
    if safe not in scalers:
        print(f"❌ Missing scaler for: {safe}")
        return None, None, None

    if safe not in embeddings:
        if fallback_embedding_dim is None:
            print(f"❌ Missing embedding for: {safe}")
            return None, None, None
        else:
            print(f"⚠️ Missing embedding → using zero vector for {safe}")
            embedding = np.zeros(fallback_embedding_dim)
    else:
        embedding = embeddings[safe]

    scaler_info = scalers[safe]

    # Extract mean / std
    mean = scaler_info["mean"]
    std = scaler_info["std"].replace(0, 1.0)

    # -----------------------------------------------------------
    # 3) Load raw normalized dataset for this symbol
    # -----------------------------------------------------------
    if safe not in datasets:
        print(f"❌ No dataset loaded for: {safe}")
        return None, None, None

    df = datasets[safe].copy()

    # Must contain the required columns
    required_columns = ["o_pc", "h_pc", "l_pc", "c_pc", "v_pc"]
    for col in required_columns:
        if col not in df.columns:
            print(f"❌ Missing required column '{col}' in dataset for {safe}")
            return None, None, None

    # -----------------------------------------------------------
    # 4) Apply normalization using stored scaler
    # -----------------------------------------------------------
    df_norm = (df[required_columns] - mean[required_columns]) / std[required_columns]

    # -----------------------------------------------------------
    # 5) Extract last window
    # -----------------------------------------------------------
    if len(df_norm) < window:
        print(f"❌ Not enough data for window={window} | {safe} has {len(df_norm)} rows")
        return None, None, None

    obs_window = df_norm.tail(window).values.astype(np.float32)

    # -----------------------------------------------------------
    # 6) Append embedding at each timestep (optional architecture)
    # -----------------------------------------------------------
    embedding_repeated = np.tile(embedding, (window, 1))
    obs_window = np.concatenate([obs_window, embedding_repeated], axis=1)

    return obs_window, scaler_info, embedding

# Cell 6 — Lot-size helpers (MT5-aware)
def pip_value_per_lot_from_mt5_live(symbol, entry_price):
    info = None
    try:
        info = mt5.symbol_info(symbol)
    except Exception:
        info = None
    if info is None:
        return None, None
    try:
        pip_val = info.trade_tick_value / info.trade_tick_size
    except Exception:
        contract_size = getattr(info, "trade_contract_size", 100000.0)
        pip_val = (info.point / entry_price) * contract_size
    return float(pip_val), float(info.point)

def calculate_lot_size_live(symbol, balance, risk_percent, entry_price, stop_loss_price):
    pip_val, point = pip_value_per_lot_from_mt5_live(symbol, entry_price)
    if point and point not in (0,None):
        pip_risk = abs(entry_price - stop_loss_price) / point
    else:
        pip_risk = abs(entry_price - stop_loss_price) / (0.01 if "JPY" in symbol else 0.0001)
    if pip_risk <= 0:
        return MIN_LOT
    dollar_risk = balance * float(risk_percent)
    if pip_val is not None:
        lot = dollar_risk / (pip_risk * pip_val)
        info = mt5.symbol_info(symbol)
        if info is not None:
            try:
                step = float(info.volume_step)
                if step>0:
                    lot = round(lot/step)*step
            except Exception:
                pass
        lot = max(MIN_LOT, min(MAX_LOT, round(lot,2)))
        return lot
    # fallback
    pip_size = 0.01 if "JPY" in symbol else 0.0001
    pip_value_per_lot = (pip_size / entry_price) * 100000.0
    lot = dollar_risk / (pip_risk * pip_value_per_lot)
    lot = max(MIN_LOT, min(MAX_LOT, round(lot,2)))
    return lot

# Cell 7 — Place order wrapper (dry_run default True)
def place_order_on_mt5(symbol, direction, lot, sl, tp=None, dry_run=DRY_RUN_DEFAULT, deviation=DEVIATION, magic=MAGIC, comment="rl_live"):
    """
    direction: "BUY" or "SELL"
    Returns MT5 response or a dict if dry_run.
    """
    if dry_run:
        req = {
            "action": "DEAL (dry_run)",
            "symbol": symbol,
            "volume": lot,
            "type": direction,
            "price": None,
            "sl": sl,
            "tp": tp,
            "deviation": deviation,
            "magic": magic,
            "comment": comment
        }
        print("DRY RUN order:", req)
        return {"dry_run": True, "request": req}

    if not mt5_init_if_needed():
        raise RuntimeError("MT5 not initialized. Open terminal and login.")
    if not ensure_symbol_published(symbol):
        raise RuntimeError(f"Symbol not available in MT5: {symbol}")

    tick = mt5.symbol_info_tick(symbol)
    if tick is None:
        raise RuntimeError("No tick info for symbol: "+symbol)

    price = tick.ask if direction=="BUY" else tick.bid
    order_type = mt5.ORDER_TYPE_BUY if direction=="BUY" else mt5.ORDER_TYPE_SELL

    request = {
        "action": mt5.TRADE_ACTION_DEAL,
        "symbol": symbol,
        "volume": float(lot),
        "type": order_type,
        "price": price,
        "sl": float(sl),
        "tp": float(tp) if tp is not None else 0.0,
        "deviation": int(deviation),
        "magic": int(magic),
        "comment": comment,
        "type_filling": mt5.ORDER_FILLING_FOK,
    }
    res = mt5.order_send(request)
    return res

# Cell 8 — Load trained model and environment assets
# Model
if not os.path.exists(MODEL_FILE):
    raise FileNotFoundError("Trained model not found at: " + MODEL_FILE)
model = PPO.load(MODEL_FILE)
print("Loaded model:", MODEL_FILE)

# Assets, embeddings, scalers
safe_names, embeddings, scalers = load_assets_and_embeddings(DATA_DIR, EMBED_FILE)
print("Assets:", safe_names)
print("Embeddings shape:", embeddings.shape)

# Cell 9 — Single-pass prediction and optional execution across all assets.
# This performs one sweep and returns a log of attempted/executed orders.

def run_once_live(dry_run=DRY_RUN_DEFAULT, risk_per_trade=RISK_PER_TRADE):
    if not mt5_init_if_needed():
        print("⚠️ Warning: MT5 initialization failed or MT5 not running. Running in offline/dry mode.")
    acct = mt5.account_info() if mt5_init_if_needed() else None
    balance = float(acct.balance) if acct is not None else 10000.0

    records = []
    for sym_safe in safe_names:
        # Try to reconstruct broker symbol name: assume safe==symbol; if your safe names differ, update asset_to_symbol mapping
        symbol = sym_safe
        print("\n---", symbol, "---")
        obs, vol_est, last_price = fetch_and_build_obs(symbol, window=WINDOW, scalers=scalers, embeddings=embeddings, safe_names=safe_names)
        if obs is None:
            print("Insufficient data for", symbol, "- skipping.")
            continue

        # model.predict expects shape like env obs. Ensure dims align:
        try:
            action, _ = model.predict(obs, deterministic=True)
        except Exception as e:
            # try flatten
            try:
                action, _ = model.predict(obs.flatten(), deterministic=True)
            except Exception as e2:
                print("Model predict failed for", symbol, ":", e, e2)
                continue

        act = extract_action_scalar(action)
        if act == 0:
            print(symbol, "→ HOLD")
            records.append({"symbol":symbol, "action":"HOLD"})
            continue

        direction = "BUY" if act==1 else "SELL"
        # compute stoploss/TP heuristics (based on vol_est)
        if vol_est is None or vol_est<=0:
            sl_dist = 0.001 * last_price
        else:
            sl_dist = max(1.5 * vol_est * last_price, 0.0005 * last_price)
        sl = last_price - sl_dist if direction=="BUY" else last_price + sl_dist
        tp = last_price + (2.5 * vol_est * last_price) if direction=="BUY" else last_price - (2.5 * vol_est * last_price)

        lot = calculate_lot_size_live(symbol, balance, risk_per_trade, last_price, sl)
        lot = max(MIN_LOT, min(MAX_LOT, round(lot, 2)))

        # place order
        res = place_order_on_mt5(symbol, direction, lot, sl, tp, dry_run=dry_run)
        rec = {
            "timestamp": datetime.utcnow().isoformat(),
            "symbol": symbol,
            "action": direction,
            "lot": lot,
            "price": last_price,
            "sl": sl,
            "tp": tp,
            "result": str(getattr(res, "retcode", res))
        }
        records.append(rec)
        print("Planned trade:", rec)
    # save records
    out_csv = os.path.join(LOG_DIR, f"live_run_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}.csv")
    pd.DataFrame(records).to_csv(out_csv, index=False)
    print("Saved run log to", out_csv)
    return records

# Example single-run (dry-run): change dry_run=False to execute live (be cautious!)
records = run_once_live(dry_run=False)

# Cell 10 — Continuous loop for live trading (uncomment to enable). KEEP dry_run=True until fully tested.
# WARNING: Live trading will place real orders if dry_run is False. Test on demo account first.

def run_continuous(interval_s=60, dry_run=DRY_RUN_DEFAULT, risk_per_trade=RISK_PER_TRADE, max_iterations=None):
    print("Starting continuous live loop — dry_run =", dry_run)
    iteration = 0
    try:
        while True:
            iteration += 1
            print(f"\n=== Iteration {iteration} @ {datetime.utcnow().isoformat()} ===")
            run_once_live(dry_run=dry_run, risk_per_trade=risk_per_trade)
            if max_iterations is not None and iteration >= max_iterations:
                print("Max iterations reached; stopping.")
                break
            time.sleep(interval_s)
    except KeyboardInterrupt:
        print("Stopped by user (KeyboardInterrupt).")
    except Exception as e:
        print("Loop error:", e)
    finally:
        mt5_shutdown_if_needed()
        print("MT5 connection closed.")

# Example usage (commented): run_continuous(interval_s=60, dry_run=True, max_iterations=10)
run_continuous(interval_s=60, dry_run=True, max_iterations=10)

# Cell 11 — Visualize logs and compute simple metrics
def summarize_run_log(csv_path):
    df = pd.read_csv(csv_path)
    total_trades = len(df[df['action']!="HOLD"])
    if total_trades==0:
        print("No trades in log.")
    print("Run summary:", csv_path)
    print("Total rows:", len(df), "Trades executed/planned:", total_trades)
    return df

def plot_price_with_signals(symbol, run_records_csv):
    df_log = pd.read_csv(run_records_csv)
    df_sym = df_log[df_log['symbol']==symbol]
    if df_sym.empty:
        print("No records for", symbol)
        return
    # fetch historical price for plotting (recent window)
    bars = mt5.copy_rates_from_pos(symbol, TF_MT5, 0, WINDOW*5)
    if bars is None:
        print("No bars to plot for", symbol)
        return
    d = pd.DataFrame(bars)
    d['time'] = pd.to_datetime(d['time'], unit='s')
    d.set_index('time', inplace=True)
    plt.figure(figsize=(12,5))
    plt.plot(d['close'], label='close')
    for _, row in df_sym.iterrows():
        # mark entry price
        plt.axhline(row['price'], linestyle='--', alpha=0.6)
    plt.title(f"{symbol} recent closes and signals")
    plt.legend()
    plt.show()

# Usage: df = summarize_run_log("models/multiasset/live_logs/live_run_YYYY...csv")
# plot_price_with_signals("Volatility_75_Index", "path-to-log.csv")

# Cell 12 — Safe shutdown and final notes
print("To run a single dry-run pass: records = run_once_live(dry_run=True)")
print("To run continuous (demo): run_continuous(interval_s=60, dry_run=True, max_iterations=10)")
print("To execute live for real orders set dry_run=False in run_once_live() or run_continuous(). DO NOT DO THIS UNTIL FULLY TESTED ON DEMO ACCOUNT.")
# Do not auto-init MT5 here — user will call run_* functions to start.
