Skip to content

Commit 83f6996

Browse files
committed
feat(KDP): global embedding for numeric features option added
1 parent 55cbbb3 commit 83f6996

File tree

2 files changed

+151
-11
lines changed

2 files changed

+151
-11
lines changed

kdp/custom_layers.py

Lines changed: 119 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2046,17 +2046,22 @@ def build(self, input_shape):
20462046
init_min_tensor = tf.fill([self.num_features], init_min_tensor)
20472047
if init_max_tensor.shape.ndims == 0:
20482048
init_max_tensor = tf.fill([self.num_features], init_max_tensor)
2049-
# Convert tensors to numpy arrays, which are acceptable by tf.constant_initializer.
2050-
init_min_value = (
2051-
init_min_tensor.numpy()
2052-
if hasattr(init_min_tensor, "numpy")
2053-
else init_min_tensor
2054-
)
2055-
init_max_value = (
2056-
init_max_tensor.numpy()
2057-
if hasattr(init_max_tensor, "numpy")
2058-
else init_max_tensor
2059-
)
2049+
2050+
if tf.executing_eagerly():
2051+
init_min_value = init_min_tensor.numpy()
2052+
init_max_value = init_max_tensor.numpy()
2053+
else:
2054+
# Fallback: if not executing eagerly, force conversion to list
2055+
init_min_value = (
2056+
init_min_tensor.numpy().tolist()
2057+
if hasattr(init_min_tensor, "numpy")
2058+
else self.init_min
2059+
)
2060+
init_max_value = (
2061+
init_max_tensor.numpy().tolist()
2062+
if hasattr(init_max_tensor, "numpy")
2063+
else self.init_max
2064+
)
20602065

20612066
self.learned_min = self.add_weight(
20622067
name="learned_min",
@@ -2133,3 +2138,106 @@ def get_config(self):
21332138
}
21342139
)
21352140
return config
2141+
2142+
2143+
class GlobalAdvancedNumericalEmbedding(tf.keras.layers.Layer):
2144+
"""
2145+
Global AdvancedNumericalEmbedding processes concatenated numeric features.
2146+
It applies an inner AdvancedNumericalEmbedding over the flattened input and then
2147+
performs global pooling (average or max) to produce a compact representation.
2148+
"""
2149+
2150+
def __init__(
2151+
self,
2152+
global_embedding_dim: int,
2153+
global_mlp_hidden_units: int,
2154+
global_num_bins: int,
2155+
global_init_min,
2156+
global_init_max,
2157+
global_dropout_rate: float,
2158+
global_use_batch_norm: bool,
2159+
global_pooling: str = "average",
2160+
**kwargs,
2161+
):
2162+
super().__init__(**kwargs)
2163+
self.global_embedding_dim = global_embedding_dim
2164+
self.global_mlp_hidden_units = global_mlp_hidden_units
2165+
self.global_num_bins = global_num_bins
2166+
2167+
# Ensure initializer parameters are Python scalars, lists, or numpy arrays.
2168+
if not isinstance(global_init_min, (list, tuple, np.ndarray)):
2169+
try:
2170+
global_init_min = float(global_init_min)
2171+
except Exception:
2172+
raise ValueError(
2173+
"init_min must be a Python scalar, list, tuple or numpy array"
2174+
)
2175+
if not isinstance(global_init_max, (list, tuple, np.ndarray)):
2176+
try:
2177+
global_init_max = float(global_init_max)
2178+
except Exception:
2179+
raise ValueError(
2180+
"init_max must be a Python scalar, list, tuple or numpy array"
2181+
)
2182+
self.global_init_min = global_init_min
2183+
self.global_init_max = global_init_max
2184+
self.global_dropout_rate = global_dropout_rate
2185+
self.global_use_batch_norm = global_use_batch_norm
2186+
self.global_pooling = global_pooling
2187+
2188+
# Use the existing advanced numerical embedding block
2189+
self.inner_embedding = AdvancedNumericalEmbedding(
2190+
embedding_dim=self.global_embedding_dim,
2191+
mlp_hidden_units=self.global_mlp_hidden_units,
2192+
num_bins=self.global_num_bins,
2193+
init_min=self.global_init_min,
2194+
init_max=self.global_init_max,
2195+
dropout_rate=self.global_dropout_rate,
2196+
use_batch_norm=self.global_use_batch_norm,
2197+
name="global_numeric_emebedding",
2198+
)
2199+
if self.global_pooling == "average":
2200+
self.global_pooling_layer = tf.keras.layers.GlobalAveragePooling1D(
2201+
name="global_avg_pool"
2202+
)
2203+
elif self.global_pooling == "max":
2204+
self.global_pooling_layer = tf.keras.layers.GlobalMaxPooling1D(
2205+
name="global_max_pool"
2206+
)
2207+
else:
2208+
raise ValueError(f"Unsupported pooling method: {self.global_pooling}")
2209+
2210+
def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
2211+
"""
2212+
Expects inputs with shape (batch, ...) and flattens them (except for the batch dim).
2213+
Then, the inner embedding produces a 3D output (batch, num_features, embedding_dim),
2214+
which is finally pooled to yield (batch, embedding_dim).
2215+
"""
2216+
# If inputs have more than 2 dimensions, flatten them (except for batch dimension).
2217+
if len(inputs.shape) > 2:
2218+
inputs = tf.reshape(inputs, (tf.shape(inputs)[0], -1))
2219+
# Pass through the inner advanced embedding.
2220+
x_embedded = self.inner_embedding(inputs, training=training)
2221+
# Global pooling over numeric features axis.
2222+
x_pooled = self.global_pooling_layer(x_embedded)
2223+
return x_pooled
2224+
2225+
def compute_output_shape(self, input_shape):
2226+
# Regardless of the input shape, the output shape is (batch_size, embedding_dim)
2227+
return (input_shape[0], self.global_embedding_dim)
2228+
2229+
def get_config(self):
2230+
config = super().get_config()
2231+
config.update(
2232+
{
2233+
"global_embedding_dim": self.global_embedding_dim,
2234+
"global_mlp_hidden_units": self.global_mlp_hidden_units,
2235+
"global_num_bins": self.global_num_bins,
2236+
"global_init_min": self.global_init_min,
2237+
"global_init_max": self.global_init_max,
2238+
"global_dropout_rate": self.global_dropout_rate,
2239+
"global_use_batch_norm": self.global_use_batch_norm,
2240+
"global_pooling": self.global_pooling,
2241+
}
2242+
)
2243+
return config

kdp/processor.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import tensorflow as tf
1313
from loguru import logger
1414

15+
from kdp.custom_layers import GlobalAdvancedNumericalEmbedding
1516
from kdp.features import (
1617
CategoricalFeature,
1718
CategoryEncodingOptions,
@@ -200,6 +201,15 @@ def __init__(
200201
init_max: float = 3.0,
201202
dropout_rate: float = 0.1,
202203
use_batch_norm: bool = True,
204+
use_global_numerical_embedding: bool = False,
205+
global_embedding_dim: int = 8,
206+
global_mlp_hidden_units: int = 16,
207+
global_num_bins: int = 10,
208+
global_init_min: float = -3.0,
209+
global_init_max: float = 3.0,
210+
global_dropout_rate: float = 0.1,
211+
global_use_batch_norm: bool = True,
212+
global_pooling: str = "average",
203213
) -> None:
204214
"""Initialize a preprocessing model.
205215
@@ -282,6 +292,17 @@ def __init__(
282292
self.dropout_rate = dropout_rate
283293
self.use_batch_norm = use_batch_norm
284294

295+
# advanced global numerical embedding control
296+
self.use_global_numerical_embedding = use_global_numerical_embedding
297+
self.global_embedding_dim = global_embedding_dim
298+
self.global_mlp_hidden_units = global_mlp_hidden_units
299+
self.global_num_bins = global_num_bins
300+
self.global_init_min = global_init_min
301+
self.global_init_max = global_init_max
302+
self.global_dropout_rate = global_dropout_rate
303+
self.global_use_batch_norm = global_use_batch_norm
304+
self.global_pooling = global_pooling
305+
285306
# PLACEHOLDERS
286307
self.preprocessors = {}
287308
self.inputs = {}
@@ -1046,6 +1067,17 @@ def _prepare_outputs(self) -> None:
10461067
name="ConcatenateNumeric",
10471068
axis=-1,
10481069
)(numeric_features)
1070+
if self.use_global_numerical_embedding:
1071+
concat_num = GlobalAdvancedNumericalEmbedding(
1072+
global_embedding_dim=self.global_embedding_dim,
1073+
global_mlp_hidden_units=self.global_mlp_hidden_units,
1074+
global_num_bins=self.global_num_bins,
1075+
global_init_min=self.global_init_min,
1076+
global_init_max=self.global_init_max,
1077+
global_dropout_rate=self.global_dropout_rate,
1078+
global_use_batch_norm=self.global_use_batch_norm,
1079+
global_pooling=self.global_pooling,
1080+
)(concat_num)
10491081
else:
10501082
concat_num = None
10511083

0 commit comments

Comments
 (0)