In [1]:
# !pip install langchain chromadb ctransformers transformers sentence_transformers
# Apple silicon:
# !pip uninstall ctransformers
# !CT_METAL=1 pip install ctransformers --no-binary ctransformers

In [2]:
from pathlib import Path
from pprint import pprint

DATA_DIR = Path("../data")
SNAPSHOTS_DIR = DATA_DIR / "platform-docs-snapshots"
VERSIONS_DIR = DATA_DIR / "platform-docs-versions"

# Load data

In [3]:
from data import get_training_data

train_queries, train_answers, train_context = get_training_data()

query = [train_queries[-3]]
query

["How to get the number of likes of a TikTok post. I'd like an example query and an explanation of the parameters used."]

In [4]:
from langchain_community.document_loaders import UnstructuredMarkdownLoader, DirectoryLoader, TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

docs = DirectoryLoader(VERSIONS_DIR, glob="[!.]*/[!.]*.md", loader_cls=TextLoader).load()
docs = [d for d in docs if Path(d.metadata['source']) != VERSIONS_DIR / "README.md"]

In [5]:
import langdetect
from collections import Counter

languages = [langdetect.detect(d.page_content) for d in docs]

for doc, language in zip(docs, languages):
    doc.metadata["language"] = language

print(Counter(languages))

Counter({'en': 131, 'fr': 8})


# Load Retriever, build index

In [6]:
from retrievers.chroma_dpr import ChromaRetriever
from retrievers.colbert import ColbertRetriever
import re

retriever = ColbertRetriever()
collection_name = "DSA"
retriever.build(docs, collection_name)

Loading index from .ragatouille/colbert/indexes/DSA...
[Feb 02, 10:33:22] Loading segmented_maxsim_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...




In [7]:
results = retriever.query(
    query=["How do I get activity logs from the Twitter API?"],
    k=5,
)

Loading searcher for index DSA for the first time... This may take a few seconds
[Feb 02, 10:33:26] #> Loading codec...
[Feb 02, 10:33:26] #> Loading IVF...
[Feb 02, 10:33:26] Loading segmented_lookup_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...
[Feb 02, 10:33:26] #> Loading doclens...


100%|█████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 2193.67it/s]

[Feb 02, 10:33:26] #> Loading codes and residuals...



100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 56.22it/s]

[Feb 02, 10:33:26] Loading filter_pids_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...





[Feb 02, 10:33:26] Loading decompress_residuals_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...




Searcher loaded!

#> QueryTokenizer.tensorize(batch_text[0], batch_background[0], bsize) ==
#> Input: . How do I get activity logs from the Twitter API?, 		 True, 		 None
#> Output IDs: torch.Size([32]), tensor([  101,     1,  2129,  2079,  1045,  2131,  4023, 15664,  2013,  1996,
        10474, 17928,  1029,   102,   103,   103,   103,   103,   103,   103,
          103,   103,   103,   103,   103,   103,   103,   103,   103,   103,
          103,   103])
#> Output Mask: torch.Size([32]), tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])



1it [00:00, 107.98it/s]


# Load Generator

### Quantized Mistral 7B, finetuned on code instructions
- https://huggingface.co/TheBloke/OpenHermes-2.5-Mistral-7B-16k-GGUF#provided-files

In [10]:
from IPython.display import display, Markdown

def print_chat(chat_log):
    for entry in chat_log:
        if entry['role'] != 'system':
            display(Markdown(f"**{entry['role'].capitalize()}:** \n{entry['content']}\n"))

In [11]:
from generators.mistral import MistralRAGGenerator

rag_generator = MistralRAGGenerator()

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [15]:
k = 5

results = retriever.query(query, k)

context = [r['content'] for r in results]

1it [00:00, 136.74it/s]


In [None]:
chat_logs = rag_generator.generate_batch(context, query)

  0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
print_chat(chat_logs[-1])

In [None]:
# For evaluation

answer = [c[-1]["content"] for c in chat_logs]

for q, a in zip(query, answer):
    print("Q:", q)
    print("A", a)