Skip to content

Commit c17acd2

Browse files
committed
feat(KDP): added selective retention of outputs based on dependencies among layers
1 parent 448f63f commit c17acd2

File tree

1 file changed

+64
-26
lines changed

1 file changed

+64
-26
lines changed

kdp/dynamic_pipeline.py

Lines changed: 64 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,84 @@
1+
import tensorflow as tf
2+
3+
14
class DynamicPreprocessingPipeline:
25
"""
3-
Dynamically initializes a sequence of Keras preprocessing layers based on the output
4-
from each previous layer, allowing each layer to access the outputs of all prior layers where relevant.
6+
Dynamically initializes and manages a sequence of Keras preprocessing layers, with selective retention of outputs
7+
based on dependencies among layers, and supports streaming data through the pipeline.
58
"""
69

710
def __init__(self, layers):
811
"""
9-
Initializes the DynamicPreprocessingPipeline with a list of layers.
12+
Initializes the pipeline with a list of preprocessing layers.
1013
1114
Args:
12-
layers (list): A list of Keras preprocessing layers, each potentially named for reference.
15+
layers (list): A list of TensorFlow preprocessing layers.
1316
"""
1417
self.layers = layers
18+
self.dependency_map = self._analyze_dependencies()
1519

16-
def initialize_and_transform(self, init_data):
20+
def _analyze_dependencies(self):
1721
"""
18-
Sequentially processes each layer, applying transformations selectively based on each
19-
layer's input requirements and ensuring efficient data usage and processing. Each layer
20-
can access the outputs of all previous layers.
21-
22-
Args:
23-
init_data (dict): A dictionary with initialization data, dynamically keyed.
22+
Analyzes and determines the dependencies of each layer on the outputs of previous layers.
2423
2524
Returns:
26-
dict: The dictionary containing selectively transformed data for each layer.
25+
dict: A dictionary mapping each layer's name to the set of layer outputs it depends on.
2726
"""
28-
current_data = init_data
29-
27+
dependencies = {}
28+
all_outputs = set()
3029
for i, layer in enumerate(self.layers):
31-
# For many layers we may not have a formal input_spec, so assume the layer uses all current data.
32-
required_keys = current_data.keys()
30+
# If the layer has an input_spec (which is common in Keras layers) we inspect it.
31+
if hasattr(layer, "input_spec") and layer.input_spec is not None:
32+
# Use a safe getter so that if an element does not have a 'name' attribute, we get None.
33+
# Then filter out the Nones.
34+
required_inputs = set(
35+
[
36+
name
37+
for name in tf.nest.flatten(
38+
tf.nest.map_structure(
39+
lambda x: getattr(x, "name", None), layer.input_spec
40+
)
41+
)
42+
if name is not None
43+
]
44+
)
45+
else:
46+
# Otherwise, assume that the layer depends on all outputs seen so far.
47+
required_inputs = all_outputs
48+
dependencies[layer.name] = required_inputs
49+
all_outputs.update(required_inputs)
50+
all_outputs.add(layer.name)
51+
return dependencies
3352

34-
# Prepare input for the current layer based on the determined keys.
35-
# Here, we assume that each layer accepts a dictionary of inputs.
36-
current_input = {k: current_data[k] for k in required_keys}
53+
def process(self, dataset):
54+
"""
55+
Processes the dataset through the pipeline using tf.data API.
56+
57+
Args:
58+
dataset (tf.data.Dataset): The dataset where each element is a dictionary of features.
3759
38-
# Apply transformation: if the layer returns a tensor, wrap it in a dict using the layer name.
39-
transformed_output = layer(current_input)
40-
if not isinstance(transformed_output, dict):
41-
transformed_output = {layer.name: transformed_output}
60+
Returns:
61+
tf.data.Dataset: The processed dataset with outputs of each layer stored by key.
62+
"""
4263

43-
# Update the current data with the transformed output so that subsequent layers can reuse it.
44-
current_data.update(transformed_output)
64+
def _apply_transformations(features):
65+
current_data = features
66+
for i, layer in enumerate(self.layers):
67+
# Get the required input keys for the current layer.
68+
required_keys = self.dependency_map[layer.name]
69+
# Prepare the input by selecting the keys if they exist in the current data.
70+
current_input = {
71+
k: current_data[k] for k in required_keys if k in current_data
72+
}
73+
# Process each required input through the layer.
74+
# Here we assume the layer accepts one tensor per key.
75+
transformed_output = {
76+
layer.name: layer(current_input[k])
77+
for k in required_keys
78+
if k in current_input
79+
}
80+
# Merge transformed output into the working data dictionary.
81+
current_data.update(transformed_output)
82+
return current_data
4583

46-
return current_data
84+
return dataset.map(_apply_transformations)

0 commit comments

Comments
 (0)