Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 62 additions & 62 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@

Mambular is a Python library for tabular deep learning. It includes models that leverage the Mamba (State Space Model) architecture, as well as other popular models like TabTransformer, FTTransformer, TabM and tabular ResNets. Check out our paper `Mambular: A Sequential Model for Tabular Deep Learning`, available [here](https://arxiv.org/abs/2408.06291). Also check out our paper introducing [TabulaRNN](https://arxiv.org/pdf/2411.17207) and analyzing the efficiency of NLP inspired tabular models.

<h3>⚡ What's New ⚡</h3>
<ul>
<li>Individual preprocessing: preprocess each feature differently, use pre-trained models for categorical encoding</li>
<li>Extract latent representations of tables</li>
<li>Use embeddings as inputs</li>
<li>Define custom training metrics</li>
</ul>




<h3> Table of Contents </h3>

- [🏃 Quickstart](#-quickstart)
Expand All @@ -30,7 +41,6 @@ Mambular is a Python library for tabular deep learning. It includes models that
- [🛠️ Installation](#️-installation)
- [🚀 Usage](#-usage)
- [💻 Implement Your Own Model](#-implement-your-own-model)
- [Custom Training](#custom-training)
- [🏷️ Citation](#️-citation)
- [License](#license)

Expand Down Expand Up @@ -103,6 +113,7 @@ pip install mamba-ssm
<h2> Preprocessing </h2>

Mambular simplifies data preprocessing with a range of tools designed for easy transformation of tabular data.
Specify a default method, or a dictionary defining individual preprocessing methods for each feature.

<h3> Data Type Detection and Transformation </h3>

Expand All @@ -116,6 +127,7 @@ Mambular simplifies data preprocessing with a range of tools designed for easy t
- **Polynomial Features**: Automatically generates polynomial and interaction terms for numerical features, enhancing the ability to capture higher-order relationships.
- **Box-Cox & Yeo-Johnson Transformations**: Performs power transformations to stabilize variance and normalize distributions.
- **Custom Binning**: Enables user-defined bin edges for precise discretization of numerical data.
- **Pre-trained Encoding**: Use sentence transformers to encode categorical features.



Expand Down Expand Up @@ -147,6 +159,28 @@ preds = model.predict(X)
preds = model.predict_proba(X)
```

Get latent representations for each feature:
```python
# simple encoding
model.encode(X)
```

Use unstructured data:
```python
# load pretrained models
image_model = ...
nlp_model = ...

# create embeddings
img_embs = image_model.encode(images)
txt_embs = nlp_model.encode(texts)

# fit model on tabular data and unstructured data
model.fit(X_train, y_train, embeddings=[img_embs, txt_embs])
```



<h3> Hyperparameter Optimization</h3>
Since all of the models are sklearn base estimators, you can use the built-in hyperparameter optimizatino from sklearn.

Expand Down Expand Up @@ -222,9 +256,11 @@ MambularLSS allows you to model the full distribution of a response variable, no
- **studentt**: For data with heavier tails, useful with small samples.
- **negativebinom**: For over-dispersed count data.
- **inversegamma**: Often used as a prior in Bayesian inference.
- **johnsonsu**: Four parameter distribution defining location, scale, kurtosis and skewness.
- **categorical**: For data with more than two categories.
- **Quantile**: For quantile regression using the pinball loss.


These distribution classes make MambularLSS versatile in modeling various data types and distributions.


Expand Down Expand Up @@ -269,13 +305,16 @@ Here's how you can implement a custom model with Mambular:

```python
from dataclasses import dataclass
from mambular.configs import BaseConfig

@dataclass
class MyConfig:
class MyConfig(BaseConfig):
lr: float = 1e-04
lr_patience: int = 10
weight_decay: float = 1e-06
lr_factor: float = 0.1
n_layers: int = 4
pooling_method:str = "avg

```

2. **Second, define your model:**
Expand All @@ -290,22 +329,32 @@ Here's how you can implement a custom model with Mambular:
class MyCustomModel(BaseModel):
def __init__(
self,
cat_feature_info,
num_feature_info,
feature_information: tuple,
num_classes: int = 1,
config=None,
**kwargs,
):
super().__init__(**kwargs)
self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
super().__init__(**kwargs)
self.save_hyperparameters(ignore=["feature_information"])
self.returns_ensemble = False

# embedding layer
self.embedding_layer = EmbeddingLayer(
*feature_information,
config=config,
)

input_dim = get_feature_dimensions(num_feature_info, cat_feature_info)
input_dim = np.sum(
[len(info) * self.hparams.d_model for info in feature_information]
)

self.linear = nn.Linear(input_dim, num_classes)

def forward(self, num_features, cat_features):
x = num_features + cat_features
x = torch.cat(x, dim=1)
def forward(self, *data) -> torch.Tensor:
x = self.embedding_layer(*data)
B, S, D = x.shape
x = x.reshape(B, S * D)


# Pass through linear layer
output = self.linear(x)
Expand All @@ -329,60 +378,11 @@ Here's how you can implement a custom model with Mambular:
```python
regressor = MyRegressor(numerical_preprocessing="ple")
regressor.fit(X_train, y_train, max_epochs=50)

regressor.evaluate(X_test, y_test)
```

# Custom Training
If you prefer to setup custom training, preprocessing and evaluation, you can simply use the `mambular.base_models`.
Just be careful that all basemodels expect lists of features as inputs. More precisely as list for numerical features and a list for categorical features. A custom training loop, with random data could look like this.

```python
import torch
import torch.nn as nn
import torch.optim as optim
from mambular.base_models import Mambular
from mambular.configs import DefaultMambularConfig

# Dummy data and configuration
cat_feature_info = {
"cat1": {
"preprocessing": "imputer -> continuous_ordinal",
"dimension": 1,
"categories": 4,
}
} # Example categorical feature information
num_feature_info = {
"num1": {"preprocessing": "imputer -> scaler", "dimension": 1, "categories": None}
} # Example numerical feature information
num_classes = 1
config = DefaultMambularConfig() # Use the desired configuration

# Initialize model, loss function, and optimizer
model = Mambular(cat_feature_info, num_feature_info, num_classes, config)
criterion = nn.MSELoss() # Use MSE for regression; change as appropriate for your task
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Example training loop
for epoch in range(10): # Number of epochs
model.train()
optimizer.zero_grad()

# Dummy Data
num_features = [torch.randn(32, 1) for _ in num_feature_info]
cat_features = [torch.randint(0, 5, (32,)) for _ in cat_feature_info]
labels = torch.randn(32, num_classes)

# Forward pass
outputs = model(num_features, cat_features)
loss = criterion(outputs, labels)

# Backward pass and optimization
loss.backward()
optimizer.step()

# Print loss for monitoring
print(f"Epoch [{epoch+1}/10], Loss: {loss.item():.4f}")

```

# 🏷️ Citation

Expand Down
2 changes: 2 additions & 0 deletions mambular/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .tabm_config import DefaultTabMConfig
from .tabtransformer_config import DefaultTabTransformerConfig
from .tabularnn_config import DefaultTabulaRNNConfig
from .base_config import BaseConfig

__all__ = [
"DefaultFTTransformerConfig",
Expand All @@ -24,4 +25,5 @@
"DefaultTabMConfig",
"DefaultTabTransformerConfig",
"DefaultTabulaRNNConfig",
"BaseConfig"
]
Loading