In [3]:
from pytorch_lightning.metrics.classification import Accuracy, Precision, Recall

import flash
from flash.core.data import download_data
from flash.tabular import TabularClassifier, TabularData

# 1. Download the data

In [5]:
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/')

# 2. Load the data

In [7]:
datamodule = TabularData.from_csv(
    "./data/titanic/titanic.csv",
    test_csv="./data/titanic/test.csv",
    categorical_input=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
    numerical_input=["Fare"],
    target="Survived",
    val_size=0.25,
)


# 3. Build the model

In [9]:
model = TabularClassifier.from_data(datamodule, metrics=[Accuracy(), Precision(), Recall()])

# 4. Create the trainer. Run 10 times on data

In [11]:
trainer = flash.Trainer(max_epochs=10)

GPU available: False, used: False
TPU available: None, using: 0 TPU cores


# 5. Train the model

In [13]:
trainer.fit(model, datamodule=datamodule)


  | Name    | Type        | Params
----------------------------------------
0 | model   | Sequential  | 59.1 K
1 | metrics | ModuleDict  | 0     
2 | embs    | ModuleList  | 15.1 K
3 | bn_num  | BatchNorm1d | 2     
----------------------------------------
74.2 K    Trainable params
0         Non-trainable params
74.2 K    Total params
0.297     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

1

# 6. Test model

In [15]:
trainer.test()

Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_accuracy': 0.7333333492279053,
 'test_cross_entropy': 0.6910470128059387,
 'test_precision': 0.7333333492279053,
 'test_recall': 0.7333333492279053}
--------------------------------------------------------------------------------


[{'test_accuracy': 0.7333333492279053,
  'test_precision': 0.7333333492279053,
  'test_recall': 0.7333333492279053,
  'test_cross_entropy': 0.6910470128059387}]

# 7. Save it!

In [17]:
trainer.save_checkpoint("tabular_classification_model.pt")