# core

> Core functionality for `onprem`

In [None]:
#| default_exp core

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export

from langchain.chains import RetrievalQA, ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.callbacks.manager import CallbackManager
from langchain.vectorstores import Chroma
from langchain.llms import LlamaCpp
from langchain.prompts import PromptTemplate
import chromadb
import os
import warnings
from typing import Any, Dict, Generator, List, Optional, Tuple, Union


In [None]:
#| export

# reference: https://github.com/langchain-ai/langchain/issues/5630#issuecomment-1574222564
class AnswerConversationBufferMemory(ConversationBufferMemory):
    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
        return super(AnswerConversationBufferMemory, self).save_context(inputs,{'response': outputs['answer']})

In [None]:
#| export

from onprem import utils as U
DEFAULT_MODEL_URL = 'https://huggingface.co/TheBloke/Wizard-Vicuna-7B-Uncensored-GGML/resolve/main/Wizard-Vicuna-7B-Uncensored.ggmlv3.q4_0.bin'
DEFAULT_LARGER_URL = ' https://huggingface.co/TheBloke/WizardLM-13B-V1.2-GGML/resolve/main/wizardlm-13b-v1.2.ggmlv3.q4_0.bin'
DEFAULT_EMBEDDING_MODEL = 'sentence-transformers/all-MiniLM-L6-v2'
DEFAULT_QA_PROMPT = """"Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.

{context}

Question: {question}
Helpful Answer:"""

class LLM:
    def __init__(self, 
                 model_url=DEFAULT_MODEL_URL,
                 use_larger:bool=False,
                 n_gpu_layers:Optional[int]=None, 
                 model_download_path:Optional[str]=None,
                 vectordb_path:Optional[str]=None,
                 max_tokens:int=512, 
                 n_ctx:int=2048, 
                 n_batch:int=1024,
                 mute_stream:bool=False,
                 callbacks = [],
                 embedding_model_name:str ='sentence-transformers/all-MiniLM-L6-v2',
                 embedding_model_kwargs:dict ={'device': 'cpu'},
                 embedding_encode_kwargs:dict ={'normalize_embeddings': False},
                 confirm:bool=True,
                 verbose:bool=False,
                 **kwargs):
        """
        LLM Constructor.  Extra `kwargs` are fed directly to `langchain.llms.LlamaCpp`.
        
        **Args:**

        - *model_url*: URL to `.bin` model (currently must be GGML model).
        - *use_larger*: If True, a larger model than the default `model_url` will be used.
        - *n_gpu_layers*: Number of layers to be loaded into gpu memory. Default is `None`.
        - *model_download_path*: Path to download model. Default is `onprem_data` in user's home directory.
        - *vectordb_path*: Path to vector database (created if it doesn't exist). 
                           Default is `onprem_data/vectordb` in user's home directory.
        - *max_tokens*: The maximum number of tokens to generate.
        - *n_ctx*: Token context window.
        - *n_batch*: Number of tokens to process in parallel.
        - *mute_stream*: Mute ChatGPT-like token stream output during generation
        - *callbacks*: Callbacks to supply model
        - *embedding_model_name*: name of sentence-transformers model. Used for `LLM.ingest` and `LLM.ask`.
        - *embedding_model_kwargs*: arguments to embedding model (e.g., `{device':'cpu'}`).
        - *embedding_encode_kwargs*: arguments to encode method of 
                                     embedding model (e.g., `{'normalize_embeddings': False}`).
        - *confirm*: whether or not to confirm with user before downloading a model
        - *verbose*: Verbosity
        """
        self.model_url = DEFAULT_LARGER_URL if use_larger else model_url
        if verbose:
            print(f'Since use_larger=True, we are using: {os.path.basename(DEFAULT_LARGER_URL)}')
        self.model_name = os.path.basename(self.model_url)
        self.model_download_path = model_download_path or U.get_datadir()
        if not os.path.isfile(os.path.join(self.model_download_path, self.model_name)):
            self.download_model(self.model_url, model_download_path=self.model_download_path, confirm=confirm)
        self.vectordb_path = vectordb_path
        self.llm = None
        self.ingester = None
        self.qa = None
        self.chatqa = None
        self.n_gpu_layers = n_gpu_layers
        self.max_tokens = max_tokens
        self.n_ctx = n_ctx
        self.n_batch = n_batch
        self.callbacks = [] if mute_stream else [StreamingStdOutCallbackHandler()]
        if callbacks: self.callbacks.extend(callbacks)
        self.embedding_model_name = embedding_model_name
        self.embedding_model_kwargs = embedding_model_kwargs
        self.embedding_encode_kwargs = embedding_encode_kwargs
        self.verbose = verbose
        self.extra_kwargs = kwargs
 
    @classmethod
    def download_model(cls, model_url:str=DEFAULT_MODEL_URL, 
                       model_download_path:Optional[str]=None, 
                       confirm:bool=True, 
                       ssl_verify:bool=True):
        """
        Download an LLM in GGML format supported by [lLama.cpp](https://github.com/ggerganov/llama.cpp).
        
        **Args:**
        
        - *model_url*: URL of model
        - *model_download_path*: Path to download model. Default is `onprem_data` in user's home directory.
        - *confirm*: whether or not to confirm with user before downloading
        - *ssl_verify*: If True, SSL certificates are verified. 
                        You can set to False if corporate firewall gives you problems.
        """
        datadir = model_download_path or U.get_datadir()
        model_name = os.path.basename(model_url)
        filename = os.path.join(datadir, model_name)
        confirm_msg = f"You are about to download the LLM {model_name} to the {datadir} folder. Are you sure?"
        if os.path.isfile(filename):
            confirm_msg = f'There is already a file {model_name} in {datadir}.\n Do you want to still download it?'
            
        shall = True
        if confirm:
            shall = input("%s (y/N) " % confirm_msg).lower() == "y"
        if shall:
            U.download(model_url, filename, verify=ssl_verify)
        else:
            warnings.warn(f'{model_name} was not downloaded because "Y" was not selected.')
        return

    def load_ingester(self):
        """
        Get `Ingester` instance. 
        You can access the `langchain.vectorstores.Chroma` instance with `load_ingester().get_db()`.
        """
        if not self.ingester:
            from onprem.ingest import Ingester
            self.ingester = Ingester(embedding_model_name=self.embedding_model_name,
                                     embedding_model_kwargs=self.embedding_model_kwargs,
                                     embedding_encode_kwargs=self.embedding_encode_kwargs,
                                     persist_directory=self.vectordb_path)
        return self.ingester
        

    def load_vectordb(self):
        """
        Get Chroma db instance
        """
        ingester = self.load_ingester()
        db = ingester.get_db()
        if not db:
            raise ValueError('A vector database has not yet been created. Please call the LLM.ingest method.')
        return db

    
    def ingest(self, 
               source_directory:str,
               chunk_size:int=500,
               chunk_overlap:int=50
              ):
        """
        Ingests all documents in `source_folder` into vector database.
        Previously-ingested documents are ignored.

        **Args:**
        
        - *source_directory*: path to folder containing document store
        - *chunk_size*: text is split to this many characters by `langchain.text_splitter.RecursiveCharacterTextSplitter`
        - *chunk_overlap*: character overlap between chunks in `langchain.text_splitter.RecursiveCharacterTextSplitter`
        
        **Returns:** `None`
        """
        ingester = self.load_ingester()
        ingester.ingest(source_directory, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
        return

 
        
    def check_model(self):
        """
        Returns the path to the model
        """
        datadir = self.model_download_path
        model_path = os.path.join(datadir, self.model_name)
        if not os.path.isfile(model_path):
            raise ValueError(f'The LLM model {self.model_name} does not appear to have been downloaded. '+\
                             f'Execute the download_model() method to download it.')
        return model_path
        
 
    def load_llm(self):
        """
        Loads the LLM from the model path.
        """
        model_path = self.check_model()
        
        if not self.llm:
            self.llm =  llm = LlamaCpp(model_path=model_path, 
                                       max_tokens=self.max_tokens, 
                                       n_batch=self.n_batch, 
                                       callback_manager = CallbackManager(self.callbacks),
                                       verbose=self.verbose, 
                                       n_gpu_layers=self.n_gpu_layers, 
                                       n_ctx=self.n_ctx, **self.extra_kwargs)    

        return self.llm
        
        
    def prompt(self, prompt, prompt_template:Optional[str]=None):
        """
        Send prompt to LLM to generate a response
        
        **Args:**
        
        - *prompt*: The prompt to supply to the model
        - *prompt_template*: Optional prompt template (must have a variable named "prompt")
        """
        llm = self.load_llm()
        if prompt_template:
            prompt = prompt_template.format(**{'prompt': prompt})
        return llm(prompt)  
 

    def load_qa(self, num_source_docs:int=4, prompt_template:str=DEFAULT_QA_PROMPT):
        """
        Prepares and loads the `langchain.chains.RetrievalQA` object
        
        **Args:**
        
        - *num_source_docs*: the number of ingested source documents use to generate answer
        - *prompt_template*: A string representing the prompt with variables "context" and "question"      
        """
        if self.qa is None:
            db = self.load_vectordb()
            retriever = db.as_retriever(search_kwargs={"k": num_source_docs})
            llm = self.load_llm()
            PROMPT = PromptTemplate(
                        template=prompt_template, input_variables=["context", "question"])
            self.qa = RetrievalQA.from_chain_type(llm=llm, 
                                                 chain_type="stuff", 
                                                 retriever=retriever, 
                                                 return_source_documents= True,
                                                 chain_type_kwargs={'prompt':PROMPT})
        return self.qa

    def load_chatqa(self, num_source_docs:int=4):
        """
        Prepares and loads a `langchain.chains.ConversationalRetrievalChain` instance
        
        **Args:**
        
        - *num_source_docs*: the number of ingested source documents use to generate answer
        """
        if self.chatqa is None:

            db = self.load_vectordb()
            retriever = db.as_retriever(search_kwargs={"k": num_source_docs})
            llm = self.load_llm()
            memory = AnswerConversationBufferMemory(memory_key="chat_history", return_messages=True)
            self.chatqa = ConversationalRetrievalChain.from_llm(self.llm, 
                                                                retriever,
                                                                memory=memory,
                                                               return_source_documents=True)
        return self.chatqa
    
    
    def ask(self, question:str, num_source_docs:int=4, prompt_template=DEFAULT_QA_PROMPT):
        """
        Answer a question based on source documents fed to the `ingest` method.
        
        **Args:**
        
        - *question*: a question you want to ask
        - *num_source_docs*: the number of ingested source documents use to generate answer
        - *prompt_template*: A string representing the prompt with variables "context" and "question"

        **Returns:**

        - A tuple consisting of the answer and the list of source documents used to generate the answer.
        """
        qa = self.load_qa(num_source_docs=num_source_docs, prompt_template=prompt_template)
        res = qa(question)
        return res['result'], res['source_documents']
    
    
    def chat(self, question:str, num_source_docs:int=4):
        """
        Chat with documents fed to the `ingest` method.
        Unlike `LLM.ask`, `LLM.chat` includes conversational memory.
        
        **Args:**
        
        - *question*: a question you want to ask
        - *num_source_docs*: the number of ingested source documents use to generate answer
        
        **Returns:**
        
        - A dictionary with keys: `question`, `answer`, `chat_history`, `source_documents`
        """
        chatqa = self.load_chatqa(num_source_docs=num_source_docs)
        res = chatqa(question)
        return res

In [None]:
show_doc(LLM.download_model)

---

[source](https://github.com/amaiya/onprem/blob/master/onprem/core.py#L98){target="_blank" style="float:right; font-size:smaller"}

### LLM.download_model

>      LLM.download_model
>                          (model_url:str='https://huggingface.co/TheBloke/Wizar
>                          d-Vicuna-7B-Uncensored-GGML/resolve/main/Wizard-
>                          Vicuna-7B-Uncensored.ggmlv3.q4_0.bin',
>                          model_download_path:Optional[str]=None,
>                          confirm:bool=True, ssl_verify:bool=True)

Download an LLM in GGML format supported by [lLama.cpp](https://github.com/ggerganov/llama.cpp).

**Args:**

- *model_url*: URL of model
- *model_download_path*: Path to download model. Default is `onprem_data` in user's home directory.
- *confirm*: whether or not to confirm with user before downloading
- *ssl_verify*: If True, SSL certificates are verified. 
                You can set to False if corporate firewall gives you problems.

In [None]:
show_doc(LLM.load_llm)

---

[source](https://github.com/amaiya/onprem/blob/master/onprem/core.py#L189){target="_blank" style="float:right; font-size:smaller"}

### LLM.load_llm

>      LLM.load_llm ()

Loads the LLM from the model path.

In [None]:
show_doc(LLM.load_ingester)

---

[source](https://github.com/amaiya/onprem/blob/master/onprem/core.py#L129){target="_blank" style="float:right; font-size:smaller"}

### LLM.load_ingester

>      LLM.load_ingester ()

Get `Ingester` instance. 
You can access the `langchain.vectorstores.Chroma` instance with `load_ingester().get_db()`.

In [None]:
show_doc(LLM.load_qa)

---

[source](https://github.com/amaiya/onprem/blob/master/onprem/core.py#L222){target="_blank" style="float:right; font-size:smaller"}

### LLM.load_qa

>      LLM.load_qa (num_source_docs:int=4, prompt_template:str='"Use the
>                   following pieces of context to answer the question at the
>                   end. If you don\'t know the answer, just say that you don\'t
>                   know, don\'t try to make up an
>                   answer.\n\n{context}\n\nQuestion: {question}\nHelpful
>                   Answer:')

Prepares and loads the `langchain.chains.RetrievalQA` object

**Args:**

- *num_source_docs*: the number of ingested source documents use to generate answer
- *prompt_template*: A string representing the prompt with variables "context" and "question"

In [None]:
show_doc(LLM.load_chatqa)

---

[source](https://github.com/amaiya/onprem/blob/master/onprem/core.py#L244){target="_blank" style="float:right; font-size:smaller"}

### LLM.load_chatqa

>      LLM.load_chatqa (num_source_docs:int=4)

Prepares and loads a `langchain.chains.ConversationalRetrievalChain` instance

**Args:**

- *num_source_docs*: the number of ingested source documents use to generate answer

In [None]:
show_doc(LLM.prompt)

---

[source](https://github.com/amaiya/onprem/blob/master/onprem/core.py#L207){target="_blank" style="float:right; font-size:smaller"}

### LLM.prompt

>      LLM.prompt (prompt, prompt_template:Optional[str]=None)

Send prompt to LLM to generate a response

**Args:**

- *prompt*: The prompt to supply to the model
- *prompt_template*: Optional prompt template (must have a variable named "prompt")

In [None]:
show_doc(LLM.ingest)

---

[source](https://github.com/amaiya/onprem/blob/master/onprem/core.py#L154){target="_blank" style="float:right; font-size:smaller"}

### LLM.ingest

>      LLM.ingest (source_directory:str, chunk_size:int=500,
>                  chunk_overlap:int=50)

Ingests all documents in `source_folder` into vector database.
Previously-ingested documents are ignored.

**Args:**

- *source_directory*: path to folder containing document store
- *chunk_size*: text is split to this many characters by `langchain.text_splitter.RecursiveCharacterTextSplitter`
- *chunk_overlap*: character overlap between chunks in `langchain.text_splitter.RecursiveCharacterTextSplitter`

**Returns:** `None`

In [None]:
show_doc(LLM.ask)

---

[source](https://github.com/amaiya/onprem/blob/master/onprem/core.py#L263){target="_blank" style="float:right; font-size:smaller"}

### LLM.ask

>      LLM.ask (question:str, num_source_docs:int=4, prompt_template='"Use the
>               following pieces of context to answer the question at the end.
>               If you don\'t know the answer, just say that you don\'t know,
>               don\'t try to make up an answer.\n\n{context}\n\nQuestion:
>               {question}\nHelpful Answer:')

Answer a question based on source documents fed to the `ingest` method.

**Args:**

- *question*: a question you want to ask
- *num_source_docs*: the number of ingested source documents use to generate answer
- *prompt_template*: A string representing the prompt with variables "context" and "question"

In [None]:
show_doc(LLM.chat)

---

[source](https://github.com/amaiya/onprem/blob/master/onprem/core.py#L278){target="_blank" style="float:right; font-size:smaller"}

### LLM.chat

>      LLM.chat (question:str, num_source_docs:int=4)

Chat with documents fed to the `ingest` method.
Unlike `LLM.ask`, `LLM.chat` includes conversational memory.

**Args:**

- *question*: a question you want to ask
- *num_source_docs*: the number of ingested source documents use to generate answer

**Returns:**
- A dictionary with keys: `question`, `answer`, `chat_history`

## Example Usage

We'll use a small 3B-parameter model here for testing purposes. The vector database is stored under `~/onprem_data` by default. In this example, we will store the vector store in temporary folders.

In [None]:
#| notest
import tempfile

In [None]:
#| notest
vectordb_path = tempfile.mkdtemp()

In [None]:
#| notest
url = 'https://huggingface.co/TheBloke/orca_mini_3B-GGML/resolve/main/orca-mini-3b.ggmlv3.q4_1.bin'
llm = LLM(model_url=url,
          vectordb_path=vectordb_path,
          confirm=False)

In [None]:
#| notest
assert os.path.isfile(os.path.join(U.get_datadir(), os.path.basename(url))), "missing model"

In [None]:
#| notest

prompt = """Extract the names of people in the supplied sentences. Here is an example:
Sentence: James Gandolfini and Paul Newman were great actors.
People:
James Gandolfini, Paul Newman
Sentence:
I like Cillian Murphy's acting. Florence Pugh is great, too.
People:"""

In [None]:
#| notest
saved_output = llm.prompt(prompt)

ggml_init_cublas: found 2 CUDA devices:
  Device 0: NVIDIA TITAN V, compute capability 7.0
  Device 1: NVIDIA TITAN V, compute capability 7.0
llama.cpp: loading model from /home/amaiya/onprem_data/orca-mini-3b.ggmlv3.q4_1.bin
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 2048
llama_model_load_internal: n_embd     = 3200
llama_model_load_internal: n_mult     = 240
llama_model_load_internal: n_head     = 32
llama_model_load_internal: n_layer    = 26
llama_model_load_internal: n_rot      = 100
llama_model_load_internal: ftype      = 3 (mostly Q4_1)
llama_model_load_internal: n_ff       = 8640
llama_model_load_internal: model size = 3B
llama_model_load_internal: ggml ctx size =    0.06 MB
llama_model_load_internal: using CUDA for GPU acceleration
ggml_cuda_set_main_device: using device 0 (NVIDIA TITAN V) as main device
llama_model_load_internal: mem required  = 3066.94 MB (+  682.00 MB per stat

 Cillian Murphy, Florence Pugh

In [None]:
#| notest
assert saved_output.strip() == 'Cillian Murphy, Florence Pugh', "bad response"

In [None]:
#| notest
llm.ingest('./sample_data', chunk_size=500, chunk_overlap=50)

Creating new vectorstore at /tmp/tmpl6ww9w5p
Loading documents from ./sample_data


Loading new documents: 100%|██████████████████████| 3/3 [00:00<00:00, 24.82it/s]


Loaded 12 new documents from ./sample_data
Split into 153 chunks of text (max. 500 chars each)
Creating embeddings. May take some minutes...
Ingestion complete! You can now query your documents using the LLM.ask method


In [None]:
#| notest
question = """What is ktrain?""" 
answer, docs = llm.ask(question)
print('\n\nReferences:\n\n')
for i, document in enumerate(docs):
    print(f"\n{i+1}.> " + document.metadata["source"] + ":")
    print(document.page_content)

 ktrain is a low-code library for augmented machine learning that enables beginners and domain experts with minimal programming or data science expertise to further democratize machine learning by facilitating the full machine learning workow from curating and preprocessing inputs (i.e., ground-truth-labeled training data) to training, tuning, troubleshooting, and applying models.

References:



1.> ./sample_data/1/ktrain_paper.pdf:
lection (He et al., 2019). By contrast, ktrain places less emphasis on this aspect of au-
tomation and instead focuses on either partially or fully automating other aspects of the
machine learning (ML) workﬂow. For these reasons, ktrain is less of a traditional Au-
2

2.> ./sample_data/1/ktrain_paper.pdf:
possible, ktrain automates (either algorithmically or through setting well-performing de-
faults), but also allows users to make choices that best ﬁt their unique application require-
ments. In this way, ktrain uses automation to augment and complement hu

**Pro-Tip**: Smaller models like this tend to hallucinate more easily than larger ones. If you see the model hallucinating answers, you can supply `use_larger=True` to `LLM` and use the slightly larger default model better-suited to this use case (or supply the URL to a different model of your choosing to `LLM`), which can provide better performance.

The `LLM.chat` method answers questions with consideration of conversational memory. Note that `LLM.chat` is better suited to larger/better models than the one used below, as the models are required to do more work (e.g., condensing the question and chat history into a standalone question).

In [None]:
#| notest
question = "What is ktrain?"
result = llm.chat(question)

 ktrain is a low-code library for augmented machine learning that allows users to automate or semi-automate various aspects of the machine learning workow, such as curating and preprocessing inputs, training, tuning, troubleshooting, and applying models. It is designed to improve the strengths of both humans and machines by enabling beginners and domain experts with minimal programming or data science expertise to use machine learning in their applications.

In [None]:
#| notest
question = "Can it be used for image classification?"
result = llm.chat(question)

 What is ktrain and how can it be used for image classification? ktrain is a low-code library for augmented machine learning that enables the full machine learning workow, including automation or semi-automation of tasks such as data curation, preprocessing, model training, tuning, troubleshooting, and application. It can be used with any machine learning model implemented in TensorFlow Keras (tf.keras). ktrain includes out-of-the-box support for various data types and tasks, including image classification using custom models and data formats, as well.

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()