# Gym Style toolkit


In [1]:
import gymnasium as gym
import numpy as np

# Static Solver

In [2]:
import numpy as np
! pip install pulp
import pulp
from pulp import value

class StaticSolver:
    def __init__(self, matches: np.ndarray, arrival_rates: np.ndarray, rewards: np.ndarray):
        """Initialize static matching model.

        Args:
            matches: Match matrix (d x n) where d is num match types, n is num agent types
            arrival_rates: Arrival rates vector of length n
            rewards: Rewards vector of length d
        """
        self.matches = matches
        self.d = matches.shape[0]  # number of match types
        self.n = matches.shape[1]  # number of agent types

        # Create optimization model
        self.model = pulp.LpProblem("StaticMatching", pulp.LpMaximize)

        # Decision variables
        self.x = [pulp.LpVariable(f'x_{i}', lowBound=0) for i in range(self.d)]

        # Objective
        self.model += pulp.lpSum(rewards[i] * self.x[i] for i in range(self.d))

        # Constraints
        for j in range(self.n):
            self.model += (
                pulp.lpSum(matches[i,j] * self.x[i] for i in range(self.d)) == arrival_rates[j],
                f'flow_conservation_{j}'
            )

    def solve(self, use_gurobi: bool = True) -> None:
        """Solve the optimization problem.

        Args:
            use_gurobi: If True, use Gurobi solver; otherwise use CBC (default False)
        """
        if use_gurobi:
            try:
                from pulp import GUROBI
                try:
                    self.model.solve(GUROBI(msg=0))
                except Exception as e:
                    print(f"Warning: Gurobi solver failed: {e}, falling back to CBC solver")
                    self.model.solve(pulp.PULP_CBC_CMD(msg=0))
            except ImportError:
                print("Warning: Gurobi not available, falling back to CBC solver")
                self.model.solve(pulp.PULP_CBC_CMD(msg=0))
        else:
            self.model.solve(pulp.PULP_CBC_CMD(msg=0))

    def get_primal_solution(self) -> np.ndarray:
        """Get primal solution (matching rates)."""
        if self.model.status != pulp.LpStatusOptimal:
            raise ValueError(f"Model not solved optimally. Status: {pulp.LpStatus[self.model.status]}")
        return np.array([v.varValue for v in self.x])

    def get_dual_solution(self) -> np.ndarray:
        """Get dual solution (shadow prices)."""
        if self.model.status != pulp.LpStatusOptimal:
            raise ValueError(f"Model not solved optimally. Status: {pulp.LpStatus[self.model.status]}")
        return np.array([self.model.constraints[f'flow_conservation_{j}'].pi for j in range(self.n)])

    def update_arrival_rates(self, new_rates: np.ndarray) -> None:
        """Update arrival rates in the model."""
        for j in range(self.n):
            self.model.constraints[f'flow_conservation_{j}'].changeRHS(new_rates[j])
        # Reset solution status
        self.model.status = pulp.LpStatusNotSolved

    def opt_obj(self):
        return value(self.model.objective)



# Config of the Env

In [3]:
"""
The config of the environment of Dynamic Matching problem is defined as follows:

Config = {
    T: int, # number of time periods
    AR: np.ndarry, # Arrival Rate -> shape: (T, d)
    M: np.ndarry, # Mathces -> shape: (m, d)
    R: np.ndarry, # Reward -> shape: (T, m)
    IQ: np.ndarry, # Initial Queue -> shape: (d)
    r: float, # Discount Rate: float
    a: float  # Abandon Rate
}

where T is the number of time periods, m is the number of match types, d is the number of agent types.
"""

matches = np.array([
    [1,0,0,0,0,0,0,0],
    [0,1,0,0,0,0,0,0],
    [0,0,1,0,0,0,0,0],
    [0,0,0,1,0,0,0,0],
    [0,0,0,0,1,0,0,0],
    [0,0,0,0,0,1,0,0],
    [0,0,0,0,0,0,1,0],
    [0,0,0,0,0,0,0,1],
    [1,0,0,0,0,1,0,0],
    [1,0,0,0,0,0,1,0],
    [1,0,0,0,0,0,0,1],
    [0,1,0,0,0,0,0,1],
    [0,0,1,0,1,0,0,0],
    [0,0,1,0,0,1,0,0],
    [0,0,1,0,0,0,0,1],
    [0,0,0,1,1,0,0,0],
])

rewards = [0., 0., 0., 0., 0., 0.,
           0., 0., 0.99530526, 1.0054256,
           0.99536582, 0.9953427, 1.00241962,
           0.9808672, 0.98275082, 0.99437712]

arrival_rates = np.array([
    0.12493259, 0.12414321, 0.12512027, 0.12620845,
    0.12402401, 0.12402403, 0.1262783, 0.12526913
])

arrival = np.array([0, 0, 0, 0, 0, 0, 1, 0])


DM_config ={
    "T": 10,
    "AR": np.tile(arrival_rates, (10, 1)),
    "M": matches,
    "R": np.tile(rewards, (10, 1)),
    'IQ': arrival,
    "r": 1.,
    "a": 0.
}

print("Matches shape:", matches.shape)
print("\nExample matches:")
print(matches)
print("\nRewards:", rewards)
print("\nArrival rates:", arrival_rates)


Matches shape: (16, 8)

Example matches:
[[1 0 0 0 0 0 0 0]
 [0 1 0 0 0 0 0 0]
 [0 0 1 0 0 0 0 0]
 [0 0 0 1 0 0 0 0]
 [0 0 0 0 1 0 0 0]
 [0 0 0 0 0 1 0 0]
 [0 0 0 0 0 0 1 0]
 [0 0 0 0 0 0 0 1]
 [1 0 0 0 0 1 0 0]
 [1 0 0 0 0 0 1 0]
 [1 0 0 0 0 0 0 1]
 [0 1 0 0 0 0 0 1]
 [0 0 1 0 1 0 0 0]
 [0 0 1 0 0 1 0 0]
 [0 0 1 0 0 0 0 1]
 [0 0 0 1 1 0 0 0]]

Rewards: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.99530526, 1.0054256, 0.99536582, 0.9953427, 1.00241962, 0.9808672, 0.98275082, 0.99437712]

Arrival rates: [0.12493259 0.12414321 0.12512027 0.12620845 0.12402401 0.12402403
 0.1262783  0.12526913]


#Creat Env

In [4]:
class DynamicMatching(gym.Env):
    """
    An environment representing the dynamic mathcing problem

    Attributes:
        T: int, # number of time periods
        AR: np.ndarry, # Arrival Rate -> shape: (T, d)
        M: np.ndarry, # Mathces -> shape: (m, d)
        R: np.ndarry, # Reward -> shape: (T, m)
        PQ: np.ndarry, # Pysical Queue -> shape: (d,)
        VQ: np.ndarry, # Virtual Queue -> shape: (d,)
        r: float, # Discount Rate: float
        a: float  # Abandon Rate

    where T is the number of time periods, m is the number of match types, d is the number of agent types.
    """
    def __init__(self, config):

        # Intializes Env parameters based on configuration dictionary
        self.config = config
        self.T = config["T"]
        self.AR = config["AR"]
        self.M = config["M"]
        self.R = config["R"]
        self.r = config["r"]
        self.a = config["a"]

        self.m = self.M.shape[0] # m is the number of match types
        self.d = self.M.shape[1] # d is the number of agent types
        self.seed = 16

        self.action_space = gym.spaces.Discrete(self.m, seed=self.seed)
        self.observation_space = gym.spaces.Discrete(self.d, seed=self.seed)
        self.initial_queue = config["IQ"].copy()
        self.physical_queue = self.initial_queue.copy()
        self.virtual_queue = self.initial_queue.copy()

        self.timestep = 0
        self.reset()

    def get_config(self):
        return self.config

    # Reset the environemnt to initial state
    def reset(self):
        self.timestep = 0
        self.physical_queue = self.initial_queue.copy()
        self.virtual_queue = self.initial_queue.copy()
        return

    # Defines one step of the DM, returning the new state, reward, whether time horizon is finished and unrealized action list
    def step(self, action_list, update=True):


        if self.timestep >= self.T:
            raise ValueError("Time horizon is finished") # Check the time horizon
        for i, action in enumerate(action_list):
            assert self.action_space.contains(action) or (action == None), f"Invalid action at index {i}: {action}" # Check the Action List is Valid

        # Update the Physical Queue
        # Realize the Action in the Action list.
        physical_rewards = 0
        actions_to_remove = []
        for action in action_list:
            if action == None:
                continue
            elif np.all(self.physical_queue >= self.M[action]):
                self.physical_queue = self.physical_queue - self.M[action]
                physical_rewards = physical_rewards + self.R[self.timestep, action]
                actions_to_remove.append(action)
        physical_rewards *= self.r ** (self.timestep) # Consider the discount rate at timestep t: reward * (r**t)

        # Remove the realized actions
        for action in actions_to_remove:
            action_list.remove(action)

        # Update the Virtual Queue
        # Do not need to realize the action list
        virtual_rewards = 0
        for action in action_list:
            if action == None:
                continue
            else:
                self.virtual_queue = self.virtual_queue - self.M[action]
                virtual_rewards = virtual_rewards + self.R[self.timestep, action]
        virtual_rewards *= self.r ** (self.timestep) # Consider the discount rate at timestep t: reward * (r**t)



        # Move to next period
        episode_over = False
        new_arrival = np.zeros(self.d)
        if update == True:
            self.timestep += 1
            if self.timestep == self.T:
                episode_over = True
            else:
                new_arrival_idx = self.observation_space.sample(probability=self.AR[self.timestep])
                new_arrival[int(new_arrival_idx)] = 1
                self.physical_queue = self.physical_queue + new_arrival
                self.virtual_queue = self.virtual_queue + new_arrival


        return [self.physical_queue.copy(), self.virtual_queue.copy()], physical_rewards, virtual_rewards, episode_over, new_arrival.copy(), action_list



# Experiments

In [5]:
class Experiment(object):
    """Optional instrumentation for running an experiment.

    Runs a simulation between an arbitrary openAI Gym environment and an algorithm, saving a dataset of (reward, time, space) complexity across each episode,
    and optionally saves trajectory information.

    Attributes:
        seed: random seed set to allow reproducibility
        dirPath: (string) location to store the data files
        nEps: (int) number of episodes for the simulation
        deBug: (bool) boolean, when set to true causes the algorithm to print information to the command line
        env: (openAI env) the environment to run the simulations on
        epLen: (int) the length of each episode
        numIters: (int) the number of iterations of (nEps, epLen) pairs to iterate over with the environment
        save_trajectory: (bool) boolean, when set to true saves the entire trajectory information
        render_flag: (bool) boolean, when set to true renders the simulations
        agent: (or_suite.agent.Agent) an algorithm to run the experiments with
        data: (np.array) an array saving the metrics along the sample paths (rewards, time, space)
        trajectory_data: (list) a list saving the trajectory information
    """

    def __init__(self, env, agent):
        '''
        Args:
            env: (openAI env) the environment to run the simulations on
            agent: (or_suite.agent.Agent) an algorithm to run the experiments with
            dict: a dictionary containing the arguments to send for the experiment, including:
                dirPath: (string) location to store the data files
                nEps: (int) number of episodes for the simulation
                deBug: (bool) boolean, when set to true causes the algorithm to print information to the command line
                env: (openAI env) the environment to run the simulations on
                epLen: (int) the length of each episode
                numIters: (int) the number of iterations of (nEps, epLen) pairs to iterate over with the environment
                save_trajectory: (bool) boolean, when set to true saves the entire trajectory information
                render: (bool) boolean, when set to true renders the simulations
                pickle: (bool) when set to true saves data to a pickle file
        '''

        self.seed = 12
        self.env = env
        self.epLen = 10
        self.num_iters = 1
        self.agent = agent

        np.random.seed(self.seed)  # sets seed for experiment

    # Runs the experiment
    def run(self):
        '''
            Runs the simulations between an environment and an algorithm
        '''
        for ite in range(self.num_iters):  # loops over the episodes

            # Reset the environment
            self.env.reset()

            # Reset the agent
            self.agent.reset()
            self.agent.update_config(self.env, self.env.get_config())


            oldState = self.env.initial_queue # obtains old state
            arrival = self.env.initial_queue
            epReward = 0


            # repeats until episode is finished
            for t in range(self.epLen):
                print("\n"+"="*10 + f"{t}-th period" + "="*10)
                # Select action list
                action_list = self.agent.pick_action(
                        queue=oldState, arrival=arrival,t=t)
                print(f"Proposed Action List: {action_list}")

                # steps based on the action chosen by the algorithm
                queue, physical_rewards, virtual_rewards, episode_over, new_arrival, remain_action_list = self.env.step(action_list)
                print(f"Remain Action List: {remain_action_list}")

                epReward += physical_rewards

                oldState = queue[self.agent.virtual]
                arrival = new_arrival

                # Update the Policy
                self.agent.update_policy(arrival=arrival, remain_action_list=remain_action_list, t=t)

            print(f"\nTotal Rewards: {epReward}")

            self.env.close()



# Agents

In [6]:
'''
All agents should inherit from the Agent class.
'''
class Agent(object):

    def __init__(self):
        pass
    def reset(self):
        pass
    def update_config(self, env, config):
        ''' Update agent information based on the config__file'''
        pass
    def update_parameters(self, param):
        pass
    def update_obs(self, obs, action, reward, newObs, timestep, info):
        '''Add observation to records'''
        pass
    def update_policy(self, h):
        '''Update internal policy based upon records'''
        pass
    def pick_action(self, obs, h):
        '''Select an action based upon the observation'''
        pass

Greedy Agent

In [7]:
class GreedyAgent(Agent):
    def __init__(self):
        self.matches = None
        self.m = None
        self.rewards = None
        self.arrivals = None
        self.virtual = False

    def reset(self):
        self.matches = None
        self.m = None
        self.rewards = None
        self.arrivals = None
        self.virtual = False

    def update_config(self, env, config):
        ''' Update agent information based on the config__file'''
        self.matches = config["M"]
        self.m = self.matches.shape[0]
        self.rewards = config["R"]
        self.arrivals = config["IQ"] # Inital Queue

    def update_parameters(self, param):
        pass

    def update_obs(self, obs, action, reward, newObs, timestep, info):
        '''Add observation to records'''
        pass

    def update_policy(self, arrival=None, remain_action_list=None, t=None):
        '''Update internal policy based upon records'''
        pass

    def pick_action(self, queue, arrival, t):
        '''Select an action based upon the observation'''
        best_reward = 0
        best_match_idx = None

        for i in range(self.m):
            # Check if match uses current arrival
            if np.inner(self.matches[i], arrival) > 1e-5:
                # Check if we have enough agents
                if np.all(queue >= self.matches[i]) and self.rewards[t][i] > best_reward:
                    best_match_idx = i
                    best_reward = self.rewards[t][i]
        return [best_match_idx]



In [8]:
DM_ENV = DynamicMatching(DM_config)
DM_AGENT = GreedyAgent()

Greedy_Exp = Experiment(env = DM_ENV, agent = DM_AGENT)
Greedy_Exp.run()


Proposed Action List: [None]
Remain Action List: [None]

Proposed Action List: [None]
Remain Action List: [None]

Proposed Action List: [15]
Remain Action List: []

Proposed Action List: [9]
Remain Action List: []

Proposed Action List: [None]
Remain Action List: [None]

Proposed Action List: [12]
Remain Action List: []

Proposed Action List: [None]
Remain Action List: [None]

Proposed Action List: [9]
Remain Action List: []

Proposed Action List: [None]
Remain Action List: [None]

Proposed Action List: [9]
Remain Action List: []

Total Rewards: 5.01307354


MaxQueue Agent

In [9]:
class MaxQueueAgent(Agent):
    def __init__(self):
        self.matches = None
        self.m = None
        self.rewards = None
        self.arrivals = None
        self.arrival_rates = None
        self.virtual = False
        self.valid_indices = None

    def reset(self):
        self.matches = None
        self.m = None
        self.rewards = None
        self.arrivals = None
        self.arrival_rates = None
        self.virtual = False
        self.valid_indices = None

    def update_config(self, env, config):
        ''' Update agent information based on the config__file'''
        self.matches = config["M"]
        self.m = self.matches.shape[0]
        self.rewards = config["R"]
        self.arrivals = config["IQ"] # Inital Queue
        self.arrival_rates = config["AR"] # arrival rates

        # Get optimal offline solution
        solver = StaticSolver(self.matches, self.arrival_rates[0], self.rewards[0])
        solver.solve()  # Will use gurobi if available
        primal_soln = solver.get_primal_solution()

        # Get valid matches (positive in optimal solution)
        self.valid_indices = np.where(primal_soln > 1e-10)[0]
        print(f"Valid indices: {self.valid_indices}")

        # Validate solution
        if len(self.valid_indices) != self.matches.shape[1]:
            raise ValueError("Primal solution not basic feasible")


    def update_parameters(self, param):
        pass

    def update_obs(self, obs, action, reward, newObs, timestep, info):
        '''Add observation to records'''
        pass

    def update_policy(self, arrival=None, remain_action_list=None, t=None):
        '''Update internal policy based upon records'''
        pass

    def pick_action(self, queue, arrival, t):
        '''Select an action based upon the observation'''
        best_match_idx = None
        highest_sum = 0

        for i in self.valid_indices:
            if np.all(queue >= self.matches[i]):
                queue_sum = np.inner(self.matches[i], queue)
                if queue_sum > highest_sum:
                    best_match_idx = i
                    highest_sum = queue_sum

        return [best_match_idx]



In [10]:
DM_ENV = DynamicMatching(DM_config)
DM_AGENT = MaxQueueAgent()

Greey_Exp = Experiment(env = DM_ENV, agent = DM_AGENT)
Greey_Exp.run()

Valid indices: [ 3  5  6  9 11 13 14 15]

Proposed Action List: [np.int64(6)]
Remain Action List: []

Proposed Action List: [None]
Remain Action List: [None]

Proposed Action List: [np.int64(15)]
Remain Action List: []

Proposed Action List: [None]
Remain Action List: [None]

Proposed Action List: [None]
Remain Action List: [None]

Proposed Action List: [None]
Remain Action List: [None]

Proposed Action List: [None]
Remain Action List: [None]

Proposed Action List: [np.int64(9)]
Remain Action List: []

Proposed Action List: [np.int64(9)]
Remain Action List: []

Proposed Action List: [None]
Remain Action List: [None]

Total Rewards: 3.0052283199999996


Primal-Dual Blind Agent

In [11]:
class PrimalDualBlindAgent(Agent):
    def __init__(self):
        self.T = None
        self.matches = None
        self.m = None
        self.rewards = None
        self.arrivals = None
        self.arrival_rates = None
        self.arrival_sum = None
        self.virtual = True # Using Virtual Queue as State
        self.valid_indices = None
        self.unrealized_matches = []
        self.v_vec = None
        self.solver = None

    def reset(self):
        self.T = None
        self.matches = None
        self.m = None
        self.rewards = None
        self.arrivals = None
        self.arrival_rates = None
        self.arrival_sum = None
        self.virtual = True # Using Virtual Queue as State
        self.valid_indices = None
        self.unrealized_matches = []
        self.v_vec = None
        self.solver = None

    def update_config(self, env, config):
        ''' Update agent information based on the config__file'''
        self.T = config["T"]
        self.matches = config["M"]
        self.m = self.matches.shape[0]
        self.rewards = config["R"]
        self.arrivals = config["IQ"] # Inital Queue
        self.arrival_rates = config["AR"] # arrival rates
        self.arrival_sum = config["IQ"]
        self.v_vec = [100.0] * self.T

        # Initialize solver with empirical arrival rates
        self.solver = StaticSolver(self.matches, self.arrival_sum, self.rewards[0])



    def update_parameters(self, param):
        pass

    def update_obs(self, obs, action, reward, newObs, timestep, info):
        '''Add observation to records'''
        pass

    def update_policy(self, arrival=None, remain_action_list=None, t=None):
        '''Update internal policy based upon records'''

        # Update empirical arrival rates with a minimum threshold to prevent numerical issues
        self.arrival_sum = self.arrival_sum + arrival
        empirical_rates = self.arrival_sum / (t + 2)
        self.solver.update_arrival_rates(empirical_rates)

        self.unrealized_matches = remain_action_list


    def pick_action(self, queue, arrival, t):
        '''Select an action based upon the observation'''


        # Get empirical dual values
        self.solver.solve()
        dual_values = self.solver.get_dual_solution()

        # Find best match based on reduced reward
        best_match_idx = None
        highest_reduced_reward = 1e-6

        for m in range(self.m):
            # Only consider matches using current arrival
            if np.inner(self.matches[m], arrival) > 0:
                # Compute reduced reward with queue pressure
                reduced_reward = (
                    self.rewards[t][m] -
                    np.inner(dual_values, self.matches[m]) +
                    np.inner(queue, self.matches[m]) / self.v_vec[t]
                )
                if reduced_reward > highest_reduced_reward:
                    best_match_idx = m
                    highest_reduced_reward = reduced_reward

        self.unrealized_matches.append(best_match_idx)

        return self.unrealized_matches



In [12]:
DM_ENV = DynamicMatching(DM_config)
DM_AGENT = PrimalDualBlindAgent()

PrimalDual_Blind_Exp = Experiment(env = DM_ENV, agent = DM_AGENT)
PrimalDual_Blind_Exp.run()


Proposed Action List: [6]
Remain Action List: []

Proposed Action List: [12]
Remain Action List: [12]

Proposed Action List: [12, 15]
Remain Action List: [12]

Proposed Action List: [12, 9]
Remain Action List: [12, 9]

Proposed Action List: [12, 9, None]
Remain Action List: [12, 9, None]

Proposed Action List: [12, 9, None, None]
Remain Action List: [9, None, None]

Proposed Action List: [9, None, None, None]
Remain Action List: [9, None, None, None]

Proposed Action List: [9, None, None, None, None]
Remain Action List: [None, None, None, None]

Proposed Action List: [None, None, None, None, None]
Remain Action List: [None, None, None, None, None]

Proposed Action List: [None, None, None, None, None, None]
Remain Action List: [None, None, None, None, None, None]

Total Rewards: 3.00222234


PrimalDualAgent

In [13]:
class PrimalDualAgent(Agent):
    def __init__(self):
        self.T = None
        self.matches = None
        self.m = None
        self.rewards = None
        self.arrivals = None
        self.arrival_rates = None
        self.virtual = True # Using Virtual Queue as State
        self.unrealized_matches = []
        self.v_vec = None
        self.solver = None

    def reset(self):
        self.T = None
        self.matches = None
        self.m = None
        self.rewards = None
        self.arrivals = None
        self.arrival_rates = None
        self.virtual = True # Using Virtual Queue as State
        self.unrealized_matches = []
        self.v_vec = None
        self.solver = None

    def update_config(self, env, config):
        ''' Update agent information based on the config__file'''
        self.T = config["T"]
        self.matches = config["M"]
        self.m = self.matches.shape[0]
        self.rewards = config["R"]
        self.arrivals = config["IQ"] # Inital Queue
        self.arrival_rates = config["AR"] # arrival rates
        self.arrival_sum = config["IQ"]
        self.v_vec = [100.0] * self.T

        # Initialize solver with empirical arrival rates
        self.solver = StaticSolver(self.matches, self.arrival_rates[0], self.rewards[0])



    def update_parameters(self, param):
        pass

    def update_obs(self, obs, action, reward, newObs, timestep, info):
        '''Add observation to records'''
        pass

    def update_policy(self, arrival=None, remain_action_list=None, t=None):
        '''Update internal policy based upon records'''
        self.unrealized_matches = remain_action_list


    def pick_action(self, queue, arrival, t):
        '''Select an action based upon the observation'''


        # Get empirical dual values
        self.solver.solve()
        dual_values = self.solver.get_dual_solution()

        # Find best match based on reduced reward
        best_match_idx = None
        highest_reduced_reward = 1e-6

        for m in range(self.m):
            # Only consider matches using current arrival
            if np.inner(self.matches[m], arrival) > 0:
                # Compute reduced reward with queue pressure
                reduced_reward = (
                    self.rewards[t][m] -
                    np.inner(dual_values, self.matches[m]) +
                    np.inner(queue, self.matches[m]) / self.v_vec[t]
                )
                if reduced_reward > highest_reduced_reward:
                    best_match_idx = m
                    highest_reduced_reward = reduced_reward

        self.unrealized_matches.append(best_match_idx)

        return self.unrealized_matches



In [14]:
DM_ENV = DynamicMatching(DM_config)
DM_AGENT = PrimalDualAgent()

PrimalDual_Exp = Experiment(env = DM_ENV, agent = DM_AGENT)
PrimalDual_Exp.run()


Proposed Action List: [6]
Remain Action List: []

Proposed Action List: [15]
Remain Action List: [15]

Proposed Action List: [15, None]
Remain Action List: [None]

Proposed Action List: [None, 9]
Remain Action List: [None, 9]

Proposed Action List: [None, 9, 13]
Remain Action List: [None, 9, 13]

Proposed Action List: [None, 9, 13, 15]
Remain Action List: [None, 9, 13, 15]

Proposed Action List: [None, 9, 13, 15, None]
Remain Action List: [None, 9, 13, 15, None]

Proposed Action List: [None, 9, 13, 15, None, None]
Remain Action List: [None, 13, 15, None, None]

Proposed Action List: [None, 13, 15, None, None, None]
Remain Action List: [None, 13, 15, None, None, None]

Proposed Action List: [None, 13, 15, None, None, None, None]
Remain Action List: [None, 13, 15, None, None, None, None]

Total Rewards: 1.99980272


# Kidney Exchange Simulator

Import Packages

In [15]:
import numpy as np
import copy

 Basic data structures

In [16]:
class Vertex:
    def __init__(self, vid, is_patient, is_altruist, features, arrival_time=0):
        self.id = vid
        self.is_patient = is_patient
        self.is_altruist = is_altruist
        self.features = features
        self.arrival_time = arrival_time


class Graph:
    def __init__(self):
        self.V = {}  # {int: Vertex}
        self.E = {}  # {(int, int): float}

Matching (Naive Version)

In [17]:
import collections
import numpy as np
import scipy.sparse as sp
from scipy.optimize import milp, LinearConstraint, Bounds

def SolveIP(graph, max_cycle_len=3, max_chain_len=4):
    # Build outgoing adjacency
    out = collections.defaultdict(list)
    for (u, v), w in graph.E.items():
        if u in graph.V and v in graph.V:
            out[u].append(v)

    structures = []  # each: {"nodes": tuple, "edges": tuple[(u,v),...], "w": float}

    # Enumerate cycles (pairs only)
    pairs = [vid for vid, vv in graph.V.items() if not vv.is_altruist]
    for start in pairs:
        stack = [(start, [start])]
        while stack:
            cur, path = stack.pop()
            if len(path) > max_cycle_len:
                continue
            for nxt in out.get(cur, []):
                if nxt == start:
                    # found a cycle
                    if len(path) >= 2 and start == min(path):  # de-duplicate cycles
                        cyc = path[:]
                        edges = [(cyc[i], cyc[(i + 1) % len(cyc)]) for i in range(len(cyc))]
                        if all(e in graph.E for e in edges):
                            w = float(sum(graph.E[e] for e in edges))
                            structures.append({"nodes": tuple(cyc), "edges": tuple(edges), "w": w})
                else:
                    if nxt in path:
                        continue
                    if nxt not in graph.V:
                        continue
                    if graph.V[nxt].is_altruist:
                        continue
                    stack.append((nxt, path + [nxt]))

    #  Enumerate chains (altruist-start only)
    altruists = [vid for vid, vv in graph.V.items() if vv.is_altruist]
    for a in altruists:
        stack = [(a, [a], [], 0.0)]  # (cur, nodes_path, edges_path, w)
        while stack:
            cur, nodes_path, edges_path, w = stack.pop()
            if len(edges_path) >= max_chain_len:
                continue
            for nxt in out.get(cur, []):
                if nxt not in graph.V:
                    continue
                if graph.V[nxt].is_altruist:
                    continue
                if nxt in nodes_path:
                    continue
                e = (cur, nxt)
                if e not in graph.E:
                    continue

                new_nodes = nodes_path + [nxt]
                new_edges = edges_path + [e]
                new_w = w + float(graph.E[e])

                # every prefix is a valid chain
                structures.append({"nodes": tuple(new_nodes), "edges": tuple(new_edges), "w": float(new_w)})

                # continue from nxt
                stack.append((nxt, new_nodes, new_edges, new_w))

    if not structures:
        return []  # nothing feasible

    # Build A matrix for Ax <= 1
    vids = list(graph.V.keys())
    vid_to_row = {vid: i for i, vid in enumerate(vids)}
    n_rows = len(vids)
    n_vars = len(structures)

    rows, cols, data = [], [], []
    for j, s in enumerate(structures):
        for vid in set(s["nodes"]):
            rows.append(vid_to_row[vid])
            cols.append(j)
            data.append(1.0)

    A = sp.csr_matrix((data, (rows, cols)), shape=(n_rows, n_vars))

    lc = LinearConstraint(A, -np.inf * np.ones(n_rows), np.ones(n_rows)) # Linear constraints: Ax <= 1
    bounds = Bounds(np.zeros(n_vars), np.ones(n_vars)) # bound decision vars:0 <= x <= 1
    integrality = np.ones(n_vars, dtype=int)   # integer decision vars: {0, 1}

    # Objective
    c = -np.array([s["w"] for s in structures], dtype=float)

    res = milp(c=c, integrality=integrality, bounds=bounds, constraints=[lc]) # Utilize the Scipy Mixed-integer linear programming packages

    # Check Slover Success
    if res.x is None or res.status != 0:
        msg = getattr(res, "message", "MILP failed")
        raise RuntimeError(f"SolveIP MILP failed (status={res.status}). {msg}")

    chosen = [structures[j] for j in range(n_vars) if res.x[j] > 0.5]

    edges = []
    for s in chosen:
        edges.extend(list(s["edges"]))
    return edges


In [18]:
def greedy_matching(graph):
    """
    Toy max-weight matching:
      - sort edges by weight desc
      - each vertex at most once as donor and once as patient
    => The chosen edges form disjoint directed paths and directed cycles.
    """
    edges_sorted = sorted(graph.E.items(), key=lambda kv: kv[1], reverse=True)

    donors_used = set()
    patients_used = set()
    matching = []

    for (u, v), w in edges_sorted:
        if u in donors_used or v in patients_used:
            continue
        if u == v:
            continue
        matching.append((u, v))
        donors_used.add(u)
        patients_used.add(v)

    return matching

Events (Expire, negative crossmatch and renege)

In [19]:
def expire(vertex, rng, prob=0.01):
    """Paper: expire with (calibrated) constant probability. Here keep as parameter."""
    return rng.random() < prob


def negative_crossmatch(patient_vertex, rng):
    """
    Paper: failure probability depends on patient's CPRA.
      P(fail) = CPRA/100
      CPRA=100 -> fail prob = 1
    """
    cpra = int(patient_vertex.features.get("cpra", 0))
    cpra = max(0, min(100, cpra))
    return rng.random() < (cpra / 100.0)


def renege(pair_vertex, rng, default_prob=0.02):
    """
    Paper: only relevant for CHAINS (non-simultaneous) â€” the paired donor may renege
    on continuing the chain.
    We model it as a Bernoulli event whose probability can be a constant, or stored
    per-vertex in features['renege_prob'].
    """
    p = float(pair_vertex.features.get("renege_prob", default_prob))
    p = max(0.0, min(1.0, p))
    return rng.random() < p

Empirical samplers for f_p and f_a

In [20]:
class EmpiricalSampler:
    def __init__(self, bank, rng):
        if not bank:
            raise ValueError("bank is empty. Provide at least 1 record.")
        self.bank = bank
        self.rng = rng

    def __call__(self):
        idx = int(self.rng.integers(0, len(self.bank)))
        return copy.deepcopy(self.bank[idx])

ABOCompatible & w_OPTN

In [21]:
def abo_compatible(donor_abo: str, cand_abo: str) -> bool:
    '''
    Blood Type Compatiable Funciton:
        O compatible to A, B, AB, O
        A compatible to A, AB
        B compatible to B, AB
        AB compatible to AB
    '''
    d = donor_abo.upper()
    c = cand_abo.upper()
    if d == "O":
        return True
    if d == "A":
        return c in ("A", "AB")
    if d == "B":
        return c in ("B", "AB")
    if d == "AB":
        return c == "AB"
    raise ValueError(f"Unknown ABO: donor={donor_abo}, candidate={cand_abo}")


CAND_ABO_POINTS = {"O": 100, "B": 50, "A": 25, "AB": 0} # O only can recieve O type kidney -> higher priority
PAIRED_DONOR_ABO_POINTS = {"O": 0, "B": 100, "A": 250, "AB": 500} # AB only can donor their kidney to AB -> higher priority


def cpra_points(cpra: int) -> int:
    cpra = int(cpra)
    if not (0 <= cpra <= 100):
        raise ValueError("CPRA must be in [0, 100].")

    if 0 <= cpra <= 19:  return 0
    if 20 <= cpra <= 29: return 5
    if 30 <= cpra <= 39: return 10
    if 40 <= cpra <= 49: return 15
    if 50 <= cpra <= 59: return 20
    if 60 <= cpra <= 69: return 25
    if 70 <= cpra <= 74: return 50
    if 75 <= cpra <= 79: return 75
    if 80 <= cpra <= 84: return 125
    if 85 <= cpra <= 89: return 200
    if 90 <= cpra <= 94: return 300
    if cpra == 95: return 500
    if cpra == 96: return 700
    if cpra == 97: return 900
    if cpra == 98: return 1250
    if cpra == 99: return 1500
    if cpra == 100: return 2000
    raise RuntimeError("unreachable")


def _paired_donor_abo_points(paired_donor_abo):
    """If multiple ABO candidates exist, take the fewest points (conservative)."""
    if paired_donor_abo is None:
        return 0
    if isinstance(paired_donor_abo, str):
        return PAIRED_DONOR_ABO_POINTS[paired_donor_abo.upper()]
    pts = [PAIRED_DONOR_ABO_POINTS[a.upper()] for a in paired_donor_abo]
    return min(pts)


def donor_abo_of(vertex: Vertex) -> str:
    f = vertex.features
    if vertex.is_altruist:
        return f["donor_abo"]
    return f["paired_donor_abo"]


def candidate_abo_of(vertex: Vertex) -> str:
    return vertex.features["candidate_abo"]


def w_optn(donor_vertex: Vertex, cand_vertex: Vertex) -> float:
    f = cand_vertex.features
    cand_abo = f["candidate_abo"].upper()
    cpra = int(f["cpra"])
    wait_days = int(f.get("wait_days", 0))

    w = 100.0 + 0.07 * max(0, wait_days)

    # Check 0-ABDR mismatch
    z = f.get("zero_abdr_mismatch", False)
    if isinstance(z, dict):
        z = bool(z.get(donor_vertex.id, False))
    if z:
        w += 10.0

    # Check the Same Hosipital
    c1 = donor_vertex.features.get("center", None)
    c2 = f.get("center", None)
    if c1 is not None and c2 is not None and c1 == c2:
        w += 75.0

    # Check the Previous Cross Match record
    px = f.get("prev_crossmatch_ok", False)
    if isinstance(px, dict):
        px = bool(px.get(donor_vertex.id, False))
    if px:
        w += 75.0

    # Check the Candidate's Age
    age = f.get("candidate_age", f.get("age", None))
    if age is not None and int(age) < 18:
        w += 100.0

    # Check the Candidate's donor history
    if bool(f.get("prior_living_donor", False)):
        w += 150.0

    w += float(CAND_ABO_POINTS[cand_abo])
    w += float(_paired_donor_abo_points(f.get("paired_donor_abo", None)))
    w += float(cpra_points(cpra))

    if bool(f.get("orphan", False)):
        w += 1000000.0

    return w

SetPool Function

In [22]:
def sample_arrivals(t, graph, lam_p, lam_a, f_p, f_a, rng):
    next_id = max(graph.V.keys()) + 1 if graph.V else 0
    num_pairs = rng.poisson(lam_p)
    num_altruists = rng.poisson(lam_a)

    for _ in range(num_pairs):
        vid = next_id
        next_id += 1
        graph.V[vid] = Vertex(vid, True, False, f_p(), arrival_time=t + 1)

    for _ in range(num_altruists):
        vid = next_id
        next_id += 1
        graph.V[vid] = Vertex(vid, False, True, f_a(), arrival_time=t + 1)


def _ensure_dict_feature(cand_features: dict, key: str) -> dict:
    if key not in cand_features or cand_features[key] is None:
        cand_features[key] = {}
    elif not isinstance(cand_features[key], dict):
        cand_features[key] = {}
    return cand_features[key]


def build_edges(graph, rng, p_zero_abdr=0.02, p_prev_xm=0.05):
    """
    E(t) = ABOCompatible(V(t)) and weight = w_optn(u,v)
    Also fill match-specific dict flags per donor id for w_optn.
    """
    graph.E.clear()
    vids = list(graph.V.keys())

    for i in vids:
        u = graph.V[i]
        try:
            d_abo = donor_abo_of(u)
        except KeyError:
            continue

        for j in vids:
            if i == j:
                continue
            v = graph.V[j]
            if v.is_altruist:
                continue  # altruists cannot receive

            try:
                c_abo = candidate_abo_of(v)
            except KeyError:
                continue

            if not abo_compatible(d_abo, c_abo):
                continue

            # match-specific features for this donor u
            z_dict = _ensure_dict_feature(v.features, "zero_abdr_mismatch")
            px_dict = _ensure_dict_feature(v.features, "prev_crossmatch_ok")
            if u.id not in z_dict:
                z_dict[u.id] = (rng.random() < p_zero_abdr)
            if u.id not in px_dict:
                px_dict[u.id] = (rng.random() < p_prev_xm)

            graph.E[(i, j)] = w_optn(u, v)

def _matching_to_succ_pred(matching):
    succ, pred = {}, {}
    for u, v in matching:
        succ[u] = v
        pred[v] = u
    return succ, pred


def _find_cycles_from_succ(succ):
    """
    Since outdegree<=1 and indegree<=1, each component is either a simple path or a simple cycle.
    A cycle here must return to its start.
    Returns list of cycles, each cycle is list of vertex ids in cycle order.
    """
    visited = set()
    cycles = []

    for start in succ.keys():
        if start in visited:
            continue
        cur = start
        path = []
        seen = set()
        while cur in succ and cur not in seen:
            seen.add(cur)
            path.append(cur)
            cur = succ[cur]
        # if we returned to start, it's a cycle
        if cur == start:
            cycles.append(path)

        visited.update(path)

    return cycles




def step_pool(graph, t, lam_p, lam_a, f_p, f_a, rng,
              expire_prob=0.01, renege_prob=0.02,
              max_cycle_len=3, max_chain_len=4):
    # 1) SolveIP
    matching_edges = SolveIP(graph, max_cycle_len=max_cycle_len, max_chain_len=max_chain_len)

    # 2) Expire on V(t)
    expired = set()
    for vid, v in list(graph.V.items()):
        if expire(v, rng, prob=expire_prob):
            expired.add(vid)

    # remove expired vertices from matching
    matching_edges = [(u, v) for (u, v) in matching_edges if (u not in expired and v not in expired)]
    succ, pred = _matching_to_succ_pred(matching_edges)

    departures = set(expired)

    # 3) Cycles: all-or-nothing
    cycles = _find_cycles_from_succ(succ)
    for cyc in cycles:
        # cycles should contain only pairs; if not, skip
        if any(graph.V[x].is_altruist for x in cyc):
            continue

        ok = True
        for v_id in cyc:
            if negative_crossmatch(graph.V[v_id], rng):
                ok = False
                break

        if ok:
            # all vertices in a successful cycle depart
            for v_id in cyc:
                departures.add(v_id)

    # 4) Chains: sequential with tail cut
    chain_starts = [u for u in succ if graph.V[u].is_altruist]
    for start in chain_starts:
        cur = start
        chain_executed = False

        while cur in succ:
            nxt = succ[cur]
            if nxt not in graph.V or graph.V[nxt].is_altruist:
                break

            # recipient crossmatch fail => stop, recipient doesn't depart
            if negative_crossmatch(graph.V[nxt], rng):
                break

            # first successful transplant into nxt
            chain_executed = True
            departures.add(nxt)

            # if nxt reneges, stop here (tail removed)
            if renege(graph.V[nxt], rng, default_prob=renege_prob):
                break

            cur = nxt

        if chain_executed:
            departures.add(start)  # altruist departs only if chain started

    # IMPORTANT: after failures, some matched vertices might remain in succ/pred
    # but since we're ignoring OWQ we simply update pool by departures.

    # 5) Remove departures
    for vid in departures:
        graph.V.pop(vid, None)

    # 6) New arrivals + rebuild edges/weights
    sample_arrivals(t, graph, lam_p, lam_a, f_p, f_a, rng)
    build_edges(graph, rng)

    return graph, departures


Demo

In [23]:
if __name__ == "__main__":
    setpool = step_pool
    rng = np.random.default_rng(0)
    g = Graph()

    # Pair records must include: candidate_abo, paired_donor_abo, cpra, wait_days
    pair_bank = [
        {
            "type": "pair",
            "candidate_abo": "B",
            "paired_donor_abo": "A",
            "cpra": 20,
            "wait_days": 60,
            "candidate_age": 45,
            "prior_living_donor": False,
            "orphan": False,
            "center": 1,
            "zero_abdr_mismatch": {},
            "prev_crossmatch_ok": {},
            "renege_prob": 0.02,
        },
        {
            "type": "pair",
            "candidate_abo": "O",
            "paired_donor_abo": "B",
            "cpra": 80,
            "wait_days": 200,
            "candidate_age": 30,
            "prior_living_donor": False,
            "orphan": False,
            "center": 2,
            "zero_abdr_mismatch": {},
            "prev_crossmatch_ok": {},
            "renege_prob": 0.02,
        },
        {
            "type": "pair",
            "candidate_abo": "AB",
            "paired_donor_abo": "O",
            "cpra": 95,
            "wait_days": 10,
            "candidate_age": 12,
            "prior_living_donor": False,
            "orphan": False,
            "center": 1,
            "zero_abdr_mismatch": {},
            "prev_crossmatch_ok": {},
            "renege_prob": 0.02,
        },
    ]

    altruist_bank = [
        {"type": "altruist", "donor_abo": "O", "center": 1},
        {"type": "altruist", "donor_abo": "A", "center": 2},
        {"type": "altruist", "donor_abo": "B", "center": 3},
    ]

    f_p = EmpiricalSampler(pair_bank, rng)
    f_a = EmpiricalSampler(altruist_bank, rng)

    sample_arrivals(0, g, lam_p=3.0, lam_a=1.0, f_p=f_p, f_a=f_a, rng=rng)
    build_edges(g, rng)

    print("Initial: |V(0)| =", len(g.V), ", |E(0)| =", len(g.E))

    for t in range(5):
        g, D = step_pool(
            graph=g,
            t=t,
            lam_p=3.0,
            lam_a=1.0,
            f_p=f_p,
            f_a=f_a,
            rng=rng,
            expire_prob=0.01,
            renege_prob=0.02,
        )
        print(
            "t =", t + 1,
            "|departures| =", len(D),
            ", |V(t)| =", len(g.V),
            ", |E(t)| =", len(g.E),
        )


Initial: |V(0)| = 2 , |E(0)| = 2
t = 1 |departures| = 0 , |V(t)| = 4 , |E(t)| = 7
t = 2 |departures| = 2 , |V(t)| = 8 , |E(t)| = 36
t = 3 |departures| = 0 , |V(t)| = 12 , |E(t)| = 80
t = 4 |departures| = 2 , |V(t)| = 14 , |E(t)| = 126
t = 5 |departures| = 0 , |V(t)| = 17 , |E(t)| = 184
