# Kidney Exchange Simulator

Import Packages

In [55]:
import numpy as np
import copy

 Basic data structures

In [56]:
class Vertex:
    def __init__(self, vid, is_altruist, features, arrival_time=0):
        self.id = vid
        self.is_altruist = is_altruist
        self.features = features
        self.arrival_time = arrival_time


class Graph:
    def __init__(self):
        self.V = {}  # {int: Vertex}
        self.E = {}  # {(int, int): float}
        self.crossmatch_hist = {}  # {(u,v): True/False}

Matching (Naive Version)

In [57]:
import collections
import numpy as np
import scipy.sparse as sp
from scipy.optimize import milp, LinearConstraint, Bounds

def SolveIP_structures(graph, max_cycle_len=3, max_chain_len=4):
    # Build outgoing adjacency
    out = collections.defaultdict(list)
    for (u, v), w in graph.E.items():
        if u in graph.V and v in graph.V:
            out[u].append(v)

    structures = []  # each: {"type","nodes","edges","w"}

    # Enumerate cycles (patients only)
    patients = [vid for vid, vv in graph.V.items() if not vv.is_altruist]
    for start in patients:
        stack = [(start, [start])]
        while stack:
            cur, path = stack.pop()
            if len(path) > max_cycle_len:
                continue
            for nxt in out.get(cur, []):
                if nxt == start:
                    if len(path) >= 2 and start == min(path):  # de-duplicate by smallest vid
                        cyc = path[:]
                        edges = [(cyc[i], cyc[(i + 1) % len(cyc)]) for i in range(len(cyc))]
                        if all(e in graph.E for e in edges):
                            w = float(sum(graph.E[e] for e in edges))
                            structures.append({"type": "cycle", "nodes": tuple(cyc), "edges": tuple(edges), "w": w})
                else:
                    if nxt in path:
                        continue
                    if nxt not in graph.V:
                        continue
                    if graph.V[nxt].is_altruist:
                        continue
                    stack.append((nxt, path + [nxt]))

    # Enumerate chains (altruist-start only); every prefix is valid
    altruists = [vid for vid, vv in graph.V.items() if vv.is_altruist]
    for a in altruists:
        stack = [(a, [a], [], 0.0)]  # (cur, nodes_path, edges_path, w)
        while stack:
            cur, nodes_path, edges_path, w = stack.pop()
            if len(edges_path) >= max_chain_len:
                continue
            for nxt in out.get(cur, []):
                if nxt not in graph.V:
                    continue
                if graph.V[nxt].is_altruist:
                    continue
                if nxt in nodes_path:
                    continue
                e = (cur, nxt)
                if e not in graph.E:
                    continue

                new_nodes = nodes_path + [nxt]
                new_edges = edges_path + [e]
                new_w = w + float(graph.E[e])

                structures.append({"type": "chain", "nodes": tuple(new_nodes), "edges": tuple(new_edges), "w": float(new_w)})
                stack.append((nxt, new_nodes, new_edges, new_w))

    if not structures:
        return []

    # Build A matrix for Ax <= 1 (vertex-disjoint)
    vids = list(graph.V.keys())
    vid_to_row = {vid: i for i, vid in enumerate(vids)}
    n_rows = len(vids)
    n_vars = len(structures)

    rows, cols, data = [], [], []
    for j, s in enumerate(structures):
        for vid in set(s["nodes"]):
            rows.append(vid_to_row[vid])
            cols.append(j)
            data.append(1.0)

    A = sp.csr_matrix((data, (rows, cols)), shape=(n_rows, n_vars))
    lc = LinearConstraint(A, -np.inf * np.ones(n_rows), np.ones(n_rows))  # Ax <= 1
    bounds = Bounds(np.zeros(n_vars), np.ones(n_vars))
    integrality = np.ones(n_vars, dtype=int)

    # maximize sum(w_j x_j) == minimize -w
    c = -np.array([s["w"] for s in structures], dtype=float)
    res = milp(c=c, integrality=integrality, bounds=bounds, constraints=[lc])

    if res.x is None or res.status != 0:
        msg = getattr(res, "message", "MILP failed")
        raise RuntimeError(f"SolveIP MILP failed (status={res.status}). {msg}")

    chosen = [structures[j] for j in range(n_vars) if res.x[j] > 0.5]
    return chosen

In [58]:
def greedy_matching(graph):
    """
    Toy max-weight matching:
      - sort edges by weight desc
      - each vertex at most once as donor and once as patient
    => The chosen edges form disjoint directed paths and directed cycles.
    """
    edges_sorted = sorted(graph.E.items(), key=lambda kv: kv[1], reverse=True)

    donors_used = set()
    patients_used = set()
    matching = []

    for (u, v), w in edges_sorted:
        if u in donors_used or v in patients_used:
            continue
        if u == v:
            continue
        matching.append((u, v))
        donors_used.add(u)
        patients_used.add(v)

    return matching

Events (Expire, negative crossmatch and renege)

In [59]:
def expire(vertex, rng, prob=0.01):
    """Paper: expire with (calibrated) constant probability. Here keep as parameter."""
    return rng.random() < prob


def negative_crossmatch(patient_vertex, rng):
    """
    Paper: failure probability depends on patient's CPRA.
    P(fail) = CPRA/100
    CPRA=100 -> fail prob = 1
    """
    cpra = int(patient_vertex.features.get("cpra", 0))
    cpra = max(0, min(100, cpra))
    return rng.random() < (cpra / 100.0)


def renege(pair_vertex, rng, default_prob=0.02):
    """
    Paper: only relevant for CHAINS (non-simultaneous) â€” the paired donor may renege
    on continuing the chain.
    We model it as a Bernoulli event whose probability can be a constant, or stored
    per-vertex in features['renege_prob'].
    """
    p = float(pair_vertex.features.get("renege_prob", default_prob))
    p = max(0.0, min(1.0, p))
    return rng.random() < p

Empirical samplers for f_p and f_a

In [60]:
class EmpiricalSampler:
    def __init__(self, bank, rng):
        if not bank:
            raise ValueError("bank is empty. Provide at least 1 record.")
        self.bank = bank
        self.rng = rng

    def __call__(self):
        idx = int(self.rng.integers(0, len(self.bank)))
        return copy.deepcopy(self.bank[idx])

ABOCompatible & w_OPTN

In [61]:
from collections import Counter

def abo_compatible(donor_abo: str, cand_abo: str) -> bool:
    '''
    Blood Type Compatiable Funciton:
        O compatible to A, B, AB, O
        A compatible to A, AB
        B compatible to B, AB
        AB compatible to AB
    '''
    d = donor_abo.upper()
    c = cand_abo.upper()
    if d == "O":
        return True
    if d == "A":
        return c in ("A", "AB")
    if d == "B":
        return c in ("B", "AB")
    if d == "AB":
        return c == "AB"
    raise ValueError(f"Unknown ABO: donor={donor_abo}, candidate={cand_abo}")


CAND_ABO_POINTS = {"O": 100, "B": 50, "A": 25, "AB": 0} # O only can recieve O type kidney -> higher priority
PAIRED_DONOR_ABO_POINTS = {"O": 0, "B": 100, "A": 250, "AB": 500} # AB only can donor their kidney to AB -> higher priority


def cpra_points(cpra: int) -> int:
    cpra = int(cpra)
    if not (0 <= cpra <= 100):
        raise ValueError("CPRA must be in [0, 100].")

    if 0 <= cpra <= 19:  return 0
    if 20 <= cpra <= 29: return 5
    if 30 <= cpra <= 39: return 10
    if 40 <= cpra <= 49: return 15
    if 50 <= cpra <= 59: return 20
    if 60 <= cpra <= 69: return 25
    if 70 <= cpra <= 74: return 50
    if 75 <= cpra <= 79: return 75
    if 80 <= cpra <= 84: return 125
    if 85 <= cpra <= 89: return 200
    if 90 <= cpra <= 94: return 300
    if cpra == 95: return 500
    if cpra == 96: return 700
    if cpra == 97: return 900
    if cpra == 98: return 1250
    if cpra == 99: return 1500
    if cpra == 100: return 2000
    raise RuntimeError("unreachable")


def _paired_donor_abo_points(paired_donor_abo):
    """If multiple ABO candidates exist, take the fewest points (conservative)."""
    if paired_donor_abo is None:
        return 0
    if isinstance(paired_donor_abo, str):
        return PAIRED_DONOR_ABO_POINTS[paired_donor_abo.upper()]
    pts = [PAIRED_DONOR_ABO_POINTS[a.upper()] for a in paired_donor_abo]
    return min(pts)


def donor_abo_of(vertex: Vertex) -> str:
    f = vertex.features
    if vertex.is_altruist:
        return f["donor_abo"]
    return f["paired_donor_abo"]


def candidate_abo_of(vertex: Vertex) -> str:
    return vertex.features["candidate_abo"]


def _extract_donor_hla(donor_vertex: Vertex) -> dict:
    """
    Return donor HLA antigens as dict: {"A":[..,..], "B":[..,..], "DR":[..,..]}
    - altruist: donor_hla_A/B/DR
    - pair: paired_donor_hla_A/B/DR
    """
    f = donor_vertex.features or {}
    if donor_vertex.is_altruist:
        A = f.get("donor_hla_A", None)
        B = f.get("donor_hla_B", None)
        DR = f.get("donor_hla_DR", None)
    else:
        A = f.get("paired_donor_hla_A", None)
        B = f.get("paired_donor_hla_B", None)
        DR = f.get("paired_donor_hla_DR", None)

    if A is None or B is None or DR is None:
        raise KeyError(f"Missing donor HLA features for vertex {donor_vertex.id}")

    return {"A": list(A), "B": list(B), "DR": list(DR)}


def _extract_candidate_hla(cand_vertex: Vertex) -> dict:
    """
    Return candidate (recipient) HLA antigens as dict: {"A":[..,..], "B":[..,..], "DR":[..,..]}
    """
    f = cand_vertex.features or {}
    A = f.get("candidate_hla_A", None)
    B = f.get("candidate_hla_B", None)
    DR = f.get("candidate_hla_DR", None)

    if A is None or B is None or DR is None:
        raise KeyError(f"Missing candidate HLA features for vertex {cand_vertex.id}")

    return {"A": list(A), "B": list(B), "DR": list(DR)}


def _zero_abdr_mismatch(donor_hla: dict, cand_hla: dict) -> bool:
    """
    True iff donor and candidate match exactly at A, B, DR (order-insensitive, count-sensitive).
    """
    for locus in ("A", "B", "DR"):
        if Counter(donor_hla[locus]) != Counter(cand_hla[locus]):
            return False
    return True


def w_optn(donor_vertex: Vertex, cand_vertex: Vertex, graph: Graph) -> float:
    """
    Updated w_optn based on your new databank features:
      - 0-ABDR mismatch computed from donor/candidate HLA antigens (edge-level)
      - Previous crossmatch from graph.crossmatch_hist[(u,v)] (edge-level)
    """
    f = cand_vertex.features or {}
    df = donor_vertex.features or {}

    cand_abo = str(f["candidate_abo"]).upper()
    cpra = int(f.get("cpra", 0))
    wait_days = int(f.get("wait_days", 0) or 0)

    w = 100.0 + 0.07 * max(0, wait_days)

    # ---- 0-ABDR mismatch (computed from HLA antigens) ----
    donor_hla = _extract_donor_hla(donor_vertex)
    cand_hla = _extract_candidate_hla(cand_vertex)
    if _zero_abdr_mismatch(donor_hla, cand_hla):
        w += 10.0

    # ---- Same hospital/center ----
    c1 = df.get("center", None)
    c2 = f.get("center", None)
    if c1 is not None and c2 is not None and c1 == c2:
        w += 75.0

    # ---- Previous crossmatch record (edge-level) ----
    prev_ok = graph.crossmatch_hist.get((donor_vertex.id, cand_vertex.id), None)
    if prev_ok is True:
        w += 75.0

    # ---- Candidate age ----
    age = f.get("candidate_age", f.get("age", None))
    if age is not None and int(age) < 18:
        w += 100.0

    # ---- Prior living donor ----
    if bool(f.get("prior_living_donor", False)):
        w += 150.0

    # ---- Your original points ----
    w += float(CAND_ABO_POINTS[cand_abo])
    w += float(_paired_donor_abo_points(f.get("paired_donor_abo", None)))
    w += float(cpra_points(cpra))

    if bool(f.get("orphan", False)):
        w += 1_000_000.0
    return float(w)




def build_edges(graph: Graph, rng, disallow_known_failed_crossmatch: bool = False):
    """
    Build feasible directed edges (u -> v) and set weights graph.E[(u,v)].

    Feasibility:
      - v must NOT be altruist
      - ABO compatible between u's donor ABO and v's candidate ABO

    Optional:
      - If disallow_known_failed_crossmatch=True and crossmatch_hist says (u,v) failed before,
        skip building that edge.

    Weight:
      - w_optn(u, v, graph) computes 0-ABDR from HLA and reads crossmatch_hist.
    """
    graph.E.clear()

    vids = list(graph.V.keys())
    for i in vids:
        u = graph.V[i]
        d_abo = donor_abo_of(u)

        for j in vids:
            if i == j:
                continue

            v = graph.V[j]
            if v.is_altruist:
                continue

            c_abo = candidate_abo_of(v)
            if not abo_compatible(d_abo, c_abo):
                continue

            if disallow_known_failed_crossmatch:
                prev = graph.crossmatch_hist.get((i, j), None)
                if prev is False:
                    continue

            graph.E[(i, j)] = float(w_optn(u, v, graph))


Sample Arrivals

In [62]:
def sample_arrivals(t, graph, lam_p, lam_a, f_p, f_a, rng):
    next_id = max(graph.V.keys()) + 1 if graph.V else 0
    num_pairs = rng.poisson(lam_p)
    num_altruists = rng.poisson(lam_a)

    for _ in range(num_pairs):
        vid = next_id
        next_id += 1
        graph.V[vid] = Vertex(vid, False, f_p(), arrival_time=t + 1)

    for _ in range(num_altruists):
        vid = next_id
        next_id += 1
        graph.V[vid] = Vertex(vid, True, f_a(), arrival_time=t + 1)

SetPool Function

In [63]:
def step_pool(graph, t, lam_p, lam_a, f_p, f_a, rng,
              expire_prob=0.01, renege_prob=0.02,
              max_cycle_len=3, max_chain_len=4):
    """
    Returns:
      graph (updated in-place),
      departures: set[int]
    """

    # 1) SolveIP -> return chosen STRUCTURES (cycle/chain grouped with order)
    chosen = SolveIP_structures(graph, max_cycle_len=max_cycle_len, max_chain_len=max_chain_len)

    # 2) Expire on V(t)
    expired = set()
    for vid, v in list(graph.V.items()):
        if expire(v, rng, prob=expire_prob):
            expired.add(vid)

    departures = set(expired)

    # Cut a chain before an expired recipient
    def _cut_chain_on_expire(nodes, edges, expired_set):
        # nodes: (a0, v1, v2, ...), edges: ((a0,v1),(v1,v2),...)
        if nodes[0] in expired_set:
            return None  # altruist expired => drop whole chain

        for i in range(1, len(nodes)):
            if nodes[i] in expired_set:
                new_nodes = nodes[:i]
                new_edges = edges[:i-1]
                if len(new_edges) >= 1:
                    return (new_nodes, new_edges)
                return None

        return (nodes, edges)

    # Remove/trim structures that touch expired vertices
    filtered = []
    for s in chosen:
        if s["type"] == "cycle":
            if any(x in expired for x in s["nodes"]):
                continue
            filtered.append(s)
        else:  # chain
            cut = _cut_chain_on_expire(s["nodes"], s["edges"], expired)
            if cut is None:
                continue
            new_nodes, new_edges = cut
            filtered.append({"type": "chain", "nodes": new_nodes, "edges": new_edges})

    chosen = filtered

    # 3) Cycles: all-or-nothing crossmatch
    # MIN CHANGE: evaluate crossmatch per executed edge (u->v) and record graph.crossmatch_hist[(u,v)]
    kept = []
    for s in chosen:
        if s["type"] != "cycle":
            kept.append(s)
            continue

        ok = True
        for (u, v) in s["edges"]:
            fail = negative_crossmatch(graph.V[v], rng)   # True means fail
            graph.crossmatch_hist[(u, v)] = (not fail)    # True means OK (negative crossmatch)
            if fail:
                ok = False
                break

        if ok:
            kept.append(s)

    chosen = kept

    # 4) Chains: sequential with tail cut (crossmatch/renege)
    # MIN CHANGE: record crossmatch outcome for each attempted edge (u->v)
    kept = []
    for s in chosen:
        if s["type"] != "chain":
            kept.append(s)
            continue

        nodes = s["nodes"]  # (a0, v1, v2, ...)
        edges = s["edges"]  # ((a0,v1),(v1,v2),...)

        if len(edges) == 0:
            continue

        cut_nodes = nodes
        cut_edges = edges

        for i, (u, v) in enumerate(edges):
            fail = negative_crossmatch(graph.V[v], rng)
            graph.crossmatch_hist[(u, v)] = (not fail)

            if fail:
                # stop BEFORE this edge executes => keep prefix
                cut_edges = edges[:i]
                cut_nodes = nodes[:i+1]
                break

            if renege(graph.V[v], rng, default_prob=renege_prob):
                # edge executes into v; stop AFTER v
                cut_edges = edges[:i+1]
                cut_nodes = nodes[:i+2]
                break

        if len(cut_edges) >= 1:
            kept.append({"type": "chain", "nodes": cut_nodes, "edges": cut_edges})

    chosen = kept

    # 5) Collect executed edges, then departures
    executed_edges = []
    for s in chosen:
        executed_edges.extend(list(s["edges"]))

    # recipients who actually receive depart; altruists depart iff they donated
    for (u, v) in executed_edges:
        departures.add(v)
        if u in graph.V and graph.V[u].is_altruist:
            departures.add(u)

    # 6) Remove departures
    for vid in departures:
        graph.V.pop(vid, None)

    # 7) New arrivals + rebuild edges/weights
    sample_arrivals(t, graph, lam_p, lam_a, f_p, f_a, rng)
    build_edges(graph, rng)

    return graph, departures


Demo

In [64]:
if __name__ == "__main__":
    rng = np.random.default_rng(0)
    g = Graph()

    # Pair records must include: candidate_abo, paired_donor_abo, cpra, wait_days
    # We additionally include HLA antigens to enable true 0-ABDR mismatch computation in w_optn.
    pair_bank = [
        {
            "type": "pair",
            "candidate_abo": "B",
            "paired_donor_abo": "A",
            "cpra": 20,
            "wait_days": 60,
            "candidate_age": 45,
            "prior_living_donor": False,
            "orphan": False,
            "center": 1,
            "renege_prob": 0.02,

            # recipient (candidate) HLA antigens (A/B/DR), 2 each
            "candidate_hla_A": [2, 24],
            "candidate_hla_B": [7, 44],
            "candidate_hla_DR": [4, 15],

            # paired donor HLA antigens (A/B/DR), 2 each
            "paired_donor_hla_A": [1, 24],
            "paired_donor_hla_B": [8, 44],
            "paired_donor_hla_DR": [4, 11],
        },
        {
            "type": "pair",
            "candidate_abo": "O",
            "paired_donor_abo": "B",
            "cpra": 30,
            "wait_days": 200,
            "candidate_age": 30,
            "prior_living_donor": False,
            "orphan": False,
            "center": 2,
            "renege_prob": 0.02,

            "candidate_hla_A": [3, 11],
            "candidate_hla_B": [35, 51],
            "candidate_hla_DR": [1, 13],

            "paired_donor_hla_A": [3, 26],
            "paired_donor_hla_B": [35, 60],
            "paired_donor_hla_DR": [1, 4],
        },
        {
            "type": "pair",
            "candidate_abo": "AB",
            "paired_donor_abo": "O",
            "cpra": 10,
            "wait_days": 10,
            "candidate_age": 12,
            "prior_living_donor": False,
            "orphan": False,
            "center": 1,
            "renege_prob": 0.02,

            "candidate_hla_A": [1, 2],
            "candidate_hla_B": [7, 8],
            "candidate_hla_DR": [4, 11],

            "paired_donor_hla_A": [1, 2],
            "paired_donor_hla_B": [7, 27],
            "paired_donor_hla_DR": [4, 11],
        },
    ]

    altruist_bank = [
        {
            "type": "altruist",
            "donor_abo": "O",
            "center": 1,
            "donor_hla_A": [2, 24],
            "donor_hla_B": [7, 44],
            "donor_hla_DR": [4, 15],
        },
        {
            "type": "altruist",
            "donor_abo": "A",
            "center": 2,
            "donor_hla_A": [1, 3],
            "donor_hla_B": [8, 35],
            "donor_hla_DR": [1, 4],
        },
        {
            "type": "altruist",
            "donor_abo": "B",
            "center": 3,
            "donor_hla_A": [11, 26],
            "donor_hla_B": [51, 60],
            "donor_hla_DR": [13, 15],
        },
    ]

    f_p = EmpiricalSampler(pair_bank, rng)
    f_a = EmpiricalSampler(altruist_bank, rng)

    sample_arrivals(0, g, lam_p=3.0, lam_a=1.0, f_p=f_p, f_a=f_a, rng=rng)
    build_edges(g, rng)

    print("Initial: |V(0)| =", len(g.V), ", |E(0)| =", len(g.E))

    for t in range(5):
        g, D = step_pool(
            graph=g,
            t=t,
            lam_p=3.0,
            lam_a=1.0,
            f_p=f_p,
            f_a=f_a,
            rng=rng,
            expire_prob=0.01,
            renege_prob=0.02,
        )
        print(
            "t =", t,
            "|departures| =", len(D),
            ", |V(t)| =", len(g.V),
            ", |E(t)| =", len(g.E),
        )


Initial: |V(0)| = 2 , |E(0)| = 2
t = 0 |departures| = 2 , |V(t)| = 3 , |E(t)| = 4
t = 1 |departures| = 0 , |V(t)| = 9 , |E(t)| = 35
t = 2 |departures| = 5 , |V(t)| = 8 , |E(t)| = 9
t = 3 |departures| = 0 , |V(t)| = 12 , |E(t)| = 33
t = 4 |departures| = 5 , |V(t)| = 13 , |E(t)| = 18
