In [1]:
# !pip install langchain chromadb ctransformers transformers sentence_transformers
# Apple silicon:
# !pip uninstall ctransformers
# !CT_METAL=1 pip install ctransformers --no-binary ctransformers

In [2]:
from pathlib import Path
from pprint import pprint

DATA_DIR = Path("../data")
SNAPSHOTS_DIR = DATA_DIR / "platform-docs-snapshots"
VERSIONS_DIR = DATA_DIR / "platform-docs-versions"

# hparams
chunk_size = 1024
chunk_overlap = chunk_size // 10
embedder_name = "BAAI/bge-small-en-v1.5" # https://huggingface.co/spaces/mteb/leaderboard

# Pre-processing

In [3]:
from langchain_community.document_loaders import UnstructuredMarkdownLoader, DirectoryLoader, TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

docs = DirectoryLoader(VERSIONS_DIR, glob="[!.]*/[!.]*.md", loader_cls=TextLoader).load()
docs = [d for d in docs if Path(d.metadata['source']) != VERSIONS_DIR / "README.md"]

In [4]:
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
docs_split = text_splitter.split_documents(docs)

print(len(docs))
print(len(docs_split))

139
11598


# Build index

In [5]:
import uuid
from tqdm.notebook import tqdm
import chromadb
from chromadb.utils import embedding_functions
import re

# Chroma uses all-MiniLM-L6-v2 by default
chroma_client = chromadb.PersistentClient()
 
def format_model_name(name):
    # chromaDB only allows these characters
    return re.sub(r'[^a-zA-Z0-9_-]', '_', name)

collection_name = f"DSA_{format_model_name(embedder_name)}"

try:
    collection = chroma_client.get_collection(collection_name)
except ValueError:
    print("Building collection...")
    embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction(embedder_name=embedder_name)

    collection = chroma_client.create_collection(name=collection_name, embedding_function=embedding_fn)
    
    text = [d.page_content for d in docs_split]
    metadatas = [d.metadata for d in docs_split]
    ids = [uuid.uuid4().hex for _ in range(len(docs_split))]
    
    collection.add(
        documents=text,
        metadatas=metadatas,
        ids=ids,
    )

In [6]:
results = collection.query(
    query_texts=["How do I get activity logs from the Twitter API?"],
    n_results=5,
)

In [7]:
pprint(results["metadatas"])
results["documents"][0][0]

[[{'source': '../data/platform-docs-versions/X_Twitter-API-V1/Tweets.md'},
  {'source': '../data/platform-docs-versions/X_Twitter-API-V2/Tweets.md'},
  {'source': '../data/platform-docs-versions/X_Twitter-API-V1/Tweets.md'},
  {'source': '../data/platform-docs-versions/X_Twitter-API-V2/Users.md'},
  {'source': '../data/platform-docs-versions/X_Twitter-API-V2/Tweets.md'}]]


'Working with timelines  \n\n-------------------------\n\nThe Twitter API has several methods, such as\xa0[GET statuses / user\\_timeline](https://developer.twitter.com/en/docs/tweets/timelines/api-reference/get-statuses-user_timeline.html)\xa0and\xa0[GET statuses / home\\_timeline](https://developer.twitter.com/en/docs/tweets/timelines/api-reference/get-statuses-home_timeline.html), which return a timeline of Tweet data. Such timelines can grow very large, so there are limits to how much of a timeline a client application may fetch in a single request. Applications must therefore iterate through timeline results in order to build a more complete list.\n\nBecause of Twitter’s realtime nature and the volume of data which is constantly being added to timelines, standard paging approaches are not always effective. The goal of this page is to demonstrate the issues Twitter developers may face when paging through result sets and to give best practices for processing a timeline.\n\n### The p

# Load Generator

### Quantized Mistral 7B, finetuned on code instructions
- https://huggingface.co/TheBloke/OpenHermes-2.5-Mistral-7B-16k-GGUF#provided-files

In [8]:
from IPython.display import display, Markdown

def print_chat(chat_log):
    for entry in chat_log:
        if entry['role'] != 'system':
            display(Markdown(f"**{entry['role'].capitalize()}:** \n{entry['content']}\n"))

In [9]:
from ctransformers import AutoModelForCausalLM, AutoConfig
from transformers import AutoTokenizer
from typing import List, Dict, Optional

TEMPLATE = '''
You are a QA assistant that answers questions based only on the context you are given.

Examples:
Q: I am a coding expert. How could I get more info about a Facebook post whose url slug ends with 123456789_123456789? I have a crowdtangle API token (TOKEN). Your output is a bash code snippet.
A: ```bash
curl -L https://api.crowdtangle.com/post/123456789_123456789?token=<CROWDTANGLE_API_TOKEN>
```

Next section will contain the context. You can use it to formulate the answer to the user question.
------------------
{context}
------------------

RULES
- Give an answer ONLY based on the above context and with no prior knowledge.
- If you cannot come up with an answer to the user question, answer "I do not know the answer to this question".
- Your answers are short, complete and easy to follow. Include code examples if neccessary.
'''

class MistralRAGGenerator:
    def __init__(self, gpu_layers: int = 0):
        self.model = AutoModelForCausalLM.from_pretrained(
            "TheBloke/OpenHermes-2.5-Mistral-7B-16k-GGUF",
            model_file="openhermes-2.5-mistral-7b-16k.Q4_K_M.gguf",
            model_type="mistral", 
            gpu_layers=gpu_layers,
            max_new_tokens=4000,
            context_length=16_000,
        )

        self.tokenizer = AutoTokenizer.from_pretrained("NurtureAI/OpenHermes-2.5-Mistral-7B-16k")

    def format_prompt(self, messages: List[Dict[str, str]]) -> str:
        return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    
    def format_context(self, context: List[str]) -> str:
        context_formatted = []
        for i, c in enumerate(context):
            context_formatted.append(f"CONTEXT {i}\n{c}")

        return "\n\n".join(context_formatted)

    def init_chat_log(self, context: List[str], template: Optional[str] = None) -> List[Dict[str, str]]:
        template = template or TEMPLATE
        context = self.format_context(context)
        return [
            {"role": "system", "content": template.format(context=context)},
        ]
    
    def chat(self, chat_log: List[Dict[str, str]], query: str):
        if query:
            chat_log.append({"role": "user", "content": query})
        if chat_log[-1]["role"] != "user":
            raise ValueError("query required")
                
        prompt = self.format_prompt(chat_log)
        output = self.model(prompt, stop=["<|im_end|>"])
        chat_log.append({"role": "assistant", "content": output})
        return chat_log
    
    def generate_answer(self, context: str, query: str) -> List[Dict[str, str]]:
        chat_log = self.init_chat_log(context, template="You: {context}\nMistral: ")
        return self.chat(chat_log, query=query)


rag_generator = MistralRAGGenerator()

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [10]:
k = 5

query = "Tell me how to work with WhatsApp business accounts using the Facebook Graph API"
results = collection.query(
    query_texts=[query],
    n_results=k,
)

context = results['documents'][0]

In [11]:
chat_log = rag_generator.generate_answer(context, query)

In [12]:
print_chat(chat_log)

**User:** 
Tell me how to work with WhatsApp business accounts using the Facebook Graph API


**Assistant:** 
To work with WhatsApp Business Accounts using the Facebook Graph API, you need to have the following permissions: `whatsapp_business_management`, `whatsapp_messaging`, and `public_profile`. Additionally, you will need a valid user access token and the ID of the WhatsApp Business Account (WABA) you want to retrieve information for.

To make a request using cURL, use the following command:
```cURLAndroid SDKObjective-C
    curl -i -X GET \
     "https://graph.facebook.com/LATEST-VERSION/WHATSAPP-BUSINESS-ACCOUNT-ID?access_token=USER-ACCESS-TOKEN"
```
Or using the SDK in Objective-C:
```swift
    GraphRequest request = GraphRequest.newGraphPathRequest(
      accessToken,
      "/WHATSAPP-BUSINESS-ACCOUNT-ID",
      new GraphRequest.Callback() {
        @Override
        public void onCompleted(GraphResponse response) {
          // Insert your code here
        }
    });
    
    request.executeAsync();
```


In [17]:
# TODO 
# - batch inputs, context, outputs in whole pipeline
# - read input + write output + eval

from data import get_training_data
train_queries, train_answers, train_context = get_training_data()

# model outputs
chat_logs_batch = [chat_log, chat_log] # TODO dummy batch
queries = [c[1]["content"] for c in chat_logs_batch]
answers = [c[-1]["content"] for c in chat_logs_batch]