Skip to content

Commit

Permalink
allow methods as process-based routing functions
Browse files Browse the repository at this point in the history
  • Loading branch information
geraintpalmer committed Apr 8, 2024
1 parent 1dc94c2 commit 5650e40
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 7 deletions.
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
History
-------

+ **3.1.3** (2024-04-08)**
+ Allows class methods as generator functions for process-based routing.

+ **3.1.2 (2024-04-08)**
+ Fix bug when using Mixture distribution.

Expand Down
12 changes: 6 additions & 6 deletions ciw/import_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def create_network_from_dictionary(params_input):
)
)
for clss_name in params['customer_class_names']:
if all(isinstance(f, types.FunctionType) for f in params["routing"]):
if all(isinstance(f, types.FunctionType) or isinstance(f, types.MethodType) for f in params["routing"]):
classes[clss_name] = CustomerClass(
params['arrival_distributions'][clss_name],
params['service_distributions'][clss_name],
Expand All @@ -140,7 +140,7 @@ def create_network_from_dictionary(params_input):
class_change_time_distributions[clss_name],
)
n = Network(nodes, classes)
if all(isinstance(f, types.FunctionType) for f in params["routing"]):
if all(isinstance(f, types.FunctionType) or isinstance(f, types.MethodType) for f in params["routing"]):
n.process_based = True
else:
n.process_based = False
Expand Down Expand Up @@ -220,7 +220,7 @@ def validify_dictionary(params):
Raises errors if there is something wrong with the
parameters dictionary.
"""
if all(isinstance(f, types.FunctionType) for f in params["routing"]):
if all(isinstance(f, types.FunctionType) or isinstance(f, types.MethodType) for f in params["routing"]):
consistant_num_classes = (
params["number_of_classes"]
== len(params["arrival_distributions"])
Expand All @@ -241,7 +241,7 @@ def validify_dictionary(params):
)
if not consistant_num_classes:
raise ValueError("Ensure consistant number of classes is used throughout.")
if all(isinstance(f, types.FunctionType) for f in params["routing"]):
if all(isinstance(f, types.FunctionType) or isinstance(f, types.MethodType) for f in params["routing"]):
consistant_class_names = (
set(params["arrival_distributions"])
== set(params["service_distributions"])
Expand All @@ -266,7 +266,7 @@ def validify_dictionary(params):
)
if not consistant_class_names:
raise ValueError("Ensure consistant names for customer classes.")
if all(isinstance(f, types.FunctionType) for f in params["routing"]):
if all(isinstance(f, types.FunctionType) or isinstance(f, types.MethodType) for f in params["routing"]):
num_nodes_count = (
[params["number_of_nodes"]]
+ [len(obs) for obs in params["arrival_distributions"].values()]
Expand Down Expand Up @@ -296,7 +296,7 @@ def validify_dictionary(params):
)
if len(set(num_nodes_count)) != 1:
raise ValueError("Ensure consistant number of nodes is used throughout.")
if not all(isinstance(f, types.FunctionType) for f in params["routing"]):
if not all(isinstance(f, types.FunctionType) or isinstance(f, types.MethodType) for f in params["routing"]):
for clss in params["routing"].values():
for row in clss:
if sum(row) > 1.0 or min(row) < 0.0 or max(row) > 1.0:
Expand Down
22 changes: 22 additions & 0 deletions ciw/tests/test_process_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ def generator_function_8(ind):
return [1]
return [1, 1, 1]

class ClassForProcessBasedMethod:
def __init__(self, n):
self.n = n
def generator_method(self, ind):
return [1, 1, 1]


class TestProcessBased(unittest.TestCase):
def test_network_takes_routing_function(self):
Expand Down Expand Up @@ -294,3 +300,19 @@ def test_customer_class_based_routing(self):
inds = Q.nodes[-1].all_individuals
routes_counter = set([tuple([ind.customer_class, tuple(dr.node for dr in ind.data_records)]) for ind in inds])
self.assertEqual(routes_counter, {('Class 1', (1, 1, 1)), ('Class 0', (1,))})

def test_process_based_takes_methods(self):
import types
G = ClassForProcessBasedMethod(5)
self.assertTrue(isinstance(G.generator_method, types.MethodType))
N = ciw.create_network(
arrival_distributions=[ciw.dists.Deterministic(1)],
service_distributions=[ciw.dists.Deterministic(1000)],
number_of_servers=[1],
routing=[G.generator_method],
)
Q = ciw.Simulation(N)
Q.simulate_until_max_time(4.5)
inds = Q.nodes[1].all_individuals
for ind in inds:
self.assertEqual(ind.route, [1, 1, 1])
2 changes: 1 addition & 1 deletion ciw/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.1.2"
__version__ = "3.1.3"

0 comments on commit 5650e40

Please sign in to comment.