@@ -2046,17 +2046,22 @@ def build(self, input_shape):
20462046 init_min_tensor = tf .fill ([self .num_features ], init_min_tensor )
20472047 if init_max_tensor .shape .ndims == 0 :
20482048 init_max_tensor = tf .fill ([self .num_features ], init_max_tensor )
2049- # Convert tensors to numpy arrays, which are acceptable by tf.constant_initializer.
2050- init_min_value = (
2051- init_min_tensor .numpy ()
2052- if hasattr (init_min_tensor , "numpy" )
2053- else init_min_tensor
2054- )
2055- init_max_value = (
2056- init_max_tensor .numpy ()
2057- if hasattr (init_max_tensor , "numpy" )
2058- else init_max_tensor
2059- )
2049+
2050+ if tf .executing_eagerly ():
2051+ init_min_value = init_min_tensor .numpy ()
2052+ init_max_value = init_max_tensor .numpy ()
2053+ else :
2054+ # Fallback: if not executing eagerly, force conversion to list
2055+ init_min_value = (
2056+ init_min_tensor .numpy ().tolist ()
2057+ if hasattr (init_min_tensor , "numpy" )
2058+ else self .init_min
2059+ )
2060+ init_max_value = (
2061+ init_max_tensor .numpy ().tolist ()
2062+ if hasattr (init_max_tensor , "numpy" )
2063+ else self .init_max
2064+ )
20602065
20612066 self .learned_min = self .add_weight (
20622067 name = "learned_min" ,
@@ -2133,3 +2138,106 @@ def get_config(self):
21332138 }
21342139 )
21352140 return config
2141+
2142+
2143+ class GlobalAdvancedNumericalEmbedding (tf .keras .layers .Layer ):
2144+ """
2145+ Global AdvancedNumericalEmbedding processes concatenated numeric features.
2146+ It applies an inner AdvancedNumericalEmbedding over the flattened input and then
2147+ performs global pooling (average or max) to produce a compact representation.
2148+ """
2149+
2150+ def __init__ (
2151+ self ,
2152+ global_embedding_dim : int ,
2153+ global_mlp_hidden_units : int ,
2154+ global_num_bins : int ,
2155+ global_init_min ,
2156+ global_init_max ,
2157+ global_dropout_rate : float ,
2158+ global_use_batch_norm : bool ,
2159+ global_pooling : str = "average" ,
2160+ ** kwargs ,
2161+ ):
2162+ super ().__init__ (** kwargs )
2163+ self .global_embedding_dim = global_embedding_dim
2164+ self .global_mlp_hidden_units = global_mlp_hidden_units
2165+ self .global_num_bins = global_num_bins
2166+
2167+ # Ensure initializer parameters are Python scalars, lists, or numpy arrays.
2168+ if not isinstance (global_init_min , (list , tuple , np .ndarray )):
2169+ try :
2170+ global_init_min = float (global_init_min )
2171+ except Exception :
2172+ raise ValueError (
2173+ "init_min must be a Python scalar, list, tuple or numpy array"
2174+ )
2175+ if not isinstance (global_init_max , (list , tuple , np .ndarray )):
2176+ try :
2177+ global_init_max = float (global_init_max )
2178+ except Exception :
2179+ raise ValueError (
2180+ "init_max must be a Python scalar, list, tuple or numpy array"
2181+ )
2182+ self .global_init_min = global_init_min
2183+ self .global_init_max = global_init_max
2184+ self .global_dropout_rate = global_dropout_rate
2185+ self .global_use_batch_norm = global_use_batch_norm
2186+ self .global_pooling = global_pooling
2187+
2188+ # Use the existing advanced numerical embedding block
2189+ self .inner_embedding = AdvancedNumericalEmbedding (
2190+ embedding_dim = self .global_embedding_dim ,
2191+ mlp_hidden_units = self .global_mlp_hidden_units ,
2192+ num_bins = self .global_num_bins ,
2193+ init_min = self .global_init_min ,
2194+ init_max = self .global_init_max ,
2195+ dropout_rate = self .global_dropout_rate ,
2196+ use_batch_norm = self .global_use_batch_norm ,
2197+ name = "global_numeric_emebedding" ,
2198+ )
2199+ if self .global_pooling == "average" :
2200+ self .global_pooling_layer = tf .keras .layers .GlobalAveragePooling1D (
2201+ name = "global_avg_pool"
2202+ )
2203+ elif self .global_pooling == "max" :
2204+ self .global_pooling_layer = tf .keras .layers .GlobalMaxPooling1D (
2205+ name = "global_max_pool"
2206+ )
2207+ else :
2208+ raise ValueError (f"Unsupported pooling method: { self .global_pooling } " )
2209+
2210+ def call (self , inputs : tf .Tensor , training : bool = False ) -> tf .Tensor :
2211+ """
2212+ Expects inputs with shape (batch, ...) and flattens them (except for the batch dim).
2213+ Then, the inner embedding produces a 3D output (batch, num_features, embedding_dim),
2214+ which is finally pooled to yield (batch, embedding_dim).
2215+ """
2216+ # If inputs have more than 2 dimensions, flatten them (except for batch dimension).
2217+ if len (inputs .shape ) > 2 :
2218+ inputs = tf .reshape (inputs , (tf .shape (inputs )[0 ], - 1 ))
2219+ # Pass through the inner advanced embedding.
2220+ x_embedded = self .inner_embedding (inputs , training = training )
2221+ # Global pooling over numeric features axis.
2222+ x_pooled = self .global_pooling_layer (x_embedded )
2223+ return x_pooled
2224+
2225+ def compute_output_shape (self , input_shape ):
2226+ # Regardless of the input shape, the output shape is (batch_size, embedding_dim)
2227+ return (input_shape [0 ], self .global_embedding_dim )
2228+
2229+ def get_config (self ):
2230+ config = super ().get_config ()
2231+ config .update (
2232+ {
2233+ "global_embedding_dim" : self .global_embedding_dim ,
2234+ "global_mlp_hidden_units" : self .global_mlp_hidden_units ,
2235+ "global_num_bins" : self .global_num_bins ,
2236+ "global_init_min" : self .global_init_min ,
2237+ "global_init_max" : self .global_init_max ,
2238+ "global_dropout_rate" : self .global_dropout_rate ,
2239+ "global_use_batch_norm" : self .global_use_batch_norm ,
2240+ "global_pooling" : self .global_pooling ,
2241+ }
2242+ )
2243+ return config
0 commit comments