In [33]:
import os, time
import chromadb

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_community.document_loaders import TextLoader
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_community.llms import Ollama
from langchain_core.runnables import RunnablePassthrough, RunnablePick

from langchain import hub

from chromadb.errors import InvalidDimensionException

In [58]:
DATABASE_PATH = '/home/raj/nlp/cmu-rag/rag/chroma/txt/'
embedding_name = 'llama2'
persist_directory = DATABASE_PATH + embedding_name
embedding = OllamaEmbeddings(model=embedding_name)

vector_store = Chroma(persist_directory=persist_directory, embedding_function=embedding)

In [118]:
rag_prompt_llama = hub.pull("rlm/rag-prompt-llama")
rag_prompt_llama.messages



prompt_message = """You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use as few words as possible and keep the answer concise. Do not mention the context in your response.
Question: {question} 
Context: {context} 
Answer:"""

rag_prompt_llama.messages[0].prompt.template = prompt_message

llm = Ollama(model = 'llama2')

def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

retriever = vector_store.as_retriever()

qa_chain = (
    {"context": retriever | format_docs, "question": RunnablePassthrough()}
    | rag_prompt_llama
    | llm
    | StrOutputParser()
)


In [122]:
import nltk
# nltk.download('punkt')

answers = []

for question in questions:
    answer = dict()

    answer_raw = qa_chain.invoke(question)
    answer["raw"] = answer_raw
    num_lines = answer_raw.count('\n')
    answer["num_lines"] = num_lines
    lines = answer_raw.split('\n')
    if num_lines == 1:
        answer["processed"] = lines[0] if "i don't know" not in lines[0].lower() else "I do not know"
    else:
        answer_lines = []
        for line in lines:
            if "i don't know" not in line.lower():
                answer_lines.append(line)
        answer["processed"] = " ".join(answer_lines)
    
    answers.append(answer)

for answer in answers:
    for k, v in answer.items():
        print(f"{k}:\n\t {v}")



raw:
	 The LTI PhD program requires 48 units of core courses.
num_lines:
	 0
processed:
	 The LTI PhD program requires 48 units of core courses.
