Skip to content

Commit a700bad

Browse files
Add preserve dtype layer and update passthrough feature handling
Co-authored-by: piotr.laczkowski <piotr.laczkowski@gmail.com>
1 parent b9d237e commit a700bad

File tree

8 files changed

+1847
-861
lines changed

8 files changed

+1847
-861
lines changed

PASSTHROUGH_FIX_SUMMARY.md

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
# PASSTHROUGH Feature Fix Summary
2+
3+
## Problem Description
4+
5+
The PASSTHROUGH features in KDP were being incorrectly cast to float32 during processing, which prevented them from preserving their original data types (strings, integers, etc.). This was problematic because passthrough features are designed to pass data through the pipeline without any preprocessing modifications.
6+
7+
## Root Cause
8+
9+
In the `_add_pipeline_passthrough` method in `kdp/processor.py` (line ~1416), the code was automatically casting all passthrough features to float32:
10+
11+
```python
12+
# For passthrough features, we only ensure type consistency by casting to float32
13+
preprocessor.add_processing_step(
14+
layer_creator=PreprocessorLayerFactory.cast_to_float32_layer,
15+
name=f"cast_to_float_{feature_name}",
16+
)
17+
```
18+
19+
This was incorrect because passthrough features should preserve their original data type.
20+
21+
## Solution
22+
23+
### 1. Created a New Layer: `PreserveDtypeLayer`
24+
25+
**File**: `kdp/layers/preserve_dtype.py`
26+
27+
This new layer can either:
28+
- Preserve the original dtype when `target_dtype=None` (default behavior)
29+
- Cast to a specific dtype when `target_dtype` is specified
30+
31+
```python
32+
@tf.keras.utils.register_keras_serializable(package="kdp.layers")
33+
class PreserveDtypeLayer(keras.layers.Layer):
34+
def __init__(self, target_dtype=None, **kwargs):
35+
super().__init__(**kwargs)
36+
self.target_dtype = target_dtype
37+
38+
def call(self, inputs, **kwargs):
39+
if self.target_dtype is not None:
40+
return tf.cast(inputs, self.target_dtype)
41+
return inputs
42+
```
43+
44+
### 2. Added Factory Method
45+
46+
**File**: `kdp/layers_factory.py`
47+
48+
Added a new factory method to create `PreserveDtypeLayer` instances:
49+
50+
```python
51+
@staticmethod
52+
def preserve_dtype_layer(
53+
name: str = "preserve_dtype", target_dtype=None, **kwargs: dict
54+
) -> tf.keras.layers.Layer:
55+
"""Create a PreserveDtypeLayer layer."""
56+
return PreprocessorLayerFactory.create_layer(
57+
layer_class=PreserveDtypeLayer,
58+
name=name,
59+
target_dtype=target_dtype,
60+
**kwargs,
61+
)
62+
```
63+
64+
### 3. Updated Processor Logic
65+
66+
**File**: `kdp/processor.py`
67+
68+
Modified the `_add_pipeline_passthrough` method to use the new `PreserveDtypeLayer` instead of casting to float32:
69+
70+
```python
71+
# For passthrough features, preserve the original dtype or cast to specified dtype
72+
target_dtype = getattr(_feature, 'dtype', None)
73+
preprocessor.add_processing_step(
74+
layer_creator=PreprocessorLayerFactory.preserve_dtype_layer,
75+
name=f"preserve_dtype_{feature_name}",
76+
target_dtype=target_dtype,
77+
)
78+
```
79+
80+
## Testing
81+
82+
### 1. Unit Tests for PreserveDtypeLayer
83+
84+
**File**: `test/layers/test_preserve_dtype_layer.py`
85+
86+
Comprehensive tests covering:
87+
- Preserving original dtypes (string, int, float)
88+
- Casting to target dtypes
89+
- Batch processing
90+
- Serialization/deserialization
91+
- Model integration
92+
93+
### 2. Factory Method Tests
94+
95+
**File**: `test/layers/test_layer_factory.py`
96+
97+
Added tests for the new `preserve_dtype_layer` factory method.
98+
99+
### 3. Integration Tests
100+
101+
**File**: `test/test_processor.py`
102+
103+
Added comprehensive tests for passthrough features:
104+
- `test_passthrough_feature_preserves_string_dtype`
105+
- `test_passthrough_feature_preserves_int_dtype`
106+
- `test_passthrough_feature_preserves_float_dtype`
107+
- `test_passthrough_feature_mixed_types`
108+
109+
### 4. Simple Test Script
110+
111+
**File**: `test_passthrough_fix.py`
112+
113+
A standalone test script that can be run without the full test environment to verify the fix works correctly.
114+
115+
## Usage Examples
116+
117+
### String Passthrough Feature
118+
119+
```python
120+
from kdp.features import PassthroughFeature, FeatureType
121+
import tensorflow as tf
122+
123+
# Create a string passthrough feature
124+
string_feature = PassthroughFeature(
125+
name="string_feature",
126+
feature_type=FeatureType.PASSTHROUGH,
127+
dtype=tf.string,
128+
)
129+
130+
# The feature will now preserve its string dtype through the pipeline
131+
```
132+
133+
### Integer Passthrough Feature
134+
135+
```python
136+
# Create an integer passthrough feature
137+
int_feature = PassthroughFeature(
138+
name="int_feature",
139+
feature_type=FeatureType.PASSTHROUGH,
140+
dtype=tf.int32,
141+
)
142+
143+
# The feature will now preserve its int32 dtype through the pipeline
144+
```
145+
146+
### Mixed Types
147+
148+
```python
149+
features = {
150+
"string_feature": PassthroughFeature(
151+
name="string_feature",
152+
feature_type=FeatureType.PASSTHROUGH,
153+
dtype=tf.string,
154+
),
155+
"int_feature": PassthroughFeature(
156+
name="int_feature",
157+
feature_type=FeatureType.PASSTHROUGH,
158+
dtype=tf.int32,
159+
),
160+
"float_feature": PassthroughFeature(
161+
name="float_feature",
162+
feature_type=FeatureType.PASSTHROUGH,
163+
dtype=tf.float64,
164+
),
165+
}
166+
167+
# All features will preserve their respective dtypes
168+
```
169+
170+
## Benefits
171+
172+
1. **Data Type Preservation**: Passthrough features now correctly preserve their original data types
173+
2. **Backward Compatibility**: Existing code continues to work, but now with correct behavior
174+
3. **Flexibility**: The `PreserveDtypeLayer` can be used for both preserving and casting dtypes as needed
175+
4. **Comprehensive Testing**: Full test coverage ensures the fix works correctly
176+
177+
## Running Tests
178+
179+
### Full Test Suite (requires TensorFlow)
180+
```bash
181+
# Install dependencies
182+
poetry install
183+
184+
# Run all tests
185+
poetry run pytest
186+
187+
# Run specific test categories
188+
poetry run pytest -m "layers" # Layer tests
189+
poetry run pytest test/test_processor.py::TestPreprocessingModel::test_passthrough_feature_preserves_string_dtype
190+
```
191+
192+
### Simple Test Script
193+
```bash
194+
python3 test_passthrough_fix.py
195+
```
196+
197+
## Files Modified
198+
199+
1. `kdp/layers/preserve_dtype.py` - New layer implementation
200+
2. `kdp/layers_factory.py` - Added factory method
201+
3. `kdp/processor.py` - Updated passthrough processing logic
202+
4. `test/layers/test_preserve_dtype_layer.py` - New unit tests
203+
5. `test/layers/test_layer_factory.py` - Added factory tests
204+
6. `test/test_processor.py` - Added integration tests
205+
7. `test_passthrough_fix.py` - Standalone test script
206+
207+
## Verification
208+
209+
The fix ensures that:
210+
- String passthrough features remain as strings
211+
- Integer passthrough features remain as integers
212+
- Float passthrough features remain as floats
213+
- Mixed data types are handled correctly
214+
- The pipeline continues to work for all other feature types
215+
- No breaking changes to existing functionality

kdp/layers/preserve_dtype.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import tensorflow as tf
2+
from tensorflow import keras
3+
4+
5+
@tf.keras.utils.register_keras_serializable(package="kdp.layers")
6+
class PreserveDtypeLayer(keras.layers.Layer):
7+
"""Custom Keras layer that preserves the original dtype of input tensors.
8+
9+
This is useful for passthrough features where we want to maintain the original
10+
data type without any casting.
11+
"""
12+
13+
def __init__(self, target_dtype=None, **kwargs):
14+
"""Initialize the layer.
15+
16+
Args:
17+
target_dtype: Optional target dtype to cast to. If None, preserves original dtype.
18+
**kwargs: Additional keyword arguments
19+
"""
20+
super().__init__(**kwargs)
21+
self.target_dtype = target_dtype
22+
23+
def call(self, inputs, **kwargs):
24+
"""Process the input tensor, optionally casting to target_dtype.
25+
26+
Args:
27+
inputs: Input tensor of any dtype
28+
**kwargs: Additional keyword arguments
29+
30+
Returns:
31+
Tensor with preserved or target dtype
32+
"""
33+
if self.target_dtype is not None:
34+
return tf.cast(inputs, self.target_dtype)
35+
return inputs
36+
37+
def get_config(self):
38+
"""Return the config dictionary for serialization.
39+
40+
Returns:
41+
A dictionary with the layer configuration
42+
"""
43+
config = super().get_config()
44+
config.update({
45+
'target_dtype': self.target_dtype
46+
})
47+
return config
48+
49+
@classmethod
50+
def from_config(cls, config):
51+
"""Create a new instance from the serialized configuration.
52+
53+
Args:
54+
config: Layer configuration dictionary
55+
56+
Returns:
57+
A new instance of the layer
58+
"""
59+
return cls(**config)

kdp/layers_factory.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from kdp.layers.text_preprocessing_layer import TextPreprocessingLayer
1111
from kdp.layers.cast_to_float import CastToFloat32Layer
12+
from kdp.layers.preserve_dtype import PreserveDtypeLayer
1213
from kdp.layers.date_parsing_layer import DateParsingLayer
1314
from kdp.layers.date_encoding_layer import DateEncodingLayer
1415
from kdp.layers.season_layer import SeasonLayer
@@ -183,6 +184,27 @@ def cast_to_float32_layer(
183184
**kwargs,
184185
)
185186

187+
@staticmethod
188+
def preserve_dtype_layer(
189+
name: str = "preserve_dtype", target_dtype=None, **kwargs: dict
190+
) -> tf.keras.layers.Layer:
191+
"""Create a PreserveDtypeLayer layer.
192+
193+
Args:
194+
name: The name of the layer.
195+
target_dtype: Optional target dtype to cast to. If None, preserves original dtype.
196+
**kwargs: Additional keyword arguments to pass to the layer constructor.
197+
198+
Returns:
199+
An instance of the PreserveDtypeLayer layer.
200+
"""
201+
return PreprocessorLayerFactory.create_layer(
202+
layer_class=PreserveDtypeLayer,
203+
name=name,
204+
target_dtype=target_dtype,
205+
**kwargs,
206+
)
207+
186208
@staticmethod
187209
def date_parsing_layer(
188210
name: str = "date_parsing_layer", **kwargs: dict

kdp/processor.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1414,10 +1414,12 @@ def _add_pipeline_passthrough(self, feature_name: str, input_layer) -> None:
14141414
feature_name=feature_name,
14151415
)
14161416
else:
1417-
# For passthrough features, we only ensure type consistency by casting to float32
1417+
# For passthrough features, preserve the original dtype or cast to specified dtype
1418+
target_dtype = getattr(_feature, 'dtype', None)
14181419
preprocessor.add_processing_step(
1419-
layer_creator=PreprocessorLayerFactory.cast_to_float32_layer,
1420-
name=f"cast_to_float_{feature_name}",
1420+
layer_creator=PreprocessorLayerFactory.preserve_dtype_layer,
1421+
name=f"preserve_dtype_{feature_name}",
1422+
target_dtype=target_dtype,
14211423
)
14221424

14231425
# Optionally reshape if needed

0 commit comments

Comments
 (0)