From c3e9c90448b6eb71db7dcdec1c9a8ed7fc1dd54c Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Fri, 14 Feb 2025 18:46:53 +0100 Subject: [PATCH 1/2] add JohnsonSU and individual preprocessing --- mambular/preprocessing/preprocessor.py | 135 ++++++++++++++++++------- mambular/utils/distributions.py | 112 ++++++++++++++++++-- 2 files changed, 200 insertions(+), 47 deletions(-) diff --git a/mambular/preprocessing/preprocessor.py b/mambular/preprocessing/preprocessor.py index a6916495..7d8d59ce 100644 --- a/mambular/preprocessing/preprocessor.py +++ b/mambular/preprocessing/preprocessor.py @@ -40,6 +40,14 @@ class Preprocessor: Parameters ---------- + feature_preprocessing: dict or None + Dictionary mapping column names to preprocessing techniques. Example: + { + "num_feature1": "minmax", + "num_feature2": "ple", + "cat_feature1": "one-hot", + "cat_feature2": "int" + } n_bins : int, default=50 The number of bins to use for numerical feature binning. This parameter is relevant only if `numerical_preprocessing` is set to 'binning', 'ple' or 'one-hot'. @@ -94,6 +102,7 @@ class Preprocessor: def __init__( self, + feature_preprocessing=None, n_bins=64, numerical_preprocessing="ple", categorical_preprocessing="int", @@ -111,10 +120,14 @@ def __init__( ): self.n_bins = n_bins self.numerical_preprocessing = ( - numerical_preprocessing.lower() if numerical_preprocessing is not None else "none" + numerical_preprocessing.lower() + if numerical_preprocessing is not None + else "none" ) self.categorical_preprocessing = ( - categorical_preprocessing.lower() if categorical_preprocessing is not None else "none" + categorical_preprocessing.lower() + if categorical_preprocessing is not None + else "none" ) if self.numerical_preprocessing not in [ "ple", @@ -149,6 +162,7 @@ def __init__( ) self.use_decision_tree_bins = use_decision_tree_bins + self.feature_preprocessing = feature_preprocessing or {} self.column_transformer = None self.fitted = False self.binning_strategy = binning_strategy @@ -237,13 +251,19 @@ def _detect_column_types(self, X): numerical_features.append(col) else: if isinstance(self.cat_cutoff, float): - cutoff_condition = (num_unique_values / total_samples) < self.cat_cutoff + cutoff_condition = ( + num_unique_values / total_samples + ) < self.cat_cutoff elif isinstance(self.cat_cutoff, int): cutoff_condition = num_unique_values < self.cat_cutoff else: - raise ValueError("cat_cutoff should be either a float or an integer.") + raise ValueError( + "cat_cutoff should be either a float or an integer." + ) - if X[col].dtype.kind not in "iufc" or (X[col].dtype.kind == "i" and cutoff_condition): + if X[col].dtype.kind not in "iufc" or ( + X[col].dtype.kind == "i" and cutoff_condition + ): categorical_features.append(col) else: numerical_features.append(col) @@ -274,6 +294,10 @@ def fit(self, X, y=None): if numerical_features: for feature in numerical_features: + feature_preprocessing = self.feature_preprocessing.get( + feature, self.numerical_preprocessing + ) + # extended the annotation list if new transformer is added, either from sklearn or custom numeric_transformer_steps: list[ tuple[ @@ -296,7 +320,7 @@ def fit(self, X, y=None): | SigmoidExpansion, ] ] = [("imputer", SimpleImputer(strategy="mean"))] - if self.numerical_preprocessing in ["binning", "one-hot"]: + if feature_preprocessing in ["binning", "one-hot"]: bins = ( self._get_decision_tree_bins(X[[feature]], y, [feature]) if self.use_decision_tree_bins @@ -308,7 +332,11 @@ def fit(self, X, y=None): ( "discretizer", KBinsDiscretizer( - n_bins=(bins if isinstance(bins, int) else len(bins) - 1), + n_bins=( + bins + if isinstance(bins, int) + else len(bins) - 1 + ), encode="ordinal", strategy=self.binning_strategy, # type: ignore subsample=200_000 if len(X) > 200_000 else None, @@ -326,32 +354,38 @@ def fit(self, X, y=None): ] ) - if self.numerical_preprocessing == "one-hot": + if feature_preprocessing == "one-hot": numeric_transformer_steps.extend( [ ("onehot_from_ordinal", OneHotFromOrdinal()), ] ) - elif self.numerical_preprocessing == "standardization": + elif feature_preprocessing == "standardization": numeric_transformer_steps.append(("scaler", StandardScaler())) - elif self.numerical_preprocessing == "minmax": - numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1)))) + elif feature_preprocessing == "minmax": + numeric_transformer_steps.append( + ("minmax", MinMaxScaler(feature_range=(-1, 1))) + ) - elif self.numerical_preprocessing == "quantile": + elif feature_preprocessing == "quantile": numeric_transformer_steps.append( ( "quantile", - QuantileTransformer(n_quantiles=self.n_bins, random_state=101), + QuantileTransformer( + n_quantiles=self.n_bins, random_state=101 + ), ) ) - elif self.numerical_preprocessing == "polynomial": + elif feature_preprocessing == "polynomial": if self.scaling_strategy == "standardization": numeric_transformer_steps.append(("scaler", StandardScaler())) elif self.scaling_strategy == "minmax": - numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1)))) + numeric_transformer_steps.append( + ("minmax", MinMaxScaler(feature_range=(-1, 1))) + ) numeric_transformer_steps.append( ( "polynomial", @@ -359,14 +393,16 @@ def fit(self, X, y=None): ) ) - elif self.numerical_preprocessing == "robust": + elif feature_preprocessing == "robust": numeric_transformer_steps.append(("robust", RobustScaler())) - elif self.numerical_preprocessing == "splines": + elif feature_preprocessing == "splines": if self.scaling_strategy == "standardization": numeric_transformer_steps.append(("scaler", StandardScaler())) elif self.scaling_strategy == "minmax": - numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1)))) + numeric_transformer_steps.append( + ("minmax", MinMaxScaler(feature_range=(-1, 1))) + ) numeric_transformer_steps.append( ( "splines", @@ -381,11 +417,13 @@ def fit(self, X, y=None): ), ) - elif self.numerical_preprocessing == "rbf": + elif feature_preprocessing == "rbf": if self.scaling_strategy == "standardization": numeric_transformer_steps.append(("scaler", StandardScaler())) elif self.scaling_strategy == "minmax": - numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1)))) + numeric_transformer_steps.append( + ("minmax", MinMaxScaler(feature_range=(-1, 1))) + ) numeric_transformer_steps.append( ( "rbf", @@ -398,11 +436,13 @@ def fit(self, X, y=None): ) ) - elif self.numerical_preprocessing == "sigmoid": + elif feature_preprocessing == "sigmoid": if self.scaling_strategy == "standardization": numeric_transformer_steps.append(("scaler", StandardScaler())) elif self.scaling_strategy == "minmax": - numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1)))) + numeric_transformer_steps.append( + ("minmax", MinMaxScaler(feature_range=(-1, 1))) + ) numeric_transformer_steps.append( ( "sigmoid", @@ -415,11 +455,15 @@ def fit(self, X, y=None): ) ) - elif self.numerical_preprocessing == "ple": - numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1)))) - numeric_transformer_steps.append(("ple", PLE(n_bins=self.n_bins, task=self.task))) + elif feature_preprocessing == "ple": + numeric_transformer_steps.append( + ("minmax", MinMaxScaler(feature_range=(-1, 1))) + ) + numeric_transformer_steps.append( + ("ple", PLE(n_bins=self.n_bins, task=self.task)) + ) - elif self.numerical_preprocessing == "box-cox": + elif feature_preprocessing == "box-cox": numeric_transformer_steps.append( ( "box-cox", @@ -427,7 +471,7 @@ def fit(self, X, y=None): ) ) - elif self.numerical_preprocessing == "yeo-johnson": + elif feature_preprocessing == "yeo-johnson": numeric_transformer_steps.append( ( "yeo-johnson", @@ -435,7 +479,7 @@ def fit(self, X, y=None): ) ) - elif self.numerical_preprocessing == "none": + elif feature_preprocessing == "none": numeric_transformer_steps.append( ( "none", @@ -449,7 +493,10 @@ def fit(self, X, y=None): if categorical_features: for feature in categorical_features: - if self.categorical_preprocessing == "int": + feature_preprocessing = self.feature_preprocessing.get( + feature, self.categorical_preprocessing + ) + if feature_preprocessing == "int": # Use ContinuousOrdinalEncoder for "int" categorical_transformer = Pipeline( [ @@ -457,7 +504,7 @@ def fit(self, X, y=None): ("continuous_ordinal", ContinuousOrdinalEncoder()), ] ) - elif self.categorical_preprocessing == "one-hot": + elif feature_preprocessing == "one-hot": # Use OneHotEncoder for "one-hot" categorical_transformer = Pipeline( [ @@ -467,7 +514,7 @@ def fit(self, X, y=None): ] ) - elif self.categorical_preprocessing == "none": + elif feature_preprocessing == "none": # Use OneHotEncoder for "one-hot" categorical_transformer = Pipeline( [ @@ -475,7 +522,7 @@ def fit(self, X, y=None): ("none", NoTransformer()), ] ) - elif self.categorical_preprocessing == "pretrained": + elif feature_preprocessing == "pretrained": categorical_transformer = Pipeline( [ ("imputer", SimpleImputer(strategy="most_frequent")), @@ -483,12 +530,18 @@ def fit(self, X, y=None): ] ) else: - raise ValueError(f"Unknown categorical_preprocessing type: {self.categorical_preprocessing}") + raise ValueError( + f"Unknown categorical_preprocessing type: {feature_preprocessing}" + ) # Append the transformer for the current categorical feature - transformers.append((f"cat_{feature}", categorical_transformer, [feature])) + transformers.append( + (f"cat_{feature}", categorical_transformer, [feature]) + ) - self.column_transformer = ColumnTransformer(transformers=transformers, remainder="passthrough") + self.column_transformer = ColumnTransformer( + transformers=transformers, remainder="passthrough" + ) self.column_transformer.fit(X, y) self.fitted = True @@ -514,13 +567,17 @@ def _get_decision_tree_bins(self, X, y, numerical_features): bins = [] for feature in numerical_features: tree_model = ( - DecisionTreeClassifier(max_depth=3) if y.dtype.kind in "bi" else DecisionTreeRegressor(max_depth=3) + DecisionTreeClassifier(max_depth=3) + if y.dtype.kind in "bi" + else DecisionTreeRegressor(max_depth=3) ) tree_model.fit(X[[feature]], y) thresholds = tree_model.tree_.threshold[tree_model.tree_.feature != -2] # type: ignore bin_edges = np.sort(np.unique(thresholds)) - bins.append(np.concatenate(([X[feature].min()], bin_edges, [X[feature].max()]))) + bins.append( + np.concatenate(([X[feature].min()], bin_edges, [X[feature].max()])) + ) return bins def transform(self, X): @@ -676,7 +733,9 @@ def get_feature_info(self, verbose=True): "categories": None, # Numerical features don't have categories } if verbose: - print(f"Numerical Feature: {feature_name}, Info: {numerical_feature_info[feature_name]}") + print( + f"Numerical Feature: {feature_name}, Info: {numerical_feature_info[feature_name]}" + ) # Categorical features elif "continuous_ordinal" in steps: diff --git a/mambular/utils/distributions.py b/mambular/utils/distributions.py index 374d0101..75395ce9 100644 --- a/mambular/utils/distributions.py +++ b/mambular/utils/distributions.py @@ -116,7 +116,9 @@ def forward(self, predictions): """ transformed_params = [] for idx, param_name in enumerate(self.param_names): - transform_func = self.get_transform(getattr(self, f"{param_name}_transform", "none")) + transform_func = self.get_transform( + getattr(self, f"{param_name}_transform", "none") + ) transformed_params.append( transform_func(predictions[:, idx]).unsqueeze( # type: ignore 1 @@ -153,7 +155,9 @@ def __init__(self, name="Normal", mean_transform="none", var_transform="positive def compute_loss(self, predictions, y_true): mean = self.mean_transform(predictions[:, self.param_names.index("mean")]) - variance = self.variance_transform(predictions[:, self.param_names.index("variance")]) + variance = self.variance_transform( + predictions[:, self.param_names.index("variance")] + ) normal_dist = dist.Normal(mean, variance) @@ -167,10 +171,14 @@ def evaluate_nll(self, y_true, y_pred): y_true_tensor = torch.tensor(y_true, dtype=torch.float32) y_pred_tensor = torch.tensor(y_pred, dtype=torch.float32) - mse_loss = torch.nn.functional.mse_loss(y_true_tensor, y_pred_tensor[:, self.param_names.index("mean")]) + mse_loss = torch.nn.functional.mse_loss( + y_true_tensor, y_pred_tensor[:, self.param_names.index("mean")] + ) rmse = np.sqrt(mse_loss.detach().numpy()) mae = ( - torch.nn.functional.l1_loss(y_true_tensor, y_pred_tensor[:, self.param_names.index("mean")]) + torch.nn.functional.l1_loss( + y_true_tensor, y_pred_tensor[:, self.param_names.index("mean")] + ) .detach() .numpy() ) @@ -228,7 +236,9 @@ def evaluate_nll(self, y_true, y_pred): .detach() .numpy() # type: ignore ) # type: ignore - poisson_deviance = 2 * torch.sum(y_true_tensor * torch.log(y_true_tensor / rate) - (y_true_tensor - rate)) + poisson_deviance = 2 * torch.sum( + y_true_tensor * torch.log(y_true_tensor / rate) - (y_true_tensor - rate) + ) metrics["mse"] = mse_loss.detach().numpy() metrics["mae"] = mae @@ -367,7 +377,9 @@ class GammaDistribution(BaseDistribution): rate_transform (str or callable): Transformation for the rate parameter to ensure it remains positive. """ - def __init__(self, name="Gamma", shape_transform="positive", rate_transform="positive"): + def __init__( + self, name="Gamma", shape_transform="positive", rate_transform="positive" + ): param_names = ["shape", "rate"] super().__init__(name, param_names) @@ -434,10 +446,16 @@ def evaluate_nll(self, y_true, y_pred): y_true_tensor = torch.tensor(y_true, dtype=torch.float32) y_pred_tensor = torch.tensor(y_pred, dtype=torch.float32) - mse_loss = torch.nn.functional.mse_loss(y_true_tensor, y_pred_tensor[:, self.param_names.index("loc")]) + mse_loss = torch.nn.functional.mse_loss( + y_true_tensor, y_pred_tensor[:, self.param_names.index("loc")] + ) rmse = np.sqrt(mse_loss.detach().numpy()) mae = ( - torch.nn.functional.l1_loss(y_true_tensor, y_pred_tensor[:, self.param_names.index("loc")]).detach().numpy() + torch.nn.functional.l1_loss( + y_true_tensor, y_pred_tensor[:, self.param_names.index("loc")] + ) + .detach() + .numpy() ) metrics["mse"] = mse_loss.detach().numpy() @@ -478,7 +496,9 @@ def __init__( def compute_loss(self, predictions, y_true): # Apply transformations to ensure mean and dispersion parameters are positive mean = self.mean_transform(predictions[:, self.param_names.index("mean")]) - dispersion = self.dispersion_transform(predictions[:, self.param_names.index("dispersion")]) + dispersion = self.dispersion_transform( + predictions[:, self.param_names.index("dispersion")] + ) # Calculate the probability (p) and number of successes (r) from mean and dispersion # These calculations follow from the mean and variance of the negative binomial distribution @@ -574,3 +594,77 @@ def compute_loss(self, predictions, y_true): # Sum losses across quantiles and compute mean loss = torch.mean(torch.stack(losses, dim=1).sum(dim=1)) return loss + + +class JohnsonSuDistribution(BaseDistribution): + """ + Represents a Johnson's SU distribution with parameters for skewness, shape, location, and scale. + + Parameters + ---------- + name (str): The name of the distribution. Defaults to "JohnsonSu". + skew_transform (str or callable): The transformation for the skewness parameter. Defaults to "none". + shape_transform (str or callable): The transformation for the shape parameter. Defaults to "positive". + loc_transform (str or callable): The transformation for the location parameter. Defaults to "none". + scale_transform (str or callable): The transformation for the scale parameter. Defaults to "positive". + """ + + def __init__( + self, + name="JohnsonSu", + skew_transform="none", + shape_transform="positive", + loc_transform="none", + scale_transform="positive", + ): + param_names = ["skew", "shape", "location", "scale"] + super().__init__(name, param_names) + + self.skew_transform = self.get_transform(skew_transform) + self.shape_transform = self.get_transform(shape_transform) + self.loc_transform = self.get_transform(loc_transform) + self.scale_transform = self.get_transform(scale_transform) + + def log_prob(self, x, skew, shape, loc, scale): + """ + Compute the log probability density of the Johnson's SU distribution. + """ + z = skew + shape * torch.asinh((x - loc) / scale) + log_pdf = ( + torch.log(shape / (scale * np.sqrt(2 * np.pi))) + - 0.5 * z**2 + - 0.5 * torch.log(1 + ((x - loc) / scale) ** 2) + ) + return log_pdf + + def compute_loss(self, predictions, y_true): + skew = self.skew_transform(predictions[:, self.param_names.index("skew")]) + shape = self.shape_transform(predictions[:, self.param_names.index("shape")]) + loc = self.loc_transform(predictions[:, self.param_names.index("location")]) + scale = self.scale_transform(predictions[:, self.param_names.index("scale")]) + + log_probs = self.log_prob(y_true, skew, shape, loc, scale) + nll = -log_probs.mean() + return nll + + def evaluate_nll(self, y_true, y_pred): + metrics = super().evaluate_nll(y_true, y_pred) + + y_true_tensor = torch.tensor(y_true, dtype=torch.float32) + y_pred_tensor = torch.tensor(y_pred, dtype=torch.float32) + + mse_loss = torch.nn.functional.mse_loss( + y_true_tensor, y_pred_tensor[:, self.param_names.index("location")] + ) + rmse = np.sqrt(mse_loss.detach().numpy()) + mae = ( + torch.nn.functional.l1_loss( + y_true_tensor, y_pred_tensor[:, self.param_names.index("location")] + ) + .detach() + .numpy() + ) + + metrics.update({"mse": mse_loss.detach().numpy(), "mae": mae, "rmse": rmse}) + + return metrics From b10ff52406798f94a9a03083e0ce7833a2d22d96 Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Fri, 14 Feb 2025 19:02:40 +0100 Subject: [PATCH 2/2] adapt embedding layer to new preprocessing --- mambular/arch_utils/layer_utils/embedding_layer.py | 5 ++++- mambular/preprocessing/preprocessor.py | 3 +++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/mambular/arch_utils/layer_utils/embedding_layer.py b/mambular/arch_utils/layer_utils/embedding_layer.py index 0fb93fd1..cb9bf180 100644 --- a/mambular/arch_utils/layer_utils/embedding_layer.py +++ b/mambular/arch_utils/layer_utils/embedding_layer.py @@ -141,8 +141,10 @@ def forward(self, num_features=None, cat_features=None): # Process categorical embeddings if self.cat_embeddings and cat_features is not None: cat_embeddings = [ - emb(cat_features[i]) for i, emb in enumerate(self.cat_embeddings) + emb(cat_features[i]) if emb(cat_features[i]).ndim == 3 else emb(cat_features[i]).unsqueeze(1) + for i, emb in enumerate(self.cat_embeddings) ] + cat_embeddings = torch.stack(cat_embeddings, dim=1) cat_embeddings = torch.squeeze(cat_embeddings, dim=2) if self.layer_norm_after_embedding: @@ -175,6 +177,7 @@ def forward(self, num_features=None, cat_features=None): # Combine categorical and numerical embeddings if cat_embeddings is not None and num_embeddings is not None: + x = torch.cat([cat_embeddings, num_embeddings], dim=1) elif cat_embeddings is not None: x = cat_embeddings diff --git a/mambular/preprocessing/preprocessor.py b/mambular/preprocessing/preprocessor.py index 7d8d59ce..15f6386d 100644 --- a/mambular/preprocessing/preprocessor.py +++ b/mambular/preprocessing/preprocessor.py @@ -464,6 +464,9 @@ def fit(self, X, y=None): ) elif feature_preprocessing == "box-cox": + numeric_transformer_steps.append( + ("minmax", MinMaxScaler(feature_range=(1e-03, 1))) + ) numeric_transformer_steps.append( ( "box-cox",