# Building Machine Learning APIs with FastAPI

In [1]:
# save model
import joblib
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import make_pipeline

# Load some categories of newsgroups dataset
categories = [
    "soc.religion.christian",
    "talk.religion.misc",
    "comp.sys.mac.hardware",
    "sci.crypt",
]

newsgroups_training = fetch_20newsgroups(
    subset="train", categories=categories, random_state=0
)
newsgroups_testing = fetch_20newsgroups(
    subset="test", categories=categories, random_state=0
)

# Make the pipeline
model = make_pipeline(
    TfidfVectorizer(),
    MultinomialNB(),
)

# Train the model
model.fit(newsgroups_training.data, newsgroups_training.target)

# Serialize the model and the target names
model_file = "newsgroups_model.joblib"
model_targets_tuple = (model, newsgroups_training.target_names)
joblib.dump(model_targets_tuple, model_file)


['newsgroups_model.joblib']

In [8]:
# prediction
import os
from typing import List, Tuple

import joblib
from sklearn.pipeline import Pipeline

# Load the model
model_file = os.path.join("newsgroups_model.joblib")
loaded_model: Tuple[Pipeline, List[str]] = joblib.load(model_file)
model, targets = loaded_model

# Run a prediction
p = model.predict(["computer cpu memory ram"])
print(targets[p[0]])


comp.sys.mac.hardware


In [10]:
%%writefile prediction_endpoint.py
# Fast API endpoint 
import os
from typing import List, Optional, Tuple

import joblib
from fastapi import FastAPI, Depends, status
from pydantic import BaseModel
from sklearn.pipeline import Pipeline


class PredictionInput(BaseModel):
    text: str


class PredictionOutput(BaseModel):
    category: str


memory = joblib.Memory(location="cache.joblib")

@memory.cache(ignore=["model"])
def predict(model: Pipeline, text: str) -> int:
    prediction = model.predict([text])
    return prediction[0]

class NewsgroupsModel:
    model: Optional[Pipeline]
    targets: Optional[List[str]]

    def load_model(self):
        """Loads the model"""
        model_file = os.path.join(os.path.dirname(__file__), "newsgroups_model.joblib")
        loaded_model: Tuple[Pipeline, List[str]] = joblib.load(model_file)
        model, targets = loaded_model
        self.model = model
        self.targets = targets

    async def predict(self, input: PredictionInput) -> PredictionOutput:
        """Runs a prediction"""
        if not self.model or not self.targets:
            raise RuntimeError("Model is not loaded")
        #prediction = self.model.predict([input.text])
        prediction = predict(self.model, input.text)
        #category = self.targets[prediction[0]]
        category = self.targets[prediction]
        return PredictionOutput(category=category)


app = FastAPI()
newgroups_model = NewsgroupsModel()


@app.post("/prediction")
async def prediction(
    output: PredictionOutput = Depends(newgroups_model.predict),
) -> PredictionOutput:
    return output


@app.delete("/cache", status_code=status.HTTP_204_NO_CONTENT)
def delete_cache():
    memory.clear()

@app.on_event("startup")
async def startup():
    newgroups_model.load_model()


Overwriting prediction_endpoint.py


# using API from terminal 
uvicorn prediction_endpoint:app --reload

# predict from text input
curl -X 'POST' \
  'http://127.0.0.1:8000/prediction' \
  -H 'accept: application/json' \
  -H 'Content-Type: application/json' \
  -d '{
  "text": "ram"
}'

# Clear Cache
curl -X 'DELETE' \
  'http://127.0.0.1:8000/cache' \
  -H 'accept: */*'

# using API from python

In [1]:
import requests

base_url = "http://127.0.0.1:8000"

input_data_batch = [{"text": "computer ram"},
             {"text": "I love bible"},
             {"text": "ram"},
             {"text": "I love cryptography "}
             ]

for data in input_data_batch:
    response = requests.post(f"{base_url}/prediction/", json=data)
    print(response.json())


{'category': 'comp.sys.mac.hardware'}
{'category': 'soc.religion.christian'}
{'category': 'comp.sys.mac.hardware'}
{'category': 'sci.crypt'}


In [21]:
# clear cache
response = requests.delete(f"{base_url}/cache")