Tabular classification is the task of assigning a class to samples of structured or relational data. The ~flash.tabular.classification.model.TabularClassifier
task can be used for classification of samples in more than two classes (multi-class classification).
Let's look at training a model to predict if passenger survival on the Titanic using the classic Kaggle data set. The data is provided in CSV files that look like this:
PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
1,0,3,"Braund, Mr. Owen Harris",male,22,1,0,A/5 21171,7.25,,S
3,1,3,"Heikkinen, Miss. Laina",female,26,0,0,STON/O2. 3101282,7.925,,S
5,0,3,"Allen, Mr. William Henry",male,35,0,0,373450,8.05,,S
6,0,3,"Moran, Mr. James",male,,0,0,330877,8.4583,,Q
...
We can create the ~flash.tabular.classification.data.TabularData
from our CSV files using the ~flash.tabular.classification.data.TabularData.from_csv
method. From the API reference <flash.tabular.classification.data.TabularData.from_csv>
, we need to provide:
- cat_cols- A list of the names of columns that contain categorical data (strings or integers).
- num_cols- A list of the names of columns that contain numerical continuous data (floats).
- target- The name of the column we want to predict.
- train_csv- A CSV file containing the training data converted to a Pandas DataFrame
Next, we create the ~flash.tabular.classification.model.TabularClassifier
and finetune on the Titanic data. We then use the trained ~flash.tabular.classification.model.TabularClassifier
for inference. Finally, we save the model. Here's the full example:
../../../flash_examples/tabular_classification.py
To learn how to view the available backbones / heads for this task, see backbones_heads
.
The tabular classifier can be used directly from the command line with zero code using flash_zero
. You can run the above example with:
flash tabular_classifier
To view configuration options and options for running the tabular classifier with your own data, use:
flash tabular_classifier --help
The ~flash.tabular.classification.model.TabularClassifier
is servable. This means you can call .serve
to serve your ~flash.core.model.Task
. Here's an example:
../../../flash_examples/serve/tabular_classification/inference_server.py
You can now perform inference from your client like this:
../../../flash_examples/serve/tabular_classification/client.py