In [None]:
import numpy as np
from dataclasses import dataclass
from typing import Optional, Tuple, List

In [None]:
# ---------- Metrics (primary = WMAE) ----------
def wmae(y, yhat, w):
    w = np.asarray(w, float)
    return (np.abs(y - yhat) * w).sum() / w.sum()

def wrmse(y, yhat, w):
    w = np.asarray(w, float)
    return np.sqrt(((y - yhat)**2 * w).sum() / w.sum())

def weighted_r2(y, yhat, w):
    w = np.asarray(w, float)
    ybar = (y * w).sum() / w.sum()
    ss_res = ((y - yhat)**2 * w).sum()
    ss_tot = ((y - ybar)**2 * w).sum()
    return 1.0 - ss_res / ss_tot if ss_tot > 0 else 0.0

def poisson_deviance(y_counts, exposure, yhat_rate):
    y = np.asarray(y_counts, float)
    lam = np.clip(yhat_rate, 1e-12, None) * np.asarray(exposure, float)
    term = np.where(y > 0, y * np.log(y / lam), 0.0) - (y - lam)
    return 2.0 * term.sum()

# ---------- Small helpers ----------
def wmean(y, w):
    w = np.asarray(w, float)
    sw = w.sum()
    return (y * w).sum() / sw if sw > 0 else 0.0

def leaf_sse(y, w):
    mu = wmean(y, w)
    return ((y - mu)**2 * w).sum()

# ---------- Node ----------
@dataclass
class Node:
    feat: Optional[int]
    thr: Optional[float]
    left: Optional["Node"]
    right: Optional["Node"]
    pred: float
    idx: np.ndarray       # indices of samples at this node
    sse: float = 0.0      # subtree SSE (sum of leaf SSEs)
    leaves: int = 1       # number of leaves in subtree

# ---------- Split search ----------
def best_split(X, y, w, idx, min_leaf_w):
    """
    Scan all features and candidate thresholds (midpoints).
    Works for numeric and one-hot columns (0/1 => threshold=0.5).
    Returns (j, t, L_idx, R_idx, child_sse) or None.
    """
    if idx.size <= 1:
        return None
    yb = y[idx]
    if np.allclose(yb, yb[0]):
        return None

    best = None
    n, d = X.shape
    for j in range(d):
        xj = X[idx, j]
        uniq = np.unique(xj)
        if uniq.size <= 1:
            continue
        if uniq.size == 2 and uniq.min() == 0.0 and uniq.max() == 1.0:
            candidates = [0.5]                      # one-hot split
        else:
            u = np.unique(np.sort(xj))
            candidates = (u[:-1] + u[1:]) / 2.0     # midpoints

        for t in candidates:
            L_mask = xj < t
            if not L_mask.any() or L_mask.all():
                continue
            L = idx[L_mask]
            R = idx[~L_mask]
            # exposure-weighted min leaf size
            if w[L].sum() < min_leaf_w or w[R].sum() < min_leaf_w:
                continue

            sseL = leaf_sse(y[L], w[L])
            sseR = leaf_sse(y[R], w[R])
            sse_total = sseL + sseR
            if (best is None) or (sse_total < best[4]):
                best = (j, t, L, R, sse_total)
    return best

# ---------- Build tree (pre-pruned) ----------
def build_tree(X, y, w, idx=None, max_depth=6, min_leaf_w=10.0, depth=0):
    if idx is None:
        idx = np.arange(X.shape[0], dtype=int)

    pred = wmean(y[idx], w[idx])

    # stop rules (simple & readable)
    if depth >= max_depth or w[idx].sum() < 2 * min_leaf_w or np.allclose(y[idx], y[idx][0]):
        leaf = Node(None, None, None, None, pred, idx.copy())
        leaf.sse = leaf_sse(y[idx], w[idx])
        leaf.leaves = 1
        return leaf

    split = best_split(X, y, w, idx, min_leaf_w)
    if split is None:
        leaf = Node(None, None, None, None, pred, idx.copy())
        leaf.sse = leaf_sse(y[idx], w[idx])
        leaf.leaves = 1
        return leaf

    j, t, L, R, _ = split
    left  = build_tree(X, y, w, L, max_depth, min_leaf_w, depth+1)
    right = build_tree(X, y, w, R, max_depth, min_leaf_w, depth+1)

    node = Node(j, t, left, right, pred, idx.copy())
    node.sse    = left.sse + right.sse
    node.leaves = left.leaves + right.leaves
    return node

# ---------- Predict ----------
def predict_one(x, node: Node):
    while node.feat is not None:
        node = node.left if x[node.feat] < node.thr else node.right
    return node.pred

def predict_tree(X, root: Node):
    return np.array([predict_one(X[i], root) for i in range(X.shape[0])], float)

# ---------- Weakest-link pruning to a target #leaves ----------
def _compute_stats(node: Node, X, y, w):
    """Recompute subtree SSE and leaf counts bottom-up (use when structure changed)."""
    if node.feat is None:
        node.sse = leaf_sse(y[node.idx], w[node.idx])
        node.leaves = 1
        return node.sse, node.leaves
    sL, lL = _compute_stats(node.left, X, y, w)
    sR, lR = _compute_stats(node.right, X, y, w)
    node.sse = sL + sR
    node.leaves = lL + lR
    return node.sse, node.leaves

def _alpha(node: Node, y, w):
    """
    α_t = (SSE_if_pruned_to_leaf - SSE_subtree) / (leaves_subtree - 1)
    (∞ for leaves)
    """
    if node.feat is None:
        return np.inf
    sse_leaf = leaf_sse(y[node.idx], w[node.idx])
    denom = max(node.leaves - 1, 1e-12)
    return (sse_leaf - node.sse) / denom

def _collect_internal(node: Node) -> List[Node]:
    if node is None or node.feat is None:
        return []
    return [node] + _collect_internal(node.left) + _collect_internal(node.right)

def prune_to_leaves(root: Node, target_leaves: int, X, y, w):
    """
    Greedy weakest-link pruning: repeatedly prune the node with smallest α
    until the tree has <= target_leaves.
    """
    import copy
    root = copy.deepcopy(root)
    _compute_stats(root, X, y, w)

    while root.leaves > target_leaves:
        nodes = _collect_internal(root)
        if not nodes:
            break
        alphas = np.array([_alpha(n, y, w) for n in nodes])
        k = int(np.argmin(alphas))
        victim = nodes[k]

        # make victim a leaf
        victim.feat = None
        victim.thr = None
        victim.left = None
        victim.right = None
        victim.pred = wmean(y[victim.idx], w[victim.idx])
        victim.sse = leaf_sse(y[victim.idx], w[victim.idx])
        victim.leaves = 1

        _compute_stats(root, X, y, w)

    return root

# ---------- Simple K-fold (no sklearn) ----------
def kfold_indices(n_samples: int, n_splits: int, shuffle=True, seed=42):
    rng = np.random.default_rng(seed)
    idx = np.arange(n_samples, dtype=int)
    if shuffle:
        rng.shuffle(idx)
    folds = np.array_split(idx, n_splits)
    for i in range(n_splits):
        val_idx = folds[i]
        train_idx = np.concatenate([folds[j] for j in range(n_splits) if j != i])
        yield train_idx, val_idx

# ---------- CV #1: tune (max_depth, min_leaf_w) ----------
def cv_preprune(X, y, w, depths=(3,5,7,9), min_leafs=(5.0,10.0,20.0), folds=5):
    best = None
    for d in depths:
        for m in min_leafs:
            scores = []
            for tr, va in kfold_indices(len(y), folds, shuffle=True, seed=42):
                root = build_tree(X, y, w, idx=tr, max_depth=d, min_leaf_w=m)
                yhat = predict_tree(X[va], root)
                scores.append(wmae(y[va], yhat, w[va]))
            mean_score = float(np.mean(scores))
            if (best is None) or (mean_score < best["wmae"]):
                best = {"max_depth": d, "min_leaf_w": m, "wmae": mean_score}
    return best

# ---------- CV #2: choose post-pruning by target #leaves ----------
def count_leaves(node: Node) -> int:
    if node.feat is None: return 1
    return count_leaves(node.left) + count_leaves(node.right)

def cv_prune_leaves(X, y, w, base_params, candidate_leaves=(2,3,4,6,8,12), folds=5):
    best = None
    for L in candidate_leaves:
        scores = []
        for tr, va in kfold_indices(len(y), folds, shuffle=True, seed=123):
            root0 = build_tree(X, y, w, idx=tr,
                               max_depth=base_params["max_depth"],
                               min_leaf_w=base_params["min_leaf_w"])
            # if tree already smaller than L, pruning does nothing
            rootL = prune_to_leaves(root0, target_leaves=L, X=X, y=y, w=w)
            yhat = predict_tree(X[va], rootL)
            scores.append(wmae(y[va], yhat, w[va]))
        mean_score = float(np.mean(scores))
        if (best is None) or (mean_score < best["wmae"]):
            best = {"max_leaves": L, "wmae": mean_score}
    return best

# ---------- Fit final tree ----------
def fit_final_tree(X, y, w, depths=(3,5,7,9), min_leafs=(5.0,10.0,20.0), leaves=(2,3,4,6,8,12), folds=5):
    # CV #1: pick pre-pruning caps
    pre = cv_preprune(X, y, w, depths=depths, min_leafs=min_leafs, folds=folds)
    # Grow once on all data with best pre-pruning
    root0 = build_tree(X, y, w, max_depth=pre["max_depth"], min_leaf_w=pre["min_leaf_w"])
    # CV #2: pick pruning level (#leaves)
    pr = cv_prune_leaves(X, y, w, pre, candidate_leaves=leaves, folds=folds)
    # Prune final tree to best #leaves
    final_root = prune_to_leaves(root0, target_leaves=pr["max_leaves"], X=X, y=y, w=w)
    return final_root, pre, pr


In [None]:
# From your preprocessing:
# X: pandas DataFrame (one-hot features), y_rate: Series (ClaimNb/Exposure), w_expo: Series (Exposure)
X_np = X_tr.values.astype(float)
y_np = y_tr.values.astype(float)
w_np = w_tr.values.astype(float)

# Fit final manual tree (you can tweak the small grids)
final_tree, pre_caps, post_caps = fit_final_tree(
    X_np, y_np, w_np,
    depths=(5,7,9),
    min_leafs=(5.0,10.0,20.0),
    leaves=(2,3,4,6,8,12),
    folds=5
)
print("Pre-pruning picked:", pre_caps)
print("Post-pruning picked:", post_caps)

# Predict on validation/test numpy arrays:
# yhat_val = predict_tree(X_val_np, final_tree)
# Evaluate (primary) WMAE; also report others:
# print("WMAE :", wmae(y_val_np, yhat_val, w_val_np))
# print("WRMSE:", wrmse(y_val_np, yhat_val, w_val_np))
# print("R2_w :", weighted_r2(y_val_np, yhat_val, w_val_np))
# print("Poisson dev.:", poisson_deviance(ClaimNb_val, w_val_np, yhat_val))
