Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Latest commit

 

History

History
68 lines (49 loc) · 2.33 KB

predictions.rst

File metadata and controls

68 lines (49 loc) · 2.33 KB

Predictions (inference)

You can use Flash to get predictions on pretrained or finetuned models.

First create a ~flash.core.data.data_module.DataModule with some predict data, then pass it to the Trainer.predict <flash.core.trainer.Trainer.predict> method.

from flash import Trainer
from flash.core.data.utils import download_data
from flash.image import ImageClassifier, ImageClassificationData

# 1. Download the data set
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

# 2. Load the model from a checkpoint
model = ImageClassifier.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/0.7.0/image_classification_model.pt"
)

# 3. Predict whether the image contains an ant or a bee
trainer = Trainer()
datamodule = ImageClassificationData.from_files(
    predict_files=["data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg"]
)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)
# out: [["bees"]]

Serializing predictions

To change the output format of predictions you can attach an ~flash.core.data.io.output.Output to your ~flash.core.model.Task. For example, you can choose to output probabilities (for more options see the API reference below).

from flash.core.classification import ProbabilitiesOutput
from flash.core.data.utils import download_data
from flash.image import ImageClassifier


# 1. Download the data set
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

# 2. Load the model from a checkpoint
model = ImageClassifier.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/0.7.0/image_classification_model.pt"
)

# 3. Attach the Output
model.output = ProbabilitiesOutput()

# 4. Predict whether the image contains an ant or a bee
trainer = Trainer()
datamodule = ImageClassificationData.from_files(
    predict_files=["data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg"]
)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)
# out: [[[0.5926494598388672, 0.40735048055648804]]]