In [1]:
# Cell 1 — Imports & config
import os, glob, time, json
from datetime import datetime
import numpy as np
import pandas as pd
import MetaTrader5 as mt5
import matplotlib.pyplot as plt
from stable_baselines3 import PPO

# Config paths (adjust if your layout differs)
MODEL_DIR = "models/multiasset"
MODEL_FILE = os.path.join(MODEL_DIR, "ppo_multiasset.zip")   # full path to zip file
EMBED_FILE = os.path.join(MODEL_DIR, "asset_embeddings.npy")
DATA_DIR = os.path.join("data", "multiasset")

LOG_DIR = os.path.join(MODEL_DIR, "live_logs")
os.makedirs(LOG_DIR, exist_ok=True)
LOG_FILE = os.path.join(LOG_DIR, "live_trade_logs.csv")

# Trading / observation settings
WINDOW = 50
TIMEFRAME = "M1"   # human readable timeframe
TF_MAP = { "M1": mt5.TIMEFRAME_M1, "M5": mt5.TIMEFRAME_M5, "M15": mt5.TIMEFRAME_M15,
           "M30": mt5.TIMEFRAME_M30, "H1": mt5.TIMEFRAME_H1, "H4": mt5.TIMEFRAME_H4,
           "D1": mt5.TIMEFRAME_D1 }
TF_MT5 = TF_MAP[TIMEFRAME.upper()]

# Execution / risk params
DRY_RUN = True            # switch to False to actually place orders (test on demo first)
DEFAULT_RISK_PCT = 0.005  # percent of balance risked (heuristic)
MIN_LOT = 0.01
MAX_LOT = 1.0

# SL/TP/trailing (pips)
DEFAULT_SL_PIPS = 20
DEFAULT_TP_PIPS = 40
TRAIL_PIPS = 15

# Limit positions per symbol
MAX_POS_PER_SYMBOL = 2


Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
Users of this version of Gym should be able to simply replace 'import gym' with 'import gymnasium as gym' in the vast majority of cases.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.


In [2]:
# Cell 2 — quick file check
print("Model dir contents:", os.listdir(MODEL_DIR))
print("Has model file:", os.path.exists(MODEL_FILE))
print("Has embedding:", os.path.exists(EMBED_FILE))
print("Scaler CSVs found:", len(glob.glob(os.path.join(DATA_DIR, "*_scaler.csv"))))
print("Normalized CSVs found:", len(glob.glob(os.path.join(DATA_DIR, "*_normalized.csv"))))


Model dir contents: ['asset_embeddings.npy', 'live_logs', 'ppo_multiasset.zip', 'tensorboard', 'vec_normalize.pkl']
Has model file: True
Has embedding: True
Scaler CSVs found: 16
Normalized CSVs found: 16


In [3]:
# Cell 3 — loaders & name helpers
def make_safe_name(sym: str) -> str:
    return sym.replace(" ", "_").replace("/", "_").replace("(", "").replace(")", "").replace(".", "_")

def raw_from_safe(safe: str) -> str:
    # Default mapping: replace underscores with spaces (matches how we saved earlier).
    # If your MT5 symbols differ, modify this function to return the correct MT5 symbol name.
    return safe.replace("_", " ")

# Load scalers (per-asset CSVs)
def load_scalers(data_dir=DATA_DIR):
    scalers = {}
    for p in sorted(glob.glob(os.path.join(data_dir, "*_scaler.csv"))):
        safe = os.path.basename(p).replace("_scaler.csv","")
        df = pd.read_csv(p, index_col=0)
        # expect 'mean' and 'std' columns
        scalers[safe] = {"mean": df["mean"], "std": df["std"].replace(0,1.0)}
    return scalers

# Load normalized datasets (for simulation and index mapping)
def load_normalized_datasets(data_dir=DATA_DIR, window=WINDOW):
    datasets = {}
    for p in sorted(glob.glob(os.path.join(data_dir, "*_normalized.csv"))):
        safe = os.path.basename(p).replace("_normalized.csv","")
        df = pd.read_csv(p, index_col=0, parse_dates=True)
        # Expect percent-change columns; attempt to auto-convert if raw columns exist
        expected = ['o_pc','h_pc','l_pc','c_pc','v_pc','Close_raw']
        if all(c in df.columns for c in expected):
            df = df[expected].dropna()
        else:
            # attempt conversion
            if all(c in df.columns for c in ['open','high','low','close','volume']):
                tmp = pd.DataFrame(index=df.index)
                tmp['o_pc'] = df['open'].pct_change()
                tmp['h_pc'] = df['high'].pct_change()
                tmp['l_pc'] = df['low'].pct_change()
                tmp['c_pc'] = df['close'].pct_change()
                tmp['v_pc'] = df['volume'].pct_change()
                tmp['Close_raw'] = df['close']
                df = tmp.dropna()
            else:
                raise ValueError(f"{p} missing expected columns")
        if len(df) > window:
            datasets[safe] = df
    if not datasets:
        raise FileNotFoundError("No datasets loaded from " + data_dir)
    return datasets

# Load embeddings and map to safe names using asset_to_idx.csv if present
def load_embeddings(embed_file=EMBED_FILE, data_dir=DATA_DIR):
    emb_dict = {}
    if not os.path.exists(embed_file):
        print("No embedding file:", embed_file)
        return emb_dict
    emb = np.load(embed_file, allow_pickle=True)
    # load asset_to_idx map if exists
    map_path = os.path.join(data_dir, "asset_to_idx.csv")
    if os.path.exists(map_path):
        am = pd.read_csv(map_path, index_col=0, header=None).iloc[:,0].to_dict()
        # am may map safe->idx; ensure types right
        # am currently looks like {'Volatility_75_Index':0, ...} or similar
        for safe, idx in am.items():
            idx = int(idx)
            if idx < emb.shape[0]:
                emb_dict[safe] = emb[idx]
    else:
        # fallback: map in order of normalized CSVs if same length
        csvs = sorted(glob.glob(os.path.join(data_dir, "*_normalized.csv")))
        safe_list = [os.path.basename(p).replace("_normalized.csv","") for p in csvs]
        if len(safe_list) == emb.shape[0]:
            for i,safe in enumerate(safe_list):
                emb_dict[safe] = emb[i]
        else:
            print("Warning: embedding length and CSV count mismatch; manual mapping required.")
    return emb_dict

# Load everything
scalers = load_scalers(DATA_DIR)
datasets = load_normalized_datasets(DATA_DIR, WINDOW)
embeddings = load_embeddings(EMBED_FILE, DATA_DIR)

safe_list = list(datasets.keys())
print("Loaded assets:", len(safe_list), safe_list[:10])
print("Loaded scalers:", len(scalers))
print("Loaded embeddings:", len(embeddings))


Loaded assets: 16 ['EURUSD', 'Jump_100_Index', 'Jump_10_Index', 'Jump_25_Index', 'Jump_50_Index', 'Jump_75_Index', 'Volatility_100_1s_Index', 'Volatility_100_Index', 'Volatility_10_1s_Index', 'Volatility_10_Index']
Loaded scalers: 16
Loaded embeddings: 17


In [4]:
# Cell 4 — load model & MT5
# Load PPO model
if not os.path.exists(MODEL_FILE):
    raise FileNotFoundError("Model file not found at: " + MODEL_FILE)
model = PPO.load(MODEL_FILE)
print("Loaded PPO model from", MODEL_FILE)

# Initialize MT5
if not mt5.initialize():
    raise RuntimeError("MT5 initialize failed. Open MetaTrader and login.")
print("MT5 connected:", mt5.version())


Loaded PPO model from models/multiasset\ppo_multiasset.zip
MT5 connected: (500, 5370, '17 Oct 2025')


# Cell 5 — fetch_and_build_obs used for live inference
def fetch_and_build_obs(symbol, window, scalers, embeddings, datasets, safe_list):
    """
    symbol: MT5 symbol name (e.g. 'Volatility 75 Index' or 'EURUSD' depending on broker)
    returns: obs (window x features), vol_est, last_price  OR (None,None,None) on failure
    """
    # Transform symbol to safe key: default mapping raw_from_safe inverse
    # We assume safe = make_safe_name(symbol) (e.g. "Volatility 75 Index" -> "Volatility_75_Index")
    safe = make_safe_name(symbol)

    # Check presence
    if safe not in scalers:
        print(f"❌ Missing scaler for: {safe}")
        return None, None, None
    if safe not in embeddings:
        print(f"⚠️ Missing embedding for: {safe} — using zeros")
        # you can choose to return None or use zeros; we'll use zeros fallback
        emb_vec = np.zeros( (max(1, list(embeddings.values())[0].shape[0]) if embeddings else 8,) )
    else:
        emb_vec = np.array(embeddings[safe], dtype=np.float32)

    # Fetch bars from MT5
    count = window + 30
    bars = mt5.copy_rates_from_pos(symbol, TF_MT5, 0, count)
    if bars is None or len(bars) < window + 2:
        print(f"Insufficient bars for symbol: {symbol} ({0 if bars is None else len(bars)} rows)")
        return None, None, None

    df = pd.DataFrame(bars)
    df['time'] = pd.to_datetime(df['time'], unit='s')
    df = df.set_index('time')
    # rename tick_volume -> volume
    if 'tick_volume' in df.columns:
        df = df.rename(columns={'tick_volume':'volume'})

    # compute pct change and take last window
    pct = df[['open','high','low','close','volume']].pct_change().dropna()
    if len(pct) < window:
        print(f"Not enough pct rows for {symbol}")
        return None, None, None
    pct = pct.tail(window)

    # Normalize using stored scaler
    scaler = scalers[safe]
    cols = ["open","high","low","close","volume"]
    # align mean/std (scaler['mean'] is a Series)
    try:
        mean = scaler["mean"][cols]
        std = scaler["std"][cols].replace(0,1.0)
    except Exception:
        # fallback if scaler is Series with positional order
        mean = pd.Series(scaler["mean"].values[:len(cols)], index=cols)
        std  = pd.Series(scaler["std"].values[:len(cols)], index=cols).replace(0,1.0)

    pct_norm = (pct[cols] - mean) / std
    last_price = float(df['close'].iloc[-1])
    vol_est = float(pct['close'].std())

    # build obs: [pct_norm cols] + embedding repeated + balance_norm + asset_id
    emb_rep = np.tile(emb_vec.reshape(1,-1),(window,1)).astype(np.float32) if emb_vec.size>0 else np.zeros((window,0),dtype=np.float32)
    balance_col = np.full((window,1), 1.0, dtype=np.float32)
    try:
        asset_id_val = safe_list.index(safe) / max(1, len(safe_list))
    except ValueError:
        asset_id_val = 0.0
    asset_col = np.full((window,1), asset_id_val, dtype=np.float32)

    obs = np.concatenate([pct_norm[cols].values.astype(np.float32), emb_rep, balance_col, asset_col], axis=1)
    return obs, vol_est, last_price


def fetch_and_build_obs(symbol, window, scalers, embeddings, datasets, safe_list):
    """
    Live inference observation builder.
    Must match training shape EXACTLY: (window, 14)
    = 5 OHLCV + 1 balance + 8 embedding
    """
    safe = make_safe_name(symbol)

    # Scaler check
    if safe not in scalers:
        print(f"❌ Missing scaler for: {safe}")
        return None, None, None

    # Embedding
    if safe not in embeddings:
        print(f"⚠️ Missing embedding for {safe}, using zeros")
        emb_vec = np.zeros(8, dtype=np.float32)
    else:
        emb_vec = np.array(embeddings[safe], dtype=np.float32)

    # Fetch MT5 bars
    count = window + 30
    bars = mt5.copy_rates_from_pos(symbol, TF_MT5, 0, count)

    if bars is None or len(bars) < window + 2:
        print(f"❌ Insufficient bars for {symbol}")
        return None, None, None

    df = pd.DataFrame(bars)
    df["time"] = pd.to_datetime(df["time"], unit="s")
    df = df.set_index("time")
    df = df.rename(columns={"tick_volume": "volume"})

    # Pct-change window
    pct = df[["open","high","low","close","volume"]].pct_change().dropna()

    if len(pct) < window:
        print(f"❌ Not enough pct rows for {symbol}")
        return None, None, None

    pct = pct.tail(window)

    # Normalize
    scaler = scalers[safe]
    cols = ["open","high","low","close","volume"]

    mean = scaler["mean"][cols]
    std = scaler["std"][cols].replace(0, 1.0)

    pct_norm = (pct[cols] - mean) / std

    # Last price + volatility
    last_price = float(df["close"].iloc[-1])
    vol_est = float(pct["close"].std())

    # Repeat embedding per timestep
    emb_rep = np.tile(emb_vec, (window, 1)).astype(np.float32)

    # Balance ONLY (training uses this)
    balance_col = np.full((window, 1), 1.0, dtype=np.float32)

    # FINAL OBS SHAPE (window, 14)
    obs = np.concatenate(
        [
            pct_norm.values.astype(np.float32),  # 5
            balance_col,                         # 1
            emb_rep                              # 8
        ],
        axis=1
    )

    return obs, vol_est, last_price


In [24]:
def fetch_and_build_obs(raw_symbol, window, scalers, embeddings, datasets, safe_list):
    safe = safe_from_raw(raw_symbol)

    if safe not in datasets:
        print("No dataset for", safe)
        return None, None, None

    df = datasets[safe]

    # Check minimum length
    if len(df) < window:
        print("Not enough data for", safe)
        return None, None, None

    # ---- FIX HERE: Use correct column names ----
    try:
        window_df = df.iloc[-window:]

        features = window_df[['o_pc', 'h_pc', 'l_pc', 'c_pc', 'v_pc']].values
        last_price = float(window_df['Close_raw'].iloc[-1])
    except KeyError as e:
        print("Column missing:", e)
        print("Available columns:", df.columns)
        return None, None, None
    # --------------------------------------------

    # Scale
    scaler = scalers[safe]
    features_scaled = scaler.transform(features)

    # Embedding vector
    embed = embeddings[safe]  # shape (embed_dim,)

    # Final obs shape: (window, 5 + embed_dim)
    obs = np.hstack([features_scaled, np.repeat(embed[np.newaxis, :], window, axis=0)])

    # Compute volatility
    vol = np.std(window_df['c_pc'].values)

    return obs.astype(np.float32), float(vol), last_price


In [25]:
# Cell 6 — execution helpers
def pip_value(symbol):
    # crude: JPY pairs differ
    if "JPY" in symbol or "JPY" in symbol.upper():
        return 0.01
    return 0.0001

def compute_lot_from_balance(balance, vol, price, risk_pct=DEFAULT_RISK_PCT, min_lot=MIN_LOT, max_lot=MAX_LOT):
    # heuristic sizing
    risk_amount = balance * risk_pct
    vol = max(vol, 1e-8)
    price_scale = 1000.0
    lot = risk_amount / (vol * price_scale)
    lot = max(min_lot, min(max_lot, round(lot, 2)))
    return float(lot)

def compute_sl_tp_by_pips(symbol, price, direction, sl_pips=DEFAULT_SL_PIPS, tp_pips=DEFAULT_TP_PIPS):
    pip = pip_value(symbol)
    if direction == "BUY":
        sl = price - sl_pips * pip
        tp = price + tp_pips * pip
    else:
        sl = price + sl_pips * pip
        tp = price - tp_pips * pip
    return float(sl), float(tp)

def trailing_sl_level(symbol, pos, trail_pips=TRAIL_PIPS):
    """
    pos: mt5.Position-like object (position returned from mt5.positions_get)
    returns new_sl_value or None
    """
    tick = mt5.symbol_info_tick(pos.symbol)
    if tick is None: 
        return None
    pip = pip_value(pos.symbol)
    if pos.type == 0:  # BUY position type==0 (depends on MT5 build)
        current = float(tick.bid)
        new_sl = current - trail_pips * pip
        # only move SL up (for buy)
        if pos.sl is None or new_sl > pos.sl:
            return float(new_sl)
    else:  # SELL
        current = float(tick.ask)
        new_sl = current + trail_pips * pip
        if pos.sl is None or new_sl < pos.sl:
            return float(new_sl)
    return None


In [26]:
# Cell 7 — order and management helpers
def get_positions_for_symbol(symbol):
    pos = mt5.positions_get(symbol=symbol)
    return [] if pos is None else list(pos)

def place_market_order(symbol, direction, lot, sl, tp, comment="multiasset_live"):
    """
    direction: "BUY" or "SELL"
    returns result object or simulated dict when DRY_RUN=True
    """
    tick = mt5.symbol_info_tick(symbol)
    if tick is None:
        print("❌ No tick for", symbol)
        return None
    price = float(tick.ask if direction=="BUY" else tick.bid)
    if DRY_RUN:
        return {"retcode": 10009, "price": price, "comment":"DRY_RUN", "direction":direction}
    request = {
        "action": mt5.TRADE_ACTION_DEAL,
        "symbol": symbol,
        "volume": float(lot),
        "type": mt5.ORDER_TYPE_BUY if direction=="BUY" else mt5.ORDER_TYPE_SELL,
        "price": price,
        "sl": float(sl) if sl is not None else 0.0,
        "tp": float(tp) if tp is not None else 0.0,
        "deviation": 20,
        "magic": 234000,
        "comment": comment,
        "type_filling": mt5.ORDER_FILLING_FOK,
    }
    res = mt5.order_send(request)
    return res

def close_position_by_ticket(ticket, volume=None):
    # get position info:
    pos = None
    pos_list = mt5.positions_get(ticket=ticket)
    if pos_list: pos = pos_list[0]
    if pos is None:
        print("No position with ticket", ticket)
        return None
    symbol = pos.symbol
    if pos.type == 0:  # long => close with SELL
        order_type = mt5.ORDER_TYPE_SELL
        price = mt5.symbol_info_tick(symbol).bid
    else:
        order_type = mt5.ORDER_TYPE_BUY
        price = mt5.symbol_info_tick(symbol).ask
    req = {
        "action": mt5.TRADE_ACTION_DEAL,
        "symbol": symbol,
        "volume": float(volume if volume is not None else pos.volume),
        "type": order_type,
        "position": int(ticket),
        "price": price,
        "deviation": 20,
        "magic": 234000,
        "comment": "auto_close",
        "type_filling": mt5.ORDER_FILLING_FOK
    }
    if DRY_RUN:
        return {"retcode": 10009, "comment":"DRY_RUN_CLOSE", "ticket": ticket}
    return mt5.order_send(req)


In [38]:
# Cell 8 — single-pass predictor + order manager
def run_once_predict_and_manage(model, safe_list, scalers, embeddings, datasets):
    """
    Loop over all assets (Option A), predict, place orders, auto-close on reversal,
    enforce MAX_POS_PER_SYMBOL, apply trailing SL.
    """
    # Use raw symbol names for MT5 by default: raw = safe.replace('_',' ')
    raw_symbols = [raw_from_safe(s) for s in safe_list]

    # account balance
    acct = mt5.account_info()
    balance = float(acct.balance) if acct else 10000.0

    # prepare log header
    header = not os.path.exists(LOG_FILE)

    for safe, raw in zip(safe_list, raw_symbols):
        print("\n---", raw, "(", safe, ") ---")
        obs, vol, last_price = fetch_and_build_obs(raw, WINDOW, scalers, embeddings, datasets, safe_list)
        if obs is None:
            print("Skip", raw)
            continue

        # predict
        try:
            action, _ = model.predict(obs[np.newaxis,...], deterministic=True)
            a = int(action[0]) if isinstance(action, (list,tuple,np.ndarray)) else int(action)
        except Exception as e:
            print("Prediction error:", e)
            continue

        positions = get_positions_for_symbol(raw)
        print("Existing positions:", len(positions))

        # Auto-close on reversal: if signal BUY and existing SELL positions exist -> close them
        if a == 1:
            # BUY signal -> close SELL positions (p.type==1 typically)
            for p in positions:
                # depending on MT5 build, p.type: 0 = BUY, 1 = SELL (sometimes reversed). Check type semantics in your MT5.
                if getattr(p, "type", None) == 1:
                    print("Closing opposing SELL position ticket", p.ticket)
                    close_position_by_ticket(p.ticket)
        elif a == 2:
            for p in positions:
                if getattr(p, "type", None) == 0:
                    print("Closing opposing BUY position ticket", p.ticket)
                    close_position_by_ticket(p.ticket)

        # Re-fetch positions and enforce limit
        positions = get_positions_for_symbol(raw)
        if len(positions) >= MAX_POS_PER_SYMBOL:
            print(f"Max positions for {raw} reached ({len(positions)}) — skipping open")
        else:
            if a == 0:
                print("HOLD")
            else:
                direction = "BUY" if a==1 else "SELL"
                lot = compute_lot_from_balance(balance, vol, last_price)
                sl, tp = compute_sl_tp_by_pips(raw, last_price, direction, DEFAULT_SL_PIPS, DEFAULT_TP_PIPS)
                res = place_market_order(raw, direction, lot, sl, tp)
                # normalize response for logging
                if isinstance(res, dict):
                    retcode = res.get("retcode")
                    comment = res.get("comment")
                    price_executed = res.get("price", last_price)
                else:
                    retcode = getattr(res, "retcode", None)
                    comment = getattr(res, "comment", "")
                    price_executed = last_price
                entry = {
                    "timestamp": datetime.utcnow().isoformat(),
                    "safe": safe,
                    "symbol": raw,
                    "action": direction,
                    "lot": lot,
                    "exec_price": price_executed,
                    "sl": sl,
                    "tp": tp,
                    "retcode": retcode,
                    "comment": comment,
                    "dry_run": DRY_RUN
                }
                pd.DataFrame([entry]).to_csv(LOG_FILE, mode="a", index=False, header=header)
                header = False
                print("Placed", direction, "lot", lot, "retcode", retcode)

        # Trailing SL update for open positions
        for p in get_positions_for_symbol(raw):
            new_sl = trailing_sl_level(raw, p, TRAIL_PIPS=TRAIL_PIPS) if False else trailing_sl_level(raw, p, trail_pips=TRAIL_PIPS)
            if new_sl:
                # send modify order to change SL: in MT5, use trade position modify, we'll call order_send with TRADE_ACTION_SLTP or use position modification wrapper
                # Simple approach: send request to set SL/TP using ORDER_TYPE: mt5.ORDER_TYPE_BUY/SELL? use trade_modify if available
                # Use mt5.order_send with action TRADE_ACTION_SLTP if supported:
                req = {
                    "action": mt5.TRADE_ACTION_SLTP,
                    "symbol": raw,
                    "position": int(p.ticket),
                    "sl": new_sl,
                    "tp": float(p.tp) if getattr(p, "tp", None) is not None else 0.0,
                }
                if DRY_RUN:
                    print(f"[DRY] Would modify SL for ticket {p.ticket} -> {new_sl}")
                else:
                    r = mt5.order_send(req)
                    print("Modify SL result:", getattr(r,"retcode", None))

    print("\nSingle pass complete.")


In [39]:
def safe_from_raw(raw_symbol: str) -> str:
    """
    Convert MT5 raw name to safe key:
    'Volatility 75 Index' → 'Volatility_75_Index'
    """
    return raw_symbol.replace(" ", "_")


In [40]:
def raw_from_safe(safe_symbol: str) -> str:
    """
    Convert safe key to raw MT5 name:
    'Volatility_75_Index' → 'Volatility 75 Index'
    """
    return safe_symbol.replace("_", " ")


In [41]:
print(safe, scalers[safe])


EURUSD <__main__.MiniScaler object at 0x000001E9B064B770>


In [42]:
print(type(scalers[safe]))


<class '__main__.MiniScaler'>


In [43]:
# Cell 9 — run once (recommended)
# Trade Option A: all assets present in datasets
safe_list = list(datasets.keys())
print("Trading assets (safe names):", safe_list)
run_once_predict_and_manage(model, safe_list, scalers, embeddings, datasets)


Trading assets (safe names): ['EURUSD', 'Jump_100_Index', 'Jump_10_Index', 'Jump_25_Index', 'Jump_50_Index', 'Jump_75_Index', 'Volatility_100_1s_Index', 'Volatility_100_Index', 'Volatility_10_1s_Index', 'Volatility_10_Index', 'Volatility_25_1s_Index', 'Volatility_25_Index', 'Volatility_50_1s_Index', 'Volatility_50_Index', 'Volatility_75_1s_Index', 'Volatility_75_Index']

--- EURUSD ( EURUSD ) ---
Prediction error: Error: Unexpected observation shape (1, 50, 13) for Box environment, please use (50, 14) or (n_env, 50, 14) for the observation shape.

--- Jump 100 Index ( Jump_100_Index ) ---
Prediction error: Error: Unexpected observation shape (1, 50, 13) for Box environment, please use (50, 14) or (n_env, 50, 14) for the observation shape.

--- Jump 10 Index ( Jump_10_Index ) ---
Prediction error: Error: Unexpected observation shape (1, 50, 13) for Box environment, please use (50, 14) or (n_env, 50, 14) for the observation shape.

( Jump_25_Index ) ---
Prediction error: Error: Unexpecte

In [35]:
class MiniScaler:
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std
    
    def transform(self, x):
        return (x - self.mean) / self.std
    
    def inverse_transform(self, x):
        return x * self.std + self.mean

# Fix for all symbols
for sym in scalers:
    mean = scalers[sym]["mean"].values.astype(np.float32)
    std = scalers[sym]["std"].values.astype(np.float32)
    scalers[sym] = MiniScaler(mean, std)


In [None]:
# Cell 10 — simulation evaluator (quick/backtest style)
def simulate_trades_on_history(datasets, model, assets=None, window=WINDOW, horizon=10):
    trades = []
    assets = assets or list(datasets.keys())
    for safe in assets:
        df = datasets[safe].reset_index(drop=True)
        emb = embeddings.get(safe, np.zeros( (0,) ))
        for i in range(window, len(df)-horizon):
            window_df = df.iloc[i-window:i]
            feat = window_df[['o_pc','h_pc','l_pc','c_pc','v_pc']].values.astype(np.float32)
            emb_rep = np.tile(emb.reshape(1,-1),(window,1)) if emb.size>0 else np.zeros((window,0),dtype=np.float32)
            balance_col = np.full((window,1),1.0,dtype=np.float32)
            asset_id = np.full((window,1), list(datasets.keys()).index(safe)/len(datasets), dtype=np.float32)
            obs = np.concatenate([feat, emb_rep, balance_col, asset_id], axis=1)
            try:
                action, _ = model.predict(obs[np.newaxis,...], deterministic=True)
                action = int(action[0]) if isinstance(action,(list,tuple,np.ndarray)) else int(action)
            except Exception:
                continue
            if action == 0: 
                continue
            entry_price = float(df['Close_raw'].iat[i-1])
            exit_price = float(df['Close_raw'].iat[i+horizon-1])
            pos = 1 if action==1 else -1
            pnl = (exit_price - entry_price)/entry_price * pos
            trades.append({"safe": safe, "entry_i": i, "horizon": horizon, "action": action, "pnl": pnl})
    trades_df = pd.DataFrame(trades)
    return trades_df

def compute_trade_metrics(trades_df):
    if trades_df.empty:
        return {}
    pnl = trades_df['pnl']
    total = pnl.sum()
    mean = pnl.mean()
    std = pnl.std()
    wins = (pnl>0).sum()
    losses = (pnl<=0).sum()
    win_rate = wins / (wins+losses) if (wins+losses)>0 else 0.0
    profit_factor = (pnl[pnl>0].sum() / abs(pnl[pnl<0].sum())) if (pnl[pnl<0].sum()!=0) else np.nan
    cum = pnl.cumsum()
    max_dd = (cum.cummax() - cum).max()
    return {"n_trades": len(pnl), "total_pnl": float(total), "mean": float(mean), "std": float(std),
            "win_rate": float(win_rate), "profit_factor": float(profit_factor), "max_drawdown": float(max_dd)}


In [None]:
# Cell 11 — run simulation (may take a few minutes depending on dataset size)
sim_trades = simulate_trades_on_history(datasets, model, assets=list(datasets.keys()), window=WINDOW, horizon=10)
sim_trades.to_csv(os.path.join(MODEL_DIR, "simulated_trades_optionA.csv"), index=False)
metrics = compute_trade_metrics(sim_trades)
print("Simulation metrics:", json.dumps(metrics, indent=2))
if not sim_trades.empty:
    plt.figure(figsize=(8,4))
    plt.hist(sim_trades['pnl'], bins=60)
    plt.title("Distribution of simulated trade returns")
    plt.show()


In [None]:
# Cell 12 — show last live logs
if os.path.exists(LOG_FILE):
    df_log = pd.read_csv(LOG_FILE)
    display(df_log.tail(20))
else:
    print("No live log present:", LOG_FILE)
