-
Notifications
You must be signed in to change notification settings - Fork 0
/
runner.py
132 lines (101 loc) · 5.28 KB
/
runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, NamedTuple, Optional
from haystack import Pipeline
from haystack import component
from haystack.core.component import Component
from haystack.core.component.sockets import Sockets
from haystack.core.component import InputSocket, OutputSocket
from copy import deepcopy
class NamedComponent(NamedTuple):
name: str
component: Component
class NamedPipeline(NamedTuple):
name: str
pipeline: Pipeline
@component
class ConcurrentComponentRunner:
"""
This component allows you to run multiple components concurrently in a thread pool.
"""
def __init__(self, named_components: List[NamedComponent], executor: Optional[ThreadPoolExecutor | None] = None):
"""
:param named_components: List of NamedComponent instances
:param executor: ThreadPoolExecutor instance if not provided a new one will be created with default values
"""
if type(named_components) != list or any(
[type(named_component) != NamedComponent for named_component in named_components]
):
raise ValueError("named_components must be a list of NamedComponent instances")
names = [named_component.name for named_component in named_components]
if len(names) != len(set(names)):
raise ValueError("All components must have unique names")
self._create_input_types(named_components)
self._create_output_types(named_components)
self.executor = executor
self.components = named_components
def _create_input_types(self, named_components: List[NamedComponent]):
if not hasattr(self, "__haystack_input__"):
self.__haystack_input__ = Sockets(self, {}, InputSocket)
for named_component in named_components:
socket_dict = deepcopy(named_component.component.__haystack_input__._sockets_dict)
for name, value in socket_dict.items():
value.name = f"{named_component.name}_{name}"
self.__haystack_input__[f"{named_component.name}_{name}"] = value
def _create_output_types(self, named_components: List[NamedComponent]):
if not hasattr(self, "__haystack_output__"):
self.__haystack_output__ = Sockets(self, {}, OutputSocket)
for named_component in named_components:
socket_dict = deepcopy(named_component.component.__haystack_output__._sockets_dict)
for name, value in socket_dict.items():
value.name = f"{named_component.name}_{name}"
self.__haystack_output__[f"{named_component.name}_{name}"] = value
def run(self, **inputs):
if self.executor is None:
with ThreadPoolExecutor() as executor:
final_results = self._run_in_executor(executor, inputs)
else:
final_results = self._run_in_executor(self.executor, inputs)
outputs = {}
for named_component, result in zip(self.components, final_results):
for key, value in result.items():
outputs[f"{named_component.name}_{key}"] = value
return outputs
def _run_in_executor(self, executor, inputs):
def _get_real_input(component_name, inputs):
real_input = {}
for key, value in inputs.items():
if key.startswith(component_name):
real_input[key.replace(f"{component_name}_", "")] = value
return real_input
results = executor.map(lambda c: c.component.run(**_get_real_input(c.name, inputs)), self.components)
return [result for result in results]
@component
class ConcurrentPipelineRunner:
"""
This component allows you to run multiple pipelines concurrently in a thread pool.
"""
def __init__(self, named_pipelines: List[NamedPipeline], executor: Optional[ThreadPoolExecutor | None] = None):
if type(named_pipelines) != list or any(
[type(named_pipeline) != NamedPipeline for named_pipeline in named_pipelines]
):
raise ValueError("named_pipelines must be a list of NamedPipeline instances")
names = [named_pipeline.name for named_pipeline in named_pipelines]
if len(names) != len(set(names)):
raise ValueError("All components must have unique names")
for named_pipeline in named_pipelines:
component.set_input_type(self, named_pipeline.name, {named_pipeline.name: Dict[str, Any]})
output_types = {}
for named_pipeline in named_pipelines:
output_types[named_pipeline.name] = Dict[str, Any]
self.pipelines = named_pipelines
self.executor = executor
def run(self, **inputs):
if self.executor is None:
with ThreadPoolExecutor() as executor:
final_results = self._run_in_executor(executor, inputs)
else:
final_results = self._run_in_executor(self.executor, inputs)
return {named_pipeline.name: result for named_pipeline, result in zip(self.pipelines, final_results)}
def _run_in_executor(self, executor: ThreadPoolExecutor, inputs: Dict[str, Any]):
results = executor.map(lambda c: c[0].pipeline.run(data=inputs[c[1]]), zip(self.pipelines, inputs.keys()))
return [result for result in results]