In [None]:
import logging

import apache_beam as beam
from apache_beam import Create, FlatMap, Map, ParDo, Filter, Flatten, Partition, MapTuple, FlatMapTuple
from apache_beam import Keys, Values
from apache_beam.transforms.util import WithKeys

from apache_beam.runners.interactive.interactive_runner import InteractiveRunner
import apache_beam.runners.interactive.interactive_beam as ib
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.inference.gemini_inference import GeminiModelHandler, generate_from_string
from collections.abc import Callable
from collections.abc import Iterable
from collections.abc import Sequence
from typing import Any
from typing import Optional

from google import genai
from google.genai import errors

from apache_beam.ml.inference import utils
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.ml.inference.base import RemoteModelHandler
from apache_beam.options import pipeline_options
from apache_beam.options.pipeline_options import DebugOptions
from apache_beam.runners import DataflowRunner

In [None]:
PROJECT_ID = !gcloud config list --format 'value(core.project)'
PROJECT_ID = PROJECT_ID[0]
REGION = "us-central1"
%env GOOGLE_CLOUD_PROJECT={PROJECT_ID}
BUCKET_NAME=f'dataflow_demo_{PROJECT_ID}'

### Creating a GCS bucket

In [None]:
!gsutil mb -l {REGION} gs://{BUCKET_NAME}

#### Creating BigQuery Table 
As an alternative option you can create table 

In [None]:
%%bigquery
CREATE OR REPLACE TABLE `genai.genai_pipeline`
(
example STRING,
inference STRING
)
OPTIONS(
    description="Table to store results"
);

In [None]:
p = beam.Pipeline(InteractiveRunner())

MODEL_NAME = "gemini-2.5-flash"

import json

import logging

#Custom utility method to specifiy gemini client invocation
def response_to_text(
    model_name: str, batch: Sequence[str], 
    model: genai.Client, inference_args: dict[str, Any]):
    response_object = model.models.generate_content(
            model=model_name, contents=batch, **inference_args
        )
    try:
        if hasattr(response_object, 'text') and response_object.text:
            return [response_object.text]
        else:
            [""]
            
    except ValueError:
        # TODO: Imprement proper error handling
        return ["ERROR"]

#Implementation of the ModelHandler interface for Google Gemini
model_handler = GeminiModelHandler(
        model_name=MODEL_NAME,
        request_fn=response_to_text,
        project=PROJECT_ID,
        location=REGION,
    )

#Output BigQuery table
output_table = f'{PROJECT_ID}.genai.genai_pipeline'

#Output BigQuery table schema
output_table_schema = 'input_text:STRING, output_text:STRING'

#Input elements to process
elements = [
    "What is the capital of Ireland?", 
    "What is the capital of France?",
]

#Custom utility method to prepare records to sore to BigQuery
class PrepareRecords(beam.DoFn):
    """
    Extract the relevant data from the PredictionResult object.
    """
    def process(self, element: PredictionResult) -> Iterable[dict[str, str]]:
        yield {'input_text': element.example, 'output_text': element.inference}

# The pipeline:
lines = (p  | "Create elements" >> Create(elements)
            | "RunInference" >> RunInference(model_handler)
            | "Prepare Record" >> beam.ParDo(PrepareRecords())
            | "Write BigQuery" >> beam.io.gcp.bigquery.WriteToBigQuery(
                table=output_table,
                schema=output_table_schema,
                method=beam.io.gcp.bigquery.WriteToBigQuery.Method.STREAMING_INSERTS,
                write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND,
                create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED)
        )

# Uncomment if you want to use Interactive Runner:
# ib.show(lines)

#Uncomment if you want to use Dataflow Runner:
options = pipeline_options.PipelineOptions(
    flags={},
    project=PROJECT_ID,
    region='us-central1',
    staging_location=f'gs://{BUCKET_NAME}/staging',
    temp_location=f'gs://{BUCKET_NAME}/temp',
    #sdk_container_image=f'gcr.io/{PROJECT_ID}/cc_gpu:latest',
    machine_type='n1-standard-4',
    disk_size_gb=50)

pipeline_result = DataflowRunner().run_pipeline(p, options=options)
pipeline_result.wait_until_finish()