In [52]:
!pip install mlx-lm
!pip install -q langchain-cohere  langchain pypdf faiss-cpu python-dotenv

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


⚠️ Langchain + MLX integration is still under review.

You can check the progress here: https://github.com/langchain-ai/langchain/pull/18152

In [17]:
!pip install -q git+https://github.com/Blaizzy/langchain.git@pc/mlx#subdirectory=libs/community --use-pep517

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [1]:
from operator import itemgetter

from langchain_community.chat_models.mlx import ChatMLX
from langchain_community.llms.mlx_pipeline import MLXPipeline

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain.output_parsers import XMLOutputParser
from langchain_core.output_parsers import StrOutputParser


from langchain.schema import (
    HumanMessage,
    SystemMessage,
    AIMessage
)

# Router

In [2]:
router_llm = MLXPipeline.from_model_id(
    "mlx-community/Mistral-7B-Instruct-v0.2-4bit",
    pipeline_kwargs={"max_tokens": 200, "temp": 0.1},
)

tools = [
    {
        "type": "function",
        "function": {
            "name": "retriever",
            "description": "Useful for retrieving factual documents and information from a database or API about a user's request to answer their queries.",
            "parameters": {
                "type": "object",
                "properties": {
                    "user_query": {
                        "type": "string",
                        "description": "The query to retrieve and ground information for."
                    },

                },
                "required": ["user_query"]
            }
        }
    },
    {
        "type": "function",
        "function": {
            "name": "direct_response",
            "description": "Useful for providing a direct response without retrieving additional information.",
            "parameters": {
                "type": "object",
                "properties": {
                    "user_query": {
                        "type": "string",
                    }
                },
                "required": ["user_query"]
            }
        }
    }
]


Fetching 7 files:   0%|          | 0/7 [00:00<?, ?it/s]

In [15]:
template = """ You are Mistral with function-calling supported. You are provided with function signatures within <tools></tools> XML tags.
    You may call one function to assist with the user query. Don't make assumptions about what values to plug into functions.
    Here are the available tools:
    <tools>
    {tools}
    </tools>

    For each function call, return a XML object with the function name and arguments within <tool_call></tool_call> XML tags as follows:
    <tool_call>
        <function_name></function_name>
        <arguments></arguments>
    </tool_call>

    If the user question requires factual answers, use the retriever tool only.

    <tool_call>
        <function_name>retriever</function_name>
        <arguments>
        </arguments>
    </tool_call>

    Please respond using XML ONLY.


    Question: {question}
"""


prompt = ChatPromptTemplate.from_template(template)
router = ChatMLX(llm=router_llm)

In [11]:
router_chain = (
    {"tools": itemgetter("tools"), "question": RunnablePassthrough()}
    | prompt
    | router
    | XMLOutputParser()
)

In [5]:
router_output = router_chain.invoke({"tools": tools, "question": "What's the weather like in Maputo?"})

In [6]:
router_output

{'response': [{'tool_call': [{'function_name': 'retriever'},
    {'arguments': [{'user_query': "What's the weather like in Maputo?"}]}]}]}

In [None]:
router_output = router_chain.invoke({"tools": tools, "question": "Hi, how are you??"})

In [13]:
router_output

{'Root': [{'tool_call': [{'function_name': 'direct_response'},
    {'arguments': [{'user_query': 'Hi, how are you?'}]}]}]}

In [54]:
router_output = router_chain.invoke({"tools": tools, "question": "Explain how Gemma model works?"})

In [55]:
router_output

{'response': [{'tool_call': [{'function_name': 'retriever'},
    {'arguments': [{'user_query': 'Explain how Gemma model works'}]}]}]}

# Retriever

In [2]:
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_cohere import CohereEmbeddings
from dotenv import load_dotenv

load_dotenv()  # take environment variables from .env.

def load_file(file_name, file_type):
    loader = PyPDFLoader(f"./assets/{file_name}.{file_type}")

    text_splitter = RecursiveCharacterTextSplitter(
        # Set a really small chunk size, just to show.
        chunk_size=5000,
        chunk_overlap=20,
        length_function=len,
        is_separator_regex=False,
    )
    return loader.load_and_split(text_splitter)


documents = load_file('gemma-report', 'pdf')
vectorstore = FAISS.from_documents(
    documents, embedding=CohereEmbeddings(model="embed-english-light-v3.0")
)
retriever = vectorstore.as_retriever()

In [3]:
retriever.get_relevant_documents("Gemma")

[Document(page_content='Gemma: Open Models Based on Gemini Research and Technology\ndevelopment ecosystem will enable downstream\ndevelopers to create a host of beneficial appli-\ncations, in areas such as science, education and\nthe arts. Our instruction-tuned offerings should\nencourage a range of developers to leverage\nGemma’s chat and code capabilities to support\ntheir own beneficial applications, while allowing\nforcustomfine-tuningtospecializethemodel’sca-\npabilities for specific use cases. To ensure Gemma\nsupports a wide range of developer needs, we are\nalso releasing two model sizes to optimally sup-\nportdifferentenvironments,andhavemadethese\nmodels available across a number of platforms\n(seeKagglefordetails). Providingbroadaccessto\nGemma in this way should reduce the economic\nand technical barriers that newer ventures or in-\ndependent developers face when incorporating\nthese technologies into their workstreams.\nAs well as serving developers with our\ninstruction-t

# Tool calling + RAG

In [None]:
gemma_llm = MLXPipeline.from_model_id(
    "mlx-community/gemma-1.1-2b-it-4bit",
    pipeline_kwargs={"max_tokens": 100, "temp": 0.1},
)

gemma_chat_model = ChatMLX(llm=gemma_llm)

In [5]:
def format_docs(docs):
    formatted_docs = []
    for i, doc in enumerate(docs):
        formatted_doc = f"<doc id='{i}'>{doc.page_content}</doc>"
        formatted_docs.append(formatted_doc)
    return "\n".join(formatted_docs)


def create_retriever_chain():
    template = """Answer the question based only on the following context:
            {context}

            Question: {question}
        """
    prompt = ChatPromptTemplate.from_template(template)

    chain = (
        {
            "context": retriever | RunnableLambda(format_docs),
            "question": RunnablePassthrough(),
        }
        | prompt
        | gemma_chat_model
        | StrOutputParser()
    )

    return chain


def tool_call(function_name, arguments):
    if function_name == 'get_weather':

        return "The weather is good!"

    if function_name == 'retriever':
        question = arguments['user_query']['question']
        retriever_chain = create_retriever_chain()
        output = retriever_chain.invoke({"question": question})

        return output.content

    else:
        messages = [
            HumanMessage(
                content=arguments['user_query']['question']
            ),
        ]
        output = gemma_chat_model.invoke(messages)

        return output.content

In [None]:
function_name= router_output["tool_call"]["function_name"]
args = router_output["tool_call"]["arguments"]
tool_call_out = tool_call(function_name, args)