diff --git a/.github/workflows/pr-tests.yml b/.github/workflows/pr-tests.yml
new file mode 100644
index 00000000..0e878fe3
--- /dev/null
+++ b/.github/workflows/pr-tests.yml
@@ -0,0 +1,50 @@
+name: PR Unit Tests
+
+on:
+ pull_request:
+ branches:
+ - develop
+ - master # Add any other branches where you want to enforce tests
+
+jobs:
+ test:
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Checkout Repository
+ uses: actions/checkout@v4
+
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: "3.10" # Change this to match your setup
+
+ - name: Install Poetry
+ run: |
+ curl -sSL https://install.python-poetry.org | python3 -
+ echo "$HOME/.local/bin" >> $GITHUB_PATH
+ export PATH="$HOME/.local/bin:$PATH"
+
+ - name: Install Dependencies
+ run: |
+ python -m pip install --upgrade pip
+ poetry install
+ pip install pytest
+
+ - name: Install Package Locally
+ run: |
+ poetry build
+ pip install dist/*.whl # Install the built package to fix "No module named 'mambular'"
+
+ - name: Run Unit Tests
+ env:
+ PYTHONPATH: ${{ github.workspace }} # Ensure the package is discoverable
+ run: pytest tests/
+
+ - name: Verify Tests Passed
+ if: ${{ success() }}
+ run: echo "All tests passed! Pull request is allowed."
+
+ - name: Fail PR on Test Failure
+ if: ${{ failure() }}
+ run: exit 1 # This ensures the PR cannot be merged if tests fail
diff --git a/README.md b/README.md
index 29f57feb..ecf45eeb 100644
--- a/README.md
+++ b/README.md
@@ -21,6 +21,17 @@
Mambular is a Python library for tabular deep learning. It includes models that leverage the Mamba (State Space Model) architecture, as well as other popular models like TabTransformer, FTTransformer, TabM and tabular ResNets. Check out our paper `Mambular: A Sequential Model for Tabular Deep Learning`, available [here](https://arxiv.org/abs/2408.06291). Also check out our paper introducing [TabulaRNN](https://arxiv.org/pdf/2411.17207) and analyzing the efficiency of NLP inspired tabular models.
+
⚡ What's New ⚡
+
+ - Individual preprocessing: preprocess each feature differently, use pre-trained models for categorical encoding
+ - Extract latent representations of tables
+ - Use embeddings as inputs
+ - Define custom training metrics
+
+
+
+
+
Table of Contents
- [🏃 Quickstart](#-quickstart)
@@ -30,7 +41,6 @@ Mambular is a Python library for tabular deep learning. It includes models that
- [🛠️ Installation](#️-installation)
- [🚀 Usage](#-usage)
- [💻 Implement Your Own Model](#-implement-your-own-model)
-- [Custom Training](#custom-training)
- [🏷️ Citation](#️-citation)
- [License](#license)
@@ -103,6 +113,7 @@ pip install mamba-ssm
Preprocessing
Mambular simplifies data preprocessing with a range of tools designed for easy transformation of tabular data.
+Specify a default method, or a dictionary defining individual preprocessing methods for each feature.
Data Type Detection and Transformation
@@ -116,6 +127,7 @@ Mambular simplifies data preprocessing with a range of tools designed for easy t
- **Polynomial Features**: Automatically generates polynomial and interaction terms for numerical features, enhancing the ability to capture higher-order relationships.
- **Box-Cox & Yeo-Johnson Transformations**: Performs power transformations to stabilize variance and normalize distributions.
- **Custom Binning**: Enables user-defined bin edges for precise discretization of numerical data.
+- **Pre-trained Encoding**: Use sentence transformers to encode categorical features.
@@ -147,6 +159,28 @@ preds = model.predict(X)
preds = model.predict_proba(X)
```
+Get latent representations for each feature:
+```python
+# simple encoding
+model.encode(X)
+```
+
+Use unstructured data:
+```python
+# load pretrained models
+image_model = ...
+nlp_model = ...
+
+# create embeddings
+img_embs = image_model.encode(images)
+txt_embs = nlp_model.encode(texts)
+
+# fit model on tabular data and unstructured data
+model.fit(X_train, y_train, embeddings=[img_embs, txt_embs])
+```
+
+
+
Hyperparameter Optimization
Since all of the models are sklearn base estimators, you can use the built-in hyperparameter optimizatino from sklearn.
@@ -222,9 +256,11 @@ MambularLSS allows you to model the full distribution of a response variable, no
- **studentt**: For data with heavier tails, useful with small samples.
- **negativebinom**: For over-dispersed count data.
- **inversegamma**: Often used as a prior in Bayesian inference.
+- **johnsonsu**: Four parameter distribution defining location, scale, kurtosis and skewness.
- **categorical**: For data with more than two categories.
- **Quantile**: For quantile regression using the pinball loss.
+
These distribution classes make MambularLSS versatile in modeling various data types and distributions.
@@ -269,13 +305,16 @@ Here's how you can implement a custom model with Mambular:
```python
from dataclasses import dataclass
+ from mambular.configs import BaseConfig
@dataclass
- class MyConfig:
+ class MyConfig(BaseConfig):
lr: float = 1e-04
lr_patience: int = 10
weight_decay: float = 1e-06
- lr_factor: float = 0.1
+ n_layers: int = 4
+ pooling_method:str = "avg
+
```
2. **Second, define your model:**
@@ -290,22 +329,32 @@ Here's how you can implement a custom model with Mambular:
class MyCustomModel(BaseModel):
def __init__(
self,
- cat_feature_info,
- num_feature_info,
+ feature_information: tuple,
num_classes: int = 1,
config=None,
**kwargs,
):
- super().__init__(**kwargs)
- self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
+ super().__init__(**kwargs)
+ self.save_hyperparameters(ignore=["feature_information"])
+ self.returns_ensemble = False
+
+ # embedding layer
+ self.embedding_layer = EmbeddingLayer(
+ *feature_information,
+ config=config,
+ )
- input_dim = get_feature_dimensions(num_feature_info, cat_feature_info)
+ input_dim = np.sum(
+ [len(info) * self.hparams.d_model for info in feature_information]
+ )
self.linear = nn.Linear(input_dim, num_classes)
- def forward(self, num_features, cat_features):
- x = num_features + cat_features
- x = torch.cat(x, dim=1)
+ def forward(self, *data) -> torch.Tensor:
+ x = self.embedding_layer(*data)
+ B, S, D = x.shape
+ x = x.reshape(B, S * D)
+
# Pass through linear layer
output = self.linear(x)
@@ -329,60 +378,11 @@ Here's how you can implement a custom model with Mambular:
```python
regressor = MyRegressor(numerical_preprocessing="ple")
regressor.fit(X_train, y_train, max_epochs=50)
+
+ regressor.evaluate(X_test, y_test)
```
-# Custom Training
-If you prefer to setup custom training, preprocessing and evaluation, you can simply use the `mambular.base_models`.
-Just be careful that all basemodels expect lists of features as inputs. More precisely as list for numerical features and a list for categorical features. A custom training loop, with random data could look like this.
-```python
-import torch
-import torch.nn as nn
-import torch.optim as optim
-from mambular.base_models import Mambular
-from mambular.configs import DefaultMambularConfig
-
-# Dummy data and configuration
-cat_feature_info = {
- "cat1": {
- "preprocessing": "imputer -> continuous_ordinal",
- "dimension": 1,
- "categories": 4,
- }
-} # Example categorical feature information
-num_feature_info = {
- "num1": {"preprocessing": "imputer -> scaler", "dimension": 1, "categories": None}
-} # Example numerical feature information
-num_classes = 1
-config = DefaultMambularConfig() # Use the desired configuration
-
-# Initialize model, loss function, and optimizer
-model = Mambular(cat_feature_info, num_feature_info, num_classes, config)
-criterion = nn.MSELoss() # Use MSE for regression; change as appropriate for your task
-optimizer = optim.Adam(model.parameters(), lr=0.001)
-
-# Example training loop
-for epoch in range(10): # Number of epochs
- model.train()
- optimizer.zero_grad()
-
- # Dummy Data
- num_features = [torch.randn(32, 1) for _ in num_feature_info]
- cat_features = [torch.randint(0, 5, (32,)) for _ in cat_feature_info]
- labels = torch.randn(32, num_classes)
-
- # Forward pass
- outputs = model(num_features, cat_features)
- loss = criterion(outputs, labels)
-
- # Backward pass and optimization
- loss.backward()
- optimizer.step()
-
- # Print loss for monitoring
- print(f"Epoch [{epoch+1}/10], Loss: {loss.item():.4f}")
-
-```
# 🏷️ Citation
diff --git a/mambular/__version__.py b/mambular/__version__.py
index a7623c91..629965c4 100644
--- a/mambular/__version__.py
+++ b/mambular/__version__.py
@@ -16,4 +16,4 @@
#
# The following line *must* be the last in the module, exactly as formatted:
-__version__ = "1.1.0"
+__version__ = "1.2.0"
diff --git a/mambular/arch_utils/layer_utils/attention_utils.py b/mambular/arch_utils/layer_utils/attention_utils.py
index bdfed319..1b50d720 100644
--- a/mambular/arch_utils/layer_utils/attention_utils.py
+++ b/mambular/arch_utils/layer_utils/attention_utils.py
@@ -5,7 +5,6 @@
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
-from rotary_embedding_torch import RotaryEmbedding
class GEGLU(nn.Module):
@@ -25,7 +24,7 @@ def FeedForward(dim, mult=4, dropout=0.0):
class Attention(nn.Module):
- def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary=False):
+ def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
@@ -34,18 +33,13 @@ def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary=False):
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
self.dropout = nn.Dropout(dropout)
- self.rotary = rotary
dim = np.int64(dim / 2)
- self.rotary_embedding = RotaryEmbedding(dim=dim)
def forward(self, x):
h = self.heads
x = self.norm(x)
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) # type: ignore
- if self.rotary:
- q = self.rotary_embedding.rotate_queries_or_keys(q)
- k = self.rotary_embedding.rotate_queries_or_keys(k)
q = q * self.scale
sim = torch.einsum("b h i d, b h j d -> b h i j", q, k)
@@ -61,7 +55,7 @@ def forward(self, x):
class Transformer(nn.Module):
- def __init__(self, dim, depth, heads, dim_head, attn_dropout, ff_dropout, rotary=False):
+ def __init__(self, dim, depth, heads, dim_head, attn_dropout, ff_dropout):
super().__init__()
self.layers = nn.ModuleList([])
@@ -74,7 +68,6 @@ def __init__(self, dim, depth, heads, dim_head, attn_dropout, ff_dropout, rotary
heads=heads,
dim_head=dim_head,
dropout=attn_dropout,
- rotary=rotary,
),
FeedForward(dim, dropout=ff_dropout),
]
diff --git a/mambular/arch_utils/layer_utils/embedding_layer.py b/mambular/arch_utils/layer_utils/embedding_layer.py
index 76afd1ec..476d7bc9 100644
--- a/mambular/arch_utils/layer_utils/embedding_layer.py
+++ b/mambular/arch_utils/layer_utils/embedding_layer.py
@@ -6,7 +6,7 @@
class EmbeddingLayer(nn.Module):
- def __init__(self, num_feature_info, cat_feature_info, config):
+ def __init__(self, num_feature_info, cat_feature_info, emb_feature_info, config):
"""Embedding layer that handles numerical and categorical embeddings.
Parameters
@@ -22,8 +22,13 @@ def __init__(self, num_feature_info, cat_feature_info, config):
super().__init__()
self.d_model = getattr(config, "d_model", 128)
- self.embedding_activation = getattr(config, "embedding_activation", nn.Identity())
- self.layer_norm_after_embedding = getattr(config, "layer_norm_after_embedding", False)
+ self.embedding_activation = getattr(
+ config, "embedding_activation", nn.Identity()
+ )
+ self.layer_norm_after_embedding = getattr(
+ config, "layer_norm_after_embedding", False
+ )
+ self.embedding_projection = getattr(config, "embedding_projection", True)
self.use_cls = getattr(config, "use_cls", False)
self.cls_position = getattr(config, "cls_position", 0)
self.embedding_dropout = (
@@ -71,27 +76,47 @@ def __init__(self, num_feature_info, cat_feature_info, config):
# for splines and other embeddings
# splines followed by linear if n_knots actual knots is less than the defined knots
else:
- raise ValueError("Invalid embedding_type. Choose from 'linear', 'ndt', or 'plr'.")
+ raise ValueError(
+ "Invalid embedding_type. Choose from 'linear', 'ndt', or 'plr'."
+ )
self.cat_embeddings = nn.ModuleList(
[
- nn.Sequential(
- nn.Embedding(feature_info["categories"] + 1, self.d_model),
- self.embedding_activation,
- )
- if feature_info["dimension"] == 1
- else nn.Sequential(
- nn.Linear(
- feature_info["dimension"],
- self.d_model,
- bias=self.embedding_bias,
- ),
- self.embedding_activation,
+ (
+ nn.Sequential(
+ nn.Embedding(feature_info["categories"] + 1, self.d_model),
+ self.embedding_activation,
+ )
+ if feature_info["dimension"] == 1
+ else nn.Sequential(
+ nn.Linear(
+ feature_info["dimension"],
+ self.d_model,
+ bias=self.embedding_bias,
+ ),
+ self.embedding_activation,
+ )
)
for feature_name, feature_info in cat_feature_info.items()
]
)
+ if len(emb_feature_info) >= 1:
+ if self.embedding_projection:
+ self.emb_embeddings = nn.ModuleList(
+ [
+ nn.Sequential(
+ nn.Linear(
+ feature_info["dimension"],
+ self.d_model,
+ bias=self.embedding_bias,
+ ),
+ self.embedding_activation,
+ )
+ for feature_name, feature_info in emb_feature_info.items()
+ ]
+ )
+
# Class token if required
if self.use_cls:
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.d_model))
@@ -100,15 +125,12 @@ def __init__(self, num_feature_info, cat_feature_info, config):
if self.layer_norm_after_embedding:
self.embedding_norm = nn.LayerNorm(self.d_model)
- def forward(self, num_features=None, cat_features=None):
+ def forward(self, num_features, cat_features, emb_features):
"""Defines the forward pass of the model.
Parameters
----------
- num_features : Tensor, optional
- Tensor containing the numerical features.
- cat_features : Tensor, optional
- Tensor containing the categorical features.
+ data: tuple of lists of tensors
Returns
-------
@@ -120,13 +142,12 @@ def forward(self, num_features=None, cat_features=None):
ValueError
If no features are provided to the model.
"""
+ num_embeddings, cat_embeddings, emb_embeddings = None, None, None
# Class token initialization
if self.use_cls:
batch_size = (
- cat_features[0].size( # type: ignore
- 0
- )
+ cat_features[0].size(0) # type: ignore
if cat_features != []
else num_features[0].size(0) # type: ignore
) # type: ignore
@@ -134,13 +155,15 @@ def forward(self, num_features=None, cat_features=None):
# Process categorical embeddings
if self.cat_embeddings and cat_features is not None:
- cat_embeddings = [emb(cat_features[i]) for i, emb in enumerate(self.cat_embeddings)]
+ cat_embeddings = [
+ emb(cat_features[i]) if emb(cat_features[i]).ndim == 3 else emb(cat_features[i]).unsqueeze(1)
+ for i, emb in enumerate(self.cat_embeddings)
+ ]
+
cat_embeddings = torch.stack(cat_embeddings, dim=1)
cat_embeddings = torch.squeeze(cat_embeddings, dim=2)
if self.layer_norm_after_embedding:
cat_embeddings = self.embedding_norm(cat_embeddings)
- else:
- cat_embeddings = None
# Process numerical embeddings based on embedding_type
if self.embedding_type == "plr":
@@ -153,8 +176,6 @@ def forward(self, num_features=None, cat_features=None):
num_embeddings = self.num_embeddings(num_features)
if self.layer_norm_after_embedding:
num_embeddings = self.embedding_norm(num_embeddings)
- else:
- num_embeddings = None
else:
# For linear and ndt embeddings, handle each feature individually
if self.num_embeddings and num_features is not None:
@@ -162,16 +183,26 @@ def forward(self, num_features=None, cat_features=None):
num_embeddings = torch.stack(num_embeddings, dim=1)
if self.layer_norm_after_embedding:
num_embeddings = self.embedding_norm(num_embeddings)
+
+ if emb_features != []:
+ if self.embedding_projection:
+ emb_embeddings = [
+ emb(emb_features[i]) for i, emb in enumerate(self.emb_embeddings)
+ ]
+ emb_embeddings = torch.stack(emb_embeddings, dim=1)
else:
- num_embeddings = None
-
- # Combine categorical and numerical embeddings
- if cat_embeddings is not None and num_embeddings is not None:
- x = torch.cat([cat_embeddings, num_embeddings], dim=1)
- elif cat_embeddings is not None:
- x = cat_embeddings
- elif num_embeddings is not None:
- x = num_embeddings
+
+ emb_embeddings = torch.stack(emb_features, dim=1)
+ if self.layer_norm_after_embedding:
+ emb_embeddings = self.embedding_norm(emb_embeddings)
+
+ embeddings = [
+ e for e in [cat_embeddings, num_embeddings, emb_embeddings] if e is not None
+ ]
+
+ if embeddings:
+ x = torch.cat(embeddings, dim=1) if len(embeddings) > 1 else embeddings[0]
+
else:
raise ValueError("No features provided to the model.")
@@ -182,7 +213,9 @@ def forward(self, num_features=None, cat_features=None):
elif self.cls_position == 1:
x = torch.cat([x, cls_tokens], dim=1) # type: ignore
else:
- raise ValueError("Invalid cls_position value. It should be either 0 or 1.")
+ raise ValueError(
+ "Invalid cls_position value. It should be either 0 or 1."
+ )
# Apply dropout to embeddings if specified in config
if self.embedding_dropout is not None:
diff --git a/mambular/base_models/basemodel.py b/mambular/base_models/basemodel.py
index a3b18217..49b56cd1 100644
--- a/mambular/base_models/basemodel.py
+++ b/mambular/base_models/basemodel.py
@@ -33,7 +33,11 @@ def save_hyperparameters(self, ignore=[]):
List of keys to ignore while saving hyperparameters, by default [].
"""
# Filter the config and extra hparams for ignored keys
- config_hparams = {k: v for k, v in vars(self.config).items() if k not in ignore} if self.config else {}
+ config_hparams = (
+ {k: v for k, v in vars(self.config).items() if k not in ignore}
+ if self.config
+ else {}
+ )
extra_hparams = {k: v for k, v in self.extra_hparams.items() if k not in ignore}
config_hparams.update(extra_hparams)
@@ -148,7 +152,9 @@ def initialize_pooling_layers(self, config, n_inputs):
"""Initializes the layers needed for learnable pooling methods based on self.hparams.pooling_method."""
if self.hparams.pooling_method == "learned_flatten":
# Flattening + Linear layer
- self.learned_flatten_pooling = nn.Linear(n_inputs * config.dim_feedforward, config.dim_feedforward)
+ self.learned_flatten_pooling = nn.Linear(
+ n_inputs * config.dim_feedforward, config.dim_feedforward
+ )
elif self.hparams.pooling_method == "attention":
# Attention-based pooling with learnable attention weights
@@ -216,3 +222,29 @@ def pool_sequence(self, out):
return out
else:
raise ValueError(f"Invalid pooling method: {self.hparams.pooling_method}")
+
+ def encode(self, data):
+ if not hasattr(self, "embedding_layer"):
+ raise ValueError("The model does not have an embedding layer")
+
+ # Check if at least one of the contextualized embedding methods exists
+ valid_layers = ["mamba", "rnn", "lstm", "encoder"]
+ available_layer = next(
+ (attr for attr in valid_layers if hasattr(self, attr)), None
+ )
+
+ if not available_layer:
+ raise ValueError("The model does not generate contextualized embeddings")
+
+ # Get the actual layer and call it
+ x = self.embedding_layer(*data)
+
+ if getattr(self.hparams, "shuffle_embeddings", False):
+ x = x[:, self.perm, :]
+
+ layer = getattr(self, available_layer)
+ if available_layer == "rnn":
+ embeddings, _ = layer(x)
+ else:
+ embeddings = layer(x)
+ return embeddings
diff --git a/mambular/base_models/ft_transformer.py b/mambular/base_models/ft_transformer.py
index 56e546d8..f0c7fb84 100644
--- a/mambular/base_models/ft_transformer.py
+++ b/mambular/base_models/ft_transformer.py
@@ -6,6 +6,7 @@
from ..arch_utils.transformer_utils import CustomTransformerEncoderLayer
from ..configs.fttransformer_config import DefaultFTTransformerConfig
from .basemodel import BaseModel
+import numpy as np
class FTTransformer(BaseModel):
@@ -52,22 +53,18 @@ class FTTransformer(BaseModel):
def __init__(
self,
- cat_feature_info,
- num_feature_info,
+ feature_information: tuple, # Expecting (num_feature_info, cat_feature_info, embedding_feature_info)
num_classes=1,
config: DefaultFTTransformerConfig = DefaultFTTransformerConfig(), # noqa: B008
**kwargs,
):
super().__init__(config=config, **kwargs)
- self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
+ self.save_hyperparameters(ignore=["feature_information"])
self.returns_ensemble = False
- self.cat_feature_info = cat_feature_info
- self.num_feature_info = num_feature_info
# embedding layer
self.embedding_layer = EmbeddingLayer(
- num_feature_info=num_feature_info,
- cat_feature_info=cat_feature_info,
+ *feature_information,
config=config,
)
@@ -87,25 +84,23 @@ def __init__(
)
# pooling
- n_inputs = len(num_feature_info) + len(cat_feature_info)
+ n_inputs = np.sum([len(info) for info in feature_information])
self.initialize_pooling_layers(config=config, n_inputs=n_inputs)
- def forward(self, num_features, cat_features):
+ def forward(self, *data):
"""Defines the forward pass of the model.
Parameters
----------
- num_features : Tensor
- Tensor containing the numerical features.
- cat_features : Tensor
- Tensor containing the categorical features.
+ data : tuple
+ Input tuple of tensors of num_features, cat_features, embeddings.
Returns
-------
Tensor
The output predictions of the model.
"""
- x = self.embedding_layer(num_features, cat_features)
+ x = self.embedding_layer(*data)
x = self.encoder(x)
diff --git a/mambular/base_models/lightning_wrapper.py b/mambular/base_models/lightning_wrapper.py
index 1e8dd340..f1c836ca 100644
--- a/mambular/base_models/lightning_wrapper.py
+++ b/mambular/base_models/lightning_wrapper.py
@@ -3,7 +3,6 @@
import lightning as pl
import torch
import torch.nn as nn
-import torchmetrics
class TaskModel(pl.LightningModule):
@@ -31,8 +30,7 @@ def __init__(
self,
model_class: type[nn.Module],
config,
- cat_feature_info,
- num_feature_info,
+ feature_information,
num_classes=1,
lss=False,
family=None,
@@ -41,6 +39,8 @@ def __init__(
pruning_epoch=5,
optimizer_type: str = "Adam",
optimizer_args: dict | None = None,
+ train_metrics: dict[str, Callable] | None = None,
+ val_metrics: dict[str, Callable] | None = None,
**kwargs,
):
super().__init__()
@@ -53,6 +53,10 @@ def __init__(
self.pruning_epoch = pruning_epoch
self.val_losses = []
+ # Store custom metrics
+ self.train_metrics = train_metrics or {}
+ self.val_metrics = val_metrics or {}
+
self.optimizer_params = {
k.replace("optimizer_", ""): v
for k, v in optimizer_args.items() # type: ignore
@@ -65,16 +69,10 @@ def __init__(
if num_classes == 2:
if not self.loss_fct:
self.loss_fct = nn.BCEWithLogitsLoss()
- self.acc = torchmetrics.Accuracy(task="binary")
- self.auroc = torchmetrics.AUROC(task="binary")
- self.precision = torchmetrics.Precision(task="binary")
self.num_classes = 1
elif num_classes > 2:
if not self.loss_fct:
self.loss_fct = nn.CrossEntropyLoss()
- self.acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
- self.auroc = torchmetrics.AUROC(task="multiclass", num_classes=num_classes)
- self.precision = torchmetrics.Precision(task="multiclass", num_classes=num_classes)
else:
self.loss_fct = nn.MSELoss()
@@ -92,13 +90,12 @@ def __init__(
self.base_model = model_class(
config=config,
- num_feature_info=num_feature_info,
- cat_feature_info=cat_feature_info,
+ feature_information=feature_information,
num_classes=output_dim,
**kwargs,
)
- def forward(self, num_features, cat_features):
+ def forward(self, num_features, cat_features, embeddings):
"""Forward pass through the model.
Parameters
@@ -114,7 +111,7 @@ def forward(self, num_features, cat_features):
Model output.
"""
- return self.base_model.forward(num_features, cat_features)
+ return self.base_model.forward(num_features, cat_features, embeddings)
def compute_loss(self, predictions, y_true):
"""Compute the loss for the given predictions and true labels.
@@ -146,7 +143,10 @@ def compute_loss(self, predictions, y_true):
)
if getattr(self.base_model, "returns_ensemble", False): # Ensemble case
- if self.loss_fct.__class__.__name__ == "CrossEntropyLoss" and predictions.dim() == 3:
+ if (
+ self.loss_fct.__class__.__name__ == "CrossEntropyLoss"
+ and predictions.dim() == 3
+ ):
# Classification case with ensemble: predictions (N, E, k), y_true (N,)
N, E, k = predictions.shape
loss = 0.0
@@ -187,31 +187,32 @@ def training_step(self, batch, batch_idx): # type: ignore
Tensor
Training loss.
"""
- cat_features, num_features, labels = batch
+ data, labels = batch
# Check if the model has a `penalty_forward` method
if hasattr(self.base_model, "penalty_forward"):
- preds, penalty = self.base_model.penalty_forward(num_features=num_features, cat_features=cat_features)
+ preds, penalty = self.base_model.penalty_forward(*data)
loss = self.compute_loss(preds, labels) + penalty
else:
- preds = self(num_features=num_features, cat_features=cat_features)
+ preds = self(*data)
loss = self.compute_loss(preds, labels)
# Log the training loss
- self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
-
- # Log additional metrics
- if not self.lss and not hasattr(self.base_model, "returns_ensemble"):
- if self.num_classes > 1:
- acc = self.acc(preds, labels)
- self.log(
- "train_acc",
- acc,
- on_step=True,
- on_epoch=True,
- prog_bar=True,
- logger=True,
- )
+ self.log(
+ "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
+ )
+
+ # Log custom training metrics
+ for metric_name, metric_fn in self.train_metrics.items():
+ metric_value = metric_fn(preds, labels)
+ self.log(
+ f"train_{metric_name}",
+ metric_value,
+ on_step=True,
+ on_epoch=True,
+ prog_bar=True,
+ logger=True,
+ )
return loss
@@ -231,8 +232,8 @@ def validation_step(self, batch, batch_idx): # type: ignore
Validation loss.
"""
- cat_features, num_features, labels = batch
- preds = self(num_features=num_features, cat_features=cat_features)
+ data, labels = batch
+ preds = self(*data)
val_loss = self.compute_loss(preds, labels)
self.log(
@@ -244,18 +245,17 @@ def validation_step(self, batch, batch_idx): # type: ignore
logger=True,
)
- # Log additional metrics
- if not self.lss and not hasattr(self.base_model, "returns_ensemble"):
- if self.num_classes > 1:
- acc = self.acc(preds, labels)
- self.log(
- "val_acc",
- acc,
- on_step=False,
- on_epoch=True,
- prog_bar=True,
- logger=True,
- )
+ # Log custom validation metrics
+ for metric_name, metric_fn in self.val_metrics.items():
+ metric_value = metric_fn(preds, labels)
+ self.log(
+ f"val_{metric_name}",
+ metric_value,
+ on_step=False,
+ on_epoch=True,
+ prog_bar=True,
+ logger=True,
+ )
return val_loss
@@ -274,8 +274,8 @@ def test_step(self, batch, batch_idx): # type: ignore
Tensor
Test loss.
"""
- cat_features, num_features, labels = batch
- preds = self(num_features=num_features, cat_features=cat_features)
+ data, labels = batch
+ preds = self(*data)
test_loss = self.compute_loss(preds, labels)
self.log(
@@ -287,21 +287,28 @@ def test_step(self, batch, batch_idx): # type: ignore
logger=True,
)
- # Log additional metrics
- if not self.lss and not hasattr(self.base_model, "returns_ensemble"):
- if self.num_classes > 1:
- acc = self.acc(preds, labels)
- self.log(
- "test_acc",
- acc,
- on_step=False,
- on_epoch=True,
- prog_bar=True,
- logger=True,
- )
-
return test_loss
+ def predict_step(self, batch, batch_idx):
+ """Predict step for a single batch.
+
+ Parameters
+ ----------
+ batch : tuple
+ Batch of data containing numerical features, categorical features, and labels.
+ batch_idx : int
+ Index of the batch.
+
+ Returns
+ -------
+ Tensor
+ Predictions.
+ """
+
+ preds = self(*batch)
+
+ return preds
+
def on_validation_epoch_end(self):
"""Callback executed at the end of each validation epoch.
@@ -341,8 +348,13 @@ def on_validation_epoch_end(self):
# Apply pruning logic if needed
if self.current_epoch >= self.pruning_epoch:
- if self.early_pruning_threshold is not None and val_loss_value > self.early_pruning_threshold:
- print(f"Pruned at epoch {self.current_epoch}, val_loss {val_loss_value}")
+ if (
+ self.early_pruning_threshold is not None
+ and val_loss_value > self.early_pruning_threshold
+ ):
+ print(
+ f"Pruned at epoch {self.current_epoch}, val_loss {val_loss_value}"
+ )
self.trainer.should_stop = True # Stop training early
def epoch_val_loss_at(self, epoch):
diff --git a/mambular/base_models/mambatab.py b/mambular/base_models/mambatab.py
index fa1e231b..4314bab5 100644
--- a/mambular/base_models/mambatab.py
+++ b/mambular/base_models/mambatab.py
@@ -5,6 +5,7 @@
from ..arch_utils.mamba_utils.mamba_arch import Mamba
from ..arch_utils.mamba_utils.mamba_original import MambaOriginal
from ..arch_utils.mlp_utils import MLPhead
+from ..utils.get_feature_dimensions import get_feature_dimensions
from ..configs.mambatab_config import DefaultMambaTabConfig
from .basemodel import BaseModel
@@ -56,29 +57,22 @@ class MambaTab(BaseModel):
def __init__(
self,
- cat_feature_info,
- num_feature_info,
+ feature_information: tuple, # Expecting (num_feature_info, cat_feature_info, embedding_feature_info)
num_classes=1,
config: DefaultMambaTabConfig = DefaultMambaTabConfig(), # noqa: B008
**kwargs,
):
super().__init__(config=config, **kwargs)
- self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
+ self.save_hyperparameters(ignore=["feature_information"])
- input_dim = 0
- for feature_name, input_shape in num_feature_info.items():
- input_dim += 1
- for feature_name, input_shape in cat_feature_info.items():
- input_dim += 1
+ input_dim = get_feature_dimensions(*feature_information)
- self.cat_feature_info = cat_feature_info
- self.num_feature_info = num_feature_info
self.returns_ensemble = False
self.initial_layer = nn.Linear(input_dim, config.d_model)
self.norm_f = LayerNorm(config.d_model)
- self.embedding_activation = self.hparams.num_embedding_activation
+ self.embedding_activation = self.hparams.embedding_activation
self.axis = config.axis
@@ -93,9 +87,20 @@ def __init__(
else:
self.mamba = MambaOriginal(config)
- def forward(self, num_features, cat_features):
- x = num_features + cat_features
- x = torch.cat(x, dim=1)
+ def forward(self, *data):
+ """Forward pass of the Mambatab model
+
+ Parameters
+ ----------
+ data : tuple
+ Input tuple of tensors of num_features, cat_features, embeddings.
+
+ Returns
+ -------
+ torch.Tensor
+ Output tensor.
+ """
+ x = torch.cat([t for tensors in data for t in tensors], dim=1)
x = self.initial_layer(x)
if self.axis == 1:
diff --git a/mambular/base_models/mambattn.py b/mambular/base_models/mambattn.py
index f393154b..fd86eee8 100644
--- a/mambular/base_models/mambattn.py
+++ b/mambular/base_models/mambattn.py
@@ -1,5 +1,5 @@
import torch
-
+import numpy as np
from ..arch_utils.get_norm_fn import get_normalization_layer
from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer
from ..arch_utils.mamba_utils.mambattn_arch import MambAttn
@@ -52,14 +52,15 @@ class MambAttention(BaseModel):
def __init__(
self,
- cat_feature_info,
- num_feature_info,
+ feature_information: tuple, # Expecting (num_feature_info, cat_feature_info, embedding_feature_info)
num_classes=1,
config: DefaultMambAttentionConfig = DefaultMambAttentionConfig(), # noqa: B008
**kwargs,
):
super().__init__(config=config, **kwargs)
- self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
+ self.save_hyperparameters(ignore=["feature_information"])
+
+ self.returns_ensemble = False
try:
self.pooling_method = self.hparams.pooling_method
@@ -76,8 +77,7 @@ def __init__(
# embedding layer
self.embedding_layer = EmbeddingLayer(
- num_feature_info=num_feature_info,
- cat_feature_info=cat_feature_info,
+ *feature_information,
config=config,
)
@@ -101,25 +101,23 @@ def __init__(
self.perm = torch.randperm(self.embedding_layer.seq_len)
# pooling
- n_inputs = len(num_feature_info) + len(cat_feature_info)
+ n_inputs = np.sum([len(info) for info in feature_information])
self.initialize_pooling_layers(config=config, n_inputs=n_inputs)
- def forward(self, num_features, cat_features):
+ def forward(self, *data):
"""Defines the forward pass of the model.
Parameters
----------
- num_features : Tensor
- Tensor containing the numerical features.
- cat_features : Tensor
- Tensor containing the categorical features.
+ data : tuple
+ Input tuple of tensors of num_features, cat_features, embeddings.
Returns
-------
- Tensor
- The output predictions of the model.
+ torch.Tensor
+ Output tensor.
"""
- x = self.embedding_layer(num_features, cat_features)
+ x = self.embedding_layer(*data)
if self.shuffle_embeddings:
x = x[:, self.perm, :]
diff --git a/mambular/base_models/mambular.py b/mambular/base_models/mambular.py
index ee73b3d6..f24df962 100644
--- a/mambular/base_models/mambular.py
+++ b/mambular/base_models/mambular.py
@@ -6,6 +6,7 @@
from ..arch_utils.mlp_utils import MLPhead
from ..configs.mambular_config import DefaultMambularConfig
from .basemodel import BaseModel
+import numpy as np
class Mambular(BaseModel):
@@ -52,21 +53,19 @@ class Mambular(BaseModel):
def __init__(
self,
- cat_feature_info,
- num_feature_info,
+ feature_information: tuple, # Expecting (cat_feature_info, num_feature_info, embedding_feature_info)
num_classes=1,
config: DefaultMambularConfig = DefaultMambularConfig(), # noqa: B008
**kwargs,
):
super().__init__(config=config, **kwargs)
- self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
+ self.save_hyperparameters(ignore=["feature_information"])
self.returns_ensemble = False
# embedding layer
self.embedding_layer = EmbeddingLayer(
- num_feature_info=num_feature_info,
- cat_feature_info=cat_feature_info,
+ *feature_information,
config=config,
)
@@ -85,25 +84,23 @@ def __init__(
self.perm = torch.randperm(self.embedding_layer.seq_len)
# pooling
- n_inputs = len(num_feature_info) + len(cat_feature_info)
+ n_inputs = np.sum([len(info) for info in feature_information])
self.initialize_pooling_layers(config=config, n_inputs=n_inputs)
- def forward(self, num_features, cat_features):
+ def forward(self, *data):
"""Defines the forward pass of the model.
Parameters
----------
- num_features : Tensor
- Tensor containing the numerical features.
- cat_features : Tensor
- Tensor containing the categorical features.
+ data : tuple
+ Input tuple of tensors of num_features, cat_features, embeddings.
Returns
-------
Tensor
The output predictions of the model.
"""
- x = self.embedding_layer(num_features, cat_features)
+ x = self.embedding_layer(*data)
if self.hparams.shuffle_embeddings:
x = x[:, self.perm, :]
diff --git a/mambular/base_models/mlp.py b/mambular/base_models/mlp.py
index 0c9251fe..94194d82 100644
--- a/mambular/base_models/mlp.py
+++ b/mambular/base_models/mlp.py
@@ -1,6 +1,6 @@
import torch
import torch.nn as nn
-
+import numpy as np
from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer
from ..configs.mlp_config import DefaultMLPConfig
from ..utils.get_feature_dimensions import get_feature_dimensions
@@ -57,31 +57,29 @@ class MLP(BaseModel):
def __init__(
self,
- cat_feature_info,
- num_feature_info,
+ feature_information: tuple, # Expecting (num_feature_info, cat_feature_info, embedding_feature_info)
num_classes: int = 1,
config: DefaultMLPConfig = DefaultMLPConfig(), # noqa: B008
**kwargs,
):
super().__init__(config=config, **kwargs)
- self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
+ self.save_hyperparameters(ignore=["feature_information"])
self.returns_ensemble = False
- self.cat_feature_info = cat_feature_info
- self.num_feature_info = num_feature_info
# Initialize layers
self.layers = nn.ModuleList()
- input_dim = get_feature_dimensions(num_feature_info, cat_feature_info)
-
if self.hparams.use_embeddings:
self.embedding_layer = EmbeddingLayer(
- num_feature_info=num_feature_info,
- cat_feature_info=cat_feature_info,
+ *feature_information,
config=config,
)
- input_dim = len(num_feature_info) * self.hparams.d_model + len(cat_feature_info) * self.hparams.d_model
+ input_dim = np.sum(
+ [len(info) * self.hparams.d_model for info in feature_information]
+ )
+ else:
+ input_dim = get_feature_dimensions(*feature_information)
# Input layer
self.layers.append(nn.Linear(input_dim, self.hparams.layer_sizes[0]))
@@ -97,7 +95,9 @@ def __init__(
# Hidden layers
for i in range(1, len(self.hparams.layer_sizes)):
- self.layers.append(nn.Linear(self.hparams.layer_sizes[i - 1], self.hparams.layer_sizes[i]))
+ self.layers.append(
+ nn.Linear(self.hparams.layer_sizes[i - 1], self.hparams.layer_sizes[i])
+ )
if self.hparams.batch_norm:
self.layers.append(nn.BatchNorm1d(self.hparams.layer_sizes[i]))
if self.hparams.layer_norm:
@@ -112,26 +112,26 @@ def __init__(
# Output layer
self.layers.append(nn.Linear(self.hparams.layer_sizes[-1], num_classes))
- def forward(self, num_features, cat_features) -> torch.Tensor:
+ def forward(self, *data) -> torch.Tensor:
"""Forward pass of the MLP model.
Parameters
----------
- x : torch.Tensor
- Input tensor.
+ data : tuple
+ Input tuple of tensors of num_features, cat_features, embeddings.
Returns
-------
torch.Tensor
Output tensor.
"""
+
if self.hparams.use_embeddings:
- x = self.embedding_layer(num_features, cat_features)
+ x = self.embedding_layer(*data)
B, S, D = x.shape
x = x.reshape(B, S * D)
else:
- x = num_features + cat_features
- x = torch.cat(x, dim=1)
+ x = torch.cat([t for tensors in data for t in tensors], dim=1)
for i in range(len(self.layers) - 1):
if isinstance(self.layers[i], nn.Linear):
diff --git a/mambular/base_models/ndtf.py b/mambular/base_models/ndtf.py
index e279dc05..c7509932 100644
--- a/mambular/base_models/ndtf.py
+++ b/mambular/base_models/ndtf.py
@@ -54,20 +54,17 @@ class NDTF(BaseModel):
def __init__(
self,
- cat_feature_info,
- num_feature_info,
+ feature_information: tuple, # Expecting (num_feature_info, cat_feature_info, embedding_feature_info)
num_classes: int = 1,
config: DefaultNDTFConfig = DefaultNDTFConfig(), # noqa: B008
**kwargs,
):
super().__init__(config=config, **kwargs)
- self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
+ self.save_hyperparameters(ignore=["feature_information"])
- self.cat_feature_info = cat_feature_info
- self.num_feature_info = num_feature_info
self.returns_ensemble = False
- input_dim = get_feature_dimensions(num_feature_info, cat_feature_info)
+ input_dim = get_feature_dimensions(*feature_information)
self.input_dimensions = [input_dim]
@@ -78,10 +75,13 @@ def __init__(
[
NeuralDecisionTree(
input_dim=self.input_dimensions[idx],
- depth=np.random.randint(self.hparams.min_depth, self.hparams.max_depth),
+ depth=np.random.randint(
+ self.hparams.min_depth, self.hparams.max_depth
+ ),
output_dim=num_classes,
lamda=self.hparams.lamda,
- temperature=self.hparams.temperature + np.abs(np.random.normal(0, 0.1)),
+ temperature=self.hparams.temperature
+ + np.abs(np.random.normal(0, 0.1)),
node_sampling=self.hparams.node_sampling,
)
for idx in range(self.hparams.n_ensembles)
@@ -103,21 +103,20 @@ def __init__(
requires_grad=True,
)
- def forward(self, num_features, cat_features) -> torch.Tensor:
+ def forward(self, *data) -> torch.Tensor:
"""Forward pass of the NDTF model.
Parameters
----------
- x : torch.Tensor
- Input tensor.
+ data : tuple
+ Input tuple of tensors of num_features, cat_features, embeddings.
Returns
-------
torch.Tensor
Output tensor.
"""
- x = num_features + cat_features
- x = torch.cat(x, dim=1)
+ x = torch.cat([t for tensors in data for t in tensors], dim=1)
x = self.conv_layer(x.unsqueeze(2))
x = x.transpose(1, 2).squeeze(-1)
@@ -131,21 +130,20 @@ def forward(self, num_features, cat_features) -> torch.Tensor:
return preds @ self.tree_weights
- def penalty_forward(self, num_features, cat_features) -> torch.Tensor:
+ def penalty_forward(self, *data) -> torch.Tensor:
"""Forward pass of the NDTF model.
Parameters
----------
- x : torch.Tensor
- Input tensor.
+ data : tuple
+ Input tuple of tensors of num_features, cat_features, embeddings.
Returns
-------
torch.Tensor
Output tensor.
"""
- x = num_features + cat_features
- x = torch.cat(x, dim=1)
+ x = torch.cat([t for tensors in data for t in tensors], dim=1)
x = self.conv_layer(x.unsqueeze(2))
x = x.transpose(1, 2).squeeze(-1)
diff --git a/mambular/base_models/node.py b/mambular/base_models/node.py
index 82cbf918..70104600 100644
--- a/mambular/base_models/node.py
+++ b/mambular/base_models/node.py
@@ -6,6 +6,7 @@
from ..configs.node_config import DefaultNODEConfig
from ..utils.get_feature_dimensions import get_feature_dimensions
from .basemodel import BaseModel
+import numpy as np
class NODE(BaseModel):
@@ -52,8 +53,7 @@ class NODE(BaseModel):
def __init__(
self,
- cat_feature_info,
- num_feature_info,
+ feature_information: tuple, # Expecting (num_feature_info, cat_feature_info, embedding_feature_info)
num_classes: int = 1,
config: DefaultNODEConfig = DefaultNODEConfig(), # noqa: B008
**kwargs,
@@ -63,16 +63,17 @@ def __init__(
self.returns_ensemble = False
- self.cat_feature_info = cat_feature_info
- self.num_feature_info = num_feature_info
-
if self.hparams.use_embeddings:
- input_dim = len(num_feature_info) * self.hparams.d_model + len(cat_feature_info) * self.hparams.d_model
-
- self.embedding_layer = EmbeddingLayer(config) # type: ignore
+ self.embedding_layer = EmbeddingLayer(
+ *feature_information,
+ config=config,
+ )
+ input_dim = np.sum(
+ [len(info) * self.hparams.d_model for info in feature_information]
+ )
else:
- input_dim = get_feature_dimensions(num_feature_info, cat_feature_info)
+ input_dim = get_feature_dimensions(*feature_information)
self.d_out = num_classes
self.block = DenseBlock(
@@ -90,7 +91,7 @@ def __init__(
output_dim=num_classes,
)
- def forward(self, num_features, cat_features):
+ def forward(self, *data):
"""Forward pass through the NODE model.
Parameters
@@ -106,12 +107,11 @@ def forward(self, num_features, cat_features):
Model output of shape [batch_size, num_classes].
"""
if self.hparams.use_embeddings:
- x = self.embedding_layer(num_features, cat_features)
+ x = self.embedding_layer(*data)
B, S, D = x.shape
x = x.reshape(B, S * D)
else:
- x = num_features + cat_features
- x = torch.cat(x, dim=1)
+ x = torch.cat([t for tensors in data for t in tensors], dim=1)
x = self.block(x).squeeze(-1)
x = self.tabular_head(x)
diff --git a/mambular/base_models/resnet.py b/mambular/base_models/resnet.py
index a2e487e3..2a383bcf 100644
--- a/mambular/base_models/resnet.py
+++ b/mambular/base_models/resnet.py
@@ -1,6 +1,6 @@
import torch
import torch.nn as nn
-
+import numpy as np
from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer
from ..arch_utils.resnet_utils import ResidualBlock
from ..configs.resnet_config import DefaultResNetConfig
@@ -56,30 +56,26 @@ class ResNet(BaseModel):
def __init__(
self,
- cat_feature_info,
- num_feature_info,
+ feature_information: tuple, # Expecting (num_feature_info, cat_feature_info, embedding_feature_info)
num_classes: int = 1,
config: DefaultResNetConfig = DefaultResNetConfig(), # noqa: B008
**kwargs,
):
super().__init__(config=config, **kwargs)
- self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
+ self.save_hyperparameters(ignore=["feature_information"])
self.returns_ensemble = False
- self.cat_feature_info = cat_feature_info
- self.num_feature_info = num_feature_info
if self.hparams.use_embeddings:
- input_dim = len(num_feature_info) * self.hparams.d_model + len(cat_feature_info) * self.hparams.d_model
- # embedding layer
self.embedding_layer = EmbeddingLayer(
- num_feature_info=num_feature_info,
- cat_feature_info=cat_feature_info,
+ *feature_information,
config=config,
)
-
+ input_dim = np.sum(
+ [len(info) * self.hparams.d_model for info in feature_information]
+ )
else:
- input_dim = get_feature_dimensions(num_feature_info, cat_feature_info)
+ input_dim = get_feature_dimensions(*feature_information)
self.initial_layer = nn.Linear(input_dim, self.hparams.layer_sizes[0])
@@ -102,14 +98,25 @@ def __init__(
self.output_layer = nn.Linear(self.hparams.layer_sizes[-1], num_classes)
- def forward(self, num_features, cat_features):
+ def forward(self, *data):
+ """Forward pass of the ResNet model.
+
+ Parameters
+ ----------
+ data : tuple
+ Input tuple of tensors of num_features, cat_features, embeddings.
+
+ Returns
+ -------
+ torch.Tensor
+ Output tensor.
+ """
if self.hparams.use_embeddings:
- x = self.embedding_layer(num_features, cat_features)
+ x = self.embedding_layer(*data)
B, S, D = x.shape
x = x.reshape(B, S * D)
else:
- x = num_features + cat_features
- x = torch.cat(x, dim=1)
+ x = torch.cat([t for tensors in data for t in tensors], dim=1)
x = self.initial_layer(x)
for block in self.blocks:
diff --git a/mambular/base_models/saint.py b/mambular/base_models/saint.py
index 38847fae..e6cfe19d 100644
--- a/mambular/base_models/saint.py
+++ b/mambular/base_models/saint.py
@@ -4,6 +4,7 @@
from ..arch_utils.transformer_utils import RowColTransformer
from ..configs.saint_config import DefaultSAINTConfig
from .basemodel import BaseModel
+import numpy as np
class SAINT(BaseModel):
@@ -50,25 +51,22 @@ class SAINT(BaseModel):
def __init__(
self,
- cat_feature_info,
- num_feature_info,
+ feature_information: tuple, # Expecting (num_feature_info, cat_feature_info, embedding_feature_info)
num_classes=1,
config: DefaultSAINTConfig = DefaultSAINTConfig(), # noqa: B008
**kwargs,
):
super().__init__(config=config, **kwargs)
- self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
+ self.save_hyperparameters(ignore=["feature_information"])
self.returns_ensemble = False
- self.cat_feature_info = cat_feature_info
- self.num_feature_info = num_feature_info
- n_inputs = len(num_feature_info) + len(cat_feature_info)
+
+ n_inputs = np.sum([len(info) for info in feature_information])
if getattr(config, "use_cls", True):
n_inputs += 1
# embedding layer
self.embedding_layer = EmbeddingLayer(
- num_feature_info=num_feature_info,
- cat_feature_info=cat_feature_info,
+ *feature_information,
config=config,
)
@@ -89,22 +87,20 @@ def __init__(
self.initialize_pooling_layers(config=config, n_inputs=n_inputs)
- def forward(self, num_features, cat_features):
+ def forward(self, *data):
"""Defines the forward pass of the model.
Parameters
----------
- num_features : Tensor
- Tensor containing the numerical features.
- cat_features : Tensor
- Tensor containing the categorical features.
+ data : tuple
+ Input tuple of tensors of num_features, cat_features, embeddings.
Returns
-------
- Tensor
- The output predictions of the model.
+ torch.Tensor
+ Output tensor.
"""
- x = self.embedding_layer(num_features, cat_features)
+ x = self.embedding_layer(*data)
x = self.encoder(x)
diff --git a/mambular/base_models/tabm.py b/mambular/base_models/tabm.py
index 7683b4be..ef6e6050 100644
--- a/mambular/base_models/tabm.py
+++ b/mambular/base_models/tabm.py
@@ -1,6 +1,6 @@
import torch
import torch.nn as nn
-
+import numpy as np
from ..arch_utils.get_norm_fn import get_normalization_layer
from ..arch_utils.layer_utils.batch_ensemble_layer import LinearBatchEnsembleLayer
from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer
@@ -11,10 +11,10 @@
class TabM(BaseModel):
+
def __init__(
self,
- cat_feature_info,
- num_feature_info,
+ feature_information: tuple, # Expecting (num_feature_info, cat_feature_info, embedding_feature_info)
num_classes: int = 1,
config: DefaultTabMConfig = DefaultTabMConfig(), # noqa: B008
**kwargs,
@@ -23,7 +23,7 @@ def __init__(
super().__init__(config=config, **kwargs)
# Save hparams including config attributes
- self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
+ self.save_hyperparameters(ignore=["feature_information"])
if not self.hparams.average_ensembles:
self.returns_ensemble = True # Directly set ensemble flag
else:
@@ -35,18 +35,19 @@ def __init__(
# Conditionally initialize EmbeddingLayer based on self.hparams
if self.hparams.use_embeddings:
self.embedding_layer = EmbeddingLayer(
- num_feature_info=num_feature_info,
- cat_feature_info=cat_feature_info,
+ *feature_information,
config=config,
)
if self.hparams.average_embeddings:
input_dim = self.hparams.d_model
else:
- input_dim = (len(num_feature_info) + len(cat_feature_info)) * config.d_model
+ input_dim = np.sum(
+ [len(info) * self.hparams.d_model for info in feature_information]
+ )
else:
- input_dim = get_feature_dimensions(num_feature_info, cat_feature_info)
+ input_dim = get_feature_dimensions(*feature_information)
# Input layer with batch ensembling
self.layers.append(
@@ -71,7 +72,11 @@ def __init__(
if self.hparams.use_glu:
self.layers.append(nn.GLU())
else:
- self.layers.append(self.hparams.activation if hasattr(self.hparams, "activation") else nn.SELU())
+ self.layers.append(
+ self.hparams.activation
+ if hasattr(self.hparams, "activation")
+ else nn.SELU()
+ )
if self.hparams.dropout > 0.0:
self.layers.append(nn.Dropout(self.hparams.dropout))
@@ -105,7 +110,11 @@ def __init__(
if self.hparams.use_glu:
self.layers.append(nn.GLU())
else:
- self.layers.append(self.hparams.activation if hasattr(self.hparams, "activation") else nn.SELU())
+ self.layers.append(
+ self.hparams.activation
+ if hasattr(self.hparams, "activation")
+ else nn.SELU()
+ )
if self.hparams.dropout > 0.0:
self.layers.append(nn.Dropout(self.hparams.dropout))
@@ -118,15 +127,13 @@ def __init__(
num_classes,
)
- def forward(self, num_features, cat_features) -> torch.Tensor:
+ def forward(self, *data) -> torch.Tensor:
"""Forward pass of the TabM model with batch ensembling.
Parameters
----------
- num_features : torch.Tensor
- Numerical features tensor.
- cat_features : torch.Tensor
- Categorical features tensor.
+ data : tuple
+ Input tuple of tensors of num_features, cat_features, embeddings.
Returns
-------
@@ -135,7 +142,7 @@ def forward(self, num_features, cat_features) -> torch.Tensor:
"""
# Handle embeddings if used
if self.hparams.use_embeddings:
- x = self.embedding_layer(num_features, cat_features)
+ x = self.embedding_layer(*data)
# Option 1: Average over feature dimension (N)
if self.hparams.average_embeddings:
x = x.mean(dim=1) # Shape: (B, D)
@@ -145,15 +152,18 @@ def forward(self, num_features, cat_features) -> torch.Tensor:
x = x.reshape(B, N * D) # Shape: (B, N * D)
else:
- x = num_features + cat_features
- x = torch.cat(x, dim=1)
+ x = torch.cat([t for tensors in data for t in tensors], dim=1)
# Process through layers with optional skip connections
for i in range(len(self.layers) - 1):
if isinstance(self.layers[i], LinearBatchEnsembleLayer):
out = self.layers[i](x)
# `out` shape is expected to be (batch_size, ensemble_size, out_features)
- if hasattr(self, "skip_connections") and self.skip_connections and x.shape == out.shape:
+ if (
+ hasattr(self, "skip_connections")
+ and self.skip_connections
+ and x.shape == out.shape
+ ):
x = x + out
else:
x = out
diff --git a/mambular/base_models/tabtransformer.py b/mambular/base_models/tabtransformer.py
index df8104a0..0287203b 100644
--- a/mambular/base_models/tabtransformer.py
+++ b/mambular/base_models/tabtransformer.py
@@ -1,6 +1,6 @@
import torch
import torch.nn as nn
-
+import numpy as np
from ..arch_utils.get_norm_fn import get_normalization_layer
from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer
from ..arch_utils.mlp_utils import MLPhead
@@ -61,14 +61,14 @@ class TabTransformer(BaseModel):
def __init__(
self,
- cat_feature_info,
- num_feature_info,
+ feature_information: tuple, # Expecting (num_feature_info, cat_feature_info, embedding_feature_info)
num_classes=1,
config: DefaultTabTransformerConfig = DefaultTabTransformerConfig(), # noqa: B008
**kwargs,
):
super().__init__(config=config, **kwargs)
- self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
+ self.save_hyperparameters(ignore=["feature_information"])
+ num_feature_info, cat_feature_info, emb_feature_info = feature_information
if cat_feature_info == {}:
raise ValueError(
"You are trying to fit a TabTransformer with no categorical features. \
@@ -76,13 +76,10 @@ def __init__(
)
self.returns_ensemble = False
- self.cat_feature_info = cat_feature_info
- self.num_feature_info = num_feature_info
# embedding layer
self.embedding_layer = EmbeddingLayer(
- num_feature_info=num_feature_info,
- cat_feature_info=cat_feature_info,
+ *({}, cat_feature_info, emb_feature_info),
config=config,
)
@@ -96,8 +93,8 @@ def __init__(
)
mlp_input_dim = 0
- for feature_name, input_shape in num_feature_info.items():
- mlp_input_dim += input_shape
+ for feature_name, info in num_feature_info.items():
+ mlp_input_dim += info["dimension"]
mlp_input_dim += self.hparams.d_model
self.tabular_head = MLPhead(
@@ -107,25 +104,24 @@ def __init__(
)
# pooling
- n_inputs = len(num_feature_info) + len(cat_feature_info)
+ n_inputs = n_inputs = [len(info) for info in feature_information]
self.initialize_pooling_layers(config=config, n_inputs=n_inputs)
- def forward(self, num_features, cat_features):
+ def forward(self, *data):
"""Defines the forward pass of the model.
Parameters
----------
- num_features : Tensor
- Tensor containing the numerical features.
- cat_features : Tensor
- Tensor containing the categorical features.
+ ata : tuple
+ Input tuple of tensors of num_features, cat_features, embeddings.
Returns
-------
Tensor
The output predictions of the model.
"""
- cat_embeddings = self.embedding_layer(None, cat_features)
+ num_features, cat_features, emb_features = data
+ cat_embeddings = self.embedding_layer(*(None, cat_features, emb_features))
num_features = torch.cat(num_features, dim=1)
num_embeddings = self.norm_f(num_features) # type: ignore
diff --git a/mambular/base_models/tabularnn.py b/mambular/base_models/tabularnn.py
index d4824e95..6ac5c3a8 100644
--- a/mambular/base_models/tabularnn.py
+++ b/mambular/base_models/tabularnn.py
@@ -1,5 +1,4 @@
from dataclasses import replace
-
import torch
import torch.nn as nn
@@ -12,26 +11,23 @@
class TabulaRNN(BaseModel):
+
def __init__(
self,
- cat_feature_info,
- num_feature_info,
+ feature_information: tuple, # Expecting (num_feature_info, cat_feature_info, embedding_feature_info)
num_classes=1,
config: DefaultTabulaRNNConfig = DefaultTabulaRNNConfig(), # noqa: B008
**kwargs,
):
super().__init__(config=config, **kwargs)
- self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
+ self.save_hyperparameters(ignore=["feature_information"])
self.returns_ensemble = False
- self.cat_feature_info = cat_feature_info
- self.num_feature_info = num_feature_info
self.rnn = ConvRNN(config)
self.embedding_layer = EmbeddingLayer(
- num_feature_info=num_feature_info,
- cat_feature_info=cat_feature_info,
+ *feature_information,
config=config,
)
@@ -50,10 +46,10 @@ def __init__(
self.norm_f = get_normalization_layer(temp_config)
# pooling
- n_inputs = len(num_feature_info) + len(cat_feature_info)
+ n_inputs = [len(info) for info in feature_information]
self.initialize_pooling_layers(config=config, n_inputs=n_inputs)
- def forward(self, num_features, cat_features):
+ def forward(self, *data):
"""Defines the forward pass of the model.
Parameters
@@ -69,7 +65,7 @@ def forward(self, num_features, cat_features):
The output predictions of the model.
"""
- x = self.embedding_layer(num_features, cat_features)
+ x = self.embedding_layer(*data)
# RNN forward pass
out, _ = self.rnn(x)
z = self.linear(torch.mean(x, dim=1))
diff --git a/mambular/configs/__init__.py b/mambular/configs/__init__.py
index b2e18708..cdda37ef 100644
--- a/mambular/configs/__init__.py
+++ b/mambular/configs/__init__.py
@@ -10,6 +10,7 @@
from .tabm_config import DefaultTabMConfig
from .tabtransformer_config import DefaultTabTransformerConfig
from .tabularnn_config import DefaultTabulaRNNConfig
+from .base_config import BaseConfig
__all__ = [
"DefaultFTTransformerConfig",
@@ -24,4 +25,5 @@
"DefaultTabMConfig",
"DefaultTabTransformerConfig",
"DefaultTabulaRNNConfig",
+ "BaseConfig"
]
diff --git a/mambular/configs/base_config.py b/mambular/configs/base_config.py
new file mode 100644
index 00000000..0e5a6396
--- /dev/null
+++ b/mambular/configs/base_config.py
@@ -0,0 +1,83 @@
+from dataclasses import dataclass, field
+from collections.abc import Callable
+import torch.nn as nn
+
+
+@dataclass
+class BaseConfig:
+ """
+ Base configuration class with shared hyperparameters for models.
+
+ This configuration class provides common hyperparameters for optimization,
+ embeddings, and categorical encoding, which can be inherited by specific
+ model configurations.
+
+ Parameters
+ ----------
+ lr : float, default=1e-04
+ Learning rate for the optimizer.
+ lr_patience : int, default=10
+ Number of epochs with no improvement before reducing the learning rate.
+ weight_decay : float, default=1e-06
+ L2 regularization parameter for weight decay in the optimizer.
+ lr_factor : float, default=0.1
+ Factor by which the learning rate is reduced when patience is exceeded.
+ activation : Callable, default=nn.ReLU()
+ Activation function to use in the model's layers.
+ cat_encoding : str, default="int"
+ Method for encoding categorical features ('int', 'one-hot', or 'linear').
+
+ Embedding Parameters
+ --------------------
+ use_embeddings : bool, default=False
+ Whether to use embeddings for categorical or numerical features.
+ embedding_activation : Callable, default=nn.Identity()
+ Activation function applied to embeddings.
+ embedding_type : str, default="linear"
+ Type of embedding to use ('linear', 'plr', etc.).
+ embedding_bias : bool, default=False
+ Whether to use bias in embedding layers.
+ layer_norm_after_embedding : bool, default=False
+ Whether to apply layer normalization after embedding layers.
+ d_model : int, default=32
+ Dimensionality of embeddings or model representations.
+ plr_lite : bool, default=False
+ Whether to use a lightweight version of Piecewise Linear Regression (PLR).
+ n_frequencies : int, default=48
+ Number of frequency components for embeddings.
+ frequencies_init_scale : float, default=0.01
+ Initial scale for frequency components in embeddings.
+ embedding_projection : bool, default=True
+ Whether to apply a projection layer after embeddings.
+
+ Notes
+ -----
+ - This base class is meant to be inherited by other configurations.
+ - Provides default values that can be overridden in derived configurations.
+
+ """
+
+ # Training Parameters
+ lr: float = 1e-04
+ lr_patience: int = 10
+ weight_decay: float = 1e-06
+ lr_factor: float = 0.1
+
+ # Embedding Parameters
+ use_embeddings: bool = False
+ embedding_activation: Callable = nn.Identity() # noqa: RUF009
+ embedding_type: str = "linear"
+ embedding_bias: bool = False
+ layer_norm_after_embedding: bool = False
+ d_model: int = 32
+ plr_lite: bool = False
+ n_frequencies: int = 48
+ frequencies_init_scale: float = 0.01
+ embedding_projection: bool = True
+
+ # Architecture Parameters
+ batch_norm: bool = False
+ layer_norm: bool = False
+ layer_norm_eps: float = 1e-05
+ activation: Callable = nn.ReLU() # noqa: RUF009
+ cat_encoding: str = "int"
diff --git a/mambular/configs/fttransformer_config.py b/mambular/configs/fttransformer_config.py
index d6aa11d2..37bdcf4b 100644
--- a/mambular/configs/fttransformer_config.py
+++ b/mambular/configs/fttransformer_config.py
@@ -1,25 +1,16 @@
from collections.abc import Callable
from dataclasses import dataclass, field
-
import torch.nn as nn
-
from ..arch_utils.transformer_utils import ReGLU
+from .base_config import BaseConfig
@dataclass
-class DefaultFTTransformerConfig:
+class DefaultFTTransformerConfig(BaseConfig):
"""Configuration class for the FT Transformer model with predefined hyperparameters.
Parameters
----------
- lr : float, default=1e-04
- Learning rate for the optimizer.
- lr_patience : int, default=10
- Number of epochs with no improvement after which the learning rate will be reduced.
- weight_decay : float, default=1e-06
- Weight decay (L2 regularization) for the optimizer.
- lr_factor : float, default=0.1
- Factor by which the learning rate will be reduced.
d_model : int, default=128
Dimensionality of the transformer model.
n_layers : int, default=4
@@ -44,20 +35,6 @@ class DefaultFTTransformerConfig:
Whether to apply normalization before other operations in each transformer block.
bias : bool, default=True
Whether to use bias in linear layers.
- embedding_activation : callable, default=nn.Identity()
- Activation function for embeddings.
- embedding_type : str, default="linear"
- Type of embedding to use ('linear', 'plr', etc.).
- plr_lite : bool, default=False
- Whether to use a lightweight version of Piecewise Linear Regression (PLR).
- n_frequencies : int, default=48
- Number of frequencies for PLR embeddings.
- frequencies_init_scale : float, default=0.01
- Initial scale for frequency parameters in embeddings.
- embedding_bias : bool, default=False
- Whether to use bias in embedding layers.
- layer_norm_after_embedding : bool, default=False
- Whether to apply layer normalization after embedding layers.
head_layer_sizes : list, default=()
Sizes of the fully connected layers in the model's head.
head_dropout : float, default=0.5
@@ -76,12 +53,6 @@ class DefaultFTTransformerConfig:
Method for encoding categorical features ('int', 'one-hot', or 'linear').
"""
- # Optimizer Parameters
- lr: float = 1e-04
- lr_patience: int = 10
- weight_decay: float = 1e-06
- lr_factor: float = 0.1
-
# Architecture Parameters
d_model: int = 128
n_layers: int = 4
@@ -96,15 +67,6 @@ class DefaultFTTransformerConfig:
norm_first: bool = False
bias: bool = True
- # Embedding Parameters
- embedding_activation: Callable = nn.Identity() # noqa: RUF009
- embedding_type: str = "linear"
- plr_lite: bool = False
- n_frequencies: int = 48
- frequencies_init_scale: float = 0.01
- embedding_bias: bool = False
- layer_norm_after_embedding: bool = False
-
# Head Parameters
head_layer_sizes: list = field(default_factory=list)
head_dropout: float = 0.5
diff --git a/mambular/configs/mambatab_config.py b/mambular/configs/mambatab_config.py
index c00d4bab..ccfe459b 100644
--- a/mambular/configs/mambatab_config.py
+++ b/mambular/configs/mambatab_config.py
@@ -1,23 +1,15 @@
from collections.abc import Callable
from dataclasses import dataclass, field
-
import torch.nn as nn
+from .base_config import BaseConfig
@dataclass
-class DefaultMambaTabConfig:
+class DefaultMambaTabConfig(BaseConfig):
"""Configuration class for the Default MambaTab model with predefined hyperparameters.
Parameters
----------
- lr : float, default=1e-04
- Learning rate for the optimizer.
- lr_patience : int, default=10
- Number of epochs with no improvement after which the learning rate will be reduced.
- weight_decay : float, default=1e-06
- Weight decay (L2 regularization) for the optimizer.
- lr_factor : float, default=0.1
- Factor by which the learning rate will be reduced.
d_model : int, default=64
Dimensionality of the model.
n_layers : int, default=1
@@ -50,18 +42,6 @@ class DefaultMambaTabConfig:
Activation function for the model.
axis : int, default=1
Axis along which operations are applied, if applicable.
- num_embedding_activation : callable, default=nn.ReLU()
- Activation function for numerical embeddings.
- embedding_type : str, default="linear"
- Type of embedding to use ('linear', etc.).
- embedding_bias : bool, default=False
- Whether to use bias in the embedding layers.
- plr_lite : bool, default=False
- Whether to use a lightweight version of Piecewise Linear Regression (PLR).
- n_frequencies : int, default=48
- Number of frequencies for PLR embeddings.
- frequencies_init_scale : float, default=0.01
- Initial scale for frequency parameters in embeddings.
head_layer_sizes : list, default=()
Sizes of the fully connected layers in the model's head.
head_dropout : float, default=0.0
@@ -82,12 +62,6 @@ class DefaultMambaTabConfig:
Whether to process data bidirectionally.
"""
- # Optimizer Parameters
- lr: float = 1e-04
- lr_patience: int = 10
- weight_decay: float = 1e-06
- lr_factor: float = 0.1
-
# Architecture Parameters
d_model: int = 64
n_layers: int = 1
@@ -106,14 +80,6 @@ class DefaultMambaTabConfig:
activation: Callable = nn.ReLU() # noqa: RUF009
axis: int = 1
- # Embedding Parameters
- num_embedding_activation: Callable = nn.ReLU() # noqa: RUF009
- embedding_type: str = "linear"
- embedding_bias: bool = False
- plr_lite: bool = False
- n_frequencies: int = 48
- frequencies_init_scale: float = 0.01
-
# Head Parameters
head_layer_sizes: list = field(default_factory=list)
head_dropout: float = 0.0
diff --git a/mambular/configs/mambattention_config.py b/mambular/configs/mambattention_config.py
index b1f029ad..49e596e5 100644
--- a/mambular/configs/mambattention_config.py
+++ b/mambular/configs/mambattention_config.py
@@ -1,23 +1,15 @@
from collections.abc import Callable
from dataclasses import dataclass, field
-
import torch.nn as nn
+from .base_config import BaseConfig
@dataclass
-class DefaultMambAttentionConfig:
+class DefaultMambAttentionConfig(BaseConfig):
"""Configuration class for the Default Mambular Attention model with predefined hyperparameters.
Parameters
----------
- lr : float, default=1e-04
- Learning rate for the optimizer.
- lr_patience : int, default=10
- Number of epochs with no improvement after which learning rate will be reduced.
- weight_decay : float, default=1e-06
- Weight decay (L2 penalty) for the optimizer.
- lr_factor : float, default=0.1
- Factor by which the learning rate will be reduced.
d_model : int, default=64
Dimensionality of the model.
n_layers : int, default=4
@@ -58,22 +50,6 @@ class DefaultMambAttentionConfig:
Type of normalization used in the model.
activation : callable, default=nn.SiLU()
Activation function for the model.
- layer_norm_eps : float, default=1e-05
- Epsilon value for layer normalization.
- num_embedding_activation : callable, default=nn.ReLU()
- Activation function for numerical embeddings.
- embedding_type : str, default="linear"
- Type of embedding to use ('linear', etc.).
- embedding_bias : bool, default=False
- Whether to use bias in the embedding layers.
- plr_lite : bool, default=False
- Whether to use a lightweight version of Piecewise Linear Regression (PLR).
- n_frequencies : int, default=48
- Number of frequencies for PLR embeddings.
- frequencies_init_scale : float, default=0.01
- Initial scale for frequency parameters in embeddings.
- layer_norm_after_embedding : bool, default=False
- Whether to apply layer normalization after embedding layers.
head_layer_sizes : list, default=()
Sizes of the fully connected layers in the model's head.
head_dropout : float, default=0.5
@@ -106,12 +82,6 @@ class DefaultMambAttentionConfig:
Number of attention layers in the model.
"""
- # Optimizer Parameters
- lr: float = 1e-04
- lr_patience: int = 10
- weight_decay: float = 1e-06
- lr_factor: float = 0.1
-
# Architecture Parameters
d_model: int = 64
n_layers: int = 4
@@ -133,16 +103,6 @@ class DefaultMambAttentionConfig:
dt_init_floor: float = 1e-04
norm: str = "LayerNorm"
activation: Callable = nn.SiLU() # noqa: RUF009
- layer_norm_eps: float = 1e-05
-
- # Embedding Parameters
- num_embedding_activation: Callable = nn.ReLU() # noqa: RUF009
- embedding_type: str = "linear"
- embedding_bias: bool = False
- plr_lite: bool = False
- n_frequencies: int = 48
- frequencies_init_scale: float = 0.01
- layer_norm_after_embedding: bool = False
# Head Parameters
head_layer_sizes: list = field(default_factory=list)
diff --git a/mambular/configs/mambular_config.py b/mambular/configs/mambular_config.py
index fcebca0e..a60b54e7 100644
--- a/mambular/configs/mambular_config.py
+++ b/mambular/configs/mambular_config.py
@@ -1,23 +1,15 @@
from collections.abc import Callable
from dataclasses import dataclass, field
-
import torch.nn as nn
+from .base_config import BaseConfig
@dataclass
-class DefaultMambularConfig:
+class DefaultMambularConfig(BaseConfig):
"""Configuration class for the Default Mambular model with predefined hyperparameters.
Parameters
----------
- lr : float, default=1e-04
- Learning rate for the optimizer.
- lr_patience : int, default=10
- Number of epochs with no improvement after which the learning rate will be reduced.
- weight_decay : float, default=1e-06
- Weight decay (L2 penalty) for the optimizer.
- lr_factor : float, default=0.1
- Factor by which the learning rate will be reduced.
d_model : int, default=64
Dimensionality of the model.
n_layers : int, default=4
@@ -28,6 +20,8 @@ class DefaultMambularConfig:
Whether to use bias in the linear layers.
dropout : float, default=0.0
Dropout rate for regularization.
+ d_conv : int, default=4
+ Size of convolution over columns.
dt_rank : str, default="auto"
Rank of the decision tree used in the model.
d_state : int, default=128
@@ -46,22 +40,6 @@ class DefaultMambularConfig:
Type of normalization used ('RMSNorm', etc.).
activation : callable, default=nn.SiLU()
Activation function for the model.
- layer_norm_eps : float, default=1e-05
- Epsilon value for layer normalization.
- embedding_activation : callable, default=nn.Identity()
- Activation function for embeddings.
- embedding_type : str, default="linear"
- Type of embedding to use ('linear', etc.).
- embedding_bias : bool, default=False
- Whether to use bias in the embedding layers.
- plr_lite : bool, default=False
- Whether to use a lightweight version of Piecewise Linear Regression (PLR).
- n_frequencies : int, default=48
- Number of frequencies for PLR embeddings.
- frequencies_init_scale : float, default=0.01
- Initial scale for frequency parameters in embeddings.
- layer_norm_after_embedding : bool, default=False
- Whether to apply layer normalization after embedding layers.
shuffle_embeddings : bool, default=False
Whether to shuffle embeddings before being passed to Mamba layers.
head_layer_sizes : list, default=()
@@ -86,17 +64,18 @@ class DefaultMambularConfig:
Whether to use PSCAN for the state-space model.
mamba_version : str, default="mamba-torch"
Version of the Mamba model to use ('mamba-torch', 'mamba1', 'mamba2').
+ conv_bias : bool, default=False
+ Whether to use a bias in the 1D convolution before each mamba block
+ AD_weight_decay: bool = True
+ Whether to use weight decay als for the A and D matrices in Mamba
+ BC_layer_norm: bool = False
+ Whether to use layer norm on the B and C matrices
"""
- # Optimizer Parameters
- lr: float = 1e-04
- lr_patience: int = 10
- weight_decay: float = 1e-06
- lr_factor: float = 0.1
-
# Architecture Parameters
d_model: int = 64
n_layers: int = 4
+ d_conv: int = 4
expand_factor: int = 2
bias: bool = False
dropout: float = 0.0
@@ -109,16 +88,11 @@ class DefaultMambularConfig:
dt_init_floor: float = 1e-04
norm: str = "RMSNorm"
activation: Callable = nn.SiLU() # noqa: RUF009
- layer_norm_eps: float = 1e-05
+ conv_bias: bool = False
+ AD_weight_decay: bool = True
+ BC_layer_norm: bool = False
# Embedding Parameters
- embedding_activation: Callable = nn.Identity() # noqa: RUF009
- embedding_type: str = "linear"
- embedding_bias: bool = False
- plr_lite: bool = False
- n_frequencies: int = 48
- frequencies_init_scale: float = 0.01
- layer_norm_after_embedding: bool = False
shuffle_embeddings: bool = False
# Head Parameters
diff --git a/mambular/configs/mlp_config.py b/mambular/configs/mlp_config.py
index 0c43cc11..1dda45fa 100644
--- a/mambular/configs/mlp_config.py
+++ b/mambular/configs/mlp_config.py
@@ -1,23 +1,15 @@
from collections.abc import Callable
from dataclasses import dataclass, field
-
import torch.nn as nn
+from .base_config import BaseConfig
@dataclass
-class DefaultMLPConfig:
+class DefaultMLPConfig(BaseConfig):
"""Configuration class for the default Multi-Layer Perceptron (MLP) model with predefined hyperparameters.
Parameters
----------
- lr : float, default=1e-04
- Learning rate for the optimizer.
- lr_patience : int, default=10
- Number of epochs with no improvement after which the learning rate will be reduced.
- weight_decay : float, default=1e-06
- Weight decay (L2 regularization) for the optimizer.
- lr_factor : float, default=0.1
- Factor by which the learning rate will be reduced.
layer_sizes : list, default=(256, 128, 32)
Sizes of the layers in the MLP.
activation : callable, default=nn.ReLU()
@@ -30,38 +22,8 @@ class DefaultMLPConfig:
Whether to use Gated Linear Units (GLU) in the MLP.
skip_connections : bool, default=False
Whether to use skip connections in the MLP.
- batch_norm : bool, default=False
- Whether to use batch normalization in the MLP layers.
- layer_norm : bool, default=False
- Whether to use layer normalization in the MLP layers.
- layer_norm_eps : float, default=1e-05
- Epsilon value for layer normalization.
- use_embeddings : bool, default=False
- Whether to use embedding layers for all features.
- embedding_activation : callable, default=nn.Identity()
- Activation function for embeddings.
- embedding_type : str, default="linear"
- Type of embedding to use ('linear', 'plr', etc.).
- embedding_bias : bool, default=False
- Whether to use bias in the embedding layers.
- layer_norm_after_embedding : bool, default=False
- Whether to apply layer normalization after embedding.
- d_model : int, default=32
- Dimensionality of the embeddings.
- plr_lite : bool, default=False
- Whether to use a lightweight version of Piecewise Linear Regression (PLR).
- n_frequencies : int, default=48
- Number of frequencies for PLR embeddings.
- frequencies_init_scale : float, default=0.01
- Initial scale for frequency parameters in embeddings.
"""
- # Optimizer Parameters
- lr: float = 1e-04
- lr_patience: int = 10
- weight_decay: float = 1e-06
- lr_factor: float = 0.1
-
# Architecture Parameters
layer_sizes: list = field(default_factory=lambda: [256, 128, 32])
activation: Callable = nn.ReLU() # noqa: RUF009
@@ -69,17 +31,3 @@ class DefaultMLPConfig:
dropout: float = 0.2
use_glu: bool = False
skip_connections: bool = False
- batch_norm: bool = False
- layer_norm: bool = False
- layer_norm_eps: float = 1e-05
-
- # Embedding Parameters
- use_embeddings: bool = False
- embedding_activation: Callable = nn.Identity() # noqa: RUF009
- embedding_type: str = "linear"
- embedding_bias: bool = False
- layer_norm_after_embedding: bool = False
- d_model: int = 32
- plr_lite: bool = False
- n_frequencies: int = 48
- frequencies_init_scale: float = 0.01
diff --git a/mambular/configs/ndtf_config.py b/mambular/configs/ndtf_config.py
index 89fad29b..1fa1eec8 100644
--- a/mambular/configs/ndtf_config.py
+++ b/mambular/configs/ndtf_config.py
@@ -1,20 +1,13 @@
from dataclasses import dataclass
+from .base_config import BaseConfig
@dataclass
-class DefaultNDTFConfig:
+class DefaultNDTFConfig(BaseConfig):
"""Configuration class for the default Neural Decision Tree Forest (NDTF) model with predefined hyperparameters.
Parameters
----------
- lr : float, default=1e-04
- Learning rate for the optimizer.
- lr_patience : int, default=10
- Number of epochs with no improvement after which the learning rate will be reduced.
- weight_decay : float, default=1e-06
- Weight decay (L2 penalty) applied to the model's weights during optimization.
- lr_factor : float, default=0.1
- Factor by which the learning rate will be reduced when a plateau is reached.
min_depth : int, default=2
Minimum depth of trees in the forest. Controls the simplest model structure.
max_depth : int, default=10
@@ -33,10 +26,6 @@ class DefaultNDTFConfig:
Factor with which the penalty is multiplied
"""
- lr: float = 1e-4
- lr_patience: int = 5
- weight_decay: float = 1e-7
- lr_factor: float = 0.1
min_depth: int = 4
max_depth: int = 16
temperature: float = 0.1
diff --git a/mambular/configs/node_config.py b/mambular/configs/node_config.py
index 82a4bdac..2c93d30d 100644
--- a/mambular/configs/node_config.py
+++ b/mambular/configs/node_config.py
@@ -1,23 +1,15 @@
from collections.abc import Callable
from dataclasses import dataclass, field
-
import torch.nn as nn
+from .base_config import BaseConfig
@dataclass
-class DefaultNODEConfig:
+class DefaultNODEConfig(BaseConfig):
"""Configuration class for the Neural Oblivious Decision Ensemble (NODE) model.
Parameters
----------
- lr : float, default=1e-03
- Learning rate for the optimizer.
- lr_patience : int, default=10
- Number of epochs without improvement after which the learning rate will be reduced.
- weight_decay : float, default=1e-06
- Weight decay (L2 regularization penalty) applied by the optimizer.
- lr_factor : float, default=0.1
- Factor by which the learning rate is reduced when there is no improvement.
num_layers : int, default=4
Number of dense layers in the model.
layer_dim : int, default=128
@@ -28,24 +20,6 @@ class DefaultNODEConfig:
Depth of each decision tree in the ensemble.
norm : str, default=None
Type of normalization to use in the model.
- use_embeddings : bool, default=False
- Whether to use embedding layers for categorical features.
- embedding_activation : callable, default=nn.Identity()
- Activation function to apply to embeddings.
- embedding_type : str, default="linear"
- Type of embedding to use ('linear', etc.).
- embedding_bias : bool, default=False
- Whether to use bias in the embedding layers.
- layer_norm_after_embedding : bool, default=False
- Whether to apply layer normalization after embedding layers.
- d_model : int, default=32
- Dimensionality of the embedding space.
- plr_lite : bool, default=False
- Whether to use a lightweight version of Piecewise Linear Regression (PLR).
- n_frequencies : int, default=48
- Number of frequencies for PLR embeddings.
- frequencies_init_scale : float, default=0.01
- Initial scale for frequency parameters in embeddings.
head_layer_sizes : list, default=()
Sizes of the layers in the model's head.
head_dropout : float, default=0.5
@@ -58,31 +32,13 @@ class DefaultNODEConfig:
Whether to use batch normalization in the head layers.
"""
- # Optimizer Parameters
- lr: float = 1e-03
- lr_patience: int = 10
- weight_decay: float = 1e-06
- lr_factor: float = 0.1
-
# Architecture Parameters
num_layers: int = 4
layer_dim: int = 128
tree_dim: int = 1
depth: int = 6
-
norm: str | None = None
- # Embedding Parameters
- use_embeddings: bool = False
- embedding_activation: Callable = nn.Identity() # noqa: RUF009
- embedding_type: str = "linear"
- embedding_bias: bool = False
- layer_norm_after_embedding: bool = False
- d_model: int = 32
- plr_lite: bool = False
- n_frequencies: int = 48
- frequencies_init_scale: float = 0.01
-
# Head Parameters
head_layer_sizes: list = field(default_factory=list)
head_dropout: float = 0.5
diff --git a/mambular/configs/resnet_config.py b/mambular/configs/resnet_config.py
index e904957e..7a458d59 100644
--- a/mambular/configs/resnet_config.py
+++ b/mambular/configs/resnet_config.py
@@ -1,23 +1,15 @@
from collections.abc import Callable
from dataclasses import dataclass, field
-
import torch.nn as nn
+from .base_config import BaseConfig
@dataclass
-class DefaultResNetConfig:
+class DefaultResNetConfig(BaseConfig):
"""Configuration class for the default ResNet model with predefined hyperparameters.
Parameters
----------
- lr : float, default=1e-04
- Learning rate for the optimizer.
- lr_patience : int, default=10
- Number of epochs with no improvement after which the learning rate will be reduced.
- weight_decay : float, default=1e-06
- Weight decay (L2 regularization penalty) applied by the optimizer.
- lr_factor : float, default=0.1
- Factor by which the learning rate is reduced when there is no improvement.
layer_sizes : list, default=(256, 128, 32)
Sizes of the layers in the ResNet.
activation : callable, default=nn.SELU()
@@ -32,36 +24,13 @@ class DefaultResNetConfig:
Whether to use Gated Linear Units (GLU) in the ResNet.
skip_connections : bool, default=True
Whether to use skip connections in the ResNet.
- batch_norm : bool, default=True
- Whether to use batch normalization in the ResNet layers.
- layer_norm : bool, default=False
- Whether to use layer normalization in the ResNet layers.
- layer_norm_eps : float, default=1e-05
- Epsilon value for layer normalization.
num_blocks : int, default=3
Number of residual blocks in the ResNet.
- use_embeddings : bool, default=True
- Whether to use embedding layers for all features.
- embedding_type : str, default="linear"
- Type of embedding to use ('linear', etc.).
- embedding_bias : bool, default=False
- Whether to use bias in the embedding layers.
- plr_lite : bool, default=False
- Whether to use a lightweight version of Piecewise Linear Regression (PLR).
average_embeddings : bool, default=True
Whether to average embeddings during the forward pass.
- embedding_activation : callable, default=nn.Identity()
- Activation function for embeddings.
- layer_norm_after_embedding : bool, default=False
- Whether to apply layer normalization after embedding layers.
- d_model : int, default=64
- Dimensionality of the embeddings.
"""
- lr: float = 1e-04
- lr_patience: int = 10
- weight_decay: float = 1e-06
- lr_factor: float = 0.1
+ # model params
layer_sizes: list = field(default_factory=lambda: [256, 128, 32])
activation: Callable = nn.SELU() # noqa: RUF009
skip_layers: bool = False
@@ -69,20 +38,7 @@ class DefaultResNetConfig:
norm: bool = False
use_glu: bool = False
skip_connections: bool = True
- batch_norm: bool = True
- layer_norm: bool = False
- layer_norm_eps: float = 1e-05
num_blocks: int = 3
# embedding params
- use_embeddings: bool = True
- embedding_type: str = "linear"
- embedding_bias = False
- plr_lite: bool = False
average_embeddings: bool = True
- embedding_activation: Callable = nn.Identity() # noqa: RUF009
- layer_norm_after_embedding: bool = False
- d_model: int = 64
- plr_lite: bool = False
- n_frequencies: int = 48
- frequencies_init_scale: float = 0.01
diff --git a/mambular/configs/saint_config.py b/mambular/configs/saint_config.py
index 6c166cb5..3e903692 100644
--- a/mambular/configs/saint_config.py
+++ b/mambular/configs/saint_config.py
@@ -1,29 +1,21 @@
from collections.abc import Callable
from dataclasses import dataclass, field
-
import torch.nn as nn
+from .base_config import BaseConfig
@dataclass
-class DefaultSAINTConfig:
+class DefaultSAINTConfig(BaseConfig):
"""Configuration class for the SAINT model with predefined hyperparameters.
Parameters
----------
- lr : float, default=1e-04
- Learning rate for the optimizer.
- lr_patience : int, default=10
- Number of epochs with no improvement after which the learning rate will be reduced.
- weight_decay : float, default=1e-06
- Weight decay (L2 regularization) for the optimizer.
- lr_factor : float, default=0.1
- Factor by which the learning rate will be reduced.
- d_model : int, default=128
- Dimensionality of the transformer model.
n_layers : int, default=4
Number of transformer layers.
n_heads : int, default=8
Number of attention heads in the transformer.
+ d_model : int, default=128
+ Dimensionality of embeddings or model representations.
attn_dropout : float, default=0.2
Dropout rate for the attention mechanism.
ff_dropout : float, default=0.1
@@ -36,26 +28,10 @@ class DefaultSAINTConfig:
Activation function for the transformer feed-forward layers.
transformer_dim_feedforward : int, default=256
Dimensionality of the feed-forward layers in the transformer.
- layer_norm_eps : float, default=1e-05
- Epsilon value for layer normalization to improve numerical stability.
norm_first : bool, default=False
Whether to apply normalization before other operations in each transformer block.
bias : bool, default=True
Whether to use bias in linear layers.
- embedding_activation : callable, default=nn.Identity()
- Activation function for embeddings.
- embedding_type : str, default="linear"
- Type of embedding to use ('linear', 'plr', etc.).
- plr_lite : bool, default=False
- Whether to use a lightweight version of Piecewise Linear Regression (PLR).
- n_frequencies : int, default=48
- Number of frequencies for PLR embeddings.
- frequencies_init_scale : float, default=0.01
- Initial scale for frequency parameters in embeddings.
- embedding_bias : bool, default=False
- Whether to use bias in embedding layers.
- layer_norm_after_embedding : bool, default=False
- Whether to apply layer normalization after embedding layers.
head_layer_sizes : list, default=()
Sizes of the fully connected layers in the model's head.
head_dropout : float, default=0.5
@@ -74,32 +50,17 @@ class DefaultSAINTConfig:
Method for encoding categorical features ('int', 'one-hot', or 'linear').
"""
- # Optimizer Parameters
- lr: float = 1e-04
- lr_patience: int = 10
- weight_decay: float = 1e-06
- lr_factor: float = 0.1
-
# Architecture Parameters
- d_model: int = 32
+
n_layers: int = 1
n_heads: int = 2
attn_dropout: float = 0.2
ff_dropout: float = 0.1
norm: str = "LayerNorm"
activation: Callable = nn.GELU() # noqa: RUF009
- layer_norm_eps: float = 1e-05
norm_first: bool = False
bias: bool = True
-
- # Embedding Parameters
- embedding_activation: Callable = nn.Identity() # noqa: RUF009
- embedding_type: str = "linear"
- plr_lite: bool = False
- n_frequencies: int = 48
- frequencies_init_scale: float = 0.01
- embedding_bias: bool = False
- layer_norm_after_embedding: bool = False
+ d_model: int = 128
# Head Parameters
head_layer_sizes: list = field(default_factory=list)
diff --git a/mambular/configs/tabm_config.py b/mambular/configs/tabm_config.py
index ee52dc8b..4c4a9314 100644
--- a/mambular/configs/tabm_config.py
+++ b/mambular/configs/tabm_config.py
@@ -1,24 +1,16 @@
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Literal
-
import torch.nn as nn
+from .base_config import BaseConfig
@dataclass
-class DefaultTabMConfig:
+class DefaultTabMConfig(BaseConfig):
"""Configuration class for the TabM model with batch ensembling and predefined hyperparameters.
Parameters
----------
- lr : float, default=1e-04
- Learning rate for the optimizer.
- lr_patience : int, default=10
- Number of epochs with no improvement after which the learning rate will be reduced.
- weight_decay : float, default=1e-06
- Weight decay (L2 penalty) for the optimizer.
- lr_factor : float, default=0.1
- Factor by which the learning rate is reduced when there is no improvement.
layer_sizes : list, default=(512, 512, 128)
Sizes of the layers in the model.
activation : callable, default=nn.ReLU()
@@ -29,32 +21,6 @@ class DefaultTabMConfig:
Normalization method to be used, if any.
use_glu : bool, default=False
Whether to use Gated Linear Units (GLU) in the model.
- batch_norm : bool, default=False
- Whether to use batch normalization in the model layers.
- layer_norm : bool, default=False
- Whether to use layer normalization in the model layers.
- layer_norm_eps : float, default=1e-05
- Epsilon value for layer normalization.
- use_embeddings : bool, default=True
- Whether to use embedding layers for all features.
- embedding_type : str, default="plr"
- Type of embedding to use ('plr', etc.).
- embedding_bias : bool, default=False
- Whether to use bias in the embedding layers.
- plr_lite : bool, default=False
- Whether to use a lightweight version of Piecewise Linear Regression (PLR).
- n_frequencies : int, default=48
- Number of frequencies for PLR embeddings.
- frequencies_init_scale : float, default=0.01
- Initial scale for frequency parameters in embeddings.
- average_embeddings : bool, default=False
- Whether to average embeddings during the forward pass.
- embedding_activation : callable, default=nn.ReLU()
- Activation function for embeddings.
- layer_norm_after_embedding : bool, default=False
- Whether to apply layer normalization after embedding layers.
- d_model : int, default=64
- Dimensionality of the embeddings.
ensemble_size : int, default=32
Number of ensemble members for batch ensembling.
ensemble_scaling_in : bool, default=True
@@ -71,34 +37,12 @@ class DefaultTabMConfig:
Model type to use ('mini' for reduced version, 'full' for complete model).
"""
- # lr params
- lr: float = 1e-04
- lr_patience: int = 10
- weight_decay: float = 1e-05
- lr_factor: float = 0.1
-
# arch params
layer_sizes: list = field(default_factory=lambda: [256, 256, 128])
activation: Callable = nn.ReLU() # noqa: RUF009
dropout: float = 0.5
norm: str | None = None
use_glu: bool = False
- batch_norm: bool = False
- layer_norm: bool = False
- layer_norm_eps: float = 1e-05
-
- # embedding params
- use_embeddings: bool = True
- embedding_type: str = "linear"
- embedding_bias = False
- plr_lite: bool = False
- average_embeddings: bool = False
- embedding_activation: Callable = nn.Identity() # noqa: RUF009
- layer_norm_after_embedding: bool = False
- d_model: int = 32
- plr_lite: bool = False
- n_frequencies: int = 48
- frequencies_init_scale: float = 0.01
# Batch ensembling specific configurations
ensemble_size: int = 32
diff --git a/mambular/configs/tabtransformer_config.py b/mambular/configs/tabtransformer_config.py
index 3cdea5c1..84f16c92 100644
--- a/mambular/configs/tabtransformer_config.py
+++ b/mambular/configs/tabtransformer_config.py
@@ -1,31 +1,22 @@
from collections.abc import Callable
from dataclasses import dataclass, field
-
import torch.nn as nn
-
from ..arch_utils.transformer_utils import ReGLU
+from .base_config import BaseConfig
@dataclass
-class DefaultTabTransformerConfig:
+class DefaultTabTransformerConfig(BaseConfig):
"""Configuration class for the default Tab Transformer model with predefined hyperparameters.
Parameters
----------
- lr : float, default=1e-04
- Learning rate for the optimizer.
- lr_patience : int, default=10
- Number of epochs with no improvement after which learning rate will be reduced.
- weight_decay : float, default=1e-06
- Weight decay (L2 penalty) for the optimizer.
- lr_factor : float, default=0.1
- Factor by which the learning rate will be reduced.
- d_model : int, default=128
- Dimensionality of the model.
n_layers : int, default=4
Number of layers in the transformer.
n_heads : int, default=8
Number of attention heads in the transformer.
+ d_model : int, default=128
+ Dimensionality of embeddings or model representations.
attn_dropout : float, default=0.2
Dropout rate for the attention mechanism.
ff_dropout : float, default=0.1
@@ -38,20 +29,10 @@ class DefaultTabTransformerConfig:
Activation function for the transformer layers.
transformer_dim_feedforward : int, default=512
Dimensionality of the feed-forward layers in the transformer.
- layer_norm_eps : float, default=1e-05
- Epsilon value for layer normalization.
norm_first : bool, default=True
Whether to apply normalization before other operations in each transformer block.
bias : bool, default=True
Whether to use bias in the linear layers.
- embedding_activation : callable, default=nn.Identity()
- Activation function for embeddings.
- embedding_type : str, default="linear"
- Type of embedding to use ('linear', etc.).
- embedding_bias : bool, default=False
- Whether to use bias in the embedding layers.
- layer_norm_after_embedding : bool, default=False
- Whether to apply layer normalization after embedding.
head_layer_sizes : list, default=()
Sizes of the layers in the model's head.
head_dropout : float, default=0.5
@@ -68,14 +49,7 @@ class DefaultTabTransformerConfig:
Encoding method for categorical features ('int', 'one-hot', etc.).
"""
- # Optimizer Parameters
- lr: float = 1e-04
- lr_patience: int = 10
- weight_decay: float = 1e-06
- lr_factor: float = 0.1
-
# Architecture Parameters
- d_model: int = 128
n_layers: int = 4
n_heads: int = 8
attn_dropout: float = 0.2
@@ -84,15 +58,9 @@ class DefaultTabTransformerConfig:
activation: Callable = nn.SELU() # noqa: RUF009
transformer_activation: Callable = ReGLU() # noqa: RUF009
transformer_dim_feedforward: int = 512
- layer_norm_eps: float = 1e-05
norm_first: bool = True
bias: bool = True
-
- # Embedding Parameters
- embedding_activation: Callable = nn.Identity() # noqa: RUF009
- embedding_type: str = "linear"
- embedding_bias: bool = False
- layer_norm_after_embedding: bool = False
+ d_model: int = 128
# Head Parameters
head_layer_sizes: list = field(default_factory=list)
diff --git a/mambular/configs/tabularnn_config.py b/mambular/configs/tabularnn_config.py
index 037c96db..84a9a99f 100644
--- a/mambular/configs/tabularnn_config.py
+++ b/mambular/configs/tabularnn_config.py
@@ -1,51 +1,29 @@
from collections.abc import Callable
from dataclasses import dataclass, field
-
import torch.nn as nn
+from .base_config import BaseConfig
@dataclass
-class DefaultTabulaRNNConfig:
+class DefaultTabulaRNNConfig(BaseConfig):
"""Configuration class for the TabulaRNN model with predefined hyperparameters.
Parameters
----------
- lr : float, default=1e-04
- Learning rate for the optimizer.
- lr_patience : int, default=10
- Number of epochs with no improvement after which learning rate will be reduced.
- weight_decay : float, default=1e-06
- Weight decay (L2 penalty) for the optimizer.
- lr_factor : float, default=0.1
- Factor by which the learning rate will be reduced.
model_type : str, default="RNN"
Type of model, one of "RNN", "LSTM", "GRU", "mLSTM", "sLSTM".
- d_model : int, default=128
- Dimensionality of the model.
n_layers : int, default=4
Number of layers in the RNN.
rnn_dropout : float, default=0.2
Dropout rate for the RNN layers.
+ d_model : int, default=128
+ Dimensionality of embeddings or model representations.
norm : str, default="RMSNorm"
Normalization method to be used.
activation : callable, default=nn.SELU()
Activation function for the RNN layers.
residuals : bool, default=False
Whether to include residual connections in the RNN.
- embedding_type : str, default="linear"
- Type of embedding for features ('linear', 'plr', etc.).
- embedding_bias : bool, default=False
- Whether to use bias in the embedding layers.
- plr_lite : bool, default=False
- Whether to use a lightweight version of Piecewise Linear Regression (PLR).
- n_frequencies : int, default=48
- Number of frequencies for PLR embeddings.
- frequencies_init_scale : float, default=0.01
- Initial scale for frequency parameters in embeddings.
- embedding_activation : callable, default=nn.ReLU()
- Activation function for embeddings.
- layer_norm_after_embedding : bool, default=False
- Whether to apply layer normalization after embedding layers.
head_layer_sizes : list, default=()
Sizes of the layers in the head of the model.
head_dropout : float, default=0.5
@@ -74,12 +52,6 @@ class DefaultTabulaRNNConfig:
Whether to use bias in the convolutional layers.
"""
- # Optimizer params
- lr: float = 1e-04
- lr_patience: int = 10
- weight_decay: float = 1e-06
- lr_factor: float = 0.1
-
# Architecture params
model_type: str = "RNN"
d_model: int = 128
@@ -89,15 +61,6 @@ class DefaultTabulaRNNConfig:
activation: Callable = nn.SELU() # noqa: RUF009
residuals: bool = False
- # Embedding params
- embedding_type: str = "linear"
- embedding_bias: bool = False
- plr_lite: bool = False
- n_frequencies: int = 48
- frequencies_init_scale: float = 0.01
- embedding_activation: Callable = nn.ReLU() # noqa: RUF009
- layer_norm_after_embedding: bool = False
-
# Head params
head_layer_sizes: list = field(default_factory=list)
head_dropout: float = 0.5
diff --git a/mambular/data_utils/datamodule.py b/mambular/data_utils/datamodule.py
index f78d6d8b..40e80f7b 100644
--- a/mambular/data_utils/datamodule.py
+++ b/mambular/data_utils/datamodule.py
@@ -78,6 +78,8 @@ def __init__(
# Initialize placeholders for data
self.X_train = None
self.y_train = None
+ self.embeddings_train = None
+ self.embeddings_val = None
self.test_preprocessor_fitted = False
self.dataloader_kwargs = dataloader_kwargs
@@ -87,6 +89,8 @@ def preprocess_data(
y_train,
X_val=None,
y_val=None,
+ embeddings_train=None,
+ embeddings_val=None,
val_size=0.2,
random_state=101,
):
@@ -98,10 +102,14 @@ def preprocess_data(
Training feature set.
y_train : array-like, shape (n_samples_train,)
Training target values.
+ embeddings_train : array-like or list of array-like, optional
+ Training embeddings if available.
X_val : DataFrame or array-like, shape (n_samples_val, n_features), optional
Validation feature set. If None, a validation set will be created from `X_train`.
y_val : array-like, shape (n_samples_val,), optional
Validation target values. If None, a validation set will be created from `y_train`.
+ embeddings_val : array-like or list of array-like, optional
+ Validation embeddings if available.
val_size : float, optional
Proportion of data to include in the validation split if `X_val` and `y_val` are None.
random_state : int, optional
@@ -113,123 +121,235 @@ def preprocess_data(
"""
if X_val is None or y_val is None:
- self.X_train, self.X_val, self.y_train, self.y_val = train_test_split(
- X_train, y_train, test_size=val_size, random_state=random_state
- )
+ split_data = [X_train, y_train]
+
+ if embeddings_train is not None:
+ if not isinstance(embeddings_train, list):
+ embeddings_train = [embeddings_train]
+ if embeddings_val is not None and not isinstance(embeddings_val, list):
+ embeddings_val = [embeddings_val]
+
+ split_data += embeddings_train
+ split_result = train_test_split(
+ *split_data, test_size=val_size, random_state=random_state
+ )
+
+ self.X_train, self.X_val, self.y_train, self.y_val = split_result[:4]
+ self.embeddings_train = split_result[4::2]
+ self.embeddings_val = split_result[5::2]
+ else:
+ self.X_train, self.X_val, self.y_train, self.y_val = train_test_split(
+ *split_data, test_size=val_size, random_state=random_state
+ )
+ self.embeddings_train = None
+ self.embeddings_val = None
else:
self.X_train = X_train
self.y_train = y_train
self.X_val = X_val
self.y_val = y_val
+ if embeddings_train is not None and embeddings_val is not None:
+ if not isinstance(embeddings_train, list):
+ embeddings_train = [embeddings_train]
+ if not isinstance(embeddings_val, list):
+ embeddings_val = [embeddings_val]
+ self.embeddings_train = embeddings_train
+ self.embeddings_val = embeddings_val
+ else:
+ self.embeddings_train = None
+ self.embeddings_val = None
+
# Fit the preprocessor on the combined training and validation data
- combined_X = pd.concat([self.X_train, self.X_val], axis=0).reset_index(drop=True)
+ combined_X = pd.concat([self.X_train, self.X_val], axis=0).reset_index(
+ drop=True
+ )
combined_y = np.concatenate((self.y_train, self.y_val), axis=0)
- # Fit the preprocessor
- self.preprocessor.fit(combined_X, combined_y)
+ if self.embeddings_train is not None and self.embeddings_val is not None:
+ combined_embeddings = [
+ np.concatenate((emb_train, emb_val), axis=0)
+ for emb_train, emb_val in zip(
+ self.embeddings_train, self.embeddings_val
+ )
+ ]
+ else:
+ combined_embeddings = None
+
+ self.preprocessor.fit(combined_X, combined_y, combined_embeddings)
# Update feature info based on the actual processed data
- (
- self.num_feature_info,
- self.cat_feature_info,
- ) = self.preprocessor.get_feature_info()
+ (self.num_feature_info, self.cat_feature_info, self.embedding_feature_info) = (
+ self.preprocessor.get_feature_info()
+ )
def setup(self, stage: str):
"""Transform the data and create DataLoaders."""
if stage == "fit":
- train_preprocessed_data = self.preprocessor.transform(self.X_train)
- val_preprocessed_data = self.preprocessor.transform(self.X_val)
+ train_preprocessed_data = self.preprocessor.transform(
+ self.X_train, self.embeddings_train
+ )
+ val_preprocessed_data = self.preprocessor.transform(
+ self.X_val, self.embeddings_val
+ )
# Initialize lists for tensors
train_cat_tensors = []
train_num_tensors = []
+ train_emb_tensors = []
val_cat_tensors = []
val_num_tensors = []
+ val_emb_tensors = []
# Populate tensors for categorical features, if present in processed data
for key in self.cat_feature_info: # type: ignore
dtype = (
torch.float32
- if "onehot"
- in self.cat_feature_info[key]["preprocessing"] # type: ignore
+ if any(
+ x in self.cat_feature_info[key]["preprocessing"]
+ for x in ["onehot", "pretrained"]
+ )
else torch.long
)
- cat_key = "cat_" + key # Assuming categorical keys are prefixed with 'cat_'
+ cat_key = "cat_" + str(
+ key
+ ) # Assuming categorical keys are prefixed with 'cat_'
if cat_key in train_preprocessed_data:
- train_cat_tensors.append(torch.tensor(train_preprocessed_data[cat_key], dtype=dtype))
+ train_cat_tensors.append(
+ torch.tensor(train_preprocessed_data[cat_key], dtype=dtype)
+ )
if cat_key in val_preprocessed_data:
- val_cat_tensors.append(torch.tensor(val_preprocessed_data[cat_key], dtype=dtype))
+ val_cat_tensors.append(
+ torch.tensor(val_preprocessed_data[cat_key], dtype=dtype)
+ )
- binned_key = "num_" + key # for binned features
+ binned_key = "num_" + str(key) # for binned features
if binned_key in train_preprocessed_data:
- train_cat_tensors.append(torch.tensor(train_preprocessed_data[binned_key], dtype=dtype))
+ train_cat_tensors.append(
+ torch.tensor(train_preprocessed_data[binned_key], dtype=dtype)
+ )
if binned_key in val_preprocessed_data:
- val_cat_tensors.append(torch.tensor(val_preprocessed_data[binned_key], dtype=dtype))
+ val_cat_tensors.append(
+ torch.tensor(val_preprocessed_data[binned_key], dtype=dtype)
+ )
# Populate tensors for numerical features, if present in processed data
for key in self.num_feature_info: # type: ignore
- num_key = "num_" + key # Assuming numerical keys are prefixed with 'num_'
+ num_key = "num_" + str(
+ key
+ ) # Assuming numerical keys are prefixed with 'num_'
if num_key in train_preprocessed_data:
- train_num_tensors.append(torch.tensor(train_preprocessed_data[num_key], dtype=torch.float32))
+ train_num_tensors.append(
+ torch.tensor(
+ train_preprocessed_data[num_key], dtype=torch.float32
+ )
+ )
if num_key in val_preprocessed_data:
- val_num_tensors.append(torch.tensor(val_preprocessed_data[num_key], dtype=torch.float32))
-
- train_labels = torch.tensor(self.y_train, dtype=self.labels_dtype).unsqueeze(dim=1)
- val_labels = torch.tensor(self.y_val, dtype=self.labels_dtype).unsqueeze(dim=1)
+ val_num_tensors.append(
+ torch.tensor(
+ val_preprocessed_data[num_key], dtype=torch.float32
+ )
+ )
+
+ if self.embedding_feature_info is not None:
+ for key in self.embedding_feature_info:
+ if key in train_preprocessed_data:
+ train_emb_tensors.append(
+ torch.tensor(
+ train_preprocessed_data[key], dtype=torch.float32
+ )
+ )
+ if key in val_preprocessed_data:
+ val_emb_tensors.append(
+ torch.tensor(
+ val_preprocessed_data[key], dtype=torch.float32
+ )
+ )
+
+ train_labels = torch.tensor(
+ self.y_train, dtype=self.labels_dtype
+ ).unsqueeze(dim=1)
+ val_labels = torch.tensor(self.y_val, dtype=self.labels_dtype).unsqueeze(
+ dim=1
+ )
- # Create datasets
self.train_dataset = MambularDataset(
train_cat_tensors,
train_num_tensors,
+ train_emb_tensors,
train_labels,
regression=self.regression,
)
- self.val_dataset = MambularDataset(val_cat_tensors, val_num_tensors, val_labels, regression=self.regression)
- elif stage == "test":
- if not self.test_preprocessor_fitted:
- raise ValueError(
- "The preprocessor has not been fitted. Please fit the preprocessor before transforming the test data."
- )
-
- self.test_dataset = MambularDataset(
- self.test_cat_tensors,
- self.test_num_tensors,
- train_labels, # type: ignore
+ self.val_dataset = MambularDataset(
+ val_cat_tensors,
+ val_num_tensors,
+ val_emb_tensors,
+ val_labels,
regression=self.regression,
)
- def preprocess_test_data(self, X):
- self.test_cat_tensors = []
- self.test_num_tensors = []
- test_preprocessed_data = self.preprocessor.transform(X)
+ def preprocess_new_data(self, X, embeddings):
+ cat_tensors = []
+ num_tensors = []
+ emb_tensors = []
+ preprocessed_data = self.preprocessor.transform(X, embeddings)
# Populate tensors for categorical features, if present in processed data
for key in self.cat_feature_info: # type: ignore
dtype = (
torch.float32
- if "onehot"
- in self.cat_feature_info[key]["preprocessing"] # type: ignore
+ if any(
+ x in self.cat_feature_info[key]["preprocessing"]
+ for x in ["onehot", "pretrained"]
+ )
else torch.long
)
- cat_key = "cat_" + key # Assuming categorical keys are prefixed with 'cat_'
- if cat_key in test_preprocessed_data:
- self.test_cat_tensors.append(torch.tensor(test_preprocessed_data[cat_key], dtype=dtype))
+ cat_key = "cat_" + str(
+ key
+ ) # Assuming categorical keys are prefixed with 'cat_'
+ if cat_key in preprocessed_data:
+ cat_tensors.append(
+ torch.tensor(preprocessed_data[cat_key], dtype=dtype)
+ )
- binned_key = "num_" + key # for binned features
- if binned_key in test_preprocessed_data:
- self.test_cat_tensors.append(torch.tensor(test_preprocessed_data[binned_key], dtype=dtype))
+ binned_key = "num_" + str(key) # for binned features
+ if binned_key in preprocessed_data:
+ cat_tensors.append(
+ torch.tensor(preprocessed_data[binned_key], dtype=dtype)
+ )
# Populate tensors for numerical features, if present in processed data
for key in self.num_feature_info: # type: ignore
- num_key = "num_" + key # Assuming numerical keys are prefixed with 'num_'
- if num_key in test_preprocessed_data:
- self.test_num_tensors.append(torch.tensor(test_preprocessed_data[num_key], dtype=torch.float32))
+ num_key = "num_" + str(
+ key
+ ) # Assuming numerical keys are prefixed with 'num_'
+ if num_key in preprocessed_data:
+ num_tensors.append(
+ torch.tensor(preprocessed_data[num_key], dtype=torch.float32)
+ )
- self.test_preprocessor_fitted = True
- return self.test_cat_tensors, self.test_num_tensors
+ if self.embedding_feature_info is not None:
+ for key in self.embedding_feature_info:
+ if key in preprocessed_data:
+ emb_tensors.append(
+ torch.tensor(preprocessed_data[key], dtype=torch.float32)
+ )
+
+ return MambularDataset(
+ cat_tensors,
+ num_tensors,
+ emb_tensors,
+ labels=None,
+ regression=self.regression,
+ )
+
+ def assign_predict_dataset(self, X, embeddings=None):
+ self.predict_dataset = self.preprocess_new_data(X, embeddings)
+
+ def assign_test_dataset(self, X, embeddings=None):
+ self.test_dataset = self.preprocess_new_data(X, embeddings)
def train_dataloader(self):
"""Returns the training dataloader.
@@ -237,13 +357,15 @@ def train_dataloader(self):
Returns:
DataLoader: DataLoader instance for the training dataset.
"""
-
- return DataLoader(
- self.train_dataset,
- batch_size=self.batch_size,
- shuffle=self.shuffle,
- **self.dataloader_kwargs,
- )
+ if hasattr(self, "train_dataset"):
+ return DataLoader(
+ self.train_dataset,
+ batch_size=self.batch_size,
+ shuffle=self.shuffle,
+ **self.dataloader_kwargs,
+ )
+ else:
+ raise ValueError("No training dataset provided!")
def val_dataloader(self):
"""Returns the validation dataloader.
@@ -251,7 +373,12 @@ def val_dataloader(self):
Returns:
DataLoader: DataLoader instance for the validation dataset.
"""
- return DataLoader(self.val_dataset, batch_size=self.batch_size, **self.dataloader_kwargs)
+ if hasattr(self, "val_dataset"):
+ return DataLoader(
+ self.val_dataset, batch_size=self.batch_size, **self.dataloader_kwargs
+ )
+ else:
+ raise ValueError("No validation dataset provided!")
def test_dataloader(self):
"""Returns the test dataloader.
@@ -259,4 +386,19 @@ def test_dataloader(self):
Returns:
DataLoader: DataLoader instance for the test dataset.
"""
- return DataLoader(self.test_dataset, batch_size=self.batch_size, **self.dataloader_kwargs)
+ if hasattr(self, "test_dataset"):
+ return DataLoader(
+ self.test_dataset, batch_size=self.batch_size, **self.dataloader_kwargs
+ )
+ else:
+ raise ValueError("No test dataset provided!")
+
+ def predict_dataloader(self):
+ if hasattr(self, "predict_dataset"):
+ return DataLoader(
+ self.predict_dataset,
+ batch_size=self.batch_size,
+ **self.dataloader_kwargs,
+ )
+ else:
+ raise ValueError("No predict dataset provided!")
diff --git a/mambular/data_utils/dataset.py b/mambular/data_utils/dataset.py
index b1c07bfc..db6c63a7 100644
--- a/mambular/data_utils/dataset.py
+++ b/mambular/data_utils/dataset.py
@@ -11,28 +11,40 @@ class MambularDataset(Dataset):
----------
cat_features_list (list of Tensors): A list of tensors representing the categorical features.
num_features_list (list of Tensors): A list of tensors representing the numerical features.
- labels (Tensor): A tensor of labels.
+ embeddings_list (list of Tensors, optional): A list of tensors representing the embeddings.
+ labels (Tensor, optional): A tensor of labels. If None, the dataset is used for prediction.
regression (bool, optional): A flag indicating if the dataset is for a regression task. Defaults to True.
"""
- def __init__(self, cat_features_list, num_features_list, labels, regression=True):
+ def __init__(
+ self,
+ cat_features_list,
+ num_features_list,
+ embeddings_list=None,
+ labels=None,
+ regression=True,
+ ):
self.cat_features_list = cat_features_list # Categorical features tensors
self.num_features_list = num_features_list # Numerical features tensors
-
+ self.embeddings_list = embeddings_list # Embeddings tensors (optional)
self.regression = regression
- if not self.regression:
- self.num_classes = len(np.unique(labels))
- if self.num_classes > 2:
- self.labels = labels.view(-1)
+
+ if labels is not None:
+ if not self.regression:
+ self.num_classes = len(np.unique(labels))
+ if self.num_classes > 2:
+ self.labels = labels.view(-1)
+ else:
+ self.num_classes = 1
+ self.labels = labels
else:
- self.num_classes = 1
self.labels = labels
+ self.num_classes = 1
else:
- self.labels = labels
- self.num_classes = 1
+ self.labels = None # No labels in prediction mode
def __len__(self):
- return len(self.labels)
+ return len(self.num_features_list[0]) # Use numerical features length
def __getitem__(self, idx):
"""Retrieves the features and label for a given index.
@@ -43,21 +55,34 @@ def __getitem__(self, idx):
Returns
-------
- tuple: A tuple containing two lists of tensors (one for categorical features and one for numerical
- features) and a single label (float if regression is True).
+ tuple: A tuple containing lists of tensors for numerical features, categorical features, embeddings
+ (if available), and a label (if available).
"""
- cat_features = [feature_tensor[idx] for feature_tensor in self.cat_features_list]
+ cat_features = [
+ feature_tensor[idx] for feature_tensor in self.cat_features_list
+ ]
num_features = [
torch.as_tensor(feature_tensor[idx]).clone().detach().to(torch.float32)
for feature_tensor in self.num_features_list
]
- label = self.labels[idx]
- if self.regression:
- label = label.clone().detach().to(torch.float32)
- elif self.num_classes == 1:
- label = label.clone().detach().to(torch.float32)
+
+ if self.embeddings_list is not None:
+ embeddings = [
+ torch.as_tensor(embed_tensor[idx]).clone().detach().to(torch.float32)
+ for embed_tensor in self.embeddings_list
+ ]
else:
- label = label.clone().detach().to(torch.long)
+ embeddings = None
+
+ if self.labels is not None:
+ label = self.labels[idx]
+ if self.regression:
+ label = label.clone().detach().to(torch.float32)
+ elif self.num_classes == 1:
+ label = label.clone().detach().to(torch.float32)
+ else:
+ label = label.clone().detach().to(torch.long)
- # Keep categorical and numerical features separate
- return cat_features, num_features, label
+ return (num_features, cat_features, embeddings), label
+ else:
+ return (num_features, cat_features, embeddings)
diff --git a/mambular/models/sklearn_base_classifier.py b/mambular/models/sklearn_base_classifier.py
index 35ee4c24..3188a578 100644
--- a/mambular/models/sklearn_base_classifier.py
+++ b/mambular/models/sklearn_base_classifier.py
@@ -1,4 +1,5 @@
import warnings
+from collections.abc import Callable
from typing import Optional
import lightning as pl
@@ -7,19 +8,26 @@
import torch
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, ModelSummary
from sklearn.base import BaseEstimator
-from sklearn.metrics import accuracy_score, log_loss, mean_squared_error
+from sklearn.metrics import accuracy_score, log_loss
from skopt import gp_minimize
+from torch.utils.data import DataLoader
+from tqdm import tqdm
from ..base_models.lightning_wrapper import TaskModel
from ..data_utils.datamodule import MambularDataModule
from ..preprocessing import Preprocessor
-from ..utils.config_mapper import activation_mapper, get_search_space, round_to_nearest_16
+from ..utils.config_mapper import (
+ activation_mapper,
+ get_search_space,
+ round_to_nearest_16,
+)
class SklearnBaseClassifier(BaseEstimator):
def __init__(self, model, config, **kwargs):
self.preprocessor_arg_names = [
"n_bins",
+ "feature_preprocessing",
"numerical_preprocessing",
"categorical_preprocessing",
"use_decision_tree_bins",
@@ -27,16 +35,24 @@ def __init__(self, model, config, **kwargs):
"task",
"cat_cutoff",
"treat_all_integers_as_numerical",
- "knots",
"degree",
+ "scaling_strategy",
+ "n_knots",
+ "use_decision_tree_knots",
+ "knots_strategy",
+ "spline_implementation",
]
self.config_kwargs = {
- k: v for k, v in kwargs.items() if k not in self.preprocessor_arg_names and not k.startswith("optimizer")
+ k: v
+ for k, v in kwargs.items()
+ if k not in self.preprocessor_arg_names and not k.startswith("optimizer")
}
self.config = config(**self.config_kwargs)
- preprocessor_kwargs = {k: v for k, v in kwargs.items() if k in self.preprocessor_arg_names}
+ preprocessor_kwargs = {
+ k: v for k, v in kwargs.items() if k in self.preprocessor_arg_names
+ }
self.preprocessor = Preprocessor(**preprocessor_kwargs)
self.task_model = None
@@ -56,7 +72,8 @@ def __init__(self, model, config, **kwargs):
self.optimizer_kwargs = {
k: v
for k, v in kwargs.items()
- if k not in ["lr", "weight_decay", "patience", "lr_patience", "optimizer_type"]
+ if k
+ not in ["lr", "weight_decay", "patience", "lr_patience", "optimizer_type"]
and k.startswith("optimizer_")
}
@@ -77,7 +94,10 @@ def get_params(self, deep=True):
params.update(self.config_kwargs)
if deep:
- preprocessor_params = {"prepro__" + key: value for key, value in self.preprocessor.get_params().items()}
+ preprocessor_params = {
+ "prepro__" + key: value
+ for key, value in self.preprocessor.get_params().items()
+ }
params.update(preprocessor_params)
return params
@@ -95,8 +115,14 @@ def set_params(self, **parameters):
self : object
Estimator instance.
"""
- config_params = {k: v for k, v in parameters.items() if not k.startswith("prepro__")}
- preprocessor_params = {k.split("__")[1]: v for k, v in parameters.items() if k.startswith("prepro__")}
+ config_params = {
+ k: v for k, v in parameters.items() if not k.startswith("prepro__")
+ }
+ preprocessor_params = {
+ k.split("__")[1]: v
+ for k, v in parameters.items()
+ if k.startswith("prepro__")
+ }
if config_params:
self.config_kwargs.update(config_params)
@@ -104,9 +130,7 @@ def set_params(self, **parameters):
for key, value in config_params.items():
setattr(self.config, key, value)
else:
- self.config = self.config_class( # type: ignore
- **self.config_kwargs
- )
+ self.config = self.config_class(**self.config_kwargs) # type: ignore
if preprocessor_params:
self.preprocessor.set_params(**preprocessor_params)
@@ -120,6 +144,8 @@ def build_model(
val_size: float = 0.2,
X_val=None,
y_val=None,
+ embeddings=None,
+ embeddings_val=None,
random_state: int = 101,
batch_size: int = 128,
shuffle: bool = True,
@@ -127,6 +153,8 @@ def build_model(
lr_patience: int | None = None,
lr_factor: float | None = None,
weight_decay: float | None = None,
+ train_metrics: dict[str, Callable] | None = None,
+ val_metrics: dict[str, Callable] | None = None,
dataloader_kwargs={},
):
"""Builds the model using the provided training data.
@@ -146,7 +174,7 @@ def build_model(
The validation target values. Required if `X_val` is provided.
random_state : int, default=101
Controls the shuffling applied to the data before applying the split.
- batch_size : int, default=64
+ batch_size : int, default=128
Number of samples per gradient update.
shuffle : bool, default=True
Whether to shuffle the training data before each epoch.
@@ -154,8 +182,12 @@ def build_model(
Learning rate for the optimizer.
lr_patience : int, default=10
Number of epochs with no improvement on the validation loss to wait before reducing the learning rate.
- factor : float, default=0.1
+ lr_factor : float, default=0.1
Factor by which the learning rate will be reduced.
+ train_metrics : dict, default=None
+ torch.metrics dict to be logged during training.
+ val_metrics : dict, default=None
+ torch.metrics dict to be logged during validation.
weight_decay : float, default=0.025
Weight decay (L2 penalty) coefficient.
dataloader_kwargs: dict, default={}
@@ -190,7 +222,16 @@ def build_model(
**dataloader_kwargs,
)
- self.data_module.preprocess_data(X, y, X_val, y_val, val_size=val_size, random_state=random_state)
+ self.data_module.preprocess_data(
+ X,
+ y,
+ X_val=X_val,
+ y_val=y_val,
+ embeddings_train=embeddings,
+ embeddings_val=embeddings_val,
+ val_size=val_size,
+ random_state=random_state,
+ )
num_classes = len(np.unique(np.array(y)))
@@ -198,12 +239,21 @@ def build_model(
model_class=self.base_model, # type: ignore
num_classes=num_classes,
config=self.config,
- cat_feature_info=self.data_module.cat_feature_info,
- num_feature_info=self.data_module.num_feature_info,
- lr_patience=(lr_patience if lr_patience is not None else self.config.lr_patience),
+ feature_information=(
+ self.data_module.num_feature_info,
+ self.data_module.cat_feature_info,
+ self.data_module.embedding_feature_info,
+ ),
+ lr_patience=(
+ lr_patience if lr_patience is not None else self.config.lr_patience
+ ),
lr=lr if lr is not None else self.config.lr,
lr_factor=lr_factor if lr_factor is not None else self.config.lr_factor,
- weight_decay=(weight_decay if weight_decay is not None else self.config.weight_decay),
+ weight_decay=(
+ weight_decay if weight_decay is not None else self.config.weight_decay
+ ),
+ train_metrics=train_metrics,
+ val_metrics=val_metrics,
optimizer_type=self.optimizer_type,
optimizer_args=self.optimizer_kwargs,
)
@@ -232,7 +282,9 @@ def get_number_of_params(self, requires_grad=True):
If the model has not been built prior to calling this method.
"""
if not self.built:
- raise ValueError("The model must be built before the number of parameters can be estimated")
+ raise ValueError(
+ "The model must be built before the number of parameters can be estimated"
+ )
else:
if requires_grad:
return sum(p.numel() for p in self.task_model.parameters() if p.requires_grad) # type: ignore
@@ -246,6 +298,8 @@ def fit(
val_size: float = 0.2,
X_val=None,
y_val=None,
+ embeddings=None,
+ embeddings_val=None,
max_epochs: int = 100,
random_state: int = 101,
batch_size: int = 128,
@@ -258,6 +312,8 @@ def fit(
lr_factor: float | None = None,
weight_decay: float | None = None,
checkpoint_path="model_checkpoints",
+ train_metrics: dict[str, Callable] | None = None,
+ val_metrics: dict[str, Callable] | None = None,
dataloader_kwargs={},
rebuild=True,
**trainer_kwargs,
@@ -302,6 +358,10 @@ def fit(
Weight decay (L2 penalty) coefficient.
checkpoint_path : str, default="model_checkpoints"
Path where the checkpoints are being saved.
+ train_metrics : dict, default=None
+ torch.metrics dict to be logged during training.
+ val_metrics : dict, default=None
+ torch.metrics dict to be logged during validation.
dataloader_kwargs: dict, default={}
The kwargs for the pytorch dataloader class.
rebuild: bool, default=True
@@ -321,6 +381,8 @@ def fit(
val_size=val_size,
X_val=X_val,
y_val=y_val,
+ embeddings=embeddings,
+ embeddings_val=embeddings_val,
random_state=random_state,
batch_size=batch_size,
shuffle=shuffle,
@@ -328,6 +390,8 @@ def fit(
lr_patience=lr_patience,
lr_factor=lr_factor,
weight_decay=weight_decay,
+ train_metrics=train_metrics,
+ val_metrics=val_metrics,
dataloader_kwargs=dataloader_kwargs,
)
@@ -365,134 +429,103 @@ def fit(
best_model_path = checkpoint_callback.best_model_path
if best_model_path:
checkpoint = torch.load(best_model_path)
- self.task_model.load_state_dict( # type: ignore
- checkpoint["state_dict"]
- )
+ self.task_model.load_state_dict(checkpoint["state_dict"]) # type: ignore
return self
- def predict(self, X, device=None):
- """Predicts target values for the given input samples.
+ def predict(self, X, embeddings=None, device=None):
+ """Predicts target labels for the given input samples.
Parameters
----------
X : DataFrame or array-like, shape (n_samples, n_features)
The input samples for which to predict target values.
-
Returns
-------
- predictions : ndarray, shape (n_samples,) or (n_samples, n_outputs)
- The predicted target values.
+ predictions : ndarray, shape (n_samples,)
+ The predicted class labels.
"""
# Ensure model and data module are initialized
if self.task_model is None or self.data_module is None:
raise ValueError("The model or data module has not been fitted yet.")
# Preprocess the data using the data module
- cat_tensors, num_tensors = self.data_module.preprocess_test_data(X)
-
- # Move tensors to appropriate device
- if device is None:
- device = next(self.task_model.parameters()).device
- if isinstance(cat_tensors, list):
- cat_tensors = [tensor.to(device) for tensor in cat_tensors]
- else:
- cat_tensors = cat_tensors.to(device)
-
- if isinstance(num_tensors, list):
- num_tensors = [tensor.to(device) for tensor in num_tensors]
- else:
- num_tensors = num_tensors.to(device)
+ self.data_module.assign_predict_dataset(X, embeddings)
# Set model to evaluation mode
self.task_model.eval()
- # Perform inference
- with torch.no_grad():
- logits = self.task_model(num_features=num_tensors, cat_features=cat_tensors)
-
- # Check if ensemble is used
- if hasattr(self.task_model.base_model, "returns_ensemble"): # If using ensemble
- # Average logits across the ensemble dimension (assuming shape: (batch_size, ensemble_size, output_dim))
- logits = logits.mean(dim=1)
- if logits.dim() == 1: # Check if logits has only one dimension (shape (N,))
- logits = logits.unsqueeze(1)
-
- # Check the shape of the logits to determine binary or multi-class classification
- if logits.shape[1] == 1:
- # Binary classification
- probabilities = torch.sigmoid(logits)
- predictions = (probabilities > 0.5).long().squeeze()
- else:
- # Multi-class classification
- probabilities = torch.softmax(logits, dim=1)
- predictions = torch.argmax(probabilities, dim=1)
+ # Perform inference using PyTorch Lightning's predict function
+ logits_list = self.trainer.predict(self.task_model, self.data_module)
+
+ # Concatenate predictions from all batches
+ logits = torch.cat(logits_list, dim=0) # type: ignore
+
+ # Check if ensemble is used
+ if getattr(self.base_model, "returns_ensemble", False): # If using ensemble
+ logits = logits.mean(dim=1) # Average over ensemble dimension
+ if logits.dim() == 1: # Ensure correct shape
+ logits = logits.unsqueeze(1)
+
+ # Check the shape of the logits to determine binary or multi-class classification
+ if logits.shape[1] == 1:
+ # Binary classification
+ probabilities = torch.sigmoid(logits)
+ predictions = (probabilities > 0.5).long().squeeze()
+ else:
+ # Multi-class classification
+ probabilities = torch.softmax(logits, dim=1)
+ predictions = torch.argmax(probabilities, dim=1)
# Convert predictions to NumPy array and return
return predictions.cpu().numpy()
- def predict_proba(self, X, device=None):
- """Predict class probabilities for the given input samples.
+ def predict_proba(self, X, embeddings=None, device=None):
+ """Predicts class probabilities for the given input samples.
Parameters
----------
- X : array-like or pd.DataFrame of shape (n_samples, n_features)
+ X : DataFrame or array-like, shape (n_samples, n_features)
The input samples for which to predict class probabilities.
-
- Notes
- -----
- The method preprocesses the input data using the same preprocessor used during training,
- sets the model to evaluation mode, and then performs inference to predict the class probabilities.
- Softmax is applied to the logits to obtain probabilities, which are then converted from a PyTorch tensor
- to a NumPy array before being returned.
-
Returns
-------
- probabilities : ndarray of shape (n_samples, n_classes)
- Predicted class probabilities for each input sample.
+ probabilities : ndarray, shape (n_samples, n_classes)
+ The predicted class probabilities.
"""
- # Preprocess the data
- if not isinstance(X, pd.DataFrame):
- X = pd.DataFrame(X)
- device = next(self.task_model.parameters()).device # type: ignore
- cat_tensors, num_tensors = self.data_module.preprocess_test_data(X)
- if isinstance(cat_tensors, list):
- cat_tensors = [tensor.to(device) for tensor in cat_tensors]
- else:
- cat_tensors = cat_tensors.to(device)
+ # Ensure model and data module are initialized
+ if self.task_model is None or self.data_module is None:
+ raise ValueError("The model or data module has not been fitted yet.")
- if isinstance(num_tensors, list):
- num_tensors = [tensor.to(device) for tensor in num_tensors]
- else:
- num_tensors = num_tensors.to(device)
+ # Preprocess the data using the data module
+ self.data_module.assign_predict_dataset(X)
- # Set the model to evaluation mode
- self.task_model.eval() # type: ignore
+ # Set model to evaluation mode
+ self.task_model.eval()
- # Perform inference
- with torch.no_grad():
- logits = self.task_model( # type: ignore
- num_features=num_tensors, cat_features=cat_tensors
- )
- # Check if ensemble is used
- # If using ensemble
- if hasattr(self.task_model.base_model, "returns_ensemble"): # type: ignore
- # Average logits across the ensemble dimension
- # (assuming shape: (batch_size, ensemble_size, output_dim))
- logits = logits.mean(dim=1)
- if logits.dim() == 1: # Check if logits has only one dimension (shape (N,))
- logits = logits.unsqueeze(1)
- if logits.shape[1] > 1:
- probabilities = torch.softmax(logits, dim=1)
- else:
- probabilities = torch.sigmoid(logits)
+ # Perform inference using PyTorch Lightning's predict function
+ logits_list = self.trainer.predict(self.task_model, self.data_module)
+
+ # Concatenate predictions from all batches
+ logits = torch.cat(logits_list, dim=0)
+
+ # Check if ensemble is used
+ if getattr(self.base_model, "returns_ensemble", False): # If using ensemble
+ logits = logits.mean(dim=1) # Average over ensemble dimension
+ if logits.dim() == 1: # Ensure correct shape
+ logits = logits.unsqueeze(1)
+
+ # Compute probabilities
+ if logits.shape[1] > 1:
+ probabilities = torch.softmax(logits, dim=1) # Multi-class classification
+ else:
+ probabilities = torch.sigmoid(logits) # Binary classification
# Convert probabilities to NumPy array and return
return probabilities.cpu().numpy()
- def evaluate(self, X, y_true, metrics=None):
+ def evaluate(self, X, y_true, embeddings=None, metrics=None):
"""Evaluate the model on the given data using specified metrics.
Parameters
@@ -501,6 +534,8 @@ def evaluate(self, X, y_true, metrics=None):
The input samples to predict.
y_true : array-like of shape (n_samples,)
The true class labels against which to evaluate the predictions.
+ embneddings : array-like or list of shape(n_samples, dimension)
+ List or array with embeddings for unstructured data inputs
metrics : dict
A dictionary where keys are metric names and values are tuples containing the metric function
and a boolean indicating whether the metric requires probability scores (True) or class labels (False).
@@ -528,11 +563,11 @@ def evaluate(self, X, y_true, metrics=None):
# Generate class probabilities if any metric requires them
if any(use_proba for _, use_proba in metrics.values()):
- probabilities = self.predict_proba(X)
+ probabilities = self.predict_proba(X, embeddings)
# Generate class labels if any metric requires them
if any(not use_proba for _, use_proba in metrics.values()):
- predictions = self.predict(X)
+ predictions = self.predict(X, embeddings)
# Compute each metric
for metric_name, (metric_func, use_proba) in metrics.items():
@@ -543,7 +578,7 @@ def evaluate(self, X, y_true, metrics=None):
return scores
- def score(self, X, y, metric=(log_loss, True)):
+ def score(self, X, y, embeddings=None, metric=(log_loss, True)):
"""Calculate the score of the model using the specified metric.
Parameters
@@ -567,18 +602,61 @@ def score(self, X, y, metric=(log_loss, True)):
X = pd.DataFrame(X)
if use_proba:
- probabilities = self.predict_proba(X)
+ probabilities = self.predict_proba(X, embeddings)
return metric_func(y, probabilities)
else:
- predictions = self.predict(X)
+ predictions = self.predict(X, embeddings)
return metric_func(y, predictions)
+ def encode(self, X, embeddings=None, batch_size=64):
+ """
+ Encodes input data using the trained model's embedding layer.
+
+ Parameters
+ ----------
+ X : array-like or DataFrame
+ Input data to be encoded.
+ batch_size : int, optional, default=64
+ Batch size for encoding.
+
+ Returns
+ -------
+ torch.Tensor
+ Encoded representations of the input data.
+
+ Raises
+ ------
+ ValueError
+ If the model or data module is not fitted.
+ """
+ # Ensure model and data module are initialized
+ if self.task_model is None or self.data_module is None:
+ raise ValueError("The model or data module has not been fitted yet.")
+ encoded_dataset = self.data_module.preprocess_new_data(X, embeddings)
+
+ data_loader = DataLoader(encoded_dataset, batch_size=batch_size, shuffle=False)
+
+ # Process data in batches
+ encoded_outputs = []
+ for batch in tqdm(data_loader):
+ embeddings = self.task_model.base_model.encode(
+ batch
+ ) # Call your encode function
+ encoded_outputs.append(embeddings)
+
+ # Concatenate all encoded outputs
+ encoded_outputs = torch.cat(encoded_outputs, dim=0)
+
+ return encoded_outputs
+
def optimize_hparams(
self,
X,
y,
X_val=None,
y_val=None,
+ embeddings=None,
+ embeddings_val=None,
time=100,
max_epochs=200,
prune_by_epoch=True,
@@ -629,13 +707,25 @@ def optimize_hparams(
)
# Initial model fitting to get the baseline validation loss
- self.fit(X, y, X_val=X_val, y_val=y_val, max_epochs=max_epochs)
+ self.fit(
+ X,
+ y,
+ X_val=X_val,
+ y_val=y_val,
+ embeddings=embeddings,
+ embeddings_val=embeddings_val,
+ max_epochs=max_epochs,
+ )
best_val_loss = float("inf")
if X_val is not None and y_val is not None:
- val_loss = self.evaluate(X_val, y_val, metrics={"Accuracy": (accuracy_score, False)})["Accuracy"]
+ val_loss = self.evaluate(
+ X_val, y_val, metrics={"Accuracy": (accuracy_score, False)}
+ )["Accuracy"]
else:
- val_loss = self.trainer.validate(self.task_model, self.data_module)[0]["val_loss"]
+ val_loss = self.trainer.validate(self.task_model, self.data_module)[0][
+ "val_loss"
+ ]
best_val_loss = val_loss
best_epoch_val_loss = self.task_model.epoch_val_loss_at( # type: ignore
@@ -661,7 +751,9 @@ def _objective(hyperparams):
if param_value in activation_mapper:
setattr(self.config, key, activation_mapper[param_value])
else:
- raise ValueError(f"Unknown activation function: {param_value}")
+ raise ValueError(
+ f"Unknown activation function: {param_value}"
+ )
else:
setattr(self.config, key, param_value)
@@ -670,11 +762,15 @@ def _objective(hyperparams):
self.config.head_layer_sizes = head_layer_sizes[:head_layer_size_length]
# Build the model with updated hyperparameters
- self.build_model(X, y, X_val=X_val, y_val=y_val, lr=self.config.lr, **optimize_kwargs)
+ self.build_model(
+ X, y, X_val=X_val, y_val=y_val, lr=self.config.lr, **optimize_kwargs
+ )
# Dynamically set the early pruning threshold
if prune_by_epoch:
- early_pruning_threshold = best_epoch_val_loss * 1.5 # Prune based on specific epoch loss
+ early_pruning_threshold = (
+ best_epoch_val_loss * 1.5
+ ) # Prune based on specific epoch loss
else:
# Prune based on the best overall validation loss
early_pruning_threshold = best_val_loss * 1.5
@@ -686,15 +782,26 @@ def _objective(hyperparams):
# Fit the model (limit epochs for faster optimization)
try:
# Wrap the risky operation (model fitting) in a try-except block
- self.fit(X, y, X_val=X_val, y_val=y_val, max_epochs=max_epochs, rebuild=False)
+ self.fit(
+ X,
+ y,
+ X_val=X_val,
+ y_val=y_val,
+ embeddings=embeddings,
+ embeddings_val=embeddings_val,
+ max_epochs=max_epochs,
+ rebuild=False,
+ )
# Evaluate validation loss
if X_val is not None and y_val is not None:
- val_loss = self.evaluate(X_val, y_val, metrics={"Mean Squared Error": mean_squared_error})[ # type: ignore
+ val_loss = self.evaluate(X_val, y_val, metrics={"Accuracy": (accuracy_score, False)})[ # type: ignore
"Mean Squared Error"
]
else:
- val_loss = self.trainer.validate(self.task_model, self.data_module)[0]["val_loss"]
+ val_loss = self.trainer.validate(self.task_model, self.data_module)[
+ 0
+ ]["val_loss"]
# Pruning based on validation loss at specific epoch
epoch_val_loss = self.task_model.epoch_val_loss_at( # type: ignore
@@ -711,15 +818,21 @@ def _objective(hyperparams):
except Exception as e:
# Penalize the hyperparameter configuration with a large value
- print(f"Error encountered during fit with hyperparameters {hyperparams}: {e}")
- return best_val_loss * 100 # Large value to discourage this configuration
+ print(
+ f"Error encountered during fit with hyperparameters {hyperparams}: {e}"
+ )
+ return (
+ best_val_loss * 100
+ ) # Large value to discourage this configuration
# Perform Bayesian optimization using scikit-optimize
result = gp_minimize(_objective, param_space, n_calls=time, random_state=42)
# Update the model with the best-found hyperparameters
best_hparams = result.x # type: ignore
- head_layer_sizes = [] if "head_layer_sizes" in self.config.__dataclass_fields__ else None
+ head_layer_sizes = (
+ [] if "head_layer_sizes" in self.config.__dataclass_fields__ else None
+ )
layer_sizes = [] if "layer_sizes" in self.config.__dataclass_fields__ else None
# Iterate over the best hyperparameters found by optimization
diff --git a/mambular/models/sklearn_base_lss.py b/mambular/models/sklearn_base_lss.py
index 8dbd4a71..cfb8d17e 100644
--- a/mambular/models/sklearn_base_lss.py
+++ b/mambular/models/sklearn_base_lss.py
@@ -1,4 +1,5 @@
import warnings
+from collections.abc import Callable
import lightning as pl
import numpy as np
@@ -9,11 +10,17 @@
from sklearn.base import BaseEstimator
from sklearn.metrics import accuracy_score, mean_squared_error
from skopt import gp_minimize
+from torch.utils.data import DataLoader
+from tqdm import tqdm
from ..base_models.lightning_wrapper import TaskModel
from ..data_utils.datamodule import MambularDataModule
from ..preprocessing import Preprocessor
-from ..utils.config_mapper import activation_mapper, get_search_space, round_to_nearest_16
+from ..utils.config_mapper import (
+ activation_mapper,
+ get_search_space,
+ round_to_nearest_16,
+)
from ..utils.distributional_metrics import (
beta_brier_score,
dirichlet_error,
@@ -34,6 +41,7 @@
PoissonDistribution,
Quantile,
StudentTDistribution,
+ JohnsonSuDistribution
)
@@ -41,6 +49,7 @@ class SklearnBaseLSS(BaseEstimator):
def __init__(self, model, config, **kwargs):
self.preprocessor_arg_names = [
"n_bins",
+ "feature_preprocessing",
"numerical_preprocessing",
"categorical_preprocessing",
"use_decision_tree_bins",
@@ -48,8 +57,12 @@ def __init__(self, model, config, **kwargs):
"task",
"cat_cutoff",
"treat_all_integers_as_numerical",
- "knots",
"degree",
+ "scaling_strategy",
+ "n_knots",
+ "use_decision_tree_knots",
+ "knots_strategy",
+ "spline_implementation",
]
self.config_kwargs = {
@@ -126,9 +139,7 @@ def set_params(self, **parameters):
for key, value in config_params.items():
setattr(self.config, key, value)
else:
- self.config = self.config_class( # type: ignore
- **self.config_kwargs
- )
+ self.config = self.config_class(**self.config_kwargs) # type: ignore
if preprocessor_params:
self.preprocessor.set_params(**preprocessor_params)
@@ -149,6 +160,8 @@ def build_model(
lr_patience: int | None = None,
lr_factor: float | None = None,
weight_decay: float | None = None,
+ train_metrics: dict[str, Callable] | None = None,
+ val_metrics: dict[str, Callable] | None = None,
dataloader_kwargs={},
):
"""Builds the model using the provided training data.
@@ -176,8 +189,12 @@ def build_model(
Learning rate for the optimizer.
lr_patience : int, default=10
Number of epochs with no improvement on the validation loss to wait before reducing the learning rate.
- factor : float, default=0.1
+ lr_factor : float, default=0.1
Factor by which the learning rate will be reduced.
+ train_metrics : dict, default=None
+ torch.metrics dict to be logged during training.
+ val_metrics : dict, default=None
+ torch.metrics dict to be logged during validation.
weight_decay : float, default=0.025
Weight decay (L2 penalty) coefficient.
dataloader_kwargs: dict, default={}
@@ -224,6 +241,8 @@ def build_model(
lr_factor=lr_factor if lr_factor is not None else self.config.lr_factor,
weight_decay=(weight_decay if weight_decay is not None else self.config.weight_decay),
lss=True,
+ train_metrics=train_metrics,
+ val_metrics=val_metrics,
optimizer_type=self.optimizer_type,
optimizer_args=self.optimizer_kwargs,
)
@@ -280,6 +299,8 @@ def fit(
weight_decay: float | None = None,
checkpoint_path="model_checkpoints",
distributional_kwargs=None,
+ train_metrics: dict[str, Callable] | None = None,
+ val_metrics: dict[str, Callable] | None = None,
dataloader_kwargs={},
rebuild=True,
**trainer_kwargs,
@@ -327,6 +348,10 @@ def fit(
Weight decay (L2 penalty) coefficient.
distributional_kwargs : dict, default=None
any arguments taht are specific for a certain distribution.
+ train_metrics : dict, default=None
+ torch.metrics dict to be logged during training.
+ val_metrics : dict, default=None
+ torch.metrics dict to be logged during validation.
checkpoint_path : str, default="model_checkpoints"
Path where the checkpoints are being saved.
dataloader_kwargs: dict, default={}
@@ -350,6 +375,7 @@ def fit(
"inversegamma": InverseGammaDistribution,
"categorical": CategoricalDistribution,
"quantile": Quantile,
+ "johnsonsu": JohnsonSuDistribution,
}
if distributional_kwargs is None:
@@ -373,6 +399,8 @@ def fit(
lr=lr,
lr_patience=lr_patience,
lr_factor=lr_factor,
+ train_metrics=train_metrics,
+ val_metrics=val_metrics,
weight_decay=weight_decay,
dataloader_kwargs=dataloader_kwargs,
)
@@ -411,9 +439,7 @@ def fit(
best_model_path = checkpoint_callback.best_model_path
if best_model_path:
checkpoint = torch.load(best_model_path)
- self.task_model.load_state_dict( # type: ignore
- checkpoint["state_dict"]
- )
+ self.task_model.load_state_dict(checkpoint["state_dict"]) # type: ignore
return self
@@ -436,32 +462,20 @@ def predict(self, X, raw=False, device=None):
raise ValueError("The model or data module has not been fitted yet.")
# Preprocess the data using the data module
- cat_tensors, num_tensors = self.data_module.preprocess_test_data(X)
-
- # Move tensors to appropriate device
- if device is not None:
- device = next(self.task_model.parameters()).device
- if isinstance(cat_tensors, list):
- cat_tensors = [tensor.to(device) for tensor in cat_tensors]
- else:
- cat_tensors = cat_tensors.to(device)
-
- if isinstance(num_tensors, list):
- num_tensors = [tensor.to(device) for tensor in num_tensors]
- else:
- num_tensors = num_tensors.to(device)
+ self.data_module.assign_predict_dataset(X)
# Set model to evaluation mode
self.task_model.eval()
- # Perform inference
- with torch.no_grad():
- predictions = self.task_model(num_features=num_tensors, cat_features=cat_tensors)
+ # Perform inference using PyTorch Lightning's predict function
+ predictions_list = self.trainer.predict(self.task_model, self.data_module)
+
+ # Concatenate predictions from all batches
+ predictions = torch.cat(predictions_list, dim=0)
# Check if ensemble is used
if getattr(self.base_model, "returns_ensemble", False): # If using ensemble
- # Average over the ensemble dimension (assuming shape: (batch_size, ensemble_size, output_dim))
- predictions = predictions.mean(dim=1)
+ predictions = predictions.mean(dim=1) # Average over ensemble dimension
if not raw:
result = self.task_model.family(predictions).cpu().numpy() # type: ignore
@@ -566,11 +580,48 @@ def score(self, X, y, metric="NLL"):
The score calculated using the specified metric.
"""
predictions = self.predict(X)
- score = self.task_model.family.evaluate_nll( # type: ignore
- y, predictions
- )
+ score = self.task_model.family.evaluate_nll(y, predictions) # type: ignore
return score
+ def encode(self, X, batch_size=64):
+ """
+ Encodes input data using the trained model's embedding layer.
+
+ Parameters
+ ----------
+ X : array-like or DataFrame
+ Input data to be encoded.
+ batch_size : int, optional, default=64
+ Batch size for encoding.
+
+ Returns
+ -------
+ torch.Tensor
+ Encoded representations of the input data.
+
+ Raises
+ ------
+ ValueError
+ If the model or data module is not fitted.
+ """
+ # Ensure model and data module are initialized
+ if self.task_model is None or self.data_module is None:
+ raise ValueError("The model or data module has not been fitted yet.")
+ encoded_dataset = self.data_module.preprocess_new_data(X)
+
+ data_loader = DataLoader(encoded_dataset, batch_size=batch_size, shuffle=False)
+
+ # Process data in batches
+ encoded_outputs = []
+ for num_features, cat_features in tqdm(data_loader):
+ embeddings = self.task_model.base_model.encode(num_features, cat_features) # Call your encode function
+ encoded_outputs.append(embeddings)
+
+ # Concatenate all encoded outputs
+ encoded_outputs = torch.cat(encoded_outputs, dim=0)
+
+ return encoded_outputs
+
def optimize_hparams(
self,
X,
diff --git a/mambular/models/sklearn_base_regressor.py b/mambular/models/sklearn_base_regressor.py
index de8999d6..210b785e 100644
--- a/mambular/models/sklearn_base_regressor.py
+++ b/mambular/models/sklearn_base_regressor.py
@@ -1,4 +1,5 @@
import warnings
+from collections.abc import Callable
import lightning as pl
import pandas as pd
@@ -7,6 +8,8 @@
from sklearn.base import BaseEstimator
from sklearn.metrics import mean_squared_error
from skopt import gp_minimize
+from torch.utils.data import DataLoader
+from tqdm import tqdm
from ..base_models.lightning_wrapper import TaskModel
from ..data_utils.datamodule import MambularDataModule
@@ -18,6 +21,7 @@ class SklearnBaseRegressor(BaseEstimator):
def __init__(self, model, config, **kwargs):
self.preprocessor_arg_names = [
"n_bins",
+ "feature_preprocessing",
"numerical_preprocessing",
"categorical_preprocessing",
"use_decision_tree_bins",
@@ -26,6 +30,7 @@ def __init__(self, model, config, **kwargs):
"cat_cutoff",
"treat_all_integers_as_numerical",
"degree",
+ "scaling_strategy",
"n_knots",
"use_decision_tree_knots",
"knots_strategy",
@@ -105,9 +110,7 @@ def set_params(self, **parameters):
for key, value in config_params.items():
setattr(self.config, key, value)
else:
- self.config = self.config_class( # type: ignore
- **self.config_kwargs
- )
+ self.config = self.config_class(**self.config_kwargs) # type: ignore
if preprocessor_params:
self.preprocessor.set_params(**preprocessor_params)
@@ -121,6 +124,8 @@ def build_model(
val_size: float = 0.2,
X_val=None,
y_val=None,
+ embeddings=None,
+ embeddings_val=None,
random_state: int = 101,
batch_size: int = 128,
shuffle: bool = True,
@@ -128,6 +133,8 @@ def build_model(
lr_patience: int | None = None,
lr_factor: float | None = None,
weight_decay: float | None = None,
+ train_metrics: dict[str, Callable] | None = None,
+ val_metrics: dict[str, Callable] | None = None,
dataloader_kwargs={},
):
"""Builds the model using the provided training data.
@@ -159,6 +166,10 @@ def build_model(
Factor by which the learning rate will be reduced.
weight_decay : float, default=0.025
Weight decay (L2 penalty) coefficient.
+ train_metrics : dict, default=None
+ torch.metrics dict to be logged during training.
+ val_metrics : dict, default=None
+ torch.metrics dict to be logged during validation.
dataloader_kwargs: dict, default={}
The kwargs for the pytorch dataloader class.
@@ -191,17 +202,31 @@ def build_model(
**dataloader_kwargs,
)
- self.data_module.preprocess_data(X, y, X_val, y_val, val_size=val_size, random_state=random_state)
+ self.data_module.preprocess_data(
+ X,
+ y,
+ X_val=X_val,
+ y_val=y_val,
+ embeddings_train=embeddings,
+ embeddings_val=embeddings_val,
+ val_size=val_size,
+ random_state=random_state,
+ )
self.task_model = TaskModel(
model_class=self.base_model, # type: ignore
config=self.config,
- cat_feature_info=self.data_module.cat_feature_info,
- num_feature_info=self.data_module.num_feature_info,
+ feature_information=(
+ self.data_module.num_feature_info,
+ self.data_module.cat_feature_info,
+ self.data_module.embedding_feature_info,
+ ),
lr=lr if lr is not None else self.config.lr,
lr_patience=(lr_patience if lr_patience is not None else self.config.lr_patience),
lr_factor=lr_factor if lr_factor is not None else self.config.lr_factor,
weight_decay=(weight_decay if weight_decay is not None else self.config.weight_decay),
+ train_metrics=train_metrics,
+ val_metrics=val_metrics,
optimizer_type=self.optimizer_type,
optimizer_args=self.optimizer_kwargs,
)
@@ -244,6 +269,8 @@ def fit(
val_size: float = 0.2,
X_val=None,
y_val=None,
+ embeddings=None,
+ embeddings_val=None,
max_epochs: int = 100,
random_state: int = 101,
batch_size: int = 128,
@@ -257,6 +284,8 @@ def fit(
weight_decay: float | None = None,
checkpoint_path="model_checkpoints",
dataloader_kwargs={},
+ train_metrics: dict[str, Callable] | None = None,
+ val_metrics: dict[str, Callable] | None = None,
rebuild=True,
**trainer_kwargs,
):
@@ -302,6 +331,12 @@ def fit(
Path where the checkpoints are being saved.
dataloader_kwargs: dict, default={}
The kwargs for the pytorch dataloader class.
+ train_metrics : dict, default=None
+ torch.metrics dict to be logged during training.
+ val_metrics : dict, default=None
+ torch.metrics dict to be logged during validation.
+ rebuild: bool, default=True
+ Whether to rebuild the model when it already was built.
**trainer_kwargs : Additional keyword arguments for PyTorch Lightning's Trainer class.
@@ -317,6 +352,8 @@ def fit(
val_size=val_size,
X_val=X_val,
y_val=y_val,
+ embeddings=embeddings,
+ embeddings_val=embeddings_val,
random_state=random_state,
batch_size=batch_size,
shuffle=shuffle,
@@ -325,6 +362,8 @@ def fit(
lr_factor=lr_factor,
weight_decay=weight_decay,
dataloader_kwargs=dataloader_kwargs,
+ train_metrics=train_metrics,
+ val_metrics=val_metrics,
)
else:
@@ -361,13 +400,11 @@ def fit(
best_model_path = checkpoint_callback.best_model_path
if best_model_path:
checkpoint = torch.load(best_model_path)
- self.task_model.load_state_dict( # type: ignore
- checkpoint["state_dict"]
- )
+ self.task_model.load_state_dict(checkpoint["state_dict"]) # type: ignore
return self
- def predict(self, X, device=None):
+ def predict(self, X, embeddings=None, device=None):
"""Predicts target values for the given input samples.
Parameters
@@ -386,37 +423,25 @@ def predict(self, X, device=None):
raise ValueError("The model or data module has not been fitted yet.")
# Preprocess the data using the data module
- cat_tensors, num_tensors = self.data_module.preprocess_test_data(X)
-
- # Move tensors to appropriate device
- if device is None:
- device = next(self.task_model.parameters()).device
- if isinstance(cat_tensors, list):
- cat_tensors = [tensor.to(device) for tensor in cat_tensors]
- else:
- cat_tensors = cat_tensors.to(device)
-
- if isinstance(num_tensors, list):
- num_tensors = [tensor.to(device) for tensor in num_tensors]
- else:
- num_tensors = num_tensors.to(device)
+ self.data_module.assign_predict_dataset(X, embeddings)
# Set model to evaluation mode
self.task_model.eval()
- # Perform inference
- with torch.no_grad():
- predictions = self.task_model(num_features=num_tensors, cat_features=cat_tensors)
+ # Perform inference using PyTorch Lightning's predict function
+ predictions_list = self.trainer.predict(self.task_model, self.data_module)
+
+ # Concatenate predictions from all batches
+ predictions = torch.cat(predictions_list, dim=0) # type: ignore
# Check if ensemble is used
- if hasattr(self.task_model.base_model, "returns_ensemble"): # If using ensemble
- # Average over the ensemble dimension (assuming shape: (batch_size, ensemble_size, output_dim))
- predictions = predictions.mean(dim=1)
+ if getattr(self.base_model, "returns_ensemble", False): # If using ensemble
+ predictions = predictions.mean(dim=1) # Average over ensemble dimension
# Convert predictions to NumPy array and return
return predictions.cpu().numpy()
- def evaluate(self, X, y_true, metrics=None):
+ def evaluate(self, X, y_true, embeddings=None, metrics=None):
"""Evaluate the model on the given data using specified metrics.
Parameters
@@ -442,7 +467,7 @@ def evaluate(self, X, y_true, metrics=None):
metrics = {"Mean Squared Error": mean_squared_error}
# Generate predictions using the trained model
- predictions = self.predict(X)
+ predictions = self.predict(X, embeddings=embeddings)
# Initialize dictionary to store results
scores = {}
@@ -453,7 +478,7 @@ def evaluate(self, X, y_true, metrics=None):
return scores
- def score(self, X, y, metric=mean_squared_error):
+ def score(self, X, y, embeddings=None, metric=mean_squared_error):
"""Calculate the score of the model using the specified metric.
Parameters
@@ -470,15 +495,56 @@ def score(self, X, y, metric=mean_squared_error):
score : float
The score calculated using the specified metric.
"""
- predictions = self.predict(X)
+ predictions = self.predict(X, embeddings)
return metric(y, predictions)
+ def encode(self, X, embeddings=None, batch_size=64):
+ """
+ Encodes input data using the trained model's embedding layer.
+
+ Parameters
+ ----------
+ X : array-like or DataFrame
+ Input data to be encoded.
+ batch_size : int, optional, default=64
+ Batch size for encoding.
+
+ Returns
+ -------
+ torch.Tensor
+ Encoded representations of the input data.
+
+ Raises
+ ------
+ ValueError
+ If the model or data module is not fitted.
+ """
+ # Ensure model and data module are initialized
+ if self.task_model is None or self.data_module is None:
+ raise ValueError("The model or data module has not been fitted yet.")
+ encoded_dataset = self.data_module.preprocess_new_data(X, embeddings)
+
+ data_loader = DataLoader(encoded_dataset, batch_size=batch_size, shuffle=False)
+
+ # Process data in batches
+ encoded_outputs = []
+ for batch in tqdm(data_loader):
+ embeddings = self.task_model.base_model.encode(batch) # Call your encode function
+ encoded_outputs.append(embeddings)
+
+ # Concatenate all encoded outputs
+ encoded_outputs = torch.cat(encoded_outputs, dim=0)
+
+ return encoded_outputs
+
def optimize_hparams(
self,
X,
y,
X_val=None,
y_val=None,
+ embeddings=None,
+ embeddings_val=None,
time=100,
max_epochs=200,
prune_by_epoch=True,
@@ -529,7 +595,15 @@ def optimize_hparams(
)
# Initial model fitting to get the baseline validation loss
- self.fit(X, y, X_val=X_val, y_val=y_val, max_epochs=max_epochs)
+ self.fit(
+ X,
+ y,
+ X_val=X_val,
+ y_val=y_val,
+ embeddings=embeddings,
+ embeddings_val=embeddings_val,
+ max_epochs=max_epochs,
+ )
best_val_loss = float("inf")
if X_val is not None and y_val is not None:
@@ -572,7 +646,16 @@ def _objective(hyperparams):
self.config.head_layer_sizes = head_layer_sizes[:head_layer_size_length]
# Build the model with updated hyperparameters
- self.build_model(X, y, X_val=X_val, y_val=y_val, lr=self.config.lr, **optimize_kwargs)
+ self.build_model(
+ X,
+ y,
+ X_val=X_val,
+ y_val=y_val,
+ embeddings=embeddings,
+ embeddings_val=embeddings_val,
+ lr=self.config.lr,
+ **optimize_kwargs,
+ )
# Dynamically set the early pruning threshold
if prune_by_epoch:
diff --git a/mambular/preprocessing/basis_expansion.py b/mambular/preprocessing/basis_expansion.py
index 59dab27f..8ee46ed4 100644
--- a/mambular/preprocessing/basis_expansion.py
+++ b/mambular/preprocessing/basis_expansion.py
@@ -1,8 +1,10 @@
import numpy as np
from scipy.interpolate import BSpline
from sklearn.base import BaseEstimator, TransformerMixin
+from sklearn.metrics import pairwise_distances
from sklearn.preprocessing import SplineTransformer
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
+from sklearn.utils.validation import check_array
class SplineExpansion(BaseEstimator, TransformerMixin):
@@ -41,6 +43,22 @@ def __init__(
if spline_implementation not in ["scipy", "sklearn"]:
raise ValueError("Invalid spline implementation. Choose 'scipy' or 'sklearn'.")
+ @staticmethod
+ def knot_identification_using_decision_tree(X, y, task="regression", n_knots=5):
+ # Use DecisionTreeClassifier for classification tasks
+ knots = []
+ if task == "classification":
+ tree = DecisionTreeClassifier(max_leaf_nodes=n_knots + 1)
+ elif task == "regression":
+ tree = DecisionTreeRegressor(max_leaf_nodes=n_knots + 1)
+ else:
+ raise ValueError("Invalid task type. Choose 'regression' or 'classification'.")
+ tree.fit(X, y)
+ # Extract thresholds from the decision tree
+ thresholds = tree.tree_.threshold[tree.tree_.threshold != -2] # type: ignore
+ knots.append(np.sort(thresholds))
+ return knots
+
def fit(self, X, y=None):
"""
Fit the preprocessor by determining the knot positions.
@@ -52,43 +70,58 @@ def fit(self, X, y=None):
Returns:
- self: Fitted preprocessor.
"""
- X = np.asarray(X)
+ if self.use_decision_tree and y is None:
+ raise ValueError("Target variable 'y' must be provided when use_decision_tree=True.")
- if self.use_decision_tree:
- if y is None:
- raise ValueError("Target variable 'y' must be provided when use_decision_tree=True.")
- y = np.asarray(y)
-
- self.knots = []
- for i in range(X.shape[1]):
- x_col = X[:, i].reshape(-1, 1)
-
- # Use DecisionTreeClassifier for classification tasks
- if self.task == "classification":
- tree = DecisionTreeClassifier(max_leaf_nodes=self.n_knots + 1)
- elif self.task == "regression":
- tree = DecisionTreeRegressor(max_leaf_nodes=self.n_knots + 1)
- else:
- raise ValueError("Invalid task type. Choose 'regression' or 'classification'.")
-
- tree.fit(x_col, y)
-
- # Extract thresholds from the decision tree
- thresholds = tree.tree_.threshold[tree.tree_.threshold != -2] # type: ignore
- self.knots.append(np.sort(thresholds))
- else:
- # Compute knots based on uniform spacing or quantile
- self.knots = []
- for i in range(X.shape[1]):
- if self.strategy == "quantile":
- # Use quantile to determine knot locations
- quantiles = np.linspace(0, 1, self.n_knots + 2)[1:-1]
- knots = np.quantile(X[:, i], quantiles)
- self.knots.append(knots)
- elif self.strategy == "uniform":
- # Use uniform spacing within the range of the feature
- knots = np.linspace(np.min(X[:, i]), np.max(X[:, i]), self.n_knots + 2)[1:-1]
- self.knots.append(knots)
+ self.knots = []
+ self.n_features_in_ = X.shape[1]
+
+ if self.use_decision_tree and self.spline_implementation == "scipy":
+ self.knots = self.knot_identification_using_decision_tree(X, y, self.task, self.n_knots)
+ self.fitted = True
+
+ elif self.spline_implementation == "scipy" and not self.use_decision_tree:
+ if self.strategy == "quantile":
+ # Use quantile to determine knot locations
+ quantiles = np.linspace(0, 1, self.n_knots + 2)[1:-1]
+ knots = np.quantile(X, quantiles)
+ self.knots.append(knots)
+ self.fitted = True
+ # print("Scipy spline implementation using quantile works in fit phase")
+ elif self.strategy == "uniform":
+ # Use uniform spacing within the range of the feature
+ knots = np.linspace(np.min(X), np.max(X), self.n_knots + 2)[1:-1]
+ self.knots.append(knots)
+ self.fitted = True
+ # print("Scipy spline implementation using uniform works in fit phase")
+
+ elif self.use_decision_tree and self.spline_implementation == "sklearn":
+ self.knots = self.knot_identification_using_decision_tree(X, y, self.task, self.n_knots)
+ knots = np.vstack(self.knots).T
+ self.transformer = SplineTransformer(
+ n_knots=self.n_knots, degree=self.degree, include_bias=False, knots=knots
+ )
+ self.transformer.fit(X)
+ self.fitted = True
+
+ elif self.spline_implementation == "sklearn" and not self.use_decision_tree:
+ if self.strategy == "quantile":
+ # print("Using sklearn spline transformer using quantile")
+ # print()
+ self.transformer = SplineTransformer(
+ n_knots=self.n_knots, degree=self.degree, include_bias=False, knots="quantile"
+ )
+ self.fitted = True
+ self.transformer.fit(X)
+
+ elif self.strategy == "uniform":
+ # print("Using sklearn spline transformer using uniform")
+ # print()
+ self.transformer = SplineTransformer(
+ n_knots=self.n_knots, degree=self.degree, include_bias=False, knots="uniform"
+ )
+ self.fitted = True
+ self.transformer.fit(X)
return self
@@ -105,43 +138,148 @@ def transform(self, X):
if self.knots is None:
raise ValueError("Knots have not been initialized. Please fit the preprocessor first.")
- X = np.asarray(X)
transformed_features = []
+ if self.fitted is False:
+ raise ValueError("Model has not been fitted. Please fit the model first.")
+
if self.spline_implementation == "scipy":
- for i in range(X.shape[1]):
- x_col = X[:, i]
- knots = self.knots[i] # type: ignore
+ # Extend the knots for boundary conditions
+ t = np.concatenate(([self.knots[0]] * self.degree, self.knots, [self.knots[-1]] * self.degree))
- # Extend the knots for boundary conditions
- t = np.concatenate(([knots[0]] * self.degree, knots, [knots[-1]] * self.degree))
+ # Create spline basis functions for this feature
+ spline_basis = [
+ BSpline.basis_element(t[j : j + self.degree + 2])(X) for j in range(len(t) - self.degree - 1)
+ ]
+ # Stack and append transformed features
+ transformed_features.append(np.vstack(spline_basis).T)
+ # Concatenate all transformed features
+ return np.hstack(transformed_features)
+ elif self.spline_implementation == "sklearn":
+ return self.transformer.transform(X)
- # Create spline basis functions for this feature
- spline_basis = [
- BSpline.basis_element(t[j : j + self.degree + 2])(x_col) for j in range(len(t) - self.degree - 1)
- ]
- # Stack and append transformed features
- transformed_features.append(np.vstack(spline_basis).T)
+def center_identification_using_decision_tree(X, y, task="regression", n_centers=5):
+ # Use DecisionTreeClassifier for classification tasks
+ centers = []
+ if task == "classification":
+ tree = DecisionTreeClassifier(max_leaf_nodes=n_centers + 1)
+ elif task == "regression":
+ tree = DecisionTreeRegressor(max_leaf_nodes=n_centers + 1)
+ else:
+ raise ValueError("Invalid task type. Choose 'regression' or 'classification'.")
+ tree.fit(X, y)
+ # Extract thresholds from the decision tree
+ thresholds = tree.tree_.threshold[tree.tree_.threshold != -2] # type: ignore
+ centers.append(np.sort(thresholds))
+ return centers
- # Concatenate all transformed features
- return np.hstack(transformed_features)
+
+class RBFExpansion(BaseEstimator, TransformerMixin):
+ def __init__(
+ self, n_centers=10, gamma: float = 1.0, use_decision_tree=True, task: str = "regression", strategy="uniform"
+ ):
+ """
+ Radial Basis Function Expansion.
+
+ Parameters:
+ - n_centers: Number of RBF centers.
+ - gamma: Width of the RBF kernel.
+ - use_decision_tree: If True, use a decision tree to determine RBF centers.
+ - task: Task type, 'regression' or 'classification'.
+ - strategy: If 'uniform', centers are uniformly spaced. If 'quantile', centers are
+ determined by data quantile.
+ """
+ self.n_centers = n_centers
+ self.gamma = gamma
+ self.use_decision_tree = use_decision_tree
+ self.strategy = strategy
+ self.task = task
+
+ if self.strategy not in ["uniform", "quantile"]:
+ raise ValueError("Invalid strategy. Choose 'uniform' or 'quantile'.")
+
+ def fit(self, X, y=None):
+ X = check_array(X)
+
+ if self.use_decision_tree and y is None:
+ raise ValueError("Target variable 'y' must be provided when use_decision_tree=True.")
+
+ if self.use_decision_tree:
+ self.centers_ = center_identification_using_decision_tree(X, y, self.task, self.n_centers)
+ self.centers_ = np.vstack(self.centers_)
else:
- if self.use_decision_tree:
- knots = np.vstack(self.knots).T
- transformer = SplineTransformer(
- n_knots=self.n_knots, degree=self.degree, include_bias=False, knots=knots
- )
- else:
- if self.strategy == "quantile":
- transformer = SplineTransformer(
- n_knots=self.n_knots, degree=self.degree, include_bias=False, knots="quantile"
- )
- elif self.strategy == "uniform":
- transformer = SplineTransformer(
- n_knots=self.n_knots, degree=self.degree, include_bias=False, knots="uniform"
- )
- else:
- raise ValueError("Invalid strategy for knot location calculation. Choose 'quantile' or 'uniform'.")
-
- return transformer.fit_transform(X)
+ # Compute centers
+ if self.strategy == "quantile":
+ self.centers_ = np.percentile(X, np.linspace(0, 100, self.n_centers), axis=0)
+ elif self.strategy == "uniform":
+ self.centers_ = np.linspace(X.min(axis=0), X.max(axis=0), self.n_centers)
+
+ # Compute gamma if not provided
+ # if self.gamma is None:
+ # dists = pairwise_distances(self.centers_)
+ # self.gamma = 1 / (2 * np.mean(dists[dists > 0]) ** 2) # Mean pairwise distance
+ return self
+
+ def transform(self, X):
+ X = check_array(X)
+ transformed = []
+ self.centers_ = np.array(self.centers_)
+ for center in self.centers_.T:
+ rbf_features = np.exp(-self.gamma * (X - center) ** 2) # type: ignore
+ transformed.append(rbf_features)
+ return np.hstack(transformed)
+
+
+class SigmoidExpansion(BaseEstimator, TransformerMixin):
+ def __init__(
+ self, n_centers=10, scale: float = 1.0, use_decision_tree=True, task: str = "regression", strategy="uniform"
+ ):
+ """
+ Sigmoid Basis Expansion.
+
+ Parameters:
+ - n_centers: Number of sigmoid centers.
+ - scale: Scale parameter for sigmoid function.
+ - use_decision_tree: If True, use a decision tree to determine sigmoid centers.
+ - task: Task type, 'regression' or 'classification'.
+ - strategy: If 'uniform', centers are uniformly spaced. If 'quantile', centers are
+ determined by data quantile.
+ """
+ self.n_centers = n_centers
+ self.scale = scale
+ self.use_decision_tree = use_decision_tree
+ self.strategy = strategy
+ self.task = task
+
+ def fit(self, X, y=None):
+ X = check_array(X)
+
+ if self.use_decision_tree and y is None:
+ raise ValueError("Target variable 'y' must be provided when use_decision_tree=True.")
+
+ if self.use_decision_tree:
+ self.centers_ = center_identification_using_decision_tree(X, y, self.task, self.n_centers)
+ self.centers_ = np.vstack(self.centers_)
+ else:
+ # Compute centers
+ if self.strategy == "quantile":
+ self.centers_ = np.percentile(X, np.linspace(0, 100, self.n_centers), axis=0)
+ elif self.strategy == "uniform":
+ self.centers_ = np.linspace(X.min(axis=0), X.max(axis=0), self.n_centers)
+
+ # Compute gamma if not provided
+ # if self.gamma is None:
+ # dists = pairwise_distances(self.centers_)
+ # self.gamma = 1 / (2 * np.mean(dists[dists > 0]) ** 2) # Mean pairwise distance
+ return self
+
+ def transform(self, X):
+ X = check_array(X)
+ transformed = []
+
+ self.centers_ = np.array(self.centers_)
+ for center in self.centers_.T:
+ sigmoid_features = 1 / (1 + np.exp(-(X - center) / self.scale))
+ transformed.append(sigmoid_features)
+ return np.hstack(transformed)
diff --git a/mambular/preprocessing/ple_encoding.py b/mambular/preprocessing/ple_encoding.py
index a75a47fa..3a70f248 100644
--- a/mambular/preprocessing/ple_encoding.py
+++ b/mambular/preprocessing/ple_encoding.py
@@ -74,6 +74,7 @@ def __init__(self, n_bins=20, tree_params={}, task="regression", conditions=None
self.pattern = r"-?\d+\.?\d*[eE]?[+-]?\d*"
def fit(self, feature, target):
+ self.n_features_in_ = 1
if self.task == "regression":
dt = DecisionTreeRegressor(max_leaf_nodes=self.n_bins)
elif self.task == "classification":
@@ -84,6 +85,7 @@ def fit(self, feature, target):
dt.fit(feature, target)
self.conditions = tree_to_code(dt, ["feature"])
+ # self.fitted = True
return self
def transform(self, feature):
diff --git a/mambular/preprocessing/prepro_utils.py b/mambular/preprocessing/prepro_utils.py
index 704091f7..4c9e9645 100644
--- a/mambular/preprocessing/prepro_utils.py
+++ b/mambular/preprocessing/prepro_utils.py
@@ -10,6 +10,7 @@ def __init__(self, bins):
def fit(self, X, y=None):
# Fit doesn't need to do anything as we are directly using provided bins
+ self.n_features_in_ = 1
return self
def transform(self, X):
@@ -172,6 +173,7 @@ def fit(self, X, y=None):
Returns:
self: Returns the instance itself.
"""
+ self.n_features_in_ = 1
return self
def transform(self, X):
@@ -203,7 +205,58 @@ class ToFloatTransformer(TransformerMixin, BaseEstimator):
"""A transformer that converts input data to float type."""
def fit(self, X, y=None):
+ self.n_features_in_ = 1
return self
def transform(self, X):
return X.astype(float)
+
+
+class LanguageEmbeddingTransformer(TransformerMixin, BaseEstimator):
+ """A transformer that encodes categorical text features into embeddings using a pre-trained language model."""
+
+ def __init__(self, model_name="paraphrase-MiniLM-L3-v2", model=None):
+ """
+ Initializes the transformer with a language embedding model.
+
+ Parameters:
+ - model_name (str): The name of the SentenceTransformer model to use (if model is None).
+ - model (object, optional): A preloaded SentenceTransformer model instance.
+ """
+ self.model_name = model_name
+ self.model = model # Allow user to pass a preloaded model
+
+ if self.model is None:
+ try:
+ from sentence_transformers import SentenceTransformer
+
+ self.model = SentenceTransformer(model_name)
+ except ImportError as e:
+ raise ImportError(
+ "sentence-transformers is not installed. Install it via `pip install sentence-transformers` or provide a preloaded model."
+ ) from e
+
+ def fit(self, X, y=None):
+ """Fit method (not required for a transformer but included for compatibility)."""
+ self.n_features_in_ = X.shape[1] if len(X.shape) > 1 else 1
+ return self
+
+ def transform(self, X):
+ """
+ Transforms input categorical text features into numerical embeddings.
+
+ Parameters:
+ - X: A 1D or 2D array-like of categorical text features.
+
+ Returns:
+ - A 2D numpy array with embeddings for each text input.
+ """
+ if isinstance(X, np.ndarray):
+ X = X.flatten().astype(str).tolist() # Convert to a list of strings if passed as an array
+ elif isinstance(X, list):
+ X = [str(x) for x in X] # Ensure everything is a string
+
+ if self.model is None:
+ raise ValueError("Model is not initialized. Ensure that the model is properly loaded.")
+ embeddings = self.model.encode(X, convert_to_numpy=True) # Get sentence embeddings
+ return embeddings
diff --git a/mambular/preprocessing/preprocessor.py b/mambular/preprocessing/preprocessor.py
index c3f9b6d7..44c21193 100644
--- a/mambular/preprocessing/preprocessor.py
+++ b/mambular/preprocessing/preprocessor.py
@@ -17,9 +17,16 @@
)
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
-from .basis_expansion import SplineExpansion
+from .basis_expansion import RBFExpansion, SigmoidExpansion, SplineExpansion
from .ple_encoding import PLE
-from .prepro_utils import ContinuousOrdinalEncoder, CustomBinner, NoTransformer, OneHotFromOrdinal, ToFloatTransformer
+from .prepro_utils import (
+ ContinuousOrdinalEncoder,
+ CustomBinner,
+ LanguageEmbeddingTransformer,
+ NoTransformer,
+ OneHotFromOrdinal,
+ ToFloatTransformer,
+)
class Preprocessor:
@@ -33,12 +40,20 @@ class Preprocessor:
Parameters
----------
+ feature_preprocessing: dict or None
+ Dictionary mapping column names to preprocessing techniques. Example:
+ {
+ "num_feature1": "minmax",
+ "num_feature2": "ple",
+ "cat_feature1": "one-hot",
+ "cat_feature2": "int"
+ }
n_bins : int, default=50
The number of bins to use for numerical feature binning. This parameter is relevant
only if `numerical_preprocessing` is set to 'binning', 'ple' or 'one-hot'.
numerical_preprocessing : str, default="ple"
The preprocessing strategy for numerical features. Valid options are
- 'ple', 'binning', 'one-hot', 'standardization', 'min-max', 'quantile', 'polynomial', 'robust',
+ 'ple', 'binning', 'one-hot', 'standardization', 'min-max', 'quantile', 'polynomial', 'robust', 'rbf', 'sigmoid'.
'splines', 'box-cox', 'yeo-johnson' and None
categorical_preprocessing : str, default="int"
The preprocessing strategy for categorical features. Valid options are
@@ -60,6 +75,9 @@ class Preprocessor:
treat_all_integers_as_numerical : bool, default=False
If True, all integer columns will be treated as numerical, regardless
of their unique value count or proportion.
+ scaling_strategy : str, default="minmax"
+ The scaling strategy to use for numerical features before applying PLE, Splines, RBF or Sigmoid.
+ Options include 'standardization', 'minmax', 'none'.
degree : int, default=3
The degree of the polynomial features to be used in preprocessing. It also affects the degree of
splines if splines are used.
@@ -67,7 +85,7 @@ class Preprocessor:
The number of knots to be used in spline transformations.
use_decision_tree_knots : bool, default=True
If True, uses decision tree regression to determine optimal knot positions for splines.
- knots_strategy : str, default="uniform"
+ knots_strategy : str, default="quantile"
Defines the strategy for determining knot positions in spline transformations
if `use_decision_tree_knots` is False. Options include 'uniform', 'quantile'.
spline_implementation : str, default="sklearn"
@@ -84,6 +102,7 @@ class Preprocessor:
def __init__(
self,
+ feature_preprocessing=None,
n_bins=64,
numerical_preprocessing="ple",
categorical_preprocessing="int",
@@ -93,6 +112,7 @@ def __init__(
cat_cutoff=0.03,
treat_all_integers_as_numerical=False,
degree=3,
+ scaling_strategy="minmax",
n_knots=64,
use_decision_tree_knots=True,
knots_strategy="uniform",
@@ -117,17 +137,28 @@ def __init__(
"splines",
"box-cox",
"yeo-johnson",
+ "rbf",
+ "sigmoid",
"none",
]:
raise ValueError(
"Invalid numerical_preprocessing value. Supported values are 'ple', 'binning', 'box-cox', \
- 'one-hot', 'standardization', 'quantile', 'polynomial', 'splines', 'minmax' , 'robust' or 'None'."
+ 'one-hot', 'standardization', 'quantile', 'polynomial', 'splines', 'minmax' , 'robust',\
+ 'rbf', 'sigmoid', or 'None'."
)
- if self.categorical_preprocessing not in ["int", "one-hot", "none"]:
- raise ValueError("invalid categorical_preprocessing value. Supported values are 'int' and 'one-hot'")
+ if self.categorical_preprocessing not in [
+ "int",
+ "one-hot",
+ "pretrained",
+ "none",
+ ]:
+ raise ValueError(
+ "invalid categorical_preprocessing value. Supported values are 'int', 'pretrained', 'none' and 'one-hot'"
+ )
self.use_decision_tree_bins = use_decision_tree_bins
+ self.feature_preprocessing = feature_preprocessing or {}
self.column_transformer = None
self.fitted = False
self.binning_strategy = binning_strategy
@@ -135,6 +166,7 @@ def __init__(
self.cat_cutoff = cat_cutoff
self.treat_all_integers_as_numerical = treat_all_integers_as_numerical
self.degree = degree
+ self.scaling_strategy = scaling_strategy
self.n_knots = n_knots
self.use_decision_tree_knots = use_decision_tree_knots
self.knots_strategy = knots_strategy
@@ -163,6 +195,7 @@ def get_params(self, deep=True):
"cat_cutoff": self.cat_cutoff,
"treat_all_integers_as_numerical": self.treat_all_integers_as_numerical,
"degree": self.degree,
+ "scaling_strategy": self.scaling_strategy,
"n_knots": self.n_knots,
"use_decision_tree_knots": self.use_decision_tree_knots,
"knots_strategy": self.knots_strategy,
@@ -227,7 +260,19 @@ def _detect_column_types(self, X):
return numerical_features, categorical_features
- def fit(self, X, y=None):
+ def _fit_embeddings(self, embeddings):
+ if embeddings is not None:
+ self.embeddings = True
+ self.embedding_dimensions = {}
+ if isinstance(embeddings, np.ndarray):
+ self.embedding_dimensions["embeddings_1"] = embeddings.shape[1]
+ elif isinstance(embeddings, list) and all(isinstance(e, np.ndarray) for e in embeddings):
+ for idx, e in enumerate(embeddings):
+ self.embedding_dimensions[f"embedding_{idx + 1}"] = e.shape[1]
+ else:
+ self.embeddings = False
+
+ def fit(self, X, y=None, embeddings=None):
"""Fits the preprocessor to the data by identifying feature types and configuring the appropriate
transformations for each feature. It sets up a column transformer with a pipeline of transformations for
numerical and categorical features based on the specified preprocessing strategy.
@@ -246,11 +291,15 @@ def fit(self, X, y=None):
if isinstance(X, dict):
X = pd.DataFrame(X)
+ self._fit_embeddings(embeddings)
+
numerical_features, categorical_features = self._detect_column_types(X)
transformers = []
if numerical_features:
for feature in numerical_features:
+ feature_preprocessing = self.feature_preprocessing.get(feature, self.numerical_preprocessing)
+
# extended the annotation list if new transformer is added, either from sklearn or custom
numeric_transformer_steps: list[
tuple[
@@ -268,10 +317,12 @@ def fit(self, X, y=None):
| PLE
| PowerTransformer
| NoTransformer
- | SplineExpansion,
+ | SplineExpansion
+ | RBFExpansion
+ | SigmoidExpansion,
]
] = [("imputer", SimpleImputer(strategy="mean"))]
- if self.numerical_preprocessing in ["binning", "one-hot"]:
+ if feature_preprocessing in ["binning", "one-hot"]:
bins = (
self._get_decision_tree_bins(X[[feature]], y, [feature])
if self.use_decision_tree_bins
@@ -301,20 +352,20 @@ def fit(self, X, y=None):
]
)
- if self.numerical_preprocessing == "one-hot":
+ if feature_preprocessing == "one-hot":
numeric_transformer_steps.extend(
[
("onehot_from_ordinal", OneHotFromOrdinal()),
]
)
- elif self.numerical_preprocessing == "standardization":
+ elif feature_preprocessing == "standardization":
numeric_transformer_steps.append(("scaler", StandardScaler()))
- elif self.numerical_preprocessing == "minmax":
+ elif feature_preprocessing == "minmax":
numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1))))
- elif self.numerical_preprocessing == "quantile":
+ elif feature_preprocessing == "quantile":
numeric_transformer_steps.append(
(
"quantile",
@@ -322,8 +373,11 @@ def fit(self, X, y=None):
)
)
- elif self.numerical_preprocessing == "polynomial":
- numeric_transformer_steps.append(("scaler", StandardScaler()))
+ elif feature_preprocessing == "polynomial":
+ if self.scaling_strategy == "standardization":
+ numeric_transformer_steps.append(("scaler", StandardScaler()))
+ elif self.scaling_strategy == "minmax":
+ numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1))))
numeric_transformer_steps.append(
(
"polynomial",
@@ -331,11 +385,14 @@ def fit(self, X, y=None):
)
)
- elif self.numerical_preprocessing == "robust":
+ elif feature_preprocessing == "robust":
numeric_transformer_steps.append(("robust", RobustScaler()))
- elif self.numerical_preprocessing == "splines":
- numeric_transformer_steps.append(("scaler", StandardScaler()))
+ elif feature_preprocessing == "splines":
+ if self.scaling_strategy == "standardization":
+ numeric_transformer_steps.append(("scaler", StandardScaler()))
+ elif self.scaling_strategy == "minmax":
+ numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1))))
numeric_transformer_steps.append(
(
"splines",
@@ -350,11 +407,51 @@ def fit(self, X, y=None):
),
)
- elif self.numerical_preprocessing == "ple":
+ elif feature_preprocessing == "rbf":
+ if self.scaling_strategy == "standardization":
+ numeric_transformer_steps.append(("scaler", StandardScaler()))
+ elif self.scaling_strategy == "minmax":
+ numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1))))
+ numeric_transformer_steps.append(
+ (
+ "rbf",
+ RBFExpansion(
+ n_centers=self.n_knots,
+ use_decision_tree=self.use_decision_tree_knots,
+ strategy=self.knots_strategy,
+ task=self.task,
+ ),
+ )
+ )
+
+ elif feature_preprocessing == "sigmoid":
+ if self.scaling_strategy == "standardization":
+ numeric_transformer_steps.append(("scaler", StandardScaler()))
+ elif self.scaling_strategy == "minmax":
+ numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1))))
+ numeric_transformer_steps.append(
+ (
+ "sigmoid",
+ SigmoidExpansion(
+ n_centers=self.n_knots,
+ use_decision_tree=self.use_decision_tree_knots,
+ strategy=self.knots_strategy,
+ task=self.task,
+ ),
+ )
+ )
+
+ elif feature_preprocessing == "ple":
numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1))))
numeric_transformer_steps.append(("ple", PLE(n_bins=self.n_bins, task=self.task)))
- elif self.numerical_preprocessing == "box-cox":
+ elif feature_preprocessing == "box-cox":
+ numeric_transformer_steps.append(
+ ("minmax", MinMaxScaler(feature_range=(1e-03, 1))) # type: ignore
+ )
+ numeric_transformer_steps.append(
+ ("check_positive", MinMaxScaler(feature_range=(1e-3, 1))) # type: ignore
+ )
numeric_transformer_steps.append(
(
"box-cox",
@@ -362,7 +459,7 @@ def fit(self, X, y=None):
)
)
- elif self.numerical_preprocessing == "yeo-johnson":
+ elif feature_preprocessing == "yeo-johnson":
numeric_transformer_steps.append(
(
"yeo-johnson",
@@ -370,7 +467,7 @@ def fit(self, X, y=None):
)
)
- elif self.numerical_preprocessing == "none":
+ elif feature_preprocessing == "none":
numeric_transformer_steps.append(
(
"none",
@@ -384,7 +481,8 @@ def fit(self, X, y=None):
if categorical_features:
for feature in categorical_features:
- if self.categorical_preprocessing == "int":
+ feature_preprocessing = self.feature_preprocessing.get(feature, self.categorical_preprocessing)
+ if feature_preprocessing == "int":
# Use ContinuousOrdinalEncoder for "int"
categorical_transformer = Pipeline(
[
@@ -392,7 +490,7 @@ def fit(self, X, y=None):
("continuous_ordinal", ContinuousOrdinalEncoder()),
]
)
- elif self.categorical_preprocessing == "one-hot":
+ elif feature_preprocessing == "one-hot":
# Use OneHotEncoder for "one-hot"
categorical_transformer = Pipeline(
[
@@ -402,7 +500,7 @@ def fit(self, X, y=None):
]
)
- elif self.categorical_preprocessing == "none":
+ elif feature_preprocessing == "none":
# Use OneHotEncoder for "one-hot"
categorical_transformer = Pipeline(
[
@@ -410,8 +508,15 @@ def fit(self, X, y=None):
("none", NoTransformer()),
]
)
+ elif feature_preprocessing == "pretrained":
+ categorical_transformer = Pipeline(
+ [
+ ("imputer", SimpleImputer(strategy="most_frequent")),
+ ("pretrained", LanguageEmbeddingTransformer()),
+ ]
+ )
else:
- raise ValueError(f"Unknown categorical_preprocessing type: {self.categorical_preprocessing}")
+ raise ValueError(f"Unknown categorical_preprocessing type: {feature_preprocessing}")
# Append the transformer for the current categorical feature
transformers.append((f"cat_{feature}", categorical_transformer, [feature]))
@@ -451,7 +556,7 @@ def _get_decision_tree_bins(self, X, y, numerical_features):
bins.append(np.concatenate(([X[feature].min()], bin_edges, [X[feature].max()])))
return bins
- def transform(self, X):
+ def transform(self, X, embeddings=None):
"""Transforms the input data using the preconfigured column transformer and converts the output into a
dictionary format with keys corresponding to transformed feature names and values as arrays of transformed data.
@@ -466,8 +571,7 @@ def transform(self, X):
Parameters
----------
X (DataFrame): The input data to be transformed.
- X (DataFrame): The input data to be transformed.
-
+ embeddings (np.array or list of np.arrays, optional): The embedding data to include in the transformation.
Returns
-------
@@ -482,6 +586,30 @@ def transform(self, X):
# Now let's convert this into a dictionary of arrays, one per column
transformed_dict = self._split_transformed_output(X, transformed_X)
+ if embeddings is not None:
+ if not self.embeddings:
+ raise ValueError("self.embeddings should be True but is not.")
+
+ if isinstance(embeddings, np.ndarray):
+ if self.embedding_dimensions["embedding_1"] != embeddings.shape[1]:
+ raise ValueError(
+ f"Expected embedding dimension {self.embedding_dimensions['embedding_1']}, "
+ f"but got {embeddings.shape[1]}"
+ )
+ transformed_dict["embedding_1"] = embeddings.astype(np.float32)
+ elif isinstance(embeddings, list) and all(isinstance(e, np.ndarray) for e in embeddings):
+ for idx, e in enumerate(embeddings):
+ key = f"embedding_{idx + 1}"
+ if self.embedding_dimensions[key] != e.shape[1]:
+ raise ValueError(
+ f"Expected embedding dimension {self.embedding_dimensions[key]} for {key}, but got {e.shape[1]}"
+ )
+ transformed_dict[key] = e.astype(np.float32)
+ else:
+ if self.embeddings is not False:
+ raise ValueError("self.embeddings should be False when embeddings are None.")
+ self.embeddings = False
+
return transformed_dict
def _split_transformed_output(self, X, transformed_X):
@@ -520,7 +648,7 @@ def _split_transformed_output(self, X, transformed_X):
start = end
return transformed_dict
- def fit_transform(self, X, y=None):
+ def fit_transform(self, X, y=None, embeddings=None):
"""Fits the preprocessor to the data and then transforms the data using the fitted preprocessing pipelines. This
is a convenience method that combines `fit` and `transform`.
@@ -535,9 +663,9 @@ def fit_transform(self, X, y=None):
dict: A dictionary with the transformed data, where keys are the base feature names and
values are the transformed features as arrays.
"""
- self.fit(X, y)
+ self.fit(X, y, embeddings)
self.fitted = True
- return self.transform(X)
+ return self.transform(X, embeddings)
def get_feature_info(self, verbose=True):
"""Retrieves information about how features are encoded within the model's preprocessor. This method identifies
@@ -547,24 +675,34 @@ def get_feature_info(self, verbose=True):
This method should only be called after the preprocessor has been fitted, as it relies on the structure and
configuration of the `column_transformer` attribute.
-
Raises
------
RuntimeError: If the `column_transformer` is not yet fitted, indicating that the preprocessor must be
fitted before invoking this method.
-
Returns
-------
- tuple of (dict, dict):
+ tuple of (dict, dict, dict):
- The first dictionary maps feature names to their respective number of bins or categories if they are
processed using discretization or ordinal encoding.
- The second dictionary includes feature names with other encoding details, such as the dimension of
features after encoding transformations (e.g., one-hot encoding dimensions).
+ - The third dictionary includes feature information for embeddings if available.
"""
numerical_feature_info = {}
categorical_feature_info = {}
+ if self.embeddings:
+ embedding_feature_info = {}
+ for key, dim in self.embedding_dimensions.items():
+ embedding_feature_info[key] = {
+ "preprocessing": None,
+ "dimension": dim,
+ "categories": None,
+ }
+ else:
+ embedding_feature_info = {}
+
if not self.column_transformer:
raise RuntimeError("The preprocessor has not been fitted yet.")
@@ -576,12 +714,10 @@ def get_feature_info(self, verbose=True):
steps = [step[0] for step in transformer_pipeline.steps]
for feature_name in columns:
- # Initialize common fields
preprocessing_type = " -> ".join(steps)
dimension = None
categories = None
- # Numerical features
if "discretizer" in steps or any(
step in steps
for step in [
@@ -590,27 +726,26 @@ def get_feature_info(self, verbose=True):
"quantile",
"polynomial",
"splines",
+ "box-cox",
]
):
last_step = transformer_pipeline.steps[-1][1]
if hasattr(last_step, "transform"):
- # Single-column input for dimension check
- dummy_input = np.zeros((1, 1))
+ dummy_input = np.zeros((1, 1)) + 1e-05
transformed_feature = last_step.transform(dummy_input)
dimension = transformed_feature.shape[1]
numerical_feature_info[feature_name] = {
"preprocessing": preprocessing_type,
"dimension": dimension,
- "categories": None, # Numerical features don't have categories
+ "categories": None,
}
if verbose:
print(f"Numerical Feature: {feature_name}, Info: {numerical_feature_info[feature_name]}")
- # Categorical features
elif "continuous_ordinal" in steps:
step = transformer_pipeline.named_steps["continuous_ordinal"]
categories = len(step.mapping_[columns.index(feature_name)])
- dimension = 1 # Ordinal encoding always outputs one dimension
+ dimension = 1
categorical_feature_info[feature_name] = {
"preprocessing": preprocessing_type,
"dimension": dimension,
@@ -625,7 +760,7 @@ def get_feature_info(self, verbose=True):
step = transformer_pipeline.named_steps["onehot"]
if hasattr(step, "categories_"):
categories = sum(len(cat) for cat in step.categories_)
- dimension = categories # One-hot encoding expands into multiple dimensions
+ dimension = categories
categorical_feature_info[feature_name] = {
"preprocessing": preprocessing_type,
"dimension": dimension,
@@ -636,7 +771,6 @@ def get_feature_info(self, verbose=True):
f"Categorical Feature (One-Hot): {feature_name}, Info: {categorical_feature_info[feature_name]}"
)
- # Fallback for other transformations
else:
last_step = transformer_pipeline.steps[-1][1]
if hasattr(last_step, "transform"):
@@ -647,13 +781,13 @@ def get_feature_info(self, verbose=True):
categorical_feature_info[feature_name] = {
"preprocessing": preprocessing_type,
"dimension": dimension,
- "categories": None, # Categories not defined for unknown categorical transformations
+ "categories": None,
}
else:
numerical_feature_info[feature_name] = {
"preprocessing": preprocessing_type,
"dimension": dimension,
- "categories": None, # Numerical features don't have categories
+ "categories": None,
}
if verbose:
print(f"Feature: {feature_name}, Info: {preprocessing_type}, Dimension: {dimension}")
@@ -661,4 +795,9 @@ def get_feature_info(self, verbose=True):
if verbose:
print("-" * 50)
- return numerical_feature_info, categorical_feature_info
+ if verbose and self.embeddings:
+ print("Embeddings:")
+ for key, value in embedding_feature_info.items():
+ print(f" Feature: {key}, Dimension: {value['dimension']}")
+
+ return numerical_feature_info, categorical_feature_info, embedding_feature_info
diff --git a/mambular/utils/distributions.py b/mambular/utils/distributions.py
index 374d0101..75395ce9 100644
--- a/mambular/utils/distributions.py
+++ b/mambular/utils/distributions.py
@@ -116,7 +116,9 @@ def forward(self, predictions):
"""
transformed_params = []
for idx, param_name in enumerate(self.param_names):
- transform_func = self.get_transform(getattr(self, f"{param_name}_transform", "none"))
+ transform_func = self.get_transform(
+ getattr(self, f"{param_name}_transform", "none")
+ )
transformed_params.append(
transform_func(predictions[:, idx]).unsqueeze( # type: ignore
1
@@ -153,7 +155,9 @@ def __init__(self, name="Normal", mean_transform="none", var_transform="positive
def compute_loss(self, predictions, y_true):
mean = self.mean_transform(predictions[:, self.param_names.index("mean")])
- variance = self.variance_transform(predictions[:, self.param_names.index("variance")])
+ variance = self.variance_transform(
+ predictions[:, self.param_names.index("variance")]
+ )
normal_dist = dist.Normal(mean, variance)
@@ -167,10 +171,14 @@ def evaluate_nll(self, y_true, y_pred):
y_true_tensor = torch.tensor(y_true, dtype=torch.float32)
y_pred_tensor = torch.tensor(y_pred, dtype=torch.float32)
- mse_loss = torch.nn.functional.mse_loss(y_true_tensor, y_pred_tensor[:, self.param_names.index("mean")])
+ mse_loss = torch.nn.functional.mse_loss(
+ y_true_tensor, y_pred_tensor[:, self.param_names.index("mean")]
+ )
rmse = np.sqrt(mse_loss.detach().numpy())
mae = (
- torch.nn.functional.l1_loss(y_true_tensor, y_pred_tensor[:, self.param_names.index("mean")])
+ torch.nn.functional.l1_loss(
+ y_true_tensor, y_pred_tensor[:, self.param_names.index("mean")]
+ )
.detach()
.numpy()
)
@@ -228,7 +236,9 @@ def evaluate_nll(self, y_true, y_pred):
.detach()
.numpy() # type: ignore
) # type: ignore
- poisson_deviance = 2 * torch.sum(y_true_tensor * torch.log(y_true_tensor / rate) - (y_true_tensor - rate))
+ poisson_deviance = 2 * torch.sum(
+ y_true_tensor * torch.log(y_true_tensor / rate) - (y_true_tensor - rate)
+ )
metrics["mse"] = mse_loss.detach().numpy()
metrics["mae"] = mae
@@ -367,7 +377,9 @@ class GammaDistribution(BaseDistribution):
rate_transform (str or callable): Transformation for the rate parameter to ensure it remains positive.
"""
- def __init__(self, name="Gamma", shape_transform="positive", rate_transform="positive"):
+ def __init__(
+ self, name="Gamma", shape_transform="positive", rate_transform="positive"
+ ):
param_names = ["shape", "rate"]
super().__init__(name, param_names)
@@ -434,10 +446,16 @@ def evaluate_nll(self, y_true, y_pred):
y_true_tensor = torch.tensor(y_true, dtype=torch.float32)
y_pred_tensor = torch.tensor(y_pred, dtype=torch.float32)
- mse_loss = torch.nn.functional.mse_loss(y_true_tensor, y_pred_tensor[:, self.param_names.index("loc")])
+ mse_loss = torch.nn.functional.mse_loss(
+ y_true_tensor, y_pred_tensor[:, self.param_names.index("loc")]
+ )
rmse = np.sqrt(mse_loss.detach().numpy())
mae = (
- torch.nn.functional.l1_loss(y_true_tensor, y_pred_tensor[:, self.param_names.index("loc")]).detach().numpy()
+ torch.nn.functional.l1_loss(
+ y_true_tensor, y_pred_tensor[:, self.param_names.index("loc")]
+ )
+ .detach()
+ .numpy()
)
metrics["mse"] = mse_loss.detach().numpy()
@@ -478,7 +496,9 @@ def __init__(
def compute_loss(self, predictions, y_true):
# Apply transformations to ensure mean and dispersion parameters are positive
mean = self.mean_transform(predictions[:, self.param_names.index("mean")])
- dispersion = self.dispersion_transform(predictions[:, self.param_names.index("dispersion")])
+ dispersion = self.dispersion_transform(
+ predictions[:, self.param_names.index("dispersion")]
+ )
# Calculate the probability (p) and number of successes (r) from mean and dispersion
# These calculations follow from the mean and variance of the negative binomial distribution
@@ -574,3 +594,77 @@ def compute_loss(self, predictions, y_true):
# Sum losses across quantiles and compute mean
loss = torch.mean(torch.stack(losses, dim=1).sum(dim=1))
return loss
+
+
+class JohnsonSuDistribution(BaseDistribution):
+ """
+ Represents a Johnson's SU distribution with parameters for skewness, shape, location, and scale.
+
+ Parameters
+ ----------
+ name (str): The name of the distribution. Defaults to "JohnsonSu".
+ skew_transform (str or callable): The transformation for the skewness parameter. Defaults to "none".
+ shape_transform (str or callable): The transformation for the shape parameter. Defaults to "positive".
+ loc_transform (str or callable): The transformation for the location parameter. Defaults to "none".
+ scale_transform (str or callable): The transformation for the scale parameter. Defaults to "positive".
+ """
+
+ def __init__(
+ self,
+ name="JohnsonSu",
+ skew_transform="none",
+ shape_transform="positive",
+ loc_transform="none",
+ scale_transform="positive",
+ ):
+ param_names = ["skew", "shape", "location", "scale"]
+ super().__init__(name, param_names)
+
+ self.skew_transform = self.get_transform(skew_transform)
+ self.shape_transform = self.get_transform(shape_transform)
+ self.loc_transform = self.get_transform(loc_transform)
+ self.scale_transform = self.get_transform(scale_transform)
+
+ def log_prob(self, x, skew, shape, loc, scale):
+ """
+ Compute the log probability density of the Johnson's SU distribution.
+ """
+ z = skew + shape * torch.asinh((x - loc) / scale)
+ log_pdf = (
+ torch.log(shape / (scale * np.sqrt(2 * np.pi)))
+ - 0.5 * z**2
+ - 0.5 * torch.log(1 + ((x - loc) / scale) ** 2)
+ )
+ return log_pdf
+
+ def compute_loss(self, predictions, y_true):
+ skew = self.skew_transform(predictions[:, self.param_names.index("skew")])
+ shape = self.shape_transform(predictions[:, self.param_names.index("shape")])
+ loc = self.loc_transform(predictions[:, self.param_names.index("location")])
+ scale = self.scale_transform(predictions[:, self.param_names.index("scale")])
+
+ log_probs = self.log_prob(y_true, skew, shape, loc, scale)
+ nll = -log_probs.mean()
+ return nll
+
+ def evaluate_nll(self, y_true, y_pred):
+ metrics = super().evaluate_nll(y_true, y_pred)
+
+ y_true_tensor = torch.tensor(y_true, dtype=torch.float32)
+ y_pred_tensor = torch.tensor(y_pred, dtype=torch.float32)
+
+ mse_loss = torch.nn.functional.mse_loss(
+ y_true_tensor, y_pred_tensor[:, self.param_names.index("location")]
+ )
+ rmse = np.sqrt(mse_loss.detach().numpy())
+ mae = (
+ torch.nn.functional.l1_loss(
+ y_true_tensor, y_pred_tensor[:, self.param_names.index("location")]
+ )
+ .detach()
+ .numpy()
+ )
+
+ metrics.update({"mse": mse_loss.detach().numpy(), "mae": mae, "rmse": rmse})
+
+ return metrics
diff --git a/mambular/utils/get_feature_dimensions.py b/mambular/utils/get_feature_dimensions.py
index 7ad000d3..b72980bc 100644
--- a/mambular/utils/get_feature_dimensions.py
+++ b/mambular/utils/get_feature_dimensions.py
@@ -1,8 +1,10 @@
-def get_feature_dimensions(num_feature_info, cat_feature_info):
+def get_feature_dimensions(num_feature_info, cat_feature_info, embedding_info):
input_dim = 0
- for feature_name, feature_info in num_feature_info.items():
+ for _, feature_info in num_feature_info.items():
input_dim += feature_info["dimension"]
- for feature_name, feature_info in cat_feature_info.items():
+ for _, feature_info in cat_feature_info.items():
+ input_dim += feature_info["dimension"]
+ for _, feature_info in embedding_info.items():
input_dim += feature_info["dimension"]
return input_dim
diff --git a/poetry.lock b/poetry.lock
index c0682a87..a620abea 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -2,13 +2,13 @@
[[package]]
name = "accelerate"
-version = "1.2.1"
+version = "1.3.0"
description = "Accelerate"
optional = false
python-versions = ">=3.9.0"
files = [
- {file = "accelerate-1.2.1-py3-none-any.whl", hash = "sha256:be1cbb958cf837e7cdfbde46b812964b1b8ae94c9c7d94d921540beafcee8ddf"},
- {file = "accelerate-1.2.1.tar.gz", hash = "sha256:03e161fc69d495daf2b9b5c8d5b43d06e2145520c04727b5bda56d49f1a43ab5"},
+ {file = "accelerate-1.3.0-py3-none-any.whl", hash = "sha256:5788d9e6a7a9f80fed665cf09681c4dddd9dc056bea656db4140ffc285ce423e"},
+ {file = "accelerate-1.3.0.tar.gz", hash = "sha256:518631c0adb80bd3d42fb29e7e2dc2256bcd7c786b0ba9119bbaa08611b36d9c"},
]
[package.dependencies]
@@ -18,7 +18,7 @@ packaging = ">=20.0"
psutil = "*"
pyyaml = "*"
safetensors = ">=0.4.3"
-torch = ">=1.10.0"
+torch = ">=2.0.0"
[package.extras]
deepspeed = ["deepspeed"]
@@ -33,98 +33,103 @@ testing = ["bitsandbytes", "datasets", "diffusers", "evaluate", "parameterized",
[[package]]
name = "aiohappyeyeballs"
-version = "2.4.4"
+version = "2.4.6"
description = "Happy Eyeballs for asyncio"
optional = false
-python-versions = ">=3.8"
+python-versions = ">=3.9"
files = [
- {file = "aiohappyeyeballs-2.4.4-py3-none-any.whl", hash = "sha256:a980909d50efcd44795c4afeca523296716d50cd756ddca6af8c65b996e27de8"},
- {file = "aiohappyeyeballs-2.4.4.tar.gz", hash = "sha256:5fdd7d87889c63183afc18ce9271f9b0a7d32c2303e394468dd45d514a757745"},
+ {file = "aiohappyeyeballs-2.4.6-py3-none-any.whl", hash = "sha256:147ec992cf873d74f5062644332c539fcd42956dc69453fe5204195e560517e1"},
+ {file = "aiohappyeyeballs-2.4.6.tar.gz", hash = "sha256:9b05052f9042985d32ecbe4b59a77ae19c006a78f1344d7fdad69d28ded3d0b0"},
]
[[package]]
name = "aiohttp"
-version = "3.11.11"
+version = "3.11.12"
description = "Async http client/server framework (asyncio)"
optional = false
python-versions = ">=3.9"
files = [
- {file = "aiohttp-3.11.11-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a60804bff28662cbcf340a4d61598891f12eea3a66af48ecfdc975ceec21e3c8"},
- {file = "aiohttp-3.11.11-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:4b4fa1cb5f270fb3eab079536b764ad740bb749ce69a94d4ec30ceee1b5940d5"},
- {file = "aiohttp-3.11.11-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:731468f555656767cda219ab42e033355fe48c85fbe3ba83a349631541715ba2"},
- {file = "aiohttp-3.11.11-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb23d8bb86282b342481cad4370ea0853a39e4a32a0042bb52ca6bdde132df43"},
- {file = "aiohttp-3.11.11-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f047569d655f81cb70ea5be942ee5d4421b6219c3f05d131f64088c73bb0917f"},
- {file = "aiohttp-3.11.11-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dd7659baae9ccf94ae5fe8bfaa2c7bc2e94d24611528395ce88d009107e00c6d"},
- {file = "aiohttp-3.11.11-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af01e42ad87ae24932138f154105e88da13ce7d202a6de93fafdafb2883a00ef"},
- {file = "aiohttp-3.11.11-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5854be2f3e5a729800bac57a8d76af464e160f19676ab6aea74bde18ad19d438"},
- {file = "aiohttp-3.11.11-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:6526e5fb4e14f4bbf30411216780c9967c20c5a55f2f51d3abd6de68320cc2f3"},
- {file = "aiohttp-3.11.11-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:85992ee30a31835fc482468637b3e5bd085fa8fe9392ba0bdcbdc1ef5e9e3c55"},
- {file = "aiohttp-3.11.11-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:88a12ad8ccf325a8a5ed80e6d7c3bdc247d66175afedbe104ee2aaca72960d8e"},
- {file = "aiohttp-3.11.11-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:0a6d3fbf2232e3a08c41eca81ae4f1dff3d8f1a30bae415ebe0af2d2458b8a33"},
- {file = "aiohttp-3.11.11-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:84a585799c58b795573c7fa9b84c455adf3e1d72f19a2bf498b54a95ae0d194c"},
- {file = "aiohttp-3.11.11-cp310-cp310-win32.whl", hash = "sha256:bfde76a8f430cf5c5584553adf9926534352251d379dcb266ad2b93c54a29745"},
- {file = "aiohttp-3.11.11-cp310-cp310-win_amd64.whl", hash = "sha256:0fd82b8e9c383af11d2b26f27a478640b6b83d669440c0a71481f7c865a51da9"},
- {file = "aiohttp-3.11.11-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ba74ec819177af1ef7f59063c6d35a214a8fde6f987f7661f4f0eecc468a8f76"},
- {file = "aiohttp-3.11.11-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4af57160800b7a815f3fe0eba9b46bf28aafc195555f1824555fa2cfab6c1538"},
- {file = "aiohttp-3.11.11-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ffa336210cf9cd8ed117011085817d00abe4c08f99968deef0013ea283547204"},
- {file = "aiohttp-3.11.11-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:81b8fe282183e4a3c7a1b72f5ade1094ed1c6345a8f153506d114af5bf8accd9"},
- {file = "aiohttp-3.11.11-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3af41686ccec6a0f2bdc66686dc0f403c41ac2089f80e2214a0f82d001052c03"},
- {file = "aiohttp-3.11.11-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:70d1f9dde0e5dd9e292a6d4d00058737052b01f3532f69c0c65818dac26dc287"},
- {file = "aiohttp-3.11.11-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:249cc6912405917344192b9f9ea5cd5b139d49e0d2f5c7f70bdfaf6b4dbf3a2e"},
- {file = "aiohttp-3.11.11-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0eb98d90b6690827dcc84c246811feeb4e1eea683c0eac6caed7549be9c84665"},
- {file = "aiohttp-3.11.11-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ec82bf1fda6cecce7f7b915f9196601a1bd1a3079796b76d16ae4cce6d0ef89b"},
- {file = "aiohttp-3.11.11-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:9fd46ce0845cfe28f108888b3ab17abff84ff695e01e73657eec3f96d72eef34"},
- {file = "aiohttp-3.11.11-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:bd176afcf8f5d2aed50c3647d4925d0db0579d96f75a31e77cbaf67d8a87742d"},
- {file = "aiohttp-3.11.11-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:ec2aa89305006fba9ffb98970db6c8221541be7bee4c1d027421d6f6df7d1ce2"},
- {file = "aiohttp-3.11.11-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:92cde43018a2e17d48bb09c79e4d4cb0e236de5063ce897a5e40ac7cb4878773"},
- {file = "aiohttp-3.11.11-cp311-cp311-win32.whl", hash = "sha256:aba807f9569455cba566882c8938f1a549f205ee43c27b126e5450dc9f83cc62"},
- {file = "aiohttp-3.11.11-cp311-cp311-win_amd64.whl", hash = "sha256:ae545f31489548c87b0cced5755cfe5a5308d00407000e72c4fa30b19c3220ac"},
- {file = "aiohttp-3.11.11-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:e595c591a48bbc295ebf47cb91aebf9bd32f3ff76749ecf282ea7f9f6bb73886"},
- {file = "aiohttp-3.11.11-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:3ea1b59dc06396b0b424740a10a0a63974c725b1c64736ff788a3689d36c02d2"},
- {file = "aiohttp-3.11.11-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8811f3f098a78ffa16e0ea36dffd577eb031aea797cbdba81be039a4169e242c"},
- {file = "aiohttp-3.11.11-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd7227b87a355ce1f4bf83bfae4399b1f5bb42e0259cb9405824bd03d2f4336a"},
- {file = "aiohttp-3.11.11-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d40f9da8cabbf295d3a9dae1295c69975b86d941bc20f0a087f0477fa0a66231"},
- {file = "aiohttp-3.11.11-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ffb3dc385f6bb1568aa974fe65da84723210e5d9707e360e9ecb51f59406cd2e"},
- {file = "aiohttp-3.11.11-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8f5f7515f3552d899c61202d99dcb17d6e3b0de777900405611cd747cecd1b8"},
- {file = "aiohttp-3.11.11-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3499c7ffbfd9c6a3d8d6a2b01c26639da7e43d47c7b4f788016226b1e711caa8"},
- {file = "aiohttp-3.11.11-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8e2bf8029dbf0810c7bfbc3e594b51c4cc9101fbffb583a3923aea184724203c"},
- {file = "aiohttp-3.11.11-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:b6212a60e5c482ef90f2d788835387070a88d52cf6241d3916733c9176d39eab"},
- {file = "aiohttp-3.11.11-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:d119fafe7b634dbfa25a8c597718e69a930e4847f0b88e172744be24515140da"},
- {file = "aiohttp-3.11.11-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:6fba278063559acc730abf49845d0e9a9e1ba74f85f0ee6efd5803f08b285853"},
- {file = "aiohttp-3.11.11-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:92fc484e34b733704ad77210c7957679c5c3877bd1e6b6d74b185e9320cc716e"},
- {file = "aiohttp-3.11.11-cp312-cp312-win32.whl", hash = "sha256:9f5b3c1ed63c8fa937a920b6c1bec78b74ee09593b3f5b979ab2ae5ef60d7600"},
- {file = "aiohttp-3.11.11-cp312-cp312-win_amd64.whl", hash = "sha256:1e69966ea6ef0c14ee53ef7a3d68b564cc408121ea56c0caa2dc918c1b2f553d"},
- {file = "aiohttp-3.11.11-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:541d823548ab69d13d23730a06f97460f4238ad2e5ed966aaf850d7c369782d9"},
- {file = "aiohttp-3.11.11-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:929f3ed33743a49ab127c58c3e0a827de0664bfcda566108989a14068f820194"},
- {file = "aiohttp-3.11.11-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0882c2820fd0132240edbb4a51eb8ceb6eef8181db9ad5291ab3332e0d71df5f"},
- {file = "aiohttp-3.11.11-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b63de12e44935d5aca7ed7ed98a255a11e5cb47f83a9fded7a5e41c40277d104"},
- {file = "aiohttp-3.11.11-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa54f8ef31d23c506910c21163f22b124facb573bff73930735cf9fe38bf7dff"},
- {file = "aiohttp-3.11.11-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a344d5dc18074e3872777b62f5f7d584ae4344cd6006c17ba12103759d407af3"},
- {file = "aiohttp-3.11.11-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b7fb429ab1aafa1f48578eb315ca45bd46e9c37de11fe45c7f5f4138091e2f1"},
- {file = "aiohttp-3.11.11-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c341c7d868750e31961d6d8e60ff040fb9d3d3a46d77fd85e1ab8e76c3e9a5c4"},
- {file = "aiohttp-3.11.11-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ed9ee95614a71e87f1a70bc81603f6c6760128b140bc4030abe6abaa988f1c3d"},
- {file = "aiohttp-3.11.11-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:de8d38f1c2810fa2a4f1d995a2e9c70bb8737b18da04ac2afbf3971f65781d87"},
- {file = "aiohttp-3.11.11-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:a9b7371665d4f00deb8f32208c7c5e652059b0fda41cf6dbcac6114a041f1cc2"},
- {file = "aiohttp-3.11.11-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:620598717fce1b3bd14dd09947ea53e1ad510317c85dda2c9c65b622edc96b12"},
- {file = "aiohttp-3.11.11-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:bf8d9bfee991d8acc72d060d53860f356e07a50f0e0d09a8dfedea1c554dd0d5"},
- {file = "aiohttp-3.11.11-cp313-cp313-win32.whl", hash = "sha256:9d73ee3725b7a737ad86c2eac5c57a4a97793d9f442599bea5ec67ac9f4bdc3d"},
- {file = "aiohttp-3.11.11-cp313-cp313-win_amd64.whl", hash = "sha256:c7a06301c2fb096bdb0bd25fe2011531c1453b9f2c163c8031600ec73af1cc99"},
- {file = "aiohttp-3.11.11-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:3e23419d832d969f659c208557de4a123e30a10d26e1e14b73431d3c13444c2e"},
- {file = "aiohttp-3.11.11-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:21fef42317cf02e05d3b09c028712e1d73a9606f02467fd803f7c1f39cc59add"},
- {file = "aiohttp-3.11.11-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1f21bb8d0235fc10c09ce1d11ffbd40fc50d3f08a89e4cf3a0c503dc2562247a"},
- {file = "aiohttp-3.11.11-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1642eceeaa5ab6c9b6dfeaaa626ae314d808188ab23ae196a34c9d97efb68350"},
- {file = "aiohttp-3.11.11-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2170816e34e10f2fd120f603e951630f8a112e1be3b60963a1f159f5699059a6"},
- {file = "aiohttp-3.11.11-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8be8508d110d93061197fd2d6a74f7401f73b6d12f8822bbcd6d74f2b55d71b1"},
- {file = "aiohttp-3.11.11-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4eed954b161e6b9b65f6be446ed448ed3921763cc432053ceb606f89d793927e"},
- {file = "aiohttp-3.11.11-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d6c9af134da4bc9b3bd3e6a70072509f295d10ee60c697826225b60b9959acdd"},
- {file = "aiohttp-3.11.11-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:44167fc6a763d534a6908bdb2592269b4bf30a03239bcb1654781adf5e49caf1"},
- {file = "aiohttp-3.11.11-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:479b8c6ebd12aedfe64563b85920525d05d394b85f166b7873c8bde6da612f9c"},
- {file = "aiohttp-3.11.11-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:10b4ff0ad793d98605958089fabfa350e8e62bd5d40aa65cdc69d6785859f94e"},
- {file = "aiohttp-3.11.11-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:b540bd67cfb54e6f0865ceccd9979687210d7ed1a1cc8c01f8e67e2f1e883d28"},
- {file = "aiohttp-3.11.11-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:1dac54e8ce2ed83b1f6b1a54005c87dfed139cf3f777fdc8afc76e7841101226"},
- {file = "aiohttp-3.11.11-cp39-cp39-win32.whl", hash = "sha256:568c1236b2fde93b7720f95a890741854c1200fba4a3471ff48b2934d2d93fd3"},
- {file = "aiohttp-3.11.11-cp39-cp39-win_amd64.whl", hash = "sha256:943a8b052e54dfd6439fd7989f67fc6a7f2138d0a2cf0a7de5f18aa4fe7eb3b1"},
- {file = "aiohttp-3.11.11.tar.gz", hash = "sha256:bb49c7f1e6ebf3821a42d81d494f538107610c3a705987f53068546b0e90303e"},
+ {file = "aiohttp-3.11.12-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:aa8a8caca81c0a3e765f19c6953416c58e2f4cc1b84829af01dd1c771bb2f91f"},
+ {file = "aiohttp-3.11.12-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:84ede78acde96ca57f6cf8ccb8a13fbaf569f6011b9a52f870c662d4dc8cd854"},
+ {file = "aiohttp-3.11.12-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:584096938a001378484aa4ee54e05dc79c7b9dd933e271c744a97b3b6f644957"},
+ {file = "aiohttp-3.11.12-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:392432a2dde22b86f70dd4a0e9671a349446c93965f261dbaecfaf28813e5c42"},
+ {file = "aiohttp-3.11.12-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:88d385b8e7f3a870146bf5ea31786ef7463e99eb59e31db56e2315535d811f55"},
+ {file = "aiohttp-3.11.12-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b10a47e5390c4b30a0d58ee12581003be52eedd506862ab7f97da7a66805befb"},
+ {file = "aiohttp-3.11.12-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b5263dcede17b6b0c41ef0c3ccce847d82a7da98709e75cf7efde3e9e3b5cae"},
+ {file = "aiohttp-3.11.12-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:50c5c7b8aa5443304c55c262c5693b108c35a3b61ef961f1e782dd52a2f559c7"},
+ {file = "aiohttp-3.11.12-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d1c031a7572f62f66f1257db37ddab4cb98bfaf9b9434a3b4840bf3560f5e788"},
+ {file = "aiohttp-3.11.12-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:7e44eba534381dd2687be50cbd5f2daded21575242ecfdaf86bbeecbc38dae8e"},
+ {file = "aiohttp-3.11.12-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:145a73850926018ec1681e734cedcf2716d6a8697d90da11284043b745c286d5"},
+ {file = "aiohttp-3.11.12-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:2c311e2f63e42c1bf86361d11e2c4a59f25d9e7aabdbdf53dc38b885c5435cdb"},
+ {file = "aiohttp-3.11.12-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:ea756b5a7bac046d202a9a3889b9a92219f885481d78cd318db85b15cc0b7bcf"},
+ {file = "aiohttp-3.11.12-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:526c900397f3bbc2db9cb360ce9c35134c908961cdd0ac25b1ae6ffcaa2507ff"},
+ {file = "aiohttp-3.11.12-cp310-cp310-win32.whl", hash = "sha256:b8d3bb96c147b39c02d3db086899679f31958c5d81c494ef0fc9ef5bb1359b3d"},
+ {file = "aiohttp-3.11.12-cp310-cp310-win_amd64.whl", hash = "sha256:7fe3d65279bfbee8de0fb4f8c17fc4e893eed2dba21b2f680e930cc2b09075c5"},
+ {file = "aiohttp-3.11.12-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:87a2e00bf17da098d90d4145375f1d985a81605267e7f9377ff94e55c5d769eb"},
+ {file = "aiohttp-3.11.12-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b34508f1cd928ce915ed09682d11307ba4b37d0708d1f28e5774c07a7674cac9"},
+ {file = "aiohttp-3.11.12-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:936d8a4f0f7081327014742cd51d320296b56aa6d324461a13724ab05f4b2933"},
+ {file = "aiohttp-3.11.12-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2de1378f72def7dfb5dbd73d86c19eda0ea7b0a6873910cc37d57e80f10d64e1"},
+ {file = "aiohttp-3.11.12-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b9d45dbb3aaec05cf01525ee1a7ac72de46a8c425cb75c003acd29f76b1ffe94"},
+ {file = "aiohttp-3.11.12-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:930ffa1925393381e1e0a9b82137fa7b34c92a019b521cf9f41263976666a0d6"},
+ {file = "aiohttp-3.11.12-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8340def6737118f5429a5df4e88f440746b791f8f1c4ce4ad8a595f42c980bd5"},
+ {file = "aiohttp-3.11.12-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4016e383f91f2814e48ed61e6bda7d24c4d7f2402c75dd28f7e1027ae44ea204"},
+ {file = "aiohttp-3.11.12-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:3c0600bcc1adfaaac321422d615939ef300df81e165f6522ad096b73439c0f58"},
+ {file = "aiohttp-3.11.12-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:0450ada317a65383b7cce9576096150fdb97396dcfe559109b403c7242faffef"},
+ {file = "aiohttp-3.11.12-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:850ff6155371fd802a280f8d369d4e15d69434651b844bde566ce97ee2277420"},
+ {file = "aiohttp-3.11.12-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:8fd12d0f989c6099e7b0f30dc6e0d1e05499f3337461f0b2b0dadea6c64b89df"},
+ {file = "aiohttp-3.11.12-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:76719dd521c20a58a6c256d058547b3a9595d1d885b830013366e27011ffe804"},
+ {file = "aiohttp-3.11.12-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:97fe431f2ed646a3b56142fc81d238abcbaff08548d6912acb0b19a0cadc146b"},
+ {file = "aiohttp-3.11.12-cp311-cp311-win32.whl", hash = "sha256:e10c440d142fa8b32cfdb194caf60ceeceb3e49807072e0dc3a8887ea80e8c16"},
+ {file = "aiohttp-3.11.12-cp311-cp311-win_amd64.whl", hash = "sha256:246067ba0cf5560cf42e775069c5d80a8989d14a7ded21af529a4e10e3e0f0e6"},
+ {file = "aiohttp-3.11.12-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:e392804a38353900c3fd8b7cacbea5132888f7129f8e241915e90b85f00e3250"},
+ {file = "aiohttp-3.11.12-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:8fa1510b96c08aaad49303ab11f8803787c99222288f310a62f493faf883ede1"},
+ {file = "aiohttp-3.11.12-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:dc065a4285307607df3f3686363e7f8bdd0d8ab35f12226362a847731516e42c"},
+ {file = "aiohttp-3.11.12-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cddb31f8474695cd61fc9455c644fc1606c164b93bff2490390d90464b4655df"},
+ {file = "aiohttp-3.11.12-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9dec0000d2d8621d8015c293e24589d46fa218637d820894cb7356c77eca3259"},
+ {file = "aiohttp-3.11.12-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e3552fe98e90fdf5918c04769f338a87fa4f00f3b28830ea9b78b1bdc6140e0d"},
+ {file = "aiohttp-3.11.12-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6dfe7f984f28a8ae94ff3a7953cd9678550dbd2a1f9bda5dd9c5ae627744c78e"},
+ {file = "aiohttp-3.11.12-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a481a574af914b6e84624412666cbfbe531a05667ca197804ecc19c97b8ab1b0"},
+ {file = "aiohttp-3.11.12-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1987770fb4887560363b0e1a9b75aa303e447433c41284d3af2840a2f226d6e0"},
+ {file = "aiohttp-3.11.12-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:a4ac6a0f0f6402854adca4e3259a623f5c82ec3f0c049374133bcb243132baf9"},
+ {file = "aiohttp-3.11.12-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:c96a43822f1f9f69cc5c3706af33239489a6294be486a0447fb71380070d4d5f"},
+ {file = "aiohttp-3.11.12-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:a5e69046f83c0d3cb8f0d5bd9b8838271b1bc898e01562a04398e160953e8eb9"},
+ {file = "aiohttp-3.11.12-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:68d54234c8d76d8ef74744f9f9fc6324f1508129e23da8883771cdbb5818cbef"},
+ {file = "aiohttp-3.11.12-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c9fd9dcf9c91affe71654ef77426f5cf8489305e1c66ed4816f5a21874b094b9"},
+ {file = "aiohttp-3.11.12-cp312-cp312-win32.whl", hash = "sha256:0ed49efcd0dc1611378beadbd97beb5d9ca8fe48579fc04a6ed0844072261b6a"},
+ {file = "aiohttp-3.11.12-cp312-cp312-win_amd64.whl", hash = "sha256:54775858c7f2f214476773ce785a19ee81d1294a6bedc5cc17225355aab74802"},
+ {file = "aiohttp-3.11.12-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:413ad794dccb19453e2b97c2375f2ca3cdf34dc50d18cc2693bd5aed7d16f4b9"},
+ {file = "aiohttp-3.11.12-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4a93d28ed4b4b39e6f46fd240896c29b686b75e39cc6992692e3922ff6982b4c"},
+ {file = "aiohttp-3.11.12-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:d589264dbba3b16e8951b6f145d1e6b883094075283dafcab4cdd564a9e353a0"},
+ {file = "aiohttp-3.11.12-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e5148ca8955affdfeb864aca158ecae11030e952b25b3ae15d4e2b5ba299bad2"},
+ {file = "aiohttp-3.11.12-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:525410e0790aab036492eeea913858989c4cb070ff373ec3bc322d700bdf47c1"},
+ {file = "aiohttp-3.11.12-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9bd8695be2c80b665ae3f05cb584093a1e59c35ecb7d794d1edd96e8cc9201d7"},
+ {file = "aiohttp-3.11.12-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f0203433121484b32646a5f5ea93ae86f3d9559d7243f07e8c0eab5ff8e3f70e"},
+ {file = "aiohttp-3.11.12-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40cd36749a1035c34ba8d8aaf221b91ca3d111532e5ccb5fa8c3703ab1b967ed"},
+ {file = "aiohttp-3.11.12-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a7442662afebbf7b4c6d28cb7aab9e9ce3a5df055fc4116cc7228192ad6cb484"},
+ {file = "aiohttp-3.11.12-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:8a2fb742ef378284a50766e985804bd6adb5adb5aa781100b09befdbfa757b65"},
+ {file = "aiohttp-3.11.12-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:2cee3b117a8d13ab98b38d5b6bdcd040cfb4181068d05ce0c474ec9db5f3c5bb"},
+ {file = "aiohttp-3.11.12-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f6a19bcab7fbd8f8649d6595624856635159a6527861b9cdc3447af288a00c00"},
+ {file = "aiohttp-3.11.12-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:e4cecdb52aaa9994fbed6b81d4568427b6002f0a91c322697a4bfcc2b2363f5a"},
+ {file = "aiohttp-3.11.12-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:30f546358dfa0953db92ba620101fefc81574f87b2346556b90b5f3ef16e55ce"},
+ {file = "aiohttp-3.11.12-cp313-cp313-win32.whl", hash = "sha256:ce1bb21fc7d753b5f8a5d5a4bae99566386b15e716ebdb410154c16c91494d7f"},
+ {file = "aiohttp-3.11.12-cp313-cp313-win_amd64.whl", hash = "sha256:f7914ab70d2ee8ab91c13e5402122edbc77821c66d2758abb53aabe87f013287"},
+ {file = "aiohttp-3.11.12-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7c3623053b85b4296cd3925eeb725e386644fd5bc67250b3bb08b0f144803e7b"},
+ {file = "aiohttp-3.11.12-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:67453e603cea8e85ed566b2700efa1f6916aefbc0c9fcb2e86aaffc08ec38e78"},
+ {file = "aiohttp-3.11.12-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6130459189e61baac5a88c10019b21e1f0c6d00ebc770e9ce269475650ff7f73"},
+ {file = "aiohttp-3.11.12-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9060addfa4ff753b09392efe41e6af06ea5dd257829199747b9f15bfad819460"},
+ {file = "aiohttp-3.11.12-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:34245498eeb9ae54c687a07ad7f160053911b5745e186afe2d0c0f2898a1ab8a"},
+ {file = "aiohttp-3.11.12-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8dc0fba9a74b471c45ca1a3cb6e6913ebfae416678d90529d188886278e7f3f6"},
+ {file = "aiohttp-3.11.12-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a478aa11b328983c4444dacb947d4513cb371cd323f3845e53caeda6be5589d5"},
+ {file = "aiohttp-3.11.12-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c160a04283c8c6f55b5bf6d4cad59bb9c5b9c9cd08903841b25f1f7109ef1259"},
+ {file = "aiohttp-3.11.12-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:edb69b9589324bdc40961cdf0657815df674f1743a8d5ad9ab56a99e4833cfdd"},
+ {file = "aiohttp-3.11.12-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:4ee84c2a22a809c4f868153b178fe59e71423e1f3d6a8cd416134bb231fbf6d3"},
+ {file = "aiohttp-3.11.12-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:bf4480a5438f80e0f1539e15a7eb8b5f97a26fe087e9828e2c0ec2be119a9f72"},
+ {file = "aiohttp-3.11.12-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:e6b2732ef3bafc759f653a98881b5b9cdef0716d98f013d376ee8dfd7285abf1"},
+ {file = "aiohttp-3.11.12-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:f752e80606b132140883bb262a457c475d219d7163d996dc9072434ffb0784c4"},
+ {file = "aiohttp-3.11.12-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:ab3247d58b393bda5b1c8f31c9edece7162fc13265334217785518dd770792b8"},
+ {file = "aiohttp-3.11.12-cp39-cp39-win32.whl", hash = "sha256:0d5176f310a7fe6f65608213cc74f4228e4f4ce9fd10bcb2bb6da8fc66991462"},
+ {file = "aiohttp-3.11.12-cp39-cp39-win_amd64.whl", hash = "sha256:74bd573dde27e58c760d9ca8615c41a57e719bff315c9adb6f2a4281a28e8798"},
+ {file = "aiohttp-3.11.12.tar.gz", hash = "sha256:7603ca26d75b1b86160ce1bbe2787a0b706e592af5b2504e12caa88a217767b0"},
]
[package.dependencies]
@@ -167,13 +172,13 @@ files = [
[[package]]
name = "attrs"
-version = "24.3.0"
+version = "25.1.0"
description = "Classes Without Boilerplate"
optional = false
python-versions = ">=3.8"
files = [
- {file = "attrs-24.3.0-py3-none-any.whl", hash = "sha256:ac96cd038792094f438ad1f6ff80837353805ac950cd2aa0e0625ef19850c308"},
- {file = "attrs-24.3.0.tar.gz", hash = "sha256:8f5c07333d543103541ba7be0e2ce16eeee8130cb0b3f9238ab904ce1e85baff"},
+ {file = "attrs-25.1.0-py3-none-any.whl", hash = "sha256:c75a69e28a550a7e93789579c22aa26b0f5b83b75dc4e08fe092980051e1090a"},
+ {file = "attrs-25.1.0.tar.gz", hash = "sha256:1c97078a80c814273a76b2a298a932eb681c87415c11dee0a6921de7f1b02c3e"},
]
[package.extras]
@@ -186,13 +191,13 @@ tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"]
[[package]]
name = "certifi"
-version = "2024.12.14"
+version = "2025.1.31"
description = "Python package for providing Mozilla's CA Bundle."
optional = false
python-versions = ">=3.6"
files = [
- {file = "certifi-2024.12.14-py3-none-any.whl", hash = "sha256:1275f7a45be9464efc1173084eaa30f866fe2e47d389406136d332ed4967ec56"},
- {file = "certifi-2024.12.14.tar.gz", hash = "sha256:b650d30f370c2b724812bee08008be0c4163b163ddaec3f2546c1caf65f191db"},
+ {file = "certifi-2025.1.31-py3-none-any.whl", hash = "sha256:ca78db4565a652026a4db2bcdf68f2fb589ea80d0be70e03929ed730746b84fe"},
+ {file = "certifi-2025.1.31.tar.gz", hash = "sha256:3d5da6925056f6f18f119200434a4780a94263f10d1c21d032a6f6b2baa20651"},
]
[[package]]
@@ -320,73 +325,74 @@ files = [
[[package]]
name = "coverage"
-version = "7.6.10"
+version = "7.6.12"
description = "Code coverage measurement for Python"
optional = false
python-versions = ">=3.9"
files = [
- {file = "coverage-7.6.10-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5c912978f7fbf47ef99cec50c4401340436d200d41d714c7a4766f377c5b7b78"},
- {file = "coverage-7.6.10-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a01ec4af7dfeb96ff0078ad9a48810bb0cc8abcb0115180c6013a6b26237626c"},
- {file = "coverage-7.6.10-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a3b204c11e2b2d883946fe1d97f89403aa1811df28ce0447439178cc7463448a"},
- {file = "coverage-7.6.10-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:32ee6d8491fcfc82652a37109f69dee9a830e9379166cb73c16d8dc5c2915165"},
- {file = "coverage-7.6.10-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675cefc4c06e3b4c876b85bfb7c59c5e2218167bbd4da5075cbe3b5790a28988"},
- {file = "coverage-7.6.10-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f4f620668dbc6f5e909a0946a877310fb3d57aea8198bde792aae369ee1c23b5"},
- {file = "coverage-7.6.10-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:4eea95ef275de7abaef630c9b2c002ffbc01918b726a39f5a4353916ec72d2f3"},
- {file = "coverage-7.6.10-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e2f0280519e42b0a17550072861e0bc8a80a0870de260f9796157d3fca2733c5"},
- {file = "coverage-7.6.10-cp310-cp310-win32.whl", hash = "sha256:bc67deb76bc3717f22e765ab3e07ee9c7a5e26b9019ca19a3b063d9f4b874244"},
- {file = "coverage-7.6.10-cp310-cp310-win_amd64.whl", hash = "sha256:0f460286cb94036455e703c66988851d970fdfd8acc2a1122ab7f4f904e4029e"},
- {file = "coverage-7.6.10-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ea3c8f04b3e4af80e17bab607c386a830ffc2fb88a5484e1df756478cf70d1d3"},
- {file = "coverage-7.6.10-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:507a20fc863cae1d5720797761b42d2d87a04b3e5aeb682ef3b7332e90598f43"},
- {file = "coverage-7.6.10-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d37a84878285b903c0fe21ac8794c6dab58150e9359f1aaebbeddd6412d53132"},
- {file = "coverage-7.6.10-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a534738b47b0de1995f85f582d983d94031dffb48ab86c95bdf88dc62212142f"},
- {file = "coverage-7.6.10-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d7a2bf79378d8fb8afaa994f91bfd8215134f8631d27eba3e0e2c13546ce994"},
- {file = "coverage-7.6.10-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:6713ba4b4ebc330f3def51df1d5d38fad60b66720948112f114968feb52d3f99"},
- {file = "coverage-7.6.10-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:ab32947f481f7e8c763fa2c92fd9f44eeb143e7610c4ca9ecd6a36adab4081bd"},
- {file = "coverage-7.6.10-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:7bbd8c8f1b115b892e34ba66a097b915d3871db7ce0e6b9901f462ff3a975377"},
- {file = "coverage-7.6.10-cp311-cp311-win32.whl", hash = "sha256:299e91b274c5c9cdb64cbdf1b3e4a8fe538a7a86acdd08fae52301b28ba297f8"},
- {file = "coverage-7.6.10-cp311-cp311-win_amd64.whl", hash = "sha256:489a01f94aa581dbd961f306e37d75d4ba16104bbfa2b0edb21d29b73be83609"},
- {file = "coverage-7.6.10-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:27c6e64726b307782fa5cbe531e7647aee385a29b2107cd87ba7c0105a5d3853"},
- {file = "coverage-7.6.10-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c56e097019e72c373bae32d946ecf9858fda841e48d82df7e81c63ac25554078"},
- {file = "coverage-7.6.10-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c7827a5bc7bdb197b9e066cdf650b2887597ad124dd99777332776f7b7c7d0d0"},
- {file = "coverage-7.6.10-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:204a8238afe787323a8b47d8be4df89772d5c1e4651b9ffa808552bdf20e1d50"},
- {file = "coverage-7.6.10-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e67926f51821b8e9deb6426ff3164870976fe414d033ad90ea75e7ed0c2e5022"},
- {file = "coverage-7.6.10-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e78b270eadb5702938c3dbe9367f878249b5ef9a2fcc5360ac7bff694310d17b"},
- {file = "coverage-7.6.10-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:714f942b9c15c3a7a5fe6876ce30af831c2ad4ce902410b7466b662358c852c0"},
- {file = "coverage-7.6.10-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:abb02e2f5a3187b2ac4cd46b8ced85a0858230b577ccb2c62c81482ca7d18852"},
- {file = "coverage-7.6.10-cp312-cp312-win32.whl", hash = "sha256:55b201b97286cf61f5e76063f9e2a1d8d2972fc2fcfd2c1272530172fd28c359"},
- {file = "coverage-7.6.10-cp312-cp312-win_amd64.whl", hash = "sha256:e4ae5ac5e0d1e4edfc9b4b57b4cbecd5bc266a6915c500f358817a8496739247"},
- {file = "coverage-7.6.10-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:05fca8ba6a87aabdd2d30d0b6c838b50510b56cdcfc604d40760dae7153b73d9"},
- {file = "coverage-7.6.10-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:9e80eba8801c386f72e0712a0453431259c45c3249f0009aff537a517b52942b"},
- {file = "coverage-7.6.10-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a372c89c939d57abe09e08c0578c1d212e7a678135d53aa16eec4430adc5e690"},
- {file = "coverage-7.6.10-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ec22b5e7fe7a0fa8509181c4aac1db48f3dd4d3a566131b313d1efc102892c18"},
- {file = "coverage-7.6.10-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:26bcf5c4df41cad1b19c84af71c22cbc9ea9a547fc973f1f2cc9a290002c8b3c"},
- {file = "coverage-7.6.10-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:4e4630c26b6084c9b3cb53b15bd488f30ceb50b73c35c5ad7871b869cb7365fd"},
- {file = "coverage-7.6.10-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:2396e8116db77789f819d2bc8a7e200232b7a282c66e0ae2d2cd84581a89757e"},
- {file = "coverage-7.6.10-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:79109c70cc0882e4d2d002fe69a24aa504dec0cc17169b3c7f41a1d341a73694"},
- {file = "coverage-7.6.10-cp313-cp313-win32.whl", hash = "sha256:9e1747bab246d6ff2c4f28b4d186b205adced9f7bd9dc362051cc37c4a0c7bd6"},
- {file = "coverage-7.6.10-cp313-cp313-win_amd64.whl", hash = "sha256:254f1a3b1eef5f7ed23ef265eaa89c65c8c5b6b257327c149db1ca9d4a35f25e"},
- {file = "coverage-7.6.10-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:2ccf240eb719789cedbb9fd1338055de2761088202a9a0b73032857e53f612fe"},
- {file = "coverage-7.6.10-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:0c807ca74d5a5e64427c8805de15b9ca140bba13572d6d74e262f46f50b13273"},
- {file = "coverage-7.6.10-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2bcfa46d7709b5a7ffe089075799b902020b62e7ee56ebaed2f4bdac04c508d8"},
- {file = "coverage-7.6.10-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4e0de1e902669dccbf80b0415fb6b43d27edca2fbd48c74da378923b05316098"},
- {file = "coverage-7.6.10-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f7b444c42bbc533aaae6b5a2166fd1a797cdb5eb58ee51a92bee1eb94a1e1cb"},
- {file = "coverage-7.6.10-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:b330368cb99ef72fcd2dc3ed260adf67b31499584dc8a20225e85bfe6f6cfed0"},
- {file = "coverage-7.6.10-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:9a7cfb50515f87f7ed30bc882f68812fd98bc2852957df69f3003d22a2aa0abf"},
- {file = "coverage-7.6.10-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:6f93531882a5f68c28090f901b1d135de61b56331bba82028489bc51bdd818d2"},
- {file = "coverage-7.6.10-cp313-cp313t-win32.whl", hash = "sha256:89d76815a26197c858f53c7f6a656686ec392b25991f9e409bcef020cd532312"},
- {file = "coverage-7.6.10-cp313-cp313t-win_amd64.whl", hash = "sha256:54a5f0f43950a36312155dae55c505a76cd7f2b12d26abeebbe7a0b36dbc868d"},
- {file = "coverage-7.6.10-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:656c82b8a0ead8bba147de9a89bda95064874c91a3ed43a00e687f23cc19d53a"},
- {file = "coverage-7.6.10-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ccc2b70a7ed475c68ceb548bf69cec1e27305c1c2606a5eb7c3afff56a1b3b27"},
- {file = "coverage-7.6.10-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5e37dc41d57ceba70956fa2fc5b63c26dba863c946ace9705f8eca99daecdc4"},
- {file = "coverage-7.6.10-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0aa9692b4fdd83a4647eeb7db46410ea1322b5ed94cd1715ef09d1d5922ba87f"},
- {file = "coverage-7.6.10-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa744da1820678b475e4ba3dfd994c321c5b13381d1041fe9c608620e6676e25"},
- {file = "coverage-7.6.10-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:c0b1818063dc9e9d838c09e3a473c1422f517889436dd980f5d721899e66f315"},
- {file = "coverage-7.6.10-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:59af35558ba08b758aec4d56182b222976330ef8d2feacbb93964f576a7e7a90"},
- {file = "coverage-7.6.10-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:7ed2f37cfce1ce101e6dffdfd1c99e729dd2ffc291d02d3e2d0af8b53d13840d"},
- {file = "coverage-7.6.10-cp39-cp39-win32.whl", hash = "sha256:4bcc276261505d82f0ad426870c3b12cb177752834a633e737ec5ee79bbdff18"},
- {file = "coverage-7.6.10-cp39-cp39-win_amd64.whl", hash = "sha256:457574f4599d2b00f7f637a0700a6422243b3565509457b2dbd3f50703e11f59"},
- {file = "coverage-7.6.10-pp39.pp310-none-any.whl", hash = "sha256:fd34e7b3405f0cc7ab03d54a334c17a9e802897580d964bd8c2001f4b9fd488f"},
- {file = "coverage-7.6.10.tar.gz", hash = "sha256:7fb105327c8f8f0682e29843e2ff96af9dcbe5bab8eeb4b398c6a33a16d80a23"},
+ {file = "coverage-7.6.12-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:704c8c8c6ce6569286ae9622e534b4f5b9759b6f2cd643f1c1a61f666d534fe8"},
+ {file = "coverage-7.6.12-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ad7525bf0241e5502168ae9c643a2f6c219fa0a283001cee4cf23a9b7da75879"},
+ {file = "coverage-7.6.12-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:06097c7abfa611c91edb9e6920264e5be1d6ceb374efb4986f38b09eed4cb2fe"},
+ {file = "coverage-7.6.12-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:220fa6c0ad7d9caef57f2c8771918324563ef0d8272c94974717c3909664e674"},
+ {file = "coverage-7.6.12-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3688b99604a24492bcfe1c106278c45586eb819bf66a654d8a9a1433022fb2eb"},
+ {file = "coverage-7.6.12-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d1a987778b9c71da2fc8948e6f2656da6ef68f59298b7e9786849634c35d2c3c"},
+ {file = "coverage-7.6.12-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:cec6b9ce3bd2b7853d4a4563801292bfee40b030c05a3d29555fd2a8ee9bd68c"},
+ {file = "coverage-7.6.12-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ace9048de91293e467b44bce0f0381345078389814ff6e18dbac8fdbf896360e"},
+ {file = "coverage-7.6.12-cp310-cp310-win32.whl", hash = "sha256:ea31689f05043d520113e0552f039603c4dd71fa4c287b64cb3606140c66f425"},
+ {file = "coverage-7.6.12-cp310-cp310-win_amd64.whl", hash = "sha256:676f92141e3c5492d2a1596d52287d0d963df21bf5e55c8b03075a60e1ddf8aa"},
+ {file = "coverage-7.6.12-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e18aafdfb3e9ec0d261c942d35bd7c28d031c5855dadb491d2723ba54f4c3015"},
+ {file = "coverage-7.6.12-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:66fe626fd7aa5982cdebad23e49e78ef7dbb3e3c2a5960a2b53632f1f703ea45"},
+ {file = "coverage-7.6.12-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ef01d70198431719af0b1f5dcbefc557d44a190e749004042927b2a3fed0702"},
+ {file = "coverage-7.6.12-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07e92ae5a289a4bc4c0aae710c0948d3c7892e20fd3588224ebe242039573bf0"},
+ {file = "coverage-7.6.12-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e695df2c58ce526eeab11a2e915448d3eb76f75dffe338ea613c1201b33bab2f"},
+ {file = "coverage-7.6.12-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d74c08e9aaef995f8c4ef6d202dbd219c318450fe2a76da624f2ebb9c8ec5d9f"},
+ {file = "coverage-7.6.12-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e995b3b76ccedc27fe4f477b349b7d64597e53a43fc2961db9d3fbace085d69d"},
+ {file = "coverage-7.6.12-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b1f097878d74fe51e1ddd1be62d8e3682748875b461232cf4b52ddc6e6db0bba"},
+ {file = "coverage-7.6.12-cp311-cp311-win32.whl", hash = "sha256:1f7ffa05da41754e20512202c866d0ebfc440bba3b0ed15133070e20bf5aeb5f"},
+ {file = "coverage-7.6.12-cp311-cp311-win_amd64.whl", hash = "sha256:e216c5c45f89ef8971373fd1c5d8d1164b81f7f5f06bbf23c37e7908d19e8558"},
+ {file = "coverage-7.6.12-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b172f8e030e8ef247b3104902cc671e20df80163b60a203653150d2fc204d1ad"},
+ {file = "coverage-7.6.12-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:641dfe0ab73deb7069fb972d4d9725bf11c239c309ce694dd50b1473c0f641c3"},
+ {file = "coverage-7.6.12-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0e549f54ac5f301e8e04c569dfdb907f7be71b06b88b5063ce9d6953d2d58574"},
+ {file = "coverage-7.6.12-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:959244a17184515f8c52dcb65fb662808767c0bd233c1d8a166e7cf74c9ea985"},
+ {file = "coverage-7.6.12-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bda1c5f347550c359f841d6614fb8ca42ae5cb0b74d39f8a1e204815ebe25750"},
+ {file = "coverage-7.6.12-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1ceeb90c3eda1f2d8c4c578c14167dbd8c674ecd7d38e45647543f19839dd6ea"},
+ {file = "coverage-7.6.12-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:0f16f44025c06792e0fb09571ae454bcc7a3ec75eeb3c36b025eccf501b1a4c3"},
+ {file = "coverage-7.6.12-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b076e625396e787448d27a411aefff867db2bffac8ed04e8f7056b07024eed5a"},
+ {file = "coverage-7.6.12-cp312-cp312-win32.whl", hash = "sha256:00b2086892cf06c7c2d74983c9595dc511acca00665480b3ddff749ec4fb2a95"},
+ {file = "coverage-7.6.12-cp312-cp312-win_amd64.whl", hash = "sha256:7ae6eabf519bc7871ce117fb18bf14e0e343eeb96c377667e3e5dd12095e0288"},
+ {file = "coverage-7.6.12-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:488c27b3db0ebee97a830e6b5a3ea930c4a6e2c07f27a5e67e1b3532e76b9ef1"},
+ {file = "coverage-7.6.12-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5d1095bbee1851269f79fd8e0c9b5544e4c00c0c24965e66d8cba2eb5bb535fd"},
+ {file = "coverage-7.6.12-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0533adc29adf6a69c1baa88c3d7dbcaadcffa21afbed3ca7a225a440e4744bf9"},
+ {file = "coverage-7.6.12-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:53c56358d470fa507a2b6e67a68fd002364d23c83741dbc4c2e0680d80ca227e"},
+ {file = "coverage-7.6.12-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:64cbb1a3027c79ca6310bf101014614f6e6e18c226474606cf725238cf5bc2d4"},
+ {file = "coverage-7.6.12-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:79cac3390bfa9836bb795be377395f28410811c9066bc4eefd8015258a7578c6"},
+ {file = "coverage-7.6.12-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:9b148068e881faa26d878ff63e79650e208e95cf1c22bd3f77c3ca7b1d9821a3"},
+ {file = "coverage-7.6.12-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:8bec2ac5da793c2685ce5319ca9bcf4eee683b8a1679051f8e6ec04c4f2fd7dc"},
+ {file = "coverage-7.6.12-cp313-cp313-win32.whl", hash = "sha256:200e10beb6ddd7c3ded322a4186313d5ca9e63e33d8fab4faa67ef46d3460af3"},
+ {file = "coverage-7.6.12-cp313-cp313-win_amd64.whl", hash = "sha256:2b996819ced9f7dbb812c701485d58f261bef08f9b85304d41219b1496b591ef"},
+ {file = "coverage-7.6.12-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:299cf973a7abff87a30609879c10df0b3bfc33d021e1adabc29138a48888841e"},
+ {file = "coverage-7.6.12-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:4b467a8c56974bf06e543e69ad803c6865249d7a5ccf6980457ed2bc50312703"},
+ {file = "coverage-7.6.12-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2458f275944db8129f95d91aee32c828a408481ecde3b30af31d552c2ce284a0"},
+ {file = "coverage-7.6.12-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a9d8be07fb0832636a0f72b80d2a652fe665e80e720301fb22b191c3434d924"},
+ {file = "coverage-7.6.12-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:14d47376a4f445e9743f6c83291e60adb1b127607a3618e3185bbc8091f0467b"},
+ {file = "coverage-7.6.12-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:b95574d06aa9d2bd6e5cc35a5bbe35696342c96760b69dc4287dbd5abd4ad51d"},
+ {file = "coverage-7.6.12-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:ecea0c38c9079570163d663c0433a9af4094a60aafdca491c6a3d248c7432827"},
+ {file = "coverage-7.6.12-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:2251fabcfee0a55a8578a9d29cecfee5f2de02f11530e7d5c5a05859aa85aee9"},
+ {file = "coverage-7.6.12-cp313-cp313t-win32.whl", hash = "sha256:eb5507795caabd9b2ae3f1adc95f67b1104971c22c624bb354232d65c4fc90b3"},
+ {file = "coverage-7.6.12-cp313-cp313t-win_amd64.whl", hash = "sha256:f60a297c3987c6c02ffb29effc70eadcbb412fe76947d394a1091a3615948e2f"},
+ {file = "coverage-7.6.12-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:e7575ab65ca8399c8c4f9a7d61bbd2d204c8b8e447aab9d355682205c9dd948d"},
+ {file = "coverage-7.6.12-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8161d9fbc7e9fe2326de89cd0abb9f3599bccc1287db0aba285cb68d204ce929"},
+ {file = "coverage-7.6.12-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a1e465f398c713f1b212400b4e79a09829cd42aebd360362cd89c5bdc44eb87"},
+ {file = "coverage-7.6.12-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f25d8b92a4e31ff1bd873654ec367ae811b3a943583e05432ea29264782dc32c"},
+ {file = "coverage-7.6.12-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1a936309a65cc5ca80fa9f20a442ff9e2d06927ec9a4f54bcba9c14c066323f2"},
+ {file = "coverage-7.6.12-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:aa6f302a3a0b5f240ee201297fff0bbfe2fa0d415a94aeb257d8b461032389bd"},
+ {file = "coverage-7.6.12-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:f973643ef532d4f9be71dd88cf7588936685fdb576d93a79fe9f65bc337d9d73"},
+ {file = "coverage-7.6.12-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:78f5243bb6b1060aed6213d5107744c19f9571ec76d54c99cc15938eb69e0e86"},
+ {file = "coverage-7.6.12-cp39-cp39-win32.whl", hash = "sha256:69e62c5034291c845fc4df7f8155e8544178b6c774f97a99e2734b05eb5bed31"},
+ {file = "coverage-7.6.12-cp39-cp39-win_amd64.whl", hash = "sha256:b01a840ecc25dce235ae4c1b6a0daefb2a203dba0e6e980637ee9c2f6ee0df57"},
+ {file = "coverage-7.6.12-pp39.pp310-none-any.whl", hash = "sha256:7e39e845c4d764208e7b8f6a21c541ade741e2c41afabdfa1caa28687a3c98cf"},
+ {file = "coverage-7.6.12-py3-none-any.whl", hash = "sha256:eb8668cfbc279a536c633137deeb9435d2962caec279c3f8cf8b91fff6ff8953"},
+ {file = "coverage-7.6.12.tar.gz", hash = "sha256:48cfc4641d95d34766ad41d9573cc0f22a48aa88d22657a1fe01dca0dbae4de2"},
]
[package.dependencies]
@@ -426,13 +432,13 @@ tomli = ["tomli (>=2.0.0,<3.0.0)"]
[[package]]
name = "einops"
-version = "0.8.0"
+version = "0.8.1"
description = "A new flavour of deep learning operations"
optional = false
python-versions = ">=3.8"
files = [
- {file = "einops-0.8.0-py3-none-any.whl", hash = "sha256:9572fb63046264a862693b0a87088af3bdc8c068fde03de63453cbbde245465f"},
- {file = "einops-0.8.0.tar.gz", hash = "sha256:63486517fed345712a8385c100cb279108d9d47e6ae59099b07657e983deae85"},
+ {file = "einops-0.8.1-py3-none-any.whl", hash = "sha256:919387eb55330f5757c6bea9165c5ff5cfe63a642682ea788a6d472576d81737"},
+ {file = "einops-0.8.1.tar.gz", hash = "sha256:de5d960a7a761225532e0f1959e5315ebeafc0cd43394732f103ca44b9837e84"},
]
[[package]]
@@ -451,18 +457,18 @@ test = ["pytest (>=6)"]
[[package]]
name = "filelock"
-version = "3.16.1"
+version = "3.17.0"
description = "A platform independent file lock."
optional = false
-python-versions = ">=3.8"
+python-versions = ">=3.9"
files = [
- {file = "filelock-3.16.1-py3-none-any.whl", hash = "sha256:2082e5703d51fbf98ea75855d9d5527e33d8ff23099bec374a134febee6946b0"},
- {file = "filelock-3.16.1.tar.gz", hash = "sha256:c249fbfcd5db47e5e2d6d62198e565475ee65e4831e2561c8e313fa7eb961435"},
+ {file = "filelock-3.17.0-py3-none-any.whl", hash = "sha256:533dc2f7ba78dc2f0f531fc6c4940addf7b70a481e269a5a3b93be94ffbe8338"},
+ {file = "filelock-3.17.0.tar.gz", hash = "sha256:ee4e77401ef576ebb38cd7f13b9b28893194acc20a8e68e18730ba9c0e54660e"},
]
[package.extras]
-docs = ["furo (>=2024.8.6)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2.4.1)"]
-testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "diff-cover (>=9.2)", "pytest (>=8.3.3)", "pytest-asyncio (>=0.24)", "pytest-cov (>=5)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.26.4)"]
+docs = ["furo (>=2024.8.6)", "sphinx (>=8.1.3)", "sphinx-autodoc-typehints (>=3)"]
+testing = ["covdefaults (>=2.3)", "coverage (>=7.6.10)", "diff-cover (>=9.2.1)", "pytest (>=8.3.4)", "pytest-asyncio (>=0.25.2)", "pytest-cov (>=6)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.28.1)"]
typing = ["typing-extensions (>=4.12.2)"]
[[package]]
@@ -568,13 +574,13 @@ files = [
[[package]]
name = "fsspec"
-version = "2024.12.0"
+version = "2025.2.0"
description = "File-system specification"
optional = false
python-versions = ">=3.8"
files = [
- {file = "fsspec-2024.12.0-py3-none-any.whl", hash = "sha256:b520aed47ad9804237ff878b504267a3b0b441e97508bd6d2d8774e3db85cee2"},
- {file = "fsspec-2024.12.0.tar.gz", hash = "sha256:670700c977ed2fb51e0d9f9253177ed20cbde4a3e5c0283cc5385b5870c8533f"},
+ {file = "fsspec-2025.2.0-py3-none-any.whl", hash = "sha256:9de2ad9ce1f85e1931858535bc882543171d197001a0a5eb2ddc04f1781ab95b"},
+ {file = "fsspec-2025.2.0.tar.gz", hash = "sha256:1c24b16eaa0a1798afa0337aa0db9b256718ab2a89c425371f5628d22c3b6afd"},
]
[package.dependencies]
@@ -604,19 +610,19 @@ sftp = ["paramiko"]
smb = ["smbprotocol"]
ssh = ["paramiko"]
test = ["aiohttp (!=4.0.0a0,!=4.0.0a1)", "numpy", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "requests"]
-test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask-expr", "dask[dataframe,test]", "moto[server] (>4,<5)", "pytest-timeout", "xarray"]
+test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask[dataframe,test]", "moto[server] (>4,<5)", "pytest-timeout", "xarray"]
test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr", "zstandard"]
tqdm = ["tqdm"]
[[package]]
name = "huggingface-hub"
-version = "0.27.0"
+version = "0.28.1"
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
optional = false
python-versions = ">=3.8.0"
files = [
- {file = "huggingface_hub-0.27.0-py3-none-any.whl", hash = "sha256:8f2e834517f1f1ddf1ecc716f91b120d7333011b7485f665a9a412eacb1a2a81"},
- {file = "huggingface_hub-0.27.0.tar.gz", hash = "sha256:902cce1a1be5739f5589e560198a65a8edcfd3b830b1666f36e4b961f0454fac"},
+ {file = "huggingface_hub-0.28.1-py3-none-any.whl", hash = "sha256:aa6b9a3ffdae939b72c464dbb0d7f99f56e649b55c3d52406f49e0a5a620c0a7"},
+ {file = "huggingface_hub-0.28.1.tar.gz", hash = "sha256:893471090c98e3b6efbdfdacafe4052b20b84d59866fb6f54c33d9af18c303ae"},
]
[package.dependencies]
@@ -629,13 +635,13 @@ tqdm = ">=4.42.1"
typing-extensions = ">=3.7.4.3"
[package.extras]
-all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "libcst (==1.4.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.5.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
+all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "libcst (==1.4.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.9.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
cli = ["InquirerPy (==0.3.4)"]
-dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "libcst (==1.4.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.5.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
+dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "libcst (==1.4.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.9.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"]
hf-transfer = ["hf-transfer (>=0.1.4)"]
inference = ["aiohttp"]
-quality = ["libcst (==1.4.0)", "mypy (==1.5.1)", "ruff (>=0.5.0)"]
+quality = ["libcst (==1.4.0)", "mypy (==1.5.1)", "ruff (>=0.9.0)"]
tensorflow = ["graphviz", "pydot", "tensorflow"]
tensorflow-testing = ["keras (<3.0)", "tensorflow"]
testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"]
@@ -644,13 +650,13 @@ typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "t
[[package]]
name = "identify"
-version = "2.6.4"
+version = "2.6.7"
description = "File identification library for Python"
optional = false
python-versions = ">=3.9"
files = [
- {file = "identify-2.6.4-py2.py3-none-any.whl", hash = "sha256:993b0f01b97e0568c179bb9196391ff391bfb88a99099dbf5ce392b68f42d0af"},
- {file = "identify-2.6.4.tar.gz", hash = "sha256:285a7d27e397652e8cafe537a6cc97dd470a970f48fb2e9d979aa38eae5513ac"},
+ {file = "identify-2.6.7-py2.py3-none-any.whl", hash = "sha256:155931cb617a401807b09ecec6635d6c692d180090a1cedca8ef7d58ba5b6aa0"},
+ {file = "identify-2.6.7.tar.gz", hash = "sha256:3fa266b42eba321ee0b2bb0936a6a6b9e36a1351cbb69055b3082f4193035684"},
]
[package.extras]
@@ -753,24 +759,24 @@ test = ["click (==8.1.7)", "cloudpickle (>=1.3,<3.0)", "coverage (==7.3.1)", "fa
[[package]]
name = "lightning-utilities"
-version = "0.11.9"
+version = "0.12.0"
description = "Lightning toolbox for across the our ecosystem."
optional = false
-python-versions = ">=3.8"
+python-versions = ">=3.9"
files = [
- {file = "lightning_utilities-0.11.9-py3-none-any.whl", hash = "sha256:ac6d4e9e28faf3ff4be997876750fee10dc604753dbc429bf3848a95c5d7e0d2"},
- {file = "lightning_utilities-0.11.9.tar.gz", hash = "sha256:f5052b81344cc2684aa9afd74b7ce8819a8f49a858184ec04548a5a109dfd053"},
+ {file = "lightning_utilities-0.12.0-py3-none-any.whl", hash = "sha256:b827f5768607e81ccc7b2ada1f50628168d1cc9f839509c7e87c04b59079e66c"},
+ {file = "lightning_utilities-0.12.0.tar.gz", hash = "sha256:95b5f22a0b69eb27ca0929c6c1d510592a70080e1733a055bf154903c0343b60"},
]
[package.dependencies]
packaging = ">=17.1"
setuptools = "*"
-typing-extensions = "*"
+typing_extensions = "*"
[package.extras]
cli = ["fire"]
docs = ["requests (>=2.0.0)"]
-typing = ["mypy (>=1.0.0)", "types-setuptools"]
+typing = ["fire", "mypy (>=1.0.0)", "types-setuptools"]
[[package]]
name = "markupsafe"
@@ -1447,32 +1453,25 @@ scipy = "*"
[[package]]
name = "psutil"
-version = "6.1.1"
-description = "Cross-platform lib for process and system monitoring in Python."
-optional = false
-python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7"
-files = [
- {file = "psutil-6.1.1-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:9ccc4316f24409159897799b83004cb1e24f9819b0dcf9c0b68bdcb6cefee6a8"},
- {file = "psutil-6.1.1-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:ca9609c77ea3b8481ab005da74ed894035936223422dc591d6772b147421f777"},
- {file = "psutil-6.1.1-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:8df0178ba8a9e5bc84fed9cfa61d54601b371fbec5c8eebad27575f1e105c0d4"},
- {file = "psutil-6.1.1-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:1924e659d6c19c647e763e78670a05dbb7feaf44a0e9c94bf9e14dfc6ba50468"},
- {file = "psutil-6.1.1-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:018aeae2af92d943fdf1da6b58665124897cfc94faa2ca92098838f83e1b1bca"},
- {file = "psutil-6.1.1-cp27-none-win32.whl", hash = "sha256:6d4281f5bbca041e2292be3380ec56a9413b790579b8e593b1784499d0005dac"},
- {file = "psutil-6.1.1-cp27-none-win_amd64.whl", hash = "sha256:c777eb75bb33c47377c9af68f30e9f11bc78e0f07fbf907be4a5d70b2fe5f030"},
- {file = "psutil-6.1.1-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:fc0ed7fe2231a444fc219b9c42d0376e0a9a1a72f16c5cfa0f68d19f1a0663e8"},
- {file = "psutil-6.1.1-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:0bdd4eab935276290ad3cb718e9809412895ca6b5b334f5a9111ee6d9aff9377"},
- {file = "psutil-6.1.1-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b6e06c20c05fe95a3d7302d74e7097756d4ba1247975ad6905441ae1b5b66003"},
- {file = "psutil-6.1.1-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97f7cb9921fbec4904f522d972f0c0e1f4fabbdd4e0287813b21215074a0f160"},
- {file = "psutil-6.1.1-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:33431e84fee02bc84ea36d9e2c4a6d395d479c9dd9bba2376c1f6ee8f3a4e0b3"},
- {file = "psutil-6.1.1-cp36-cp36m-win32.whl", hash = "sha256:384636b1a64b47814437d1173be1427a7c83681b17a450bfc309a1953e329603"},
- {file = "psutil-6.1.1-cp36-cp36m-win_amd64.whl", hash = "sha256:8be07491f6ebe1a693f17d4f11e69d0dc1811fa082736500f649f79df7735303"},
- {file = "psutil-6.1.1-cp37-abi3-win32.whl", hash = "sha256:eaa912e0b11848c4d9279a93d7e2783df352b082f40111e078388701fd479e53"},
- {file = "psutil-6.1.1-cp37-abi3-win_amd64.whl", hash = "sha256:f35cfccb065fff93529d2afb4a2e89e363fe63ca1e4a5da22b603a85833c2649"},
- {file = "psutil-6.1.1.tar.gz", hash = "sha256:cf8496728c18f2d0b45198f06895be52f36611711746b7f30c464b422b50e2f5"},
+version = "7.0.0"
+description = "Cross-platform lib for process and system monitoring in Python. NOTE: the syntax of this script MUST be kept compatible with Python 2.7."
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "psutil-7.0.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:101d71dc322e3cffd7cea0650b09b3d08b8e7c4109dd6809fe452dfd00e58b25"},
+ {file = "psutil-7.0.0-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:39db632f6bb862eeccf56660871433e111b6ea58f2caea825571951d4b6aa3da"},
+ {file = "psutil-7.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fcee592b4c6f146991ca55919ea3d1f8926497a713ed7faaf8225e174581e91"},
+ {file = "psutil-7.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b1388a4f6875d7e2aff5c4ca1cc16c545ed41dd8bb596cefea80111db353a34"},
+ {file = "psutil-7.0.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5f098451abc2828f7dc6b58d44b532b22f2088f4999a937557b603ce72b1993"},
+ {file = "psutil-7.0.0-cp36-cp36m-win32.whl", hash = "sha256:84df4eb63e16849689f76b1ffcb36db7b8de703d1bc1fe41773db487621b6c17"},
+ {file = "psutil-7.0.0-cp36-cp36m-win_amd64.whl", hash = "sha256:1e744154a6580bc968a0195fd25e80432d3afec619daf145b9e5ba16cc1d688e"},
+ {file = "psutil-7.0.0-cp37-abi3-win32.whl", hash = "sha256:ba3fcef7523064a6c9da440fc4d6bd07da93ac726b5733c29027d7dc95b39d99"},
+ {file = "psutil-7.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:4cf3d4eb1aa9b348dec30105c55cd9b7d4629285735a102beb4441e38db90553"},
+ {file = "psutil-7.0.0.tar.gz", hash = "sha256:7be9c3eba38beccb6495ea33afd982a44074b78f28c434a1f51cc07fd315c456"},
]
[package.extras]
-dev = ["abi3audit", "black", "check-manifest", "coverage", "packaging", "pylint", "pyperf", "pypinfo", "pytest-cov", "requests", "rstcheck", "ruff", "sphinx", "sphinx_rtd_theme", "toml-sort", "twine", "virtualenv", "vulture", "wheel"]
+dev = ["abi3audit", "black (==24.10.0)", "check-manifest", "coverage", "packaging", "pylint", "pyperf", "pypinfo", "pytest", "pytest-cov", "pytest-xdist", "requests", "rstcheck", "ruff", "setuptools", "sphinx", "sphinx_rtd_theme", "toml-sort", "twine", "virtualenv", "vulture", "wheel"]
test = ["pytest", "pytest-xdist", "setuptools"]
[[package]]
@@ -1578,13 +1577,13 @@ test = ["cloudpickle (>=1.3)", "coverage (==7.3.1)", "fastapi", "numpy (>=1.17.2
[[package]]
name = "pytz"
-version = "2024.2"
+version = "2025.1"
description = "World timezone definitions, modern and historical"
optional = false
python-versions = "*"
files = [
- {file = "pytz-2024.2-py2.py3-none-any.whl", hash = "sha256:31c7c1817eb7fae7ca4b8c7ee50c72f93aa2dd863de768e1ef4245d426aa0725"},
- {file = "pytz-2024.2.tar.gz", hash = "sha256:2aa355083c50a0f93fa581709deac0c9ad65cca8a9e9beac660adcbd493c798a"},
+ {file = "pytz-2025.1-py2.py3-none-any.whl", hash = "sha256:89dd22dca55b46eac6eda23b2d72721bf1bdfef212645d81513ef5d03038de57"},
+ {file = "pytz-2025.1.tar.gz", hash = "sha256:c2db42be2a2518b28e65f9207c4d05e6ff547d1efa4086469ef855e4ab70178e"},
]
[[package]]
@@ -1672,53 +1671,53 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]]
name = "ruff"
-version = "0.8.5"
+version = "0.9.6"
description = "An extremely fast Python linter and code formatter, written in Rust."
optional = false
python-versions = ">=3.7"
files = [
- {file = "ruff-0.8.5-py3-none-linux_armv6l.whl", hash = "sha256:5ad11a5e3868a73ca1fa4727fe7e33735ea78b416313f4368c504dbeb69c0f88"},
- {file = "ruff-0.8.5-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:f69ab37771ea7e0715fead8624ec42996d101269a96e31f4d31be6fc33aa19b7"},
- {file = "ruff-0.8.5-py3-none-macosx_11_0_arm64.whl", hash = "sha256:b5462d7804558ccff9c08fe8cbf6c14b7efe67404316696a2dde48297b1925bb"},
- {file = "ruff-0.8.5-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d56de7220a35607f9fe59f8a6d018e14504f7b71d784d980835e20fc0611cd50"},
- {file = "ruff-0.8.5-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9d99cf80b0429cbebf31cbbf6f24f05a29706f0437c40413d950e67e2d4faca4"},
- {file = "ruff-0.8.5-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7b75ac29715ac60d554a049dbb0ef3b55259076181c3369d79466cb130eb5afd"},
- {file = "ruff-0.8.5-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:c9d526a62c9eda211b38463528768fd0ada25dad524cb33c0e99fcff1c67b5dc"},
- {file = "ruff-0.8.5-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:587c5e95007612c26509f30acc506c874dab4c4abbacd0357400bd1aa799931b"},
- {file = "ruff-0.8.5-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:622b82bf3429ff0e346835ec213aec0a04d9730480cbffbb6ad9372014e31bbd"},
- {file = "ruff-0.8.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f99be814d77a5dac8a8957104bdd8c359e85c86b0ee0e38dca447cb1095f70fb"},
- {file = "ruff-0.8.5-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:c01c048f9c3385e0fd7822ad0fd519afb282af9cf1778f3580e540629df89725"},
- {file = "ruff-0.8.5-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:7512e8cb038db7f5db6aae0e24735ff9ea03bb0ed6ae2ce534e9baa23c1dc9ea"},
- {file = "ruff-0.8.5-py3-none-musllinux_1_2_i686.whl", hash = "sha256:762f113232acd5b768d6b875d16aad6b00082add40ec91c927f0673a8ec4ede8"},
- {file = "ruff-0.8.5-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:03a90200c5dfff49e4c967b405f27fdfa81594cbb7c5ff5609e42d7fe9680da5"},
- {file = "ruff-0.8.5-py3-none-win32.whl", hash = "sha256:8710ffd57bdaa6690cbf6ecff19884b8629ec2a2a2a2f783aa94b1cc795139ed"},
- {file = "ruff-0.8.5-py3-none-win_amd64.whl", hash = "sha256:4020d8bf8d3a32325c77af452a9976a9ad6455773bcb94991cf15bd66b347e47"},
- {file = "ruff-0.8.5-py3-none-win_arm64.whl", hash = "sha256:134ae019ef13e1b060ab7136e7828a6d83ea727ba123381307eb37c6bd5e01cb"},
- {file = "ruff-0.8.5.tar.gz", hash = "sha256:1098d36f69831f7ff2a1da3e6407d5fbd6dfa2559e4f74ff2d260c5588900317"},
+ {file = "ruff-0.9.6-py3-none-linux_armv6l.whl", hash = "sha256:2f218f356dd2d995839f1941322ff021c72a492c470f0b26a34f844c29cdf5ba"},
+ {file = "ruff-0.9.6-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b908ff4df65dad7b251c9968a2e4560836d8f5487c2f0cc238321ed951ea0504"},
+ {file = "ruff-0.9.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:b109c0ad2ececf42e75fa99dc4043ff72a357436bb171900714a9ea581ddef83"},
+ {file = "ruff-0.9.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1de4367cca3dac99bcbd15c161404e849bb0bfd543664db39232648dc00112dc"},
+ {file = "ruff-0.9.6-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ac3ee4d7c2c92ddfdaedf0bf31b2b176fa7aa8950efc454628d477394d35638b"},
+ {file = "ruff-0.9.6-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5dc1edd1775270e6aa2386119aea692039781429f0be1e0949ea5884e011aa8e"},
+ {file = "ruff-0.9.6-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:4a091729086dffa4bd070aa5dab7e39cc6b9d62eb2bef8f3d91172d30d599666"},
+ {file = "ruff-0.9.6-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d1bbc6808bf7b15796cef0815e1dfb796fbd383e7dbd4334709642649625e7c5"},
+ {file = "ruff-0.9.6-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:589d1d9f25b5754ff230dce914a174a7c951a85a4e9270613a2b74231fdac2f5"},
+ {file = "ruff-0.9.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dc61dd5131742e21103fbbdcad683a8813be0e3c204472d520d9a5021ca8b217"},
+ {file = "ruff-0.9.6-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5e2d9126161d0357e5c8f30b0bd6168d2c3872372f14481136d13de9937f79b6"},
+ {file = "ruff-0.9.6-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:68660eab1a8e65babb5229a1f97b46e3120923757a68b5413d8561f8a85d4897"},
+ {file = "ruff-0.9.6-py3-none-musllinux_1_2_i686.whl", hash = "sha256:c4cae6c4cc7b9b4017c71114115db0445b00a16de3bcde0946273e8392856f08"},
+ {file = "ruff-0.9.6-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:19f505b643228b417c1111a2a536424ddde0db4ef9023b9e04a46ed8a1cb4656"},
+ {file = "ruff-0.9.6-py3-none-win32.whl", hash = "sha256:194d8402bceef1b31164909540a597e0d913c0e4952015a5b40e28c146121b5d"},
+ {file = "ruff-0.9.6-py3-none-win_amd64.whl", hash = "sha256:03482d5c09d90d4ee3f40d97578423698ad895c87314c4de39ed2af945633caa"},
+ {file = "ruff-0.9.6-py3-none-win_arm64.whl", hash = "sha256:0e2bb706a2be7ddfea4a4af918562fdc1bcb16df255e5fa595bbd800ce322a5a"},
+ {file = "ruff-0.9.6.tar.gz", hash = "sha256:81761592f72b620ec8fa1068a6fd00e98a5ebee342a3642efd84454f3031dca9"},
]
[[package]]
name = "safetensors"
-version = "0.5.0"
+version = "0.5.2"
description = ""
optional = false
python-versions = ">=3.7"
files = [
- {file = "safetensors-0.5.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:c683b9b485bee43422ba2855f72777c37647190281e03da4c8d2a69fa5336558"},
- {file = "safetensors-0.5.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:6106aa835deb7263f7014f74c05842ab828d6c11d789f2e7e98f26b1a305e72d"},
- {file = "safetensors-0.5.0-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1349611f74f55c5ee1c1c144c536a2743c38f7d8bf60b9fc8267e0efc0591a2"},
- {file = "safetensors-0.5.0-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:56d936028ac799e18644b08a91fd98b4b62ae3dcd0440b1cfcb56535785589f1"},
- {file = "safetensors-0.5.0-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a2f26afada2233576ffea6b80042c2c0a8105c164254af56168ec14299ad3122"},
- {file = "safetensors-0.5.0-cp38-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:20067e7a5e63f0cbc88457b2a1161e70ff73af4cc3a24bce90309430cd6f6e7e"},
- {file = "safetensors-0.5.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:649d6a4aa34d5174ae87289068ccc2fec2a1a998ecf83425aa5a42c3eff69bcf"},
- {file = "safetensors-0.5.0-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:debff88f41d569a3e93a955469f83864e432af35bb34b16f65a9ddf378daa3ae"},
- {file = "safetensors-0.5.0-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:bdf6a3e366ea8ba1a0538db6099229e95811194432c684ea28ea7ae28763b8dc"},
- {file = "safetensors-0.5.0-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:0371afd84c200a80eb7103bf715108b0c3846132fb82453ae018609a15551580"},
- {file = "safetensors-0.5.0-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:5ec7fc8c3d2f32ebf1c7011bc886b362e53ee0a1ec6d828c39d531fed8b325d6"},
- {file = "safetensors-0.5.0-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:53715e4ea0ef23c08f004baae0f609a7773de7d4148727760417c6760cfd6b76"},
- {file = "safetensors-0.5.0-cp38-abi3-win32.whl", hash = "sha256:b85565bc2f0456961a788d2f11d9d892eec46603db0e4923aa9512c2355aa727"},
- {file = "safetensors-0.5.0-cp38-abi3-win_amd64.whl", hash = "sha256:f451941f8aa11e7be5c3fa450e264609a2b1e65fa38ae590a74e55a94d646b76"},
- {file = "safetensors-0.5.0.tar.gz", hash = "sha256:c47b34c549fa1e0c655c4644da31332c61332c732c47c8dd9399347e9aac69d1"},
+ {file = "safetensors-0.5.2-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:45b6092997ceb8aa3801693781a71a99909ab9cc776fbc3fa9322d29b1d3bef2"},
+ {file = "safetensors-0.5.2-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:6d0d6a8ee2215a440e1296b843edf44fd377b055ba350eaba74655a2fe2c4bae"},
+ {file = "safetensors-0.5.2-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:86016d40bcaa3bcc9a56cd74d97e654b5f4f4abe42b038c71e4f00a089c4526c"},
+ {file = "safetensors-0.5.2-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:990833f70a5f9c7d3fc82c94507f03179930ff7d00941c287f73b6fcbf67f19e"},
+ {file = "safetensors-0.5.2-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3dfa7c2f3fe55db34eba90c29df94bcdac4821043fc391cb5d082d9922013869"},
+ {file = "safetensors-0.5.2-cp38-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:46ff2116150ae70a4e9c490d2ab6b6e1b1b93f25e520e540abe1b81b48560c3a"},
+ {file = "safetensors-0.5.2-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ab696dfdc060caffb61dbe4066b86419107a24c804a4e373ba59be699ebd8d5"},
+ {file = "safetensors-0.5.2-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:03c937100f38c9ff4c1507abea9928a6a9b02c9c1c9c3609ed4fb2bf413d4975"},
+ {file = "safetensors-0.5.2-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:a00e737948791b94dad83cf0eafc09a02c4d8c2171a239e8c8572fe04e25960e"},
+ {file = "safetensors-0.5.2-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:d3a06fae62418ec8e5c635b61a8086032c9e281f16c63c3af46a6efbab33156f"},
+ {file = "safetensors-0.5.2-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:1506e4c2eda1431099cebe9abf6c76853e95d0b7a95addceaa74c6019c65d8cf"},
+ {file = "safetensors-0.5.2-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5c5b5d9da594f638a259fca766046f44c97244cc7ab8bef161b3e80d04becc76"},
+ {file = "safetensors-0.5.2-cp38-abi3-win32.whl", hash = "sha256:fe55c039d97090d1f85277d402954dd6ad27f63034fa81985a9cc59655ac3ee2"},
+ {file = "safetensors-0.5.2-cp38-abi3-win_amd64.whl", hash = "sha256:78abdddd03a406646107f973c7843276e7b64e5e32623529dc17f3d94a20f589"},
+ {file = "safetensors-0.5.2.tar.gz", hash = "sha256:cb4a8d98ba12fa016f4241932b1fc5e702e5143f5374bba0bbcf7ddc1c4cf2b8"},
]
[package.extras]
@@ -1736,41 +1735,41 @@ torch = ["safetensors[numpy]", "torch (>=1.10)"]
[[package]]
name = "scikit-learn"
-version = "1.6.0"
+version = "1.6.1"
description = "A set of python modules for machine learning and data mining"
optional = false
python-versions = ">=3.9"
files = [
- {file = "scikit_learn-1.6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:366fb3fa47dce90afed3d6106183f4978d6f24cfd595c2373424171b915ee718"},
- {file = "scikit_learn-1.6.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:59cd96a8d9f8dfd546f5d6e9787e1b989e981388d7803abbc9efdcde61e47460"},
- {file = "scikit_learn-1.6.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:efa7a579606c73a0b3d210e33ea410ea9e1af7933fe324cb7e6fbafae4ea5948"},
- {file = "scikit_learn-1.6.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a46d3ca0f11a540b8eaddaf5e38172d8cd65a86cb3e3632161ec96c0cffb774c"},
- {file = "scikit_learn-1.6.0-cp310-cp310-win_amd64.whl", hash = "sha256:5be4577769c5dde6e1b53de8e6520f9b664ab5861dd57acee47ad119fd7405d6"},
- {file = "scikit_learn-1.6.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1f50b4f24cf12a81c3c09958ae3b864d7534934ca66ded3822de4996d25d7285"},
- {file = "scikit_learn-1.6.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:eb9ae21f387826da14b0b9cb1034f5048ddb9182da429c689f5f4a87dc96930b"},
- {file = "scikit_learn-1.6.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0baa91eeb8c32632628874a5c91885eaedd23b71504d24227925080da075837a"},
- {file = "scikit_learn-1.6.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c716d13ba0a2f8762d96ff78d3e0cde90bc9c9b5c13d6ab6bb9b2d6ca6705fd"},
- {file = "scikit_learn-1.6.0-cp311-cp311-win_amd64.whl", hash = "sha256:9aafd94bafc841b626681e626be27bf1233d5a0f20f0a6fdb4bee1a1963c6643"},
- {file = "scikit_learn-1.6.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:04a5ba45c12a5ff81518aa4f1604e826a45d20e53da47b15871526cda4ff5174"},
- {file = "scikit_learn-1.6.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:21fadfc2ad7a1ce8bd1d90f23d17875b84ec765eecbbfc924ff11fb73db582ce"},
- {file = "scikit_learn-1.6.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:30f34bb5fde90e020653bb84dcb38b6c83f90c70680dbd8c38bd9becbad7a127"},
- {file = "scikit_learn-1.6.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1dad624cffe3062276a0881d4e441bc9e3b19d02d17757cd6ae79a9d192a0027"},
- {file = "scikit_learn-1.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:2fce7950a3fad85e0a61dc403df0f9345b53432ac0e47c50da210d22c60b6d85"},
- {file = "scikit_learn-1.6.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e5453b2e87ef8accedc5a8a4e6709f887ca01896cd7cc8a174fe39bd4bb00aef"},
- {file = "scikit_learn-1.6.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:5fe11794236fb83bead2af26a87ced5d26e3370b8487430818b915dafab1724e"},
- {file = "scikit_learn-1.6.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:61fe3dcec0d82ae280877a818ab652f4988371e32dd5451e75251bece79668b1"},
- {file = "scikit_learn-1.6.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b44e3a51e181933bdf9a4953cc69c6025b40d2b49e238233f149b98849beb4bf"},
- {file = "scikit_learn-1.6.0-cp313-cp313-win_amd64.whl", hash = "sha256:a17860a562bac54384454d40b3f6155200c1c737c9399e6a97962c63fce503ac"},
- {file = "scikit_learn-1.6.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:98717d3c152f6842d36a70f21e1468fb2f1a2f8f2624d9a3f382211798516426"},
- {file = "scikit_learn-1.6.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:34e20bfac8ff0ebe0ff20fb16a4d6df5dc4cc9ce383e00c2ab67a526a3c67b18"},
- {file = "scikit_learn-1.6.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eba06d75815406091419e06dd650b91ebd1c5f836392a0d833ff36447c2b1bfa"},
- {file = "scikit_learn-1.6.0-cp313-cp313t-win_amd64.whl", hash = "sha256:b6916d1cec1ff163c7d281e699d7a6a709da2f2c5ec7b10547e08cc788ddd3ae"},
- {file = "scikit_learn-1.6.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:66b1cf721a9f07f518eb545098226796c399c64abdcbf91c2b95d625068363da"},
- {file = "scikit_learn-1.6.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:7b35b60cf4cd6564b636e4a40516b3c61a4fa7a8b1f7a3ce80c38ebe04750bc3"},
- {file = "scikit_learn-1.6.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a73b1c2038c93bc7f4bf21f6c9828d5116c5d2268f7a20cfbbd41d3074d52083"},
- {file = "scikit_learn-1.6.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c3fa7d3dd5a0ec2d0baba0d644916fa2ab180ee37850c5d536245df916946bd"},
- {file = "scikit_learn-1.6.0-cp39-cp39-win_amd64.whl", hash = "sha256:df778486a32518cda33818b7e3ce48c78cef1d5f640a6bc9d97c6d2e71449a51"},
- {file = "scikit_learn-1.6.0.tar.gz", hash = "sha256:9d58481f9f7499dff4196927aedd4285a0baec8caa3790efbe205f13de37dd6e"},
+ {file = "scikit_learn-1.6.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d056391530ccd1e501056160e3c9673b4da4805eb67eb2bdf4e983e1f9c9204e"},
+ {file = "scikit_learn-1.6.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:0c8d036eb937dbb568c6242fa598d551d88fb4399c0344d95c001980ec1c7d36"},
+ {file = "scikit_learn-1.6.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8634c4bd21a2a813e0a7e3900464e6d593162a29dd35d25bdf0103b3fce60ed5"},
+ {file = "scikit_learn-1.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:775da975a471c4f6f467725dff0ced5c7ac7bda5e9316b260225b48475279a1b"},
+ {file = "scikit_learn-1.6.1-cp310-cp310-win_amd64.whl", hash = "sha256:8a600c31592bd7dab31e1c61b9bbd6dea1b3433e67d264d17ce1017dbdce8002"},
+ {file = "scikit_learn-1.6.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:72abc587c75234935e97d09aa4913a82f7b03ee0b74111dcc2881cba3c5a7b33"},
+ {file = "scikit_learn-1.6.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:b3b00cdc8f1317b5f33191df1386c0befd16625f49d979fe77a8d44cae82410d"},
+ {file = "scikit_learn-1.6.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dc4765af3386811c3ca21638f63b9cf5ecf66261cc4815c1db3f1e7dc7b79db2"},
+ {file = "scikit_learn-1.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:25fc636bdaf1cc2f4a124a116312d837148b5e10872147bdaf4887926b8c03d8"},
+ {file = "scikit_learn-1.6.1-cp311-cp311-win_amd64.whl", hash = "sha256:fa909b1a36e000a03c382aade0bd2063fd5680ff8b8e501660c0f59f021a6415"},
+ {file = "scikit_learn-1.6.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:926f207c804104677af4857b2c609940b743d04c4c35ce0ddc8ff4f053cddc1b"},
+ {file = "scikit_learn-1.6.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:2c2cae262064e6a9b77eee1c8e768fc46aa0b8338c6a8297b9b6759720ec0ff2"},
+ {file = "scikit_learn-1.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1061b7c028a8663fb9a1a1baf9317b64a257fcb036dae5c8752b2abef31d136f"},
+ {file = "scikit_learn-1.6.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e69fab4ebfc9c9b580a7a80111b43d214ab06250f8a7ef590a4edf72464dd86"},
+ {file = "scikit_learn-1.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:70b1d7e85b1c96383f872a519b3375f92f14731e279a7b4c6cfd650cf5dffc52"},
+ {file = "scikit_learn-1.6.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:2ffa1e9e25b3d93990e74a4be2c2fc61ee5af85811562f1288d5d055880c4322"},
+ {file = "scikit_learn-1.6.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:dc5cf3d68c5a20ad6d571584c0750ec641cc46aeef1c1507be51300e6003a7e1"},
+ {file = "scikit_learn-1.6.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c06beb2e839ecc641366000ca84f3cf6fa9faa1777e29cf0c04be6e4d096a348"},
+ {file = "scikit_learn-1.6.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8ca8cb270fee8f1f76fa9bfd5c3507d60c6438bbee5687f81042e2bb98e5a97"},
+ {file = "scikit_learn-1.6.1-cp313-cp313-win_amd64.whl", hash = "sha256:7a1c43c8ec9fde528d664d947dc4c0789be4077a3647f232869f41d9bf50e0fb"},
+ {file = "scikit_learn-1.6.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:a17c1dea1d56dcda2fac315712f3651a1fea86565b64b48fa1bc090249cbf236"},
+ {file = "scikit_learn-1.6.1-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:6a7aa5f9908f0f28f4edaa6963c0a6183f1911e63a69aa03782f0d924c830a35"},
+ {file = "scikit_learn-1.6.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0650e730afb87402baa88afbf31c07b84c98272622aaba002559b614600ca691"},
+ {file = "scikit_learn-1.6.1-cp313-cp313t-win_amd64.whl", hash = "sha256:3f59fe08dc03ea158605170eb52b22a105f238a5d512c4470ddeca71feae8e5f"},
+ {file = "scikit_learn-1.6.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6849dd3234e87f55dce1db34c89a810b489ead832aaf4d4550b7ea85628be6c1"},
+ {file = "scikit_learn-1.6.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:e7be3fa5d2eb9be7d77c3734ff1d599151bb523674be9b834e8da6abe132f44e"},
+ {file = "scikit_learn-1.6.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:44a17798172df1d3c1065e8fcf9019183f06c87609b49a124ebdf57ae6cb0107"},
+ {file = "scikit_learn-1.6.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8b7a3b86e411e4bce21186e1c180d792f3d99223dcfa3b4f597ecc92fa1a422"},
+ {file = "scikit_learn-1.6.1-cp39-cp39-win_amd64.whl", hash = "sha256:7a73d457070e3318e32bdb3aa79a8d990474f19035464dfd8bede2883ab5dc3b"},
+ {file = "scikit_learn-1.6.1.tar.gz", hash = "sha256:b4fc2525eca2c69a59260f583c56a7557c6ccdf8deafdba6e060f94c1c59738e"},
]
[package.dependencies]
@@ -1814,51 +1813,51 @@ plots = ["matplotlib (>=2.0.0)"]
[[package]]
name = "scipy"
-version = "1.15.0"
+version = "1.15.1"
description = "Fundamental algorithms for scientific computing in Python"
optional = false
python-versions = ">=3.10"
files = [
- {file = "scipy-1.15.0-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:aeac60d3562a7bf2f35549bdfdb6b1751c50590f55ce7322b4b2fc821dc27fca"},
- {file = "scipy-1.15.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:5abbdc6ede5c5fed7910cf406a948e2c0869231c0db091593a6b2fa78be77e5d"},
- {file = "scipy-1.15.0-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:eb1533c59f0ec6c55871206f15a5c72d1fae7ad3c0a8ca33ca88f7c309bbbf8c"},
- {file = "scipy-1.15.0-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:de112c2dae53107cfeaf65101419662ac0a54e9a088c17958b51c95dac5de56d"},
- {file = "scipy-1.15.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2240e1fd0782e62e1aacdc7234212ee271d810f67e9cd3b8d521003a82603ef8"},
- {file = "scipy-1.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d35aef233b098e4de88b1eac29f0df378278e7e250a915766786b773309137c4"},
- {file = "scipy-1.15.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:1b29e4fc02e155a5fd1165f1e6a73edfdd110470736b0f48bcbe48083f0eee37"},
- {file = "scipy-1.15.0-cp310-cp310-win_amd64.whl", hash = "sha256:0e5b34f8894f9904cc578008d1a9467829c1817e9f9cb45e6d6eeb61d2ab7731"},
- {file = "scipy-1.15.0-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:46e91b5b16909ff79224b56e19cbad65ca500b3afda69225820aa3afbf9ec020"},
- {file = "scipy-1.15.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:82bff2eb01ccf7cea8b6ee5274c2dbeadfdac97919da308ee6d8e5bcbe846443"},
- {file = "scipy-1.15.0-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:9c8254fe21dd2c6c8f7757035ec0c31daecf3bb3cffd93bc1ca661b731d28136"},
- {file = "scipy-1.15.0-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:c9624eeae79b18cab1a31944b5ef87aa14b125d6ab69b71db22f0dbd962caf1e"},
- {file = "scipy-1.15.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d13bbc0658c11f3d19df4138336e4bce2c4fbd78c2755be4bf7b8e235481557f"},
- {file = "scipy-1.15.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bdca4c7bb8dc41307e5f39e9e5d19c707d8e20a29845e7533b3bb20a9d4ccba0"},
- {file = "scipy-1.15.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:6f376d7c767731477bac25a85d0118efdc94a572c6b60decb1ee48bf2391a73b"},
- {file = "scipy-1.15.0-cp311-cp311-win_amd64.whl", hash = "sha256:61513b989ee8d5218fbeb178b2d51534ecaddba050db949ae99eeb3d12f6825d"},
- {file = "scipy-1.15.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5beb0a2200372b7416ec73fdae94fe81a6e85e44eb49c35a11ac356d2b8eccc6"},
- {file = "scipy-1.15.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:fde0f3104dfa1dfbc1f230f65506532d0558d43188789eaf68f97e106249a913"},
- {file = "scipy-1.15.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:35c68f7044b4e7ad73a3e68e513dda946989e523df9b062bd3cf401a1a882192"},
- {file = "scipy-1.15.0-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:52475011be29dfcbecc3dfe3060e471ac5155d72e9233e8d5616b84e2b542054"},
- {file = "scipy-1.15.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5972e3f96f7dda4fd3bb85906a17338e65eaddfe47f750e240f22b331c08858e"},
- {file = "scipy-1.15.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe00169cf875bed0b3c40e4da45b57037dc21d7c7bf0c85ed75f210c281488f1"},
- {file = "scipy-1.15.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:161f80a98047c219c257bf5ce1777c574bde36b9d962a46b20d0d7e531f86863"},
- {file = "scipy-1.15.0-cp312-cp312-win_amd64.whl", hash = "sha256:327163ad73e54541a675240708244644294cb0a65cca420c9c79baeb9648e479"},
- {file = "scipy-1.15.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:0fcb16eb04d84670722ce8d93b05257df471704c913cb0ff9dc5a1c31d1e9422"},
- {file = "scipy-1.15.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:767e8cf6562931f8312f4faa7ddea412cb783d8df49e62c44d00d89f41f9bbe8"},
- {file = "scipy-1.15.0-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:37ce9394cdcd7c5f437583fc6ef91bd290014993900643fdfc7af9b052d1613b"},
- {file = "scipy-1.15.0-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:6d26f17c64abd6c6c2dfb39920f61518cc9e213d034b45b2380e32ba78fde4c0"},
- {file = "scipy-1.15.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1e2448acd79c6374583581a1ded32ac71a00c2b9c62dfa87a40e1dd2520be111"},
- {file = "scipy-1.15.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36be480e512d38db67f377add5b759fb117edd987f4791cdf58e59b26962bee4"},
- {file = "scipy-1.15.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ccb6248a9987193fe74363a2d73b93bc2c546e0728bd786050b7aef6e17db03c"},
- {file = "scipy-1.15.0-cp313-cp313-win_amd64.whl", hash = "sha256:952d2e9eaa787f0a9e95b6e85da3654791b57a156c3e6609e65cc5176ccfe6f2"},
- {file = "scipy-1.15.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:b1432102254b6dc7766d081fa92df87832ac25ff0b3d3a940f37276e63eb74ff"},
- {file = "scipy-1.15.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:4e08c6a36f46abaedf765dd2dfcd3698fa4bd7e311a9abb2d80e33d9b2d72c34"},
- {file = "scipy-1.15.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:ec915cd26d76f6fc7ae8522f74f5b2accf39546f341c771bb2297f3871934a52"},
- {file = "scipy-1.15.0-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:351899dd2a801edd3691622172bc8ea01064b1cada794f8641b89a7dc5418db6"},
- {file = "scipy-1.15.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e9baff912ea4f78a543d183ed6f5b3bea9784509b948227daaf6f10727a0e2e5"},
- {file = "scipy-1.15.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:cd9d9198a7fd9a77f0eb5105ea9734df26f41faeb2a88a0e62e5245506f7b6df"},
- {file = "scipy-1.15.0-cp313-cp313t-win_amd64.whl", hash = "sha256:129f899ed275c0515d553b8d31696924e2ca87d1972421e46c376b9eb87de3d2"},
- {file = "scipy-1.15.0.tar.gz", hash = "sha256:300742e2cc94e36a2880ebe464a1c8b4352a7b0f3e36ec3d2ac006cdbe0219ac"},
+ {file = "scipy-1.15.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:c64ded12dcab08afff9e805a67ff4480f5e69993310e093434b10e85dc9d43e1"},
+ {file = "scipy-1.15.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:5b190b935e7db569960b48840e5bef71dc513314cc4e79a1b7d14664f57fd4ff"},
+ {file = "scipy-1.15.1-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:4b17d4220df99bacb63065c76b0d1126d82bbf00167d1730019d2a30d6ae01ea"},
+ {file = "scipy-1.15.1-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:63b9b6cd0333d0eb1a49de6f834e8aeaefe438df8f6372352084535ad095219e"},
+ {file = "scipy-1.15.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9f151e9fb60fbf8e52426132f473221a49362091ce7a5e72f8aa41f8e0da4f25"},
+ {file = "scipy-1.15.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21e10b1dd56ce92fba3e786007322542361984f8463c6d37f6f25935a5a6ef52"},
+ {file = "scipy-1.15.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:5dff14e75cdbcf07cdaa1c7707db6017d130f0af9ac41f6ce443a93318d6c6e0"},
+ {file = "scipy-1.15.1-cp310-cp310-win_amd64.whl", hash = "sha256:f82fcf4e5b377f819542fbc8541f7b5fbcf1c0017d0df0bc22c781bf60abc4d8"},
+ {file = "scipy-1.15.1-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:5bd8d27d44e2c13d0c1124e6a556454f52cd3f704742985f6b09e75e163d20d2"},
+ {file = "scipy-1.15.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:be3deeb32844c27599347faa077b359584ba96664c5c79d71a354b80a0ad0ce0"},
+ {file = "scipy-1.15.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:5eb0ca35d4b08e95da99a9f9c400dc9f6c21c424298a0ba876fdc69c7afacedf"},
+ {file = "scipy-1.15.1-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:74bb864ff7640dea310a1377d8567dc2cb7599c26a79ca852fc184cc851954ac"},
+ {file = "scipy-1.15.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:667f950bf8b7c3a23b4199db24cb9bf7512e27e86d0e3813f015b74ec2c6e3df"},
+ {file = "scipy-1.15.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:395be70220d1189756068b3173853029a013d8c8dd5fd3d1361d505b2aa58fa7"},
+ {file = "scipy-1.15.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ce3a000cd28b4430426db2ca44d96636f701ed12e2b3ca1f2b1dd7abdd84b39a"},
+ {file = "scipy-1.15.1-cp311-cp311-win_amd64.whl", hash = "sha256:3fe1d95944f9cf6ba77aa28b82dd6bb2a5b52f2026beb39ecf05304b8392864b"},
+ {file = "scipy-1.15.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c09aa9d90f3500ea4c9b393ee96f96b0ccb27f2f350d09a47f533293c78ea776"},
+ {file = "scipy-1.15.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:0ac102ce99934b162914b1e4a6b94ca7da0f4058b6d6fd65b0cef330c0f3346f"},
+ {file = "scipy-1.15.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:09c52320c42d7f5c7748b69e9f0389266fd4f82cf34c38485c14ee976cb8cb04"},
+ {file = "scipy-1.15.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:cdde8414154054763b42b74fe8ce89d7f3d17a7ac5dd77204f0e142cdc9239e9"},
+ {file = "scipy-1.15.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4c9d8fc81d6a3b6844235e6fd175ee1d4c060163905a2becce8e74cb0d7554ce"},
+ {file = "scipy-1.15.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0fb57b30f0017d4afa5fe5f5b150b8f807618819287c21cbe51130de7ccdaed2"},
+ {file = "scipy-1.15.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:491d57fe89927fa1aafbe260f4cfa5ffa20ab9f1435025045a5315006a91b8f5"},
+ {file = "scipy-1.15.1-cp312-cp312-win_amd64.whl", hash = "sha256:900f3fa3db87257510f011c292a5779eb627043dd89731b9c461cd16ef76ab3d"},
+ {file = "scipy-1.15.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:100193bb72fbff37dbd0bf14322314fc7cbe08b7ff3137f11a34d06dc0ee6b85"},
+ {file = "scipy-1.15.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:2114a08daec64980e4b4cbdf5bee90935af66d750146b1d2feb0d3ac30613692"},
+ {file = "scipy-1.15.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:6b3e71893c6687fc5e29208d518900c24ea372a862854c9888368c0b267387ab"},
+ {file = "scipy-1.15.1-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:837299eec3d19b7e042923448d17d95a86e43941104d33f00da7e31a0f715d3c"},
+ {file = "scipy-1.15.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82add84e8a9fb12af5c2c1a3a3f1cb51849d27a580cb9e6bd66226195142be6e"},
+ {file = "scipy-1.15.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:070d10654f0cb6abd295bc96c12656f948e623ec5f9a4eab0ddb1466c000716e"},
+ {file = "scipy-1.15.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:55cc79ce4085c702ac31e49b1e69b27ef41111f22beafb9b49fea67142b696c4"},
+ {file = "scipy-1.15.1-cp313-cp313-win_amd64.whl", hash = "sha256:c352c1b6d7cac452534517e022f8f7b8d139cd9f27e6fbd9f3cbd0bfd39f5bef"},
+ {file = "scipy-1.15.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0458839c9f873062db69a03de9a9765ae2e694352c76a16be44f93ea45c28d2b"},
+ {file = "scipy-1.15.1-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:af0b61c1de46d0565b4b39c6417373304c1d4f5220004058bdad3061c9fa8a95"},
+ {file = "scipy-1.15.1-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:71ba9a76c2390eca6e359be81a3e879614af3a71dfdabb96d1d7ab33da6f2364"},
+ {file = "scipy-1.15.1-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:14eaa373c89eaf553be73c3affb11ec6c37493b7eaaf31cf9ac5dffae700c2e0"},
+ {file = "scipy-1.15.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f735bc41bd1c792c96bc426dece66c8723283695f02df61dcc4d0a707a42fc54"},
+ {file = "scipy-1.15.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:2722a021a7929d21168830790202a75dbb20b468a8133c74a2c0230c72626b6c"},
+ {file = "scipy-1.15.1-cp313-cp313t-win_amd64.whl", hash = "sha256:bc7136626261ac1ed988dca56cfc4ab5180f75e0ee52e58f1e6aa74b5f3eacd5"},
+ {file = "scipy-1.15.1.tar.gz", hash = "sha256:033a75ddad1463970c96a88063a1df87ccfddd526437136b6ee81ff0312ebdf6"},
]
[package.dependencies]
@@ -1871,23 +1870,23 @@ test = ["Cython", "array-api-strict (>=2.0,<2.1.1)", "asv", "gmpy2", "hypothesis
[[package]]
name = "setuptools"
-version = "75.6.0"
+version = "75.8.0"
description = "Easily download, build, install, upgrade, and uninstall Python packages"
optional = false
python-versions = ">=3.9"
files = [
- {file = "setuptools-75.6.0-py3-none-any.whl", hash = "sha256:ce74b49e8f7110f9bf04883b730f4765b774ef3ef28f722cce7c273d253aaf7d"},
- {file = "setuptools-75.6.0.tar.gz", hash = "sha256:8199222558df7c86216af4f84c30e9b34a61d8ba19366cc914424cdbd28252f6"},
+ {file = "setuptools-75.8.0-py3-none-any.whl", hash = "sha256:e3982f444617239225d675215d51f6ba05f845d4eec313da4418fdbb56fb27e3"},
+ {file = "setuptools-75.8.0.tar.gz", hash = "sha256:c5afc8f407c626b8313a86e10311dd3f661c6cd9c09d4bf8c15c0e11f9f2b0e6"},
]
[package.extras]
-check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.7.0)"]
+check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.8.0)"]
core = ["importlib_metadata (>=6)", "jaraco.collections", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"]
cover = ["pytest-cov"]
doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"]
enabler = ["pytest-enabler (>=2.2)"]
-test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"]
-type = ["importlib_metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (>=1.12,<1.14)", "pytest-mypy"]
+test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"]
+type = ["importlib_metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.14.*)", "pytest-mypy"]
[[package]]
name = "six"
@@ -2105,13 +2104,13 @@ files = [
[[package]]
name = "tzdata"
-version = "2024.2"
+version = "2025.1"
description = "Provider of IANA time zone data"
optional = false
python-versions = ">=2"
files = [
- {file = "tzdata-2024.2-py2.py3-none-any.whl", hash = "sha256:a48093786cdcde33cad18c2555e8532f34422074448fbc874186f0abd79565cd"},
- {file = "tzdata-2024.2.tar.gz", hash = "sha256:7d85cc416e9382e69095b7bdf4afd9e3880418a2413feec7069d533d6b4e31cc"},
+ {file = "tzdata-2025.1-py2.py3-none-any.whl", hash = "sha256:7e127113816800496f027041c570f50bcd464a020098a3b6b199517772303639"},
+ {file = "tzdata-2025.1.tar.gz", hash = "sha256:24894909e88cdb28bd1636c6887801df64cb485bd593f2fd83ef29075a81d694"},
]
[[package]]
@@ -2143,13 +2142,13 @@ zstd = ["zstandard (>=0.18.0)"]
[[package]]
name = "virtualenv"
-version = "20.28.1"
+version = "20.29.2"
description = "Virtual Python Environment builder"
optional = false
python-versions = ">=3.8"
files = [
- {file = "virtualenv-20.28.1-py3-none-any.whl", hash = "sha256:412773c85d4dab0409b83ec36f7a6499e72eaf08c80e81e9576bca61831c71cb"},
- {file = "virtualenv-20.28.1.tar.gz", hash = "sha256:5d34ab240fdb5d21549b76f9e8ff3af28252f5499fb6d6f031adac4e5a8c5329"},
+ {file = "virtualenv-20.29.2-py3-none-any.whl", hash = "sha256:febddfc3d1ea571bdb1dc0f98d7b45d24def7428214d4fb73cc486c9568cce6a"},
+ {file = "virtualenv-20.29.2.tar.gz", hash = "sha256:fdaabebf6d03b5ba83ae0a02cfe96f48a716f4fae556461d180825866f75b728"},
]
[package.dependencies]
@@ -2259,5 +2258,5 @@ propcache = ">=0.2.0"
[metadata]
lock-version = "2.0"
-python-versions = ">=3.10, <=3.12"
-content-hash = "c7606af7fb47a2fb5e856b23ef3e06a1740544bda46470dafeb7c7a3ca794d5e"
+python-versions = ">=3.10, <=3.13"
+content-hash = "b34f786ecb4e8e548d37bbaf30c2b4de024f20cba22399712f255f23ffc4e6d7"
diff --git a/pyproject.toml b/pyproject.toml
index 70b8acc3..b0e4c845 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "mambular"
-version = "1.1.0"
+version = "1.2.0"
description = "A python package for tabular deep learning with mamba blocks."
authors = ["Anton Thielmann", "Manish Kumar", "Christoph Weisser"]
readme = "README.md"
@@ -11,12 +11,12 @@ requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[tool.poetry.dependencies]
-python = ">=3.10, <=3.12"
+python = ">=3.10, <=3.13"
numpy = "<=1.26.4"
pandas = "^2.0.3"
lightning = "^2.3.3"
scikit-learn = "^1.3.2"
-torch = "^2.5.1"
+torch = ">=2.2.2, <=2.5.1"
torchmetrics = "^1.5.2"
setuptools = "^75.3.0"
properscoring = "^0.1"
@@ -25,7 +25,6 @@ einops = "^0.8.0"
accelerate = "^1.2.1"
scipy = "^1.15.0"
-
[tool.poetry.group.dev.dependencies]
pytest = "^8.1"
pytest-cov = "^4.1"
@@ -56,19 +55,6 @@ line-length = 120
target-version = "py310"
exclude = ["*.ipynb", "mambular/arch_utils/mamba_utils.mamba_orginal.py"]
-ignore = [
- "B006",
- "F401", # Ignore unused imports
- "F841", # Ignore unused local variables
- "E501", # Ignore line length
- "D100", # Missing module-level docstring
- "D101", # Missing class-level docstring
- "D102", # Missing method-level docstring
- "D103", # Missing function-level docstring
- "B007",
- "S307",
-]
-
[tool.ruff.lint]
select = [
"A", # flake8-buildins
@@ -83,6 +69,19 @@ select = [
"W", # pycodestyle - warnings
]
+ignore = [
+ "B006",
+ "F401", # Ignore unused imports
+ "F841", # Ignore unused local variables
+ "E501", # Ignore line length
+ "D100", # Missing module-level docstring
+ "D101", # Missing class-level docstring
+ "D102", # Missing method-level docstring
+ "D103", # Missing function-level docstring
+ "B007",
+ "S307",
+]
+
[tool.ruff.lint.per-file-ignores]
# allow asserts in test files (bandit)
diff --git a/tests/test_base.py b/tests/test_base.py
new file mode 100644
index 00000000..20d97971
--- /dev/null
+++ b/tests/test_base.py
@@ -0,0 +1,155 @@
+import pytest
+import inspect
+import torch
+import os
+import importlib
+from mambular.base_models.basemodel import BaseModel
+
+# Paths for models and configs
+MODEL_MODULE_PATH = "mambular.base_models"
+CONFIG_MODULE_PATH = "mambular.configs"
+
+# Discover all models
+model_classes = []
+for filename in os.listdir(os.path.dirname(__file__) + "/../mambular/base_models"):
+ if filename.endswith(".py") and filename not in [
+ "__init__.py",
+ "basemodel.py",
+ "lightning_wrapper.py",
+ "bayesian_tabm.py",
+ ]:
+ module_name = f"{MODEL_MODULE_PATH}.{filename[:-3]}"
+ module = importlib.import_module(module_name)
+
+ for name, obj in inspect.getmembers(module, inspect.isclass):
+ if issubclass(obj, BaseModel) and obj is not BaseModel:
+ model_classes.append(obj)
+
+
+def get_model_config(model_class):
+ """Dynamically load the correct config class for each model."""
+ model_name = model_class.__name__ # e.g., "Mambular"
+ config_class_name = f"Default{model_name}Config" # e.g., "DefaultMambularConfig"
+
+ try:
+ config_module = importlib.import_module(
+ f"{CONFIG_MODULE_PATH}.{model_name.lower()}_config"
+ )
+ config_class = getattr(config_module, config_class_name)
+ return config_class() # Instantiate config
+ except (ModuleNotFoundError, AttributeError) as e:
+ pytest.fail(
+ f"Could not find or instantiate config {config_class_name} for {model_name}: {e}"
+ )
+
+
+@pytest.mark.parametrize("model_class", model_classes)
+def test_model_inherits_base_model(model_class):
+ """Test that each model correctly inherits from BaseModel."""
+ assert issubclass(
+ model_class, BaseModel
+ ), f"{model_class.__name__} should inherit from BaseModel."
+
+
+@pytest.mark.parametrize("model_class", model_classes)
+def test_model_has_forward_method(model_class):
+ """Test that each model has a forward method with *data."""
+ assert hasattr(
+ model_class, "forward"
+ ), f"{model_class.__name__} is missing a forward method."
+
+ sig = inspect.signature(model_class.forward)
+ assert any(
+ p.kind == inspect.Parameter.VAR_POSITIONAL for p in sig.parameters.values()
+ ), f"{model_class.__name__}.forward should have *data argument."
+
+
+@pytest.mark.parametrize("model_class", model_classes)
+def test_model_takes_config(model_class):
+ """Test that each model accepts a config argument."""
+ sig = inspect.signature(model_class.__init__)
+ assert (
+ "config" in sig.parameters
+ ), f"{model_class.__name__} should accept a 'config' parameter."
+
+
+@pytest.mark.parametrize("model_class", model_classes)
+def test_model_has_num_classes(model_class):
+ """Test that each model accepts a num_classes argument."""
+ sig = inspect.signature(model_class.__init__)
+ assert (
+ "num_classes" in sig.parameters
+ ), f"{model_class.__name__} should accept a 'num_classes' parameter."
+
+
+@pytest.mark.parametrize("model_class", model_classes)
+def test_model_calls_super_init(model_class):
+ """Test that each model calls super().__init__(config=config, **kwargs)."""
+ source = inspect.getsource(model_class.__init__)
+ assert (
+ "super().__init__(config=config" in source
+ ), f"{model_class.__name__} should call super().__init__(config=config, **kwargs)."
+
+
+@pytest.mark.parametrize("model_class", model_classes)
+def test_model_initialization(model_class):
+ """Test that each model can be initialized with its correct config."""
+ config = get_model_config(model_class)
+ feature_info = (
+ {
+ "A": {
+ "preprocessing": "imputer -> check_positive -> box-cox",
+ "dimension": 1,
+ "categories": None,
+ }
+ },
+ {
+ "sibsp": {
+ "preprocessing": "imputer -> continuous_ordinal",
+ "dimension": 1,
+ "categories": 8,
+ }
+ },
+ {},
+ ) # Mock feature info
+
+ try:
+ model = model_class(
+ feature_information=feature_info, num_classes=3, config=config
+ )
+ except Exception as e:
+ pytest.fail(f"Failed to initialize {model_class.__name__}: {e}")
+
+
+@pytest.mark.parametrize("model_class", model_classes)
+def test_model_defines_key_attributes(model_class):
+ """Test that each model defines expected attributes like returns_ensemble"""
+ config = get_model_config(model_class)
+ feature_info = (
+ {
+ "A": {
+ "preprocessing": "imputer -> check_positive -> box-cox",
+ "dimension": 1,
+ "categories": None,
+ }
+ },
+ {
+ "sibsp": {
+ "preprocessing": "imputer -> continuous_ordinal",
+ "dimension": 1,
+ "categories": 8,
+ }
+ },
+ {},
+ ) # Mock feature info
+
+ try:
+ model = model_class(
+ feature_information=feature_info, num_classes=3, config=config
+ )
+ except TypeError as e:
+ pytest.fail(f"Failed to initialize {model_class.__name__}: {e}")
+
+ expected_attrs = ["returns_ensemble"]
+ for attr in expected_attrs:
+ assert hasattr(model, attr), f"{model_class.__name__} should define '{attr}'."
diff --git a/tests/test_classifier.py b/tests/test_classifier.py
deleted file mode 100644
index 7243233d..00000000
--- a/tests/test_classifier.py
+++ /dev/null
@@ -1,115 +0,0 @@
-import unittest
-from unittest.mock import MagicMock, patch
-
-import numpy as np
-import pandas as pd
-import torch
-from sklearn.metrics import accuracy_score, log_loss
-
-from mambular.models import MambularClassifier # Ensure correct import path
-
-
-class TestMambularClassifier(unittest.TestCase):
- def setUp(self):
- # Patching external dependencies
- self.patcher_pl_trainer = patch("lightning.Trainer")
- self.mock_pl_trainer = self.patcher_pl_trainer.start()
-
- self.patcher_base_model = patch("mambular.base_models.classifier.BaseMambularClassifier")
- self.mock_base_model = self.patcher_base_model.start()
-
- self.classifier = MambularClassifier(d_model=128, dropout=0.1)
-
- # Sample data
- self.X = pd.DataFrame(np.random.randn(100, 10))
- self.y = np.random.choice(["A", "B", "C"], size=100)
-
- self.classifier.cat_feature_info = {}
- self.classifier.num_feature_info = {}
-
- def tearDown(self):
- self.patcher_pl_trainer.stop()
- self.patcher_base_model.stop()
-
- def test_initialization(self):
- # This assumes MambularConfig is properly imported and used in the MambularRegressor class
- from mambular.utils.configs import DefaultMambularConfig
-
- self.assertIsInstance(self.classifier.config, DefaultMambularConfig)
- self.assertEqual(self.classifier.config.d_model, 128)
- self.assertEqual(self.classifier.config.dropout, 0.1)
-
- def test_split_data(self):
- """Test the data splitting functionality."""
- X_train, X_val, y_train, y_val = self.classifier.split_data(self.X, self.y, val_size=0.2, random_state=42)
- self.assertEqual(len(X_train), 80)
- self.assertEqual(len(X_val), 20)
- self.assertEqual(len(y_train), 80)
- self.assertEqual(len(y_val), 20)
-
- def test_fit(self):
- """Test the training setup and call."""
- # Mock the necessary parts to simulate training
- self.classifier.preprocess_data = MagicMock()
- self.classifier.model = self.mock_base_model
-
- self.classifier.fit(self.X, self.y)
-
- # Ensure that the fit method of the trainer is called
- self.mock_pl_trainer.return_value.fit.assert_called_once()
-
- def test_predict(self):
- # Create a mock tensor as the model output
- # Assuming three classes A, B, C as per self.y
- mock_logits = torch.rand(100, 3)
-
- # Mock the model and its method calls
- self.classifier.model = MagicMock()
- self.classifier.model.eval.return_value = None
- self.classifier.model.return_value = mock_logits
-
- # Mock preprocess_test_data to return dummy tensor data
- self.classifier.preprocess_test_data = MagicMock(return_value=([], []))
-
- predictions = self.classifier.predict(self.X)
-
- # Assert that predictions return as expected
- expected_predictions = torch.argmax(mock_logits, dim=1).numpy()
- np.testing.assert_array_equal(predictions, expected_predictions)
-
- def test_evaluate(self):
- # Mock predict and predict_proba to simulate classifier output
- mock_predictions = np.random.choice([0, 1, 2], size=100)
- raw_probabilities = np.random.rand(100, 3)
- # Normalize these probabilities so that each row sums to 1
- mock_probabilities = raw_probabilities / raw_probabilities.sum(axis=1, keepdims=True)
- self.classifier.predict = MagicMock(return_value=mock_predictions)
- self.classifier.predict_proba = MagicMock(return_value=mock_probabilities)
-
- # Define metrics to test
- metrics = {
- "Accuracy": (accuracy_score, False),
- # Log Loss requires probability scores
- "Log Loss": (log_loss, True),
- }
-
- # Call evaluate with the defined metrics
- result = self.classifier.evaluate(self.X, self.y, metrics=metrics)
-
- # Assert that predict and predict_proba were called correctly
- self.classifier.predict.assert_called_once()
- self.classifier.predict_proba.assert_called_once()
-
- # Check the results of evaluate
- expected_accuracy = accuracy_score(self.y, mock_predictions)
- expected_log_loss = log_loss(self.y, mock_probabilities)
- self.assertEqual(result["Accuracy"], expected_accuracy)
- self.assertAlmostEqual(result["Log Loss"], expected_log_loss)
-
- # Assert calls with appropriate arguments
- self.classifier.predict.assert_called_once_with(self.X)
- self.classifier.predict_proba.assert_called_once_with(self.X)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/test_configs.py b/tests/test_configs.py
new file mode 100644
index 00000000..5299534a
--- /dev/null
+++ b/tests/test_configs.py
@@ -0,0 +1,115 @@
+import pytest
+import inspect
+import importlib
+import os
+import dataclasses
+import typing
+from mambular.configs.base_config import BaseConfig # Ensure correct path
+
+CONFIG_MODULE_PATH = "mambular.configs"
+config_classes = []
+
+# Discover all config classes in mambular/configs/
+for filename in os.listdir(os.path.dirname(__file__) + "/../mambular/configs"):
+ if (
+ filename.endswith(".py")
+ and filename != "base_config.py"
+ and not filename.startswith("__")
+ ):
+ module_name = f"{CONFIG_MODULE_PATH}.{filename[:-3]}"
+ module = importlib.import_module(module_name)
+
+ for name, obj in inspect.getmembers(module, inspect.isclass):
+ if issubclass(obj, BaseConfig) and obj is not BaseConfig:
+ config_classes.append(obj)
+
+
+@pytest.mark.parametrize("config_class", config_classes)
+def test_config_inherits_baseconfig(config_class):
+ """Test that each config class correctly inherits from BaseConfig."""
+ assert issubclass(
+ config_class, BaseConfig
+ ), f"{config_class.__name__} should inherit from BaseConfig."
+
+
+@pytest.mark.parametrize("config_class", config_classes)
+def test_config_instantiation(config_class):
+ """Test that each config class can be instantiated without errors."""
+ try:
+ config = config_class()
+ except Exception as e:
+ pytest.fail(f"Failed to instantiate {config_class.__name__}: {e}")
+
+
+@pytest.mark.parametrize("config_class", config_classes)
+def test_config_has_expected_attributes(config_class):
+ """Test that each config has all required attributes from BaseConfig."""
+ base_attrs = {field.name for field in dataclasses.fields(BaseConfig)}
+ config_attrs = {field.name for field in dataclasses.fields(config_class)}
+
+ missing_attrs = base_attrs - config_attrs
+ assert (
+ not missing_attrs
+ ), f"{config_class.__name__} is missing attributes: {missing_attrs}"
+
+
+@pytest.mark.parametrize("config_class", config_classes)
+def test_config_default_values(config_class):
+ """Ensure that each config class has default values assigned correctly."""
+ config = config_class()
+
+ for field in dataclasses.fields(config_class):
+ attr = field.name
+ expected_type = field.type
+
+ assert hasattr(
+ config, attr
+ ), f"{config_class.__name__} is missing attribute '{attr}'."
+
+ value = getattr(config, attr)
+
+ # Handle generic types properly
+ origin = typing.get_origin(expected_type)
+
+ if origin is typing.Literal:
+ # If the field is a Literal, ensure the value is one of the allowed options
+ allowed_values = typing.get_args(expected_type)
+ assert (
+ value in allowed_values
+ ), f"{config_class.__name__}.{attr} has incorrect value: expected one of {allowed_values}, got {value}"
+ elif origin is typing.Union:
+ # For Union types (e.g., Optional[str]), check if value matches any type in the union
+ allowed_types = typing.get_args(expected_type)
+ assert any(
+ isinstance(value, t) for t in allowed_types
+ ), f"{config_class.__name__}.{attr} has incorrect type: expected one of {allowed_types}, got {type(value)}"
+ elif origin is not None:
+ # If it's another generic type (e.g., list[str]), check against the base type
+ assert (
+ isinstance(value, origin) or value is None
+ ), f"{config_class.__name__}.{attr} has incorrect type: expected {expected_type}, got {type(value)}"
+ else:
+ # Standard type check
+ assert (
+ isinstance(value, expected_type) or value is None
+ ), f"{config_class.__name__}.{attr} has incorrect type: expected {expected_type}, got {type(value)}"
+
+
+@pytest.mark.parametrize("config_class", config_classes)
+def test_config_allows_updates(config_class):
+ """Ensure that config values can be updated and remain type-consistent."""
+ config = config_class()
+
+ update_values = {
+ "lr": 0.01,
+ "d_model": 128,
+ "embedding_type": "plr",
+ "activation": lambda x: x, # Function update
+ }
+
+ for attr, new_value in update_values.items():
+ if hasattr(config, attr):
+ setattr(config, attr, new_value)
+ assert (
+ getattr(config, attr) == new_value
+ ), f"{config_class.__name__}.{attr} did not update correctly."
diff --git a/tests/test_distributions.py b/tests/test_distributions.py
deleted file mode 100644
index 1a8f2ca7..00000000
--- a/tests/test_distributions.py
+++ /dev/null
@@ -1,311 +0,0 @@
-import unittest
-
-import torch
-
-from mambular.utils.distributions import (
- BetaDistribution,
- CategoricalDistribution,
- DirichletDistribution,
- GammaDistribution,
- InverseGammaDistribution,
- NegativeBinomialDistribution,
- NormalDistribution,
- PoissonDistribution,
- StudentTDistribution,
-)
-
-
-class TestNormalDistribution(unittest.TestCase):
- def setUp(self):
- """Initialize the NormalDistribution object with default transforms."""
- self.normal = NormalDistribution()
-
- def test_initialization(self):
- """Test the initialization and default parameter settings."""
- self.assertEqual(self.normal._name, "Normal")
- self.assertEqual(self.normal.param_names, ["mean", "variance"])
- self.assertIsInstance(self.normal.mean_transform, type(lambda x: x))
- self.assertIsInstance(self.normal.variance_transform, type(torch.nn.functional.softplus))
-
- def test_predefined_transforms(self):
- """Test if predefined transformations are correctly applied."""
- x = torch.tensor([-1.0, 0.0, 1.0])
- self.assertTrue(torch.allclose(self.normal.mean_transform(x), x)) # 'none' should change nothing
- self.assertTrue(
- torch.all(torch.ge(self.normal.variance_transform(x), 0))
- ) # 'positive' should make all values non-negative
-
- def test_compute_loss_known_values(self):
- """Test the loss computation against known values."""
- predictions = torch.tensor([[0.0, 1.0]]) # mean = 0, variance = 1
- y_true = torch.tensor([0.0])
- self.normal = NormalDistribution()
- loss = self.normal.compute_loss(predictions, y_true)
- test_dist = torch.distributions.Normal(
- loc=predictions[:, 0], scale=torch.nn.functional.softplus(predictions[:, 1])
- )
- expected_loss = -test_dist.log_prob(torch.tensor(0.0)).mean()
- self.assertAlmostEqual(loss.item(), expected_loss.item(), places=5)
-
- def test_evaluate_nll(self):
- """Test the evaluate NLL function."""
- y_true = [0.0]
- y_pred = [[0.0, 1.0]] # mean=0, variance=1
- result = self.normal.evaluate_nll(y_true, y_pred)
- self.assertIn("NLL", result)
- self.assertIn("mse", result)
- self.assertIn("mae", result)
- self.assertIn("rmse", result)
-
-
-class TestPoissonDistribution(unittest.TestCase):
- def setUp(self):
- """Initialize the PoissonDistribution object with default transform."""
- self.poisson = PoissonDistribution()
-
- def test_initialization(self):
- """Test the initialization and parameter settings."""
- self.assertEqual(self.poisson._name, "Poisson")
- self.assertEqual(self.poisson.param_names, ["rate"])
- self.assertIsInstance(self.poisson.rate_transform, type(torch.nn.functional.softplus))
-
- def test_compute_loss_known_values(self):
- """Test the loss computation against known values."""
- predictions = torch.tensor([[1.0]]) # rate = 1
- y_true = torch.tensor([1.0])
- loss = self.poisson.compute_loss(predictions, y_true)
- expected_loss = (
- -torch.distributions.Poisson(torch.nn.functional.softplus(predictions[:, 0]))
- .log_prob(torch.tensor(1.0))
- .mean()
- )
- self.assertAlmostEqual(loss.item(), expected_loss.item(), places=5)
-
-
-class TestBetaDistribution(unittest.TestCase):
- def setUp(self):
- """Initialize the BetaDistribution object with default transforms."""
- self.beta = BetaDistribution()
-
- def test_initialization(self):
- """Test the initialization and parameter settings."""
- self.assertEqual(self.beta._name, "Beta")
- self.assertEqual(self.beta.param_names, ["alpha", "beta"])
- self.assertIsInstance(self.beta.alpha_transform, type(torch.nn.functional.softplus))
- self.assertIsInstance(self.beta.beta_transform, type(torch.nn.functional.softplus))
-
- def test_compute_loss_known_values(self):
- """Test the loss computation against known values."""
- predictions = torch.tensor([[1.0, 1.0]]) # alpha = 1, beta = 1 (uniform distribution)
- y_true = torch.tensor([0.5])
- loss = self.beta.compute_loss(predictions, y_true)
- expected_loss = (
- -torch.distributions.Beta(
- torch.nn.functional.softplus(predictions[:, 0]),
- torch.nn.functional.softplus(predictions[:, 1]),
- )
- .log_prob(torch.tensor(0.5))
- .mean()
- )
- self.assertAlmostEqual(loss.item(), expected_loss.item(), places=5)
-
-
-class TestInverseGammaDistribution(unittest.TestCase):
- def setUp(self):
- """Initialize the InverseGammaDistribution object with default transforms."""
- self.inverse_gamma = InverseGammaDistribution()
-
- def test_initialization(self):
- """Test the initialization and parameter settings."""
- self.assertEqual(self.inverse_gamma._name, "InverseGamma")
- self.assertEqual(self.inverse_gamma.param_names, ["shape", "scale"])
- self.assertIsInstance(self.inverse_gamma.shape_transform, type(torch.nn.functional.softplus))
- self.assertIsInstance(self.inverse_gamma.scale_transform, type(torch.nn.functional.softplus))
-
- def test_compute_loss_known_values(self):
- """Test the loss computation against known values."""
- # These values for shape and scale parameters are chosen to be feasible and testable.
- predictions = torch.tensor([[3.0, 2.0]]) # shape = 3, scale = 2
- y_true = torch.tensor([0.5])
-
- loss = self.inverse_gamma.compute_loss(predictions, y_true)
- # Manually calculate the expected loss using torch's distribution functions
- shape = torch.nn.functional.softplus(predictions[:, 0])
- scale = torch.nn.functional.softplus(predictions[:, 1])
- inverse_gamma_dist = torch.distributions.InverseGamma(shape, scale)
- expected_loss = -inverse_gamma_dist.log_prob(y_true).mean()
-
- self.assertAlmostEqual(loss.item(), expected_loss.item(), places=5)
-
-
-class TestDirichletDistribution(unittest.TestCase):
- def setUp(self):
- """Initialize the DirichletDistribution object with default transforms."""
- self.dirichlet = DirichletDistribution()
-
- def test_initialization(self):
- """Test the initialization and parameter settings."""
- self.assertEqual(self.dirichlet._name, "Dirichlet")
- # Concentration param_name is a simplification as mentioned in your class docstring
- self.assertEqual(self.dirichlet.param_names, ["concentration"])
- self.assertIsInstance(self.dirichlet.concentration_transform, type(torch.nn.functional.softplus))
-
- def test_compute_loss_known_values(self):
- """Test the loss computation against known values."""
- # These values are chosen to be feasible and testable.
- # Example: Concentrations for a 3-dimensional Dirichlet distribution
- predictions = torch.tensor(
- [[1.0, 1.0, 1.0]]
- ) # Equal concentration, should resemble uniform distribution over simplex
- y_true = torch.tensor([[0.33, 0.33, 0.34]]) # Example point in the probability simplex
-
- loss = self.dirichlet.compute_loss(predictions, y_true)
- # Manually calculate the expected loss using torch's distribution functions
- concentration = torch.nn.functional.softplus(predictions)
- dirichlet_dist = torch.distributions.Dirichlet(concentration)
- expected_loss = -dirichlet_dist.log_prob(y_true).mean()
-
- self.assertAlmostEqual(loss.item(), expected_loss.item(), places=5)
-
-
-class TestGammaDistribution(unittest.TestCase):
- def setUp(self):
- """Initialize the GammaDistribution object with default transforms."""
- self.gamma = GammaDistribution()
-
- def test_initialization(self):
- """Test the initialization and parameter settings."""
- self.assertEqual(self.gamma._name, "Gamma")
- self.assertEqual(self.gamma.param_names, ["shape", "rate"])
- self.assertIsInstance(self.gamma.shape_transform, type(torch.nn.functional.softplus))
- self.assertIsInstance(self.gamma.rate_transform, type(torch.nn.functional.softplus))
-
- def test_compute_loss_known_values(self):
- """Test the loss computation against known values."""
- # Set some test parameters and observations
- predictions = torch.tensor([[2.0, 3.0]]) # shape = 2, rate = 3
- y_true = torch.tensor([0.5]) # Test value
-
- loss = self.gamma.compute_loss(predictions, y_true)
- # Manually calculate the expected loss using torch's distribution functions
- shape = torch.nn.functional.softplus(predictions[:, 0])
- rate = torch.nn.functional.softplus(predictions[:, 1])
- gamma_dist = torch.distributions.Gamma(shape, rate)
- expected_loss = -gamma_dist.log_prob(y_true).mean()
-
- self.assertAlmostEqual(loss.item(), expected_loss.item(), places=5)
-
-
-class TestStudentTDistribution(unittest.TestCase):
- def setUp(self):
- """Initialize the StudentTDistribution object with default transforms."""
- self.student_t = StudentTDistribution()
-
- def test_initialization(self):
- """Test the initialization and parameter settings."""
- self.assertEqual(self.student_t._name, "StudentT")
- self.assertEqual(self.student_t.param_names, ["df", "loc", "scale"])
- self.assertIsInstance(self.student_t.df_transform, type(torch.nn.functional.softplus))
- self.assertIsInstance(
- self.student_t.loc_transform,
- type(lambda x: x), # Assuming 'none' transformation
- )
- self.assertIsInstance(self.student_t.scale_transform, type(torch.nn.functional.softplus))
-
- def test_compute_loss_known_values(self):
- """Test the loss computation against known values."""
- # Set some test parameters and observations
- predictions = torch.tensor([[10.0, 0.0, 1.0]]) # df=10, loc=0, scale=1
- y_true = torch.tensor([0.5]) # Test value
-
- loss = self.student_t.compute_loss(predictions, y_true)
- # Manually calculate the expected loss using torch's distribution functions
- df = torch.nn.functional.softplus(predictions[:, 0])
- loc = predictions[:, 1] # 'none' transformation
- scale = torch.nn.functional.softplus(predictions[:, 2])
- student_t_dist = torch.distributions.StudentT(df, loc, scale)
- expected_loss = -student_t_dist.log_prob(y_true).mean()
-
- self.assertAlmostEqual(loss.item(), expected_loss.item(), places=5)
-
- def test_evaluate_nll(self):
- """Test the evaluate NLL function and additional metrics."""
- y_true = [0.5]
- y_pred = [[10.0, 0.0, 1.0]] # df=10, loc=0, scale=1
- result = self.student_t.evaluate_nll(y_true, y_pred)
-
- self.assertIn("NLL", result)
- self.assertIn("mse", result)
- self.assertIn("mae", result)
- self.assertIn("rmse", result)
-
- # Check that MSE, MAE, RMSE calculations are reasonable
- self.assertGreaterEqual(result["mse"], 0)
- self.assertGreaterEqual(result["mae"], 0)
- self.assertGreaterEqual(result["rmse"], 0)
-
-
-class TestNegativeBinomialDistribution(unittest.TestCase):
- def setUp(self):
- """Initialize the NegativeBinomialDistribution object with default transforms."""
- self.negative_binomial = NegativeBinomialDistribution()
-
- def test_initialization(self):
- """Test the initialization and parameter settings."""
- self.assertEqual(self.negative_binomial._name, "NegativeBinomial")
- self.assertEqual(self.negative_binomial.param_names, ["mean", "dispersion"])
- self.assertIsInstance(self.negative_binomial.mean_transform, type(torch.nn.functional.softplus))
- self.assertIsInstance(
- self.negative_binomial.dispersion_transform,
- type(torch.nn.functional.softplus),
- )
-
- def test_compute_loss_known_values(self):
- """Test the loss computation against known values."""
- # Set some test parameters and observations
- predictions = torch.tensor([[10.0, 0.1]]) # mean=10, dispersion=0.1
- y_true = torch.tensor([5.0]) # Test value
-
- loss = self.negative_binomial.compute_loss(predictions, y_true)
- # Manually calculate the expected loss using torch's distribution functions
- mean = torch.nn.functional.softplus(predictions[:, 0])
- dispersion = torch.nn.functional.softplus(predictions[:, 1])
- r = 1 / dispersion
- p = r / (r + mean)
- negative_binomial_dist = torch.distributions.NegativeBinomial(total_count=r, probs=p)
- expected_loss = -negative_binomial_dist.log_prob(y_true).mean()
-
- self.assertAlmostEqual(loss.item(), expected_loss.item(), places=5)
-
-
-class TestCategoricalDistribution(unittest.TestCase):
- def setUp(self):
- """Initialize the CategoricalDistribution object with a probability transformation."""
- self.categorical = CategoricalDistribution()
-
- def test_initialization(self):
- """Test the initialization and parameter settings."""
- self.assertEqual(self.categorical._name, "Categorical")
- self.assertEqual(self.categorical.param_names, ["probs"])
- # The transformation function will need to ensure the probabilities are valid (non-negative and sum to 1)
- # Typically, this might involve applying softmax to ensure the constraints are met.
- # Here, we assume `prob_transform` is something akin to softmax for the sake of test setup.
- self.assertIsInstance(self.categorical.probs_transform, type(torch.nn.functional.softmax))
-
- def test_compute_loss_known_values(self):
- # Example with three categories
- logits = torch.tensor([[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]]) # Logits for three categories
- y_true = torch.tensor([2, 1])
-
- loss = self.categorical.compute_loss(logits, y_true)
- # Apply softmax to logits to convert them into probabilities
- probs = torch.nn.functional.softmax(logits, dim=1)
- cat_dist = torch.distributions.Categorical(probs=probs)
- expected_loss = -cat_dist.log_prob(y_true).mean()
-
- self.assertAlmostEqual(loss.item(), expected_loss.item(), places=5)
-
-
-# Running the tests
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/test_lss.py b/tests/test_lss.py
deleted file mode 100644
index 01192db4..00000000
--- a/tests/test_lss.py
+++ /dev/null
@@ -1,104 +0,0 @@
-import unittest
-from unittest.mock import MagicMock, patch
-
-import numpy as np
-import pandas as pd
-from properscoring import (
- crps_gaussian,
-)
-
-# Assuming this is the source of the CRPS function
-from sklearn.metrics import mean_poisson_deviance, mean_squared_error
-
-from mambular.models import MambularLSS # Update the import path
-
-
-class TestMambularLSS(unittest.TestCase):
- def setUp(self):
- # Patch PyTorch Lightning's Trainer and any other external dependencies
- self.patcher_trainer = patch("lightning.Trainer")
- self.mock_trainer = self.patcher_trainer.start()
-
- self.patcher_base_model = patch("mambular.base_models.distributional.BaseMambularLSS")
- self.mock_base_model = self.patcher_base_model.start()
-
- # Initialize MambularLSS with example parameters
- self.model = MambularLSS(d_model=128, dropout=0.1, n_layers=4)
-
- # Sample data
- self.X = pd.DataFrame(np.random.randn(100, 10))
- self.y = np.random.rand(100)
-
- self.model.cat_feature_info = {}
- self.model.num_feature_info = {}
-
- self.X_test = pd.DataFrame(np.random.randn(100, 10))
- self.y_test = np.random.rand(100) ** 2
-
- def tearDown(self):
- self.patcher_trainer.stop()
- self.patcher_base_model.stop()
-
- def test_initialization(self):
- from mambular.utils.configs import DefaultMambularConfig
-
- self.assertIsInstance(self.model.config, DefaultMambularConfig)
- self.assertEqual(self.model.config.d_model, 128)
- self.assertEqual(self.model.config.dropout, 0.1)
- self.assertEqual(self.model.config.n_layers, 4)
-
- def test_split_data(self):
- X_train, X_val, y_train, y_val = self.model.split_data(self.X, self.y, val_size=0.2, random_state=42)
- self.assertEqual(len(X_train), 80)
- self.assertEqual(len(X_val), 20)
- self.assertEqual(len(y_train), 80)
- self.assertEqual(len(y_val), 20)
-
- def test_fit(self):
- # Mock preprocessing and model setup to focus on testing training logic
- self.model.preprocess_data = MagicMock()
- self.model.model = self.mock_base_model
-
- self.model.fit(self.X, self.y, family="normal")
-
- # Ensure the fit method of the trainer is called
- self.mock_trainer.return_value.fit.assert_called_once()
-
- def test_normal_metrics(self):
- # Mock predictions for the normal distribution: [mean, variance]
- mock_predictions = np.column_stack((np.random.normal(size=100), np.abs(np.random.normal(size=100))))
- self.model.predict = MagicMock(return_value=mock_predictions)
-
- # Define custom metrics or use a function that fetches appropriate metrics
- self.model.get_default_metrics = MagicMock(
- return_value={
- "MSE": lambda y, pred: mean_squared_error(y, pred[:, 0]),
- "CRPS": lambda y, pred: np.mean(
- [crps_gaussian(y[i], mu=pred[i, 0], sig=np.sqrt(pred[i, 1])) for i in range(len(y))]
- ),
- }
- )
-
- results = self.model.evaluate(self.X_test, self.y_test, distribution_family="normal")
-
- # Validate the MSE
- expected_mse = mean_squared_error(self.y_test, mock_predictions[:, 0])
- self.assertAlmostEqual(results["MSE"], expected_mse, places=4)
- self.assertIn("CRPS", results) # Check for existence but not the exact value in this test
-
- def test_poisson_metrics(self):
- # Mock predictions for Poisson
- mock_predictions = np.random.poisson(lam=3, size=100) + 1e-3
- self.model.predict = MagicMock(return_value=mock_predictions)
-
- self.model.get_default_metrics = MagicMock(return_value={"Poisson Deviance": mean_poisson_deviance})
-
- results = self.model.evaluate(self.X_test, self.y_test, distribution_family="poisson")
- self.assertIn("Poisson Deviance", results)
- # Optionally calculate expected deviance and check
- expected_deviance = mean_poisson_deviance(self.y_test, mock_predictions)
- self.assertAlmostEqual(results["Poisson Deviance"], expected_deviance)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/test_preprocessor.py b/tests/test_preprocessor.py
index 68c32656..5c046304 100644
--- a/tests/test_preprocessor.py
+++ b/tests/test_preprocessor.py
@@ -1,82 +1,111 @@
-import unittest
-
+import pytest
import numpy as np
import pandas as pd
from sklearn.exceptions import NotFittedError
+from mambular.preprocessing import Preprocessor
+
+
+@pytest.fixture
+def sample_data():
+ return pd.DataFrame(
+ {
+ "numerical": np.random.randn(100),
+ "categorical": np.random.choice(["A", "B", "C"], size=100),
+ "integer": np.random.randint(0, 5, size=100),
+ }
+ )
+
+
+@pytest.fixture
+def sample_target():
+ return np.random.randn(100)
+
+
+@pytest.fixture(
+ params=[
+ "ple",
+ "binning",
+ "one-hot",
+ "standardization",
+ "minmax",
+ "quantile",
+ "polynomial",
+ "robust",
+ "splines",
+ "yeo-johnson",
+ "box-cox",
+ "rbf",
+ "sigmoid",
+ "none",
+ ]
+)
+def preprocessor(request):
+ return Preprocessor(
+ numerical_preprocessing=request.param, categorical_preprocessing="one-hot"
+ )
+
+
+def test_preprocessor_initialization(preprocessor):
+ assert preprocessor.numerical_preprocessing in [
+ "ple",
+ "binning",
+ "one-hot",
+ "standardization",
+ "minmax",
+ "quantile",
+ "polynomial",
+ "robust",
+ "splines",
+ "yeo-johnson",
+ "box-cox",
+ "rbf",
+ "sigmoid",
+ "none",
+ ]
+ assert preprocessor.categorical_preprocessing == "one-hot"
+ assert not preprocessor.fitted
+
+
+def test_preprocessor_fit(preprocessor, sample_data, sample_target):
+ preprocessor.fit(sample_data, sample_target)
+ assert preprocessor.fitted
+ assert preprocessor.column_transformer is not None
+
+
+def test_preprocessor_transform(preprocessor, sample_data, sample_target):
+ preprocessor.fit(sample_data, sample_target)
+ transformed = preprocessor.transform(sample_data)
+ assert isinstance(transformed, dict)
+ assert len(transformed) > 0
+
+
+def test_preprocessor_fit_transform(preprocessor, sample_data, sample_target):
+ transformed = preprocessor.fit_transform(sample_data, sample_target)
+ assert isinstance(transformed, dict)
+ assert len(transformed) > 0
+
+
+def test_preprocessor_get_params(preprocessor):
+ params = preprocessor.get_params()
+ assert "n_bins" in params
+ assert "numerical_preprocessing" in params
+
+
+def test_preprocessor_set_params(preprocessor):
+ preprocessor.set_params(n_bins=128)
+ assert preprocessor.n_bins == 128
+
+
+def test_transform_before_fit_raises_error(preprocessor, sample_data):
+ with pytest.raises(NotFittedError):
+ preprocessor.transform(sample_data)
+
-from mambular.utils.preprocessor import Preprocessor
-
-
-class TestPreprocessor(unittest.TestCase):
- def setUp(self):
- # Sample data for testing
- self.data = pd.DataFrame(
- {
- "numerical": np.random.randn(500),
- "categorical": np.random.choice(["A", "B", "C"], size=500),
- "mixed": np.random.choice([1, "A", "B"], size=500),
- }
- )
- self.target = np.random.randn(500)
-
- def test_initialization(self):
- """Test initialization of the Preprocessor with default parameters."""
- pp = Preprocessor(n_bins=20, numerical_preprocessing="binning")
- self.assertEqual(pp.n_bins, 20)
- self.assertEqual(pp.numerical_preprocessing, "binning")
- self.assertFalse(pp.use_decision_tree_bins)
-
- def test_fit(self):
- """Test the fitting process of the preprocessor."""
- pp = Preprocessor(numerical_preprocessing="binning", n_bins=20)
- pp.fit(self.data, self.target)
- self.assertIsNotNone(pp.column_transformer)
-
- def test_transform_not_fitted(self):
- """Test that transform raises an error if called before fitting."""
- pp = Preprocessor()
- with self.assertRaises(NotFittedError):
- pp.transform(self.data)
-
- def test_fit_transform(self):
- """Test fitting and transforming the data."""
- pp = Preprocessor(numerical_preprocessing="standardization")
- transformed_data = pp.fit_transform(self.data, self.target)
- self.assertIsInstance(transformed_data, dict)
- self.assertTrue("num_numerical" in transformed_data)
- self.assertTrue("cat_categorical" in transformed_data)
-
- def test_ple(self):
- """Test fitting and transforming the data."""
- pp = Preprocessor(numerical_preprocessing="ple", n_bins=20)
- transformed_data = pp.fit_transform(self.data, self.target)
- self.assertIsInstance(transformed_data, dict)
- self.assertTrue("num_numerical" in transformed_data)
- self.assertTrue("cat_categorical" in transformed_data)
-
- def test_transform_with_missing_values(self):
- """Ensure the preprocessor can handle missing values."""
- data_with_missing = self.data.copy()
- data_with_missing.loc[0, "numerical"] = np.nan
- data_with_missing.loc[1, "categorical"] = np.nan
- pp = Preprocessor(numerical_preprocessing="normalization")
- transformed_data = pp.fit_transform(data_with_missing, self.target)
- self.assertNotIn(np.nan, transformed_data["num_numerical"])
- self.assertNotIn(np.nan, transformed_data["cat_categorical"])
-
- def test_decision_tree_bins(self):
- """Test the usage of decision tree for binning."""
- pp = Preprocessor(use_decision_tree_bins=True, numerical_preprocessing="binning", n_bins=5)
- pp.fit(self.data, self.target)
- # Checking if the preprocessor setup decision tree bins properly
- self.assertTrue(
- all(
- isinstance(x, np.ndarray)
- for x in pp._get_decision_tree_bins(self.data[["numerical"]], self.target, ["numerical"])
- )
- )
-
-
-# Running the tests
-if __name__ == "__main__":
- unittest.main()
+def test_get_feature_info(preprocessor, sample_data, sample_target):
+ preprocessor.fit(sample_data, sample_target)
+ numerical_info, categorical_info, embedding_info = preprocessor.get_feature_info(
+ verbose=False
+ )
+ assert isinstance(numerical_info, dict)
+ assert isinstance(categorical_info, dict)
+ assert isinstance(embedding_info, dict)
diff --git a/tests/test_regressor.py b/tests/test_regressor.py
deleted file mode 100644
index 86260398..00000000
--- a/tests/test_regressor.py
+++ /dev/null
@@ -1,103 +0,0 @@
-import unittest
-from unittest.mock import MagicMock, patch
-
-import numpy as np
-import pandas as pd
-from sklearn.metrics import mean_squared_error, r2_score
-
-from mambular.models import MambularRegressor # Ensure correct import path
-
-
-class TestMambularRegressor(unittest.TestCase):
- def setUp(self):
- # Patching external dependencies
- self.patcher_pl_trainer = patch("lightning.Trainer")
- self.mock_pl_trainer = self.patcher_pl_trainer.start()
-
- self.patcher_base_model = patch("mambular.base_models.regressor.BaseMambularRegressor")
- self.mock_base_model = self.patcher_base_model.start()
-
- self.regressor = MambularRegressor(d_model=128, dropout=0.1)
-
- # Sample data
- self.X = pd.DataFrame(np.random.randn(100, 10))
- self.y = np.random.rand(100)
-
- self.regressor.cat_feature_info = {}
- self.regressor.num_feature_info = {}
-
- def tearDown(self):
- self.patcher_pl_trainer.stop()
- self.patcher_base_model.stop()
-
- def test_initialization(self):
- # This assumes MambularConfig is properly imported and used in the MambularRegressor class
- from mambular.utils.configs import DefaultMambularConfig
-
- self.assertIsInstance(self.regressor.config, DefaultMambularConfig)
- self.assertEqual(self.regressor.config.d_model, 128)
- self.assertEqual(self.regressor.config.dropout, 0.1)
-
- def test_split_data(self):
- """Test the data splitting functionality."""
- X_train, X_val, y_train, y_val = self.regressor.split_data(self.X, self.y, val_size=0.2, random_state=42)
- self.assertEqual(len(X_train), 80)
- self.assertEqual(len(X_val), 20)
- self.assertEqual(len(y_train), 80)
- self.assertEqual(len(y_val), 20)
-
- def test_fit(self):
- """Test the training setup and call."""
- # Mock the necessary parts to simulate training
- self.regressor.preprocess_data = MagicMock()
- self.regressor.model = self.mock_base_model
-
- self.regressor.fit(self.X, self.y)
-
- # Ensure that the fit method of the trainer is called
- self.mock_pl_trainer.return_value.fit.assert_called_once()
-
- def test_predict(self):
- # Create mock return objects that mimic tensor behavior
- mock_prediction = MagicMock()
- mock_prediction.cpu.return_value = MagicMock()
- mock_prediction.cpu.return_value.numpy.return_value = np.array([0.5] * 100)
-
- # Mock the model and its method calls
- self.regressor.model = MagicMock()
- self.regressor.model.eval.return_value = None
- self.regressor.model.return_value = mock_prediction
-
- # Mock preprocess_test_data to return dummy tensor data
- self.regressor.preprocess_test_data = MagicMock(return_value=([], []))
-
- predictions = self.regressor.predict(self.X)
-
- # Assert that predictions return as expected
- np.testing.assert_array_equal(predictions, np.array([0.5] * 100))
-
- def test_evaluate(self):
- # Mock the predict method to simulate regressor output
- mock_predictions = np.random.rand(100)
- self.regressor.predict = MagicMock(return_value=mock_predictions)
-
- # Define metrics to test
- metrics = {"Mean Squared Error": mean_squared_error, "R2 Score": r2_score}
-
- # Call evaluate with the defined metrics
- result = self.regressor.evaluate(self.X, self.y, metrics=metrics)
-
- # Compute expected metrics directly
- expected_mse = mean_squared_error(self.y, mock_predictions)
- expected_r2 = r2_score(self.y, mock_predictions)
-
- # Check the results of evaluate
- self.assertAlmostEqual(result["Mean Squared Error"], expected_mse)
- self.assertAlmostEqual(result["R2 Score"], expected_r2)
-
- # Ensure predict was called correctly
- self.regressor.predict.assert_called_once_with(self.X)
-
-
-if __name__ == "__main__":
- unittest.main()