In [None]:
# import everything
import os
import pickle
import numpy as np
import pandas as pd
import networkx as nx
from tqdm import tqdm
import mysql.connector
from datetime import timedelta
import matplotlib.pyplot as plt
from collections import defaultdict
from bisect import bisect_left, bisect_right


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import SAGEConv

from dotenv import load_dotenv

# the path to the .env file
dotenv_path = "/Users/tristanbrigham/Desktop/Classes/CPSC 6440/FinalProject/Submission/.env"
load_dotenv(dotenv_path)


In [None]:
# load in the data
with open("training_gnn_datasets_snapshot.pkl", "rb") as f:
    snapshot = pickle.load(f)
    df_trades = snapshot["df_trades"]
    df_mo = snapshot["df_mo"]
    clob_outcome_dict = snapshot["clob_outcome_dict"]
    df_markets = snapshot["df_markets"]
    df_market_events = snapshot["df_market_events"]
    df_event_series = snapshot["df_event_series"]

print("loaded trades shape:", df_trades.shape)
print(df_trades.head())
print("number of clob tokens in dict:", len(clob_outcome_dict))
print("mapping tables shapes:",
      "markets:", df_markets.shape,
      "market_events:", df_market_events.shape,
      "event_series:", df_event_series.shape)


In [None]:
# get the times that things ran

# trade_time is already datetime
df_trades = df_trades.sort_values("trade_time").reset_index(drop=True)

# make sure that the timestamp is numeric
df_trades["timestamp"] = pd.to_numeric(df_trades["timestamp"], errors="coerce")
df_trades = df_trades.dropna(subset=["timestamp"])


# get the number of trades and split
n = len(df_trades)
idx_train_end = int(0.8 * n)
idx_val_end = int(0.9 * n)


# get the train, val, and test sets
df_train = df_trades.iloc[:idx_train_end].copy()
df_val = df_trades.iloc[idx_train_end:idx_val_end].copy()
df_test = df_trades.iloc[idx_val_end:].copy()

print("train rows:", len(df_train))
print("val rows:", len(df_val))
print("test rows:", len(df_test))

# training period length in years
t_min_train = df_train["trade_time"].min()
t_max_train = df_train["trade_time"].max()
train_period_seconds = float((t_max_train - t_min_train).total_seconds())
train_period_years = train_period_seconds / (365.25 * 24 * 60 * 60)

print("train period years:", train_period_years)

# minimum number of trades in the training set required for a wallet to be included
# so that we don't get a bunch of rando wallets
MIN_TRAIN_TRADES = 15 

train_trade_counts = df_train.groupby("proxy_wallet")["id"].count()
eligible_wallets = set(train_trade_counts[train_trade_counts >= MIN_TRAIN_TRADES].index)

print("wallets with at least", MIN_TRAIN_TRADES, "trades in training:", len(eligible_wallets))


In [None]:
# enrich the data that we have

# merge condition_id into the market_id
df_markets_trim = df_markets[["id", "conditionId"]].dropna(subset=["conditionId"]).copy()
df_markets_trim.rename(columns={"id": "market_id"}, inplace=True)



# merge the dataframes
df_trades = df_trades.merge(
    df_markets_trim,
    how="left",
    left_on="condition_id",
    right_on="conditionId"
)

# make sure that we don't have any duplicate columns
df_trades.drop(columns=["conditionId"], inplace=True)

# merge market_id into the event_id
if not df_market_events.empty:
    df_me_trim = df_market_events[["market_id", "event_id"]].dropna(subset=["market_id"])
    df_trades = df_trades.merge(
        df_me_trim,
        how="left",
        on="market_id"
    )
else:
    df_trades["event_id"] = np.nan

# merge event_id into the series_id
if not df_event_series.empty:
    df_es_trim = df_event_series[["event_id", "series_id"]].dropna(subset=["event_id"])
    df_trades = df_trades.merge(
        df_es_trim,
        how="left",
        on="event_id"
    )
else:
    df_trades["series_id"] = np.nan


# print some base statistics
print("df_trades with market/event/series columns:")
print(df_trades[["proxy_wallet", "condition_id", "market_id", "event_id", "series_id"]].head())


In [None]:
# get the prices of the assets in question over time

# extract needed columns
df_prices = df_trades[["trade_time", "asset", "price"]].copy()
df_prices["price"] = df_prices["price"].astype(float)

# make sure timestamps are sorted
df_prices = df_prices.sort_values("trade_time").reset_index(drop=True)

# all unique timestamps across all trades 
unique_times = pd.Index(df_prices["trade_time"].sort_values().unique())

# for each token, get the price at each timestamp
token_price_paths = {}

for token_id, g in tqdm(
    df_prices.groupby("asset", sort=False),
    total=df_prices["asset"].nunique(),
    desc="building sparse token price paths"
):
    g = g.sort_values("trade_time")

    # collapse duplicate timestamps per token by taking the last trade at each timestamp
    g_token = (
        g[["trade_time", "price"]]
        .groupby("trade_time", as_index=True)["price"]
        .last()
        .to_frame()
    )

    token_price_paths[token_id] = g_token

print("number of tokens with sparse price series:", len(token_price_paths))

# see the prices of one example token
if token_price_paths:
    example_token = next(iter(token_price_paths.keys()))
    print("example token:", example_token)
    print(token_price_paths[example_token].head(20))


In [None]:

# the asset that we found with the highest overall variance
token_price_change_counts = {}

# count the number of price changes for each token
for token_id, df_token in token_price_paths.items():
    
    # getting the price values
    prices = df_token["price"].values
    if len(prices) < 2:
        continue
    # count steps where price changes by a good amount
    num_changes = np.sum(np.abs(np.diff(prices)) > 1e-6)  
    token_price_change_counts[token_id] = num_changes

if not token_price_change_counts:
    raise ValueError("no tokens with nontrivial price series found.")

# get the one that has changed price the most times overall
most_active_token = max(token_price_change_counts, key=token_price_change_counts.get)
print("token with most price changes along the way:", most_active_token)
print("number of price changes:", token_price_change_counts[most_active_token])

df_most_active = token_price_paths[most_active_token]

var_most_active = np.var(df_most_active["price"].values)
print("variance of this token's price path:", var_most_active)


# get the price plotted over time
plt.figure(figsize=(12, 6))
plt.plot(df_most_active.index, df_most_active["price"], marker='.', lw=1)
plt.xlabel("time")
plt.ylabel("price")
plt.title(f"price trajectory with most changes for token {most_active_token}")
plt.tight_layout()
plt.show()


In [None]:
# restrict to eligible wallets only
df_trades_wallet = df_trades[df_trades["proxy_wallet"].isin(eligible_wallets)][
    ["trade_time", "proxy_wallet", "asset", "side", "size", "price"]
].copy()


# get the size and price as floats
df_trades_wallet["size"] = df_trades_wallet["size"].astype(float)
df_trades_wallet["price"] = df_trades_wallet["price"].astype(float)

# sort the trades by time
df_trades_wallet = df_trades_wallet.sort_values("trade_time").reset_index(drop=True)

# compute share changes
df_trades_wallet["share_change"] = np.where(
    df_trades_wallet["side"] == "BUY",
    df_trades_wallet["size"],
    -df_trades_wallet["size"]
)

# compute cash changes
df_trades_wallet["cash_change"] = np.where(
    df_trades_wallet["side"] == "BUY",
    -df_trades_wallet["size"] * df_trades_wallet["price"],
    df_trades_wallet["size"] * df_trades_wallet["price"]
)

# group by the wallet and the asset combination
wallet_initial_shares = defaultdict(dict)
groupby_wallet_asset = df_trades_wallet.groupby(["proxy_wallet", "asset"])

# iterate through the pairs and investigate what we find
for (wallet, asset), grp in tqdm(groupby_wallet_asset, desc="wallet-asset pairs"):
    share_change = grp["share_change"].values
    cum_shares = np.cumsum(share_change)
    min_cum = cum_shares.min()
    init_shares = max(0.0, -float(min_cum))
    wallet_initial_shares[wallet][asset] = init_shares

# compute running statistics
wallet_initial_cash = {}
groupby_wallet = df_trades_wallet.groupby("proxy_wallet")
for wallet, grp in tqdm(groupby_wallet, desc="wallets"):
    cash_change = grp["cash_change"].values
    cum_cash = np.cumsum(cash_change)
    min_cash = cum_cash.min()
    init_cash = max(0.0, -float(min_cash))
    wallet_initial_cash[wallet] = init_cash

print("wallets with inferred initial holdings:", len(wallet_initial_cash))


In [None]:
# build a downsampled global time grid
all_times = unique_times

# downsamplignt o make things faster to run
num_grid = min(400, len(all_times)) 

grid_indices = np.linspace(0, len(all_times) - 1, num_grid).astype(int)
grid_times = all_times[grid_indices]

print("number of grid times:", len(grid_times))
print("grid start:", grid_times[0], "grid end:", grid_times[-1])

# map tokens to indices and prebuild price matrix aligned with grid_times
token_list = sorted(token_price_paths.keys())
token_to_idx = {tok: i for i, tok in enumerate(token_list)}
num_tokens = len(token_list)

def prices_on_grid(df_token, grid_times):
    # df_token is indexed by its own trade_time, with the price column
    tok_times = df_token.index.to_numpy()
    tok_prices = df_token["price"].to_numpy()

    if tok_times.size == 0:
        # no trades for this token
        return np.zeros(len(grid_times), dtype=np.float32)

    # for each grid time, take last trade price at or before that time
    idx = np.searchsorted(tok_times, grid_times, side="right") - 1

    # if a grid time is before the first trade, use the first trade price
    idx[idx < 0] = 0
    idx = np.clip(idx, 0, tok_times.size - 1)

    return tok_prices[idx].astype(np.float32)

# the price grid is a matrix of the prices of the tokens at the grid times
price_grid = np.zeros((num_tokens, num_grid), dtype=np.float32)
for tok, i_tok in tqdm(token_to_idx.items(), desc="building price grid"):
    df_tok = token_price_paths[tok]
    price_grid[i_tok, :] = prices_on_grid(df_tok, grid_times)

# wallet list and index mapping
wallet_list = sorted(df_trades_wallet["proxy_wallet"].unique())
wallet_to_idx = {w: i for i, w in enumerate(wallet_list)}
idx_to_wallet = {i: w for w, i in wallet_to_idx.items()}
num_wallets = len(wallet_list)

print("number of wallets with trades (eligible):", num_wallets)

# portfolio value matrix
V = np.zeros((num_grid, num_wallets), dtype=np.float64)

# prepare numpy arrays for trades to avoid pandas overhead in inner loop
dfw = df_trades_wallet.sort_values("trade_time").reset_index(drop=True)
time_arr = dfw["trade_time"].to_numpy()
wallet_arr = dfw["proxy_wallet"].to_numpy()
asset_arr = dfw["asset"].to_numpy()
side_arr = dfw["side"].to_numpy()
size_arr = dfw["size"].to_numpy(dtype=float)
price_arr = dfw["price"].to_numpy(dtype=float)
n_trades = len(dfw)

# initialize per-wallet state
# that we can continue from
wallet_cash = {}
wallet_pos = {}

for w in tqdm(wallet_list, desc="init wallet state"):
    wallet_cash[w] = float(wallet_initial_cash.get(w, 0.0))
    wallet_pos[w] = dict(wallet_initial_shares.get(w, {}))

# get the portfolio value for the wallet at a given point in time
def compute_wallet_value(wallet, k):
    
    # get the cash and positions of the wallet
    val = wallet_cash[wallet]
    pos = wallet_pos[wallet]

    # if the wallet has no positions, return the cash
    if not pos:
        return val
    for asset, shares in pos.items():
        if shares == 0.0:
            continue
        tok_idx = token_to_idx.get(asset)
        if tok_idx is None:
            continue
        p = price_grid[tok_idx, k]
        val += shares * p
    return val

# loop over the grid times
trade_idx = 0
active_wallets = set()

for k, t in enumerate(tqdm(grid_times, desc="grid time steps")):
    # carry forward previous values by default
    if k > 0:
        V[k, :] = V[k - 1, :]

    # process trades up to and including time t
    while trade_idx < n_trades and time_arr[trade_idx] <= t:
        w = wallet_arr[trade_idx]
        a = asset_arr[trade_idx]
        s = side_arr[trade_idx]
        q = size_arr[trade_idx]
        p = price_arr[trade_idx]

        if w not in wallet_cash:
            wallet_cash[w] = 0.0
            wallet_pos[w] = {}

        if a not in wallet_pos[w]:
            wallet_pos[w][a] = 0.0

        if s == "BUY":
            wallet_pos[w][a] += q
            wallet_cash[w] -= q * p
        else:
            wallet_pos[w][a] -= q
            wallet_cash[w] += q * p

        if wallet_pos[w][a] < 0 and wallet_pos[w][a] > -1e-9:
            wallet_pos[w][a] = 0.0

        if wallet_cash[w] < 0 and wallet_cash[w] > -1e-9:
            wallet_cash[w] = 0.0

        active_wallets.add(w)
        trade_idx += 1

    # recompute value only for wallets that changed in this interval
    for w in active_wallets:
        j = wallet_to_idx.get(w)
        if j is None:
            continue
        V[k, j] = compute_wallet_value(w, k)

    active_wallets.clear()

print("portfolio value matrix shape:", V.shape)


num_times, num_wallets = V.shape

# normalize the wallets
V_unit = V.astype(np.float64).copy()

for j in range(num_wallets):
    col = V[:, j]
    # find the first strictly positive value as the starting capital
    positive_idx = np.where(col > 0)[0]
    if len(positive_idx) == 0:
        # wallet is always zero
        continue
    first_idx = positive_idx[0]
    start_val = col[first_idx]
    if start_val <= 0:
        continue
    # normalize the whole path by start_val
    V_unit[:, j] = col / start_val

# get the percentage changes
R_wallet = np.zeros((num_times, num_wallets), dtype=np.float64)
V_prev = V_unit[:-1, :]
V_curr = V_unit[1:, :]

mask = V_prev > 0
delta = V_curr - V_prev
R_step = np.zeros_like(V_prev)
np.divide(delta, V_prev, out=R_step, where=mask)

R_wallet[1:, :] = R_step

# get the historical sharpe
r = R_wallet[1:, :]
mean_r = r.mean(axis=0)
std_r = r.std(axis=0)
with np.errstate(divide="ignore", invalid="ignore"):
    sharpe_vec = np.where(std_r > 0, mean_r / std_r, 0.0)

wallet_sharpe = {wallet_list[j]: float(sharpe_vec[j]) for j in range(num_wallets)}

print("computed historical sharpe for", len(wallet_sharpe), "wallets")
print("example wallet sharpe (first 10):")
print(list(wallet_sharpe.items())[:10])



In [None]:
if not wallet_list:
    raise ValueError("wallet_list is empty; run the value computation cell first.")

# get wallets that we are interested in
def select_wallets_with_sell_and_final_cash(df_trades_wallet, wallet_list, N=9):
    
    # the ret list
    found_wallets = []
    candidates = np.random.permutation(wallet_list)
    
    # iterate over the candidates
    for w in candidates:
        if len(found_wallets) >= N:
            break


        # get the trades for the wallet
        df_w = df_trades_wallet[df_trades_wallet["proxy_wallet"] == w]
        sides = set(df_w["side"]) if not df_w.empty else set()
        if "SELL" in sides and "BUY" in sides:
            # simulate to determine final cash
            cash = float(wallet_initial_cash.get(w, 0.0))
            pos = dict(wallet_initial_shares.get(w, {}))


            # get the data in the right format
            time_arr_w = df_w["trade_time"].to_numpy()
            asset_arr_w = df_w["asset"].to_numpy()
            side_arr_w = df_w["side"].to_numpy()
            size_arr_w = df_w["size"].to_numpy(dtype=float)
            price_arr_w = df_w["price"].to_numpy(dtype=float)
            n_trades_w = len(df_w)


            # iterate over the grid times
            trade_idx = 0
            for k, t in enumerate(grid_times):
                while trade_idx < n_trades_w and time_arr_w[trade_idx] <= t:
                    a = asset_arr_w[trade_idx]
                    s = side_arr_w[trade_idx]
                    q = size_arr_w[trade_idx]
                    p = price_arr_w[trade_idx]

                    if a not in pos:
                        pos[a] = 0.0

                    if s == "BUY":
                        pos[a] += q
                        cash -= q * p
                    else:
                        pos[a] -= q
                        cash += q * p

                    if pos[a] < 0 and pos[a] > -1e-9:
                        pos[a] = 0.0
                    if cash < 0 and cash > -1e-9:
                        cash = 0.0

                    trade_idx += 1
            # check final cash
            if cash > 0:
                found_wallets.append(w)
    if len(found_wallets) < N:
        print(f"Only found {len(found_wallets)} wallets with both BUY and SELL and final cash > 0.")
    return found_wallets

wallets_selected = select_wallets_with_sell_and_final_cash(df_trades_wallet, wallet_list, N=9)
print(f"Selected {len(wallets_selected)} wallets for visualization:")

fig, axes = plt.subplots(3, 3, figsize=(16, 14))
axes = axes.flat if hasattr(axes, 'flat') else axes.reshape(-1)

for i, wallet_selected in enumerate(wallets_selected):
    ax = axes[i]
    df_w = df_trades_wallet[df_trades_wallet["proxy_wallet"] == wallet_selected].copy()

    if "trade_time" in df_w.columns:
        df_w = df_w.sort_values("trade_time").reset_index(drop=True)
    else:
        raise KeyError("Column 'trade_time' does not exist in df_trades_wallet, cannot proceed.")

    if df_w.empty:
        print(f"no trades found for wallet {wallet_selected}, skipping.")
        continue

    cash = float(wallet_initial_cash.get(wallet_selected, 0.0))
    pos = dict(wallet_initial_shares.get(wallet_selected, {}))


    # get the data in the right format
    time_arr_w = df_w["trade_time"].to_numpy()
    asset_arr_w = df_w["asset"].to_numpy()
    side_arr_w = df_w["side"].to_numpy()
    size_arr_w = df_w["size"].to_numpy(dtype=float)
    price_arr_w = df_w["price"].to_numpy(dtype=float)
    n_trades_w = len(df_w)

    cash_series = np.zeros(len(grid_times), dtype=np.float64)
    pos_value_series = np.zeros(len(grid_times), dtype=np.float64)
    total_series = np.zeros(len(grid_times), dtype=np.float64)

    trade_idx = 0


    # check each of the grid times  
    for k, t in enumerate(grid_times):
        while trade_idx < n_trades_w and time_arr_w[trade_idx] <= t:
            a = asset_arr_w[trade_idx]
            s = side_arr_w[trade_idx]
            q = size_arr_w[trade_idx]
            p = price_arr_w[trade_idx]

            if a not in pos:
                pos[a] = 0.0

            if s == "BUY":
                pos[a] += q
                cash -= q * p
            else:
                pos[a] -= q
                cash += q * p

            if pos[a] < 0 and pos[a] > -1e-9:
                pos[a] = 0.0
            if cash < 0 and cash > -1e-9:
                cash = 0.0

            trade_idx += 1

        pos_val = 0.0
        for a, q in pos.items():
            if q == 0.0:
                continue
            tok_idx = token_to_idx.get(a)
            if tok_idx is None:
                continue
            p_t = price_grid[tok_idx, k]
            pos_val += q * p_t

        val = cash + pos_val

        cash_series[k] = cash
        pos_value_series[k] = pos_val
        total_series[k] = val
    
    # show me the wlalets
    ax.plot(grid_times, total_series, label="total value")
    ax.plot(grid_times, cash_series, label="cash")
    ax.plot(grid_times, pos_value_series, label="positions value")
    ax.set_xlabel("time")
    ax.set_ylabel("value")
    ax.set_title(f"wallet {wallet_selected[:8]}...")
    ax.legend(fontsize="small")
    ax.grid(True, alpha=0.25)



for i in range(len(wallets_selected), 9):
    fig.delaxes(axes[i])

fig.suptitle("Portfolio evolution for example wallets", fontsize=18)
fig.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()


In [None]:
# choose how many chunks you want
NUM_CHUNKS = 6


# the number of times that we have data for
num_times = len(grid_times)
if NUM_CHUNKS < 4:
    raise ValueError("NUM_CHUNKS must be at least 4 (need train, val, test).")

chunk_edges = np.linspace(0, num_times, NUM_CHUNKS + 1, dtype=int)

def get_chunk_indices(k):
    if not (1 <= k <= NUM_CHUNKS):
        raise ValueError(f"chunk index k must be in 1..{NUM_CHUNKS}")
    return int(chunk_edges[k - 1]), int(chunk_edges[k])

print("number of grid_times:", num_times)
for k in range(1, NUM_CHUNKS + 1):
    s, e = get_chunk_indices(k)
    print(f"chunk {k}: indices [{s}, {e}) length {e - s}")

# we will build features and graph using only data up to the end of chunk
cut_start, cut_end = get_chunk_indices(NUM_CHUNKS - 2)
cut_idx = cut_end - 1
t_cut = grid_times[cut_idx]
print(f"feature/graph cutoff t_cut (end of chunk {NUM_CHUNKS-2}):", t_cut)

# feature flags
feature_flags = {
    "use_return_mean_std": True,
    "use_num_series": True,
    "use_num_events": True,
    "use_num_markets": True,
    "use_log_trades_per_hour": True,
    "use_avg_trade_size": True,
    "use_avg_trade_price": True,
    "use_max_drawdown_pct": True,
    "use_series_onehot": True,
    "use_degree_in": True,
    "use_degree_out": True,
    "use_orc_curvature": False,
    "use_local_curvature": True
}

# onehot dimension for most traded series
TOP_SERIES_K = 20  

r_up_to_cut = R_wallet[1:cut_end, :]

ret_mean_full = r_up_to_cut.mean(axis=0).astype(np.float32)
ret_std_full = r_up_to_cut.std(axis=0).astype(np.float32)

df_trades_tcut = df_trades[df_trades["trade_time"] <= t_cut].copy()
df_wallet_tcut = df_trades_wallet[df_trades_wallet["trade_time"] <= t_cut].copy()

df_trades_tcut = df_trades_tcut[df_trades_tcut["proxy_wallet"].isin(wallet_list)]
df_wallet_tcut = df_wallet_tcut[df_wallet_tcut["proxy_wallet"].isin(wallet_list)]

def reindex_to_wallets(series, default=0.0, dtype=np.float32):
    s = series.reindex(wallet_list).fillna(default).astype(dtype)
    return s.to_numpy()

num_series_per_wallet = reindex_to_wallets(
    df_trades_tcut.groupby("proxy_wallet")["series_id"].nunique()
)

num_events_per_wallet = reindex_to_wallets(
    df_trades_tcut.groupby("proxy_wallet")["event_id"].nunique()
)

num_markets_per_wallet = reindex_to_wallets(
    df_trades_tcut.groupby("proxy_wallet")["market_id"].nunique()
)

if not df_wallet_tcut.empty:
    agg_time = df_wallet_tcut.groupby("proxy_wallet")["trade_time"].agg(["count", "min", "max"])
    horizon_hours = (agg_time["max"] - agg_time["min"]).dt.total_seconds() / 3600.0
    horizon_hours = horizon_hours.replace(0, 1.0 / 3600.0)
    trades_per_hour = agg_time["count"] / horizon_hours
    log_trades_per_hour = np.log1p(trades_per_hour)
else:
    log_trades_per_hour = pd.Series([], dtype=float)

log_trades_per_wallet = reindex_to_wallets(log_trades_per_hour)

avg_size_per_wallet = reindex_to_wallets(
    df_wallet_tcut.groupby("proxy_wallet")["size"].mean()
)

avg_price_per_wallet = reindex_to_wallets(
    df_wallet_tcut.groupby("proxy_wallet")["price"].mean()
)

max_drawdown_pct = np.zeros(num_wallets, dtype=np.float32)

for j in range(num_wallets):
    # up to t_cut
    v = V_unit[:cut_end, j]  
    if not np.any(v > 0):
        max_drawdown_pct[j] = 0.0
        continue
    running_max = np.maximum.accumulate(v)
    with np.errstate(divide="ignore", invalid="ignore"):
        dd = np.where(running_max > 0, v / running_max - 1.0, 0.0)
    min_dd = dd.min()
    max_drawdown_pct[j] = float(-min_dd) if min_dd < 0 else 0.0

series_counts = (
    df_trades_tcut["series_id"]
    .dropna()
    .value_counts()
    .head(TOP_SERIES_K)
)
top_series_ids = list(series_counts.index)
series_id_to_col = {sid: idx for idx, sid in enumerate(top_series_ids)}

series_onehot = np.zeros((num_wallets, len(top_series_ids)), dtype=np.float32)

if len(top_series_ids) > 0:
    df_series_pairs = (
        df_trades_tcut[["proxy_wallet", "series_id"]]
        .dropna(subset=["series_id"])
        .drop_duplicates()
    )
    for wallet, sid in df_series_pairs.itertuples(index=False):
        if sid in series_id_to_col and wallet in wallet_to_idx:
            wi = wallet_to_idx[wallet]
            si = series_id_to_col[sid]
            series_onehot[wi, si] = 1.0

print("top series used for one-hot:", top_series_ids)

feature_list = []

if feature_flags["use_return_mean_std"]:
    feature_list.append(ret_mean_full[:, None])
    feature_list.append(ret_std_full[:, None])

if feature_flags["use_num_series"]:
    feature_list.append(num_series_per_wallet[:, None])

if feature_flags["use_num_events"]:
    feature_list.append(num_events_per_wallet[:, None])

if feature_flags["use_num_markets"]:
    feature_list.append(num_markets_per_wallet[:, None])

if feature_flags["use_log_trades_per_hour"]:
    feature_list.append(log_trades_per_wallet[:, None])

if feature_flags["use_avg_trade_size"]:
    feature_list.append(avg_size_per_wallet[:, None])

if feature_flags["use_avg_trade_price"]:
    feature_list.append(avg_price_per_wallet[:, None])

if feature_flags["use_max_drawdown_pct"]:
    feature_list.append(max_drawdown_pct[:, None])

if feature_flags["use_series_onehot"] and series_onehot.shape[1] > 0:
    feature_list.append(series_onehot)

X_base = np.concatenate(feature_list, axis=1).astype(np.float32)
print("X_base shape:", X_base.shape)
print("feature_flags:", feature_flags)


In [None]:
df_graph = df_trades[
    (df_trades["trade_time"] <= t_cut) &
    (df_trades["proxy_wallet"].isin(wallet_list))
].copy()

df_graph = df_graph.sort_values("trade_time").reset_index(drop=True)

cols_needed = [
    "id", "transaction_hash", "condition_id", "size", "price",
    "timestamp", "trade_time", "side", "proxy_wallet",
    "asset", "outcome_id", "outcome_index", "trade_type_id",
    "created_at", "updated_at"
]
for c in cols_needed:
    if c not in df_graph.columns:
        raise ValueError(f"missing column {c} in df_graph")

all_trades_graph = list(df_graph[cols_needed].itertuples(index=False, name=None))

IDX_ID = 0
IDX_TX_HASH = 1
IDX_CONDITION_ID = 2
IDX_SIZE = 3
IDX_PRICE = 4
IDX_TIMESTAMP_INT = 5
IDX_TRADE_TIME   = 6
IDX_SIDE = 7
IDX_PROXY_WALLET = 8
IDX_ASSET = 9
IDX_OUTCOME_ID = 10
IDX_OUTCOME_INDEX = 11
IDX_TRADE_TYPE_ID = 12
IDX_CREATED_AT = 13
IDX_UPDATED_AT = 14

MIN_RELATIONSHIP_WEIGHT = 0.65
MAX_GAP_SECONDS = 15
CO_TRADE_EXTENSION_SECONDS = 40

sorted_trades = sorted(all_trades_graph, key=lambda r: r[IDX_TRADE_TIME])

trades_by_cp = defaultdict(list)
for row in tqdm(sorted_trades, desc="populating trades_by_cp", disable=len(sorted_trades) < 10000):
    condition_id = row[IDX_CONDITION_ID]
    proxy_wallet = row[IDX_PROXY_WALLET]
    trades_by_cp[(condition_id, proxy_wallet)].append(row)

continuum_windows = []

for (condition_id, proxy_wallet), rows_cp in tqdm(
    trades_by_cp.items(),
    desc="building continuum windows",
    disable=len(trades_by_cp) < 5000
):
    if not rows_cp:
        continue
    current_start = rows_cp[0][IDX_TRADE_TIME]
    current_end = rows_cp[0][IDX_TRADE_TIME]

    for row in rows_cp[1:]:
        ts = row[IDX_TRADE_TIME]
        if (ts - current_end).total_seconds() <= MAX_GAP_SECONDS:
            current_end = ts
        else:
            continuum_windows.append(
                {"condition_id": condition_id, "wallet": proxy_wallet,
                 "min_ts": current_start, "max_ts": current_end}
            )
            current_start = ts
            current_end = ts

    continuum_windows.append(
        {"condition_id": condition_id, "wallet": proxy_wallet,
         "min_ts": current_start, "max_ts": current_end}
    )

print("number of continuum windows:", len(continuum_windows))

trades_by_condition = defaultdict(list)
for row in tqdm(sorted_trades, desc="assigning trades by condition", disable=len(sorted_trades) < 10000):
    trades_by_condition[row[IDX_CONDITION_ID]].append(row)

cond_times = {}
for cond_id, rows_c in trades_by_condition.items():
    cond_times[cond_id] = [r[IDX_TRADE_TIME] for r in rows_c]

total_windows_by_wallet = defaultdict(int)
co_trade_windows = defaultdict(lambda: defaultdict(int))

for w in tqdm(continuum_windows, desc="finding co-trade windows", disable=len(continuum_windows) < 5000):
    cond_id = w["condition_id"]
    wallet_a = w["wallet"]
    min_ts = w["min_ts"]
    max_ts = w["max_ts"]

    start_ts = min_ts
    end_ts = max_ts + timedelta(seconds=CO_TRADE_EXTENSION_SECONDS)

    total_windows_by_wallet[wallet_a] += 1

    rows_c = trades_by_condition[cond_id]
    times_c = cond_times[cond_id]

    left = bisect_left(times_c, start_ts)
    right = bisect_right(times_c, end_ts)

    co_wallets = set()
    for row in rows_c[left:right]:
        wallet_b = row[IDX_PROXY_WALLET]
        if wallet_b != wallet_a:
            co_wallets.add(wallet_b)

    for wallet_b in co_wallets:
        co_trade_windows[wallet_a][wallet_b] += 1

directed_graph = defaultdict(dict)

for wallet_a, targets in tqdm(co_trade_windows.items(), desc="building directed graph", disable=len(co_trade_windows) < 3000):
    if wallet_a not in wallet_to_idx:
        continue
    total_windows = total_windows_by_wallet[wallet_a]
    if total_windows == 0:
        continue
    for wallet_b, count in targets.items():
        if wallet_b not in wallet_to_idx:
            continue
        pct_overlap = count / total_windows
        if pct_overlap >= MIN_RELATIONSHIP_WEIGHT:
            directed_graph[wallet_a][wallet_b] = float(pct_overlap)

print("wallets with outgoing edges:", len(directed_graph))


print("computing wallet degrees")
print("num_wallets:", num_wallets)
deg_out = np.zeros(num_wallets, dtype=np.float32)
deg_in = np.zeros(num_wallets, dtype=np.float32)

for w_a, neighbors in tqdm(directed_graph.items(), desc="computing wallet degrees", disable=len(directed_graph) < 3000):
    if w_a not in wallet_to_idx:
        continue
    i = wallet_to_idx[w_a]
    deg_out[i] = len(neighbors)
    for w_b in neighbors.keys():
        if w_b in wallet_to_idx:
            j = wallet_to_idx[w_b]
            deg_in[j] += 1.0

orc_vec = np.zeros(num_wallets, dtype=np.float32)
orc_feature_index = None

if feature_flags.get("use_orc_curvature", False):
    try:
        import networkx as nx
        from GraphRicciCurvature.OllivierRicci import OllivierRicci

        G_orc = nx.Graph()
        for w_a, neighbors in tqdm(
            directed_graph.items(),
            desc="building undirected graph for ORC",
            disable=len(directed_graph) < 3000
        ):
            for w_b, weight in neighbors.items():
                if w_a == w_b:
                    continue
                w_val = float(weight)
                if G_orc.has_edge(w_a, w_b):
                    G_orc[w_a][w_b]["weight"] = 0.5 * (G_orc[w_a][w_b]["weight"] + w_val)
                else:
                    G_orc.add_edge(w_a, w_b, weight=w_val)

        print("graph for ORC: nodes", G_orc.number_of_nodes(), "edges", G_orc.number_of_edges())

        if G_orc.number_of_edges() > 0:
            orc = OllivierRicci(G_orc, alpha=0.5, verbose="INFO")
            orc.compute_ricci_curvature()

            node_orc = {n: data.get("ricciCurvature", 0.0) for n, data in G_orc.nodes(data=True)}
            for w, idx in tqdm(wallet_to_idx.items(), desc="mapping ORC values to wallets", disable=len(wallet_to_idx) < 2000):
                orc_vec[idx] = float(node_orc.get(w, 0.0))

            print(
                "ollivier-ricci curvature stats:",
                "min", float(orc_vec.min()),
                "max", float(orc_vec.max()),
                "mean", float(orc_vec.mean())
            )
        else:
            print("empty graph for ORC, leaving orc_vec as zeros.")
    except Exception as e:
        print("ORC curvature computation failed, skipping. error:", e)


local_curv_vec = np.zeros(num_wallets, dtype=np.float32)
local_curv_feature_index = None



# good approximation of the local curvature
if feature_flags.get("use_local_curvature", False):
    try:

        G_curv = nx.Graph()
        for w_a, neighbors in tqdm(
            directed_graph.items(),
            desc="building undirected graph for local curvature",
            disable=len(directed_graph) < 3000
        ):
            for w_b, weight in neighbors.items():
                if w_a == w_b:
                    continue
                w_val = float(weight)
                if G_curv.has_edge(w_a, w_b):
                    # keep max weight (or average; not critical)
                    G_curv[w_a][w_b]["weight"] = max(G_curv[w_a][w_b]["weight"], w_val)
                else:
                    G_curv.add_edge(w_a, w_b, weight=w_val)

        print(
            "graph for local curvature: nodes",
            G_curv.number_of_nodes(),
            "edges",
            G_curv.number_of_edges()
        )

        if G_curv.number_of_edges() > 0:
            clustering = nx.clustering(G_curv, weight="weight")

            # curvature-like scalar: kappa_i = 1 - C_i
            for w, idx in wallet_to_idx.items():
                C_i = float(clustering.get(w, 0.0))
                local_curv_vec[idx] = 1.0 - C_i

            print(
                "local curvature proxy stats:",
                "min", float(local_curv_vec.min()),
                "max", float(local_curv_vec.max()),
                "mean", float(local_curv_vec.mean())
            )
        else:
            print("empty graph for local curvature, leaving local_curv_vec as zeros.")
    except Exception as e:
        print("local curvature computation failed, skipping. error:", e)


X_all = X_base.astype(np.float32)
X_all_aug = X_all

if feature_flags.get("use_degree_out", True):
    X_all_aug = np.concatenate([X_all_aug, deg_out[:, None]], axis=1)

if feature_flags.get("use_degree_in", True):
    X_all_aug = np.concatenate([X_all_aug, deg_in[:, None]], axis=1)

if feature_flags.get("use_orc_curvature", False):
    X_all_aug = np.concatenate([X_all_aug, orc_vec[:, None]], axis=1)
    orc_feature_index = X_all_aug.shape[1] - 1
    print("ORC feature column index:", orc_feature_index)
else:
    orc_feature_index = None

if feature_flags.get("use_local_curvature", False):
    X_all_aug = np.concatenate([X_all_aug, local_curv_vec[:, None]], axis=1)
    local_curv_feature_index = X_all_aug.shape[1] - 1
    print("local curvature feature column index:", local_curv_feature_index)
else:
    local_curv_feature_index = None

print("X_all_aug shape:", X_all_aug.shape)

# record which columns are curvature columns in X_all_aug
curvature_col_indices = []

if orc_feature_index is not None:
    curvature_col_indices.append(orc_feature_index)

if local_curv_feature_index is not None:
    curvature_col_indices.append(local_curv_feature_index)

curvature_col_indices = sorted(curvature_col_indices)

print("curvature columns:", curvature_col_indices)

edge_src = []
edge_dst = []
edge_weight_list = []

for w_a, neighbors in tqdm(directed_graph.items(), desc="building edge indices and weights", disable=len(directed_graph) < 3000):
    i = wallet_to_idx[w_a]
    for w_b, weight in neighbors.items():
        j = wallet_to_idx[w_b]
        edge_src.append(i)
        edge_dst.append(j)
        edge_weight_list.append(float(weight))

if len(edge_src) == 0:
    edge_index = torch.empty((2, 0), dtype=torch.long)
    edge_attr = torch.empty((0, 1), dtype=torch.float32)
else:
    edge_index = torch.tensor([edge_src, edge_dst], dtype=torch.long)
    edge_attr = torch.tensor(edge_weight_list, dtype=torch.float32).unsqueeze(1)

print("edge_index shape:", edge_index.shape)
print("edge_attr shape:", edge_attr.shape)


In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch_geometric.nn import SAGEConv

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("using device:", device)

deg_total = deg_in + deg_out
active_mask = deg_total >= 1
active_idx = np.where(active_mask)[0]

print(f"active wallets: {active_idx.size} / {num_wallets}")

num_features = X_all_aug.shape[1]

if "curvature_col_indices" not in globals():
    curvature_col_indices = []

all_cols = np.arange(num_features)
curvature_col_indices = np.array(curvature_col_indices, dtype=int) if len(curvature_col_indices) > 0 else np.array([], dtype=int)

if curvature_col_indices.size == 0:
    print("no curvature columns detected")
    non_curv_cols = all_cols
else:
    mask_cols = np.ones(num_features, dtype=bool)
    mask_cols[curvature_col_indices] = False
    non_curv_cols = all_cols[mask_cols]

X_all_no_curv_full = X_all_aug[:, non_curv_cols].astype(np.float32)
X_all_with_curv_full = X_all_aug.astype(np.float32)

print("X_all_no_curv_full shape:", X_all_no_curv_full.shape)
print("X_all_with_curv_full shape:", X_all_with_curv_full.shape)

# restrict to active wallets
X_features_no_curv = X_all_no_curv_full[active_idx]
X_features_with_curv = X_all_with_curv_full[active_idx]

# compress returns to active wallets
R_wallet_active = R_wallet[:, active_idx].astype(np.float32)

# build compressed edge_index induced by active_idx
edge_index_np = edge_index.cpu().numpy()
src_full = edge_index_np[0]
dst_full = edge_index_np[1]


# the edge_index_active is the edge_index of the active wallets
edge_keep_mask = active_mask[src_full] & active_mask[dst_full]
src_kept = src_full[edge_keep_mask]
dst_kept = dst_full[edge_keep_mask]

idx_map = -np.ones(num_wallets, dtype=np.int64)
idx_map[active_idx] = np.arange(active_idx.size, dtype=np.int64)

src_sub = idx_map[src_kept]
dst_sub = idx_map[dst_kept]

edge_index_active = torch.tensor(
    np.vstack([src_sub, dst_sub]),
    dtype=torch.long,
    device=device
)

print("compressed edge_index shape:", edge_index_active.shape)


def normalize_long_short(scores, eps=1e-8):
    w = torch.tanh(scores)
    w = w - w.mean()
    sum_abs = w.abs().sum()
    sum_abs = torch.where(
        sum_abs < eps,
        torch.tensor(1.0, device=w.device, dtype=w.dtype),
        sum_abs
    )
    w = 2.0 * w / sum_abs
    return w

class GNNPortfolioModel(nn.Module):
    def __init__(self, in_dim=2, hidden_dim=16, num_layers=3):
        
        super().__init__()
        
        self.convs = nn.ModuleList()
        
        if num_layers == 1:
            self.convs.append(SAGEConv(in_dim, hidden_dim))
        
        else:
            self.convs.append(SAGEConv(in_dim, hidden_dim))
            for _ in range(num_layers - 2):
                self.convs.append(SAGEConv(hidden_dim, hidden_dim))
            self.convs.append(SAGEConv(hidden_dim, hidden_dim))
        
        self.score_mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x, edge_index):
        h = x
        for conv in self.convs:
            h = conv(h, edge_index)
            h = F.relu(h)
        scores = self.score_mlp(h).squeeze(-1)
        retweights = normalize_long_short(scores)
        return retweights, scores

def sharpe_loss(weights, R_steps, eps=1e-8):
    port_ret = torch.matmul(R_steps, weights)
    mean_ret = port_ret.mean()
    std_ret = port_ret.std(unbiased=False) + eps
    sharpe = mean_ret / std_ret
    loss = -sharpe
    return loss, sharpe, mean_ret, std_ret


def get_return_block(k):
    s, e = get_chunk_indices(k)
    if e <= s + 1:
        raise ValueError(f"chunk {k} too short for returns: start {s}, end {e}")
    return R_wallet_active[s + 1:e, :].astype(np.float32)

def run_experiment(X_np, R_train_np, R_val_np, label):
    print("experiment:", label)


    # get the data in the right format
    x_all = torch.from_numpy(X_np).float().to(device)
    edge_index_local = edge_index_active
    R_train_local = torch.from_numpy(R_train_np).float().to(device)
    R_val_local = torch.from_numpy(R_val_np).float().to(device)

    in_dim = x_all.size(1)
    model = GNNPortfolioModel(in_dim=in_dim, hidden_dim=16, num_layers=8).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

    num_epochs = 2000
    early_stop_patience = 300

    train_losses = []
    val_losses = []
    val_sharpes = []

    best_val_loss = float("inf")
    best_model_state = None
    best_epoch = 0
    epochs_since_improvement = 0


    # iterate through the epochs
    for epoch in range(1, num_epochs + 1):
        model.train()
        optimizer.zero_grad()


        # get the weights for the training set
        weights_train, _ = model(x_all, edge_index_local)
        loss_tr, sharpe_tr, mean_tr, std_tr = sharpe_loss(weights_train, R_train_local)

        loss_tr.backward()
        optimizer.step()

        model.eval()
        with torch.no_grad():
            weights_val, _ = model(x_all, edge_index_local)
            loss_val, sharpe_v, mean_v, std_v = sharpe_loss(weights_val, R_val_local)


        # append all of the data together
        train_losses.append(loss_tr.item())
        val_losses.append(loss_val.item())
        val_sharpes.append(sharpe_v.item())


        # check if we are done or not
        if loss_val.item() < best_val_loss - 1e-6:
            best_val_loss = loss_val.item()
            best_model_state = model.state_dict()
            best_epoch = epoch
            epochs_since_improvement = 0
        else:
            epochs_since_improvement += 1


        # print the results along the way for updates
        if epoch % 20 == 0 or epoch == 1:
            print(
                f"epoch {epoch:04d} | "
                f"train loss {loss_tr.item():12f} | "
                f"train sharpe {sharpe_tr.item():12f} | "
                f"val sharpe {sharpe_v.item():12f} | "
                f"val loss {loss_val.item():12f}"
            )

        if epochs_since_improvement >= early_stop_patience:
            print(f"\nearly stopping at epoch {epoch} (best epoch: {best_epoch})")
            break

    if best_val_loss is not None and best_model_state is not None:
        model.load_state_dict(best_model_state)
        print(f"loaded best model from epoch {best_epoch} with val loss {best_val_loss:}")

    model.eval()
    with torch.no_grad():
        weights_val_best, _ = model(x_all, edge_index_local)
        _, sharpe_v_best, mean_v_best, std_v_best = sharpe_loss(weights_val_best, R_val_local)

    print(f"final best validation sharpe ({label}): {sharpe_v_best.item():}")

    return {
        "label": label,
        "model": model,
        "weights_val": weights_val_best.detach().cpu().numpy(),
        "val_sharpe": sharpe_v_best.item(),
        "train_losses": train_losses,
        "val_losses": val_losses,
        "val_sharpes": val_sharpes,
    }


num_chunks = NUM_CHUNKS
if num_chunks < 4:
    raise ValueError("NUM_CHUNKS must be at least 4 for train/val/test setup or else we do not hvae enough data to work with.")


train_blocks = []
for i in range(1, num_chunks - 2):
    target_k = i + 1
    R_k = get_return_block(target_k)
    train_blocks.append(R_k)
R_train_all = np.concatenate(train_blocks, axis=0).astype(np.float32)
R_val_block = get_return_block(num_chunks - 1)
R_test_block = get_return_block(num_chunks)

results_no_curv = run_experiment(
    X_features_no_curv,
    R_train_np=R_train_all,
    R_val_np=R_val_block,
    label="no_curvature_features"
)

results_with_curv = run_experiment(
    X_features_with_curv,
    R_train_np=R_train_all,
    R_val_np=R_val_block,
    label="with_curvature_features"
)

plt.figure(figsize=(10, 5))

plt.plot(results_no_curv["train_losses"], label="Train Loss (no curvature)", color="blue", alpha=0.6)
plt.plot(results_no_curv["val_losses"], label="Val Loss (no curvature)", color="blue", linestyle="dashed", alpha=0.7)

plt.plot(results_with_curv["train_losses"], label="Train Loss (with curvature)", color="orange", alpha=0.7)
plt.plot(results_with_curv["val_losses"], label="Val Loss (with curvature)", color="orange", linestyle="dashed", alpha=0.8)

plt.xlabel("Epoch")
plt.ylabel("Loss (negative Sharpe)")
plt.title("Train/Validation Losses for GNN Portfolio: With vs. Without Curvature Features")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()





# evaluate on the test set
def eval_on_test(model, X_np, label):
    x_all = torch.from_numpy(X_np).float().to(device)
    edge_index_local = edge_index_active
    R_test_torch = torch.from_numpy(R_test_block).float().to(device)

    model = model.to(device)
    model.eval()
    with torch.no_grad():
        weights_test, _ = model(x_all, edge_index_local)
        _, sharpe_test, mean_test, std_test = sharpe_loss(weights_test, R_test_torch)

    print(
        f"test result ({label}): "
        f"sharpe={sharpe_test.item():} | "
        f"mean_ret={mean_test.item():.6e} | std_ret={std_test.item():.6e}"
    )
    return weights_test.detach().cpu().numpy().reshape(-1), sharpe_test.item()

weights_no_curv_active, sharpe_test_no_curv = eval_on_test(
    results_no_curv["model"],
    X_features_no_curv,
    "no_curvature_features"
)

weights_with_curv_active, sharpe_test_with_curv = eval_on_test(
    results_with_curv["model"],
    X_features_with_curv,
    "with_curvature_features"
)

print(f"val sharpe (no curv): {results_no_curv['val_sharpe']:}")
print(f"val sharpe (with curv): {results_with_curv['val_sharpe']:}")
print(f"test sharpe (no curv): {sharpe_test_no_curv:}")
print(f"test sharpe (with curv): {sharpe_test_with_curv:}")

# expand active weights back to full wallet universe for downstream diagnostics
final_weights_no_curv = np.zeros(num_wallets, dtype=np.float32)
final_weights_no_curv[active_idx] = weights_no_curv_active

final_weights_with_curv = np.zeros(num_wallets, dtype=np.float32)
final_weights_with_curv[active_idx] = weights_with_curv_active


final_weights_val = final_weights_with_curv.copy()
print("stored final_weights_val (with curvature) length:", final_weights_val.shape[0])


In [None]:

required = ["final_weights_no_curv", "final_weights_with_curv"]
for name in required:
    if name not in locals():
        raise ValueError(f"{name} not found. run the comparison training cell first.")

w_no  = final_weights_no_curv
w_cur = final_weights_with_curv

if w_no.shape[0] != V.shape[1] or w_cur.shape[0] != V.shape[1]:
    raise ValueError("mismatch between number of weights and number of wallets.")

print(" check the weights ")
print("[no_curv] sum w:", w_no.sum(), "sum |w|:", np.abs(w_no).sum(), "min:", w_no.min(), "max:", w_no.max())
print("[with_curv] sum w:", w_cur.sum(), "sum |w|:", np.abs(w_cur).sum(), "min:", w_cur.min(), "max:", w_cur.max())

num_chunks = NUM_CHUNKS

# validation window
s_val, e_val = get_chunk_indices(num_chunks - 1)
val_slice = slice(s_val + 1, e_val)

# test window
s_test, e_test = get_chunk_indices(num_chunks)
test_slice = slice(s_test + 1, e_test)

times_val = grid_times[val_slice]
times_test = grid_times[test_slice]

plt.figure(figsize=(10, 4))
plt.hist(w_no,  bins=50, alpha=0.5, label="no_curv")
plt.hist(w_cur, bins=50, alpha=0.5, label="with_curv")
plt.xlabel("weight")
plt.ylabel("count")
plt.title("distribution of learned long-short weights")
plt.legend()
plt.tight_layout()
plt.show()

top_k = 20
sorted_idx_cur = np.argsort(-np.abs(w_cur))
top_idx = sorted_idx_cur[:top_k]
top_weights_cur = w_cur[top_idx]
top_wallets = [wallet_list[i] for i in top_idx]

plt.figure(figsize=(12, 5))
plt.bar(range(top_k), top_weights_cur)
plt.xticks(range(top_k), [w[:8] + "..." for w in top_wallets], rotation=45, ha="right")
plt.ylabel("weight")
plt.title(f"top {top_k} wallets by long-short weight (with curvature)")
plt.tight_layout()
plt.show()

top_k_60 = 60
top_idx_60 = sorted_idx_cur[:top_k_60]
top_weights_cur_60 = w_cur[top_idx_60]
top_wallets_60 = [wallet_list[i] for i in top_idx_60]

plt.figure(figsize=(16, 6))
plt.bar(range(top_k_60), top_weights_cur_60)
plt.xticks(
    range(top_k_60),
    [w[:8] + "..." for w in top_wallets_60],
    rotation=90,
    ha="right"
)
plt.ylabel("weight")
plt.title(f"top {top_k_60} wallets by long-short weight (with curvature)")
plt.tight_layout()
plt.show()

num_times = R_wallet.shape[0]
port_step_no  = R_wallet @ w_no
port_step_cur = R_wallet @ w_cur

def build_path(port_step):
    path = np.zeros(num_times, dtype=np.float64)
    path[0] = 1.0
    for t in range(1, num_times):
        path[t] = path[t - 1] * (1.0 + port_step[t])
    return path

portfolio_full_no = build_path(port_step_no)
portfolio_full_cur = build_path(port_step_cur)

portfolio_val_no = portfolio_full_no[val_slice]
portfolio_val_cur = portfolio_full_cur[val_slice]

# helper to convert time to seconds
def times_to_unix_seconds(times_array):
    if np.issubdtype(times_array.dtype, np.datetime64):
        return times_array.astype("datetime64[s]").astype("int64").astype(float)
    elif np.issubdtype(times_array.dtype, np.number):
        return times_array.astype(float)
    else:
        import pandas as pd
        return pd.to_datetime(times_array).astype(np.int64) // 10**9

times_val_sec = times_to_unix_seconds(np.asarray(times_val))

def stats_for_window(port_step, portfolio_window, label):
    step_ret_val = port_step[val_slice]
    mean_r = step_ret_val.mean()
    std_r = step_ret_val.std()
    sharpe_val = mean_r / std_r if std_r > 0 else 0.0

    if len(times_val_sec) > 1:
        dt = np.median(np.diff(times_val_sec))
        dt_years = dt / (365.25 * 24 * 3600)
        n_steps_per_year = 1.0 / dt_years

        period_years = (times_val_sec[-1] - times_val_sec[0]) / (365.25 * 24 * 3600)
        annualized_return = (portfolio_window[-1] / portfolio_window[0])**(1 / period_years) - 1
        ann_mean = mean_r * n_steps_per_year
        ann_std = std_r * np.sqrt(n_steps_per_year)
        annualized_sharpe = ann_mean / ann_std if ann_std > 0 else 0.0
    else:
        annualized_return = float("nan")
        annualized_sharpe = float("nan")

    print(f"\nvalidation stats [{label}]:")
    print("\tmean step return:", float(mean_r))
    print("\tstd step return:", float(std_r))
    print("\tsharpe:", float(sharpe_val))
    print("\tannualized return:", float(annualized_return))
    print("\tannualized sharpe:", float(annualized_sharpe))

stats_for_window(port_step_no, portfolio_val_no,  "no_curv")
stats_for_window(port_step_cur, portfolio_val_cur, "with_curv")

plt.figure(figsize=(12, 5))
plt.plot(grid_times, portfolio_full_no, label="no_curv")
plt.plot(grid_times, portfolio_full_cur, label="with_curv")
plt.axvline(grid_times[s_val], color="gray", linestyle="--", label="validation start")
plt.axvline(grid_times[s_test], color="red", linestyle="--", label="test start")
plt.xlabel("time")
plt.ylabel("unit value")
plt.title("portfolio over entire horizon: no_curv vs with_curv")
plt.legend()
plt.tight_layout()
plt.show()

plt.figure(figsize=(12, 5))
plt.plot(times_val, portfolio_val_no / portfolio_val_no[0],  label="no_curv")
plt.plot(times_val, portfolio_val_cur / portfolio_val_cur[0], label="with_curv")
plt.xlabel("time")
plt.ylabel("unit value")
plt.title("validation window: no_curv vs with_curv")
plt.legend()
plt.tight_layout()
plt.show()


In [None]:
def show_top_wallets(weights, label, top_n=40):
    # fucntion that makes a nice table for printing
    print(f"\n TOP {top_n} WALLETS with {label} ")
    abs_sorted_idx = np.argsort(-np.abs(weights))[:top_n]

    rows = []
    for rank, idx in enumerate(abs_sorted_idx, start=1):
        # wallet address string
        addr = wallet_list[idx]
        # signed weight
        w    = float(weights[idx])
        rows.append((rank, idx, addr, w, abs(w)))

    # pretty print the results that we get
    print(f"{'rank':>4} | {'idx':>5} | {'wallet':<48} | {'weight':>12} | {'abs(weight)':>12}")
    for r, i, addr, w, aw in rows:
        print(f"{r:4d} | {i:5d} | {addr:<48} | {w:10f} | {aw:10f}")

# print for both models
show_top_wallets(final_weights_no_curv,  "NO CURVATURE")
show_top_wallets(final_weights_with_curv, "WITH CURVATURE")


In [None]:
# restrict to wallets of interest
df_all = df_trades[df_trades["proxy_wallet"].isin(wallet_list)].copy()

# ensure required columns exist
required_cols = ["proxy_wallet", "condition_id", "event_id", "series_id"]
missing = [c for c in required_cols if c not in df_all.columns]
if missing:
    raise ValueError(f"missing columns in df_trades: {missing}")

# get the trade counts per wallet condition combination
df_cond_counts = (
    df_all
    .groupby(["proxy_wallet", "condition_id"], dropna=False)
    .size()
    .rename("trade_count")
    .reset_index()
)

# make the mapping
cond_map = (
    df_all[["condition_id", "event_id", "series_id"]]
    .drop_duplicates(subset=["condition_id"])
)

df_cond_counts = df_cond_counts.merge(cond_map, on="condition_id", how="left")

# aggregate
df_event_counts = (
    df_cond_counts
    .groupby(["proxy_wallet", "event_id"], dropna=False)["trade_count"]
    .sum()
    .reset_index()
)

# aggregate
df_series_counts = (
    df_cond_counts
    .groupby(["proxy_wallet", "series_id"], dropna=False)["trade_count"]
    .sum()
    .reset_index()
)

wallet_series_counts_wide = df_series_counts.pivot_table(
    index="proxy_wallet",
    columns="series_id",
    values="trade_count",
    fill_value=0,
    aggfunc="sum"
)

wallet_event_counts_wide = df_event_counts.pivot_table(
    index="proxy_wallet",
    columns="event_id",
    values="trade_count",
    fill_value=0,
    aggfunc="sum"
)

print("df_cond_counts shape:", df_cond_counts.shape)
print("df_event_counts shape:", df_event_counts.shape)
print("df_series_counts shape:", df_series_counts.shape)
