In [1]:
import logging
import logging.handlers as handlers
import sys
import os
def init_logger(name:str, filename:str, path:str='./log', level:int=logging.INFO):
    logger = logging.getLogger(name)
    logger.setLevel(level)

    # Check path exist
    if(os.path.exists(path) == False):
        os.makedirs(path)

    file = os.path.join(path,f"{filename}.log")
    formatter = logging.Formatter('%(asctime)s|%(filename)s:%(lineno)d|%(levelname)s|%(message)s')
    formatter.datefmt = '%d-%m-%Y %H:%M:%S'
    # Handler
    consoleHandler = logging.StreamHandler(sys.stdout)
    consoleHandler.setFormatter(formatter)
    # This will rotate log
    fileHandler = handlers.RotatingFileHandler(filename=file, mode='a', maxBytes=10240000, backupCount=10)
    fileHandler.setFormatter(formatter)

    # Add Handler
    logger.addHandler(consoleHandler)
    logger.addHandler(fileHandler)

    logger.propagate = False

In [4]:
import datetime
import pickle
# from components.logger import init_logger
# from components.model import load_model 
import logging
import os
from tqdm import tqdm
import torch

from typing import Union

# from langchain.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceInstructEmbeddings


from langchain import HuggingFacePipeline
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.memory import ConversationBufferWindowMemory

from langchain.llms.base import LLM
from langchain.chains.retrieval_qa.base import BaseRetrievalQA
from langchain.chains import LLMChain
from langchain.chains import ConversationalRetrievalChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT

# A bunch of global variable
MODEL_NAME:str = "mlflow-example"
STAGE:str = "Production"
MLFLOW_URL:str = "http://la.cs.ait.ac.th"
CACHE_FOLDER:str = os.path.join("/root","cache")
embedding_model:HuggingFaceInstructEmbeddings
vector_database:FAISS
llm_model:LLM
qa_retriever:BaseRetrievalQA
conversational_qa_memory_retriever:ConversationalRetrievalChain
question_generator:LLMChain

device="cuda:1"
device_id=1

prompt_template = """
You are the chatbot and the face of Asian Institute of Technology (AIT). Your job is to give answers to prospective and current students about the school.
Your job is to answer questions only and only related to the AIT. Anything unrelated should be responded with the fact that your main job is solely to provide assistance regarding AIT.
MUST only use the following pieces of context to answer the question at the end. If the answers are not in the context or you are not sure of the answer, just say that you don't know, don't try to make up an answer.
{context}
Question: {question}
When encountering abusive, offensive, or harmful language, such as fuck, bitch,etc,  just politely ask the users to maintain appropriate behaviours.
Always make sure to elaborate your response and use vibrant, positive tone to represent good branding of the school.
Never answer with any unfinished response
Answer:
"""
PROMPT = PromptTemplate(
    template=prompt_template, input_variables=["context", "question"]
)
chain_type_kwargs = {"prompt": PROMPT}

def load_scraped_web_info():
    with open("ait-web-document", "rb") as fp:
        ait_web_documents = pickle.load(fp)
        
        
    text_splitter = RecursiveCharacterTextSplitter(
        # Set a really small chunk size, just to show.
        chunk_size = 500,
        chunk_overlap  = 100,
        length_function = len,
    )

    chunked_text = text_splitter.create_documents([doc for doc in tqdm(ait_web_documents)])

def load_embedding_model():
    embedding_model = HuggingFaceInstructEmbeddings(model_name='hkunlp/instructor-base',
                                                    cache_folder='./.cache',
                                                model_kwargs = {'device': torch.device(device)}
                                                )
    return embedding_model

def load_faiss_index():
    global embedding_model
    vector_database = FAISS.load_local("faiss_index_web_and_curri_new", embedding_model) #CHANGE THIS FAISS EMBEDDED KNOWLEDGE
    return vector_database

def load_llm_model_cpu():
    llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0', 
                            task= 'text2text-generation',        
                            model_kwargs={ "max_length": 256, "temperature": 0,
                                            "torch_dtype":torch.float32,
                                        "repetition_penalty": 1.3})

    return llm

def load_llm_model_gpu(gpu_id:int ):
    llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0', 
                                            task= 'text2text-generation',
                                            device=device_id,
                                            model_kwargs={ 
                                                # "device_map": "auto",
                                                        # "load_in_8bit": True,
                                                        "max_length": 256, 
                                                        "offload_folder": "offload",
                                                        "temperature": 0,
                                                        "repetition_penalty": 1.5},
                                            )

    return llm

def load_conversational_qa_memory_retriever():
    global vector_database, llm_model

    question_generator = LLMChain(llm=llm_model, prompt=CONDENSE_QUESTION_PROMPT)
    doc_chain = load_qa_chain(llm_model, chain_type="stuff", prompt = PROMPT)
    memory = ConversationBufferWindowMemory(k = 3,  memory_key="chat_history", return_messages=True,  output_key='answer')
    
    
    
    conversational_qa_memory_retriever = ConversationalRetrievalChain(
        retriever=vector_database.as_retriever(),
        question_generator=question_generator,
        combine_docs_chain=doc_chain,
        return_source_documents=True,
        memory = memory,
        get_chat_history=lambda h :h)
    return conversational_qa_memory_retriever, question_generator

def load_retriever(llm, db):
    qa_retriever = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff",
                            retriever=db.as_retriever(),
                            chain_type_kwargs= chain_type_kwargs)

    return qa_retriever

def retrieve_document(query_input):
    global vector_database
    related_doc = vector_database.similarity_search(query_input)
    return related_doc

def retrieve_answer(my_text_input:str):
    global qa_retriever
    prompt_answer=  my_text_input
    answer = qa_retriever.run(prompt_answer)
    log = {"timestamp": datetime.datetime.now(),
        "question":my_text_input,
        "generated_answer": answer[6:],
        "rating":0 }

    # TODO: change below code and maintain in session
    # st.session_state.history.append(log)
    # update_worksheet_qa()
    # st.session_state.chat_history.append({"message": st.session_state.my_text_input, "is_user": True})
    # st.session_state.chat_history.append({"message": answer[6:] , "is_user": False})

    # st.session_state.my_text_input = ""

    return answer[6:] #this positional slicing helps remove "<pad> " at the beginning

def main():
    global embedding_model, vector_database, llm_model, qa_retriever, conversational_qa_memory_retriever, question_generator
    # Init Logger
    init_logger(name="main", filename="main", path="./logs/", level=logging.DEBUG)
    logger = logging.getLogger(name="main")
    # Now prepare model

    # MODEL = load_model(model_name=MODEL_NAME, stage=STAGE, cache_folder=CACHE_FOLDER)
    load_scraped_web_info()
    embedding_model = load_embedding_model()
    vector_database = load_faiss_index()
    # llm_model = load_llm_model_cpu()
    llm_model = load_llm_model_gpu(0)
    qa_retriever = load_retriever(llm= llm_model, db= vector_database)
    conversational_qa_memory_retriever, question_generator = load_conversational_qa_memory_retriever()

    return 0

def get_root(text: str):
    timestamp = datetime.datetime.now()
    ans = retrieve_answer(text)
    timestamp2 = datetime.datetime.now()
    delta = timestamp2 - timestamp
    print(f"Time taken is {delta.total_seconds()} seconds")
    return ans

In [5]:
main()

100%|██████████| 205/205 [00:00<00:00, 4478293.33it/s]


load INSTRUCTOR_Transformer
max_seq_length  512


0

In [7]:
get_root("What is AIT?")

Time taken is 2.542579 seconds


'The  Asian  Institute  of  Technology  (AIT)  is  an  international  English-speaking  postgraduate  institution,  focusing  on  engineering,  environment,  and  management  studies.\n'

In [9]:
get_root("what are the subjects that I need to take if I am taking computer science as my major")


Time taken is 6.198089 seconds


'If  you  are  taking  computer  science  as  your  major,  the  subjects  that  you  need  to  take  include:\n *  Calculus\n *  Discrete  Mathematics\n *  Linear  Algebra\n *  Basic  Computer  Programming\n *  Data  Modeling  and  Management\n *  Business  Intelligence  and  Analytics\n The  core  curriculum  in  computer  science  covers  all  aspects  of  computing,  with  the  faculty  particularly  active  in  artificial  intelligence,  software  engineering,  networking,  and  information  systems.  Students  are  also  encouraged  to  take  courses  and  conduct  research  in  areas  of  computer  science  that  interact  with  Information  Management,  Industrial  Engineering,  Manufacturing  Systems  Engineering,  Telecommunications,  Mechatronics,  and  other  fields  covered  at  the  Institute.\n'