In [None]:
import mysql.connector
import pandas as pd
import numpy as np
import os
from datetime import timedelta
from collections import defaultdict
from bisect import bisect_left, bisect_right
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import SAGEConv
import networkx as nx

import matplotlib.pyplot as plt

# creds
DB_HOST_NAS = "192.168.0.165"
DB_PORT_NAS = 3306
DB_USER_NAS = "teo"
DB_PASSWORD_NAS = "password"
DB_NAME_NAS = "polymarket_project"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


In [None]:
# load the trades and mappings

TRADES_CACHE_PATH = "cached_trades.parquet"
MAPPINGS_CACHE_PATH = "cached_mappings.pkl"

# true to save 
SAVE_CACHE = False  

# true to load from cache
LOAD_FROM_CACHE = True  

if SAVE_CACHE:
    print("Saving trades to local cache...")
    df_trades.to_parquet(TRADES_CACHE_PATH, index=False)
    
    # mapping tables and series titles
    import pickle
    mappings = {
        "df_markets": df_markets,
        "df_market_events": df_market_events,
        "df_event_series": df_event_series,
        "series_id_to_title": series_id_to_title
    }
    with open(MAPPINGS_CACHE_PATH, "wb") as f:
        pickle.dump(mappings, f)
    
    print(f"Saved {len(df_trades)} trades to {TRADES_CACHE_PATH}")
    print(f"Saved mappings to {MAPPINGS_CACHE_PATH}")

elif LOAD_FROM_CACHE:
    if os.path.exists(TRADES_CACHE_PATH) and os.path.exists(MAPPINGS_CACHE_PATH):
        print("Loading trades from local cache...")
        df_trades = pd.read_parquet(TRADES_CACHE_PATH)
        
        import pickle
        with open(MAPPINGS_CACHE_PATH, "rb") as f:
            mappings = pickle.load(f)
        
        df_markets = mappings["df_markets"]
        df_market_events = mappings["df_market_events"]
        df_event_series = mappings["df_event_series"]
        series_id_to_title = mappings["series_id_to_title"]
        
        # if trade_time not in df_trades recreate it
        if "trade_time" not in df_trades.columns:
            if np.issubdtype(df_trades["timestamp"].dtype, np.number):
                df_trades["trade_time"] = pd.to_datetime(df_trades["timestamp"], unit="s")
            else:
                df_trades["trade_time"] = pd.to_datetime(df_trades["timestamp"])
        
        print(f"Loaded {len(df_trades)} trades from cache")
        print(f"Series loaded: {len(series_id_to_title)}")
    else:
        print("Cache files not found. Run with SAVE_CACHE=True first after loading from DB.")
else:
    print("Using data from database (already loaded in Cell 2)")


In [None]:
#  market/event/series data enrichment

df_trades = df_trades.sort_values("trade_time").reset_index(drop=True)
df_trades["timestamp"] = pd.to_numeric(df_trades["timestamp"], errors="coerce")
df_trades = df_trades.dropna(subset=["timestamp"])

# condition and market merge
df_markets_trim = df_markets[["id", "conditionId"]].dropna(subset=["conditionId"]).copy()
df_markets_trim.rename(columns={"id": "market_id"}, inplace=True)
df_trades = df_trades.merge(df_markets_trim, how="left", left_on="condition_id", right_on="conditionId")
df_trades.drop(columns=["conditionId"], inplace=True, errors="ignore")

# market and event merge
if not df_market_events.empty:
    df_trades = df_trades.merge(
        df_market_events[["market_id", "event_id"]].dropna(),
        how="left", on="market_id"
    )
else:
    df_trades["event_id"] = np.nan

# event and series merge
if not df_event_series.empty:
    df_trades = df_trades.merge(
        df_event_series[["event_id", "series_id"]].dropna(),
        how="left", on="event_id"
    )
else:
    df_trades["series_id"] = np.nan

# convert series id to sting
df_trades["series_id"] = df_trades["series_id"].astype(str)

print("Trades with series info:")
print(df_trades[["proxy_wallet", "condition_id", "series_id"]].head())
print("\nUnique series:", df_trades["series_id"].nunique())


In [None]:
# Params
MIN_TRADES_PER_WALLET = 10
MIN_RELATIONSHIP_WEIGHT = 0.50    

MAX_GAP_SECONDS = 15
CO_TRADE_EXTENSION_SECONDS = 40   

# Training parameters
NUM_CHUNKS = 6
NUM_EPOCHS = 500
EARLY_STOP_PATIENCE = 100
HIDDEN_DIM = 16
NUM_LAYERS = 4
LEARNING_RATE = 1e-3

# Minimum requirements for a series to be analyzed
MIN_WALLETS_IN_SERIES = 20
MIN_EDGES_IN_GRAPH = 10

TOP_N_SERIES = 50

feature_flags = {
    # Base feats
    "use_return_mean_std": True,
    "use_log_trades_per_hour": True,
    "use_avg_trade_size": True,
    "use_avg_trade_price": True,
    "use_max_drawdown_pct": True,
    
    # Graph feats
    "use_degree_in": True,
    "use_degree_out": True,
    "use_leader_ratio": True, # deg_out / (1 + deg_in) 
    
    # Curvature feats
    "use_orc_curvature": False, # Heavy ORC 
    "use_local_curvature": True # Fast local clustering curvature
}

print("Configuration:")
print(f"\tMIN_TRADES_PER_WALLET: {MIN_TRADES_PER_WALLET}")
print(f"\tMIN_RELATIONSHIP_WEIGHT: {MIN_RELATIONSHIP_WEIGHT}")
print(f"\tMAX_GAP_SECONDS: {MAX_GAP_SECONDS}")
print(f"\tCO_TRADE_EXTENSION_SECONDS: {CO_TRADE_EXTENSION_SECONDS}")
print(f"\tTOP_N_SERIES: {TOP_N_SERIES if TOP_N_SERIES else 'ALL'}")
print(f"\nFeature flags:")
for flag, val in feature_flags.items():
    print(f"{flag}: {val}")


In [None]:
# find the series that have enough trades
# to include in our dataset

# count trades 
series_trade_counts = df_trades.groupby("series_id").size().sort_values(ascending=False)

# count wallets per series
series_wallet_counts = df_trades.groupby("series_id")["proxy_wallet"].nunique().sort_values(ascending=False)

# filter to series with enough wallets
valid_series = series_wallet_counts[series_wallet_counts >= MIN_WALLETS_IN_SERIES].index.tolist()

# remove nan and sort by volume
valid_series = [s for s in valid_series if s != 'nan' and pd.notna(s)]
valid_series = sorted(valid_series, key=lambda s: series_trade_counts.get(s, 0), reverse=True)

# apply the filter
if TOP_N_SERIES is not None and TOP_N_SERIES > 0:
    valid_series = valid_series[:TOP_N_SERIES]
    print(f"*** LIMITING TO TOP {TOP_N_SERIES} SERIES BY TRADE VOLUME ***\n")

print(f"Total unique series: {df_trades['series_id'].nunique()}")
print(f"Series with >= {MIN_WALLETS_IN_SERIES} wallets: {len(valid_series)}")
print(f"\nSeries to analyze ({len(valid_series)} total sort by volume:")
for i, sid in enumerate(valid_series[:20]):
    title = series_id_to_title.get(sid, sid)
    n_wallets = series_wallet_counts[sid]
    n_trades = series_trade_counts[sid]
    print(f"  {i+1:>2}. {sid}: {title[:45]:<45} | {n_wallets:>5} wallets | {n_trades:>7} trades")


In [None]:
# Helpers

def normalize_long_short(scores, eps=1e-8):
    w = torch.tanh(scores)
    w = w - w.mean()
    sum_abs = w.abs().sum()
    sum_abs = torch.where(
        sum_abs < eps,
        torch.tensor(1.0, device=w.device, dtype=w.dtype),
        sum_abs
    )
    w = 2.0 * w / sum_abs
    return w


class GNNPortfolioModel(nn.Module):
    def __init__(self, in_dim, hidden_dim=16, num_layers=4):
        super().__init__()
        self.convs = nn.ModuleList()
        
        if num_layers == 1:
            self.convs.append(SAGEConv(in_dim, hidden_dim))
        
        else:
            self.convs.append(SAGEConv(in_dim, hidden_dim))
            for _ in range(num_layers - 2):
                self.convs.append(SAGEConv(hidden_dim, hidden_dim))
            self.convs.append(SAGEConv(hidden_dim, hidden_dim))
        
        self.score_mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x, edge_index):
        h = x
        for conv in self.convs:
            h = conv(h, edge_index)
            h = F.relu(h)
        scores = self.score_mlp(h).squeeze(-1)
        wout = normalize_long_short(scores)
        return wout, scores


def sharpe_loss(weights, R_steps, eps=1e-8):
    port_ret = torch.matmul(R_steps, weights)
    mean_ret = port_ret.mean()
    std_ret = port_ret.std(unbiased=False) + eps
    sharpe = mean_ret / std_ret
    return -sharpe, sharpe, mean_ret, std_ret


def build_co_trading_graph(df_series, wallet_list, wallet_to_idx):
    """
    Build co-trading directed graph for a specific series.
    Returns: directed_graph dict, deg_in array, deg_out array, edge_index tensor
    """
    num_wallets = len(wallet_list)
    
    # time sort trades
    df_series = df_series.sort_values("trade_time").reset_index(drop=True)
    
    # trade tuples
    cols_needed = ["condition_id", "trade_time", "proxy_wallet"]
    all_trades = list(df_series[cols_needed].itertuples(index=False, name=None))
    
    if len(all_trades) < 10:
        return None, None, None, None, 0
    
    # grp by (condition_id, wallet)
    trades_by_cp = defaultdict(list)
    for row in all_trades:
        condition_id, trade_time, proxy_wallet = row
        trades_by_cp[(condition_id, proxy_wallet)].append(trade_time)
    
    # continuum window construction
    continuum_windows = []
    for (condition_id, proxy_wallet), times in trades_by_cp.items():
        times = sorted(times)
        current_start = times[0]
        current_end = times[0]
        
        for ts in times[1:]:
            if (ts - current_end).total_seconds() <= MAX_GAP_SECONDS:
                current_end = ts
            else:
                continuum_windows.append({
                    "condition_id": condition_id,
                    "wallet": proxy_wallet,
                    "min_ts": current_start,
                    "max_ts": current_end
                })
                current_start = ts
                current_end = ts
        
        continuum_windows.append({
            "condition_id": condition_id,
            "wallet": proxy_wallet,
            "min_ts": current_start,
            "max_ts": current_end
        })
    
    # group by condition
    trades_by_condition = defaultdict(list)
    for row in all_trades:
        condition_id, trade_time, proxy_wallet = row
        trades_by_condition[condition_id].append((trade_time, proxy_wallet))
    
    # sort by time
    cond_times = {}
    for cond_id, rows_c in trades_by_condition.items():
        rows_c.sort(key=lambda x: x[0])
        cond_times[cond_id] = [r[0] for r in rows_c]
    
    # find co-trade windows
    total_windows_by_wallet = defaultdict(int)
    co_trade_windows = defaultdict(lambda: defaultdict(int))
    
    for w in continuum_windows:
        cond_id = w["condition_id"]
        wallet_a = w["wallet"]
        min_ts = w["min_ts"]
        max_ts = w["max_ts"]
        
        start_ts = min_ts
        end_ts = max_ts + timedelta(seconds=CO_TRADE_EXTENSION_SECONDS)
        
        total_windows_by_wallet[wallet_a] += 1
        
        rows_c = trades_by_condition[cond_id]
        times_c = cond_times[cond_id]
        
        left = bisect_left(times_c, start_ts)
        right = bisect_right(times_c, end_ts)
        
        co_wallets = set()
        for trade_time, proxy_wallet in rows_c[left:right]:
            if proxy_wallet != wallet_a:
                co_wallets.add(proxy_wallet)
        
        for wallet_b in co_wallets:
            co_trade_windows[wallet_a][wallet_b] += 1
    
    # directed graph with edge weight threshold
    directed_graph = defaultdict(dict)
    
    for wallet_a, targets in co_trade_windows.items():
        if wallet_a not in wallet_to_idx:
            continue
        total_windows = total_windows_by_wallet[wallet_a]
        if total_windows == 0:
            continue
        for wallet_b, count in targets.items():
            if wallet_b not in wallet_to_idx:
                continue
            pct_overlap = count / total_windows
            if pct_overlap >= MIN_RELATIONSHIP_WEIGHT:
                directed_graph[wallet_a][wallet_b] = float(pct_overlap)
    
    # compute deg 
    deg_out = np.zeros(num_wallets, dtype=np.float32)
    deg_in = np.zeros(num_wallets, dtype=np.float32)
    
    for w_a, neighbors in directed_graph.items():
        if w_a not in wallet_to_idx:
            continue
        i = wallet_to_idx[w_a]
        deg_out[i] = len(neighbors)
        for w_b in neighbors.keys():
            if w_b in wallet_to_idx:
                j = wallet_to_idx[w_b]
                deg_in[j] += 1.0
    
    # make edge index
    edge_src = []
    edge_dst = []
    
    for w_a, neighbors in directed_graph.items():
        i = wallet_to_idx[w_a]
        for w_b in neighbors.keys():
            j = wallet_to_idx[w_b]
            edge_src.append(i)
            edge_dst.append(j)
    
    num_edges = len(edge_src)
    
    if num_edges == 0:
        edge_index = torch.empty((2, 0), dtype=torch.long)
    else:
        edge_index = torch.tensor([edge_src, edge_dst], dtype=torch.long)
    
    return directed_graph, deg_in, deg_out, edge_index, num_edges


def compute_local_curvature(directed_graph, wallet_list, wallet_to_idx):
    """
    Compute local curvature proxy for each wallet.
    """
    num_wallets = len(wallet_list)
    local_curv_vec = np.zeros(num_wallets, dtype=np.float32)
    
    # undirected graph for curvature
    G = nx.Graph()
    for w_a, neighbors in directed_graph.items():
        for w_b, weight in neighbors.items():
            if w_a == w_b:
                continue
            if G.has_edge(w_a, w_b):
                G[w_a][w_b]["weight"] = max(G[w_a][w_b]["weight"], weight)
            else:
                G.add_edge(w_a, w_b, weight=weight)
    
    if G.number_of_edges() == 0:
        return local_curv_vec, 0.0, 0.0, 0.0
    
    # cluster coeffs
    clustering = nx.clustering(G, weight="weight")
    
    for w, idx in wallet_to_idx.items():
        C_i = float(clustering.get(w, 0.0))
        local_curv_vec[idx] = 1.0 - C_i
    
    # filter active wallets
    active_curvs = [local_curv_vec[wallet_to_idx[w]] for w in G.nodes() if w in wallet_to_idx]
    if active_curvs:
        curv_min = float(np.min(active_curvs))
        curv_max = float(np.max(active_curvs))
        curv_mean = float(np.mean(active_curvs))
    else:
        curv_min, curv_max, curv_mean = 0.0, 0.0, 0.0
    
    return local_curv_vec, curv_min, curv_max, curv_mean


def compute_orc_curvature(directed_graph, wallet_list, wallet_to_idx):
    num_wallets = len(wallet_list)
    orc_vec = np.zeros(num_wallets, dtype=np.float32)
    
    try:
        # windows uses spawn not fork for orc
        import multiprocessing as mp
        _original_get_context = mp.get_context
        def _patched_get_context(method=None):
            if method == 'fork':
                method = 'spawn'  
            return _original_get_context(method)
        mp.get_context = _patched_get_context
        
        from GraphRicciCurvature.OllivierRicci import OllivierRicci  
    except ImportError:
        print("GraphRicciCurvature not installed.")
        return orc_vec, 0.0, 0.0, 0.0
    
    # undirected
    G = nx.Graph()
    for w_a, neighbors in directed_graph.items():
        for w_b, weight in neighbors.items():
            if w_a == w_b:
                continue
            w_val = float(weight)
            if G.has_edge(w_a, w_b):
                G[w_a][w_b]["weight"] = 0.5 * (G[w_a][w_b]["weight"] + w_val)
            else:
                G.add_edge(w_a, w_b, weight=w_val)
    
    if G.number_of_edges() == 0:
        return orc_vec, 0.0, 0.0, 0.0
    
    try:
        
        # try to compute the curvature
        orc = OllivierRicci(G, alpha=0.5, verbose="ERROR", proc=1)
        orc.compute_ricci_curvature()
        

        # get the curvature for each node
        node_orc = {n: data.get("ricciCurvature", 0.0) for n, data in G.nodes(data=True)}
        for w, idx in wallet_to_idx.items():
            orc_vec[idx] = float(node_orc.get(w, 0.0))
        
        active_orcs = [orc_vec[wallet_to_idx[w]] for w in G.nodes() if w in wallet_to_idx]
        if active_orcs:
            orc_min = float(np.min(active_orcs))
            orc_max = float(np.max(active_orcs))
            orc_mean = float(np.mean(active_orcs))
        else:
            orc_min, orc_max, orc_mean = 0.0, 0.0, 0.0
            
    except Exception as e:
        print(f"ORC computation failed: {e}")
        orc_min, orc_max, orc_mean = 0.0, 0.0, 0.0
    
    return orc_vec, orc_min, orc_max, orc_mean


print("Helper functions defined.")


In [None]:
# Analysis function

def analyze_series(series_id, df_trades, verbose=False):
    series_title = series_id_to_title.get(series_id, series_id)
    
    # filter to this series
    df_series = df_trades[df_trades["series_id"] == series_id].copy()
    
    if len(df_series) < 100:
        if verbose:
            print(f"Skipping {series_title} only {len(df_series)} trades")
        return None
    
    df_series = df_series.sort_values("trade_time").reset_index(drop=True)
    
    # get wallets with enough trades
    wallet_counts = df_series.groupby("proxy_wallet").size()
    eligible_wallets = set(wallet_counts[wallet_counts >= MIN_TRADES_PER_WALLET].index)
    
    if len(eligible_wallets) < MIN_WALLETS_IN_SERIES:
        if verbose:
            print(f"Skipping {series_title}: only {len(eligible_wallets)} eligible wallets")
        return None
    
    df_series = df_series[df_series["proxy_wallet"].isin(eligible_wallets)].copy()
    
    # make mappings
    wallet_list = sorted(eligible_wallets)
    wallet_to_idx = {w: i for i, w in enumerate(wallet_list)}
    num_wallets = len(wallet_list)
    
    # make time grid
    unique_times = pd.Index(df_series["trade_time"].sort_values().unique())
    num_grid = min(200, len(unique_times))
    
    if num_grid < 2 * NUM_CHUNKS:
        if verbose:
            print(f"Skipping {series_title}: not enough time points ({num_grid})")
        return None
    
    grid_indices = np.linspace(0, len(unique_times) - 1, num_grid).astype(int)
    grid_times = unique_times[grid_indices]
    
    # token price paths
    df_prices = df_series[["trade_time", "asset", "price"]].copy()
    df_prices["price"] = df_prices["price"].astype(float)
    
    token_price_paths = {}
    for token_id, g in df_prices.groupby("asset", sort=False):
        g = g.sort_values("trade_time")
        g_token = g[["trade_time", "price"]].groupby("trade_time")["price"].last().to_frame()
        token_price_paths[token_id] = g_token
    
    token_list = sorted(token_price_paths.keys())
    token_to_idx = {tok: i for i, tok in enumerate(token_list)}
    num_tokens = len(token_list)
    
    # price grid
    def prices_on_grid(df_token, grid_times):
        tok_times = df_token.index.to_numpy()
        tok_prices = df_token["price"].to_numpy()
        if tok_times.size == 0:
            return np.zeros(len(grid_times), dtype=np.float32)
        idx = np.searchsorted(tok_times, grid_times, side="right") - 1
        idx[idx < 0] = 0
        idx = np.clip(idx, 0, tok_times.size - 1)
        return tok_prices[idx].astype(np.float32)
    
    price_grid = np.zeros((num_tokens, num_grid), dtype=np.float32)
    for tok, i_tok in token_to_idx.items():
        price_grid[i_tok, :] = prices_on_grid(token_price_paths[tok], grid_times)
    
    # we infer initial holdings here
    df_wallet = df_series[["trade_time", "proxy_wallet", "asset", "side", "size", "price"]].copy()
    df_wallet["size"] = df_wallet["size"].astype(float)
    df_wallet["price"] = df_wallet["price"].astype(float)
    df_wallet = df_wallet.sort_values("trade_time").reset_index(drop=True)
    
    df_wallet["share_change"] = np.where(df_wallet["side"] == "BUY", df_wallet["size"], -df_wallet["size"])
    df_wallet["cash_change"] = np.where(
        df_wallet["side"] == "BUY",
        -df_wallet["size"] * df_wallet["price"],
        df_wallet["size"] * df_wallet["price"]
    )
    
    wallet_initial_shares = defaultdict(dict)
    for (wallet, asset), grp in df_wallet.groupby(["proxy_wallet", "asset"]):
        cum_shares = np.cumsum(grp["share_change"].values)
        min_cum = cum_shares.min()
        wallet_initial_shares[wallet][asset] = max(0.0, -float(min_cum))
    
    wallet_initial_cash = {}
    for wallet, grp in df_wallet.groupby("proxy_wallet"):
        cum_cash = np.cumsum(grp["cash_change"].values)
        min_cash = cum_cash.min()
        wallet_initial_cash[wallet] = max(0.0, -float(min_cash))
    
    # compute portfolio values over time
    V = np.zeros((num_grid, num_wallets), dtype=np.float64)
    
    time_arr = df_wallet["trade_time"].to_numpy()
    wallet_arr = df_wallet["proxy_wallet"].to_numpy()
    asset_arr = df_wallet["asset"].to_numpy()
    side_arr = df_wallet["side"].to_numpy()
    size_arr = df_wallet["size"].to_numpy(dtype=float)
    price_arr = df_wallet["price"].to_numpy(dtype=float)
    n_trades = len(df_wallet)
    
    wallet_cash = {w: float(wallet_initial_cash.get(w, 0.0)) for w in wallet_list}
    wallet_pos = {w: dict(wallet_initial_shares.get(w, {})) for w in wallet_list}
    
    def compute_wallet_value(wallet, k):
        val = wallet_cash[wallet]
        pos = wallet_pos[wallet]
        for asset, shares in pos.items():
            if shares == 0.0:
                continue
            tok_idx = token_to_idx.get(asset)
            if tok_idx is None:
                continue
            val += shares * price_grid[tok_idx, k]
        return val
    
    trade_idx = 0
    active_wallets_set = set()
    
    for k, t in enumerate(grid_times):
        if k > 0:
            V[k, :] = V[k - 1, :]
        
        while trade_idx < n_trades and time_arr[trade_idx] <= t:
            w = wallet_arr[trade_idx]
            a = asset_arr[trade_idx]
            s = side_arr[trade_idx]
            q = size_arr[trade_idx]
            p = price_arr[trade_idx]
            
            if w not in wallet_to_idx:
                trade_idx += 1
                continue
            
            if a not in wallet_pos[w]:
                wallet_pos[w][a] = 0.0
            
            if s == "BUY":
                wallet_pos[w][a] += q
                wallet_cash[w] -= q * p
            else:
                wallet_pos[w][a] -= q
                wallet_cash[w] += q * p
            
            if wallet_pos[w][a] < 0 and wallet_pos[w][a] > -1e-9:
                wallet_pos[w][a] = 0.0
            if wallet_cash[w] < 0 and wallet_cash[w] > -1e-9:
                wallet_cash[w] = 0.0
            
            active_wallets_set.add(w)
            trade_idx += 1
        
        for w in active_wallets_set:
            j = wallet_to_idx.get(w)
            if j is not None:
                V[k, j] = compute_wallet_value(w, k)
        
        active_wallets_set.clear()
    
    # normalize to unit value paths and compute returns
    V_unit = V.copy()
    for j in range(num_wallets):
        col = V[:, j]
        positive_idx = np.where(col > 0)[0]
        if len(positive_idx) > 0:
            start_val = col[positive_idx[0]]
            if start_val > 0:
                V_unit[:, j] = col / start_val
    
    R_wallet = np.zeros((num_grid, num_wallets), dtype=np.float64)
    V_prev = V_unit[:-1, :]
    V_curr = V_unit[1:, :]
    mask = V_prev > 0
    delta = V_curr - V_prev
    R_step = np.zeros_like(V_prev)
    np.divide(delta, V_prev, out=R_step, where=mask)
    R_wallet[1:, :] = R_step
    
    # chunk indices
    chunk_edges = np.linspace(0, num_grid, NUM_CHUNKS + 1, dtype=int)
    
    def get_chunk_indices(k):
        return int(chunk_edges[k - 1]), int(chunk_edges[k])
    
    # cut off for features/graph: end of chunk NUM_CHUNKS-2
    cut_start, cut_end = get_chunk_indices(NUM_CHUNKS - 2)
    cut_idx = cut_end - 1
    t_cut = grid_times[cut_idx]
    
    # make the feats up to t_cut
    r_up_to_cut = R_wallet[1:cut_end, :]
    ret_mean = r_up_to_cut.mean(axis=0).astype(np.float32)
    ret_std = r_up_to_cut.std(axis=0).astype(np.float32)
    
    # trade activity features
    df_wallet_tcut = df_wallet[df_wallet["trade_time"] <= t_cut]
    
    def reindex_to_wallets(series, default=0.0):
        return series.reindex(wallet_list).fillna(default).astype(np.float32).to_numpy()
    
    if not df_wallet_tcut.empty:
        agg_time = df_wallet_tcut.groupby("proxy_wallet")["trade_time"].agg(["count", "min", "max"])
        horizon_hours = (agg_time["max"] - agg_time["min"]).dt.total_seconds() / 3600.0
        horizon_hours = horizon_hours.replace(0, 1.0 / 3600.0)
        trades_per_hour = agg_time["count"] / horizon_hours
        log_trades_per_hour = np.log1p(trades_per_hour)
    else:
        log_trades_per_hour = pd.Series([], dtype=float)
    
    log_trades = reindex_to_wallets(log_trades_per_hour)
    avg_size = reindex_to_wallets(df_wallet_tcut.groupby("proxy_wallet")["size"].mean())
    avg_price = reindex_to_wallets(df_wallet_tcut.groupby("proxy_wallet")["price"].mean())
    
    # max drawdown calc
    #
    max_drawdown = np.zeros(num_wallets, dtype=np.float32)
    for j in range(num_wallets):
        v = V_unit[:cut_end, j]
        if np.any(v > 0):
            running_max = np.maximum.accumulate(v)
            with np.errstate(divide="ignore", invalid="ignore"):
                dd = np.where(running_max > 0, v / running_max - 1.0, 0.0)
            max_drawdown[j] = float(-dd.min()) if dd.min() < 0 else 0.0
    
    # make co trading graph
    df_graph = df_series[df_series["trade_time"] <= t_cut].copy()
    directed_graph, deg_in, deg_out, edge_index, num_edges = build_co_trading_graph(
        df_graph, wallet_list, wallet_to_idx
    )
    
    if num_edges < MIN_EDGES_IN_GRAPH:
        if verbose:
            print(f"Skipping {series_title} since it has only {num_edges} edges in graph")
        return None
    
    # compute curvatures
    local_curv_vec = np.zeros(num_wallets, dtype=np.float32)
    orc_curv_vec = np.zeros(num_wallets, dtype=np.float32)
    curv_min, curv_max, curv_mean = 0.0, 0.0, 0.0
    orc_min, orc_max, orc_mean = 0.0, 0.0, 0.0
    
    if feature_flags.get("use_local_curvature", True):
        local_curv_vec, curv_min, curv_max, curv_mean = compute_local_curvature(
            directed_graph, wallet_list, wallet_to_idx
        )
    
    if feature_flags.get("use_orc_curvature", False):
        orc_curv_vec, orc_min, orc_max, orc_mean = compute_orc_curvature(
            directed_graph, wallet_list, wallet_to_idx
        )
    
    # assemble base feats
    feature_list_base = []
    
    if feature_flags.get("use_return_mean_std", True):
        feature_list_base.append(ret_mean[:, None])
        feature_list_base.append(ret_std[:, None])
    
    if feature_flags.get("use_log_trades_per_hour", True):
        feature_list_base.append(log_trades[:, None])
    
    if feature_flags.get("use_avg_trade_size", True):
        feature_list_base.append(avg_size[:, None])
    
    if feature_flags.get("use_avg_trade_price", True):
        feature_list_base.append(avg_price[:, None])
    
    if feature_flags.get("use_max_drawdown_pct", True):
        feature_list_base.append(max_drawdown[:, None])
    
    if feature_flags.get("use_degree_out", True):
        feature_list_base.append(deg_out[:, None])
    
    if feature_flags.get("use_degree_in", True):
        feature_list_base.append(deg_in[:, None])

        # lender ratio
    leader_ratio = deg_out / (1 + deg_in)
    
    if feature_flags.get("use_leader_ratio", True):
        feature_list_base.append(leader_ratio[:, None])
    
    if len(feature_list_base) == 0:
        feature_list_base = [ret_mean[:, None], ret_std[:, None]]
    
    X_base = np.concatenate(feature_list_base, axis=1).astype(np.float32)
    
    # assemble curve feats
    feature_list_curv = [X_base]
    
    if feature_flags.get("use_local_curvature", True):
        feature_list_curv.append(local_curv_vec[:, None])
    
    if feature_flags.get("use_orc_curvature", False):
        feature_list_curv.append(orc_curv_vec[:, None])
    
    X_with_curv = np.concatenate(feature_list_curv, axis=1).astype(np.float32)
    
    # active wallets filtration
    deg_total = deg_in + deg_out
    active_mask = deg_total >= 1
    active_idx = np.where(active_mask)[0]
    
    if active_idx.size < 10:
        if verbose:
            print(f"Skipping {series_title} since it has only {active_idx.size} active wallets")
        return None
    
    X_no_curv_active = X_base[active_idx]
    X_with_curv_active = X_with_curv[active_idx]
    R_wallet_active = R_wallet[:, active_idx].astype(np.float32)
    
    # edge index for active wallets
    edge_index_np = edge_index.numpy()
    src_full, dst_full = edge_index_np[0], edge_index_np[1]
    edge_keep_mask = active_mask[src_full] & active_mask[dst_full]
    src_kept = src_full[edge_keep_mask]
    dst_kept = dst_full[edge_keep_mask]
    
    idx_map = -np.ones(num_wallets, dtype=np.int64)
    idx_map[active_idx] = np.arange(active_idx.size, dtype=np.int64)
    
    src_sub = idx_map[src_kept]
    dst_sub = idx_map[dst_kept]
    edge_index_active = torch.tensor(np.vstack([src_sub, dst_sub]), dtype=torch.long, device=device)
    
    # make train/val/test returns
    def get_return_block(k):
        s, e = get_chunk_indices(k)
        if e <= s + 1:
            return None
        return R_wallet_active[s + 1:e, :].astype(np.float32)
    
    train_blocks = []
    for i in range(1, NUM_CHUNKS - 2):
        R_k = get_return_block(i + 1)
        if R_k is not None:
            train_blocks.append(R_k)
    
    if not train_blocks:
        if verbose:
            print(f"Skipping {series_title} with no training blocks")
        return None
    
    R_train = np.concatenate(train_blocks, axis=0).astype(np.float32)
    R_val = get_return_block(NUM_CHUNKS - 1)
    R_test = get_return_block(NUM_CHUNKS)
    
    if R_val is None or R_test is None or len(R_train) < 5 or len(R_val) < 5 or len(R_test) < 5:
        if verbose:
            print(f"Skipping {series_title} ad there is insufficient return data")
        return None
    
    # train models
    def train_model(X_np, label):

        # get the data in the right format
        x_all = torch.from_numpy(X_np).float().to(device)
        R_train_t = torch.from_numpy(R_train).float().to(device)
        R_val_t = torch.from_numpy(R_val).float().to(device)
        R_test_t = torch.from_numpy(R_test).float().to(device)
        

        # initialize the model
        in_dim = x_all.size(1)
        model = GNNPortfolioModel(in_dim=in_dim, hidden_dim=HIDDEN_DIM, num_layers=NUM_LAYERS).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
        
        best_val_loss = float("inf")
        best_state = None
        epochs_since_best = 0
        

        # iterate over the epochs
        for epoch in range(1, NUM_EPOCHS + 1):
            model.train()
            optimizer.zero_grad()
            weights, _ = model(x_all, edge_index_active)
            loss, _, _, _ = sharpe_loss(weights, R_train_t)
            loss.backward()
            optimizer.step()
            
            model.eval()
            with torch.no_grad():
                weights_val, _ = model(x_all, edge_index_active)
                loss_val, sharpe_val, _, _ = sharpe_loss(weights_val, R_val_t)
            
            if loss_val.item() < best_val_loss - 1e-6:
                best_val_loss = loss_val.item()
                best_state = model.state_dict()
                epochs_since_best = 0
            else:
                epochs_since_best += 1
            
            if epochs_since_best >= EARLY_STOP_PATIENCE:
                break
        
        if best_state is not None:
            model.load_state_dict(best_state)
        
        model.eval()
        with torch.no_grad():
            final_weights, _ = model(x_all, edge_index_active)
            _, val_sharpe, _, _ = sharpe_loss(final_weights, R_val_t)
            _, test_sharpe, _, _ = sharpe_loss(final_weights, R_test_t)
        
        # sharpe and weights return
        return val_sharpe.item(), test_sharpe.item(), final_weights.cpu().numpy()
    
    # train without curvature
    val_sharpe_no_curv, test_sharpe_no_curv, weights_no_curv = train_model(X_no_curv_active, "no_curv")
    
    # train with curvature
    val_sharpe_with_curv, test_sharpe_with_curv, weights_with_curv = train_model(X_with_curv_active, "with_curv")
    
    # build portfolio returns for test only 
    test_start, test_end = get_chunk_indices(NUM_CHUNKS)
    R_test_segment = R_wallet_active[test_start:test_end, :]
    test_times = grid_times[test_start:test_end]
    test_len = test_end - test_start
    
    port_ret_test_no = R_test_segment @ weights_no_curv
    port_ret_test_curv = R_test_segment @ weights_with_curv
    
    cum_returns_test_no = np.zeros(test_len)
    cum_returns_test_curv = np.zeros(test_len)
    cum_returns_test_no[0] = 1.0
    cum_returns_test_curv[0] = 1.0
    for t in range(1, test_len):
        cum_returns_test_no[t] = cum_returns_test_no[t-1] * (1 + port_ret_test_no[t])
        cum_returns_test_curv[t] = cum_returns_test_curv[t-1] * (1 + port_ret_test_curv[t])
    
    return {
        "series_id": series_id,
        "series_title": series_title,
        "num_trades": len(df_series),
        "num_wallets": num_wallets,
        "num_active_wallets": active_idx.size,
        "num_edges": num_edges,
        "num_features_base": X_base.shape[1],
        "num_features_curv": X_with_curv.shape[1],
        "leader_ratio_mean": float(np.mean(leader_ratio)),
        # curve stats local
        "local_curv_min": curv_min,
        "local_curv_max": curv_max,
        "local_curv_mean": curv_mean,
        # orc curvature stats (if enabled)
        "orc_curv_min": orc_min,
        "orc_curv_max": orc_max,
        "orc_curv_mean": orc_mean,
        # performance
        "val_sharpe_no_curv": val_sharpe_no_curv,
        "val_sharpe_with_curv": val_sharpe_with_curv,
        "test_sharpe_no_curv": test_sharpe_no_curv,
        "test_sharpe_with_curv": test_sharpe_with_curv,
        "curv_improvement_val": val_sharpe_with_curv - val_sharpe_no_curv,
        "curv_improvement_test": test_sharpe_with_curv - test_sharpe_no_curv,
        # returns 
        "test_times": test_times,
        "cum_returns_no_curv": cum_returns_test_no,
        "cum_returns_with_curv": cum_returns_test_curv,
        "final_return_no_curv": cum_returns_test_no[-1] - 1,
        "final_return_with_curv": cum_returns_test_curv[-1] - 1,
        # per wallet data
        "weights_no_curv": weights_no_curv,
        "weights_with_curv": weights_with_curv,
        "leader_ratio_active": leader_ratio[active_idx]
    }


print("Per-series analysis function defined.")


In [None]:
# Run analysis on all valid series

results = []

print(f"Analyzing {len(valid_series)} series")

for i, series_id in enumerate(tqdm(valid_series, desc="Analyzing series")):
    series_title = series_id_to_title.get(series_id, series_id)
    
    try:
        result = analyze_series(series_id, df_trades, verbose=False)
        
        if result is not None:
            results.append(result)
    except Exception as e:
        print(f"\n[{i+1}/{len(valid_series)}] {series_title[:40]} with the following error: - {str(e)[:50]}")
        continue

print(f"\n\nCompleted analysis for {len(results)} series.")


In [None]:
# Convert to Df and analyze

df_results = pd.DataFrame(results)

if len(df_results) == 0:
    print("No results to analyze.")
else:
    print("RESULTS")
    
    print(f"\Series analyzed: {len(df_results)}")
    print(f"\nFlags used:")
    for flag, val in feature_flags.items():
        print(f"  {flag}: {val}")
    
    print(f"\nOverall stats with local curvature:")
    print(f"\tstd curvature over series: {df_results['local_curv_mean'].std():.4f}")
    print(f"\tmean curvature over series: {df_results['local_curv_mean'].mean():.4f}")
    print(f"\tMin curvature: {df_results['local_curv_mean'].min():.4f}")
    print(f"\tMax curvature: {df_results['local_curv_mean'].max():.4f}")
    
    if feature_flags.get("use_orc_curvature", False):
        print(f"\nOverall stats with ORC curvature:")
        print(f"\tMean ORC curvature over series: {df_results['orc_curv_mean'].mean():.4f}")
        print(f"\tStd ORC curvature over series: {df_results['orc_curv_mean'].std():.4f}")
    
    print(f"\tmean val Sharpe improvement from curvature: {df_results['curv_improvement_val'].mean():.4f}")
    print(f"\tmean test Sharpe improvement from curvature: {df_results['curv_improvement_test'].mean():.4f}")
    
    # How many series benefit from curvature?
    n_val_improved = (df_results['curv_improvement_val'] > 0).sum()
    n_test_improved = (df_results['curv_improvement_test'] > 0).sum()
    print(f"\nseries where curvature improved val Sharpe: {n_val_improved} / {len(df_results)} which is {100*n_val_improved / len(df_results):.3f}%")
    print(f"\nseries where curvature improved test Sharpe: {n_test_improved} / {len(df_results)} which is {100*n_test_improved / len(df_results):.3f}%")
    
    df_results


In [None]:
# summary table and leader ratio vs curvature

print("Top Markets by Leader Ratio")
summary_cols = ['series_title', 'leader_ratio_mean', 'local_curv_mean']
summary_df = df_results[summary_cols].copy()
summary_df.columns = ['Market', 'Leader Ratio', 'Curvature']
summary_df = summary_df.sort_values('Leader Ratio', ascending=False).head(10)

display(summary_df.reset_index(drop=True))

corr_leader_curv = df_results['leader_ratio_mean'].corr(df_results['local_curv_mean'])
fig, ax = plt.subplots(figsize=(8, 6))
ax.scatter(df_results['leader_ratio_mean'], df_results['local_curv_mean'], alpha=0.6)
ax.set_xlabel('Leader Ratio')
ax.set_ylabel('Curvature')
ax.set_title(f'Leader Ratio vs Curvature with r = {corr_leader_curv:.3f}')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()
