# FinMamba (Market-Aware Graph + Multi-Level Mamba) - From PDF

This notebook implements the core components described in `finmamba.pdf`:
- Market-Aware Graph (MAG): dynamic stock correlation graph + market-aware sparsification using a market proxy index.
- Graph Attention Aggregation: multi-head neighbor aggregation on the pruned graph.
- Multi-Level Mamba (MLM): multi-level selective SSM blocks.
- Optimization objectives: point-wise regression + pair-wise hinge ranking + GIB loss.

Data:
- Uses the repo feature dataset: `dataset/features/all_features.parquet` (fallback: `all_features.csv`).

Training:
- Time split: train first 7 years, validate middle, test last 18 months.
- Epochs: 3.

Backtest:
- Converts predicted scores into weekly Top-K long-only weights.
- Runs `src/backtester.engine.run_backtest` and shows the existing Bokeh dashboard.


In [None]:
from __future__ import annotations

from pathlib import Path
import sys

import numpy as np
import pandas as pd

import torch

from bokeh.io import output_notebook, show

from sklearn.feature_selection import mutual_info_classif


In [None]:
# Resolve project root robustly
CWD = Path.cwd().resolve()
PROJECT_ROOT = None
for p in [CWD, *CWD.parents]:
    if (p / 'dataset').exists() and (p / 'src').exists():
        PROJECT_ROOT = p
        break
if PROJECT_ROOT is None:
    raise RuntimeError(f'Could not locate project root from CWD={CWD}')

if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

print('PROJECT_ROOT:', PROJECT_ROOT)


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('torch:', torch.__version__)
print('device:', device)


In [None]:
# Load feature-extracted dataset
FEATURES_PARQUET_PATH = PROJECT_ROOT / 'dataset' / 'features' / 'all_features.parquet'
FEATURES_CSV_PATH = PROJECT_ROOT / 'dataset' / 'features' / 'all_features.csv'

if FEATURES_PARQUET_PATH.exists():
    df = pd.read_parquet(FEATURES_PARQUET_PATH)
    if 'Date' in df.columns:
        df['Date'] = pd.to_datetime(df['Date'])
        df = df.set_index('Date')
elif FEATURES_CSV_PATH.exists():
    df = pd.read_csv(FEATURES_CSV_PATH, parse_dates=['Date']).set_index('Date')
else:
    raise FileNotFoundError('Feature dataset not found under dataset/features/.')

if 'Asset_ID' not in df.columns:
    raise ValueError('Expected Asset_ID column in feature dataset')
if 'ret_1d' not in df.columns:
    raise ValueError('Expected ret_1d column in feature dataset')

df = df.sort_index()

# Forward labels
TARGET_FWD_COL = 'y_ret_1d_fwd'
df[TARGET_FWD_COL] = df.groupby('Asset_ID', sort=False)['ret_1d'].shift(-1)
df = df.dropna(subset=[TARGET_FWD_COL])

# Use forward return for regression target r_t

TRAIN_YEARS = 7
TEST_MONTHS = 18

start = pd.Timestamp(df.index.min())
end = pd.Timestamp(df.index.max())
train_end = start + pd.DateOffset(years=TRAIN_YEARS)
test_start = end - pd.DateOffset(months=TEST_MONTHS)

if train_end >= test_start:
    raise ValueError('Not enough history for requested split')

train_mask = df.index < train_end
val_mask = (df.index >= train_end) & (df.index < test_start)
test_mask = df.index >= test_start

df_train = df.loc[train_mask].copy()
df_val = df.loc[val_mask].copy()
df_test = df.loc[test_mask].copy()

print('date range:', start.date(), '->', end.date())
print('train:', df_train.index.min().date(), '->', df_train.index.max().date(), 'rows:', df_train.shape[0])
print('val  :', df_val.index.min().date(), '->', df_val.index.max().date(), 'rows:', df_val.shape[0])
print('test :', df_test.index.min().date(), '->', df_test.index.max().date(), 'rows:', df_test.shape[0])
print('assets:', df['Asset_ID'].nunique())

# Numeric feature columns
exclude = {'Asset_ID', TARGET_FWD_COL}
feature_cols = [c for c in df.columns if c not in exclude]
numeric_feature_cols = [c for c in feature_cols if pd.api.types.is_numeric_dtype(df[c])]
print('n_numeric_features:', len(numeric_feature_cols))


In [None]:
# Build dense tensors X[dates, assets, features] and r[dates, assets]

N_FEATURES = 32
LOOKBACK = 20  # per PDF window size

# feature selection (MI) on training rows
X_fs = df_train[numeric_feature_cols].replace([np.inf, -np.inf], np.nan)
y_fs = df_train[TARGET_FWD_COL].astype(float)

med = X_fs.median(axis=0)
X_imp = X_fs.fillna(med)

mi = mutual_info_classif(X_imp.to_numpy(), (y_fs.to_numpy() > 0).astype(int), random_state=42)
mi_s = pd.Series(mi, index=numeric_feature_cols).sort_values(ascending=False)
sel_features = mi_s.head(min(N_FEATURES, len(mi_s))).index.tolist()
print('selected features:', sel_features)

# Pivot per feature -> [T,N]
dates = pd.Index(sorted(df.index.unique()))
assets = pd.Index(sorted(df['Asset_ID'].unique()))

feat_arrays = []
for f in sel_features:
    m = df.pivot_table(index=df.index, columns='Asset_ID', values=f, aggfunc='mean').reindex(index=dates, columns=assets)
    feat_arrays.append(m.to_numpy(dtype=np.float32))

X_all = np.stack(feat_arrays, axis=-1)  # [T,N,F]

r_all = (
    df.pivot_table(index=df.index, columns='Asset_ID', values=TARGET_FWD_COL, aggfunc='mean')
    .reindex(index=dates, columns=assets)
    .to_numpy(dtype=np.float32)
)

# Train normalization stats
train_dates_mask = dates < train_end
x_tr = X_all[train_dates_mask]
med_f = np.nanmedian(x_tr, axis=(0, 1))
X_all = np.where(np.isnan(X_all), med_f, X_all)
mean_f = X_all[train_dates_mask].mean(axis=(0, 1))
std_f = X_all[train_dates_mask].std(axis=(0, 1), ddof=0)
std_f = np.where(std_f == 0.0, 1.0, std_f)
X_all = (X_all - mean_f) / std_f

# Index lists (prediction dates)
train_idx = [i for i, dt in enumerate(dates) if (dt < train_end) and i >= LOOKBACK - 1]
val_idx = [i for i, dt in enumerate(dates) if (dt >= df_val.index.min()) and (dt <= df_val.index.max()) and i >= LOOKBACK - 1]
test_idx = [i for i, dt in enumerate(dates) if (dt >= df_test.index.min()) and (dt <= df_test.index.max()) and i >= LOOKBACK - 1]

print('n_train_dates:', len(train_idx), 'n_val_dates:', len(val_idx), 'n_test_dates:', len(test_idx))


In [None]:
# Prior relationship D: long-term correlation on training window using close prices
from src.backtester.data import load_cleaned_assets, align_close_prices

assets_ohlcv = load_cleaned_assets(symbols=assets.to_list(), cleaned_dir=str(PROJECT_ROOT / 'dataset' / 'cleaned'))
close_prices = align_close_prices(assets_ohlcv).sort_index().ffill().bfill()

cp_train = close_prices.loc[:train_end].copy()
ret = cp_train.pct_change().dropna(how='all')
ret = ret.dropna(axis=1, how='any')

# keep consistent ordering
assets_bt = ret.columns.to_list()
idx_assets = [assets.get_loc(a) for a in assets_bt]

corr = ret.corr().fillna(0.0).to_numpy(dtype=np.float32)
prior_d = torch.tensor(corr, device=device)
print('prior_d:', prior_d.shape)


In [None]:
from src.models.finmamba import FinMamba, FinMambaConfig, finmamba_loss

cfg = FinMambaConfig(lookback=LOOKBACK, n_levels=2, hidden_dim=64, n_heads=4)

model = FinMamba(feature_dim=len(sel_features), cfg=cfg).to(device)
model.set_prior(prior_d)


In [None]:
# Training loop (epochs=3)

EPOCHS = 3
BATCH_DATES = 8
LR = 1e-3

from torch.utils.data import Dataset, DataLoader

class DateDataset(Dataset):
    def __init__(self, idxs: list[int]):
        self.idxs = idxs
    def __len__(self) -> int:
        return len(self.idxs)
    def __getitem__(self, k: int):
        i = self.idxs[k]
        X_seq = X_all[i - LOOKBACK + 1 : i + 1][:, idx_assets, :]  # [L,N,F]
        r = r_all[i][idx_assets]  # [N]
        return (
            torch.tensor(X_seq, dtype=torch.float32),
            torch.tensor(r, dtype=torch.float32),
        )

train_loader = DataLoader(DateDataset(train_idx), batch_size=BATCH_DATES, shuffle=True)
val_loader = DataLoader(DateDataset(val_idx), batch_size=BATCH_DATES, shuffle=False)

opt = torch.optim.Adam(model.parameters(), lr=LR)

for epoch in range(1, EPOCHS + 1):
    print(f'epoch {epoch}/{EPOCHS} (train)')
    train_losses = []
    model.train()
    for x_seq, r in train_loader:
        x_seq = x_seq.to(device)  # [B,L,N,F]
        r = r.to(device)
        opt.zero_grad(set_to_none=True)
        y_pred, z_seq, s_seq = model(x_seq)
        loss = finmamba_loss(y_pred=y_pred, r_true=r, z_seq=z_seq, s_seq=s_seq, cfg=cfg)
        loss.backward()
        train_losses.append(float(loss.item()))
        opt.step()

    # quick val loss
    if train_losses:
        print('  train_loss', float(np.mean(train_losses)))
    model.eval()
    losses = []
    with torch.no_grad():
        for b, (x_seq, r) in enumerate(val_loader):
            if b >= 10:
                break
            x_seq = x_seq.to(device)
            r = r.to(device)
            y_pred, z_seq, s_seq = model(x_seq)
            losses.append(float(finmamba_loss(y_pred=y_pred, r_true=r, z_seq=z_seq, s_seq=s_seq, cfg=cfg).item()))
    if losses:
        print('  val_loss', float(np.mean(losses)))
    else:
        print('  val_loss n/a (no validation batches)')


In [None]:
# Predict scores on test and run backtest + Bokeh

from src.backtester.engine import BacktestConfig, run_backtest
from src.backtester.report import compute_backtest_report
from src.backtester.bokeh_plots import build_interactive_portfolio_layout
from src.backtester.portfolio import equal_weight

model.eval()
all_dates = []
all_assets = []
all_scores = []

with torch.no_grad():
    for i in test_idx:
        x_seq = torch.tensor(X_all[i - LOOKBACK + 1 : i + 1][:, idx_assets, :], dtype=torch.float32).unsqueeze(0).to(device)
        y_pred, _z, _s = model(x_seq)
        score = y_pred.squeeze(0).cpu().numpy()  # [N]

        dt = pd.Timestamp(dates[i])
        all_dates.append(np.repeat(np.datetime64(dt.to_datetime64()), len(assets_bt)))
        all_assets.append(np.array(assets_bt, dtype=object))
        all_scores.append(score)

dates_test = np.concatenate(all_dates)
assets_test = np.concatenate(all_assets)
scores_test = np.concatenate(all_scores)

long = pd.DataFrame({'Date': dates_test, 'Asset_ID': assets_test, 'score': scores_test})
pred_matrix = long.pivot_table(index='Date', columns='Asset_ID', values='score', aggfunc='mean').sort_index()

# Slice backtest window to test period
bt_start = pd.Timestamp(df_test.index.min())
bt_end = pd.Timestamp(df_test.index.max())
close_bt = close_prices.loc[bt_start:bt_end, assets_bt]

market_df = pd.DataFrame({
    'Open': pd.concat([d['Open'] for d in assets_ohlcv.values()], axis=1).mean(axis=1).loc[bt_start:bt_end],
    'High': pd.concat([d['High'] for d in assets_ohlcv.values()], axis=1).mean(axis=1).loc[bt_start:bt_end],
    'Low': pd.concat([d['Low'] for d in assets_ohlcv.values()], axis=1).mean(axis=1).loc[bt_start:bt_end],
    'Close': pd.concat([d['Close'] for d in assets_ohlcv.values()], axis=1).mean(axis=1).loc[bt_start:bt_end],
    'Volume': pd.concat([d['Volume'] for d in assets_ohlcv.values()], axis=1).sum(axis=1).loc[bt_start:bt_end],
}).sort_index()

pred_matrix = pred_matrix.reindex(close_bt.index)

# Weekly Top-K long-only weights
REBALANCE_FREQ = 'W'
TOP_K = 20

rebal_dates = set(pd.Series(pred_matrix.index, index=pred_matrix.index).resample(REBALANCE_FREQ).last().dropna().tolist())

w_last = pd.Series(0.0, index=assets_bt)
w_rows = []
for dt in pred_matrix.index:
    if dt in rebal_dates:
        row = pred_matrix.loc[dt].dropna().sort_values(ascending=False)
        top = row.head(min(TOP_K, len(row)))
        candidates = [a for a, v in top.items() if np.isfinite(v)]
        if len(candidates) == 0:
            w_last = pd.Series(0.0, index=assets_bt)
        else:
            w_dict = equal_weight(candidates)
            w_last = pd.Series(0.0, index=assets_bt)
            for a, w in w_dict.items():
                if a in w_last.index:
                    w_last[a] = float(w)
    w_rows.append(w_last)

weights = pd.DataFrame(w_rows, index=pred_matrix.index, columns=assets_bt).fillna(0.0)

output_notebook()

bt_cfg = BacktestConfig(initial_equity=1_000_000.0, transaction_cost_bps=5.0, mode='vectorized')
res = run_backtest(close_bt, weights, config=bt_cfg)
report = compute_backtest_report(result=res, close_prices=close_bt)
display(report.to_frame('FinMamba Report'))

layout = build_interactive_portfolio_layout(
    market_ohlcv=market_df,
    equity=res.equity,
    returns=res.returns,
    weights=res.weights,
    turnover=res.turnover,
    costs=res.costs,
    close_prices=close_bt,
    title='FinMamba (PDF-inspired) - Backtest',
)
show(layout)
