In [None]:
! pip install -q transformers openai sentence-transformers datasets langchain==0.0.310

In [None]:
! pip install -q -U google-api-python-client

In [None]:
import os

from langchain.retrievers.you import YouRetriever
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI


os.environ["YDC_API_KEY"] = "YOUR YOU.COM API KEY"
os.environ["OPENAI_API_KEY"] = "YOUR OPENAI API KEY"
yr = YouRetriever()
model = "gpt-3.5-turbo-16k"
qa = RetrievalQA.from_chain_type(llm=ChatOpenAI(model=model), chain_type="map_reduce", retriever=yr)

In [None]:
from datasets import load_dataset


ds = load_dataset("hotpot_qa", "fullwiki")["train"]

In [None]:
from langchain.utilities import GoogleSearchAPIWrapper


os.environ["GOOGLE_CSE_ID"] = "Your Google CSE ID"
os.environ["GOOGLE_API_KEY"] = "Your Google API Key"
search = GoogleSearchAPIWrapper()

def top10_results(query):
    return search.results(query, 10)

In [None]:
from langchain.schema.retriever import BaseRetriever, Document
from typing import TYPE_CHECKING, Any, Dict, List, Optional 
from langchain.callbacks.manager import CallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun


class GoogleRetriever(BaseRetriever):
    def __int__(self):
        pass

    def _get_relevant_documents(
            self, query: str, *, run_manager: CallbackManagerForRetrieverRun
    ) -> List[Document]:
        return [Document(page_content=result.get("snippet", "")) for result in top10_results(query)]

    async def _aget_relevant_documents(
            self,
            query: str,
            *,
            run_manager: AsyncCallbackManagerForRetrieverRun,
            **kwargs: Any,
    ) -> List[Document]:
        raise NotImplementedError()

In [None]:
goog_qa = RetrievalQA.from_chain_type(
    llm=ChatOpenAI(model=model), chain_type="map_reduce", retriever=GoogleRetriever()
)

In [None]:
SAMPLE_SIZE = 100
pds = ds.to_pandas()
pds_sample = pds.sample(SAMPLE_SIZE).reset_index()

In [None]:
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm


def parallel_progress_apply(column, callback, num_workers):
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        return list(tqdm(executor.map(callback, column), total=len(column)))

In [None]:
def get_run_chain_function(chain):
    def run_chain(example):
        try:
            return chain(example)["result"]
        except:
            return ""
    return run_chain

In [None]:
pds_sample["ydc_prediction"] = parallel_progress_apply(
    pds_sample["question"], lambda x: get_run_chain_function(qa)(x), num_workers=8
)

In [None]:
# Can't use parallel calls here because Google API so slow :/
pds_sample["google_prediction"] = pds_sample["question"].apply(get_run_chain_function(goog_qa))

In [None]:
import re
import string
from collections import Counter


# This is all ripped from hotpot_qa source code with minor modifications to only return the f1 instead of the (P,R,F1) tuple
# https://github.com/hotpotqa/hotpot/blob/master/hotpot_evaluate_v1.py#L26
def calculate_f1_score(prediction, ground_truth):
    normalized_prediction = normalize_answer(prediction)
    normalized_ground_truth = normalize_answer(ground_truth)

    ZERO_METRIC = (0, 0, 0)

    if (
        normalized_prediction in ["yes", "no", "noanswer"]
        and normalized_prediction != normalized_ground_truth
    ):
        return 0
    if (
        normalized_ground_truth in ["yes", "no", "noanswer"]
        and normalized_prediction != normalized_ground_truth
    ):
        return 0

    prediction_tokens = normalized_prediction.split()
    ground_truth_tokens = normalized_ground_truth.split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def normalize_answer(s):
    def remove_articles(text):
        return re.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def exact_match_score(prediction, ground_truth):
    return normalize_answer(prediction) == normalize_answer(ground_truth)


def filter_wiki_citation(snip):
    return not snip.startswith("- ^")

In [None]:
pds_sample["ydc_f1"] = parallel_progress_apply(
    list(pds_sample.iterrows()),
    lambda x: calculate_f1_score(x[1]["ydc_prediction"], x[1]["answer"]),
    num_workers=8,
)
pds_sample["google_f1"] = parallel_progress_apply(
    list(pds_sample.iterrows()),
    lambda x: calculate_f1_score(x[1]["google_prediction"], x[1]["answer"]),
    num_workers=8,
)

In [None]:
print("YDC F1")
print(pds_sample["ydc_f1"].mean())
print("Google F1")
print(pds_sample["google_f1"].mean())