In [90]:
!pip install -qU langchain langchain-community langchain_chroma

In [91]:
CREDENTIALS = 'Yjg4MTQzMmUtNDAwMS00NDk0LThjOGUtNmU5ZWQ2YzQ4NDQ2OmQ4MWMxZGZiLTFmNGYtNDk5NS05OGQzLTBiMzYyYWJmNjk3OA=='
TESTPDF = "../data/papers/10.1002@solr.201900061.pdf"

In [92]:
import json
import os
import sys

from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import GigaChatEmbeddings
from langchain.agents import AgentExecutor, ZeroShotAgent
from langchain.tools import Tool
from langchain_community.chat_models import GigaChat
from spotipy.oauth2 import SpotifyClientCredentials
from langchain.tools.retriever import create_retriever_tool

from langchain_chroma import Chroma
from PyPDF2 import PdfReader
# from langchain_community.document_loaders import TextLoader, PDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter

In [93]:
from typing import Any
def extract_raw_text_from_pdf(path) -> str:
    reader = PdfReader(stream=path)

    raw_text = ''
    for _, page in enumerate(reader.pages):
        text = page.extract_text()
        if text:
            raw_text += text

    return raw_text

def get_index_from_pdf(pdf_path) -> Any:
    text_splitter = RecursiveCharacterTextSplitter(
        # separator="\n",
        chunk_size=1000,
        chunk_overlap=200,
        length_function=len,
    )

    raw_text = extract_raw_text_from_pdf(path=pdf_path)
    texts = text_splitter.split_text(text=raw_text)
    embeddings = GigaChatEmbeddings(credentials=CREDENTIALS, verify_ssl_certs=False, scope='GIGACHAT_API_CORP')
    index = FAISS.from_texts(texts=texts, embedding=embeddings)
    return index

def invoke_chain_with_index(chain, index, query) -> dict:
    query = "For each of these {Spiro HTM, Spiro-CB, Spiro-THF} report efficiency (PCE or optimized efficiency or η). Put that data in the markdown table with columns 'HTM' - 'PCE'"
    docs = index.similarity_search(query)
    return chain.invoke({"input_documents": docs, "question": query})

In [94]:
path = "../data/papers/2014/2014_1.pdf"

from langchain.chains import create_retrieval_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain

CONTEXT_BASED_Q_TEMPLATE = """"Answer the following question based only on the provided context:

<context>
{context}
</context>

Question: {input}
"""

def get_doc_chain(llm, template, prompt):
    prompt = ChatPromptTemplate.from_template(template)
    return create_stuff_documents_chain(llm, prompt)
def get_retriever(index, doc_chain):
    return index.as_retriever()
def get_retrieval_chain(index, doc_chain):
    retriever = index.as_retriever()
    return create_retrieval_chain(retriever, doc_chain)

In [95]:
retriever_tool = create_retriever_tool(
    get_retriever,
    "state-of-union-retriever",
    "Query a retriever to get information about state of the union address",
)

In [96]:
from typing import List

from langchain_core.pydantic_v1 import BaseModel, Field

class Response(BaseModel):
    """Final response to the question being asked"""

    answer: str = Field(description="The final answer to respond to the user")
    sources: List[int] = Field(
        description="List of page chunks that contain answer to the question. Only include a page chunk if it contains relevant information"
    )

In [97]:
import json
from langchain_core.agents import AgentActionMessageLog, AgentFinish

In [98]:
def parse(output):
    # If no function was invoked, return to user
    if "function_call" not in output.additional_kwargs:
        return AgentFinish(return_values={"output": output.content}, log=output.content)

    # Parse out the function call
    function_call = output.additional_kwargs["function_call"]
    name = function_call["name"]
    inputs = json.loads(function_call["arguments"])

    # If the Response function was invoked, return to the user with the function inputs
    if name == "Response":
        return AgentFinish(return_values=inputs, log=str(function_call))
    # Otherwise, return an agent action
    else:
        return AgentActionMessageLog(
            tool=name, tool_input=inputs, log="", message_log=[output]
        )

In [99]:
from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad import format_to_openai_function_messages
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.chat_models.gigachat import GigaChat

In [100]:
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "You are a helpful assistant"),
        ("user", "{input}"),
        MessagesPlaceholder(variable_name="agent_scratchpad"),
    ]
)

giga = GigaChat(credentials=CREDENTIALS, 
                verify_ssl_certs=False,
                scope='GIGACHAT_API_CORP',
                model="GigaChat-Pro",
                )

In [101]:
llm_with_tools = giga.bind_functions([retriever_tool, Response])

In [102]:
agent = (
    {
        "input": lambda x: x["input"],
        # Format agent scratchpad from intermediate steps
        "agent_scratchpad": lambda x: format_to_openai_function_messages(
            x["intermediate_steps"]
        ),
    }
    | prompt
    | llm_with_tools
    | parse
)

In [103]:
agent_executor = AgentExecutor(tools=[retriever_tool], agent=agent, verbose=True)


In [None]:
agent_executor.invoke( # this code causes error 422 when executed
    {"input": "Did the device reach the PCE over 35%?"},
    return_only_outputs=True,
)
