In [2]:
"""
create ModelHandler and use spaCy for inference.
"""
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.ml.inference.base import RunInference, ModelHandler, PredictionResult

import spacy
from spacy import displacy, Language
from typing import Any, Dict, Iterable, Optional, Sequence

import warnings
warnings.filterwarnings("ignore")

In [5]:
nlp = spacy.load("en_core_web_sm")

text_strings = [
    "The New York Times is an American daily newspaper based in New York City with a worldwide readership.",
    "It was founded in 1851 by Henry Jarvis Raymond and George Jones, and was initially published by Raymond, Jones & Company."
]

displacy.render(nlp(text_strings[0]), style="ent")

In [6]:
displacy.render(nlp(text_strings[1]), style="ent")

In [7]:
# Start building the pipeline
pipeline = beam.Pipeline()

In [8]:
class SpacyModelHandler(ModelHandler[str,
                                     PredictionResult,
                                     Language]):
    def __init__(self, model_name: str = 'en_core_web_sm'):
        self.model_name = model_name

    def load_model(self) -> Language:
        return spacy.load(self.model_name)
    
    def run_inference(self, batch: Sequence[str], model: Language,
                      inference_args: Optional[Dict[str, Any]] = None
                        ) -> Iterable[PredictionResult]:
        preds = []
        for one_text in batch:
            doc = model(one_text)
            preds.append([(ent.text, ent.start_char, ent.end_char, ent.label_) for ent in doc.ents])

        return [PredictionResult(x, y) for x, y in zip(batch, preds)]
        

In [9]:
# Print the results for verification.
with pipeline as p:
    (p 
    | "CreateSentences" >> beam.Create(text_strings)
    | "RunInferenceSpacy" >> RunInference(SpacyModelHandler("en_core_web_sm"))
    | beam.Map(print)
    )

PredictionResult(example='The New York Times is an American daily newspaper based in New York City with a worldwide readership.', inference=[('The New York Times', 0, 18, 'ORG'), ('American', 25, 33, 'NORP'), ('daily', 34, 39, 'DATE'), ('New York City', 59, 72, 'GPE')], model_id=None)
PredictionResult(example='It was founded in 1851 by Henry Jarvis Raymond and George Jones, and was initially published by Raymond, Jones & Company.', inference=[('1851', 18, 22, 'DATE'), ('Henry Jarvis Raymond', 26, 46, 'PERSON'), ('George Jones', 51, 63, 'PERSON'), ('Raymond, Jones & Company', 96, 120, 'ORG')], model_id=None)
