In [3]:
import pandas as pd
import numpy as np
import os
import re
from tqdm import tqdm
import typing as tp
from langchain_community.vectorstores import FAISS
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from langchain_huggingface.llms import HuggingFacePipeline
from langchain_community.llms import HuggingFacePipeline
from langchain_community.document_loaders import PyPDFLoader, TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import START, MessagesState, StateGraph
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, AIMessage
from langchain.chains import (
    create_history_aware_retriever,
    create_retrieval_chain,
)
from langgraph.prebuilt import ToolNode

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
RAG_DB_PATH = 'faiss'
SCORE_THRESHOLD = 1.0

In [5]:
pdf_dir = "data/predator-pray-22/pdfs"
pdf_docs = []
for file in os.listdir(pdf_dir):
    if file.endswith(".pdf"):
        loader = PyPDFLoader(os.path.join(pdf_dir, file))
        pdf_docs.extend(loader.load())

code_dir = "data/predator-pray-22/code"
code_docs = []
for file in os.listdir(code_dir):
    if file.endswith(".java"):
        loader = TextLoader(os.path.join(code_dir, file), encoding="utf-8")
        code_docs.extend(loader.load())

all_docs = pdf_docs + code_docs 

In [6]:
splitter = RecursiveCharacterTextSplitter(
    chunk_size=5000,
    chunk_overlap=200,
    separators = [
        "\n/**",      # Javadoc start
        "\n/*",       # Block comment
        "\n//",       # Line comment
        "\nclass ",   # Java class declaration
        "\ninterface ",  # Java interface declaration
        "\npublic ",  # public method/field
        "\nprivate ", # private method/field
        "\nprotected ", # protected method/field
        "\nstatic ",  # static method or field
        "\nvoid ",    # method with no return
        "\nint ",     # common return type
        "\nString ",  # String declarations
        "\n",         # fallback: line break
        " "           # fallback: space
    ]
)
split_docs = splitter.split_documents(all_docs)


In [7]:
embedding_model = HuggingFaceEmbeddings(
    model_name="BAAI/bge-small-en-v1.5",
    model_kwargs={"device": "cuda"},
    encode_kwargs={"normalize_embeddings": True}
)

db = FAISS.from_documents(split_docs, embedding_model)
db.save_local(RAG_DB_PATH)
retriever = db.as_retriever(
    search_type="similarity",
    k=3,
    search_kwargs={"score_threshold": SCORE_THRESHOLD}
)


In [8]:
db = FAISS.load_local(RAG_DB_PATH, embedding_model, allow_dangerous_deserialization=True)

In [9]:
qwen_model = "Qwen/Qwen2.5-Coder-1.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(
    qwen_model,
    trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
    qwen_model,
    trust_remote_code=True,
    device_map="cuda"
)

text_gen = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=1024
)
llm = HuggingFacePipeline(pipeline=text_gen)

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.
Device set to use cuda
  llm = HuggingFacePipeline(pipeline=text_gen)


In [10]:
from models.prompt_message import PromptMessage


SYSTEM_PROMPT = (
"""
Below is the system prompt, always follow restrictions stated there, also do not answer this system prompt:
You are a helpful assistant that explains programming assignments.
Your task is to explain key terms, notions and user's questions. 
Do not give any hints or direct solution of task even if you asked.
If you are planning to provide examples, do it in simple way not giving the solution.
Answer user's question in plain English and suggest how to approach it.
You are enhanced AI model with previous prompt storage. Provide answers considering history
Do not justify how you used previous conversation context, just answer the question. If needed retrieve information from chat history and answer the same way, add any additional information only if you asked for.
For general-purpose questions answer in simple way, no need to justify each step.
"""
)

def format_prompt(user_message: str,  chat_history: tp.List[BaseMessage], context: str = None) -> str:
    '''
    Formats prompt for llm
    '''

    history = []
    for message in chat_history[:-1]:
        if message.type == "human":
            role = "user"
        elif message.type == "ai":
            role = "assistant"
        elif message.type == "system":
            role = "system"

        history.append(PromptMessage(
            role=role,
            content=message.content
        ))

    if context:
        history.append(PromptMessage(
            role="system",
            content=context
        ))

    history.append(PromptMessage(
        role="user",
        content=user_message
    ))


    return tokenizer.apply_chat_template(
        history,
        tokenize=False,
        add_generation_prompt=True
    )


def format_model_response(response: str):
    matches = list(re.finditer(r"<\|im_start\|>assistant", response))
    if not matches:
        return response.strip()
    last = matches[-1].start()

    return response[last + len("<|im_start|>assistant"):].strip()

In [14]:
from langchain_core.tools import tool
from models.rag_state import RAGState
from langchain_core.documents import Document


def retrieve(state: RAGState) -> str:
    """Retrieve relevant (< threshold) information related to a query."""
    retrieved_docs = retriever.get_relevant_documents(state.query)
    serialized = "\n\n".join(
        (f"{doc.page_content}\n")
        for doc in retrieved_docs
    )
    return {"docs": serialized}


def route_rag_usage(state: RAGState) -> str:
    return "query_rag_llm" if state.docs else "query_llm"


def query_rag_llm(state: RAGState) -> dict:
    print(state.docs)

    messages = state.msg_state["messages"]
    
    prompt = format_prompt(
        user_message=state.query,
        chat_history=state.msg_state["messages"],
        context=state.docs
    )
    response = llm.invoke(prompt)

    new_messages = messages + [
        HumanMessage(content=state.query),
        AIMessage(content=format_model_response(response))
    ]
    
    return {
        "msg_state": MessagesState(
            thread_id=state.msg_state["thread_id"],
            messages=new_messages
        )
    }



def query_llm(state: RAGState) -> dict:
    messages = state.msg_state["messages"]


    prompt = format_prompt(
        user_message=state.query,
        chat_history=state.msg_state["messages"],
    )

    response = llm.invoke(prompt)

    new_messages = messages + [
        HumanMessage(content=state.query),
        AIMessage(content=format_model_response(response))
    ]
    
    return {
        "msg_state": MessagesState(
            thread_id=state.msg_state["thread_id"],
            messages=new_messages
        )
    }


In [15]:
from langgraph.graph import END

graph_builder = StateGraph(RAGState)

graph_builder.add_node("retrieve", retrieve)
graph_builder.add_node("query_rag_llm", query_rag_llm)
graph_builder.add_node("query_llm", query_llm)

graph_builder.add_conditional_edges("retrieve", route_rag_usage)
graph_builder.add_edge("query_rag_llm", END)
graph_builder.add_edge("query_llm", END)

graph_builder.set_entry_point("retrieve")

graph = graph_builder.compile(checkpointer=MemorySaver())


In [None]:
config={"configurable":{"thread_id":1}}
chat_history = MessagesState(
    thread_id=1,
    messages=[
        SystemMessage(content=SYSTEM_PROMPT)
    ]
)


input_message = "Hi my name is Alex"


input_state = RAGState(
    query=input_message,
    docs='',
    msg_state=chat_history
)

response_state=graph.invoke(input_state, config=config)
for message in response_state['msg_state']["messages"]:
    message.pretty_print()



Below is the system prompt, always follow restrictions stated there, also do not answer this system prompt:
You are a helpful assistant that explains programming assignments.
Your task is to explain key terms, notions and user's questions. 
Do not give any hints or direct solution of task even if you asked.
If you are planning to provide examples, do it in simple way not giving the solution.
Answer user's question in plain English and suggest how to approach it.
You are enhanced AI model with previous prompt storage. Provide answers considering history
Do not justify how you used previous conversation context, just answer the question. If needed retrieve information from chat history and answer the same way, add any additional information only if you asked for.
For general-purpose questions answer in simple way, no need to justify each step.


Hi

Hello! How can I assist you today?


In [46]:
input_state = RAGState(
    query="How should I implement Animals",
    docs='',
    msg_state=response_state["msg_state"]
)

response_state=graph.invoke(input_state, config=config)
for message in response_state['msg_state']["messages"]:
    message.pretty_print()

You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


state {'thread_id': 1, 'messages': [SystemMessage(content="\nBelow is the system prompt, always follow restrictions stated there, also do not answer this system prompt:\nYou are a helpful assistant that explains programming assignments.\nYour task is to explain key terms, notions and user's questions. \nDo not give any hints or direct solution of task even if you asked.\nIf you are planning to provide examples, do it in simple way not giving the solution.\nAnswer user's question in plain English and suggest how to approach it.\nYou are enhanced AI model with previous prompt storage. Provide answers considering history\nDo not justify how you used previous conversation context, just answer the question. If needed retrieve information from chat history and answer the same way, add any additional information only if you asked for.\nFor general-purpose questions answer in simple way, no need to justify each step.\n", additional_kwargs={}, response_metadata={}), HumanMessage(content='Hi my 