# [7] Model Serving

### Batch inference

In [None]:
import ray.data
from ray.train.torch import TorchPredictor
from ray.data import ActorPoolStrategy

In [None]:
# Load predictor
run_id = sorted_runs.iloc[0].run_id
best_checkpoint = get_best_checkpoint(run_id=run_id)

In [None]:
class Predictor:
    def __init__(self, checkpoint):
        self.predictor = TorchPredictor.from_checkpoint(checkpoint)
    def __call__(self, batch):
        z = self.predictor.predict(batch)["predictions"]
        y_pred = np.stack(z).argmax(1)
        prediction = decode(y_pred, preprocessor.index_to_class)
        return {"prediction": prediction}

In [None]:
predictor = Predictor()
prediction = predictor(batch)

In [None]:
# Batch predict
predictions = test_ds.map_batches(
    Predictor,
    batch_size=128,
    compute=ActorPoolStrategy(min_size=1, max_size=2),  # scaling
    batch_format="pandas",
    fn_constructor_kwargs={"checkpoint": best_checkpoint})

In [None]:
# Sample predictions
predictions.take(3)

### Online inference

In [None]:
from fastapi import FastAPI
from ray import serve
import requests
from starlette.requests import Request

In [None]:
# Define application
app = FastAPI(
    title="Made With ML",
    description="Classify machine learning projects.",
    version="0.1")

In [None]:
class ModelDeployment:

    def __init__(self, run_id):
        """Initialize the model."""
        self.run_id = run_id
        mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)  # so workers have access to model registry
        best_checkpoint = get_best_checkpoint(run_id=run_id)
        self.predictor = TorchPredictor.from_checkpoint(best_checkpoint)
        self.preprocessor = self.predictor.get_preprocessor()

    @app.post("/predict/")
    async def _predict(self, request: Request):
        data = await request.json()
        df = pd.DataFrame([{"title": data.get("title", ""), "description": data.get("description", ""), "tag": ""}])
        results = predict_with_proba(df=df, predictor=self.predictor)
        return {"results": results}

In [None]:
@serve.deployment(route_prefix="/", num_replicas="1", ray_actor_options={"num_cpus": 8, "num_gpus": 0})
@serve.ingress(app)
class ModelDeployment:
    pass

In [None]:
# Run service
sorted_runs = mlflow.search_runs(experiment_names=[experiment_name], order_by=["metrics.val_loss ASC"])
run_id = sorted_runs.iloc[0].run_id
serve.run(ModelDeployment.bind(run_id=run_id))

In [None]:
# Query
title = "Transfer learning with transformers"
description = "Using transformers for transfer learning on text classification tasks."
json_data = json.dumps({"title": title, "description": description})
requests.post("http://127.0.0.1:8000/predict/", data=json_data).json()

In [None]:
# Query (noise)
title = " 65n7r5675"  # random noise
json_data = json.dumps({"title": title, "description": ""})
requests.post("http://127.0.0.1:8000/predict/", data=json_data).json()

In [None]:
# Shutdown
serve.shutdown()

### Custom Logic

In [None]:
@serve.deployment(route_prefix="/", num_replicas="1", ray_actor_options={"num_cpus": 8, "num_gpus": 0})
@serve.ingress(app)
class ModelDeploymentRobust:

    def __init__(self, run_id, threshold=0.9):
        """Initialize the model."""
        self.run_id = run_id
        self.threshold = threshold
        mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)  # so workers have access to model registry
        best_checkpoint = get_best_checkpoint(run_id=run_id)
        self.predictor = TorchPredictor.from_checkpoint(best_checkpoint)
        self.preprocessor = self.predictor.get_preprocessor()

    @app.post("/predict/")
    async def _predict(self, request: Request):
        data = await request.json()
        df = pd.DataFrame([{"title": data.get("title", ""), "description": data.get("description", ""), "tag": ""}])
        results = predict_with_proba(df=df, predictor=self.predictor)

        # Apply custom logic
        for i, result in enumerate(results):
            pred = result["prediction"]
            prob = result["probabilities"]
            if prob[pred] < self.threshold:
                results[i]["prediction"] = "other"

        return {"results": results}

In [None]:
# Run service
serve.run(ModelDeploymentRobust.bind(run_id=run_id, threshold=0.9))

In [None]:
# Query (noise)
title = " 65n7r5675"  # random noise
json_data = json.dumps({"title": title, "description": ""})
requests.post("http://127.0.0.1:8000/predict/", data=json_data).json()

In [None]:
# Shutdown
serve.shutdown()