In [None]:
from trading_engine.core import (
    read_data, create_model_state, orchestrate_model_backtests,
    orchestrate_model_simulations, orchestrate_portfolio_optimizations,
    orchestrate_portfolio_simulations, orchestrate_portfolio_aggregation
)

import datetime

In [None]:
# Define your trading model
from typing import Callable

import polars as pl
from polars import LazyFrame

from common.bundles import ModelStateBundle

def MomentumModel(
        trade_ticker: str,
        signal_ticker: str,
        momentum_column: str,
        inverse: bool = True,  # True => long when signal < 0, else flat
        threshold: float = 0.0,  # optional dead-band around 0 to avoid noise
) -> Callable[[ModelStateBundle], LazyFrame]:
    """
    Generate weights for `trade_ticker` from the momentum of `signal_ticker`.

    inverse=True  -> long when momentum(signal_ticker) < -threshold, else 0
    inverse=False -> long when momentum(signal_ticker) >  threshold, else 0

    Output columns: ["date", trade_ticker]
    
    Args:
        bundle: ModelStateBundle containing:
            - model_state: ticker-associated data (DataFrame)
            - supplemental_model_state: non-ticker data like macro indicators (DataFrame)
    """

    def run_model(bundle: ModelStateBundle) -> LazyFrame:
        # Pull only the signal series from model_state (lazy)
        sig = (
            bundle.model_state.lazy()
            .filter(pl.col("ticker") == signal_ticker)
            .select([
                pl.col("date"),
                pl.col(momentum_column).alias("sig")
            ])
        )

        # Condition: long when (sig < -threshold) if inverse else (sig > threshold)
        cond = (pl.col("sig") < -threshold) if inverse else (pl.col("sig") > threshold)

        # Vectorized mapping to {1.0, 0.0}; treat null as False via fill_null(False)
        weights = sig.select([
            pl.col("date"),
            pl.when(cond.fill_null(False))
            .then(pl.lit(1.0))
            .otherwise(pl.lit(0.0))
            .cast(pl.Float64)
            .alias(trade_ticker)
        ])

        return weights  # LazyFrame

    return run_model

In [None]:
# Example: Model using both model_state and supplemental_model_state
# This model trades USO-US based on its momentum AND oil inventory data

def USO_OilInventoryModel(
        trade_ticker: str = "USO-US",
        momentum_column: str = "close_momentum_10",
        inventory_series_id: str = "WCESTUS1",  # U.S. Ending Stocks excl. SPR
        production_series_id: str = "WCRFPUS2",   # U.S. Field Production
) -> Callable[[ModelStateBundle], LazyFrame]:
    """
    Trade USO-US based on:
    1. USO momentum (from model_state)
    2. Oil inventory levels (from supplemental_model_state)
    3. Oil production levels (from supplemental_model_state)
    
    Strategy:
    - Long when: USO momentum > 0 AND inventories decreasing AND production stable/increasing
    - Short when: USO momentum < 0 AND inventories increasing
    - Otherwise: flat
    
    Note: Supplemental data is weekly (Wednesdays), so we forward-fill to daily
    
    Args:
        bundle: ModelStateBundle containing:
            - model_state: ticker data including USO-US momentum
            - supplemental_model_state: oil macro data (weekly frequency)
    """
    
    def run_model(bundle: ModelStateBundle) -> LazyFrame:
        # 1) Get USO momentum from model_state
        uso_data = (
            bundle.model_state.lazy()
            .filter(pl.col("ticker") == trade_ticker)
            .select([
                pl.col("date"),
                pl.col(momentum_column).alias("uso_momentum")
            ])
        )
        
        # 2) Get oil inventory and production from supplemental_model_state
        # Supplemental data is weekly (Wednesdays), so we need to forward-fill to match daily dates
        supplemental = bundle.supplemental_model_state.lazy()
        
        # Check which columns exist in supplemental data
        supplemental_cols = supplemental.collect_schema().names()
        
        # Get inventory data (if available)
        if inventory_series_id in supplemental_cols:
            inventory_data = supplemental.select([
                pl.col("date"),
                pl.col(inventory_series_id).alias("inventory")
            ])
        else:
            # Create empty inventory column if series not found
            # Use the date column from supplemental to maintain structure
            inventory_data = supplemental.select([
                pl.col("date"),
                pl.lit(None).cast(pl.Float64).alias("inventory")
            ])
        
        # Get production data (if available)
        if production_series_id in supplemental_cols:
            production_data = supplemental.select([
                pl.col("date"),
                pl.col(production_series_id).alias("production")
            ])
        else:
            # Create empty production column if series not found
            production_data = supplemental.select([
                pl.col("date"),
                pl.lit(None).cast(pl.Float64).alias("production")
            ])
        
        # 3) Join supplemental data and forward-fill weekly data to daily
        # Supplemental data is weekly (Wednesdays), so we forward-fill to match daily USO dates
        # First, ensure dates are in compatible format (both as date or both as string)
        uso_dates = uso_data.select("date")
        
        # Join inventory and production data
        # Weekly data will be forward-filled to all days until next Wednesday
        macro_data = (
            uso_dates
            .join(inventory_data, on="date", how="left")
            .join(production_data, on="date", how="left")
            .sort("date")
            .with_columns([
                pl.col("inventory").forward_fill().alias("inventory"),
                pl.col("production").forward_fill().alias("production"),
            ])
        )
        
        # 4) Calculate weekly changes (difference from 7 days ago)
        # This captures week-over-week changes in inventory/production
        macro_data = macro_data.with_columns([
            (pl.col("inventory") - pl.col("inventory").shift(7)).alias("inventory_change_weekly"),
            (pl.col("production") - pl.col("production").shift(7)).alias("production_change_weekly"),
        ])
        
        # 5) Join USO momentum with macro data
        combined = (
            uso_data
            .join(macro_data, on="date", how="left")
            .sort("date")
        )
        
        # 7) Generate signals based on combined factors
        # Long: positive momentum AND (inventory decreasing weekly OR production increasing weekly)
        # Short: negative momentum AND inventory increasing weekly
        long_condition = (
            (pl.col("uso_momentum") > 0) &
            (
                (pl.col("inventory_change_weekly") < 0) |  # Inventories decreasing week-over-week
                (pl.col("production_change_weekly") > 0)    # Production increasing week-over-week
            )
        )
        
        short_condition = (
            (pl.col("uso_momentum") < 0) &
            (pl.col("inventory_change_weekly") > 0)  # Inventories increasing week-over-week
        )
        
        # Generate weights
        weights = combined.select([
            pl.col("date"),
            pl.when(long_condition.fill_null(False))
            .then(pl.lit(1.0))
            .when(short_condition.fill_null(False))
            .then(pl.lit(-1.0))
            .otherwise(pl.lit(0.0))
            .cast(pl.Float64)
            .alias(trade_ticker)
        ])
        
        return weights
    
    return run_model

In [None]:
models_registry = {
    "RXI_TLT_pml_10": {
        "tickers": ["RXI-US", "TLT-US"],  # define the tickers your model looks at
        "columns": ["close_momentum_10"],  # define the model state columns your model needs
        "function": MomentumModel(
            trade_ticker="RXI-US",
            signal_ticker="TLT-US",
            momentum_column="close_momentum_10",
            inverse=False,
        ),
        "lookback": 0,
    },
    "USO_OilInventory": {
        "tickers": ["USO-US"],  # USO-US is the trade ticker
        "columns": ["close_momentum_10"],  # momentum column from model_state
        "function": USO_OilInventoryModel(
            trade_ticker="USO-US",
            momentum_column="close_momentum_10",
            inventory_series_id="WCESTUS1",  # U.S. Ending Stocks excl. SPR
            production_series_id="WCRFPUS2", # U.S. Field Production
        ),
        "lookback": 0,
    },
}

In [None]:
# 1) experiment config
universe = [
  "SPY-US", "SLV-US", "GLD-US", "TLT-US", "USO-US", "UNG-US", "IXJ-US",
  "KXI-US", "JXI-US", "IXG-US", "IXN-US", "RXI-US", "MXI-US", "EXI-US",
  "IXC-US", "IEI-US", "SHY-US", "BIL-US", "JPXN-US", "INDA-US", "MCHI-US",
  "EZU-US", "IBIT-US", "ETHA-US", "VIXY-US"
]
features = ["close_momentum_10"]                   # keys in FEATURES
models   = ["USO_OilInventory"]                    # Try the new model that uses both data sources!
# models   = ["RXI_TLT_pml_10"]                   # Or use the original model
aggregators = ["model_mvo"]                        # keys in AGGREGATORS
optimizers   = ["mean_variance_constrained"]       # keys in OPTIMIZERS
initial_value = 1_000_000
start_date = datetime.date(2021, 1, 1)
end_date = datetime.date(2025, 1, 1)

In [None]:
# 2) build model state + prices (cached locally)
raw_data_bundle = read_data(include_supplemental=True)
model_state_bundle, prices = create_model_state(
    raw_data_bundle=raw_data_bundle,
    features=features,
    start_date=start_date,
    end_date=end_date,
    universe=universe
)

# Note: Supplemental data is weekly (available every Wednesday)
# The model handles forward-filling to match daily USO-US data

## Inspect Supplemental Model State

Let's first check what data is available in the supplemental model state:

In [None]:
# Inspect supplemental model state columns
print("Supplemental Model State columns:")
print(model_state_bundle.supplemental_model_state.columns)
print("\nSupplemental Model State shape:")
print(model_state_bundle.supplemental_model_state.shape)
print("\nFirst few rows:")
print(model_state_bundle.supplemental_model_state.head())

In [None]:
# 3) run model backtests + simulations
model_insights = orchestrate_model_backtests(
    model_state_bundle=model_state_bundle,
    models=models,
    universe=universe,
    registry=models_registry  # pass in your custom models registry instead of pulling default prod registry
)

model_simulations = orchestrate_model_simulations(
    prices=prices,
    model_insights=model_insights,
    start_date=start_date,
    end_date=end_date
)

In [None]:
# 4) aggregate + optimize portfolio and simulate
aggregated_insights = orchestrate_portfolio_aggregation(
    model_insights=model_insights,
    backtest_results=model_simulations,
    universe=universe,
    aggregators=aggregators,
    start_date=start_date,
    end_date=end_date,
)

optimizer_insights = orchestrate_portfolio_optimizations(
    prices=prices,
    aggregated_insights=aggregated_insights,
    universe=universe,
    optimizers=optimizers,
)

optimizer_simulations = orchestrate_portfolio_simulations(
    prices=prices,
    portfolio_insights=optimizer_insights,
    start_date=start_date,
    end_date=end_date,
    initial_value=initial_value,
)

In [None]:
# 5) visualize one result (example: mean_variance_constrained)
optimizer_simulations["mean_variance_constrained"]["backtest_metrics"]