diff --git a/mambular/arch_utils/layer_utils/embedding_layer.py b/mambular/arch_utils/layer_utils/embedding_layer.py index 6098adb4..476d7bc9 100644 --- a/mambular/arch_utils/layer_utils/embedding_layer.py +++ b/mambular/arch_utils/layer_utils/embedding_layer.py @@ -156,8 +156,10 @@ def forward(self, num_features, cat_features, emb_features): # Process categorical embeddings if self.cat_embeddings and cat_features is not None: cat_embeddings = [ - emb(cat_features[i]) for i, emb in enumerate(self.cat_embeddings) + emb(cat_features[i]) if emb(cat_features[i]).ndim == 3 else emb(cat_features[i]).unsqueeze(1) + for i, emb in enumerate(self.cat_embeddings) ] + cat_embeddings = torch.stack(cat_embeddings, dim=1) cat_embeddings = torch.squeeze(cat_embeddings, dim=2) if self.layer_norm_after_embedding: @@ -189,6 +191,7 @@ def forward(self, num_features, cat_features, emb_features): ] emb_embeddings = torch.stack(emb_embeddings, dim=1) else: + emb_embeddings = torch.stack(emb_features, dim=1) if self.layer_norm_after_embedding: emb_embeddings = self.embedding_norm(emb_embeddings) @@ -199,6 +202,7 @@ def forward(self, num_features, cat_features, emb_features): if embeddings: x = torch.cat(embeddings, dim=1) if len(embeddings) > 1 else embeddings[0] + else: raise ValueError("No features provided to the model.") diff --git a/mambular/preprocessing/preprocessor.py b/mambular/preprocessing/preprocessor.py index cbb9f2f2..0e69f815 100644 --- a/mambular/preprocessing/preprocessor.py +++ b/mambular/preprocessing/preprocessor.py @@ -40,6 +40,14 @@ class Preprocessor: Parameters ---------- + feature_preprocessing: dict or None + Dictionary mapping column names to preprocessing techniques. Example: + { + "num_feature1": "minmax", + "num_feature2": "ple", + "cat_feature1": "one-hot", + "cat_feature2": "int" + } n_bins : int, default=50 The number of bins to use for numerical feature binning. This parameter is relevant only if `numerical_preprocessing` is set to 'binning', 'ple' or 'one-hot'. @@ -94,6 +102,7 @@ class Preprocessor: def __init__( self, + feature_preprocessing=None, n_bins=64, numerical_preprocessing="ple", categorical_preprocessing="int", @@ -153,6 +162,7 @@ def __init__( ) self.use_decision_tree_bins = use_decision_tree_bins + self.feature_preprocessing = feature_preprocessing or {} self.column_transformer = None self.fitted = False self.binning_strategy = binning_strategy @@ -300,6 +310,10 @@ def fit(self, X, y=None, embeddings=None): if numerical_features: for feature in numerical_features: + feature_preprocessing = self.feature_preprocessing.get( + feature, self.numerical_preprocessing + ) + # extended the annotation list if new transformer is added, either from sklearn or custom numeric_transformer_steps: list[ tuple[ @@ -322,7 +336,7 @@ def fit(self, X, y=None, embeddings=None): | SigmoidExpansion, ] ] = [("imputer", SimpleImputer(strategy="mean"))] - if self.numerical_preprocessing in ["binning", "one-hot"]: + if feature_preprocessing in ["binning", "one-hot"]: bins = ( self._get_decision_tree_bins(X[[feature]], y, [feature]) if self.use_decision_tree_bins @@ -356,22 +370,22 @@ def fit(self, X, y=None, embeddings=None): ] ) - if self.numerical_preprocessing == "one-hot": + if feature_preprocessing == "one-hot": numeric_transformer_steps.extend( [ ("onehot_from_ordinal", OneHotFromOrdinal()), ] ) - elif self.numerical_preprocessing == "standardization": + elif feature_preprocessing == "standardization": numeric_transformer_steps.append(("scaler", StandardScaler())) - elif self.numerical_preprocessing == "minmax": + elif feature_preprocessing == "minmax": numeric_transformer_steps.append( ("minmax", MinMaxScaler(feature_range=(-1, 1))) ) - elif self.numerical_preprocessing == "quantile": + elif feature_preprocessing == "quantile": numeric_transformer_steps.append( ( "quantile", @@ -381,7 +395,7 @@ def fit(self, X, y=None, embeddings=None): ) ) - elif self.numerical_preprocessing == "polynomial": + elif feature_preprocessing == "polynomial": if self.scaling_strategy == "standardization": numeric_transformer_steps.append(("scaler", StandardScaler())) elif self.scaling_strategy == "minmax": @@ -395,10 +409,10 @@ def fit(self, X, y=None, embeddings=None): ) ) - elif self.numerical_preprocessing == "robust": + elif feature_preprocessing == "robust": numeric_transformer_steps.append(("robust", RobustScaler())) - elif self.numerical_preprocessing == "splines": + elif feature_preprocessing == "splines": if self.scaling_strategy == "standardization": numeric_transformer_steps.append(("scaler", StandardScaler())) elif self.scaling_strategy == "minmax": @@ -419,7 +433,7 @@ def fit(self, X, y=None, embeddings=None): ), ) - elif self.numerical_preprocessing == "rbf": + elif feature_preprocessing == "rbf": if self.scaling_strategy == "standardization": numeric_transformer_steps.append(("scaler", StandardScaler())) elif self.scaling_strategy == "minmax": @@ -438,7 +452,7 @@ def fit(self, X, y=None, embeddings=None): ) ) - elif self.numerical_preprocessing == "sigmoid": + elif feature_preprocessing == "sigmoid": if self.scaling_strategy == "standardization": numeric_transformer_steps.append(("scaler", StandardScaler())) elif self.scaling_strategy == "minmax": @@ -457,7 +471,8 @@ def fit(self, X, y=None, embeddings=None): ) ) - elif self.numerical_preprocessing == "ple": + + elif feature_preprocessing == "ple": numeric_transformer_steps.append( ("minmax", MinMaxScaler(feature_range=(-1, 1))) ) @@ -465,7 +480,10 @@ def fit(self, X, y=None, embeddings=None): ("ple", PLE(n_bins=self.n_bins, task=self.task)) ) - elif self.numerical_preprocessing == "box-cox": + elif feature_preprocessing == "box-cox": + numeric_transformer_steps.append( + ("minmax", MinMaxScaler(feature_range=(1e-03, 1))) + ) numeric_transformer_steps.append( ("check_positive", MinMaxScaler(feature_range=(1e-3, 1))) ) @@ -476,7 +494,7 @@ def fit(self, X, y=None, embeddings=None): ) ) - elif self.numerical_preprocessing == "yeo-johnson": + elif feature_preprocessing == "yeo-johnson": numeric_transformer_steps.append( ( "yeo-johnson", @@ -484,7 +502,7 @@ def fit(self, X, y=None, embeddings=None): ) ) - elif self.numerical_preprocessing == "none": + elif feature_preprocessing == "none": numeric_transformer_steps.append( ( "none", @@ -498,7 +516,10 @@ def fit(self, X, y=None, embeddings=None): if categorical_features: for feature in categorical_features: - if self.categorical_preprocessing == "int": + feature_preprocessing = self.feature_preprocessing.get( + feature, self.categorical_preprocessing + ) + if feature_preprocessing == "int": # Use ContinuousOrdinalEncoder for "int" categorical_transformer = Pipeline( [ @@ -506,7 +527,7 @@ def fit(self, X, y=None, embeddings=None): ("continuous_ordinal", ContinuousOrdinalEncoder()), ] ) - elif self.categorical_preprocessing == "one-hot": + elif feature_preprocessing == "one-hot": # Use OneHotEncoder for "one-hot" categorical_transformer = Pipeline( [ @@ -516,7 +537,7 @@ def fit(self, X, y=None, embeddings=None): ] ) - elif self.categorical_preprocessing == "none": + elif feature_preprocessing == "none": # Use OneHotEncoder for "one-hot" categorical_transformer = Pipeline( [ @@ -524,7 +545,7 @@ def fit(self, X, y=None, embeddings=None): ("none", NoTransformer()), ] ) - elif self.categorical_preprocessing == "pretrained": + elif feature_preprocessing == "pretrained": categorical_transformer = Pipeline( [ ("imputer", SimpleImputer(strategy="most_frequent")), @@ -533,7 +554,7 @@ def fit(self, X, y=None, embeddings=None): ) else: raise ValueError( - f"Unknown categorical_preprocessing type: {self.categorical_preprocessing}" + f"Unknown categorical_preprocessing type: {feature_preprocessing}" ) # Append the transformer for the current categorical feature diff --git a/mambular/utils/distributions.py b/mambular/utils/distributions.py index 374d0101..75395ce9 100644 --- a/mambular/utils/distributions.py +++ b/mambular/utils/distributions.py @@ -116,7 +116,9 @@ def forward(self, predictions): """ transformed_params = [] for idx, param_name in enumerate(self.param_names): - transform_func = self.get_transform(getattr(self, f"{param_name}_transform", "none")) + transform_func = self.get_transform( + getattr(self, f"{param_name}_transform", "none") + ) transformed_params.append( transform_func(predictions[:, idx]).unsqueeze( # type: ignore 1 @@ -153,7 +155,9 @@ def __init__(self, name="Normal", mean_transform="none", var_transform="positive def compute_loss(self, predictions, y_true): mean = self.mean_transform(predictions[:, self.param_names.index("mean")]) - variance = self.variance_transform(predictions[:, self.param_names.index("variance")]) + variance = self.variance_transform( + predictions[:, self.param_names.index("variance")] + ) normal_dist = dist.Normal(mean, variance) @@ -167,10 +171,14 @@ def evaluate_nll(self, y_true, y_pred): y_true_tensor = torch.tensor(y_true, dtype=torch.float32) y_pred_tensor = torch.tensor(y_pred, dtype=torch.float32) - mse_loss = torch.nn.functional.mse_loss(y_true_tensor, y_pred_tensor[:, self.param_names.index("mean")]) + mse_loss = torch.nn.functional.mse_loss( + y_true_tensor, y_pred_tensor[:, self.param_names.index("mean")] + ) rmse = np.sqrt(mse_loss.detach().numpy()) mae = ( - torch.nn.functional.l1_loss(y_true_tensor, y_pred_tensor[:, self.param_names.index("mean")]) + torch.nn.functional.l1_loss( + y_true_tensor, y_pred_tensor[:, self.param_names.index("mean")] + ) .detach() .numpy() ) @@ -228,7 +236,9 @@ def evaluate_nll(self, y_true, y_pred): .detach() .numpy() # type: ignore ) # type: ignore - poisson_deviance = 2 * torch.sum(y_true_tensor * torch.log(y_true_tensor / rate) - (y_true_tensor - rate)) + poisson_deviance = 2 * torch.sum( + y_true_tensor * torch.log(y_true_tensor / rate) - (y_true_tensor - rate) + ) metrics["mse"] = mse_loss.detach().numpy() metrics["mae"] = mae @@ -367,7 +377,9 @@ class GammaDistribution(BaseDistribution): rate_transform (str or callable): Transformation for the rate parameter to ensure it remains positive. """ - def __init__(self, name="Gamma", shape_transform="positive", rate_transform="positive"): + def __init__( + self, name="Gamma", shape_transform="positive", rate_transform="positive" + ): param_names = ["shape", "rate"] super().__init__(name, param_names) @@ -434,10 +446,16 @@ def evaluate_nll(self, y_true, y_pred): y_true_tensor = torch.tensor(y_true, dtype=torch.float32) y_pred_tensor = torch.tensor(y_pred, dtype=torch.float32) - mse_loss = torch.nn.functional.mse_loss(y_true_tensor, y_pred_tensor[:, self.param_names.index("loc")]) + mse_loss = torch.nn.functional.mse_loss( + y_true_tensor, y_pred_tensor[:, self.param_names.index("loc")] + ) rmse = np.sqrt(mse_loss.detach().numpy()) mae = ( - torch.nn.functional.l1_loss(y_true_tensor, y_pred_tensor[:, self.param_names.index("loc")]).detach().numpy() + torch.nn.functional.l1_loss( + y_true_tensor, y_pred_tensor[:, self.param_names.index("loc")] + ) + .detach() + .numpy() ) metrics["mse"] = mse_loss.detach().numpy() @@ -478,7 +496,9 @@ def __init__( def compute_loss(self, predictions, y_true): # Apply transformations to ensure mean and dispersion parameters are positive mean = self.mean_transform(predictions[:, self.param_names.index("mean")]) - dispersion = self.dispersion_transform(predictions[:, self.param_names.index("dispersion")]) + dispersion = self.dispersion_transform( + predictions[:, self.param_names.index("dispersion")] + ) # Calculate the probability (p) and number of successes (r) from mean and dispersion # These calculations follow from the mean and variance of the negative binomial distribution @@ -574,3 +594,77 @@ def compute_loss(self, predictions, y_true): # Sum losses across quantiles and compute mean loss = torch.mean(torch.stack(losses, dim=1).sum(dim=1)) return loss + + +class JohnsonSuDistribution(BaseDistribution): + """ + Represents a Johnson's SU distribution with parameters for skewness, shape, location, and scale. + + Parameters + ---------- + name (str): The name of the distribution. Defaults to "JohnsonSu". + skew_transform (str or callable): The transformation for the skewness parameter. Defaults to "none". + shape_transform (str or callable): The transformation for the shape parameter. Defaults to "positive". + loc_transform (str or callable): The transformation for the location parameter. Defaults to "none". + scale_transform (str or callable): The transformation for the scale parameter. Defaults to "positive". + """ + + def __init__( + self, + name="JohnsonSu", + skew_transform="none", + shape_transform="positive", + loc_transform="none", + scale_transform="positive", + ): + param_names = ["skew", "shape", "location", "scale"] + super().__init__(name, param_names) + + self.skew_transform = self.get_transform(skew_transform) + self.shape_transform = self.get_transform(shape_transform) + self.loc_transform = self.get_transform(loc_transform) + self.scale_transform = self.get_transform(scale_transform) + + def log_prob(self, x, skew, shape, loc, scale): + """ + Compute the log probability density of the Johnson's SU distribution. + """ + z = skew + shape * torch.asinh((x - loc) / scale) + log_pdf = ( + torch.log(shape / (scale * np.sqrt(2 * np.pi))) + - 0.5 * z**2 + - 0.5 * torch.log(1 + ((x - loc) / scale) ** 2) + ) + return log_pdf + + def compute_loss(self, predictions, y_true): + skew = self.skew_transform(predictions[:, self.param_names.index("skew")]) + shape = self.shape_transform(predictions[:, self.param_names.index("shape")]) + loc = self.loc_transform(predictions[:, self.param_names.index("location")]) + scale = self.scale_transform(predictions[:, self.param_names.index("scale")]) + + log_probs = self.log_prob(y_true, skew, shape, loc, scale) + nll = -log_probs.mean() + return nll + + def evaluate_nll(self, y_true, y_pred): + metrics = super().evaluate_nll(y_true, y_pred) + + y_true_tensor = torch.tensor(y_true, dtype=torch.float32) + y_pred_tensor = torch.tensor(y_pred, dtype=torch.float32) + + mse_loss = torch.nn.functional.mse_loss( + y_true_tensor, y_pred_tensor[:, self.param_names.index("location")] + ) + rmse = np.sqrt(mse_loss.detach().numpy()) + mae = ( + torch.nn.functional.l1_loss( + y_true_tensor, y_pred_tensor[:, self.param_names.index("location")] + ) + .detach() + .numpy() + ) + + metrics.update({"mse": mse_loss.detach().numpy(), "mae": mae, "rmse": rmse}) + + return metrics