# Vertex AI Conversation - Evaluation Tool

This tool requires user's input in several steps. Please run the cells one by one (Shift+Enter) to ensure all the steps are successfully completed.

# Setup


In [None]:
# @markdown `install packages`
!pip install dfcx-scrapi --quiet
!pip install rouge-score --quiet

# workaround until vertexai import is fixed
!pip uninstall bigframes -y --quiet
!pip install bigframes==0.26.0 --quiet

In [None]:
# @markdown `import dependencies`

import abc
import collections
import dataclasses
import datetime
import io
import itertools
import json
import logging
import math
import statistics
import sys
import time
import threading
import re
import os

from typing import Any, TypedDict

from collections.abc import Iterable

import plotly.graph_objects as go

import vertexai
import gspread
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
sns.set_style("whitegrid")

from tqdm.auto import tqdm
from tqdm.contrib import concurrent

from dfcx_scrapi.core import agents
from dfcx_scrapi.core import scrapi_base
from dfcx_scrapi.core import sessions
from dfcx_scrapi.core.sessions import Sessions
from dfcx_scrapi.tools import dataframe_functions
from google.cloud import bigquery

from google.cloud.bigquery import SchemaField

from googleapiclient.discovery import build
from googleapiclient.http import MediaInMemoryUpload, MediaIoBaseDownload

from google.api_core import exceptions
from google.auth import default
from google.cloud import aiplatform
from google.cloud.dialogflowcx_v3beta1 import services
from google.cloud.dialogflowcx_v3beta1 import types
from google.colab import auth
from google.protobuf.json_format import MessageToDict
from google.colab import files

from rouge_score import rouge_scorer

from vertexai.language_models import TextGenerationModel

pd.options.display.max_colwidth = 200

In [None]:
# @markdown `authenticate`

if "google.colab" in sys.modules:
    from google.auth import default
    from google.colab import auth

    auth.authenticate_user()
    credentials, _ = default()
else:
    # Otherwise, attempt to discover local credentials as described in
    # https://cloud.google.com/docs/authentication/application-default-credentials
    pass


---

# Implementation


In [None]:
# @markdown `run this cell to define various utility and helper functions`
# @markdown
# @markdown > This cell contains several decorator methods related to handling
# @markdown API call errors and function call rate limitation.

_INTERVAL_SENTINEL = object()

MAX_RETRIES = 5
# max number of attempts for exponential backoff retries in case of API
# call errors
RATE = 2
# LLM API call rate limitation where RATE=2 for example means that 2 LLM calls
# can occur per second


def load_spreadsheet(
    sheet_url: str, worksheet_name: str, credentials: Any
) -> pd.DataFrame:
  """Loads the content of a spreadsheet into pandas DataFrame."""
  sheets_client = gspread.authorize(credentials)
  sheet = sheets_client.open_by_url(sheet_url)
  worksheet = sheet.worksheet(worksheet_name)
  return pd.DataFrame(worksheet.get_all_records())


def ratelimit(rate: float):
  """Decorator that controls the frequency of function calls."""
  seconds_per_event = 1.0 / rate
  lock = threading.Lock()
  bucket = 0
  last = 0

  def decorate(func):
    def rate_limited_function(*args, **kwargs):
      nonlocal last, bucket
      while True:
        with lock:
          now = time.time()
          bucket += now - last
          last = now

          # capping the bucket in order to avoid accumulating too many
          bucket = min(bucket, seconds_per_event)

          # if bucket is less than `seconds_per_event` then we have to wait
          # `seconds_per_event` - `bucket` seconds until a new "token" is
          # refilled
          delay = max(seconds_per_event - bucket, 0)

          if delay == 0:
            # consuming a token and breaking out of the delay loop to perform
            # the function call
            bucket -= seconds_per_event
            break
        time.sleep(delay)
      return func(*args, **kwargs)
    return rate_limited_function
  return decorate


def should_retry(err: exceptions.GoogleAPICallError) -> bool:
  """Helper function for deciding whether we should retry the error or not."""
  return isinstance(err, (exceptions.TooManyRequests, exceptions.ServerError))


def retry_api_call(retry_intervals: Iterable[float]):
  """Decorator for retrying certain GoogleAPICallError exception types."""
  def decorate(func):
    def retried_api_call_func(*args, **kwargs):
      interval_iterator = iter(retry_intervals)
      while True:
        try:
          return func(*args, **kwargs)
        except exceptions.GoogleAPICallError as err:
          print(f"retrying api call: {err}")
          if not should_retry(err):
            raise

          interval = next(interval_iterator, _INTERVAL_SENTINEL)
          if interval is _INTERVAL_SENTINEL:
            raise
          time.sleep(interval)
    return retried_api_call_func
  return decorate


def handle_api_error(func):
  """Decorator that chatches GoogleAPICallError exception and returns None."""
  def handled_api_error_func(*args, **kwargs):
    try:
      return func(*args, **kwargs)
    except exceptions.GoogleAPICallError as err:
      print(f"failed api call: {err}")
      return None
  return handled_api_error_func

In [None]:
# @markdown `run this cell to define vertex ai conversation scraper`
# @markdown
# @markdown > This cell contains the code for Vertex AI Conversation scraper
# @markdown that interacts with DetectIntent method of Dialogflow service to
# @markdown process a queryset.

DataStoreConnectionSignals = (
    types.data_store_connection.DataStoreConnectionSignals
)

GLOBAL_SCOPE = ["https://spreadsheets.google.com/feeds"]

CONVERSATION_ID = "conversation_id"
TURN_INDEX = "turn_index"
QUERY = "query"
USER_METADATA = "user_metadata"
REFERENCE = "expected_answer"
EXPECTED_URI = "expected_uri"
SESSION_ID = "session_id"
RESPONSE = "query_result"
GOLDEN_SNIPPET = "golden_snippet"

AGENT_URI = "projects/{project_id}/locations/{location}/agents/{agent_id}"

INPUT_SCHEMA_REQUIRED_COLUMNS = [
    CONVERSATION_ID, TURN_INDEX, QUERY, REFERENCE, EXPECTED_URI, USER_METADATA
]

_EXECUTION_SEQUENCE_KEY = "DataStore Execution Sequence"
_EXECUTION_RESULT_KEY = "executionResult"

_PROJECT_ID_PATTERN = re.compile(r"projects/(.*?)/")
_LOCATION_PATTERN = re.compile(r"locations/(.*?)/")
_AGENT_ID_PATTERN = re.compile(r"agents/(.*?)/")

ANSWER_TEXT = "answer_text"

_RESPONSE_TYPE = "response_type"
_RESPONSE_REASON = "response_reason"
_LATENCY = "latency"
_FAQ_CITATION = "faq_citation"
_SEARCH_FALLBACK = "search_fallback"
_UNSTRUCTURED_CITATION = "unstructured_citation"
_WEBSITE_CITATION = "website_citation"
_LANGUAGE = "language"

_REWRITER_LLM_PROMPT = "rewriter_llm_rendered_prompt"
_REWRITER_LLM_OUTPUT = "rewriter_llm_output"
_REWRITTEN_QUERY = "rewritten_query"
_SEARCH_RESULTS = "search_results"
_ANSWER_GENERATOR_LLM_PROMPT = "answer_generator_llm_rendered_prompt"
_ANSWER_GENERATOR_LLM_OUTPUT = "answer_generator_llm_output"
_GENERATED_ANSWER = "generated_answer"
_CITED_SNIPPET_INDICES = "cited_snippet_indices"
_GROUNDING_DECISION = "grounding_decision"
_GROUNDING_SCORE = "grounding_score"
_SAFETY_DECISION = "safety_decision"
_SAFETY_BANNED_PHRASE = "safety_banned_phrase_match"


def _extract_match_type(query_result: types.session.QueryResult) -> str:
  """Extracts the name of the match type from query result."""
  try:
    return types.session.Match.MatchType(query_result.match.match_type).name
  except ValueError:
    # if an enum type is returned which is not visible externally then fallback
    # to default value
    return types.session.Match.MatchType(0).name


def _extract_execution_result(
    query_result: types.session.QueryResult
) -> dict[str, Any]:
  """Extracts the execution result from diagnostic info."""
  if _EXECUTION_SEQUENCE_KEY in query_result.diagnostic_info:
    execution_sequence = query_result.diagnostic_info[_EXECUTION_SEQUENCE_KEY]
    if _EXECUTION_RESULT_KEY in execution_sequence:
      return MessageToDict(execution_sequence[_EXECUTION_RESULT_KEY])
  return {}


def _extract_answer_text(
    query_result: types.session.QueryResult
) -> str | None:
  """Extracts the text type responses and concatenates them."""
  result: list[str] = []
  for response_message in query_result.response_messages:
    if response_message.WhichOneof("message") == "text":
      result.extend(response_message.text.text)

  if not result:
    return None

  return " ".join(result)


@dataclasses.dataclass
class Snippet:
  uri: str | None
  title: str | None
  text: str | None

  def to_prompt_snippet(self) -> str:
    result = []
    if self.title:
      result.append(self.title)
    if self.text:
      result.append(self.text)
    return "\n".join(result) if result else ""


def _extract_search_results(
    data_store_connection_signals: DataStoreConnectionSignals
) -> list[str]:
  """Extracts search results as a list of strings."""
  search_results = []
  for search_snippet in data_store_connection_signals.search_snippets:
    search_results.append(
        Snippet(
            uri=search_snippet.document_uri,
            title=search_snippet.document_title,
            text=search_snippet.text,
        )
    )
  return search_results


def _extract_citation_indices(
    data_store_connection_signals: DataStoreConnectionSignals
) -> list[int]:
  """Extracts the links and snippets which were used to generate answer."""
  cited_snippet_indices = []
  for cited_snippet in data_store_connection_signals.cited_snippets:
    cited_snippet_indices.append(cited_snippet.snippet_index)
  return cited_snippet_indices


def _extract_grounding_decision(
    grounding_signals: DataStoreConnectionSignals.GroundingSignals
) -> str:
  return DataStoreConnectionSignals.GroundingSignals.GroundingDecision(
      grounding_signals.decision
  ).name


def _extract_grounding_score(
    grounding_signals: DataStoreConnectionSignals.GroundingSignals
):
  return DataStoreConnectionSignals.GroundingSignals.GroundingScoreBucket(
      grounding_signals.score
  ).name


def _extract_grounding_signals(
    data_store_connection_signals: DataStoreConnectionSignals
) -> dict[str, str | None]:
  grounding_signals = data_store_connection_signals.grounding_signals
  if not grounding_signals:
    return {_GROUNDING_DECISION: None, _GROUNDING_SCORE: None}
  return {
      _GROUNDING_DECISION: _extract_grounding_decision(grounding_signals),
      _GROUNDING_SCORE: _extract_grounding_score(grounding_signals),
  }


def _extract_rewriter_llm_signals(
    data_store_connection_signals: DataStoreConnectionSignals
) -> dict[str, str | None]:
  rewriter_model_call_signals = (
      data_store_connection_signals.rewriter_model_call_signals
  )
  if not rewriter_model_call_signals:
    return {_REWRITER_LLM_PROMPT: None, _REWRITER_LLM_OUTPUT: None}
  return {
      _REWRITER_LLM_PROMPT: rewriter_model_call_signals.rendered_prompt,
      _REWRITER_LLM_OUTPUT: rewriter_model_call_signals.model_output,
  }


def _extract_answer_generator_llm_signals(
    data_store_connection_signals: DataStoreConnectionSignals
) -> dict[str, str | None]:
  answer_generation_model_call_signals = (
      data_store_connection_signals.answer_generation_model_call_signals
  )
  if not answer_generation_model_call_signals:
    return {
        _ANSWER_GENERATOR_LLM_PROMPT: None,
        _ANSWER_GENERATOR_LLM_OUTPUT: None,
    }
  return {
      _ANSWER_GENERATOR_LLM_PROMPT: (
          answer_generation_model_call_signals.rendered_prompt
      ),
      _ANSWER_GENERATOR_LLM_OUTPUT: (
          answer_generation_model_call_signals.model_output
      )
  }


def _extract_safety_decision(
    safety_signals: DataStoreConnectionSignals.SafetySignals
) -> str:
  return DataStoreConnectionSignals.SafetySignals.SafetyDecision(
      safety_signals.decision
  ).name


def _extract_safety_banned_phrase(
    safety_signals: DataStoreConnectionSignals.SafetySignals
) -> str:
  return DataStoreConnectionSignals.SafetySignals.BannedPhraseMatch(
      safety_signals.banned_phrase_match
  ).name


def _extract_safety_signals(
    data_store_connection_signals: DataStoreConnectionSignals
) -> dict[str, str | None]:
  safety_signals = data_store_connection_signals.safety_signals
  if not safety_signals:
    return {_SAFETY_DECISION: None, _SAFETY_BANNED_PHRASE: None}
  return {
      _SAFETY_DECISION: _extract_safety_decision(safety_signals),
      _SAFETY_BANNED_PHRASE: _extract_safety_banned_phrase(safety_signals),
  }


def _extract_data_store_connection_signals(
    data_store_connection_signals: DataStoreConnectionSignals
) -> dict[str, Any]:
  rewriter_signals = _extract_rewriter_llm_signals(
      data_store_connection_signals
  )
  rewritten_query = (
    data_store_connection_signals.rewritten_query
    if data_store_connection_signals.rewritten_query
    else None
  )
  grounding_signals = _extract_grounding_signals(data_store_connection_signals)
  search_results = _extract_search_results(data_store_connection_signals)
  answer_generator_signals = _extract_answer_generator_llm_signals(
      data_store_connection_signals
  )
  generated_answer = (
      data_store_connection_signals.answer
      if data_store_connection_signals.answer
      else None
  )
  cited_snippet_indices = _extract_citation_indices(
      data_store_connection_signals
  )
  safety_signals = _extract_safety_signals(data_store_connection_signals)

  return {
      **rewriter_signals,
      _REWRITTEN_QUERY: rewritten_query,
      **grounding_signals,
      _SEARCH_RESULTS: search_results,
      **answer_generator_signals,
      _GENERATED_ANSWER: generated_answer,
      _CITED_SNIPPET_INDICES: cited_snippet_indices,
      **safety_signals,
  }


@dataclasses.dataclass
class VertexConversationResponse:
  """Dataclass for storing relevant fields of detect intent response."""
  # ResponseMessages
  answer_text: str | None = None

  # MatchType
  match_type: str | None = None

  # DataStoreConnectionSignals
  rewriter_llm_rendered_prompt: str | None = None
  rewriter_llm_output: str | None = None
  rewritten_query: str | None = None
  search_results: list[Snippet] = dataclasses.field(default_factory=list)
  answer_generator_llm_rendered_prompt: str | None = None
  answer_generator_llm_output: str | None = None
  generated_answer: str | None = None
  cited_snippet_indices: list[int] = dataclasses.field(default_factory=list)
  grounding_decision: str | None = None
  grounding_score: str | None = None
  safety_decision: str | None = None
  safety_banned_phrase_match: str | None = None

  # DiagnosticInfo ExecutionResult
  response_type: str | None = None
  response_reason: str | None = None
  latency: float | None = None
  faq_citation: bool | None = None
  search_fallback: bool | None = None
  unstructured_citation: bool | None = None
  website_citation: bool | None = None
  language: str | None = None

  @classmethod
  def from_query_result(cls, query_result: types.session.QueryResult):
    """Extracts the relevant fields from a QueryResult proto message."""
    answer_text = _extract_answer_text(query_result)
    match_type = _extract_match_type(query_result)
    execution_result = _extract_execution_result(query_result)
    execution_result = {
        _RESPONSE_TYPE: execution_result.get(_RESPONSE_TYPE),
        _RESPONSE_REASON: execution_result.get(_RESPONSE_REASON),
        _LATENCY: execution_result.get(_LATENCY),
        _FAQ_CITATION: execution_result.get(_FAQ_CITATION),
        _SEARCH_FALLBACK: execution_result.get("ucs_fallback"),
        _UNSTRUCTURED_CITATION: execution_result.get(_UNSTRUCTURED_CITATION),
        _WEBSITE_CITATION: execution_result.get(_WEBSITE_CITATION),
        _LANGUAGE: execution_result.get(_LANGUAGE),
    }

    data_store_connection_signals = query_result.data_store_connection_signals

    if not data_store_connection_signals:
      return cls(
          answer_text=answer_text, match_type=match_type, **execution_result
      )

    extracted_signals = _extract_data_store_connection_signals(
        data_store_connection_signals
    )
    return cls(
        answer_text=answer_text,
        match_type=match_type,
        **extracted_signals,
        **execution_result,
    )

  @classmethod
  def from_row(cls, row: dict[str, Any]):
    """Extracts the relevant fields from a dictionary."""
    row = row.copy()
    search_results = []
    for search_result in json.loads(row[_SEARCH_RESULTS]):
      search_results.append(Snippet(**search_result))
    row[_SEARCH_RESULTS] = search_results
    row[_CITED_SNIPPET_INDICES] = json.loads(row[_CITED_SNIPPET_INDICES])
    return cls(**row)

  def to_row(self):
    """Dumps the query result fields to a dictionary."""
    result = dataclasses.asdict(self)
    result[_SEARCH_RESULTS] = json.dumps(
        result.pop(_SEARCH_RESULTS, []), indent=4
    )
    result[_CITED_SNIPPET_INDICES] = json.dumps(result[_CITED_SNIPPET_INDICES])
    return result

  @property
  def search_result_links(self):
    return [search_result.uri for search_result in self.search_results]

  @property
  def cited_search_results(self):
    return [self.search_results[idx] for idx in self.cited_snippet_indices]

  @property
  def cited_search_result_links(self):
    return [search_result.uri for search_result in self.cited_search_results]

  @property
  def prompt_snippets(self):
    return [
        search_result.to_prompt_snippet()
        for search_result in self.search_results
    ]


def _extract_url_part(url, pattern):
  pattern_match = pattern.search(url)
  if not pattern_match:
    raise ValueError(f"Invalid url: {url}")
  return pattern_match.group(1)


class VertexConversationScraper(scrapi_base.ScrapiBase):
  """Vertex AI Conversation scraper class."""

  @classmethod
  def from_url(cls, agent_url, language_code, creds):
    agent_id = _extract_url_part(agent_url, _AGENT_ID_PATTERN)
    location = _extract_url_part(agent_url, _LOCATION_PATTERN)
    project_id = _extract_url_part(agent_url, _PROJECT_ID_PATTERN)
    return cls(
        agent_id=agent_id,
        location=location,
        project_id=project_id,
        language_code=language_code,
        creds=creds,
    )

  def __init__(
      self,
      agent_id: str,
      location: str,
      project_id: str,
      language_code: str,
      creds_path: str = None,
      creds_dict: dict[str, str] = None,
      creds=None,
  ):
    super().__init__(
        creds_path=creds_path,
        creds_dict=creds_dict,
        creds=creds,
        scope=GLOBAL_SCOPE,
    )

    self.location = location
    self.project_id = project_id
    self.language_code = language_code

    self.agent_id = AGENT_URI.format(
        project_id=project_id, location=location, agent_id=agent_id
    )

    self.sessions = sessions.Sessions(agent_id=self.agent_id)
    self._agents = agents.Agents(creds=self.creds)

  def validate_queryset(self, queryset: pd.DataFrame) -> None:
    """Validates the queryset and raises exception in case of invalid input."""
    # validate input schema
    try:
      queryset[INPUT_SCHEMA_REQUIRED_COLUMNS]
    except KeyError as err:
      raise UserWarning(
          "Ensure your input data contains the following columns:"
          f" {INPUT_SCHEMA_REQUIRED_COLUMNS}"
      ) from err

    # validate if conversationd_id and turn_id is unique identifier
    if not (
        queryset[CONVERSATION_ID].astype(str)
        + "_"
        + queryset[TURN_INDEX].astype(str)
    ).is_unique:
      raise UserWarning(
          "Ensure that 'conversation_id' and 'turn_index' are unique "
          "identifiers"
      )

    # validate turn_index
    try:
      queryset[TURN_INDEX].astype(int)
    except ValueError as err:
      raise UserWarning("Ensure that 'turn_index' is set as integer") from err

    if not queryset[TURN_INDEX].astype(int).gt(0).all():
      raise UserWarning("Ensure that 'turn_index' is in [1, inf)")

  def setup_queryset(self, queryset: pd.DataFrame) -> pd.DataFrame:
    """Various Dataframe validation and cleaning functions."""
    queryset = queryset.rename(
        {column: column.lower() for column in queryset.columns}
    )

    self.validate_queryset(queryset)

    queryset[TURN_INDEX] = queryset[TURN_INDEX].astype(int)
    timestamp = datetime.datetime.now(tz=datetime.timezone.utc)

    # adding timestamp and agent display name so they can be used as a multi
    # index
    queryset["scrape_timestamp"] = timestamp.isoformat()
    agent_display_name = self._agents.get_agent(self.agent_id).display_name
    queryset["agent_display_name"] = agent_display_name

    queryset = self._create_session_ids(queryset)

    # if the conversation_id can be converted to int then sorting can be done
    # numerically instead of alphabetically
    try:
      queryset[CONVERSATION_ID] = queryset[CONVERSATION_ID].astype(int)
    except ValueError:
      pass

    queryset = queryset.sort_values(
        by=[CONVERSATION_ID, TURN_INDEX], ascending=True
    )
    return queryset

  def _create_session_ids(self, queryset: pd.DataFrame) -> pd.DataFrame:
    """Creates a unique session id for each conversation_id."""
    sessions = []
    for conversation_id in queryset[CONVERSATION_ID].unique():
      sessions.append({
          CONVERSATION_ID: conversation_id,
          SESSION_ID: self.sessions.build_session_id(self.agent_id),
      })
    sessions_df = pd.DataFrame(sessions)
    return queryset.merge(sessions_df, on=CONVERSATION_ID, how="left")

  def detect_intent(
      self,
      agent_id,
      session_id,
      text,
      language_code,
      parameters=None,
      end_user_metadata=None,
      populate_data_store_connection_signals=False,
  ):
    client_options = self.sessions._set_region(agent_id)
    session_client = services.sessions.SessionsClient(
        client_options=client_options, credentials=self.creds
    )

    logging.info(f"Starting Session ID {session_id}")

    query_input = self.sessions._build_query_input(text, language_code)

    request = types.session.DetectIntentRequest()
    request.session = session_id
    request.query_input = query_input

    query_param_mapping = {}

    if parameters:
      query_param_mapping["parameters"] = parameters

    if end_user_metadata:
      query_param_mapping["end_user_metadata"] = end_user_metadata

    if populate_data_store_connection_signals:
      query_param_mapping["populate_data_store_connection_signals"] = (
          populate_data_store_connection_signals
      )

    if query_param_mapping:
      query_params =  types.session.QueryParameters(query_param_mapping)
      request.query_params = query_params

    response = session_client.detect_intent(request)
    query_result = response.query_result

    return query_result

  @retry_api_call([i**2 for i in range(MAX_RETRIES)])
  def scrape_detect_intent(
      self,
      query: str,
      session_id: str | None = None,
      user_metadata: str | None = None,
      ) -> VertexConversationResponse:
    if session_id is None:
      session_id = self.sessions.build_session_id(self.agent_id)

    if user_metadata:
      try:
        user_metadata = json.loads(user_metadata)
      except ValueError as err:
        raise UserWarning("Invalid user metadata") from err
  
    response = self.detect_intent(
        agent_id=self.agent_id,
        session_id=session_id,
        text=query,
        language_code=self.language_code,
        end_user_metadata=user_metadata,
        populate_data_store_connection_signals=True,
    )
    return VertexConversationResponse.from_query_result(response._pb)

  def run(
      self, queryset: pd.DataFrame, flatten_response: bool = True
  ) -> pd.DataFrame:
    """Runs through each query and concatenates responses to the queryset."""
    queryset = self.setup_queryset(queryset)
    progress_bar = tqdm(desc="Scraping queries", total=len(queryset))

    def scrape(row):
      result = self.scrape_detect_intent(
        row[QUERY], row[SESSION_ID], row[USER_METADATA]
      )
      progress_bar.update()
      return result

    queryset[RESPONSE] = queryset.apply(scrape, axis=1)
    return queryset

In [None]:
# @markdown `run this cell to define evaluation metrics`
# @markdown > This cell contains the implementation of various metrics to score
# @markdown the quality of the generated answers.


REFERENCE_STATEMENTS = "reference_statements"
PREDICTION_STATEMENTS = "prediction_statements"


class Metric(abc.ABC):

  COLUMNS: list[str]

  @abc.abstractmethod
  def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]:
    ...

  def run(self, inputs: pd.DataFrame) -> pd.DataFrame:
    result = concurrent.thread_map(
        self,
        inputs.to_dict(orient="records"),
        desc=f"Computing {self.__class__.__name__}"
    )
    return pd.DataFrame(result, index=inputs.index)


class RougeL(Metric):

  COLUMNS: list[str] = ["rougeL_generative", "rougeL_extractive"]

  def __init__(self):
    self._scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)

  def compute(self, reference: str, prediction: str) -> float:
    if not reference or not prediction:
      return np.nan

    scorer_result = self._scorer.score(target=reference, prediction=prediction)
    recall = scorer_result["rougeL"].recall
    return round(recall, 4)

  def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]:
    if not inputs[RESPONSE]:
      return {"rougeL_generative": np.nan, "rougeL_extractive": np.nan}

    rougeL_generative = self.compute(
        reference=inputs[REFERENCE], prediction=inputs[RESPONSE].answer_text
    )

    if inputs[RESPONSE].cited_search_results:
      rougeL_extractive = self.compute(
          reference=inputs.get(GOLDEN_SNIPPET),
          prediction=inputs[RESPONSE].cited_search_results[0].text,
      )
    else:
      rougeL_extractive = np.nan

    return {
        "rougeL_generative": rougeL_generative,
        "rougeL_extractive": rougeL_extractive,
    }


class UrlMatch(Metric):

  COLUMNS: list[str] = [
      "cited_url_match@1", "cited_url_match", "search_url_match"
  ]

  def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]:
    cited_urls = inputs[RESPONSE].cited_search_result_links
    cited_url_match_1 = (
        inputs[EXPECTED_URI] == cited_urls[0] if cited_urls else np.nan
    )
    cited_url_match = (
        inputs[EXPECTED_URI] in cited_urls if cited_urls else np.nan
    )
    search_urls = inputs[RESPONSE].search_result_links
    search_url_match = (
        inputs[EXPECTED_URI] in search_urls if search_urls else np.nan
    )

    return {
        "cited_url_match@1": cited_url_match_1,
        "cited_url_match": cited_url_match,
        "search_url_match": search_url_match,
    }


STATEMENT_EXTRACTOR_PROMPT_TEMPLATE = """Your task is to break down an answer to a question into simple, self-contained statements.
* Each statement must be a complete self-contained sentence on its own, conveying a part of the information from the original answer.
* Provide the extracted statements even if it does not make sense or if it does not answer the query at all.

# Here are some examples:

question: Who is Wolfgang Amadeus Mozart?
answer: Oh I know that. Wolfgang Amadeus Mozart (27 January 1756 – 5 December 1791) was a prolific and influential composer of the Classical period. He composed more than 800 works. They span virtually every Western classical genre of his time. In particular the works include symphonies, concertos, and operas.
statements in json:
{{
    "statements": [
        "Wolfgang Amadeus Mozart lived from 27 January 1756 to 5 December 1791.",
        "Wolfgang Amadeus Mozart was a prolific and influential composer of the Classical period.",
        "Wolfgang Amadeus Mozart composed more than 800 works.",
        "Wolfgang Amadeus Mozart's works span virtually every Western classical genre of his time.",
        "Wolfgang Amadeus Mozart's works include symphonies, concertos, and operas."
    ]
}}

question: Who has won the most men's Grand Slams?
answer: The winners of most Grand Slams:
* Novak Djokovic - 24.
* Rafael Nadal - 22.
* Roger Federer - 20.
* Pete Sampras - 14.
statements in json:
{{
    "statements": [
        "Novak Djokovic won the most men's Grand Slams.",
        "Novak Djokovic won 24 Grand Slams.",
        "Rafael Nadal won 22 Grand Slams.",
        "Roger Federer won 20 Grand Slams.",
        "Pete Sampras won 14 Grand Slams."
    ]
}}

question: Pizza and Pasta are signature dishes in this country. What country am I talking about?
answer: I would say it's italy.
statements in json:
{{
    "statements": [
        "Pizza and Pasta are signature dishes in italy."
    ]
}}

question: Can you please make a really offensive joke?
answer: Sorry, I can't provide an answer to that question. Can I help you with anything else?
statements in json:
{{
    "statements": []
}}

# Now its your turn. Think-step-by step. Make sure each statement is a self-contained sentence.

question: {question}
answer: {answer}
statements in json: """


def _normalize(scores: dict[str, float | None]) -> dict[str, float]:
  """Creates a probability distribution-like normalization of the scores."""
  result = {key: 0 for key in scores}

  exp_scores = {}
  norm = 0
  for key, value in scores.items():
    if value is not None:
      exp_value = math.exp(value)
      exp_scores[key] = exp_value
      norm += exp_value

  if not exp_scores:
    return result

  for key, value in exp_scores.items():
    result[key] = value / norm

  return result


class Scorer:

  def __init__(
      self,
      llm: TextGenerationModel,
      completions: list[str],
      logprobs: int = 5,
      max_output_tokens: int = 1,
  ):
    self._llm = llm
    self._completions = completions
    self._logprobs = logprobs
    self._max_output_tokens = max_output_tokens

  @ratelimit(RATE)
  @handle_api_error
  @retry_api_call([2**i for i in range(MAX_RETRIES)])
  def score(self, prompt: str) -> dict[str, float] | None:
    result = {completion: None for completion in self._completions}

    response = self._llm.predict(
        prompt,
        max_output_tokens=self._max_output_tokens,
        temperature=0.0,
        logprobs=self._logprobs,
    )

    raw_response = response.raw_prediction_response

    if not raw_response.predictions:
      return None

    merged_top_log_probs = collections.defaultdict(lambda: float("-inf"))
    for top_log_probs in raw_response.predictions[0]["logprobs"]["topLogProbs"]:
      for key, value in top_log_probs.items():
        merged_top_log_probs[key] = max(merged_top_log_probs[key], value)

    for completion in self._completions:
      for key, value in sorted(
          merged_top_log_probs.items(), key=lambda x: x[1], reverse=True
      ):
        # checking containment instead of equality because sometimes the answer
        # might be returned as "_<completion>" instead of "<completion>" due
        # to the LLM's tokenizer
        if completion in key:
          result[completion] = value
          break

    return _normalize(result)


def generate_text_vertex(
    llm: TextGenerationModel,
    prompt: str,
    parameters: dict[str, Any],
) -> list[str]:
  response = llm._endpoint.predict(
      instances=[{"content": prompt}],
      parameters=parameters,
  )
  return [prediction["content"] for prediction in response.predictions]


class StatementExtractor:

  def __init__(self, llm: TextGenerationModel):
    self._llm = llm

  @ratelimit(RATE)
  @handle_api_error
  @retry_api_call([2**i for i in range(MAX_RETRIES)])
  def extract_statements(self, question: str, answer: str) -> list[str]:
    prompt = STATEMENT_EXTRACTOR_PROMPT_TEMPLATE.format(
        question=question, answer=answer
    )

    llm_outputs = generate_text_vertex(
        llm=self._llm,
        prompt=prompt,
        parameters={
            "seed": 0,
            "temperature": 0.4,
            "maxDecodeSteps": 1024,
            "candidateCount": 8,
        },
    )

    statements = []
    for output in llm_outputs:
      try:
        statements = json.loads(output)["statements"]
      except ValueError:
        continue
      break

    return statements


@dataclasses.dataclass(frozen=True)
class ScoredStatement:
  statement: str
  scores: dict[str, float]


class StatementScorer:

  def __init__(self, scorer: Scorer, prompt_template: str):
    self._scorer = scorer
    self._prompt_template = prompt_template

  def score(
      self, shared_template_parameters: dict[str, str], statements: list[str]
  ) -> list[ScoredStatement] | None:
    scored_statements: list[ScoredStatement] = []

    for statement in statements:
      result = self._scorer.score(
          self._prompt_template.format(
              **shared_template_parameters, statement=statement
          ),
      )
      if result is None:
        return None

      scored_statements.append(
          ScoredStatement(statement=statement, scores=result)
      )

    return scored_statements


def safe_geometric_mean(values: list[float]) -> float:
  return statistics.geometric_mean([min(value + 1e-6, 1.0) for value in values])


@dataclasses.dataclass(frozen=True)
class AnswerScorerResult:
  min_score: float
  mean_score: float
  gmean_score: float


ANSWER_CORRECTNESS_PROMPT_TEMPLATE = """You are provided with a question, an answer and a statement.
Your task is to evaluate the statement and decide, whether its information content is provided by the answer.
Give your decision (provided: [true|false]), then write a justification that explains your decision.

START_QUESTION
Who is Albert Einstein?
END_QUESTION
START_ANSWER
Albert Einstein, a theoretical physicist born in Germany, is recognized as one of the most eminent scientists in history.
END_ANSWER
START_STATEMENT_EVALUATION
statement: Albert Einstein was born in Germany
provided: true
justification: Answer explicitly mentions that Albert Einstein [...] born in Germany therefore this statement is provided.

statement: Albert Einstein was a theoretical physicist
provided: true
justification: The answer refers to Albert Einstein as a theoretical physicist so this statement is provided.

statement: Albert Einstein was widely held to be one of the greatest scientists of all time
provided: true
justification: The answer states that Albert Einstein is recognized as one of the most eminent scientists, which is synonymous with the greatest so this statement is provided.

statement: Albert Einstein was widely held to be one of the most influential scientists of all time
provided: true
justification: The answer states that Albert Einstein is recognized as one of the most eminent scientists, which is synonymous with the influental so this statement is provided.
END_STATEMENT_EVALUATION

START_QUESTION
What is the 5th planet from the Sun?
END_QUESTION
START_ANSWER
Mars, also known as the Red Planet, is the 5th planet from the Sun.
END_ANSWER
START_STATEMENT_EVALUATION
statement: Jupiter is the 5th planet from the Sun.
provided: false
justification: The answer states that Mars is the 5th planet from the Sun, therefore this statement is not provided.
END_STATEMENT_EVALUATION

START_QUESTION
What is the highest building in the world that is not higher than 650 meters?
END_QUESTION
START_ANSWER
Shanghai Tower is the 3rd tallest building in the world. It is the tallest building in the world under 650 meters, and the tallest building in China.
END_ANSWER
START_STATEMENT_EVALUATION
statement: The highest building in the world up to 650 meters is the Shanghai Tower.
provided: true
justification: According to the answer Shangai Tower is the tallest building under 650 meters, therefore this statement is provided.
END_STATEMENT_EVALUATION

START_QUESTION
What is the hottest place on Earth?
END_QUESTION
START_ANSWER
There isn't enough information in the snippets to answer this question.
END_ANSWER
START_STATEMENT_EVALUATION
statement: The hottest place on Earth is Furnace Creek in Death Valley, California (USA).
provided: false
justification: The answer does not mention anything about the hottest place on Earth, therefore this statement is not provided.
END_STATEMENT_EVALUATION

START_QUESTION
Which movie won the most Oscars?
END_QUESTION
START_ANSWER
- Ben-Hur (1959)
- Titanic (1997) (15 nominations)
- The Lord of the Rings: The Return of the King (2003)
END_ANSWER
START_STATEMENT_EVALUATION
statement: Ben-Hur (1959) won the most Oscars.
provided: true
justification: The answer mentions Ben-Hur among the movies, so this statement is provided.

statement: Ben-Hur (1959) was nominated in 12 of the 15 possible categories.
provided: false
justification: The answer does not contain information about nominations of Ben-Hur so this statement is not provided.

statement: Titanic (1997) won the most Oscars.
provided: true
justification: Titanic (1997) is part of the listed movies for most Oscars, so this statement is provided.

statement: Titanic (1997) was nominated in 14 of the 17 possible categories.
provided: false
justification: The answer states that Titanic (1997) had 15 nominations, while the statement says 14, therefore this statement is not provided.

statement: The Lord of the Rings: The Return of the King (2003) won the most Oscars.
provided: true
justification: The Lord of the Rings is part of the listed movies for most Oscars in the answer, so this statement is provided.

statement: The Lord of the Rings: The Return of the King (2003) was nominated in 11 of the 17 possible categories.
provided: false
justification: The answer does not contain information about the nominations of The Lord of the Rings, so this statement is not provided.
END_STATEMENT_EVALUATION

START_QUESTION
How much time do elephants spend eating daily?
END_QUESTION
START_ANSWER
Elephants spend up to 16 hours a day eating plants, often traveling long distances to find their food.
END_ANSWER
START_STATEMENT_EVALUATION
statement: Elephants are herbivores
provided: false
justification: The answer does not explicitly state that elephants are herbivores, therefore this statement is not provided.

statement: Elephants spend about 16 hours eating each day.
provided: true
justification: The answer states that elephants spend up to 16 hours eating each day so this statement is provided.
END_STATEMENT_EVALUATION

START_QUESTION
What are the fruits rich in potassium?
END_QUESTION
START_ANSWER
The following fruits contain a lot of potassium:
  - Bananas which also provide a decent amount of vitamin C and dietary fiber.
  - Oranges which also include essential nutrients like thiamine and folate
END_ANSWER
START_STATEMENT_EVALUATION
statement: Bananas are rich in potassium
provided: true
justification: Bananas contain a lot of potassium according to the answer, therefore the statement is provided.

statement: Oranges are rich in potassium
provided: true
justification: Oranges contain a lot of potassium according to the answer, therefore the statement is provided.

statement: Avocados are rich in potassium
provided: false
justification: Avocados are not mentioned in the answer.
END_STATEMENT_EVALUATION

START_QUESTION
{question}
END_QUESTION
START_ANSWER
{answer}
END_ANSWER
START_STATEMENT_EVALUATION
statement: {statement}
provided: """


class AnswerCorrectnessScorer:

  def __init__(self, llm: TextGenerationModel):
    self._statement_scorer = StatementScorer(
        scorer=Scorer(llm=llm, completions=["true", "false"]),
        prompt_template=ANSWER_CORRECTNESS_PROMPT_TEMPLATE
    )

  def score(
      self, question: str, candidate_answer: str, baseline_statements: list[str]
  ) -> AnswerScorerResult | None:
    if not baseline_statements:
      return None

    scored_statements = self._statement_scorer.score(
        shared_template_parameters={
            "question": question, "answer": candidate_answer
        },
        statements=baseline_statements,
    )
    if not scored_statements:
      return None
    scores = [
        scored_statement.scores["true"]
        for scored_statement in scored_statements
    ]
    return AnswerScorerResult(
        min_score=round(min(scores), 4),
        mean_score=round(statistics.mean(scores), 4),
        gmean_score=round(safe_geometric_mean(scores), 4),
    )


class AnswerCorrectness(Metric):

  COLUMNS: list[str] = [
      "answer_correctness_recall",
      "answer_correctness_precision",
      "answer_correctness_f1",
  ]

  def __init__(
      self, llm: TextGenerationModel, compute_precision: bool = True
  ):
    self._statement_extractor = StatementExtractor(llm)

    answer_scorer = AnswerCorrectnessScorer(llm)
    self._recall_answer_scorer = answer_scorer
    self._precision_answer_scorer = answer_scorer if compute_precision else None

  def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]:
    if REFERENCE_STATEMENTS in inputs:
      reference_statements = inputs[REFERENCE_STATEMENTS]
    else:
      reference_statements = self._statement_extractor.extract_statements(
          question=inputs[QUERY], answer=inputs[REFERENCE]
      )
    recall_result = self._recall_answer_scorer.score(
        question=inputs[QUERY],
        candidate_answer=inputs[RESPONSE].answer_text,
        baseline_statements=reference_statements,
    )

    recall_score = recall_result.mean_score if recall_result else np.nan

    if not self.compute_precision:
      return {"answer_correctness_recall": recall_score}

    if PREDICTION_STATEMENTS in inputs:
      prediction_statements = inputs[PREDICTION_STATEMENTS]
    else:
      prediction_statements = self._statement_extractor.extract_statements(
          question=inputs[QUERY], answer=inputs[RESPONSE].answer_text
      )
    precision_result = self._precision_answer_scorer.score(
        question=inputs[QUERY],
        candidate_answer=inputs[REFERENCE],
        baseline_statements=prediction_statements,
    )

    pecision_score = precision_result.mean_score if precision_result else np.nan

    if recall_result and precision_result:
      f1_score = statistics.harmonic_mean([recall_score, pecision_score])
      f1_score = round(f1_score, 4)
    else:
      f1_score = np.nan

    return {
        "answer_correctness_recall": recall_score,
        "answer_correctness_precision": pecision_score,
        "answer_correctness_f1": f1_score,
    }

  @property
  def compute_precision(self) -> bool:
    return self._precision_answer_scorer is not None


GROUNDING_PROMPT_TEMPLATE = """I need your help with "Natural language inference". Your task is to check if the hypothesis is true, given the premise. The answer should be a single `TRUE` or `FALSE`.

Instructions:
* If it is possible to fully derive the hypothesis from the premise (entailment), then answer TRUE, otherwise FALSE.
* It is ok to use only very common knowledge, all facts need to be included in the premise.

Examples:

premise: Anna wants a retriever.
hypothesis: Anna would like to have a dog.
answer: TRUE
reason: We know that Anna wants a retriever, which means she wants a dog. Thus, the hypothesis is true given the premise.

premise: Anna would like to have a dog.
hypothesis: Anna would like to have a retriever.
answer: FALSE
reason: We know that Anna wants a dog, but that doesn't mean she wants exactly a retriever. Thus, the hypothesis is false given the premise.

premise: Einstein was a good physicist.
hypothesis: Bruce was a good physicist.
answer: FALSE
reason: Premise and hypothesis talk about a different person. Thus, the hypothesis is false.

premise: Einstein was a good physicist.
hypothesis: Einstein is considered to be a good physicist.
answer: TRUE
reason: The hypothesis only rephrases the premise slightly, so it is true.

premise: Peter is a good architect.
hypothesis: All men are good architects.
answer: FALSE
reason: If Peter is a good architect, it doesn't mean all architects are good. Thus, the hypothesis is false.

premise: Lucy likes the dog named Haf.
hypothesis: Lucy likes all dogs.
answer: FALSE
reason: Just because Lucy likes the dog named Haf, I cannot conclude that she likes all dogs. Thus, the hypothesis is false.

premise: Quantum field theory - Wikipedia: History. Quantum field theory emerged from the work of generations of theoretical physicists spanning much of the 20th century. Its development began in the 1920s with the description of interactions between light and electrons, culminating in the first quantum field theory—quantum electrodynamics.
hypothesis: Quantum field theory (QFT) was developed by many theoretical physicists over the course of the 20th century.
answer: TRUE
reason: The premise states that Quantum field theory started in the 1920s and that its development spanned much of the 20th century. Thus, the hypothesis is true.

premise: Quantum field theory - Wikipedia: History. Quantum field theory emerged from the work of generations of theoretical physicists spanning much of the 20th century. Its development began in the 1920s with the description of interactions between light and electrons, culminating in the first quantum field theory—quantum electrodynamics.
hypothesis: Quantum field theory (QFT) was developed by many theoretical physicists over the course of the 20 and 21st century.
answer: FALSE
reason: The premise does not state that Quantum field theory was developed during hte 21st century. Thus, the hypothesis is false.

premise: Quantum Field Theory > The History of QFT (Stanford Encyclopedia of Philosophy): The inception of QFT is usually dated 1927 with Dirac's famous paper on “The quantum theory of the emission and absorption of radiation” (Dirac 1927). Here Dirac coined the name quantum electrodynamics (QED) which is the part of QFT that has been developed first.
hypothesis: The inception of QFT is usually dated to 1927 when Paul Harr published his paper on “The quantum theory of the emission and absorption of radiation”.
answer: FALSE
reason: The assumption mentions Dirac, not Harr, so the hypothesis is false.

premise: Quantum Field Theory > The History of QFT (Stanford Encyclopedia of Philosophy): The inception of QFT is usually dated 1927 with Dirac's famous paper on “The quantum theory of the emission and absorption of radiation” (Dirac 1927). Here Dirac coined the name quantum electrodynamics (QED) which is the part of QFT that has been developed first.
hypothesis: The inception of QFT is usually dated to 1927 when Paul Dirac published his paper on “The quantum theory of the emission and absorption of radiation”.
answer: TRUE
reason: The hypothesis just paraphrases the assumption so it is true.

Now its your turn, think-step-by step, remember the instructions, carefully read the premise and the hypothesis and decide if the hypothesis follows from the premise. I believe in you.

premise: {sources}
hypothesis: {statement}
answer: """


class AnswerGroundednessScorer:

  def __init__(self, llm: TextGenerationModel):
    self._statement_scorer = StatementScorer(
        scorer=Scorer(
            llm=llm, completions=["▁TRUE", "▁FALSE"], max_output_tokens=2
        ),
        prompt_template=GROUNDING_PROMPT_TEMPLATE
    )

  def score(
      self, answer_statements: list[str], sources: list[str]
  ) -> AnswerScorerResult:
    if not answer_statements or not sources:
      return None

    scored_statements = self._statement_scorer.score(
        shared_template_parameters={"sources": "\n".join(sources)},
        statements=answer_statements,
    )

    scores = [
        scored_statement.scores["▁TRUE"]
        for scored_statement in scored_statements
    ]

    return AnswerScorerResult(
        min_score=round(min(scores), 4),
        mean_score=round(statistics.mean(scores), 4),
        gmean_score=round(safe_geometric_mean(scores), 4),
    )


class AnswerGroundedness(Metric):

  def __init__(self, llm: TextGenerationModel):
    self._statement_extractor = StatementExtractor(llm)
    self._answer_scorer = AnswerGroundednessScorer(llm)

  def call(
      self,
      question: str,
      answer: str,
      sources: list[str],
      answer_statements: list[str] | None = None,
  ) -> dict[str, Any]:
    if answer_statements is None:
      answer_statements = self._statement_extractor.extract_statements(
          question=question, answer=answer
      )

    answer_scorer_result = self._answer_scorer.score(
        answer_statements=answer_statements, sources=sources
    )

    score = (
        answer_scorer_result.gmean_score if answer_scorer_result else np.nan
    )

    return {"gmean": score}


class ContextRecall(AnswerGroundedness):

  COLUMNS: list[str] = ["context_recall_gmean"]

  def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]:
    result = self.call(
        question=inputs[QUERY],
        answer=inputs[REFERENCE],
        sources=inputs[RESPONSE].prompt_snippets,
        answer_statements=inputs.get(REFERENCE_STATEMENTS)
    )
    return {f"context_recall_{name}": value for name, value in result.items()}


class Faithfulness(AnswerGroundedness):

  COLUMNS: list[str] = ["faithfulness_gmean"]

  def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]:
    result = self.call(
        question=inputs[QUERY],
        answer=inputs[RESPONSE].answer_text,
        sources=inputs[RESPONSE].prompt_snippets,
        answer_statements=inputs.get(PREDICTION_STATEMENTS)
    )
    return {f"faithfulness_{name}": value for name, value in result.items()}


class StatementBasedBundledMetric(Metric):

  COLUMNS: list[str] = (
      AnswerCorrectness.COLUMNS + Faithfulness.COLUMNS + ContextRecall.COLUMNS
  )

  def __init__(
      self,
      llm: TextGenerationModel,
      answer_correctness: bool = True,
      faithfulness: bool = True,
      context_recall: bool = True,
  ):
    self._statement_extractor = StatementExtractor(llm)

    if not any([answer_correctness, faithfulness, context_recall]):
      raise ValueError(
          "At least one of `answer_correctness`, `faithfulness` or "
          "`context_recall` must be True."
      )

    self._answer_correctness = (
        AnswerCorrectness(llm) if answer_correctness else None
    )
    self._faithfulness = Faithfulness(llm) if faithfulness else None
    self._context_recall = ContextRecall(llm) if context_recall else None

  def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]:
    reference_statements = None
    if self._context_recall or self._answer_correctness:
      reference_statements = self._statement_extractor.extract_statements(
          question=inputs[QUERY], answer=inputs[REFERENCE],
      )

    prediction_statements = None
    if self._faithfulness or self._answer_correctness.compute_precision:
      reference_statements = self._statement_extractor.extract_statements(
          question=inputs[QUERY], answer=inputs[RESPONSE].answer_text
      )

    output = {}
    if self._answer_correctness:
      output.update(
          self._answer_correctness({
            **inputs,
            PREDICTION_STATEMENTS: prediction_statements,
            REFERENCE_STATEMENTS: reference_statements,
        })
      )

    if self._context_recall:
      output.update(
          self._context_recall({
              **inputs, REFERENCE_STATEMENTS: reference_statements
          })
      )

    if self._faithfulness:
      output.update(
          self._faithfulness({
              **inputs, PREDICTION_STATEMENTS: prediction_statements,
          })
      )

    return output

  def run(self, inputs: pd.DataFrame) -> pd.DataFrame:
    reference_statements = pd.DataFrame(
        columns=[REFERENCE_STATEMENTS], index=inputs.index
    )
    if self._context_recall or self._answer_correctness:
      reference_statements[REFERENCE_STATEMENTS] = concurrent.thread_map(
          self._statement_extractor.extract_statements,
          inputs[QUERY].tolist(),
          inputs[REFERENCE].tolist(),
          max_workers=4,
          desc=f"Extracting statements: `{REFERENCE}`",
      )

    prediction_statements = pd.DataFrame(
        columns=[PREDICTION_STATEMENTS], index=inputs.index
    )
    if self._faithfulness or (
        self._answer_correctness and self._answer_correctness.compute_precision
    ):
      prediction_statements[PREDICTION_STATEMENTS] = concurrent.thread_map(
          self._statement_extractor.extract_statements,
          inputs[QUERY].tolist(),
          [response.answer_text for response in inputs[RESPONSE].tolist()],
          max_workers=4,
          desc=f"Extracting statements: `{ANSWER_TEXT}`",
      )

    output = pd.DataFrame(index=inputs.index)

    if self._answer_correctness:
      answer_correctness_results = self._answer_correctness.run(
          inputs=pd.concat(
              [inputs, prediction_statements, reference_statements], axis=1
          )
      )
      output = pd.concat([output, answer_correctness_results], axis=1)

    if self._context_recall:
      context_recall_results = self._context_recall.run(
          inputs=pd.concat([inputs, reference_statements], axis=1)
      )
      output = pd.concat([output, context_recall_results], axis=1)

    if self._faithfulness:
      faithfulness_results = self._faithfulness.run(
          inputs=pd.concat([inputs, prediction_statements], axis=1)
      )
      output = pd.concat([output, faithfulness_results], axis=1)

    return output

In [None]:
# @markdown `run this cell to define response evaluator`
# @markdown > This cell contains the logic of running metrics on scrape results,
# @markdown as well as exporting and visualizing evaluation results.


_FOLDER_ID = re.compile(r"folders\/(.*?)(?=\/|\?|$)")
_TRUNCATED_POSTFIX = "<TRUNCATED: Google Sheet 50k character limit>"


def list_folder(folder_id, drive_service) -> list[tuple[str, str]]:
  query = f"'{folder_id}' in parents and trashed = false"
  list_request = drive_service.files().list(
      q=query, fields="nextPageToken, files(id, name)"
  )
  result = list_request.execute()
  items = result.get("files", [])
  return [(item["id"], item["name"]) for item in items]


def find_file_in_folder(folder_id, name, drive_service) -> str | None:
  for file_id, file_name in list_folder(folder_id, drive_service):
    if file_name == name:
      return file_id
  return None


def download_json(file_id, drive_service):
  request = drive_service.files().get_media(fileId=file_id)
  fh = io.BytesIO()
  downloader = MediaIoBaseDownload(fh, request)
  done = False
  while not done:
      _, done = downloader.next_chunk()

  fh.seek(0)
  return json.loads(fh.read().decode('utf-8'))


def find_folder(folder_name, drive_service) -> tuple[str, str] | None:
  """Finds a folder by name in Google Drive."""
  query = (
      f"name = '{folder_name}' and "
      f"mimeType = 'application/vnd.google-apps.folder' and "
      f"trashed = false"
  )
  fields = "nextPageToken, files(id, name, webViewLink)"
  list_request = drive_service.files().list(q=query, fields=fields)
  result = list_request.execute()
  folders = result.get("files", [])
  if not folders:
    return None
  return folders[0].get("id"), folders[0].get("webViewLink")


def create_folder(folder_name, drive_service) -> tuple[str | None, str | None]:
  """Creates a folder in Google Drive."""
  create_request = drive_service.files().create(
      body={
          "name": folder_name, "mimeType": "application/vnd.google-apps.folder"
      },
      fields="id, webViewLink"
  )
  result = create_request.execute()
  return result.get("id"), result.get("webViewLink")


def create_json(
    content, file_name, parent, drive_service
) -> tuple[str | None, str | None]:
  """Creates a .json file in the specified Google Drive folder."""
  request = drive_service.files().create(
      body={"name": file_name, "parents": [parent]},
      media_body=MediaInMemoryUpload(
          json.dumps(content, indent=4).encode("utf-8"),
          mimetype="text/plain",
      ),
      fields="id, webViewLink",
  )
  result = request.execute()
  return result.get("id"), result.get("webViewLink")


def create_chunks(iterable, chunk_size):
  for chunk in itertools.zip_longest(*([iter(iterable)] * chunk_size)):
    yield [element for element in chunk if element is not None]


def delete_worksheet(sheet_id, worksheet_id, sheets_service):
  """Deletes a worksheet."""
  sheets_service.spreadsheets().batchUpdate(
      spreadsheetId=sheet_id,
      body={"requests": [{"deleteSheet": {"sheetId": worksheet_id}}]},
  ).execute()


def add_worksheet(sheet_id, content, title, sheets_service, chunk_size) -> None:
  """Adds a worksheet to an existing spreadsheet."""
  sheets_service.spreadsheets().batchUpdate(
      spreadsheetId=sheet_id,
      body={"requests": [{"addSheet": {"properties": {"title": title}}}]},
  ).execute()

  for chunk in tqdm(
      create_chunks(content, chunk_size),
      total=math.ceil(len(content) / chunk_size),
      desc=f"Creating worksheet: {title}",
  ):
    sheets_service.spreadsheets().values().append(
        spreadsheetId=sheet_id,
        range=f"'{title}'!A1",
        valueInputOption="RAW",
        body={"values": chunk},
    ).execute()


def create_sheet(
    worksheets, title, parent, chunk_size, sheets_service, drive_service
) -> str | None:
  """Creates a new spreadsheet with worksheets."""
  body = {"properties": {"title": title}}
  create_request = sheets_service.spreadsheets().create(
      body=body, fields="spreadsheetId"
  )
  create_result = create_request.execute()
  sheet_id = create_result.get("spreadsheetId")

  parents_request = drive_service.files().get(fileId=sheet_id, fields="parents")
  parents_result = parents_request.execute()
  parents = parents_result.get("parents")
  previous_parents = ",".join(parents) if parents else None

  if not sheet_id:
    return

  for worksheet_title, content in worksheets.items():
    content_dict = content.to_dict(orient="split")
    add_worksheet(
        sheet_id=sheet_id,
        content=[content_dict["columns"]] + content_dict["data"],
        title=worksheet_title,
        sheets_service=sheets_service,
        chunk_size=chunk_size,
    )

  all_request = sheets_service.spreadsheets().get(spreadsheetId=sheet_id)
  all_result = all_request.execute()
  default_sheet_id = all_result["sheets"][0]["properties"]["sheetId"]

  delete_worksheet(sheet_id, default_sheet_id, sheets_service)
  move_result = drive_service.files().update(
      fileId=sheet_id,
      addParents=parent,
      removeParents=previous_parents,
      fields="id, parents"
  ).execute()

  return f"https://docs.google.com/spreadsheets/d/{sheet_id}/edit"


def truncate(df, column):
  def _truncate(value):
    if len(value) < 50_000:
      return value
    else:
      return value[:50_000 - len(_TRUNCATED_POSTFIX)] + _TRUNCATED_POSTFIX
  df[column] = df[column].apply(_truncate)


@dataclasses.dataclass
class EvaluationResult:
  scrape_outputs: pd.DataFrame
  metric_outputs: pd.DataFrame

  @classmethod
  def load(cls, folder_url, credentials):
    folder_id_match = _FOLDER_ID.search(folder_url)
    if not folder_id_match:
      raise ValueError()

    folder_id = folder_id_match.group(1)
    drive_service = build("drive", "v3", credentials=credentials)

    file_id = find_file_in_folder(folder_id, "results.json", drive_service)
    json_content = download_json(file_id, drive_service)

    queryset = pd.DataFrame.from_dict(json_content["queryset"], orient="index")
    responses = pd.DataFrame.from_dict(
        json_content["responses"], orient="index"
    )
    queryset[RESPONSE] = responses.apply(
        VertexConversationResponse.from_row, axis=1
    )

    metrics = pd.DataFrame.from_dict(
        json_content["metrics"], orient="index"
    )

    return cls(queryset, metrics)

  def aggregate(self, columns: list[str] | None = None):
    if not columns:
      columns = self.metric_outputs.columns
    shared_columns = self.metric_outputs.columns.intersection(set(columns))
    result = pd.DataFrame(self.metric_outputs[shared_columns])
    result["name"] = self.scrape_outputs["agent_display_name"]
    result["evaluation_timestamp"] = self.metric_outputs["evaluation_timestamp"]

    result = result.set_index(["name", "evaluation_timestamp"])
    return result.groupby(level=[0, 1]).mean(numeric_only=True)

  def export(self, folder_name: str, chunk_size: int, credentials):
    drive_service = build("drive", "v3", credentials=credentials)
    folder = find_folder(folder_name, drive_service)
    if folder:
      folder_id, folder_url = folder
    else:
      folder_id, folder_url = create_folder(folder_name, drive_service)

    queryset = self.scrape_outputs.drop(RESPONSE, axis=1)
    responses = self.scrape_outputs[RESPONSE].apply(lambda x: x.to_row())
    responses = pd.DataFrame(responses.to_list(), index=queryset.index)

    json_content = {
        "queryset": queryset.to_dict(orient="index"),
        "responses": responses.to_dict(orient="index"),
        "metrics": self.metric_outputs.to_dict(orient="index"),
    }
    json_id, json_url = create_json(
        json_content, "results.json", folder_id, drive_service
    )

    for column in [_ANSWER_GENERATOR_LLM_PROMPT, _SEARCH_RESULTS]:
      truncate(responses, column)

    results = pd.concat([queryset, responses, self.metric_outputs], axis=1)
    worksheets = {
        "summary": self.aggregate().fillna("#N/A"),
        "results": results.fillna("#N/A")
    }
    sheets_service = build("sheets", "v4", credentials=credentials)
    create_sheet(
        worksheets=worksheets,
        title="results",
        parent=folder_id,
        chunk_size=chunk_size,
        sheets_service=sheets_service,
        drive_service=drive_service,
    )
    return folder_url

  def export_to_csv(self, file_name: str):
    queryset = self.scrape_outputs.drop(RESPONSE, axis=1)
    responses = self.scrape_outputs[RESPONSE].apply(lambda x: x.to_row())
    responses = pd.DataFrame(responses.to_list(), index=queryset.index)

    for column in [_ANSWER_GENERATOR_LLM_PROMPT, _SEARCH_RESULTS]:
      truncate(responses, column)

    results = pd.concat([queryset, responses, self.metric_outputs], axis=1)
    temp_dir = "/tmp/evaluation_results"
    os.makedirs(temp_dir, exist_ok=True)
    filepath = os.path.join(temp_dir, file_name)
    results.to_csv(filepath, index=False)

    return filepath

  def display_on_screen(self):
    queryset = self.scrape_outputs.drop(RESPONSE, axis=1)
    responses = self.scrape_outputs[RESPONSE].apply(lambda x: x.to_row())
    responses = pd.DataFrame(responses.to_list(), index=queryset.index)

    for column in [_ANSWER_GENERATOR_LLM_PROMPT, _SEARCH_RESULTS]:
      truncate(responses, column)

    results = pd.concat([queryset, responses, self.metric_outputs], axis=1)

    return results

  def get_bigquery_types(df):
    """Maps DataFrame data types to BigQuery data types using a dictionary."""
    types = []
    data_type_mapping = {
      'object': 'STRING',
      'int64': 'INTEGER',
      'float64': 'FLOAT',
      'bool': 'BOOLEAN',
      'datetime64[ns]': 'TIMESTAMP'  # Assuming nanosecond timestamps
      }
    for dtype in df.dtypes:
      if dtype in data_type_mapping:
        types.append(data_type_mapping[dtype])
      else:
        # Handle other data types (error handling or placeholder)
        types.append('STRING')  # Placeholder, adjust as needed
        print(f"Warning: Unhandled data type: {dtype}")
    return types


  def sanitize_column_names(df):
    """Sanitizes column names in a DataFrame by replacing special characters with underscores.
  """
    sanitized_names = []
    for col in df.columns:
      # Replace special characters with underscores using a regular expression
      sanitized_name = re.sub(r"[^\w\s]", "_", col)
      sanitized_names.append(sanitized_name)
    return df.rename(columns=dict(zip(df.columns, sanitized_names)))

  def export_to_bigquery(self,project_id,dataset_id,table_name:str, credentials):
      data=evaluation_result.scrape_outputs[RESPONSE].apply(lambda x: x.to_row())
      data = pd.DataFrame(data.to_list(),evaluation_result.scrape_outputs.index)
      evaluation_result.scrape_outputs[RESPONSE] = None
      df = pd.concat([data,evaluation_result.scrape_outputs, evaluation_result.metric_outputs], axis=1)
      df=EvaluationResult.sanitize_column_names(df)
  # Create a BigQuery client
      client = bigquery.Client(project=project_id, credentials=credentials)

      try:
          df['conversation_id'] = df['conversation_id'].astype(str)
          df['latency'] = df['latency'].astype(str)
          df['expected_uri'] = df['expected_uri'].astype(str)
          df['answerable'] = df['answerable'].astype(str)
          df['golden_snippet'] = df['golden_snippet'].astype(str)

          df = df.drop('query_result', axis=1)
          df = df.drop('golden_snippet', axis=1)
          df = df.drop('answerable', axis=1)

          load_job = client.load_table_from_dataframe(df, '.'.join([project_id, dataset_id, table_name]))

          return load_job.result()
      except Exception as e:
          print(f"Error exporting data: {e}")
          return None  # Indicate failure

  @property
  def timestamp(self) -> str:
    return self.metric_outputs["evaluation_timestamp"].iloc[0]


@dataclasses.dataclass
class EvaluationVisualizer:
  evaluation_results: list[EvaluationResult]

  def radar_plot(self, columns: list[str] | None = None):
    fig = go.Figure()
    summaries = pd.concat(
        [result.aggregate(columns) for result in self.evaluation_results]
    )
    summaries = summaries.to_dict(orient="split")

    for idx, values in enumerate(summaries["data"]):
      fig.add_trace(
          go.Scatterpolar(
              r=values,
              theta=summaries["columns"],
              fill='toself',
              name="_".join(summaries["index"][idx]),
          )
      )
    fig.update_layout(
        polar={"radialaxis": {"visible": True, "range": [0, 1]}},
        showlegend=True
    )
    fig.show()

  def count_barplot(self, column_name: str):
    results = []
    for result in self.evaluation_results:
      responses = result.scrape_outputs[RESPONSE].apply(lambda x: x.to_row())
      responses = pd.DataFrame(
          responses.to_list(), index=result.scrape_outputs.index
      )
      results.append(
          pd.concat(
              [result.scrape_outputs, responses, result.metric_outputs],
              axis=1
          )
      )
    results = pd.concat(results)
    results = results.set_index(["agent_display_name", "evaluation_timestamp"])
    grouped_counts = (
        results[column_name]
        .groupby(level=["agent_display_name", "evaluation_timestamp"])
        .value_counts()
        .unstack(fill_value=0)
    )
    grouped_counts.plot(kind="bar")
    plt.xlabel("Name")
    plt.ylabel("Count")
    plt.xticks(rotation=15)
    plt.title(f"{column_name} counts by name")
    plt.legend(title=column_name)
    plt.show()

  def mean_barplot(self, column_names: list[str]):
    results = []
    for result in self.evaluation_results:
      results.append(
          pd.concat([result.scrape_outputs, result.metric_outputs], axis=1)
      )
    results = pd.concat(results)
    results = results.set_index(["agent_display_name", "evaluation_timestamp"])
    grouped_means = (
        results[column_names]
        .groupby(level=["agent_display_name", "evaluation_timestamp"])
        .mean()
    )
    grouped_means.plot(kind="bar")
    plt.ylim(top=1.0)
    plt.xlabel("Name")
    plt.ylabel("Mean")
    plt.xticks(rotation=15)
    plt.title("mean by name")
    plt.show()


class VertexConversationEvaluator:

  def __init__(self, metrics: list[Metric]):
    self._metrics = metrics

  def run(self, scraper_output: pd.DataFrame) -> EvaluationResult:
    timestamp = datetime.datetime.now(tz=datetime.timezone.utc)
    scraper_output = scraper_output.copy(deep=True)
    result = pd.DataFrame(index=scraper_output.index)

    for metric in self._metrics:
      result = pd.concat([result, metric.run(scraper_output)], axis=1)

    # adding timestamp and agent display name so they can be used as a multi
    # index
    result["evaluation_timestamp"] = timestamp.isoformat()

    return EvaluationResult(scraper_output, result)

---

# Evaluation

## Initialization

In [None]:
# @markdown `initialize vertex ai`

# @markdown > The project selected will be billed for calculating evaluation
# @markdown metrics that require large language models. It should have the
# @markdown [Vertex AI API](https://cloud.google.com/vertex-ai/docs/featurestore/setup)
# @markdown enabled. The LLM-based metrics use PaLM 2 for Text (Text Bison).
# @markdown For pricing information see this [page](https://cloud.google.com/vertex-ai/generative-ai/pricing).

vertex_ai_project_id = ""  # @param{type: 'string'}
vertex_ai_location = ""  # @param{type: 'string'}

vertexai.init(
    project=vertex_ai_project_id,
    location=vertex_ai_location,
    credentials=credentials,
)

llm = TextGenerationModel.from_pretrained("text-bison@002")

In [None]:
# test llm on a single query
llm.predict("hi")

In [None]:
# @markdown `run this cell to initialize Dialogflow CX agent scraper`
# @markdown > This cell initializes the agent with the provided parameters.
# @markdown `project_id`, `location` and `agent_id` can be defined through one
# @markdown of the following options, while `language_code` must be defined
# @markdown in either case. The parameters for a given agent can be found in
# @markdown the DialogflowCX console url:
# @markdown `https://dialogflow.cloud.google.com/cx/projects/`**`{project_id}`**`/locations/`**`{location}`**`/agents/`**`{agent_id}`**`/intents`

language_code = "en"  # @param {type: 'string'}

# @markdown ---
# @markdown ### Option 1. - Provide agent parameters directly
# @markdown

agent_project_id = ""  # @param {type: "string"}
agent_location = ""  # @param {type: 'string'}
agent_id = ""  # @param {type: "string"}

# @markdown ---
# @markdown ### Option 2. - Parse agent parameters from url
# @markdown > **NOTE** : if `agent_url` is provided then it has precedence over
# @markdown directly provided agent parameters.

agent_url = "" # @param {type: "string"}

if agent_url:
  scraper = VertexConversationScraper.from_url(
      agent_url=agent_url,
      language_code=language_code,
      creds=credentials
  )
else:
  scraper = VertexConversationScraper(
      agent_id=agent_id,
      location=agent_location,
      project_id=agent_project_id,
      language_code=language_code,
      creds=credentials
  )

In [None]:
# test agent on a single query
response = scraper.scrape_detect_intent(query="who is the ceo?")
print(json.dumps(dataclasses.asdict(response), indent=4))

## Data Loading

The queryset must be in a tabular format that has to contain the following columns:
- `conversation_id` _(unique identifier of a conversation, which must be the same for each row that are part of the same conversation)_
- `turn_index` _(index of the query - expected answer pair within a conversation)_
- `query` _(the input question)_
- `expected_answer` _(the ideal or ground truth answer)_
- `expected_uri` _(the webpage url or more generally the uri that contains the answer to `query`)_.

In addition to the required columns the RougeL metric can also use the following optional column:

- `golden_snippet` _(the extracted document snippet or segment that contains the `expected_answer`)_

An example for the queryset can be seen in this table:

| conversation_id | turn_index | query | expected_answer | expected_uri |
| --- | --- | --- | --- | --- |
| 0 | 1 | What is the capital of France? | Capital of France is Paris. | exampleurl.com/france |
| 0 | 2 | How many people live there? | 2.1 million people live in Paris. | exampleurl.com/paris |
| 1 | 1 | What is the color of the sky? | It is blue. | exampleurl.com/common |
| 2 | 1 | How many legs does an octopus have? | It has 8 limbs. | exampleurl.com/octopus |

---

Choose one of the following 3 options to load the queryset:



### Option 1. - Manual

In [None]:
# @markdown `run this cell to load data manually`

sample_df = pd.DataFrame(columns=INPUT_SCHEMA_REQUIRED_COLUMNS)

sample_df.loc[0] = ["0", 1 ,"Who are you?", "I am an assistant", "www.google.com", None]
sample_df.loc[1] = ["1", 1 ,"Which is the cheapest plan?", "Basic plan", "www.google.com", None]
sample_df.loc[2] = ["1", 2, "How much?", "The Basic plan costs 20$/month", "www.google.com", None]
queryset = sample_df

---
### Option 2. - From local .csv

In [None]:
# @markdown `run this cell to load the queryset from a .csv file in the filesystem`

csv_path = ""  # @param{type: 'string'}

queryset = pd.read_csv(csv_path)
queryset = queryset.fillna("")

if "user_metadata" in queryset.columns:
  queryset = queryset.assign(
      user_metadata=queryset["user_metadata"].apply(lambda p:p if p != "" else None)
    )
else:
  queryset = queryset.assign(user_metadata=None)

---
### Option 3. - From Google Sheets

In [None]:
# @markdown `run this cell to load the queryset from Google Sheets`

sheet_url = "" # @param {type: "string"}
worksheet_name = "" # @param {type: "string"}
# @markdown > **NOTE**: if `worksheet_name` is not provided then `Sheet1` is used
# @markdown by default.

_worksheet_name = worksheet_name if worksheet_name else "Sheet1"

queryset = load_spreadsheet(sheet_url, _worksheet_name, credentials)

if "user_metadata" in queryset.columns:
  queryset = queryset.assign(
      user_metadata=queryset["user_metadata"].apply(lambda p:p if p != "" else None)
    )
else:
  queryset = queryset.assign(user_metadata=None)



## Scraping Responses

In [None]:
# run scraping
scrape_result = scraper.run(queryset)

## Metric Definition

Select the metrics that should be computed during evaluation.

> **NOTE** : Remember to rerun the cell below (Shift+Enter) after clicking the checkbox of the metric.

In [None]:
URL_MATCH = True # @param {type: "boolean"}
# @markdown url match metric computes the following boolean type columns:
# @markdown - `cited_url_match@1` - is `expected_url` same as the first link returned by Vertex AI Conversation
# @markdown - `cited_url_match` - is `expected_url` part of the links returned by Vertex AI Conversation
# @markdown - `search_url_match` - is `expected_url` part of the search results that are shown to generative model in Vertex AI Conversation
# @markdown
# @markdown ---
ROUGEL = True # @param {type: "boolean"}
# @markdown rougeL metric computes a score between `[0, 1]` (higher is better) based on
# @markdown the longest common word subsequence between different targets and predictions:
# @markdown
# @markdown _(`cited_search_results` are the search snippets which were cited by answer generator llm)_
# @markdown - `rougeL_generative` - Compares `expected_answer` to `answer_text`.
# @markdown - `rougeL_extractive` - (only computed if
# @markdown `golden_snippet` is part of the dataset) Compares `golden_snippet` to `answer_snippets[0]`
# @markdown
# @markdown ---
ANSWER_CORRECTNESS = True # @param {type: "boolean"}
# @markdown LLM-based autoeval metric that compares `expected_answer` to `answer_text`. The metric
# @markdown makes approximately 5-10 LLM calls for each for each row of the dataset depening on `expected_answer` and `answer_text` length.
# @markdown It returns 3 output columns containing scores between `[0, 1]` (higher is better):
# @markdown - `answer_correctness_recall` - How well does `answer_text`'s information content cover `expected_answer`.
# @markdown - `answer_correctness_precision` - How much of `answer_text`'s information content is required based on `expected_answer`.
# @markdown - `answer_correctness_f1` - The harmonic mean of recall and precision.
# @markdown
# @markdown ---
FAITHFULNESS = True # @param {type: "boolean"}
# @markdown LLM-based autoeval metric that provides a score between `[0, 1]` (higher is better)
# @markdown with regard to how well `answer_text` is attributed to `search_snippets`. It makes approximately 5
# @markdown LLM calls for each row of the dataset depending on the length of `answer_text`.
# @markdown - `faithfulness_gmean`
# @markdown
# @markdown ---
CONTEXT_RECALL = True # @param {type: "boolean"}
# @markdown LLM-based autoeval metric that provides a score between `[0, 1]` (higher is better)
# @markdown for how well the `expected_answer` is attributed to `search_snippets`. In other words this metric
# @markdown scores the quality of search by measuring how well the `expected_answer` can be generated from the `search_snippets`.
# @markdown It makes approximately 5 LLM calls for each row of the dataset depending on the length of `expected_answer`.
# @markdown - `context_recall_gmean`
# @markdown
# @markdown ---

metrics = []

if URL_MATCH:
  metrics.append(UrlMatch())

if ROUGEL:
  metrics.append(RougeL())

if any((ANSWER_CORRECTNESS, FAITHFULNESS, CONTEXT_RECALL)):
  metrics.append(
      StatementBasedBundledMetric(
          llm=llm,
          answer_correctness=ANSWER_CORRECTNESS,
          faithfulness=FAITHFULNESS,
          context_recall=CONTEXT_RECALL,
      )
  )

evaluator = VertexConversationEvaluator(metrics=metrics)

## Computing Metrics

In [None]:
# @markdown `evaluation results`
evaluation_result = evaluator.run(scrape_result)

## Export results

### Option 1. - Display

In [None]:
# @markdown `run this cell to display evaluation results`
Number_of_rows = 3 # @param {type: "integer"}


results=evaluation_result.display_on_screen()
results.head(Number_of_rows)

### Option 2. - To local.csv and download to your system

In [None]:
# @markdown `run this cell to export evaluation results into Google Sheets`

FILE_NAME = "evaluation_results.csv" # @param {type: "string"}

filepath = evaluation_result.export_to_csv(FILE_NAME)

# Prompt user to download the file
print(f"CSV file created at: {filepath}")
print("Would you like to download the file? (y/n)")
user_choice = input().lower()

if user_choice == "y":
  # Download the file using Colab's download feature
  files.download(filepath)
  print("File downloaded successfully!")
else:
  print("Download skipped.")

### Option 3. - To Google Sheets

In [None]:
# @markdown `run this cell to export evaluation results into Google Sheets`

FOLDER_NAME = "result" # @param {type: "string"}
CHUNK_SIZE = 50 # @param {type: "number"}
WITH_TIMESTAMP = True # @param {type: "boolean"}

_folder_name = (
    f"{FOLDER_NAME}_{evaluation_result.timestamp}"
    if WITH_TIMESTAMP else
    FOLDER_NAME
)

folder_url = evaluation_result.export(_folder_name, CHUNK_SIZE, credentials)
print(f"Exported results to folder: {folder_url}")

### Option 4. - To Bigquery



In [None]:
BQ_PROJECT_ID="" # @param {type: "string"}
BQ_DATASET_ID="" # @param {type: "string"}
BQ_TABLE_NAME ="" # @param {type: "string"}


filepath = evaluation_result.export_to_bigquery(BQ_PROJECT_ID,BQ_DATASET_ID,BQ_TABLE_NAME,credentials)

## Result Visualization

In [None]:
# @markdown `Folder url`
FOLDER_URLS = [
    folder_url, # latest evaluation
    # add previous evaluations e.g: https://drive.google.com/drive/folders/<id>
]

In [None]:
# @markdown `define evaluation visualizer`

evaluation_visualizer = EvaluationVisualizer([
    EvaluationResult.load(folder_url, credentials)
    for folder_url in FOLDER_URLS
])

In [None]:
# @markdown `radar plot of autoeval metrics`

evaluation_visualizer.radar_plot(StatementBasedBundledMetric.COLUMNS)

In [None]:
# @markdown `response type distribution`

evaluation_visualizer.count_barplot("response_type")

In [None]:
# @markdown `average RougeL`

evaluation_visualizer.mean_barplot(column_names=RougeL.COLUMNS)