# Training

In [None]:
# %%capture
! pip install git+https://github.com/PyTorchLightning/lightning-flash.git

In [None]:
from torchmetrics.classification import Accuracy, Precision, Recall

import flash
from flash.core.data.utils import download_data
from flash.tabular import TabularClassifier, TabularData

###  1. Download the data
The data are downloaded from a URL, and save in a 'data' directory.

In [None]:
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/')

###  2. Load the data
Flash Tasks have built-in DataModules that you can use to organize your data. Pass in a train, validation and test folders and Flash will take care of the rest.

Creates a TabularData relies on [Pandas DataFrame](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html). 

In [None]:
datamodule = TabularData.from_csv(
    ["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
    ["Fare"],
    target_fields="Survived",
    train_file="./data/titanic/titanic.csv",
    test_file="./data/titanic/test.csv",
    val_split=0.25,
)


###  3. Build the model

Note: Categorical columns will be mapped to the embedding space. Embedding space is set of tensors to be trained associated to each categorical column. 

In [None]:
model = TabularClassifier.from_data(datamodule, metrics=[Accuracy(), Precision(), Recall()])

###  4. Create the trainer. Run 10 times on data

In [None]:
trainer = flash.Trainer(max_epochs=10)

###  5. Train the model

In [None]:
trainer.fit(model, datamodule=datamodule)

###  6. Test model

In [None]:
trainer.test()

###  7. Save it!

In [None]:
trainer.save_checkpoint("tabular_classification_model.pt")

# Predicting

###  8. Load the model from a checkpoint

`TabularClassifier.load_from_checkpoint` supports both url or local_path to a checkpoint. If provided with an url, the checkpoint will first be downloaded and laoded to re-create the model. 

In [None]:
model = TabularClassifier.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/tabular_classification_model.pt")

###  9. Generate predictions from a sheet file! Who would survive?

`TabularClassifier.predict` support both DataFrame and path to `.csv` file.

In [None]:
predictions = model.predict("data/titanic/titanic.csv")

In [None]:
print(predictions)