# Dataflow NER Pipeline with BERT

#### Install Required Libraries

In [None]:
%pip install --quiet scikit-learn
%pip install --quiet transformers[torch]
%pip install --quiet seqeval
%pip install --quiet tensorflow
%pip install --quiet tf-keras
%pip install --quiet torch --quiet
%pip install --quiet datasets --quiet
%pip install --quiet evaluate --quiet

In [None]:
PROJECT_ID = !gcloud config list --format 'value(core.project)'
PROJECT_ID = PROJECT_ID[0]
REGION = "us-central1"
%env GOOGLE_CLOUD_PROJECT={PROJECT_ID}
#BUCKET_NAME=f'dataflow_demo_{PROJECT_ID}'

In [None]:
# Use the exact model you plan to train (bert-base, roberta, etc.)
MODEL_NAME = "google-bert/bert-base-multilingual-cased"
DATASET = "./train_bert_ner_1k.txt"
OUTPUT_BUCKET_NAME=f"bert-finetuning-ner-{PROJECT_ID}"
gcs_bucket = f"gs://{OUTPUT_BUCKET_NAME}"

In [None]:
!gsutil mb -l {REGION} gs://{OUTPUT_BUCKET_NAME}

#### Inference pipeline based on Apache Beam
Can be orchestrated by using Cloud Composer or Vertex AI Pipelines

In [None]:
DATAFLOW_BUCKET=f"bert-ner-demo-io-storage-{PROJECT_ID}" 
OUTPUT_GCS_BUCKET = f"gs://{DATAFLOW_BUCKET}/output/"
TEMP_LOCATION=f"gs://{DATAFLOW_BUCKET}/temp/"
STAGING_LOCATION=f"gs://{DATAFLOW_BUCKET}/staging/"

In [None]:
!gsutil mb -l {REGION} gs://{DATAFLOW_BUCKET}

In [None]:

ENDPOINT_ID="4897865805393297408"
PUBSUB_SUBSCRIPTION='projects/oleksandr-demo/subscriptions/input_messages-sub'
BIGQUERY_NER_TABLE=f'{PROJECT_ID}:genai.ner_extract'
DATAFLOW_JOB_NAME='bert-ner-inference'

In [None]:
import json
import argparse
import re
import os
from apache_beam import Create
from typing import Tuple, Iterable
from google.cloud import storage
from apache_beam.ml.inference import RunInference
from apache_beam.ml.inference.base import PredictionResult, KeyedModelHandler, ModelHandler
from apache_beam.ml.inference.vertex_ai_inference import VertexAIModelHandlerJSON
from apache_beam.options.pipeline_options import PipelineOptions, StandardOptions
import apache_beam as beam

In [None]:
def _get_core_type(pred):
    """
    Helper to extract 'PHONE' from 'B-PHONE', 'I-PHONE', or just 'PHONE'
    Handles keys 'entity', 'entity_group', or 'label'.
    """
    label = pred.get('entity_group') or pred.get('entity') or pred.get('label')
    if not label: return "UNKNOWN"

    # Strip B- or I- prefixes
    if "-" in label and (label.startswith("B-") or label.startswith("I-")):
        return label.split("-", 1)[1]
    return label

In [None]:
def postprocess_predictions(predictions_result, merge_distance=1):
    """
    Merges adjacent entities of the same type.
    
    Args:
        predictions: List of dicts returned by the pipeline.
        merge_distance: Max characters allowed between entities to merge them.
                        0 = must be touching (e.g., "123" + "-")
                        1 = allows spaces (e.g., "John" + " " + "Doe")
    """
    predictions = predictions_result[1].inference
    if not predictions:
        return []

    # 1. Sort by start index to ensure processing order
    sorted_preds = sorted(predictions, key=lambda x: int(x['start']))
    
    merged = []
    
    # Initialize the first entity
    # We strip 'B-' or 'I-' to compare the core type (e.g., "PHONE")
    first_pred = sorted_preds[0]
    current_group = {
        "entity_group": _get_core_type(first_pred),
        "score": first_pred['score'],
        "word": first_pred['word'],
        "start": int(first_pred['start']),
        "end": int(first_pred['end'])
    }

    for next_pred in sorted_preds[1:]:
        next_type = _get_core_type(next_pred)
        
        # Calculate gap between current end and next start
        gap = int(next_pred['start'] - current_group['end'])
        
        # MERGE CONDITION:
        # 1. Same Entity Type (e.g. PHONE == PHONE)
        # 2. Adjacent or close enough (gap <= threshold)
        if next_type == current_group['entity_group'] and gap <= merge_distance:
            
            # Update End Position
            current_group['end'] = int(next_pred['end'])
            
            # Merge text safely
            # If there is a space in the original text (gap > 0), add it back
            # Also handle BERT subwords (remove '##' if present)
            sep = " " * gap # Reconstruct space if gap exists
            clean_word = next_pred['word'].replace("##", "")
            
            # If the previous word ended with a subword marker (rare but possible), handle it
            # But usually we just append
            current_group['word'] += sep + clean_word
            
            # Update Score: You can take Max or Average
            current_group['score'] = max(current_group['score'], next_pred['score'])
            
        else:
            # NO MERGE: Push current and start new
            merged.append(current_group)
            
            current_group = {
                "entity_group": next_type,
                "score": next_pred['score'],
                "word": next_pred['word'],
                "start": int(next_pred['start']),
                "end": int(next_pred['end'])
            }

    # Append the final group
    merged.append(current_group)
    
    # Final cleanup of words (in case the very first word had ##)
    for m in merged:
        m['word'] = m['word'].replace("##", "")
        
    return {"id": predictions_result[0], "message": predictions_result[1].example, "ner": merged}

#### Define BigQuery Schema
Using Nested Repeated field for the list of extracted NER entities

In [None]:
BIGQUERY_NER_SCHEMA = {
    "fields": [
        # Root level fields
        {"name": "id", "type": "STRING", "mode": "NULLABLE"},
        {"name": "message", "type": "STRING", "mode": "NULLABLE"},
        
        # Nested Repeated field for the list of entities
        {
            "name": "ner",
            "type": "RECORD", 
            "mode": "REPEATED",
            "fields": [
                {"name": "entity_group", "type": "STRING", "mode": "NULLABLE"},
                {"name": "score", "type": "FLOAT", "mode": "NULLABLE"},
                {"name": "word", "type": "STRING", "mode": "NULLABLE"},
                {"name": "start", "type": "INTEGER", "mode": "NULLABLE"},
                {"name": "end", "type": "INTEGER", "mode": "NULLABLE"}
            ]
        }
    ]
}

#### Create a basic pipeline with NER Inference

In [None]:
elements=[
    {"id": "0e42fec7-71ce-40bd-ac83-e2ffc36c6fe0", "message": "Is kweaver@example.org the correct address for Christina Ramirez ?"},
    {"id": "3590d10c-8c1d-4df0-817d-4184387d3a54", "message": "Is davidlucas@example.com the correct address for Patricia Parsons ?"}
]

ner_model_handler = KeyedModelHandler(VertexAIModelHandlerJSON(
    endpoint_id=ENDPOINT_ID, 
    project=PROJECT_ID, 
    location=REGION))

# Set Interactive Runner for testing
options = PipelineOptions(flags={})

# Set Interactive Runner for testing
p = beam.Pipeline(InteractiveRunner(), options=options)

p_output = (p
        | "Create elements" >> Create(elements)
        | "ToKeyValue" >> beam.Map(lambda x: (x["id"], x["message"]))
        | "NER Model" >> RunInference(ner_model_handler)
        | "Postprocess" >> beam.Map(postprocess_predictions)
        | "Format Output" >> beam.Map(json.dumps))

ib.show(p_output)

#### Submit Dataflow Streaming Job

In [None]:
# Set up Beam PipelineOptions for Dataflow Runner
pipeline_options = PipelineOptions(
    runner="DataflowRunner",
    project=PROJECT_ID,
    region=REGION,
    temp_location=TEMP_LOCATION,
    staging_location=STAGING_LOCATION,
    job_name=DATAFLOW_JOB_NAME,
    save_main_session=True,
    max_num_workers=1,
    streaming=True
)

p = beam.Pipeline(options=pipeline_options)

pp = (p
        | 'ReadPubSub' >> beam.io.ReadFromPubSub(subscription=PUBSUB_SUBSCRIPTION)
        | 'DecodeMsg' >> beam.Map(lambda row: json.loads(row.decode('utf-8')))
        | "ToKeyValue" >> beam.Map(lambda x: (x["id"], x["message"]))
        | "NER Model" >> RunInference(ner_model_handler)
        | "Postprocess" >> beam.Map(postprocess_predictions)
        | "WriteBigQuery" >> beam.io.gcp.bigquery.WriteToBigQuery(
            table=BIGQUERY_NER_TABLE,
            schema=BIGQUERY_NER_SCHEMA,
            method=beam.io.gcp.bigquery.WriteToBigQuery.Method.STREAMING_INSERTS,
            write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND,
            create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED)
        )

p.run().wait_until_finish()