In [None]:
from ssec_tutorials import OLMO_MODEL

In [None]:
from pathlib import Path
from qdrant_client import QdrantClient
from uuid import uuid4

In [None]:
import panel as pn
from langchain.llms import LlamaCpp
from langchain.schema.runnable import RunnablePassthrough
from langchain_core.callbacks import CallbackManager
from langchain_core.prompts import PromptTemplate
from langchain_community.vectorstores import Qdrant
from langchain.embeddings import HuggingFaceEmbeddings

repo_root = Path("../..").resolve()

# template = """<|user|>
# You are an astrophysics expert. Answer the question based only on the following context:

# {context}

# Question: {question}
# <|user|>"""
qdrant_path = repo_root / "scipy_qdrant"
qdrant_collection = "arxiv_astro-ph_abstracts"

# prompt = PromptTemplate.from_template(template)
embedding = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L12-v2")

pn.extension()

model_path = OLMO_MODEL

In [None]:
@pn.cache
def get_vector_store():
    # If the Qdrant Vector Database Collection already exists, load it
    client = QdrantClient(path=str(qdrant_path))
    db = Qdrant(
        client=client,
        collection_name=qdrant_collection,
        embeddings=embedding
    )
    return db

In [None]:
db = get_vector_store()

In [None]:
def get_chain(callbacks):
    retriever = db.as_retriever(callbacks=callbacks, search_type="similarity", search_kwargs={"k": 3})
    # Callbacks support token-wise streaming
    callback_manager = CallbackManager(callbacks)
    model = LlamaCpp(
        model_path=str(model_path),
        callback_manager=callback_manager,
        temperature=0.8,
        n_ctx=2048,
        max_tokens=512,
        verbose=False,
        echo=False
    )
    prompt = PromptTemplate.from_template(
        template=model.client.metadata['tokenizer.chat_template'],
        template_format="jinja2"
    )

    def create_format(input_dict):
        context = input_dict.get('context')
        question = input_dict.get('question')
        return dict(
            add_generation_prompt=True,
            messages=[
                {"role": "user", "content": f"""
                    You are an astrophysics expert. Answer the question based only on the following context:

                    {context}

                    Question: {question}"""
                }
            ])

    def format_docs(docs):
        text = "\n\n".join([d.page_content for d in docs])
        return text

    def hack(docs):
        # https://github.com/langchain-ai/langchain/issues/7290
        for callback in callbacks:
            callback.on_retriever_end(docs, run_id=uuid4())
        return docs

    return (
        # NOTE: THIS BREAKS HERE... Can't quite pass in the piping in dict!
        {"context": retriever | hack | format_docs, "question": RunnablePassthrough()}
        | create_format
        | prompt
        | model
    )

In [None]:
async def callback(contents, user, instance):
    callback_handler = pn.chat.langchain.PanelCallbackHandler(instance, user='OLMo', avatar='🌳')
    # Not return the result at the end of the generation
    # this prevents the model from repeating the result
    callback_handler.on_llm_end = lambda response, *args, **kwargs: None
    chain = get_chain(callbacks=[callback_handler])
    response = await chain.ainvoke(contents)

In [None]:
pn.chat.ChatInterface(callback=callback).servable()