In [74]:
# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the "License")

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License

# Apache Beam RunInference in a Streaming Pipeline

<table align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/run_inference_streaming_pipeline.ipynb"><img src="https://raw.githubusercontent.com/google/or-tools/main/tools/colab_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/apache/beam/blob/master/examples/notebooks/beam-ml/run_inference_streaming_pipeline.ipynb"><img src="https://raw.githubusercontent.com/google/or-tools/main/tools/github_32px.png" />View source on GitHub</a>
  </td>
</table>


This notebook shows how to use the Apache Beam [RunInference](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.RunInference) transform in a streaming pipeline with [Google Cloud Pub-Sub](https://cloud.google.com/pubsub?utm_source=google&utm_medium=cpc&utm_campaign=na-US-all-en-dr-bkws-all-all-trial-b-dr-1605212&utm_content=text-ad-none-any-DEV_c-CRE_648329165516-ADGP_Desk%20%7C%20BKWS%20-%20BRO%20%7C%20Txt%20_%20Pub%2Fsub-KWID_43700075187144857-aud-664745643345%3Akwd-874320293016&utm_term=KW_pub%20sub%20google%20cloud-ST_pub%20sub%20google%20cloud&gclid=CjwKCAjw-IWkBhBTEiwA2exyO19xMFn6h1UKjb4QUavatV8Yb5Au9pCQj2_VAo0rzaYS8v2bq5VmuBoCL9wQAvD_BwE&gclsrc=aw.ds).

This notebook demonstrates the following steps:
- Load and save a model from Hugging Face Models Hub.
- Use Pub/Sub IO in Python SDK as a streaming source.
- Use PyTorch model handler for RunInference.

For more information about using RunInference, see [Get started with AI/ML pipelines](https://beam.apache.org/documentation/ml/overview/) in the Apache Beam documentation.

## Before you begin
Set up your environment and download dependencies.

In [None]:
!pip install apache_beam[gcp]==2.48.0
!pip install torch
!pip install transformers
!pip install tensorflow

### Authenticate with Google Cloud
This notebook relies on Google Cloud Pub/Sub as an input to the pipeline as well for writing out the results. To use your Google Cloud account, authenticate this notebook.

In [63]:
from google.colab import auth
auth.authenticate_user()

### Import dependencies and set up your bucket
Use the following code to import dependencies and to set up your Google Cloud Storage bucket.

Replace `MESSAGE_TOPIC` and `RESPONSE_TOPIC` with the Pub/Sub topics in your project.

**Important**: If an error occurs, restart your runtime.

In [75]:
import os
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.inference.pytorch_inference import make_tensor_model_fn
from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor
import torch
from transformers import AutoConfig
from transformers import AutoModelForSeq2SeqLM
from transformers import AutoTokenizer
from transformers.tokenization_utils import PreTrainedTokenizer

message_topic = "MESSAGE_TOPIC"
response_topic = "RESPONSE_TOPIC"

MAX_RESPONSE_TOKENS = 256

model_name = "google/flan-t5-small"
state_dict_path = "saved_model"

### Download and save the model
We will use the AutoClasses from Hugging Face to instantly load the model in memory and later save it to the path defined above.

In [76]:
model = AutoModelForSeq2SeqLM.from_pretrained(
        model_name, torch_dtype=torch.bfloat16
    )

directory = os.path.dirname(state_dict_path)
torch.save(model.state_dict(), state_dict_path)

### Utitlity functions for before/after running RunInference

In [78]:
def to_tensors(input_text: str, tokenizer) -> torch.Tensor:
    """Encodes input text into token tensors.
    Args:
        input_text: Input text for the LLM model.
        tokenizer: Tokenizer for the LLM model.
    Returns: Tokenized input tokens.
    """
    return tokenizer(input_text, return_tensors="pt").input_ids[0]


def get_response(result: PredictionResult, tokenizer) -> str:
    """Decodes output token tensors into text.
    Args:
        result: Prediction results from the RunInference transform.
        tokenizer: Tokenizer for the LLM model.
    Returns: The model's response as text.
    """
    output_tokens = result.inference
    return tokenizer.decode(output_tokens, skip_special_tokens=True)

### Run Inference Pipeline
Run the cell below and publish messages from the `message_topic` using google cloud console.

(Since it is a streming pipeline, the cell will keep on running until it is manually stopped.)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

# create an instance of pytorch model handler
model_handler = PytorchModelHandlerTensor(
            state_dict_path=state_dict_path,
            model_class=AutoModelForSeq2SeqLM.from_config,
            model_params={"config": AutoConfig.from_pretrained(model_name)},
            inference_fn=make_tensor_model_fn("generate"),
        )

# set up pipeline options to enable streaming
pipeline = beam.Pipeline(options=PipelineOptions(save_main_session=True,pickle_library="cloudpickle",streaming=True))

with pipeline as p:
  _ = (
          p
          | "Read from Pub/Sub" >> beam.io.ReadFromPubSub(message_topic)
          | "Decode bytes" >> beam.Map(lambda msg: msg.decode("utf-8"))
          | "To tensors" >> beam.Map(to_tensors, tokenizer)
          | "RunInference"
            >> RunInference(
                model_handler,
                inference_args={"max_new_tokens": MAX_RESPONSE_TOKENS},
            )
          | "Get response" >> beam.Map(get_response, tokenizer)
          | "Encode bytes" >> beam.Map(lambda msg: msg.encode("utf-8"))
          | "Write to Pub/Sub" >> beam.io.WriteToPubSub(response_topic)
      )