In [None]:
from langchain_community.document_loaders import WebBaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from pprint import pprint
import uuid
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain.storage import InMemoryByteStore
from langchain_community.vectorstores import Chroma
from langchain.retrievers.multi_vector import MultiVectorRetriever
from ragatouille import RAGPretrainedModel
import requests
from langchain import hub
from langchain_core.runnables import RunnablePassthrough

## Multi representaion Indexing

### Web文書取得

In [None]:
loader = WebBaseLoader("https://lilianweng.github.io/posts/2023-06-23-agent/")
docs = loader.load()

loader = WebBaseLoader("https://lilianweng.github.io/posts/2024-02-05-human-data-quality/")
docs.extend(loader.load())

print('文書数:', len(docs))
pprint(docs)

### 複数文書に対する要約処理のバッチ実行

In [None]:
chain = (
    {"doc": lambda x: x.page_content}
    | ChatPromptTemplate.from_template("Summarize th following document:\n\n{doc}")
    | ChatOpenAI(model="gpt-3.5-turbo", max_retries=0)
    | StrOutputParser()
)

summaries = chain.batch(docs, {"max_concurrency": 5})

In [None]:
pprint(summaries[0])
pprint('*'*80)
pprint(summaries[1])

### データベースの作成

In [None]:
vectorstore = Chroma(collection_name="summaries", embedding_function=OpenAIEmbeddings())

store = InMemoryByteStore()
id_key ="doc_id"

retriever = MultiVectorRetriever(
    vectorstore=vectorstore,
    byte_store=store,
    id_key=id_key,
)

doc_ids = [str(uuid.uuid4()) for _ in docs]

# 要約文書
summary_docs = [
    Document(page_content=s, metadata={id_key: doc_ids[i]}) for i, s, in enumerate(summaries)
]

# 要約文書のベクトルを追加
retriever.vectorstore.add_documents(summary_docs)

# 元の文書の追加(要約文書のベクトルとはidで紐づけられている)
retriever.docstore.mset(list(zip(doc_ids, docs)))

### 要約文書のベクトル検索による「要約」文書の取得

In [None]:
query = "Memory in agents"
sub_docs = vectorstore.similarity_search(query, k=1)
# sub_docs = retriever.vectorstore.similarity_search(query, k=1)
pprint(sub_docs[0])

### 要約文書のベクトル検索による「元」文書の取得

In [None]:
retrieved_docs = retriever.get_relevant_documents(query, n_results=1)
print(retrieved_docs[0].page_content[0:500])

要約文書についてはベクトル化して検索に利用し、元文書は実際に返す文書にする、といったように1つの文書を複数の形で利用する。

## RAPTOR(Recursive Abstractive Processing for Tree-Organized Retrieval)

与えられた複数文書に対して、文書をまとめてクラスタにして要約することを繰り返して様々な抽象度の文書を作成し、  
それらを埋め込みを作成することによって様々な粒度の質問に対応して文書を検索できるようにする手法。  

## ColBERT

### ColBertモデル取得

In [None]:
RAG = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")

### Wikipedia記事取得

In [None]:
def get_wikipedia_page(title: str):
    """
    Retrieve the full text content of a Wikipedia page.

    :param title: str - Title of the Wikipedia page.
    :return: str - Full text content of the page as raw string.
    """
    # Wikipedia API endpoint
    URL = "https://en.wikipedia.org/w/api.php"

    # Parameters for the API request
    params = {
        "action": "query",
        "format": "json",
        "titles": title,
        "prop": "extracts",
        "explaintext": True,
    }

    # Custom User-Agent header to comply with Wikipedia's best practices
    headers = {"User-Agent": "RAGatouille_tutorial/0.0.1 (ben@clavie.eu)"}

    response = requests.get(URL, params=params, headers=headers)
    data = response.json()

    # Extracting page content
    page = next(iter(data["query"]["pages"].values()))
    return page["extract"] if "extract" in page else None

full_document = get_wikipedia_page("Hayao_Miyazaki")

### Index作成

In [None]:
RAG.index(
    collection=[full_document],
    index_name="Miyazaki-123",
    max_document_length=180,
    split_documents=True,
)

### ragatouilleのRAGの検索確認

In [None]:
results = RAG.search(query="What animation studio did Miyazaki found?", k=3)
results

### Langchainリトリーバへの変換、確認

In [None]:
retriever = RAG.as_langchain_retriever(k=3)
print(retriever.invoke("What animation studio did Miyazaki found?"))

### チェーン作成、回答確認

In [None]:
prompt = hub.pull("rlm/rag-prompt")
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)

chain = (
   {'context': retriever, 'question': RunnablePassthrough()}
   | prompt
   | llm
   | StrOutputParser()
)

print(chain.invoke("What animation studio did Miyazaki found?"))