Skip to content

Commit 185292c

Browse files
feat(KDP): Integrate Advanced Numerical Embedding (#25)
2 parents bd90f11 + 55cbbb3 commit 185292c

File tree

5 files changed

+547
-6
lines changed

5 files changed

+547
-6
lines changed

kdp/custom_layers.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
import numpy as np
77
import tensorflow as tf
88
import tensorflow_probability as tfp
9+
from tensorflow.keras import layers
10+
11+
from loguru import logger
912

1013

1114
class TextPreprocessingLayer(tf.keras.layers.Layer):
@@ -1945,3 +1948,188 @@ def from_config(cls, config: dict) -> "VariableSelection":
19451948
VariableSelection: A new instance of the layer.
19461949
"""
19471950
return cls(**config)
1951+
1952+
1953+
class AdvancedNumericalEmbedding(layers.Layer):
1954+
"""Advanced numerical embedding layer for continuous features.
1955+
1956+
This layer embeds each continuous numerical feature into a higher-dimensional space by
1957+
combining two branches:
1958+
1959+
1. Continuous Branch: Each feature is processed via a small MLP (using TimeDistributed layers).
1960+
2. Discrete Branch: Each feature is discretized into bins using learnable min/max boundaries
1961+
and then an embedding is looked up for its bin.
1962+
1963+
A learnable gate (of shape (num_features, embedding_dim)) combines the two branch outputs
1964+
per feature and per embedding dimension. Additionally, the continuous branch uses a residual
1965+
connection and optional batch normalization to improve training stability.
1966+
1967+
The layer supports inputs of shape (batch, num_features) for any number of features and returns
1968+
outputs of shape (batch, num_features, embedding_dim).
1969+
1970+
Args:
1971+
embedding_dim (int): Output embedding dimension per feature.
1972+
mlp_hidden_units (int): Hidden units for the continuous branch MLP.
1973+
num_bins (int): Number of bins for discretization.
1974+
init_min (float or list): Initial minimum values for discretization boundaries. If a scalar is
1975+
provided, it is applied to all features.
1976+
init_max (float or list): Initial maximum values for discretization boundaries.
1977+
dropout_rate (float): Dropout rate applied to the continuous branch.
1978+
use_batch_norm (bool): Whether to apply batch normalization to the continuous branch.
1979+
1980+
"""
1981+
1982+
def __init__(
1983+
self,
1984+
embedding_dim: int,
1985+
mlp_hidden_units: int,
1986+
num_bins: int,
1987+
init_min,
1988+
init_max,
1989+
dropout_rate: float = 0.0,
1990+
use_batch_norm: bool = False,
1991+
**kwargs,
1992+
):
1993+
super().__init__(**kwargs)
1994+
self.embedding_dim = embedding_dim
1995+
self.mlp_hidden_units = mlp_hidden_units
1996+
self.num_bins = num_bins
1997+
self.dropout_rate = dropout_rate
1998+
self.use_batch_norm = use_batch_norm
1999+
self.init_min = init_min
2000+
self.init_max = init_max
2001+
2002+
if self.num_bins is None:
2003+
raise ValueError(
2004+
"num_bins must be provided to activate the discrete branch."
2005+
)
2006+
2007+
def build(self, input_shape):
2008+
# input_shape: (batch, num_features)
2009+
self.num_features = input_shape[-1]
2010+
# Continuous branch: process each feature independently using TimeDistributed MLP.
2011+
self.cont_mlp = tf.keras.Sequential(
2012+
[
2013+
layers.TimeDistributed(
2014+
layers.Dense(self.mlp_hidden_units, activation="relu")
2015+
),
2016+
layers.TimeDistributed(layers.Dense(self.embedding_dim)),
2017+
],
2018+
name="cont_mlp",
2019+
)
2020+
self.dropout = (
2021+
layers.Dropout(self.dropout_rate)
2022+
if self.dropout_rate > 0
2023+
else lambda x, training: x
2024+
)
2025+
if self.use_batch_norm:
2026+
self.batch_norm = layers.TimeDistributed(
2027+
layers.BatchNormalization(), name="cont_batch_norm"
2028+
)
2029+
# Residual projection to match embedding_dim.
2030+
self.residual_proj = layers.TimeDistributed(
2031+
layers.Dense(self.embedding_dim, activation=None), name="residual_proj"
2032+
)
2033+
# Discrete branch: Create one Embedding layer per feature.
2034+
self.bin_embeddings = []
2035+
for i in range(self.num_features):
2036+
embed_layer = layers.Embedding(
2037+
input_dim=self.num_bins,
2038+
output_dim=self.embedding_dim,
2039+
name=f"bin_embed_{i}",
2040+
)
2041+
self.bin_embeddings.append(embed_layer)
2042+
# Learned bin boundaries for each feature, shape: (num_features,)
2043+
init_min_tensor = tf.convert_to_tensor(self.init_min, dtype=tf.float32)
2044+
init_max_tensor = tf.convert_to_tensor(self.init_max, dtype=tf.float32)
2045+
if init_min_tensor.shape.ndims == 0:
2046+
init_min_tensor = tf.fill([self.num_features], init_min_tensor)
2047+
if init_max_tensor.shape.ndims == 0:
2048+
init_max_tensor = tf.fill([self.num_features], init_max_tensor)
2049+
# Convert tensors to numpy arrays, which are acceptable by tf.constant_initializer.
2050+
init_min_value = (
2051+
init_min_tensor.numpy()
2052+
if hasattr(init_min_tensor, "numpy")
2053+
else init_min_tensor
2054+
)
2055+
init_max_value = (
2056+
init_max_tensor.numpy()
2057+
if hasattr(init_max_tensor, "numpy")
2058+
else init_max_tensor
2059+
)
2060+
2061+
self.learned_min = self.add_weight(
2062+
name="learned_min",
2063+
shape=(self.num_features,),
2064+
initializer=tf.constant_initializer(init_min_value),
2065+
trainable=True,
2066+
)
2067+
self.learned_max = self.add_weight(
2068+
name="learned_max",
2069+
shape=(self.num_features,),
2070+
initializer=tf.constant_initializer(init_max_value),
2071+
trainable=True,
2072+
)
2073+
# Gate to combine continuous and discrete branches, shape: (num_features, embedding_dim)
2074+
self.gate = self.add_weight(
2075+
name="gate",
2076+
shape=(self.num_features, self.embedding_dim),
2077+
initializer="zeros",
2078+
trainable=True,
2079+
)
2080+
logger.debug(
2081+
"AdvancedNumericalEmbedding built for {} features with embedding_dim={}",
2082+
self.num_features,
2083+
self.embedding_dim,
2084+
)
2085+
super().build(input_shape)
2086+
2087+
def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
2088+
# Continuous branch.
2089+
inputs_expanded = tf.expand_dims(inputs, axis=-1) # (batch, num_features, 1)
2090+
cont = self.cont_mlp(inputs_expanded)
2091+
cont = self.dropout(cont, training=training)
2092+
if self.use_batch_norm:
2093+
cont = self.batch_norm(cont, training=training)
2094+
# Residual connection.
2095+
cont_res = self.residual_proj(inputs_expanded)
2096+
cont = cont + cont_res # (batch, num_features, embedding_dim)
2097+
2098+
# Discrete branch.
2099+
inputs_float = tf.cast(inputs, tf.float32)
2100+
# Use learned min and max for scaling.
2101+
scaled = (inputs_float - self.learned_min) / (
2102+
self.learned_max - self.learned_min + 1e-6
2103+
)
2104+
# Compute bin indices.
2105+
bin_indices = tf.floor(scaled * self.num_bins)
2106+
bin_indices = tf.cast(bin_indices, tf.int32)
2107+
bin_indices = tf.clip_by_value(bin_indices, 0, self.num_bins - 1)
2108+
disc_embeddings = []
2109+
for i in range(self.num_features):
2110+
feat_bins = bin_indices[:, i] # (batch,)
2111+
feat_embed = self.bin_embeddings[i](
2112+
feat_bins
2113+
) # i is a Python integer here.
2114+
disc_embeddings.append(feat_embed)
2115+
disc = tf.stack(disc_embeddings, axis=1) # (batch, num_features, embedding_dim)
2116+
2117+
# Combine branches via a per-feature, per-dimension gate.
2118+
gate = tf.nn.sigmoid(self.gate) # (num_features, embedding_dim)
2119+
output = gate * cont + (1 - gate) * disc # (batch, num_features, embedding_dim)
2120+
return output
2121+
2122+
def get_config(self):
2123+
config = super().get_config()
2124+
config.update(
2125+
{
2126+
"embedding_dim": self.embedding_dim,
2127+
"mlp_hidden_units": self.mlp_hidden_units,
2128+
"num_bins": self.num_bins,
2129+
"init_min": self.init_min,
2130+
"init_max": self.init_max,
2131+
"dropout_rate": self.dropout_rate,
2132+
"use_batch_norm": self.use_batch_norm,
2133+
}
2134+
)
2135+
return config

kdp/features.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,26 +117,52 @@ def from_string(type_str: str) -> "FeatureType":
117117

118118

119119
class NumericalFeature(Feature):
120-
"""NumericalFeature with dynamic kwargs passing."""
120+
"""NumericalFeature with dynamic kwargs passing and embedding support."""
121121

122122
def __init__(
123123
self,
124124
name: str,
125125
feature_type: FeatureType = FeatureType.FLOAT_NORMALIZED,
126126
preferred_distribution: DistributionType | None = None,
127+
use_embedding: bool = False,
128+
embedding_dim: int = 8,
129+
num_bins: int = 10,
127130
**kwargs,
128131
) -> None:
129132
"""Initializes a NumericalFeature instance.
130133
131134
Args:
132135
name (str): The name of the feature.
133136
feature_type (FeatureType): The type of the feature.
134-
preferred_distribution (DistributionType | None): The preferred distribution type for the feature.
137+
preferred_distribution (DistributionType | None): The preferred distribution type.
138+
use_embedding (bool): Whether to use advanced numerical embedding.
139+
embedding_dim (int): Dimension of the embedding space.
140+
num_bins (int): Number of bins for discretization.
135141
**kwargs: Additional keyword arguments for the feature.
136142
"""
137143
super().__init__(name, feature_type, **kwargs)
138144
self.dtype = tf.float32
139145
self.preferred_distribution = preferred_distribution
146+
self.use_embedding = use_embedding
147+
self.embedding_dim = embedding_dim
148+
self.num_bins = num_bins
149+
150+
def get_embedding_layer(self, input_shape: tuple) -> tf.keras.layers.Layer:
151+
"""Creates and returns an AdvancedNumericalEmbedding layer configured for this feature."""
152+
from kdp.custom_layers import (
153+
AdvancedNumericalEmbedding,
154+
) # Avoid circular import
155+
156+
return AdvancedNumericalEmbedding(
157+
embedding_dim=self.embedding_dim,
158+
mlp_hidden_units=max(16, self.embedding_dim * 2),
159+
num_bins=self.num_bins,
160+
init_min=self.kwargs.get("init_min", -3.0),
161+
init_max=self.kwargs.get("init_max", 3.0),
162+
dropout_rate=self.kwargs.get("dropout_rate", 0.1),
163+
use_batch_norm=self.kwargs.get("use_batch_norm", True),
164+
name=f"{self.name}_embedding",
165+
)
140166

141167

142168
class CategoricalFeature(Feature):

kdp/processor.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,14 @@ def __init__(
192192
distribution_aware_bins: int = 1000,
193193
feature_selection_units: int = 32,
194194
feature_selection_dropout: float = 0.2,
195+
use_advanced_numerical_embedding: bool = False,
196+
embedding_dim: int = 8,
197+
mlp_hidden_units: int = 16,
198+
num_bins: int = 10,
199+
init_min: float = -3.0,
200+
init_max: float = 3.0,
201+
dropout_rate: float = 0.1,
202+
use_batch_norm: bool = True,
195203
) -> None:
196204
"""Initialize a preprocessing model.
197205
@@ -225,6 +233,12 @@ def __init__(
225233
feature_selection_dropout (float): Dropout rate for feature selection.
226234
use_distribution_aware (bool): Whether to use distribution-aware encoding for features.
227235
distribution_aware_bins (int): Number of bins to use for distribution-aware encoding.
236+
use_advanced_numerical_embedding (bool): Whether to use advanced numerical embedding.
237+
embedding_dim (int): Dimension of the embedding for advanced numerical embedding.
238+
mlp_hidden_units (int): Number of units for the MLP in advanced numerical embedding.
239+
num_bins (int): Number of bins for discretization in advanced numerical embedding.
240+
init_min (float): Minimum value for the embedding in advanced numerical embedding.
241+
init_max (float): Maximum value for the embedding in advanced numerical embedding.
228242
"""
229243
self.path_data = path_data
230244
self.batch_size = batch_size or 50_000
@@ -258,6 +272,16 @@ def __init__(
258272
self.distribution_aware_bins = distribution_aware_bins
259273
self.feature_selection_dropout = feature_selection_dropout
260274

275+
# advanced numerical embedding control
276+
self.use_advanced_numerical_embedding = use_advanced_numerical_embedding
277+
self.embedding_dim = embedding_dim
278+
self.mlp_hidden_units = mlp_hidden_units
279+
self.num_bins = num_bins
280+
self.init_min = init_min
281+
self.init_max = init_max
282+
self.dropout_rate = dropout_rate
283+
self.use_batch_norm = use_batch_norm
284+
261285
# PLACEHOLDERS
262286
self.preprocessors = {}
263287
self.inputs = {}
@@ -576,13 +600,13 @@ def _add_pipeline_numeric(
576600
stats (dict): A dictionary containing the metadata of the feature, including
577601
the mean and variance of the feature.
578602
"""
579-
# getting feature object
603+
# Get the feature specifications
580604
_feature = self.features_specs[feature_name]
581605

582-
# initializing preprocessor
606+
# Initialize preprocessor
583607
preprocessor = FeaturePreprocessor(name=feature_name)
584608

585-
# Add cast to float32 first for all numeric features
609+
# First, cast to float32 is applied to all numeric features.
586610
preprocessor.add_processing_step(
587611
layer_creator=PreprocessorLayerFactory.cast_to_float32_layer,
588612
name=f"cast_to_float_{feature_name}",
@@ -676,10 +700,30 @@ def _add_pipeline_numeric(
676700
name=f"norm_{feature_name}",
677701
)
678702

703+
# Check for advanced numerical embedding.
704+
if self.use_advanced_numerical_embedding:
705+
logger.info(f"Using AdvancedNumericalEmbedding for {feature_name}")
706+
# Obtain the embedding layer.
707+
embedding_layer = _feature.get_embedding_layer(
708+
input_shape=input_layer.shape
709+
)
710+
preprocessor.add_processing_step(
711+
layer_creator=lambda **kwargs: embedding_layer,
712+
layer_class="AdvancedNumericalEmbedding",
713+
name=f"advanced_embedding_{feature_name}",
714+
embedding_dim=self.embedding_dim,
715+
mlp_hidden_units=self.mlp_hidden_units,
716+
num_bins=self.num_bins,
717+
init_min=self.init_min,
718+
init_max=self.init_max,
719+
dropout_rate=self.dropout_rate,
720+
use_batch_norm=self.use_batch_norm,
721+
)
722+
679723
# Process the feature
680724
_output_pipeline = preprocessor.chain(input_layer=input_layer)
681725

682-
# Apply feature selection if enabled for numeric features
726+
# Optionally, apply feature selection for numeric features.
683727
if (
684728
self.feature_selection_placement == FeatureSelectionPlacementOptions.NUMERIC
685729
or self.feature_selection_placement

0 commit comments

Comments
 (0)