In [None]:
# from telegram_toolchain.data.database import get_conn
from telegram_data_models import Message, Chat, MessageTextContent, Queue
from dotenv import load_dotenv

load_dotenv()  # loads .env from cwd (or parents)
load_dotenv("../../credentials/credentials.env")
from sqlalchemy import select, func, case, create_engine
from tqdm.auto import tqdm  # works in both notebooks & terminals
import pyarrow as pa
import pyarrow.parquet as pq
import pyarrow.dataset as ds
import pandas as pd
import os
import time
from pathlib import Path
import json
import duckdb

con = duckdb.connect()
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from matplotlib.ticker import LogLocator, LogFormatterMathtext, NullFormatter

from cmcrameri import cm

plt.style.use('../../resources/mpl_styles/default.mplstyle')

In [None]:
# -----------------------------
# Utilities
# -----------------------------
def _get_positive_xy(df, x_col, y_col, *, log_base=np.e, mask=None):
    """
    Extract x,y arrays; keep strictly positive; apply optional extra mask;
    return x_pos, y_pos, X=log(x), Y=log(y), and final mask.
    """
    x = df[x_col].to_numpy(dtype=float)
    y = df[y_col].to_numpy(dtype=float)

    m = np.isfinite(x) & np.isfinite(y) & (x > 0) & (y > 0)
    if mask is not None:
        m = m & mask

    x_pos = x[m]
    y_pos = y[m]

    if x_pos.size == 0:
        raise ValueError("No strictly positive finite points in both x and y after masking.")

    if log_base == 10:
        X = np.log10(x_pos)
        Y = np.log10(y_pos)
    elif log_base in (np.e, None):
        X = np.log(x_pos)
        Y = np.log(y_pos)
    else:
        # change of base: log_b(x) = ln(x)/ln(b)
        X = np.log(x_pos) / np.log(log_base)
        Y = np.log(y_pos) / np.log(log_base)

    return x_pos, y_pos, X, Y, m


def _line_params_to_powerlaw(a, b, *, log_base=np.e):
    """
    If fitting log y = a + b log x with log base 'log_base':
    returns multiplicative constant c s.t. y = c * x^b.
    """
    if log_base == 10:
        # log10 y = a + b log10 x  => y = 10^a * x^b
        c = 10**a
    elif log_base in (np.e, None):
        # ln y = a + b ln x => y = exp(a) * x^b
        c = np.exp(a)
    else:
        # y = (log_base)^a * x^b
        c = log_base**a
    return c


# -----------------------------
# Correlation measures
# -----------------------------
def correlation_measures(df, x_col, y_col, *, log_base=10, mask=None):
    """
    Returns correlation measures on log-transformed x,y:
      - Pearson r in log space
      - Spearman rho on (x,y) (equivalent to on logs since monotonic)
      - Kendall tau on (x,y)
    """
    from scipy.stats import pearsonr, spearmanr, kendalltau

    _, _, X, Y, m = _get_positive_xy(df, x_col, y_col, log_base=log_base, mask=mask)

    pearson_r, pearson_p = pearsonr(X, Y)
    spear_rho, spear_p = spearmanr(X, Y)  # using X,Y is fine; ranks same as x,y
    kend_tau, kend_p = kendalltau(X, Y)

    return {
        "n": int(X.size),
        "mask": m,
        "pearson_r_log": float(pearson_r),
        "pearson_p": float(pearson_p),
        "spearman_rho": float(spear_rho),
        "spearman_p": float(spear_p),
        "kendall_tau": float(kend_tau),
        "kendall_p": float(kend_p),
    }


# -----------------------------
# Line fits in log-log space
# Each returns a dict:
#  - method, a, b for log y = a + b log x
#  - c for y = c * x^b
#  - plus optional diagnostics
# -----------------------------
def fit_loglog_ols(df, x_col, y_col, *, log_base=10, mask=None):
    """Ordinary least squares of Y on X in log space."""
    from scipy import stats

    _, _, X, Y, m = _get_positive_xy(df, x_col, y_col, log_base=log_base, mask=mask)

    res = stats.linregress(X, Y)
    a = res.intercept
    b = res.slope
    c = _line_params_to_powerlaw(a, b, log_base=log_base)

    return {
        "method": "ols",
        "n": int(X.size),
        "mask": m,
        "a": float(a),
        "b": float(b),
        "c": float(c),
        "rvalue": float(res.rvalue),
        "pvalue": float(res.pvalue),
        "stderr_slope": float(res.stderr),
        "stderr_intercept": float(res.intercept_stderr),
    }


def fit_loglog_tls(df, x_col, y_col, *, log_base=10, mask=None):
    """
    Total least squares / orthogonal regression in log space via SVD.
    Minimizes squared perpendicular distances.
    """
    _, _, X, Y, m = _get_positive_xy(df, x_col, y_col, log_base=log_base, mask=mask)

    Xc = X - X.mean()
    Yc = Y - Y.mean()
    A = np.vstack([Xc, Yc]).T
    # first principal component
    _, _, Vt = np.linalg.svd(A, full_matrices=False)
    vx, vy = Vt[0, 0], Vt[0, 1]
    if np.isclose(vx, 0):
        raise ValueError("TLS failed: vx ~ 0 (vertical line in log space).")

    b = vy / vx
    a = Y.mean() - b * X.mean()
    c = _line_params_to_powerlaw(a, b, log_base=log_base)

    # orthogonal residual RMS (in log units)
    # distance from point to line: |bX - Y + a| / sqrt(b^2 + 1)
    ortho = np.abs(b * X - Y + a) / np.sqrt(b * b + 1)
    return {
        "method": "tls",
        "n": int(X.size),
        "mask": m,
        "a": float(a),
        "b": float(b),
        "c": float(c),
        "ortho_rms": float(np.sqrt(np.mean(ortho**2))),
    }


def fit_loglog_theilsen(df, x_col, y_col, *, log_base=10, mask=None, random_state=0):
    """Theil–Sen robust line fit in log space."""
    from sklearn.linear_model import TheilSenRegressor

    _, _, X, Y, m = _get_positive_xy(df, x_col, y_col, log_base=log_base, mask=mask)

    model = TheilSenRegressor(random_state=random_state)
    model.fit(X.reshape(-1, 1), Y)
    b = float(model.coef_[0])
    a = float(model.intercept_)
    c = _line_params_to_powerlaw(a, b, log_base=log_base)

    return {
        "method": "theil_sen",
        "n": int(X.size),
        "mask": m,
        "a": a,
        "b": b,
        "c": float(c),
    }


def fit_loglog_huber(df, x_col, y_col, *, log_base=10, mask=None, epsilon=1.35):
    """Huber robust regression in log space."""
    from sklearn.linear_model import HuberRegressor

    _, _, X, Y, m = _get_positive_xy(df, x_col, y_col, log_base=log_base, mask=mask)

    model = HuberRegressor(epsilon=epsilon)
    model.fit(X.reshape(-1, 1), Y)
    b = float(model.coef_[0])
    a = float(model.intercept_)
    c = _line_params_to_powerlaw(a, b, log_base=log_base)

    return {
        "method": "huber",
        "n": int(X.size),
        "mask": m,
        "a": a,
        "b": b,
        "c": float(c),
        "epsilon": float(epsilon),
    }


def fit_loglog_ransac(
    df,
    x_col,
    y_col,
    *,
    log_base=10,
    mask=None,
    min_samples=0.5,
    residual_threshold=None,
    random_state=0,
):
    """
    RANSAC line fit in log space (robust to lots of outliers).
    residual_threshold is in log-units; if None, sklearn chooses heuristic.
    """
    from sklearn.linear_model import LinearRegression
    from sklearn.linear_model import RANSACRegressor

    _, _, X, Y, m = _get_positive_xy(df, x_col, y_col, log_base=log_base, mask=mask)

    base = LinearRegression()
    model = RANSACRegressor(
        estimator=base,
        min_samples=min_samples,
        residual_threshold=residual_threshold,
        random_state=random_state,
    )
    model.fit(X.reshape(-1, 1), Y)

    b = float(model.estimator_.coef_[0])
    a = float(model.estimator_.intercept_)
    c = _line_params_to_powerlaw(a, b, log_base=log_base)

    inlier_mask = model.inlier_mask_
    return {
        "method": "ransac",
        "n": int(X.size),
        "mask": m,
        "a": a,
        "b": b,
        "c": float(c),
        "inlier_frac": float(np.mean(inlier_mask)),
        "residual_threshold": residual_threshold,
    }


def fit_loglog_binned_median(df, x_col, y_col, *, log_base=10, mask=None, bins=30, min_per_bin=20):
    """
    Bin in log(x), compute median log(y) per bin, then OLS on bin summaries.
    """
    from scipy import stats

    _, _, X, Y, m = _get_positive_xy(df, x_col, y_col, log_base=log_base, mask=mask)

    # define bins in X
    edges = np.linspace(X.min(), X.max(), bins + 1)
    idx = np.digitize(X, edges) - 1
    good = (idx >= 0) & (idx < bins)

    Xb, Yb = [], []
    for k in range(bins):
        sel = good & (idx == k)
        if np.sum(sel) >= min_per_bin:
            Xb.append(np.median(X[sel]))
            Yb.append(np.median(Y[sel]))

    Xb = np.asarray(Xb)
    Yb = np.asarray(Yb)
    if Xb.size < 2:
        raise ValueError("Not enough populated bins to fit binned median line.")

    res = stats.linregress(Xb, Yb)
    a, b = res.intercept, res.slope
    c = _line_params_to_powerlaw(a, b, log_base=log_base)

    return {
        "method": "binned_median",
        "n": int(X.size),
        "n_bins_used": int(Xb.size),
        "mask": m,
        "a": float(a),
        "b": float(b),
        "c": float(c),
        "rvalue": float(res.rvalue),
    }


def bootstrap_fit_ci(
    fit_func, df, x_col, y_col, *, log_base=10, mask=None, n_boot=500, random_state=0
):
    """
    Generic bootstrap CI for slope b and intercept a in log space.
    fit_func should be one of the fit_* functions above (or compatible).
    """
    rng = np.random.default_rng(random_state)
    x_pos, y_pos, X, Y, m = _get_positive_xy(df, x_col, y_col, log_base=log_base, mask=mask)

    n = X.size
    a_s = np.empty(n_boot)
    b_s = np.empty(n_boot)

    # Work on arrays to avoid repeatedly touching the dataframe
    # We'll create a lightweight dict-like view to reuse fit routines? simplest: resample indices and fit directly here.
    # We'll just do OLS in bootstrap unless you want bootstrap for each robust method too.
    # If you want method-specific bootstraps, it's doable but more compute.
    for i in range(n_boot):
        ii = rng.integers(0, n, size=n)
        Xs = X[ii]
        Ys = Y[ii]
        # quick OLS in log space
        b, a = np.polyfit(Xs, Ys, 1)
        a_s[i] = a
        b_s[i] = b

    def ci(arr):
        return (float(np.percentile(arr, 2.5)), float(np.percentile(arr, 97.5)))

    return {
        "n": int(n),
        "n_boot": int(n_boot),
        "a_ci": ci(a_s),
        "b_ci": ci(b_s),
    }

In [None]:
def integer_logspace_edges(vmin: float, vmax: float, n_bins: int) -> np.ndarray:
    """
    Build integer-only edges approximating log spacing.

    - Starts at max(1, floor(vmin)), ends at ceil(vmax)
    - Each next edge follows a geometric ratio (vmax/vmin)^(1/n_bins)
      but rounded to an integer, with at least +1 step.
    - Guarantees strictly increasing integer edges and avoids skinny final bin.
    """
    if vmin <= 0 or vmax <= 0:
        raise ValueError("vmin and vmax must be > 0 for log spacing.")
    if not np.isfinite(vmin) or not np.isfinite(vmax):
        raise ValueError("vmin/vmax must be finite.")
    if n_bins < 1:
        return np.array([np.floor(vmin), np.ceil(vmax)], dtype=float)

    left = max(1.0, np.floor(vmin))
    right = np.ceil(vmax)
    r = (vmax / vmin) ** (1.0 / n_bins)

    edges = [left]
    i = 1
    while edges[-1] < right and i <= 10_000:
        target = vmin * (r**i)
        next_edge = max(edges[-1] + 1.0, np.round(target))
        if next_edge >= right:
            edges.append(right)
            break
        edges.append(next_edge)
        i += 1

    if edges[-1] < right:
        edges.append(right)
    if len(edges) < 2:
        edges = [left, right]

    return np.array(edges, dtype=float)


def compute_loglog_fit(df, x_col, y_col, *, method="huber", log_base=10, mask=None, **kwargs):
    """
    Dispatch wrapper to compute a fit dict for log y = a + b log x.
    """
    method = method.lower()
    if method == "ols":
        return fit_loglog_ols(df, x_col, y_col, log_base=log_base, mask=mask)
    if method == "tls":
        return fit_loglog_tls(df, x_col, y_col, log_base=log_base, mask=mask)
    if method in ("theil", "theilsen", "theil_sen"):
        return fit_loglog_theilsen(df, x_col, y_col, log_base=log_base, mask=mask, **kwargs)
    if method == "huber":
        return fit_loglog_huber(df, x_col, y_col, log_base=log_base, mask=mask, **kwargs)
    if method == "ransac":
        return fit_loglog_ransac(df, x_col, y_col, log_base=log_base, mask=mask, **kwargs)

    raise ValueError(f"Unknown method: {method}")


def overlay_loglog_fit(
    ax,
    fit,
    *,
    x_min=None,
    x_max=None,
    n=200,
    color="#1a1a1a",
    lw=2.5,
    ls="-",
    alpha=0.9,
    label=None,
):
    """
    Draw y = c * x^b based on fit dict onto an existing log-log axis.
    """
    b = fit["b"]
    c = fit["c"]

    if x_min is None or x_max is None:
        lo, hi = ax.get_xlim()
        x_min = lo if x_min is None else x_min
        x_max = hi if x_max is None else x_max

    xs = np.geomspace(x_min, x_max, n)
    ys = c * (xs**b)

    if label is None:
        label = f"{fit['method']}: y = {c:.3g}·x^{b:.3g}"

    ax.plot(xs, ys, color=color, lw=lw, ls=ls, alpha=alpha, label=label)
    return xs, ys


def two_d_histogram_int_bins(
    df,
    x_col="out_deg_true",
    y_col="priority",
    *,
    gridsize=50,  # target number of approx log bins per axis
    frac_guess=20,
    title="Chat-by-chat: priority vs true out-degree, all languages",
    x_label="True out-degree",
    y_label="Priority",
    figsize=(7, 6),
    mincnt=1,
    ylim=(1, None),
    normalize_by_area=False,  # if True, divide by bin area (only when stat_col is None)
    cscale="log",  # "log" or "linear" color scale
    stat_col=None,  # e.g. "rank"; if None, color by counts/density
    stat_func="mean",  # "mean" or "median" when stat_col is not None
    fit_method=None,  # e.g. "huber", "ols", "tls", "theil_sen", "ransac", "binned_median"
    fit_kwargs=None,  # dict of kwargs for the chosen fit
    fit_color="#d62728",  # overlay color (change if you want)
    fit_label=None,
    include_fit_box=False,  # optional textbox with slope/intercept
    log_base=10,  # log base used for fitting (independent of axes)
    ax=None,
    labelsize=18,
    legendsize=18,
    include_stat_box=True,
    cbar_labelsize=18,
    legend_loc="upper right",
    show_minor_ticks=False,
    reverse_colormap=False,
    facecolor="white",
    include_diagonal=True,
    include_legend=True,
):
    """
    Log-log density/count plot using integer-aligned, approx-log-spaced *rectangular* bins.

    Parameters
    ----------
    normalize_by_area : bool, default False
        If True and stat_col is None, color shows counts divided by bin area (Δx·Δy).
        If False, or if stat_col is not None, color does not divide by area.

    cscale : {"log", "linear"}, default "log"
        Color scale for the counts/densities/statistics.
        - "log": logarithmic color scale via LogNorm; colorbar ticks at 10^x.
        - "linear": linear color scale.

    stat_col : str or None, default None
        If None, color encodes bin counts (or area-normalized density).
        If a column name, color encodes a statistic (mean/median) of this column
        within each 2D bin.

    stat_func : {"mean", "median"}, default "mean"
        Statistic used on `stat_col` when stat_col is not None.

    Returns
    -------
    fig, ax, frac_above, mask_pos
    """
    # 1) Extract data
    x = df[x_col].to_numpy(dtype=float)
    y = df[y_col].to_numpy(dtype=float)

    # 2) Keep only strictly positive values
    mask_pos = (x > 0) & (y > 0)
    x_pos, y_pos = x[mask_pos], y[mask_pos]

    if x_pos.size == 0:
        raise ValueError("No strictly positive points in both x and y.")

    # Fraction above y = x
    frac_above = np.mean(y_pos > x_pos)
    print(f"Using {x_pos.size} points with x>0 and y>0")
    print(f"Above the line y=x: {frac_above:.2%}")

    # If using a statistic, pull that column now (only for the positive-mask subset)
    if stat_col is not None:
        stat_vals = df.loc[mask_pos, stat_col].to_numpy(dtype=float)

    # 3) Axes
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    else:
        fig = ax.figure

    # 4) Integer-ish log edges
    xedges = integer_logspace_edges(x_pos.min(), x_pos.max(), gridsize)
    yedges = integer_logspace_edges(y_pos.min(), y_pos.max(), gridsize)

    # 5) 2D histogram (raw counts)
    H, xedges, yedges = np.histogram2d(x_pos, y_pos, bins=[xedges, yedges])

    # 6) Build Z depending on whether we're using counts/density or a statistic
    if stat_col is None:
        # ---- Original behavior: counts or density ----
        if normalize_by_area:
            dx = np.diff(xedges)  # shape (Nx,)
            dy = np.diff(yedges)  # shape (Ny,)
            area = np.outer(dx, dy)  # shape (Nx, Ny), same as H

            # avoid division by zero
            area = np.where(area > 0, area, np.nan)

            Z = H / area  # counts per unit x per unit y
            # apply mincnt based on raw counts, like hexbin
            Z[H < mincnt] = 0
            base_label = "density: count / (Δx·Δy)"
        else:
            # keep raw counts
            Z = H.copy()
            Z[Z < mincnt] = 0
            base_label = "count per bin"

    else:
        # ---- New behavior: statistic of another column per bin ----
        if normalize_by_area:
            print("Warning: normalize_by_area is ignored when stat_col is not None.")

        # Compute bin indices for each point
        ix = np.digitize(x_pos, xedges) - 1  # 0..Nx-1, may be -1 or Nx
        iy = np.digitize(y_pos, yedges) - 1  # 0..Ny-1, may be -1 or Ny

        Nx = len(xedges) - 1
        Ny = len(yedges) - 1

        # Mask out-of-range indices (should be rare if edges span data)
        valid = (ix >= 0) & (ix < Nx) & (iy >= 0) & (iy < Ny)
        ix = ix[valid]
        iy = iy[valid]
        stat_vals_valid = stat_vals[valid]

        # We'll accumulate values in each bin to compute mean/median
        # Strategy: use a dictionary keyed by (ix, iy) to hold lists,
        # then aggregate.
        from collections import defaultdict

        bin_values = defaultdict(list)
        for bx, by, v in zip(ix, iy, stat_vals_valid):
            bin_values[(bx, by)].append(v)

        # Prepare Z as statistic per bin
        Z = np.full((Nx, Ny), np.nan)

        if stat_func not in {"mean", "median"}:
            raise ValueError("stat_func must be 'mean' or 'median'")

        for (bx, by), vals in bin_values.items():
            vals_arr = np.asarray(vals, dtype=float)
            if stat_func == "mean":
                Z[bx, by] = np.mean(vals_arr)
            else:
                Z[bx, by] = np.median(vals_arr)

        # Apply mincnt threshold based on counts H
        Z[H < mincnt] = np.nan

        base_label = f"{stat_func}({stat_col}) per bin"

    # Mask invalid / zero-or-less bins for plotting
    if stat_col is None:
        Z_plot = np.ma.masked_where(Z <= 0, Z)
    else:
        # For statistics, mask NaNs; if using log scale also mask <= 0
        if cscale == "log":
            Z_plot = np.ma.masked_where(~np.isfinite(Z) | (Z <= 0), Z)
        else:
            Z_plot = np.ma.masked_invalid(Z)
    cmap = plt.get_cmap().reversed() if reverse_colormap else plt.get_cmap()
    # 7) Plot with pcolormesh, choose color scale
    if cscale == "log":
        if Z_plot.size == 0 or np.all(Z_plot.mask):
            raise ValueError("No positive values to plot on a log color scale.")
        vmin = Z_plot.min()
        vmax = Z_plot.max()
        mesh = ax.pcolormesh(
            xedges,
            yedges,
            Z_plot.T,
            norm=LogNorm(
                vmin=vmin,
                vmax=vmax,
            ),
            cmap=cmap,
        )
    elif cscale == "linear":
        mesh = ax.pcolormesh(
            xedges,
            yedges,
            Z_plot.T,
            cmap=cmap,
        )

    else:
        raise ValueError("cscale must be 'log' or 'linear'")

    ax.set_xscale("log")
    ax.set_yscale("log")

    # Make sure limits are set before computing ticks
    if ylim is not None:
        ax.set_ylim(ylim)

    # ---- FORCE SAME MAJOR TICKS ON BOTH AXES ----
    x_lo, x_hi = ax.get_xlim()
    y_lo, y_hi = ax.get_ylim()

    lo = min(x_lo, y_lo)
    hi = max(x_hi, y_hi)

    dec_min = int(np.floor(np.log10(lo)))
    dec_max = int(np.ceil(np.log10(hi)))
    decades = np.arange(dec_min, dec_max + 1)

    # Major ticks: 10^dec
    major_ticks = 10.0**decades
    ax.set_xticks(major_ticks)
    ax.set_yticks(major_ticks)

    ax.tick_params(axis="both", which="major", length=7)

    # Analysis/fit
    # 13) Optional fit overlay (in log-log space)
    fit = None
    if fit_method is not None:
        if fit_kwargs is None:
            fit_kwargs = {}

        # use the same positive mask you already computed (mask_pos)
        fit = compute_loglog_fit(
            df,
            x_col=x_col,
            y_col=y_col,
            method=fit_method,
            log_base=log_base,
            mask=mask_pos,
            **fit_kwargs,
        )

        # Overlay on current axis range
        overlay_loglog_fit(
            ax,
            fit,
            color=fit_color,
            label=fit_label,
        )

        if include_fit_box:
            b = fit["b"]
            c = fit["c"]
            txt = f"{fit['method']} fit\nb = {b:.3g}\nc = {c:.3g}"
            ax.text(
                0.05,
                1.02,
                txt,
                transform=ax.transAxes,
                ha="left",
                va="bottom",
                bbox=dict(boxstyle="round", facecolor="white", alpha=0.7),
                fontsize=legendsize,
            )

    # ---- OPTIONAL MINOR TICKS ----
    if show_minor_ticks:
        minor_ticks = []
        for d in decades:
            minor_ticks.extend((10.0**d) * np.arange(2, 10))
        minor_ticks = np.array(minor_ticks)
        minor_ticks = minor_ticks[(minor_ticks >= lo) & (minor_ticks <= hi)]

        ax.set_xticks(minor_ticks, minor=True)
        ax.set_yticks(minor_ticks, minor=True)

        ax.xaxis.set_minor_formatter(NullFormatter())
        ax.yaxis.set_minor_formatter(NullFormatter())

        ax.tick_params(axis="both", which="minor", length=4)
    # ----------------------------------------------------

    ax.set_xlabel(x_label, fontsize=labelsize)
    ax.set_ylabel(y_label, fontsize=labelsize)
    if title is not None:
        ax.set_title(title)

    # 8) 1:1 line
    if include_diagonal:
        lo_line = min(x_pos.min(), y_pos.min())
        hi_line = max(x_pos.max(), y_pos.max())
        ax.plot(
            [lo_line, hi_line],
            [lo_line, hi_line],
            linewidth=2,
            alpha=0.8,
            label="y = x",
            color="#1a1a1a",
        )

    # 9) Reference slope(s)
    frac_between = None

    if frac_guess is not None:
        xgrid = np.array([x_pos.min(), x_pos.max()])

        # Case 1: single scalar (original behavior)
        if np.isscalar(frac_guess):
            ax.plot(
                xgrid,
                xgrid / frac_guess,
                label=f"y = x/{frac_guess}",
                linestyle="dashdot",
                linewidth=2,
                color="#1a1a1a",
            )

        # Case 2: two values → band between two lines
        else:
            try:
                f1, f2 = sorted(frac_guess)
            except Exception:
                raise ValueError("frac_guess must be a scalar or an iterable of two numbers")

            if f1 <= 0 or f2 <= 0:
                raise ValueError("frac_guess values must be positive")

            # Draw both lines
            ax.plot(
                xgrid,
                xgrid / f1,
                linestyle="dashdot",
                linewidth=2,
                label=f"y = x/{f1}",
                color="#1a1a1a",
            )
            ax.plot(
                xgrid,
                xgrid / f2,
                linestyle="dotted",
                linewidth=2,
                label=f"y = x/{f2}",
                color="#1a1a1a",
            )

            # Compute fraction of points between the two lines
            lower = x_pos / f2
            upper = x_pos / f1
            frac_between = np.mean((y_pos >= lower) & (y_pos <= upper))

    # 10) Colorbar
    cb = fig.colorbar(mesh, ax=ax)

    if cscale == "log":
        cb.locator = LogLocator(base=10)
        cb.formatter = LogFormatterMathtext(base=10)  # shows 10^x
        cb.update_ticks()
        cb.set_label(base_label, fontsize=cbar_labelsize, rotation=270, labelpad=20)
    else:
        cb.set_label(base_label, fontsize=cbar_labelsize, rotation=270, labelpad=20)

    # 11) Annotate fractions
    annotation_lines = [f"y > x: {frac_above:.1%}"]

    if frac_between is not None:
        annotation_lines.append(f"x/{f2} ≤ y ≤ x/{f1}: {frac_between:.1%}")
    print(f"x/{f2} ≤ y ≤ x/{f1}: {frac_between:.1%}")
    if include_stat_box:
        ax.text(
            0.05,
            1.15,
            "\n".join(annotation_lines),
            transform=ax.transAxes,
            ha="left",
            va="top",
            bbox=dict(boxstyle="round", facecolor="white", alpha=0.7),
            fontsize=legendsize,
        )

    # 12) Cosmetics
    if include_legend:  # Don't show the legend box, if not content
        ax.legend(fontsize=legendsize, loc=legend_loc)
    ax.grid(True, which="major", ls="--", alpha=0.4, color="k")
    ax.set_facecolor(facecolor)
    if show_minor_ticks:
        ax.grid(True, which="minor", ls="--", alpha=0.25)

    fig.tight_layout()

    return fig, ax, frac_above, mask_pos

In [None]:
# Database setup
db_user = os.environ.get("DB_USER")
db_pass = os.environ.get("DB_PASSWORD")
db_host = os.environ.get("DB_HOST")
db_port = os.environ.get("DB_PORT")
db_name = os.environ.get("DB_NAME")

db_url = f'postgresql+psycopg2://{db_user}:{db_pass}@{db_host}:{db_port}/{db_name}'

# Dask can't work with ORM models
message_table = Message.__table__
chat_table = Chat.__table__
queue_table = Queue.__table__

In [None]:
engine = create_engine(
    db_url,
    pool_pre_ping=True,  # good for long streaming jobs
    future=True,
)

In [None]:
df_lang = pd.read_parquet("../../data/chat_languages.parquet")
if df_lang.index.name == "chat_id" and "chat_id" not in df_lang.columns:
    df_lang = df_lang.reset_index()

In [None]:
time_window = "1 hour"

if time_window == "1 hour":
    delta_seconds = 60 * 60
elif time_window == "1 day":
    delta_seconds = 60 * 60 * 24
elif time_window == "1 week":
    delta_seconds = 60 * 60 * 24 * 7
elif time_window == "0":
    delta_seconds = 0
elif time_window == "inf":
    delta_seconds = 1e9
else:
    raise ValueError("Unknown time window")

# df = con.execute(f"""
# WITH base AS (
#   SELECT src, sender, ts
#   FROM read_parquet('../../data/edges_sorted.parquet')
#   WHERE ts IS NOT NULL
# ),
# with_prev AS (
#   SELECT
#     src,
#     sender,
#     ts,
#     lag(ts) OVER (PARTITION BY src, sender ORDER BY ts) AS prev_ts
#   FROM base
# ),
# outdeg AS (
#   SELECT
#     src,
#     CAST(SUM(
#       CASE
#         WHEN prev_ts IS NULL THEN 1
#         WHEN ts - prev_ts > INTERVAL '{delta_seconds} seconds' THEN 1
#         ELSE 0
#       END
#     ) AS BIGINT) AS out_degree
#   FROM with_prev
#   GROUP BY src
# )
# SELECT
#   o.src,
#   o.out_degree,
#   c.true_out_deg,
#   c.rank
# FROM outdeg o
# LEFT JOIN read_parquet('../../data/rank_degree.parquet') c
#   ON o.src = c.chat_id
# ORDER BY o.out_degree DESC
# """).df()

df = con.execute(
    f"""
WITH base AS (
  SELECT src, sender, ts
  FROM read_parquet('../../data/edges_sorted.parquet')
  WHERE ts IS NOT NULL
    AND src_is_chat = 1
),
with_prev AS (
  SELECT
    src,
    sender,
    ts,
    lag(ts) OVER (PARTITION BY src, sender ORDER BY ts) AS prev_ts
  FROM base
),
outdeg AS (
  SELECT
    src,
    CAST(SUM(
      CASE
        WHEN prev_ts IS NULL THEN 1
        WHEN ts - prev_ts > INTERVAL '{delta_seconds} seconds' THEN 1
        ELSE 0
      END
    ) AS BIGINT) AS out_degree
  FROM with_prev
  GROUP BY src
)
SELECT
  o.src,
  o.out_degree,
  c.true_out_deg,
  c.rank
FROM outdeg o
LEFT JOIN read_parquet('../../data/rank_degree.parquet') c
  ON o.src = c.chat_id
ORDER BY o.out_degree DESC
"""
).df()

In [None]:
df

In [None]:
only_chats = False
if only_chats:
    df = df[df["src_is_chat"] == 1].copy()

gray_name = "#333533"
fig, ax, frac_above, mask_pos = two_d_histogram_int_bins(
    df,
    x_col="true_out_deg",
    y_col="out_degree",
    gridsize=40,
    title=None,
    x_label="True out-degree",
    y_label="Recorded out-degree",
    cscale="linear",
    frac_guess=[5, 200],
    legendsize=15,
    legend_loc="upper left",
    cbar_labelsize=15,
    reverse_colormap=True,
    include_stat_box=False,
    stat_col="rank",
    include_legend=False,
)
plt.show()

In [None]:
plt.get_cmap()

# Correlation analysis


In [None]:
fig, ax, frac_above, mask_pos = two_d_histogram_int_bins(
    df,
    x_col="true_out_deg",
    y_col="out_degree",
    gridsize=40,
    title=None,
    x_label="True out-degree",
    y_label="Recorded out-degree",
    cscale="log",
    frac_guess=[5, 200],
    legendsize=15,
    legend_loc="upper left",
    cbar_labelsize=15,
    reverse_colormap=True,
    include_stat_box=False,
    # stat_col="rank",
    include_legend=False,
    # NEW:
    fit_method="ransac",
    fit_kwargs={"min_samples": 0.5, "residual_threshold": 0.3},
    fit_color="#e41a1c",
    include_fit_box=True,
    log_base=10,
)

In [None]:
def binned_median_loglog_plot(
    df,
    x_col,
    y_col,
    *,
    bins=40,  # number of log-x bins
    x_edges=None,  # optional explicit bin edges (overrides bins)
    min_per_bin=50,  # drop bins with fewer points
    quantiles=(0.25, 0.75),  # band around median; set None for no band
    log_base=10,  # for binning in log space (axes are still log)
    figsize=(7, 6),
    title=None,
    x_label=None,
    y_label=None,
    ax=None,
    show_points=True,
    connect=True,  # connect medians with a line
    show_band=True,
    band_alpha=0.2,
    # Optional overlay of a fit in log-log space to the BIN MEDIANS
    fit_method=None,  # "ols" or "tls" are easy here; see below
    fit_to="binned",  # "binned" (default) or "raw"
):
    """
    Plot median y in log-spaced bins of x (log-log axes).

    Returns
    -------
    fig, ax, summary
      summary is a dict with arrays for bin centers/medians/etc.
    """
    x = df[x_col].to_numpy(dtype=float)
    y = df[y_col].to_numpy(dtype=float)
    m = np.isfinite(x) & np.isfinite(y) & (x > 0) & (y > 0)
    x = x[m]
    y = y[m]
    if x.size == 0:
        raise ValueError("No strictly positive finite points in both x and y.")

    # Choose log function for binning
    if log_base == 10:
        lx = np.log10(x)
    elif log_base in (np.e, None):
        lx = np.log(x)
    else:
        lx = np.log(x) / np.log(log_base)

    # Build bin edges in log-x
    if x_edges is None:
        edges_lx = np.linspace(lx.min(), lx.max(), bins + 1)
    else:
        # interpret x_edges as ORIGINAL-scale edges, convert to log
        x_edges = np.asarray(x_edges, dtype=float)
        if np.any(x_edges <= 0):
            raise ValueError("All x_edges must be > 0 for log binning.")
        if log_base == 10:
            edges_lx = np.log10(x_edges)
        elif log_base in (np.e, None):
            edges_lx = np.log(x_edges)
        else:
            edges_lx = np.log(x_edges) / np.log(log_base)

    # Assign points to bins
    idx = np.digitize(lx, edges_lx) - 1
    nb = len(edges_lx) - 1
    good = (idx >= 0) & (idx < nb)

    # Summarize each bin
    x_med, y_med, counts = [], [], []
    y_lo, y_hi = [], []

    for k in range(nb):
        sel = good & (idx == k)
        n = int(np.sum(sel))
        if n < min_per_bin:
            continue

        xs = x[sel]
        ys = y[sel]

        x_med.append(np.median(xs))
        y_med.append(np.median(ys))
        counts.append(n)

        if quantiles is not None:
            qlo, qhi = quantiles
            y_lo.append(np.quantile(ys, qlo))
            y_hi.append(np.quantile(ys, qhi))

    x_med = np.asarray(x_med)
    y_med = np.asarray(y_med)
    counts = np.asarray(counts)

    if x_med.size < 2:
        raise ValueError("Not enough populated bins to plot (try fewer bins or lower min_per_bin).")

    if quantiles is not None:
        y_lo = np.asarray(y_lo)
        y_hi = np.asarray(y_hi)

    # Make axes
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    else:
        fig = ax.figure

    # Plot band first
    if show_band and quantiles is not None:
        ax.fill_between(x_med, y_lo, y_hi, alpha=band_alpha)

    # Plot median points (optionally connected)
    if connect:
        ax.plot(x_med, y_med, marker="o" if show_points else None)
    elif show_points:
        ax.scatter(x_med, y_med)

    ax.set_xscale("log")
    ax.set_yscale("log")

    ax.set_xlabel(x_label if x_label is not None else x_col)
    ax.set_ylabel(y_label if y_label is not None else y_col)
    if title is not None:
        ax.set_title(title)

    # Optional fit overlay
    fit = None
    if fit_method is not None:
        fit_method = fit_method.lower()

        # Decide what to fit to: binned medians or raw points
        if fit_to == "binned":
            Xfit = (
                np.log10(x_med)
                if log_base == 10
                else (
                    np.log(x_med) if log_base in (np.e, None) else np.log(x_med) / np.log(log_base)
                )
            )
            Yfit = (
                np.log10(y_med)
                if log_base == 10
                else (
                    np.log(y_med) if log_base in (np.e, None) else np.log(y_med) / np.log(log_base)
                )
            )
        elif fit_to == "raw":
            Xfit = lx
            # log of y in same base
            if log_base == 10:
                Yfit = np.log10(y)
            elif log_base in (np.e, None):
                Yfit = np.log(y)
            else:
                Yfit = np.log(y) / np.log(log_base)
        else:
            raise ValueError("fit_to must be 'binned' or 'raw'")

        if fit_method == "ols":
            b, a = np.polyfit(Xfit, Yfit, 1)  # slope, intercept
        elif fit_method == "tls":
            Xc = Xfit - Xfit.mean()
            Yc = Yfit - Yfit.mean()
            A = np.vstack([Xc, Yc]).T
            _, _, Vt = np.linalg.svd(A, full_matrices=False)
            vx, vy = Vt[0, 0], Vt[0, 1]
            if np.isclose(vx, 0):
                raise ValueError("TLS failed: near-vertical line in log space.")
            b = vy / vx
            a = Yfit.mean() - b * Xfit.mean()
        else:
            raise ValueError("fit_method must be 'ols' or 'tls' in this standalone plot.")

        # Convert to y = c * x^b
        if log_base == 10:
            c = 10**a
        elif log_base in (np.e, None):
            c = np.exp(a)
        else:
            c = log_base**a

        # draw fit over current x-range
        xmin, xmax = ax.get_xlim()
        xs = np.geomspace(xmin, xmax, 200)
        ys = c * (xs**b)
        ax.plot(xs, ys, linestyle="--", linewidth=2, label=f"{fit_method}: y={c:.3g}·x^{b:.3g}")
        ax.legend()

        fit = {"method": fit_method, "a": float(a), "b": float(b), "c": float(c), "fit_to": fit_to}

    ax.grid(True, which="major", ls="--", alpha=0.4)

    summary = {
        "x_median": x_med,
        "y_median": y_med,
        "count": counts,
        "quantiles": quantiles,
    }
    if quantiles is not None:
        summary["y_qlo"] = y_lo
        summary["y_qhi"] = y_hi
    if fit is not None:
        summary["fit"] = fit

    fig.tight_layout()
    return fig, ax, summary

In [None]:
fig, ax, summary = binned_median_loglog_plot(
    df,
    x_col="true_out_deg",
    y_col="out_degree",
    bins=40,
    min_per_bin=200,
    quantiles=(0.25, 0.75),  # IQR band
    title="Binned-median trend: Recorded vs True out-degree",
    x_label="True out-degree",
    y_label="Recorded out-degree",
    fit_method="tls",  # try "ols" or "tls"
    fit_to="binned",  # fit line to binned medians
)

In [None]:
import numpy as np
import matplotlib.pyplot as plt


def binned_mean_std_loglog_plot(
    df,
    x_col,
    y_col,
    *,
    bins=40,
    x_edges=None,  # optional explicit ORIGINAL-scale edges; overrides bins
    min_per_bin=50,
    log_base=10,
    band="std",  # "std" or "sem" or None
    figsize=(7, 6),
    title=None,
    x_label=None,
    y_label=None,
    ax=None,
    connect=True,
    show_points=True,
    band_alpha=0.2,
    clip_lower_at=0.0,  # for log y, negative/zero bands are invalid; we'll clip to this
):
    """
    Plot binned mean(y) in log-spaced bins of x, with band = ±std or ±sem.

    Returns fig, ax, summary dict with per-bin stats.
    """
    x = df[x_col].to_numpy(dtype=float)
    y = df[y_col].to_numpy(dtype=float)
    m = np.isfinite(x) & np.isfinite(y) & (x > 0) & (y > 0)
    x = x[m]
    y = y[m]
    if x.size == 0:
        raise ValueError("No strictly positive finite points in both x and y.")

    # log-x for binning
    if log_base == 10:
        lx = np.log10(x)
    elif log_base in (np.e, None):
        lx = np.log(x)
    else:
        lx = np.log(x) / np.log(log_base)

    # bin edges in log-x
    if x_edges is None:
        edges_lx = np.linspace(lx.min(), lx.max(), bins + 1)
    else:
        x_edges = np.asarray(x_edges, dtype=float)
        if np.any(x_edges <= 0):
            raise ValueError("All x_edges must be > 0.")
        if log_base == 10:
            edges_lx = np.log10(x_edges)
        elif log_base in (np.e, None):
            edges_lx = np.log(x_edges)
        else:
            edges_lx = np.log(x_edges) / np.log(log_base)

    idx = np.digitize(lx, edges_lx) - 1
    nb = len(edges_lx) - 1
    good = (idx >= 0) & (idx < nb)

    x_center, y_mean, y_std, y_lo, y_hi, counts = [], [], [], [], [], []

    for k in range(nb):
        sel = good & (idx == k)
        n = int(np.sum(sel))
        if n < min_per_bin:
            continue

        xs = x[sel]
        ys = y[sel]

        xm = np.mean(xs)  # could also use geometric mean: np.exp(np.mean(np.log(xs)))
        mu = np.mean(ys)

        # sample std (ddof=1) if n>1 else 0
        sd = np.std(ys, ddof=1) if n > 1 else 0.0

        x_center.append(xm)
        y_mean.append(mu)
        y_std.append(sd)
        counts.append(n)

        if band is not None:
            if band == "std":
                err = sd
            elif band == "sem":
                err = sd / np.sqrt(n)
            else:
                raise ValueError("band must be 'std', 'sem', or None")

            lo = mu - err
            hi = mu + err

            # On log y-axis, values must be > 0
            lo = max(lo, clip_lower_at)
            y_lo.append(lo)
            y_hi.append(hi)

    x_center = np.asarray(x_center)
    y_mean = np.asarray(y_mean)
    y_std = np.asarray(y_std)
    counts = np.asarray(counts)

    if x_center.size < 2:
        raise ValueError("Not enough populated bins to plot (try fewer bins or lower min_per_bin).")

    if band is not None:
        y_lo = np.asarray(y_lo)
        y_hi = np.asarray(y_hi)

    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    else:
        fig = ax.figure

    # band first
    if band is not None:
        # If clip_lower_at==0, some lo might be 0 -> invalid on log axis.
        # Replace nonpositive lo with nan so fill_between doesn't blow up.
        lo_plot = y_lo.copy()
        lo_plot[lo_plot <= 0] = np.nan
        ax.fill_between(x_center, lo_plot, y_hi, alpha=band_alpha, label=f"±{band}")

    # mean curve / points
    if connect:
        ax.plot(x_center, y_mean, marker="o" if show_points else None, label="binned mean")
    elif show_points:
        ax.scatter(x_center, y_mean, label="binned mean")

    ax.set_xscale("log")
    ax.set_yscale("log")

    ax.set_xlabel(x_label if x_label is not None else x_col)
    ax.set_ylabel(y_label if y_label is not None else y_col)
    if title is not None:
        ax.set_title(title)

    ax.grid(True, which="major", ls="--", alpha=0.4)
    ax.legend()
    fig.tight_layout()

    summary = {
        "x_center": x_center,
        "y_mean": y_mean,
        "y_std": y_std,
        "count": counts,
        "band": band,
    }
    if band is not None:
        summary["y_lo"] = y_lo
        summary["y_hi"] = y_hi

    return fig, ax, summary

In [None]:
fig, ax, summary = binned_mean_std_loglog_plot(
    df,
    x_col="true_out_deg",
    y_col="out_degree",
    bins=40,
    min_per_bin=200,
    band="std",  # or "sem" or None
    title="Binned mean ± std (log–log)",
    x_label="True out-degree",
    y_label="Recorded out-degree",
)

In [None]:
import numpy as np
import matplotlib.pyplot as plt


def _log_edges(vmin, vmax, bins, base=10):
    if vmin <= 0 or vmax <= 0:
        raise ValueError("log bins require vmin>0 and vmax>0")
    if base == 10:
        return np.logspace(np.log10(vmin), np.log10(vmax), bins + 1)
    elif base in (np.e, None):
        return np.exp(np.linspace(np.log(vmin), np.log(vmax), bins + 1))
    else:
        ln = np.log
        return np.exp(
            np.linspace(ln(vmin), ln(vmax), bins + 1)
        )  # edges are in natural units; base doesn’t matter for spacing


def ridge_max_ybin_per_xbin(
    df,
    x_col,
    y_col,
    *,
    x_bins=40,
    y_bins=40,
    x_edges=None,  # optional explicit edges (original scale)
    y_edges=None,
    log_base=10,
    mincnt_x=1,  # require at least this many points in an x-bin to report a ridge point
    tie_break="lower",  # "lower", "upper", or "center" if multiple y-bins share the max
):
    """
    Compute ridge points from a 2D histogram: for each x-bin choose y-bin with max count.

    Returns
    -------
    ridge : dict with keys
      x_centers, y_centers, peak_counts, xbin_counts, H, x_edges, y_edges, mask_pos
    """
    x = df[x_col].to_numpy(dtype=float)
    y = df[y_col].to_numpy(dtype=float)
    mask_pos = np.isfinite(x) & np.isfinite(y) & (x > 0) & (y > 0)
    x = x[mask_pos]
    y = y[mask_pos]
    if x.size == 0:
        raise ValueError("No strictly positive finite points in both x and y.")

    if x_edges is None:
        x_edges = _log_edges(x.min(), x.max(), x_bins, base=log_base)
    else:
        x_edges = np.asarray(x_edges, dtype=float)

    if y_edges is None:
        y_edges = _log_edges(y.min(), y.max(), y_bins, base=log_base)
    else:
        y_edges = np.asarray(y_edges, dtype=float)

    H, x_edges, y_edges = np.histogram2d(x, y, bins=[x_edges, y_edges])  # H shape (Nx, Ny)
    H = H.astype(int)

    # bin centers (geometric mean is natural for log bins)
    x_centers = np.sqrt(x_edges[:-1] * x_edges[1:])
    y_centers = np.sqrt(y_edges[:-1] * y_edges[1:])

    Nx, Ny = H.shape
    peak_y_idx = np.full(Nx, -1, dtype=int)
    peak_counts = np.zeros(Nx, dtype=int)
    xbin_counts = H.sum(axis=1)

    for i in range(Nx):
        if xbin_counts[i] < mincnt_x:
            continue

        col = H[i, :]
        maxv = col.max()
        if maxv <= 0:
            continue

        js = np.flatnonzero(col == maxv)
        if js.size == 1:
            j = int(js[0])
        else:
            if tie_break == "lower":
                j = int(js.min())
            elif tie_break == "upper":
                j = int(js.max())
            elif tie_break == "center":
                j = int(js[js.size // 2])
            else:
                raise ValueError("tie_break must be 'lower', 'upper', or 'center'")

        peak_y_idx[i] = j
        peak_counts[i] = maxv

    keep = peak_y_idx >= 0
    return {
        "x_ridge": x_centers[keep],
        "y_ridge": y_centers[peak_y_idx[keep]],
        "peak_counts": peak_counts[keep],
        "xbin_counts": xbin_counts[keep],
        "H": H,
        "x_edges": x_edges,
        "y_edges": y_edges,
        "mask_pos": mask_pos,
    }


def plot_ridge(
    ridge,
    *,
    ax=None,
    figsize=(7, 6),
    title=None,
    x_label="x",
    y_label="y",
    show_sizes=True,  # point size proportional to peak count
    show_color=False,  # point color proportional to peak count
    connect=True,
):
    x_r = ridge["x_ridge"]
    y_r = ridge["y_ridge"]
    pc = ridge["peak_counts"]

    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    else:
        fig = ax.figure

    if show_sizes:
        # scale marker sizes gently
        s = 20 + 80 * (pc / pc.max() if pc.max() > 0 else 1.0)
    else:
        s = 30

    if show_color:
        sc = ax.scatter(x_r, y_r, s=s, c=pc)
        fig.colorbar(sc, ax=ax, label="peak bin count")
        if connect:
            ax.plot(x_r, y_r, alpha=0.5)
    else:
        if connect:
            ax.plot(x_r, y_r, marker="o", markersize=4)
        else:
            ax.scatter(x_r, y_r, s=s)

    ax.set_xscale("log")
    ax.set_yscale("log")
    ax.set_xlabel(x_label)
    ax.set_ylabel(y_label)
    if title is not None:
        ax.set_title(title)
    ax.grid(True, which="major", ls="--", alpha=0.4)
    fig.tight_layout()
    return fig, ax

In [None]:
ridge = ridge_max_ybin_per_xbin(
    df,
    x_col="true_out_deg",
    y_col="out_degree",
    x_bins=30,
    y_bins=30,
    mincnt_x=50,  # ignore x-bins with <50 points total
    tie_break="lower",
)

fig, ax = plot_ridge(
    ridge,
    title="Ridge: y-bin with max count per x-bin",
    x_label="True out-degree",
    y_label="Recorded out-degree",
    show_sizes=True,
    show_color=False,
    connect=True,
)