In [None]:
# Standard library
import datetime as dt
import gc
import glob
import json
import os
import re
import subprocess
import math
import sys
from datetime import date, datetime, time as dtime
from pathlib import Path

from typing import Iterable, Iterator

from mpl_toolkits.axes_grid1.inset_locator import inset_axes
# Third-party
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
from matplotlib.ticker import NullFormatter
import numpy as np
import pandas as pd
import polars as pl
import pyarrow.dataset as ds
import requests
from scipy.stats import expon, gaussian_kde, kstest
import networkx as nx

In [None]:
# ============================
# Address-to-Address Flow Graph (Top-K SOURCES → Top-M DESTS)
# MEMORY-SAFE + DIAGNOSTICS + PATTERNS AWARE + PREFIX FIXES
# ============================

# ---------- knobs ----------
TOPK_SRC           = 25                       # Top spenders we trace as SOURCES
TOPM_DST           = 150                      # Keep top-M destinations overall (after reduce)
START_DAY          = "2013-01-01"             # str or datetime.date; we coerce below
MAX_TX_OUTPAIRS    = 400                      # skip pathological tx (sources*destinations)
SHOW_EDGES         = 150                      # edges to draw after pruning
FINAL_MIN_EDGE_BTC = 0.5                      # filter AFTER aggregation (visual clarity)
TX_MIN_EDGE_BTC    = 0.0                      # per-transaction edge cutoff BEFORE aggregation
MAX_DST_PER_TX     = 20                       # per spend tx, keep at most top-N dest outputs
CHUNK_SIZE         = int(os.getenv("CHUNK_SIZE", "200"))

# On-disk scratch (external drive recommended)
WORKDIR            = Path("/media/vatereal/Main/flows_tmp")
OUTPOINTS_DS       = WORKDIR / "top_outpoints_ds"     # (txid,n)->(addr,value) for TOP SOURCES
EDGES_DS           = WORKDIR / "edges_fragments"      # edge fragments to be reduced later
PFX_LEN            = 3                                 # txid prefix length for partitioning (3–4 ok)

WORKDIR.mkdir(parents=True, exist_ok=True)
OUTPOINTS_DS.mkdir(parents=True, exist_ok=True)
EDGES_DS.mkdir(parents=True, exist_ok=True)

# ---------- config helpers ----------
def _coerce_start_day(x) -> date:
    if isinstance(x, date):
        return x
    if isinstance(x, str):
        return date.fromisoformat(x)
    return date(2013, 1, 1)

SAFER_START: date = _coerce_start_day(START_DAY)

# VALUE_UNITS controls btc conversion; if not pre-defined, make a safe default
try:
    VALUE_UNITS  # may be defined earlier
except NameError:
    VALUE_UNITS = os.getenv("VALUE_UNITS", "sats").lower()

try:
    value_btc_expr  # may be defined earlier
except NameError:
    if VALUE_UNITS in ("sats", "sat", "satoshis"):
        value_btc_expr = (pl.col("value").cast(pl.Float64) / 1e8)
    elif VALUE_UNITS in ("btc",):
        value_btc_expr = pl.col("value").cast(pl.Float64)
    else:
        value_btc_expr = (pl.col("value").cast(pl.Float64) / 1e8)

# Basic address normalizer; reuse if provided
try:
    _addr_norm_py
except NameError:
    def _addr_norm_py(x):
        if x is None:
            return None
        s = str(x).strip()
        return s if s else None

# Expect 'patterns' to be defined earlier; safe fallback if not
try:
    patterns
except NameError:
    PARQUET_DIR = Path(os.getenv("PARQUET_DIR", "."))
    patterns = {
        "blocks": str(PARQUET_DIR / "blocks/day=*/blocks-*.parquet"),
        "txs":    str(PARQUET_DIR / "txs/day=*/txs-*.parquet"),
        "io":     str(PARQUET_DIR / "io/day=*/io-*.parquet"),
    }

# ---------- utilities ----------
DAY_RE = re.compile(r"[\/\\]day=(\d{4}-\d{2}-\d{2})[\/\\]")

def _day_from_path(p: str) -> date | None:
    m = DAY_RE.search(p)
    if not m:
        return None
    try:
        return date.fromisoformat(m.group(1))
    except Exception:
        return None

def list_io_files_since(since_day: date) -> list[str]:
    all_paths = glob.glob(patterns["io"])
    keep: list[str] = []
    for p in all_paths:
        d = _day_from_path(p)
        if d is None or d < since_day:
            continue
        keep.append(p)
    keep.sort()
    return keep

def chunked_paths(paths: Iterable[str], n: int) -> Iterator[list[str]]:
    batch: list[str] = []
    for p in paths:
        batch.append(p)
        if len(batch) >= n:
            yield batch
            batch = []
    if batch:
        yield batch

# Column aliasing (if your schema varies)
ALIASES = {
    "addr": "address",
    "val": "value",
    "tx_id": "txid",
    "prev_tx": "prev_txid",
    "prev_index": "prev_vout",
    "vout": "n",
}

NEEDED_ANYWAY = ["dir", "txid", "n", "address", "value",
                 "prev_txid", "prev_vout", "height", "time"]  # presence optional; we add nulls if missing

def scan_with_override(paths: list[str]) -> pl.LazyFrame:
    lf = pl.scan_parquet(paths)
    have = lf.collect_schema().names()
    renames = {src: dst for src, dst in ALIASES.items() if src in have and dst not in have}
    if renames:
        lf = lf.rename(renames)
    have = set(lf.collect_schema().names())
    missing = [c for c in NEEDED_ANYWAY if c not in have]
    if missing:
        lf = lf.with_columns([pl.lit(None).alias(c) for c in missing])
    return lf

# Small helper: drop obviously bad address tokens
def _is_bad_addr_expr(col: str) -> pl.Expr:
    return (pl.col(col).is_null() | (pl.col(col) == "") |
            (pl.col(col).cast(pl.Utf8, strict=False).str.to_lowercase() == "null") |
            (pl.col(col).cast(pl.Utf8, strict=False).str.to_lowercase() == "none"))

# ---- PREFIX HELPERS (normalize to lowercase) ----
def _prefix_expr(col="txid", n=PFX_LEN) -> pl.Expr:
    return pl.col(col).cast(pl.Utf8, strict=False).str.to_lowercase().str.slice(0, n).alias("pfx")

def _txid_prefix_expr(col="txid", n=PFX_LEN) -> pl.Expr:
    return _prefix_expr(col, n)

def _prevtxid_prefix_expr(col="prev_txid", n=PFX_LEN) -> pl.Expr:
    return _prefix_expr(col, n)

def _normalize_pfx_key(key) -> str:
    # Polars group_by iteration may yield ('abc',) for single key
    if isinstance(key, tuple) and len(key) == 1:
        key = key[0]
    return (str(key) if key is not None else "").lower()[:PFX_LEN]

# ---------- diagnostics ----------
def dataset_smoke_check(max_files: int = 8):
    counts = {k: len(glob.glob(v)) for k, v in patterns.items()}
    print("[smoke] file counts:", counts)

    io_paths = list_io_files_since(SAFER_START)
    if not io_paths:
        print("[smoke] no io files at/after", SAFER_START)
        return

    sample = io_paths[:max_files]
    lf = scan_with_override(sample)
    have = lf.collect_schema().names()
    print("[smoke] checking up to", len(sample), "io files...")
    print("[smoke] columns after alias:", have)

    if {"dir"}.issubset(have):
        counts_dir = (
            lf.select("dir")
              .with_columns(pl.col("dir").cast(pl.Utf8, strict=False))
              .group_by("dir").len().sort("len", descending=True)
              .collect()
        )
        print("[smoke] dir counts:\n", counts_dir)

    try:
        tminmax = (
            lf.select(
                pl.col("time").cast(pl.Datetime, strict=False).min().alias("min"),
                pl.col("time").cast(pl.Datetime, strict=False).max().alias("max"),
            ).collect()
        )
        print("[smoke] time range:\n", tminmax)
    except Exception as e:
        print("[smoke] time range skipped:", repr(e))

# ---------- index prefix discovery (handles both naming styles) ----------
def _parse_pfx_dirname(base: str) -> str | None:
    # base is like "pfx=abc" OR "pfx=('abc',)"
    if not base.startswith("pfx="):
        return None
    val = base[4:]
    if val.startswith("('") and val.endswith("',)"):
        p = val[2:-3]
    else:
        p = val
    if not p:
        return None
    return p.lower()

def list_index_prefixes() -> set[str]:
    pfx_set: set[str] = set()
    for d in glob.glob(str(OUTPOINTS_DS / "pfx=*")):
        base = os.path.basename(d)
        p = _parse_pfx_dirname(base)
        if p:
            pfx_set.add(p)
    return pfx_set

# ============================
# 0) Top-K **sources** by “spent outputs”
# ============================
def topk_sources_chunked(k: int = TOPK_SRC) -> list[str]:
    files = list_io_files_since(SAFER_START)
    if not files:
        print("[topk] no io files found via patterns; check `patterns['io']` and START_DAY.")
        return []
    partial = None

    for paths_chunk in chunked_paths(files, CHUNK_SIZE):
        lf = scan_with_override(paths_chunk)
        have = lf.collect_schema().names()
        need_out = [c for c in ("dir","txid","n","address","value") if c in have]
        need_in  = [c for c in ("dir","prev_txid","prev_vout") if c in have]
        if not {"dir"}.issubset(set(have)):
            continue

        outs = (
            lf.select(need_out)
              .with_columns([
                  pl.col("dir").cast(pl.Utf8, strict=False),
                  pl.col("address").map_elements(_addr_norm_py, return_dtype=pl.Utf8).alias("addr_norm"),
                  value_btc_expr.cast(pl.Float32).alias("value_btc"),
              ])
              .filter(pl.col("dir").str.to_lowercase().str.contains("out", literal=True))
              .filter(~_is_bad_addr_expr("addr_norm"))
              .select(["txid","n","addr_norm","value_btc"])
              .collect()
        )
        if not outs.height:
            continue

        ins = (
            lf.select(need_in)
              .with_columns([
                  pl.col("dir").cast(pl.Utf8, strict=False),
                  pl.col("prev_txid").cast(pl.Utf8, strict=False).alias("ptx"),
                  pl.col("prev_vout").cast(pl.Int64, strict=False).alias("pvout"),
              ])
              .filter(pl.col("dir").str.to_lowercase().str.contains("in", literal=True)
                      & pl.col("ptx").is_not_null() & pl.col("pvout").is_not_null())
              .select(["ptx","pvout"])
              .collect()
        )
        if not ins.height:
            continue

        joined = (
            pl.from_dataframe(ins)
              .join(outs.rename({"txid":"ptx","n":"pvout"}), on=["ptx","pvout"], how="inner")
              .group_by("addr_norm")
              .len()
              .sort("len", descending=True)
              .limit(5000)
        )
        partial = joined if partial is None else (
            pl.concat([partial, joined], how="vertical_relaxed")
              .group_by("addr_norm")
              .agg(pl.col("len").sum().alias("len"))
              .sort("len", descending=True)
              .limit(20_000)
        )
        del outs, ins, joined
        gc.collect()

    if partial is None or not partial.height:
        return []
    tops = (partial.filter(~_is_bad_addr_expr("addr_norm"))
                    .sort("len", descending=True)
                    .head(k))
    try: display(tops)
    except: print(tops)
    return tops["addr_norm"].to_list()


In [None]:
# ============================
# 1) Build ON-DISK index (txid,n -> src_addr,value_btc) for TOP sources
# ============================
def build_top_outpoints_dataset(top_sources: list[str]):
    if not top_sources:
        raise RuntimeError("No top sources.")
    # clean existing shards
    for f in glob.glob(str(OUTPOINTS_DS / "**/*.parquet"), recursive=True):
        try: os.remove(f)
        except: pass

    files = list_io_files_since(SAFER_START)
    wrote = 0
    unique_pfx: set[str] = set()

    for paths_chunk in chunked_paths(files, CHUNK_SIZE):
        lf = scan_with_override(paths_chunk)
        have = lf.collect_schema().names()
        need = [c for c in ("dir","txid","n","address","value") if c in have]
        if not {"dir","txid","n","value"}.issubset(set(have)):
            continue

        outp = (
            lf.select(need)
              .with_columns([
                  pl.col("dir").cast(pl.Utf8, strict=False),
                  pl.col("address").map_elements(_addr_norm_py, return_dtype=pl.Utf8).alias("addr_norm"),
                  value_btc_expr.cast(pl.Float32).alias("value_btc"),
                  _txid_prefix_expr("txid", PFX_LEN),
              ])
              .filter(pl.col("dir").str.to_lowercase().str.contains("out", literal=True)
                      & pl.col("value_btc").is_not_null())
              .filter(pl.col("addr_norm").is_in(top_sources) & ~_is_bad_addr_expr("addr_norm"))
              .select(["pfx","txid","n","addr_norm","value_btc"])
              .collect()
        )
        if not outp.height:
            continue

        # robust grouping: normalize key to lower + write to canonical "pfx=<pfx>" dirs
        for key, g in outp.group_by("pfx"):
            pfx = _normalize_pfx_key(key)
            if not pfx:
                continue
            pdir = OUTPOINTS_DS / f"pfx={pfx}"
            pdir.mkdir(parents=True, exist_ok=True)
            out_path = pdir / f"outpoints-{np.random.randint(1e9)}.parquet"
            g.drop("pfx").write_parquet(out_path)
            wrote += g.height
            unique_pfx.add(pfx)

        del outp
        gc.collect()

    print(f"[index] wrote {wrote:,} outpoint rows for top sources into {OUTPOINTS_DS}")
    print(f"[index] unique pfx folders (canonical): {len(unique_pfx)} (sample: {sorted(list(unique_pfx))[:12]})")

def _glob_both_styles_for_pfx(pfx: str) -> list[str]:
    """Return all parquet paths for either folder style: pfx=abc OR pfx=('abc',)."""
    pfx = (pfx or "").lower()
    paths: list[str] = []
    # canonical
    d1 = OUTPOINTS_DS / f"pfx={pfx}"
    if d1.is_dir():
        paths.extend(glob.glob(str(d1 / "*.parquet")))
    # legacy tuple-style
    d2 = OUTPOINTS_DS / ("pfx=('{}',)".format(pfx))
    if d2.is_dir():
        paths.extend(glob.glob(str(d2 / "*.parquet")))
    return paths

def _load_outpoints_for_prefixes(prefixes: list[str]) -> pl.DataFrame:
    all_paths: list[str] = []
    # Only try the prefixes we actually have on disk (fast path)
    have_pfx = list_index_prefixes()
    req = { (p or "").lower() for p in (prefixes or []) }
    wanted = sorted(req & have_pfx)
    if not wanted:
        return pl.DataFrame(schema={
            "txid": pl.Utf8, "n": pl.Int64, "addr_norm": pl.Utf8, "value_btc": pl.Float32
        })
    for p in wanted:
        all_paths.extend(_glob_both_styles_for_pfx(p))
    if not all_paths:
        return pl.DataFrame(schema={
            "txid": pl.Utf8, "n": pl.Int64, "addr_norm": pl.Utf8, "value_btc": pl.Float32
        })
    df = pl.concat([pl.read_parquet(p) for p in all_paths], how="vertical_relaxed")
    return df.with_columns([
        pl.col("txid").cast(pl.Utf8, strict=False),
        pl.col("n").cast(pl.Int64, strict=False),
        pl.col("addr_norm").cast(pl.Utf8, strict=False),
        pl.col("value_btc").cast(pl.Float32, strict=False),
    ])


In [None]:
# ============================
# 2) Build flows: Top-K SOURCES → ANY dest (reduce later) — vectorized, faster
# ============================
def build_flows_top_sources_streaming() -> pd.DataFrame:
    files = list_io_files_since(SAFER_START)
    if not files:
        return pd.DataFrame(columns=["src","dst","btc"])

    # clean old edge fragments
    for f in glob.glob(str(EDGES_DS / "*.parquet")):
        try: os.remove(f)
        except: pass

    have_pfx = list_index_prefixes()
    print(f"[flows] index prefixes available: {len(have_pfx)}")

    total_pairs = 0

    with pl.StringCache():
        for batch_i, paths_chunk in enumerate(chunked_paths(files, CHUNK_SIZE), start=1):
            lf = scan_with_override(paths_chunk)
            have = lf.collect_schema().names()

            need_in = [c for c in ("dir","txid","prev_txid","prev_vout") if c in have]
            if not need_in:
                print(f"[chunk {batch_i}] no inputs columns → skip")
                continue

            ins = (
                lf.select(need_in)
                  .with_columns([
                      pl.col("dir").cast(pl.Utf8, strict=False),
                      pl.col("txid").cast(pl.Utf8, strict=False).alias("spend_txid"),
                      pl.col("prev_txid").cast(pl.Utf8, strict=False).alias("ptx"),
                      pl.col("prev_vout").cast(pl.Int64, strict=False).alias("pvout"),
                      _prevtxid_prefix_expr("prev_txid", PFX_LEN),
                  ])
                  .filter(pl.col("dir").str.to_lowercase().str.contains("in", literal=True)
                          & pl.col("ptx").is_not_null() & pl.col("pvout").is_not_null())
                  .select(["spend_txid","ptx","pvout","pfx"])
                  .collect()
            )
            if not ins.height:
                print(f"[chunk {batch_i}] ins=0 → skip")
                continue

            need_pfx_all = { (p or "").lower() for p in ins["pfx"].unique().to_list() }
            need_pfx = sorted(need_pfx_all & have_pfx)
            if not need_pfx:
                print(f"[chunk {batch_i}] none of {len(need_pfx_all)} needed prefixes exist on disk → skip")
                del ins; gc.collect(); continue

            out_idx_df = _load_outpoints_for_prefixes(need_pfx)
            if out_idx_df.is_empty():
                print(f"[chunk {batch_i}] out_idx=0 (requested_prefixes={len(need_pfx)} / raw_needed={len(need_pfx_all)}) → skip")
                del ins; gc.collect(); continue

            ins_src = (
                pl.from_dataframe(ins)
                  .join(
                      out_idx_df.rename({
                          "txid":"ptx","n":"pvout",
                          "addr_norm":"src_addr","value_btc":"src_value_btc"
                      }),
                      on=["ptx","pvout"], how="inner"
                  )
                  .group_by(["spend_txid","src_addr"])
                  .agg(pl.col("src_value_btc").sum().alias("s_amt"))
            )
            if not ins_src.height:
                print(f"[chunk {batch_i}] ins_src=0 after join → skip")
                del ins; gc.collect(); continue

            spend_txids = ins_src["spend_txid"].unique().to_list()

            need_out = [c for c in ("dir","txid","address","value") if c in have]
            outs = (
                lf.select(need_out)
                  .with_columns([
                      pl.col("dir").cast(pl.Utf8, strict=False),
                      pl.col("txid").cast(pl.Utf8, strict=False),
                      pl.col("address").map_elements(_addr_norm_py, return_dtype=pl.Utf8).alias("dst_addr"),
                      value_btc_expr.cast(pl.Float32).alias("d_amt"),
                  ])
                  .filter(pl.col("dir").str.to_lowercase().str.contains("out", literal=True))
                  .filter(pl.col("txid").is_in(spend_txids) & ~_is_bad_addr_expr("dst_addr"))
                  .select(["txid","dst_addr","d_amt"])
                  .group_by(["txid","dst_addr"])
                  .agg(pl.col("d_amt").sum().alias("d_amt"))
                  .rename({"txid":"spend_txid"})
                  .collect()
            )
            if not outs.height:
                print(f"[chunk {batch_i}] outs=0 for spenders → skip")
                del ins, ins_src; gc.collect(); continue

            # --- vectorized edge accumulation (no Python tuple loops)
            S = ins_src.to_pandas()
            D = outs.to_pandas()
            gS, gD = S.groupby("spend_txid"), D.groupby("spend_txid")
            common_tx = sorted(set(gS.groups.keys()) & set(gD.groups.keys()))
            if not common_tx:
                print(f"[chunk {batch_i}] common_tx=0 (ins_src_tx={S['spend_txid'].nunique()}, outs_tx={D['spend_txid'].nunique()})")
                del ins, ins_src, outs, S, D; gc.collect(); continue

            # accumulate columns; flush once per chunk
            col_src, col_dst, col_btc = [], [], []
            pairs_this_chunk = 0

            for tx in common_tx:
                s = gS.get_group(tx)                   # src_addr, s_amt
                d = gD.get_group(tx).copy()            # dst_addr, d_amt
                d = d.sort_values("d_amt", ascending=False).head(MAX_DST_PER_TX)

                if len(s) * len(d) > MAX_TX_OUTPAIRS:
                    continue

                s_total = float(s["s_amt"].sum())
                d_total = float(d["d_amt"].sum())
                if s_total <= 0 or d_total <= 0:
                    continue

                V = min(s_total, d_total)
                s_share = (s["s_amt"].values / s_total)   # shape (S,)
                d_share = (d["d_amt"].values / d_total)   # shape (D,)
                F = V * np.outer(s_share, d_share)        # (S,D)

                if TX_MIN_EDGE_BTC > 0.0:
                    mask = F >= TX_MIN_EDGE_BTC
                else:
                    mask = np.ones_like(F, dtype=bool)

                if not mask.any():
                    continue

                ii, jj = np.nonzero(mask)
                weights = F[mask].astype(np.float32)

                srcs = s["src_addr"].values
                dsts = d["dst_addr"].values

                col_src += [srcs[i] for i in ii]
                col_dst += [dsts[j] for j in jj]
                col_btc += weights.tolist()
                pairs_this_chunk += len(weights)

            # one Parquet per chunk (fewer shards, faster IO)
            if col_src:
                shard = EDGES_DS / f"edges-b{batch_i:05d}.parquet"
                pl.DataFrame(
                    {"src": col_src, "dst": col_dst, "btc": col_btc},
                    schema={"src": pl.Utf8, "dst": pl.Utf8, "btc": pl.Float32},
                ).write_parquet(shard, compression="zstd", compression_level=3, statistics=False)

            total_pairs += pairs_this_chunk
            print(f"[chunk {batch_i}] ins={len(ins)} out_idx={len(out_idx_df)} ins_src={len(S)} "
                  f"outs={len(D)} common_tx={len(common_tx)} pairs_added={pairs_this_chunk}")

            del ins, ins_src, outs, S, D, out_idx_df, col_src, col_dst, col_btc
            gc.collect()

    shard_paths = glob.glob(str(EDGES_DS / "*.parquet"))
    if not shard_paths:
        print("[flows] no edge fragments were written (try: TX_MIN_EDGE_BTC↓, TOPK_SRC↑, MAX_DST_PER_TX↑, PFX_LEN↑).")
        return pd.DataFrame(columns=["src","dst","btc"])
    print(f"[flows] reduce {len(shard_paths)} edge shards (total_pairs={total_pairs}) from {EDGES_DS}")

    reduced = (
        pl.scan_parquet(shard_paths)
          .group_by(["src","dst"])
          .agg(pl.col("btc").sum().alias("btc"))
          .sort("btc", descending=True)
          .collect()
    )

    # Top-M destinations by inbound BTC
    top_dst = (
        reduced.group_by("dst")
               .agg(pl.col("btc").sum().alias("in_btc"))
               .sort("in_btc", descending=True)
               .head(TOPM_DST)
               .get_column("dst")
               .to_list()
    )
    final_edges = (
        reduced.filter(pl.col("dst").is_in(top_dst) & (pl.col("btc") >= FINAL_MIN_EDGE_BTC))
               .sort("btc", descending=True)
               .to_pandas()
    )
    return final_edges


In [None]:
# ============================
# 3) Plot — cleaner, less cluttered
# ============================
def plot_flow_graph(
    edges: pd.DataFrame,
    sources: set[str] | None = None,
    max_edges: int = SHOW_EDGES,
    min_edge_btc: float = FINAL_MIN_EDGE_BTC,
    layout: str = "bipartite",         # "bipartite" | "spring" | "kk" | "circular"
    label_top: int = 60,               # only label top-N nodes by strength
    min_node_strength: float = 0.0,    # hide nodes with very tiny total flow
    figsize=(22, 13),
    dpi: int = 160,
    save_path: str | None = None
):
    """
    Draws a cleaner flow graph:
      - prunes to top edges and labels top nodes
      - optional bipartite L->R layout (sources on the left)
      - shortened labels for readability
    """
    if edges.empty:
        print("No edges to draw."); return

    # ---- prune edges for visibility
    e = edges[edges["btc"] >= min_edge_btc].copy()
    e = e.sort_values("btc", ascending=False).head(max_edges)
    if e.empty:
        print("No edges reach the visibility threshold."); return

    # ---- node strengths (weighted degree)
    node_strength = {}
    for _, r in e.iterrows():
        node_strength[r["src"]] = node_strength.get(r["src"], 0.0) + float(r["btc"])
        node_strength[r["dst"]] = node_strength.get(r["dst"], 0.0) + float(r["btc"])

    # optional node pruning
    if min_node_strength > 0:
        keep = {n for n, w in node_strength.items() if w >= min_node_strength}
        e = e[e["src"].isin(keep) & e["dst"].isin(keep)]
        node_strength = {n: w for n, w in node_strength.items() if n in keep}

    # ---- build graph
    G = nx.DiGraph()
    for n, w in node_strength.items():
        G.add_node(n, weight=w)
    for _, r in e.iterrows():
        if r["src"] != r["dst"]:
            G.add_edge(r["src"], r["dst"], weight=float(r["btc"]))

    if G.number_of_nodes() == 0:
        print("Nothing to draw after pruning."); return

    # ---- choose layout
    def _bipartite_positions(G, left: set[str]):
        left = [n for n in G.nodes() if n in left]
        right = [n for n in G.nodes() if n not in left]
        pos = {}
        if left:
            ys = np.linspace(0, 1, len(left), endpoint=True)
            for i, n in enumerate(sorted(left, key=lambda x: -G.nodes[x]["weight"])):
                pos[n] = np.array([0.0, ys[i]])
        if right:
            ys = np.linspace(0, 1, len(right), endpoint=True)
            for i, n in enumerate(sorted(right, key=lambda x: -G.nodes[x]["weight"])):
                pos[n] = np.array([1.0, ys[i]])
        return pos

    if layout == "bipartite" and sources:
        pos = _bipartite_positions(G, set(sources))
    elif layout == "kk":
        pos = nx.kamada_kawai_layout(G, weight="weight")
    elif layout == "circular":
        pos = nx.circular_layout(G)
    else:
        k = 2.0 / max(1.0, math.sqrt(G.number_of_nodes()))
        pos = nx.spring_layout(G, k=k, iterations=300, seed=42, weight="weight")

    # ---- node sizes & colors
    node_sizes = [max(120, 2000 * math.log10(G.nodes[n]["weight"] + 10.0)) for n in G.nodes()]
    if sources:
        node_colors = ["#2b8cbe" if n in sources else "#a1d99b" for n in G.nodes()]
    else:
        node_colors = "#4c72b0"

    # Optional community tint (best-effort)
    try:
        from networkx.algorithms.community import greedy_modularity_communities
        comms = list(greedy_modularity_communities(G.to_undirected()))
        comm_map = {}
        for i, c in enumerate(comms):
            for n in c: comm_map[n] = i
        if not sources:
            palette = ["#4c72b0","#55a868","#c44e52","#8172b3","#ccb974","#64b5cd"]
            node_colors = [palette[comm_map.get(n,0) % len(palette)] for n in G.nodes()]
    except Exception:
        pass  # community step is optional

    # ---- edge widths & alpha by weight
    es = np.array([d["weight"] for _,_,d in G.edges(data=True)])
    wmin, wmax = float(es.min()), float(es.max())
    def scale(x, a, b):
        return 0.0 if wmax == wmin else (a + (x - wmin) * (b - a) / (wmax - wmin))
    edge_widths = [scale(w, 0.5, 6.0) for w in es]
    edge_alphas = [scale(w, 0.15, 0.6) for w in es]

    # ---- draw
    plt.figure(figsize=figsize, dpi=dpi)
    nx.draw_networkx_nodes(G, pos, node_size=node_sizes, node_color=node_colors, alpha=0.9)
    for alpha in sorted(set(round(a,2) for a in edge_alphas)):
        idx = [i for i,a in enumerate(edge_alphas) if round(a,2) == alpha]
        nx.draw_networkx_edges(
            G, pos,
            edgelist=[list(G.edges())[i] for i in idx],
            width=[edge_widths[i] for i in idx],
            alpha=alpha,
            arrows=True, arrowsize=8, arrowstyle="-|>",
            connectionstyle="arc3,rad=0.05"
        )

    # ---- label only top nodes by strength (shortened)
    def short(addr: str, k: int = 6) -> str:
        return addr if len(addr) <= 2*k+1 else f"{addr[:k]}…{addr[-k:]}"
    top_nodes = sorted(G.nodes(), key=lambda n: G.nodes[n]["weight"], reverse=True)[:label_top]
    lbl_pos = {n: pos[n] for n in top_nodes}
    lbls = {n: short(n) for n in top_nodes}
    nx.draw_networkx_labels(G, lbl_pos, labels=lbls, font_size=9, font_weight="bold")

    title_side = " (bipartite)" if layout == "bipartite" else ""
    plt.title(f"Flows from Top Sources → Top Destinations{title_side}\n"
              f"drawn edges: {len(e):,}  nodes: {G.number_of_nodes():,}")
    plt.axis("off")
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, bbox_inches="tight", dpi=dpi)
        print(f"[plot] saved to {save_path}")
    plt.show()

In [None]:
# ============================
# Run
# ============================

print(f"[config] START_DAY={SAFER_START}  VALUE_UNITS={VALUE_UNITS}  CHUNK_SIZE={CHUNK_SIZE}")
print("[smoke] file counts:", {k: len(glob.glob(v)) for k, v in patterns.items()})
dataset_smoke_check(max_files=8)  # quick visibility

TOP_SOURCES = set(topk_sources_chunked(TOPK_SRC))
if not TOP_SOURCES:
    raise RuntimeError("Could not determine top sources (verify prevout linkage, columns, and START_DAY).")

build_top_outpoints_dataset(list(TOP_SOURCES))
# Quick visibility into available index prefixes
have_pfx = list_index_prefixes()
print(f"[post-index] have {len(have_pfx)} pfx dirs (sample: {sorted(list(have_pfx))[:16]})")

edges_df = build_flows_top_sources_streaming()
try: display(edges_df.head(20))
except: print(edges_df.head(20))

# Cleaner graph defaults (tweak as you like)
plot_flow_graph(
    edges_df,
    sources=TOP_SOURCES,
    layout="bipartite",      # or "kk" / "spring" / "circular"
    max_edges=250,           # draw only top-N edges
    min_edge_btc=max(1.0, FINAL_MIN_EDGE_BTC),
    label_top=60,            # label only strongest nodes
    min_node_strength=0.0,   # hide very-weak nodes if you set >0
    figsize=(24, 14),
    dpi=180,
    save_path=None           # or a path like "flows_bipartite.png"
)


In [None]:
# ============================
# VIZ UTILS — pruning + alternates (Sankey / Heatmap)
# ============================

def _short(addr: str, k: int = 6) -> str:
    return addr if len(addr) <= 2*k+1 else f"{addr[:k]}…{addr[-k:]}"

def condense_edges_for_viz(
    edges: pd.DataFrame,
    keep_per_source: int = 6,
    share_per_source: float = 0.9,   # keep edges until this cumulative share per src
    min_edge_btc: float = 0.0,
    add_other: bool = True,
) -> pd.DataFrame:
    """
    Reduce clutter by limiting the number of outgoing edges per source.
    Keeps top-N and enough edges to reach share_per_source for each src.
    Optionally rolls the remainder into a single 'Other (src)' sink.
    """
    if edges.empty:
        return edges.copy()

    e = edges[edges["btc"] >= min_edge_btc].copy()
    out_rows = []
    for src, g in e.groupby("src", sort=False):
        g = g.sort_values("btc", ascending=False).reset_index(drop=True)
        total = g["btc"].sum()
        if total <= 0:
            continue
        g["cum_share"] = g["btc"].cumsum() / total
        keep_mask = (g.index < keep_per_source) | (g["cum_share"] <= share_per_source)
        kept = g[keep_mask]
        out_rows.append(kept[["src","dst","btc"]])

        if add_other:
            other_sum = g.loc[~keep_mask, "btc"].sum()
            if other_sum > 0:
                out_rows.append(pd.DataFrame(
                    {"src":[src], "dst":[f"Other({_short(src)})"], "btc":[other_sum]}
                ))

    if not out_rows:
        return pd.DataFrame(columns=["src","dst","btc"])
    out = pd.concat(out_rows, ignore_index=True)
    # Re-aggregate in case of multiple chunks adding to same (src,dst)
    out = (out.groupby(["src","dst"], as_index=False)["btc"].sum()
              .sort_values("btc", ascending=False).reset_index(drop=True))
    return out


def plot_sankey_flows(
    edges: pd.DataFrame,
    sources: set[str] | None = None,
    max_nodes_right: int = 80,         # cap # of right-side nodes for readability
    keep_per_source: int = 6,          # also applied before building the Sankey
    share_per_source: float = 0.9,
    min_edge_btc: float = 0.0,
    file_html: str | None = "flows_sankey.html",
):
    """
    Plot an alluvial/Sankey diagram (great for dense bipartite flows).
    Saves an interactive HTML by default.
    """
    try:
        import plotly.graph_objects as go
    except Exception as e:
        print("[sankey] plotly not installed; `pip install plotly` to enable. Error:", e)
        return

    if edges.empty:
        print("[sankey] no edges"); return

    # 1) prune per source
    e = condense_edges_for_viz(edges, keep_per_source, share_per_source, min_edge_btc, add_other=True)

    # 2) cap # of right nodes globally (merge tail into 'Other (dest)')
    right_totals = e.groupby("dst", as_index=False)["btc"].sum().sort_values("btc", ascending=False)
    keep_right = set(right_totals.head(max_nodes_right)["dst"].tolist())
    e2 = e.copy()
    e2.loc[~e2["dst"].isin(keep_right), "dst"] = "Other (dest)"

    # 3) build node list (sources left, dests right). Keep ordering stable.
    left_nodes = sorted(set(e2["src"].tolist()), key=lambda s: (s not in (sources or set()), s))
    right_nodes = sorted(set(e2["dst"].tolist()))
    nodes = left_nodes + right_nodes
    idx = {n:i for i,n in enumerate(nodes)}

    # 4) Sankey links
    src_idx = e2["src"].map(idx).tolist()
    dst_idx = e2["dst"].map(idx).tolist()
    vals    = e2["btc"].astype(float).tolist()

    # 5) labels (shortened for readability)
    labels = [_short(n) for n in nodes]
    colors = (["#2b8cbe"]*len(left_nodes)) + (["#a1d99b"]*len(right_nodes))

    fig = go.Figure(data=[go.Sankey(
        arrangement="snap",
        node=dict(label=labels, color=colors, pad=12, thickness=18),
        link=dict(source=src_idx, target=[i+len(left_nodes) if n in right_nodes else i for i,n in zip(dst_idx, e2["dst"])],
                  value=vals, color="rgba(0,0,0,0.25)")
    )])

    fig.update_layout(
        title=f"Flows from Top Sources → Top Destinations (Sankey)\n"
              f"links: {len(vals):,}  left nodes: {len(left_nodes)}  right nodes: {len(right_nodes)}",
        font=dict(size=12),
        margin=dict(l=20, r=20, t=70, b=20),
        height=750
    )

    if file_html:
        fig.write_html(file_html, include_plotlyjs="cdn")
        print(f"[sankey] wrote {file_html}")
    fig.show()


def plot_flow_heatmap(
    edges: pd.DataFrame,
    sources: set[str] | None = None,
    top_left: int = 25,
    top_right: int = 60,
    min_edge_btc: float = 0.0,
    figsize=(18, 10),
    dpi=140
):
    """
    Adjacency matrix heatmap of btc; excellent for scanning patterns without clutter.
    """
    if edges.empty:
        print("[heatmap] no edges"); return

    e = edges[edges["btc"] >= min_edge_btc].copy()
    # pick top src/dst by total volume
    top_src = (e.groupby("src", as_index=False)["btc"].sum()
                 .sort_values("btc", ascending=False).head(top_left)["src"].tolist())
    top_dst = (e.groupby("dst", as_index=False)["btc"].sum()
                 .sort_values("btc", ascending=False).head(top_right)["dst"].tolist())

    M = (e[e["src"].isin(top_src) & e["dst"].isin(top_dst)]
         .pivot_table(index="src", columns="dst", values="btc", aggfunc="sum", fill_value=0.0)
         .reindex(index=top_src, columns=top_dst))

    if M.empty:
        print("[heatmap] nothing after filtering"); return

    plt.figure(figsize=figsize, dpi=dpi)
    im = plt.imshow(M.values, aspect="auto", interpolation="nearest")
    plt.colorbar(im, label="BTC")
    plt.yticks(range(len(M.index)), [_short(s) for s in M.index])
    plt.xticks(range(len(M.columns)), [_short(d) for d in M.columns], rotation=90)
    plt.title("Flows heatmap (src × dst)")
    plt.tight_layout()
    plt.show()


In [None]:
# 1) Much cleaner graph: first condense, then draw bipartite
edges_viz = condense_edges_for_viz(
    edges_df,
    keep_per_source=6,      # try 4–8
    share_per_source=0.9,   # keep until 90% of each source's mass
    min_edge_btc=1.0,
    add_other=True
)
plot_flow_graph(
    edges_viz,
    sources=TOP_SOURCES,
    layout="bipartite",
    max_edges=len(edges_viz),    # we've already pruned
    min_edge_btc=0.0,
    label_top=50,
    figsize=(24, 12),
    dpi=160
)

# 2) Interactive alluvial view (best when many edges):
plot_sankey_flows(
    edges_df,
    sources=TOP_SOURCES,
    max_nodes_right=80,      # tune down if still busy
    keep_per_source=6,
    share_per_source=0.9,
    min_edge_btc=1.0,
    file_html="flows_sankey.html"
)

# 3) Quick pattern scan:
plot_flow_heatmap(
    edges_df,
    sources=TOP_SOURCES,
    top_left=25,
    top_right=60,
    min_edge_btc=1.0
)
