-
Notifications
You must be signed in to change notification settings - Fork 32
/
service.py
1015 lines (884 loc) · 50.6 KB
/
service.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
############################################################################################
#
# Copyright 2020, University of Stuttgart: Institute for Natural Language Processing (IMS)
#
# This file is part of Adviser.
# Adviser is free software: you can redistribute it and/or modify'
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3.
#
# Adviser is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Adviser. If not, see <https://www.gnu.org/licenses/>.
#
############################################################################################
import copy
import datetime
import inspect
import pickle
import threading
import time
from threading import Thread
from typing import List, Dict, Union, Iterable, Any
import zmq
from zmq import Context, Socket
from zmq.devices import ThreadProxy, ProcessProxy
from utils.domain.domain import Domain
from utils.logger import DiasysLogger
from utils.topics import Topic
def _send_msg(pub_channel: Socket, topic: str, content: Any):
""" Serializes message, appends current timespamp and sends it over the specified channel to the specified topic.
Use this function for all internal message passing.
Args:
pub_channel (Socket): publisher socket
topic (str): topic to publish to
content (Any): message content
"""
timestamp = datetime.datetime.now().timestamp() # current timestamp as POSIX float
data = pickle.dumps((timestamp, content))
pub_channel.send_multipart((bytes(topic, encoding="ascii"), data))
def _send_ack(pub_channel: Socket, topic: str, content: bool = True):
""" Sends an acknowledge-message to the specified channel (ACK).
Is used together with `_recv_ack` to synchronize services (waiting for ACK messages).
Args:
pub_channel (Socket): publisher socket
topic (str): topic to send ACK to
content (bool): for ACK's, content is either `True` (ACK) or `False` (NACK)
"""
_send_msg(pub_channel, f"ACK/{topic}", content)
def _recv_ack(sub_channel: Socket, topic: str, expected_content: bool = True):
""" Blocks until an acknowledge-message for the specified topic with the expected content is received via the
specified subscriber channel.
Args:
sub_channel (Socket): subscriber socket
topic (str): topic to listen for ACK's
expected_content (bool): are we expecting `True` (ACK) or `False` (NACK)
"""
ack_topic = topic if topic.startswith("ACK/") else f"ACK/{topic}"
while True:
msg = sub_channel.recv_multipart(copy=True)
recv_topic = msg[0].decode("ascii")
content = pickle.loads(msg[1])[1] # pickle.loads(msg) -> tuple(timestamp, content) -> return content
if recv_topic == ack_topic:
if content == expected_content:
return
class RemoteService:
"""
This is a placeholder` to be used in the service list argument when constructing a `DialogSystem`:
* Run the real `Service` instance on a remote node, give it a *UNIQUE* identifier
* call `run_standalone()` on this instance
* Instantiate a remote service on the node about to run the `DialogSystem`, assign the *SAME* identifier to it
* add it to the `DialogSystem` service list
* Now, when calling the constructor of `DialogSystem`, you should see messages informing you about the
successfull connection, or if the system is still trying to connect, it will block until connected to
the remote service.
"""
def __init__(self, identifier: str):
"""
Args:
identifier (str): the *UNIQUE* identifier to call the remote service instance
"""
self.identifier = identifier
class Service:
"""
Service base class.
Inherit from this class, if you want to publish / subscribe to topics *(Don't forget to call the super constructor!)*.
You may decorate arbitrary functions in the child class with the services.service.PublishSubscribe decorator
for this purpose.
Note: A `Service` will only start listening to messages once it is added to a `DialogSystem`
(or calling `run_standalone()` in the remote case and adding a corresponding `RemoteService` to the `DialogSystem`).
"""
def __init__(self, domain: Union[str, Domain] = "", sub_topic_domains: Dict[str, str] = {}, pub_topic_domains: Dict[str, str] = {},
ds_host_addr: str = "127.0.0.1", sub_port: int = 65533, pub_port: int = 65534, protocol: str = "tcp",
debug_logger: DiasysLogger = None, identifier: str = None):
"""
Create a new service instance *(call this super constructor from your inheriting classes!)*.
Args:
domain (Union[str, Domain]): The domain(-name) of your service (or empty string, if domain-agnostic).
If a domain(-name) is set, it will automatically filter out all messages from other domains.
If no domain(-name) is set, messages from all domains will be received.
sub_topic_domains (Dict[str, str]): change subscribed to topics to listen to a specific domain
(e.g. 'erase'/append a domain for a specific topic)
pub_topic_domains (Dict[str, str]): change published topics to a specific domain
(e.g. 'erase'/append a domain for a specific topic)
ds_host_addr (str): IP-address of the parent `DialogSystem` (default: localhost)
sub_port (int): subscriber port following zmq's XSUB/XPUB pattern
pub_port (int): publisher port following zmq's XSUB/XPUB pattern
protocol (string): communication protocol with `DialogSystem` - has to match!
Possible options: `tcp`, `inproc`, `ipc`
debug_logger (DiasysLogger): If not `None`, all messags are printed to the logger, including send/receive events.
Can be useful for debugging because you can still see messages received by the `DialogSystem`
even if they are never forwarded (as expected) to your `Service`.
identifier (str): Set this to a *UNIQUE* identifier per service to be run remotely.
See `RemoteService` for more details.
"""
self.is_training = False
self.domain = domain
# get domain name (gets appended to all sub/pub topics so that different domain topics don't get shared)
if domain is not None:
self._domain_name = domain.get_domain_name() if isinstance(domain, Domain) else domain
else:
self._domain_name = ""
self._sub_topic_domains = sub_topic_domains
self._pub_topic_domains = pub_topic_domains
# socket information
self._host_addr = ds_host_addr
self._sub_port = sub_port
self._pub_port = pub_port
self._protocol = protocol
self._identifier = identifier
self.debug_logger = debug_logger
self._sub_topics = set()
self._pub_topics = set()
self._publish_sockets = dict()
self._internal_start_topics = dict()
self._internal_end_topics = dict()
self._internal_terminate_topics = dict()
# NOTE: class name + memory pointer make topic unique (required, e.g. for running mutliple instances of same module!)
self._start_topic = f"{type(self).__name__}/{id(self)}/START"
self._end_topic = f"{type(self).__name__}/{id(self)}/END"
self._terminate_topic = f"{type(self).__name__}/{id(self)}/TERMINATE"
self._train_topic = f"{type(self).__name__}/{id(self)}/TRAIN"
self._eval_topic = f"{type(self).__name__}/{id(self)}/EVAL"
def _init_pubsub(self):
""" Search for all functions decorated with the `PublishSubscribe` decorator and call the setup methods for them """
for func_name in dir(self):
func_inst = getattr(self, func_name)
if hasattr(func_inst, "pubsub"):
# found decorated publisher / subscriber function -> setup sockets and listeners
self._setup_listener(func_inst, getattr(func_inst, "sub_topics"),
getattr(func_inst, 'queued_sub_topics'))
self._setup_publishers(func_inst, getattr(func_inst, "pub_topics"))
def _register_with_dialogsystem(self):
""" Start listening to dialog system control channel messages """
self._setup_dialog_ctrl_msg_listener()
Thread(target=self._control_channel_listener).start()
def _setup_listener(self, func_instance, topics: List[str], queued_topics: List[str]):
"""
Starts a new subscription thread for a function decorated with `services.service.PublishSubscribe`.
Args:
func_instance (function): instance of the function that was decorated with `services.service.PublishSubscribe`.
topics (List[str]): list of subscribed topics (drops all but most recent messages before function call)
queued_topics (List[str]): list for subscribed topics (drops no messages, forward a list of received messages to function call)
"""
if len(topics + queued_topics) == 0:
# no subscribed to topics - no need to setup anything (e.g. only publisher)
return
# ensure that sub_topics and queued_sub_topics don't intersect (otherwise, both would set same function argument value)
assert set(topics).isdisjoint(queued_topics), "sub_topics and queued_sub_topics have to be disjoint!"
# setup socket
ctx = Context.instance()
subscriber = ctx.socket(zmq.SUB)
# subscribe to all listed topics
for topic in topics + queued_topics:
topic_domain_str = f"{topic}/{self._domain_name}" if self._domain_name else topic
if topic in self._sub_topic_domains:
# overwrite domain for this specific topic and service instance
topic_domain_str = f"{topic}/{self._sub_topic_domains[topic]}" if self._sub_topic_domains[topic] else topic
subscriber.setsockopt(zmq.SUBSCRIBE, bytes(topic_domain_str, encoding="ascii"))
# subscribe to control channels
subscriber.setsockopt(zmq.SUBSCRIBE, bytes(f"{func_instance}/START", encoding="ascii"))
subscriber.setsockopt(zmq.SUBSCRIBE, bytes(f"{func_instance}/END", encoding="ascii"))
subscriber.setsockopt(zmq.SUBSCRIBE, bytes(f"{func_instance}/TERMINATE", encoding="ascii"))
subscriber.connect(f"{self._protocol}://{self._host_addr}:{self._sub_port}")
self._internal_start_topics[f"{str(func_instance)}/START"] = str(func_instance)
self._internal_end_topics[f"{str(func_instance)}/END"] = str(func_instance)
self._internal_terminate_topics[f"{str(func_instance)}/TERMINATE"] = str(func_instance)
# register and run listener thread
listener_thread = Thread(target=self._receiver_thread, args=(subscriber, func_instance,
topics, queued_topics,
f"{str(func_instance)}/START",
f"{str(func_instance)}/END",
f"{str(func_instance)}/TERMINATE"))
listener_thread.start()
# add to list of local topics
# TODO maybe add topic_domain_str instead for more clarity?
self._sub_topics.update(topics + queued_topics)
def _setup_publishers(self, func_instance, topics):
""" Creates a publish socket for a function decorated with `services.service.PublishSubscribe`. """
if len(topics) == 0:
return # no topics - no need for a socket
# setup publish socket
ctx = Context.instance()
publisher = ctx.socket(zmq.PUB)
publisher.sndhwm = 1100000
publisher.connect(f"{self._protocol}://{self._host_addr}:{self._pub_port}")
self._publish_sockets[func_instance] = publisher
# add to list of local topics
self._pub_topics.update(topics)
def _setup_dialog_ctrl_msg_listener(self):
""" Setup a subscriber socket to receive `DialogSystem` control message """
ctx = Context.instance()
# setup receiver for dialog system control messages
self._control_channel_sub = ctx.socket(zmq.SUB)
self._control_channel_sub.setsockopt(zmq.SUBSCRIBE, bytes(self._start_topic, encoding="ascii"))
self._control_channel_sub.setsockopt(zmq.SUBSCRIBE, bytes(self._end_topic, encoding="ascii"))
self._control_channel_sub.setsockopt(zmq.SUBSCRIBE, bytes(self._terminate_topic, encoding="ascii"))
self._control_channel_sub.connect(f"{self._protocol}://{self._host_addr}:{self._sub_port}")
# setup sender for dialog system control message acknowledgements
self._control_channel_pub = ctx.socket(zmq.PUB)
self._control_channel_pub.sndhwm = 1100000
self._control_channel_pub.connect(f"{self._protocol}://{self._host_addr}:{self._pub_port}")
# setup receiver for internal ACK messages
self._internal_control_channel_sub = ctx.socket(zmq.SUB)
for internal_ctrl_topic in list(self._internal_end_topics.keys()) + list(
self._internal_start_topics.keys()) + list(self._internal_terminate_topics.keys()):
self._internal_control_channel_sub.setsockopt(zmq.SUBSCRIBE,
bytes(f"ACK/{internal_ctrl_topic}", encoding="ascii"))
self._internal_control_channel_sub.connect(f"{self._protocol}://{self._host_addr}:{self._sub_port}")
def _control_channel_listener(self):
""" Using the control message subscription socket, listen to control messages from the `DialogSystem` in a loop.
Meant to be called in a thread.
"""
listen = True
while listen:
try:
# receive message for subscribed control topic
msg = self._control_channel_sub.recv_multipart(copy=True)
topic = msg[0].decode("ascii")
timestamp, content = pickle.loads(msg[1])
if topic == self._start_topic:
# initialize dialog state
self.dialog_start()
# set all listeners of this service to listening mode (block until they are listening)
for internal_start_topic in self._internal_start_topics:
_send_msg(self._control_channel_pub, internal_start_topic, True)
_recv_ack(self._internal_control_channel_sub, internal_start_topic)
_send_ack(self._control_channel_pub, self._start_topic)
elif topic == self._end_topic:
# stop all listeners of this service (block until they stopped)
for internal_end_topic in self._internal_end_topics:
_send_msg(self._control_channel_pub, internal_end_topic, True)
_recv_ack(self._internal_control_channel_sub, internal_end_topic, True)
self.dialog_end()
_send_ack(self._control_channel_pub, self._end_topic)
elif topic == self._terminate_topic:
# terminate all listeners of this service (block until they stopped)
for internal_terminate_topic in self._internal_terminate_topics:
_send_msg(self._control_channel_pub, internal_terminate_topic, True)
_recv_ack(self._internal_control_channel_sub, internal_terminate_topic, True)
self.dialog_exit()
_send_ack(self._control_channel_pub, self._terminate_topic)
listen = False
elif topic == self._train_topic:
self.train()
_send_ack(self._control_channel_pub, self._train_topic)
elif topic == self._eval_topic:
self.eval()
_send_ack(self._control_channel_pub, self._eval_topic)
else:
if self.debug_logger:
self.debug_logger.info("- (Service): received unknown control message from topic", topic,
" with content", content)
except KeyboardInterrupt:
break
except:
import traceback
print("ERROR in Service: _control_channel_listener")
traceback.print_exc()
def dialog_start(self):
""" This function is called before the first message to a new dialog is published.
You should overwrite this function to set/reset dialog-level variables. """
pass
def dialog_end(self):
""" This function is called after a dialog ended (Topics.DIALOG_END message was received).
You should overwrite this function to record dialog-level information. """
pass
def dialog_exit(self):
""" This function is called when the dialog system is shutting down.
You should overwrite this function to stop your threads and cleanup any open resources. """
pass
def train(self):
""" Sets module to training mode """
self.is_training = True
def eval(self):
""" Sets module to eval mode """
self.is_training = False
def run_standalone(self, host_reg_port: int = 65535):
"""
Run this service as a standalone serivce (without a `DialogSystem`) on a remote node.
Use a `RemoteService` with *corresponding identifier* on the `DialogSystem` node to connect both.
Note: this call is blocking!
Args:
host_reg_port (int): The port on the `DialogSystem` node listening for `Service` register requests
"""
assert self._identifier is not None, "running a service on a remote node requires a unique identifier"
print("Waiting for dialog system host...")
# send service info to dialog system node
self._init_pubsub()
ctx = Context.instance()
sync_endpoint = ctx.socket(zmq.REQ)
sync_endpoint.connect(f"tcp://{self._host_addr}:{host_reg_port}")
data = pickle.dumps((self._domain_name, self._sub_topics, self._pub_topics, self._start_topic, self._end_topic,
self._terminate_topic))
sync_endpoint.send_multipart((bytes(f"REGISTER_{self._identifier}", encoding="ascii"), data))
# wait for registration confirmation
registered = False
while not registered:
msg = sync_endpoint.recv()
msg = msg.decode("utf-8")
if msg.startswith("ACK_REGISTER_"):
remote_service_identifier = msg[len("ACK_REGISTER_"):]
if remote_service_identifier == self._identifier:
self._register_with_dialogsystem()
sync_endpoint.send_multipart(
(bytes(f"CONF_REGISTER_{self._identifier}", encoding="ascii"), pickle.dumps(True)))
registered = True
print(f"Done")
def get_all_subscribed_topics(self):
"""
Returns:
Set of all topics subscribed to by this `Service`
"""
return copy.deepcopy(self._sub_topics)
def get_all_published_topics(self):
"""
Returns:
Set of all topics published to by this `Service`
"""
return copy.deepcopy(self._pub_topics)
def _receiver_thread(self, subscriber: Socket, func_instance,
topics: Iterable[str], queued_topics: Iterable[str],
start_topic: str, end_topic: str, terminate_topic: str):
"""
Loop for receiving messages.
Will continue until a message for `terminate_topic` is received.
Handles waiting for messages, decoding, unpickling and subscription topic to
service function keyword mapping.
Meant to be run in a Thread!
Args:
subscriber (Socket): subscriber socket
func_instance (function instance): the decorated subscriber function instance to be called with the received messages
topics (Iterable[str]): all last-message-only topics the decorated `func_instance` subscribes to
queued_topics (Iterable[str]): all collect-all-messages-since-last-call topics the decorated `func_instance` subscribes to
start_topic (str): Control message topic to set this specific `function_instance` into listening mode (receive all non-control messages)
end_topic (str): Control message topic to set this specific `function_instance` into non-listening mode (ignore all non-control messages)
terminate_topic (str): Control message topic to end the listener loop for this specific `function_instance`.
Also closes the socket before returning.
"""
ctx = Context.instance()
control_channel_pub = ctx.socket(zmq.PUB)
control_channel_pub.sndhwm = 1100000
control_channel_pub.connect(f"{self._protocol}://{self._host_addr}:{self._pub_port}")
values = {}
timestamps = {}
all_sub_topics = topics + queued_topics
num_topics = len(all_sub_topics)
active = False
terminating = False
while not terminating:
try:
msg = subscriber.recv_multipart(copy=True)
topic = msg[0].decode("ascii")
# based on topic, decide what to do
if topic == start_topic:
# reset values and start listening to non-control messages
values = {}
timestamps = {}
active = True
_send_ack(control_channel_pub, start_topic)
elif topic == end_topic:
# ignore all non-control messages
active = False
_send_ack(control_channel_pub, end_topic)
elif topic == terminate_topic:
# shutdown listener thread by exiting loop
active = False
_send_ack(control_channel_pub, terminate_topic)
terminating = True
else:
# non-control message
if active:
# process message
timestamp, content = pickle.loads(msg[1])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): listener thread for function {func_instance}:\n received for topic {topic}:\n {content}")
# simple synchronization mechanism: remember only newest values,
# store them until there was at least 1 new value received per topic.
# Then call callback function with complete set of values.
# Reset values afterwards and start collecting again.
# problem: routing based on prefixes -> function argument names may differ
# solution: find longest common prefix of argument name and received topic
common_prefix = ""
for key in all_sub_topics:
if topic.startswith(key) and len(topic) > len(common_prefix):
common_prefix = key
if common_prefix in topics:
# store only latest value
values[common_prefix] = content # set value for received topic
timestamps[common_prefix] = timestamp # set timestamp for received value
else:
# topic is a queued_topic - queue all values and their timestamps
if not common_prefix in values:
values[common_prefix] = []
timestamps[common_prefix] = []
values[common_prefix].append(content)
timestamps[common_prefix].append(timestamp)
if len(values) == num_topics:
# received a new value for each topic -> call callback function
if func_instance.timestamp_enabled:
# append timestamps, if required
values['timestamps'] = timestamps
if self.debug_logger:
self.debug_logger.info(
f"- (DS): received all messages for function {func_instance}\n -> CALLING function")
if self.__class__ == Service:
# NOTE workaround for publisher / subscriber without being an instance method
func_instance(**values)
else:
func_instance(self, **values)
# reset values
values = {}
timestamps = {}
except KeyboardInterrupt:
break
except:
print("THREAD ERROR")
import traceback
traceback.print_exc()
# shutdown
subscriber.close()
# Each decorated function should return a dictonary with the keys matching the pub_topics names
def PublishSubscribe(sub_topics: List[str] = [], pub_topics: List[str] = [], queued_sub_topics: List[str] = []):
"""
Decorator function for services.
To be able to publish / subscribe to / from topics,
your class is required to inherit from services.service.Service.
Then, decorate any function you like.
Your function will be called as soon as:
* at least one message is received for each topic in sub_topics (only latest message will be forwarded, others dropped)
* at least one message is received for each topic in queued_sub_topics (all messages since the previous function call will be forwarded as a list)
Args:
sub_topics(List[str or utils.topics.Topic]): The topics you want to get the latest messages from.
If multiple messages are received until your function is called,
you will only receive the value of the latest message, previously received
values will be discarded.
pub_topics(List[str or utils.topics.Topic]): The topics you want to publish messages to.
queued_sub_topics(List[str or utils.topics.Topic]): The topics you want to get all messages from.
If multiple messages are received until your function is called,
you will receive all values since the previous function call as a list.
Notes:
* Subscription topic names have to match your function keywords
* Your function should return a dictionary with the keys matching your publish topics names
and the value being any arbitrary python object or primitive type you want to send
* sub_topics and queued_sub_topics have to be disjoint!
* If you need timestamps for your messages, specify a 'timestamps' argument in your subscribing function.
It will be filled by a dictionary providing timestamps for each received value, indexed by name.
Technical notes:
* Data will be automatically pickled / unpickled during send / receive to reduce meassage size.
However, some python objects are not serializable (e.g. database connections) for good reasons
and will throw an error if you try to publish them.
* The domain name of your service class will be appended to your publish topics.
Subscription topics are prefix-matched, so you will receive all messages from 'topic/suffix'
if you subscibe to 'topic'.
"""
def wrapper(func):
def delegate(self, *args, **kwargs):
func_inst = getattr(self, func.__name__)
callargs = list(args)
if self in callargs: # remove self when in *args, because already known to function
callargs.remove(self)
result = func(self, *callargs, **kwargs)
if result:
# fix! (user could have multiple "/" characters in topic - only use last one )
domains = {res.split("/")[0]: res.split("/")[1] if "/" in res else "" for res in result}
result = {key.split("/")[0]: result[key] for key in result}
if func_inst not in self._publish_sockets:
# not a publisher, just normal function
return result
socket = self._publish_sockets[func_inst]
domain = self._domain_name
if socket and result:
# publish messages
for topic in pub_topics:
# for topic in result: # NOTE publish any returned value in dict with it's key as topic
if topic in result:
domain = domain if domain else domains[topic]
topic_domain_str = f"{topic}/{domain}" if domain else topic
if topic in self._pub_topic_domains:
topic_domain_str = f"{topic}/{self._pub_topic_domains[topic]}" if self._pub_topic_domains[topic] else topic
_send_msg(socket, topic_domain_str, result[topic])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): sent message from {func} to topic {topic_domain_str}:\n {result[topic]}")
return result
# declare function as publish / subscribe functions and attach the respective topics
delegate.pubsub = True
delegate.sub_topics = sub_topics
delegate.queued_sub_topics = queued_sub_topics
delegate.pub_topics = pub_topics
# check arguments: is subsriber interested in timestamps?
delegate.timestamp_enabled = 'timestamps' in inspect.getfullargspec(func)[0]
return delegate
return wrapper
class DialogSystem:
"""
This class will constrct a dialog system from the list of services provided to the constructor.
It will also handle synchronization for initalization of services before dialog start / after dialog end / on system shutdown
and lets you discover potential conflicts in you messaging pipeline.
This class is also used to communicate / synchronize with services running on different nodes.
"""
def __init__(self, services: List[Union[Service, RemoteService]], sub_port: int = 65533, pub_port: int = 65534,
reg_port: int = 65535, protocol: str = 'tcp', debug_logger: DiasysLogger = None):
"""
Args:
services (List[Union[Service, RemoteService]]): List of all (remote) services to connect to.
Only once they're specified here will they start listening for
messages.
sub_port(int): subscriber port
sub_addr(str): IP-address or domain name of proxy subscriber interface (e.g. 127.0.0.1 for your local machine)
pub_port(int): publisher port
pub_addr(str): IP-address or domain name of proxy publisher interface (e.g. 127.0.0.1 for your local machine)
reg_port (int): registration port for remote services
protocol(str): communication protol, either 'inproc' or 'tcp' or `ipc`
debug_logger (DiasysLogger): If not `None`, all messags are printed to the logger, including send/receive events.
Can be useful for debugging because you can still see messages received by the `DialogSystem`
even if they are never forwarded (as expected) to your `Service`
"""
# node-local topics
self.debug_logger = debug_logger
self.protocol = protocol
self._sub_topics = {}
self._pub_topics = {}
self._remote_identifiers = set()
self._services = [] # collects names and instances of local services
self._start_dialog_services = set() # collects names of local services that subscribe to dialog_start
# node-local sockets
self._domains = set()
# start proxy thread
self._proxy_dev = ProcessProxy(in_type=zmq.XSUB, out_type=zmq.XPUB) # , mon_type=zmq.XSUB)
self._proxy_dev.bind_in(f"{protocol}://127.0.0.1:{pub_port}")
self._proxy_dev.bind_out(f"{protocol}://127.0.0.1:{sub_port}")
self._proxy_dev.start()
self._sub_port = sub_port
self._pub_port = pub_port
# thread control
self._start_topics = set()
self._end_topics = set()
self._terminate_topics = set()
self._stopEvent = threading.Event()
# control channels
ctx = Context.instance()
self._control_channel_pub = ctx.socket(zmq.PUB)
self._control_channel_pub.sndhwm = 1100000
self._control_channel_pub.connect(f"{protocol}://127.0.0.1:{pub_port}")
self._control_channel_sub = ctx.socket(zmq.SUB)
# register services (local and remote)
remote_services = {}
for service in services:
if isinstance(service, Service):
# register local service
service_name = type(service).__name__ if service._identifier is None else service._identifier
service._init_pubsub()
self._add_service_info(service_name, service._domain_name, service._sub_topics, service._pub_topics,
service._start_topic, service._end_topic, service._terminate_topic)
service._register_with_dialogsystem()
elif isinstance(service, RemoteService):
remote_services[getattr(service, 'identifier')] = service
self._register_remote_services(remote_services, reg_port)
self._control_channel_sub.connect(f"{protocol}://127.0.0.1:{sub_port}")
self._setup_dialog_end_listener()
time.sleep(0.25)
def _register_pub_topic(self, publisher, topic: str):
""" Map a publisher instance to a topic """
if not topic in self._pub_topics:
self._pub_topics[topic] = set()
self._pub_topics[topic].add(publisher)
def _register_sub_topic(self, subscriber, topic):
""" Map a subscriber instance to a topic """
if not topic in self._sub_topics:
self._sub_topics[topic] = set()
self._sub_topics[topic].add(subscriber)
def _register_remote_services(self, remote_services: List[RemoteService], reg_port: int):
"""
Register all remote services.
*Blocking* until an ACK was received from all of them, confirming they're setup and ready.
Args:
remote_services (List[RemoteService]): list of all remote services to register
reg_port (int): registration port for remote services
"""
if len(remote_services) == 0:
return # nothing to register
# Socket to receive registration requests
ctx = Context.instance()
reg_service = ctx.socket(zmq.REP)
reg_service.bind(f'tcp://127.0.0.1:{reg_port}')
while len(remote_services) > 0:
# call next remote service
msg, data = reg_service.recv_multipart()
msg = msg.decode("utf-8")
if msg.startswith("REGISTER_"):
# make sure we have a register message
remote_service_identifier = msg[len("REGISTER_"):]
if remote_service_identifier in remote_services:
print(f"registering service {remote_service_identifier}...")
# add remote service interface info
domain_name, sub_topics, pub_topics, start_topic, end_topic, terminate_topic = pickle.loads(data)
self._add_service_info(remote_service_identifier, domain_name, sub_topics, pub_topics, start_topic,
end_topic, terminate_topic)
self._remote_identifiers.add(remote_service_identifier)
# acknowledge service registration
reg_service.send(bytes(f'ACK_REGISTER_{remote_service_identifier}', encoding="ascii"))
elif msg.startswith("CONF_REGISTER_"):
# complete registration
remote_service_identifier = msg[len("CONF_REGISTER_"):]
if remote_service_identifier in remote_services:
del remote_services[remote_service_identifier]
print(f"successfully registered service {remote_service_identifier}")
reg_service.send(bytes(f"", encoding="ascii"))
print("########## Finished registering all remote services ##########")
def _add_service_info(self, service_name: str, domain_name: str, sub_topics: List[str], pub_topics: List[str],
start_topic: str, end_topic:str, terminate_topic: str):
""" Add all relevant info from a service (needed to construct dialog graph for debugging).
Also, sets up all required control channels for this service based on the service's info.
Args:
service_name (str): service name
domain_name (str): domain name
sub_topics (List[str]): list of all subscribed to topics of the given service
pub_topics (List[str]): list of all topics the given service publishes to
start_topic (str): control channel topic for setting given service into `listening` mode
end_topic (str): control channel topic for setting given service into `non-listening` mode
terminate_topic (str): control channel topic for stopping given service's listener loops and
closing the listener sockets
"""
self._domains.add(domain_name)
for topic in sub_topics:
self._register_sub_topic(service_name, topic)
for topic in pub_topics:
self._register_pub_topic(service_name, topic)
# setup control channels
self._start_topics.add(start_topic)
self._end_topics.add(end_topic)
self._terminate_topics.add(terminate_topic)
self._control_channel_sub.setsockopt(zmq.SUBSCRIBE, bytes(f"ACK/{start_topic}", encoding="ascii"))
self._control_channel_sub.setsockopt(zmq.SUBSCRIBE, bytes(f"ACK/{end_topic}", encoding="ascii"))
self._control_channel_sub.setsockopt(zmq.SUBSCRIBE, bytes(f"ACK/{terminate_topic}", encoding="ascii"))
def _setup_dialog_end_listener(self):
""" Creates socket for listening to Topic.DIALOG_END messages """
ctx = Context.instance()
self._end_socket = ctx.socket(zmq.SUB)
# subscribe to dialog end from all domains
self._end_socket.setsockopt(zmq.SUBSCRIBE, bytes(Topic.DIALOG_END, encoding="ascii"))
self._end_socket.connect(f"{self.protocol}://127.0.0.1:{self._sub_port}")
# # add to list of local topics
# if Topic.DIALOG_END not in self._local_sub_topics:
# self._local_sub_topics[Topic.DIALOG_END] = set()
# self._local_sub_topics[Topic.DIALOG_END].add(type(self).__name__)
def stop(self):
""" Set stop event (can be queried by services via the `terminating()` function) """
self._stopEvent.set()
pass
def terminating(self):
""" Returns True if the system is stopping, else False """
return self._stopEvent.is_set()
def shutdown(self):
""" Shutdown dialog system.
This will trigger `terminate` messages to be sent to all registered services to stop their listener loops.
Should be called in the end before exiting your program.
Blocks until all services sent ACK's confirming they're stopped.
"""
self._stopEvent.set()
for terminate_topic in self._terminate_topics:
_send_msg(self._control_channel_pub, terminate_topic, True)
_recv_ack(self._control_channel_sub, terminate_topic)
def _end_dialog(self):
""" Block until all receivers stopped listening.
Then, calls `dialog_end` on all registered services. """
# listen for Topic.DIALOG_END messages
while True:
try:
msg = self._end_socket.recv_multipart(copy=True)
# receive message for subscribed topic
topic = msg[0].decode("ascii")
timestamp, content = pickle.loads(msg[1])
if content:
if self.debug_logger:
self.debug_logger.info(f"- (DS): received DIALOG_END message in _end_dialog from topic {topic}")
self.stop()
break
except KeyboardInterrupt:
break
except:
import traceback
traceback.print_exc()
print("ERROR in _end_dialog ")
# stop receivers (blocking)
for end_topic in self._end_topics:
_send_msg(self._control_channel_pub, end_topic, True)
_recv_ack(self._control_channel_sub, end_topic)
if self.debug_logger:
self.debug_logger.info(f"- (DS): all services STOPPED listening")
def _start_dialog(self, start_signals: dict):
""" Block until all receivers started listening.
Then, call `dialog_start`on all registered services.
Finally, publish all start signals given. """
self._stopEvent.clear()
# start receivers (blocking)
for start_topic in self._start_topics:
_send_msg(self._control_channel_pub, start_topic, True)
_recv_ack(self._control_channel_sub, start_topic)
if self.debug_logger:
self.debug_logger.info(f"- (DS): all services STARTED listening")
# publish first turn trigger
# for domain in self._domains:
# "wildcard" mechanism: publish start messages to all known domains
for topic in start_signals:
_send_msg(self._control_channel_pub, f"{topic}", start_signals[topic])
def run_dialog(self, start_signals: dict = {Topic.DIALOG_END: False}):
""" Run a complete dialog (blocking).
Dialog will be started via messages to the topics specified in `start_signals`.
The dialog will end on receiving any `Topic.DIALOG_END` message with value 'True',
so make sure at least one service in your dialog graph will publish this message eventually.
Args:
start_signals (Dict[str, Any]): mapping from topic -> value
Publishes the value given for each topic to the respective topic.
Use this to trigger the start of your dialog system.
"""
self._start_dialog(start_signals)
self._end_dialog()
def list_published_topics(self):
""" Get all declared publisher topics.
Returns:
A dictionary with mapping
topic (str) -> publishing services (Set[str]).
Note:
* Call this method after instantiating all services.
* Even though a publishing topic might be listed here, there is no guarantee that
its publisher(s) might ever publish to it.
"""
return copy.deepcopy(self._pub_topics) # copy s.t. no user changes this list
def list_subscribed_topics(self):
""" Get all declared subscribed topics.
Returns:
A dictionary with mapping
topic (str) -> subscribing services (Set[str]).
Notes:
* Call this method after instantiating all services.
"""
return copy.deepcopy(self._sub_topics) # copy s.t. no user changes this list
def draw_system_graph(self, name: str = 'system', format: str = "png", show: bool = True):
""" Draws a graph of the system as a directed graph.
Services are represented by nodes, messages by directed edges (from publisher to subscriber).
Warnings are drawn as yellow edges (and the missing subscribers represented by an 'UNCONNECTED SERVICES' node),
errors as red edges (and the missing publishers represented by the 'UNCONNECTED SERVICES' node as well).
Will mark remote services with blue.
Args:
name (str): used to construct the name of your output file
format (str): output file format (e.g. png, pdf, jpg, ...)
show (bool): if True, the graph image will be opened in your default image viewer application
Requires:
graphviz library (pip install graphviz)
"""
from graphviz import Digraph
g = Digraph(name=name, format=format)
# collect all services, errors and warnings
services = set()
for service_set in self._pub_topics.values():
services = services.union(service_set)
for service_set in self._sub_topics.values():
services = services.union(service_set)
errors, warnings = self.list_inconsistencies()
# add services as nodes
for service in services:
if service in self._remote_identifiers:
g.node(service, color='#1f618d', style='filled', fontcolor='white', shape='box') # remote service
else:
g.node(service, color='#1c2833', shape='box') # local service
if len(errors) > 0 or len(warnings) > 0:
g.node('UNCONNECTED SERVICES', style='filled', color='#922b21', fontcolor='white', shape='box')
# draw connections from publisher to subscribers as edges
for topic in self._pub_topics:
publishers = self._pub_topics[topic]
receivers = self._sub_topics[topic] if topic in self._sub_topics else []
for receiver in receivers:
for publisher in publishers:
g.edge(publisher, receiver, label=topic)
# draw warnings and errors as edges to node 'UNCONNECTED SERVICES'
for topic in errors:
receivers = errors[topic]
for receiver in receivers:
g.edge('UNCONNECTED SERVICES', receiver, color='#c34400', fontcolor='#c34400', label=topic)
for topic in warnings:
publishers = warnings[topic]
for publisher in publishers:
g.edge(publisher, 'UNCONNECTED SERVICES', color='#e37c02', fontcolor='#e37c02', label=topic)
# draw graph
g.render(view=show, cleanup=True)
def list_inconsistencies(self):
""" Checks for potential errors in the current messaging pipleline:
e.g. len(list_inconsistencies()[0]) == 0 -> error free pipeline
(Potential) Errors are defined in this context as subscribed topics without publishers.
Warnings are defined in this context as published topics without subscribers.
Returns:
A touple of dictionaries:
* the first dictionary contains potential errors (with the mapping topics -> subsribing services)
* the second dictionary contains warnings (with the mapping topics -> publishing services).
Notes:
* Call this method after instantiating all services.
* Even if there are no errors returned by this method, there is not guarantee that all publishers
eventually publish to their respective topics.
"""
# look for subscribers w/o publishers by checking topic prefixes
errors = {}
for sub_topic in self._sub_topics:
found_pub = False
for pub_topic in self._pub_topics:
if pub_topic.startswith(sub_topic):
found_pub = True
break
if not found_pub:
errors[sub_topic] = self._sub_topics[sub_topic]
# look for publishers w/o subscribers by checking topic prefixes
warnings = {}
for pub_topic in self._pub_topics:
found_sub = False
for sub_topic in self._sub_topics:
if pub_topic.startswith(sub_topic):
found_sub = True
break
if not found_sub:
warnings[pub_topic] = self._pub_topics[pub_topic]
return errors, warnings
def print_inconsistencies(self):
""" Checks for potential errors in the current messaging pipleline:
e.g. len(list_local_inconsistencies()[0]) == 0 -> error free pipeline and prints them
to the console.
(Potential) Errors are defined in this context as subscribed topics without publishers.
Warnings are defined in this context as published topics without subscribers.
Notes:
* Call this method after instantiating all services.
* Even if there are no errors returned by this method, there is not guarantee that all publishers
eventually publish to their respective topics.
"""
# console colors
WARNING = '\033[93m'
ERROR = '\033[91m'
ENDC = '\033[0m'
errors, warnings = self.list_inconsistencies()
print(ERROR)
print("(Potential) Errors (subscribed topics without publishers):")
for topic in errors:
print(f" topic: '{topic}', subscribed to in services: {errors[topic]}")
print(ENDC)
print(WARNING)
print("Warnings (published topics without subscribers):")
for topic in warnings:
print(f" topic: '{topic}', published in services: {warnings[topic]}")
print(ENDC)