Our next step is to create some tests for our ~flash.core.model.Task
. For the TemplateSKLearnClassifier
, we will just create some basic tests. You should expand on these to include tests for any specific functionality you have in your ~flash.core.model.Task
.
We use smoke tests, usually called test_smoke
, throughout. These just instantiate the class we are testing, to see that they can be created without raising any errors.
Before we write our custom tests, we should add out examples to the CI. To do this, add a line for each example (finetuning
and predict
) to the annotation of test_example
in tests/examples/test_scripts.py. Here's how those lines look for our template.py
examples:
pytest.param(
"finetuning", "template.py", marks=pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
),
...
pytest.param(
"predict", "template.py", marks=pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
),
The most important tests in test_data.py check that the from_*
methods work correctly. In the class TestTemplateData
, we have two of these: test_from_numpy
and test_from_sklearn
. In general, there should be one test_from_*
method for each ~flash.core.data.io.input
you have configured.
Here's the code for test_from_numpy
:
../../../tests/template/classification/test_data.py
In test_model.py, we first have test_forward
and test_train
. These test that tensors can be passed to the forward and that the ~flash.core.model.Task
can be trained. Here's the code for test_forward
and test_train
:
../../../tests/template/classification/test_model.py
../../../tests/template/classification/test_model.py
We also include tests for validating and testing: test_val
, and test_test
. These tests are very similar to test_train
, but here they are for completeness:
../../../tests/template/classification/test_model.py
../../../tests/template/classification/test_model.py
We also include tests for prediction named test_predict_*
for each of our data sources. In our case, we have test_predict_numpy
and test_predict_sklearn
. These tests should load the data with a ~flash.core.data.data_module.DataModule
and generate predictions with Trainer.predict <flash.core.trainer.Trainer.predict>
. Here's test_predict_sklearn
as an example:
../../../tests/template/classification/test_model.py
Now that you've written the tests, it's time to add some docs! <contributing_docs>