In [None]:
!pip -qq install chromadb

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m399.0/399.0 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.6/62.6 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.4/58.4 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.5/59.5 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.3/5.3 MB[0m [31m26.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.9/5.9 MB[0m [31m54.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m57.5 MB/s[0m eta [

# The following cells creates a db and few functions.


In [None]:
import chromadb
from chromadb import Settings
from chromadb.utils import embedding_functions
import time
import random

## This code is setting up a Chroma database and preparing it for use.

In [None]:
db_directory="my_db_directory"  # the directory where the database files are stored
# set up the client to interface with the database
client = chromadb.Client(
    Settings(
        persist_directory=db_directory,  # location of database files
        chroma_db_impl="duckdb+parquet",  # type of database implementation
    )
)
collection_name = "persisted_collection"  # name of the collection of documents in the database
ef = embedding_functions.DefaultEmbeddingFunction()  # the function to generate embeddings for queries and documents


## These functions are for searching, adding to, and resetting the db

In [None]:
# function that takes a query text and searches it in the database
def query_db(query_text):
    try:
        collection = client.get_collection(collection_name)  # load the collection of documents
    except:
        print("There was an issue loading the collection.")
        return

    query_embedding = ef([query_text])  # generate an embedding for the query

    # search the database for documents that match the query
    results = collection.query(
        query_embeddings=query_embedding,
        n_results=5  # return the top 5 results
    )

    return results

# function to save a list of messages to the database
def save_messages_to_db(messages):
    collection = client.get_or_create_collection(name=collection_name)  # load the collection, or create it if it does not exist

    for message in messages:
        embedding = ef([message])  # generate an embedding for the message
        try:
            search = query_db(message)
            first_item_distance = search['distances'][0][0]
            if first_item_distance == 0:
                print(f"Message '{message}' is already in database. Skipped.")  # if the message is already in the database, do not add it again
            else:
                add_message_to_collection(collection, embedding, message)  # if the message is not in the database, add it
        except Exception as e:
            print(f"An error occurred while querying the database: {e}. Adding message '{message}' to the database.")  # if there was an error querying the database, add the message anyway
            add_message_to_collection(collection, embedding, message)

    client.persist()  # save changes to the database

# Helper function to add a message to a collection
def add_message_to_collection(collection, embedding, message):
    collection.add(
        embeddings=embedding,  # the embedding of the message
        documents=[message],  # the text of the message
        ids=[f"id{int(time.time())}{random.randint(0, 999999)}"],  # a unique ID for the message
    )
    print(f"Message '{message}' added to database.")  # log that the message was added

# function to delete all data in the database
def reset_db(client):
    client.reset()
    print("db reset")

# Save some messages to the database as embeddings

In [None]:
messages = ["Hello, world!", "How are you?", "Goodbye!"]
save_messages_to_db(messages)

/root/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx.tar.gz: 100%|██████████| 79.3M/79.3M [00:01<00:00, 64.2MiB/s]


An error occurred while querying the database: list index out of range. Adding message 'Hello, world!' to the database.
Message 'Hello, world!' added to database.
Message 'How are you?' added to database.




Message 'Goodbye!' added to database.


# Query the database
Lower distance means closest result. The next cell has easier to read output and does the same thing

In [None]:
# Define the text to query
query_text = "Hello"

# Use the `query_db` function from the `rememory` module to search for the query text in the database
# This function returns a dictionary with the details of the matching documents
results = query_db(query_text)

# Print the query text and the results
# The results include a list of document IDs, the text of the documents, and the "distance" from the query text for each document
# The "distance" is a measure of how closely the document matches the query text, with a lower value indicating a closer match
print(f"Query results for '{query_text}': {results}")




Query results for 'Hello': {'ids': [['id1689442926429066', 'id1689442926436832', 'id1689442926995715']], 'embeddings': None, 'documents': [['Hello, world!', 'Goodbye!', 'How are you?']], 'metadatas': [[None, None, None]], 'distances': [[0.6042104363441467, 1.148119568824768, 1.214174509048462]]}


# Same as above but with a nicer output structure 😎

In [None]:
# Define the text to query
query_text = "Hello"

# Use the `query_db` function from the `rememory` module to search for the query text in the database
# This function returns a dictionary with the details of the matching documents
results = query_db(query_text)

# Extract individual details (documents, ids and distances) from the results
documents = results['documents'][0]
ids = results['ids'][0]
distances = results['distances'][0]

# Print the query text
print(f"Query results for '{query_text}':\n")

# Loop through each document and print its details
for i in range(len(documents)):
    print(f"Document {i+1}:")  # 'i+1' because 'i' starts from 0, but we want the first document to be number 1
    print(f"\tID: {ids[i]}")  # Fetch and print the ID of the current document
    print(f"\tText: {documents[i]}")  # Fetch and print the text of the current document
    print(f"\tDistance: {distances[i]}\n")  # Fetch and print the distance of the current document




Query results for 'Hello':

Document 1:
	ID: id1689442926429066
	Text: Hello, world!
	Distance: 0.6042104363441467

Document 2:
	ID: id1689442926436832
	Text: Goodbye!
	Distance: 1.148119568824768

Document 3:
	ID: id1689442926995715
	Text: How are you?
	Distance: 1.214174509048462



# Wipe the database

In [None]:
reset_db(client)

db reset


# BONUS Summarization 🐵

In [None]:
!pip -qq install sentencepiece
!pip -qq install transformers

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m17.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.2/7.2 MB[0m [31m61.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m28.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m83.5 MB/s[0m eta [36m0:00:00[0m
[?25h

## Change the string below to the text you want to summarize

In [None]:
input_text = "The FitnessGram™ Pacer Test is a multistage aerobic capacity test that progressively gets more difficult as it continues. The 20 meter pacer test will begin in 30 seconds. Line up at the start. The running speed starts slowly, but gets faster each minute after you hear this signal. [beep] A single lap should be completed each time you hear this sound. [ding] Remember to run in a straight line, and run as long as possible. The second time you fail to complete a lap before the sound, your test is over. The test will begin on the word start. On your mark, get ready, start."
input_text

'The FitnessGram™ Pacer Test is a multistage aerobic capacity test that progressively gets more difficult as it continues. The 20 meter pacer test will begin in 30 seconds. Line up at the start. The running speed starts slowly, but gets faster each minute after you hear this signal. [beep] A single lap should be completed each time you hear this sound. [ding] Remember to run in a straight line, and run as long as possible. The second time you fail to complete a lap before the sound, your test is over. The test will begin on the word start. On your mark, get ready, start.'

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import unicodedata

## Tweak the params below to fit your needs

In [None]:
DEFAULT_SUMMARIZE_PARAMS = { "temperature": 1.0,"repetition_penalty": 1.0,"max_length": 150,"min_length": 50,"length_penalty": 1.5,"bad_words": ["\n",'"',"*","[","]","{","}",":","(",")","<",">","Â","The text ends","The story ends","The text is","The story is",],}

## Run this cell to set up the functions

In [None]:
# Choose GPU if available, else fall back to CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pre-trained summarization transformer model from Hugging Face and load the associated tokenizer
summarization_transformer = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn").to(device)
summarization_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")

# Define a recursive function to handle texts that are too long
def summarize_chunks(text: str, params: dict) -> str:
    try:
        # Try to summarize the text
        return summarize(text, params)
    except IndexError:
        # If text is too long, divide it by two and try again
        print("Sequence length too large for model, cutting text in half and calling again")
        new_params = params.copy()
        new_params["max_length"] = new_params["max_length"] // 2
        new_params["min_length"] = new_params["min_length"] // 2
        # Recursive call to summarize each half of the text
        return summarize_chunks(
            text[: (len(text) // 2)], new_params
        ) + summarize_chunks(text[(len(text) // 2):], new_params)

# Function to generate summary
def summarize(text: str, params: dict) -> str:
    # Tokenize the input text
    inputs = summarization_tokenizer(text, return_tensors="pt").to(device)
    token_count = len(inputs[0])

    # Process bad words (words we don't want in the summary)
    bad_words_ids = [
        summarization_tokenizer(bad_word, add_special_tokens=False).input_ids
        for bad_word in params["bad_words"]
    ]
    # Generate summary using transformer model
    summary_ids = summarization_transformer.generate(
        inputs["input_ids"],
        num_beams=2,
        max_length=max(token_count, int(params["max_length"])),
        min_length=min(token_count, int(params["min_length"])),
        repetition_penalty=float(params["repetition_penalty"]),
        temperature=float(params["temperature"]),
        length_penalty=float(params["length_penalty"]),
        bad_words_ids=bad_words_ids,
    )
    # Decode the summary
    summary = summarization_tokenizer.batch_decode(
        summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
    )[0]
    # Normalize the summary and return
    summary = normalize_string(summary)
    return summary

# Function to normalize string
def normalize_string(input: str) -> str:
    # Normalize unicode characters and remove extra spaces
    output = " ".join(unicodedata.normalize("NFKC", input).strip().split())
    return output

# Wrapper function to handle parameter defaults and print input/outputs
def local_summarize(text, params=None):
    # If no parameters specified, use defaults
    if params is None:
        params = DEFAULT_SUMMARIZE_PARAMS.copy()

    print("Summary input:", text, sep="\n")
    # Call to main summarize function
    summary = summarize_chunks(text, params)
    print("Summary output:", summary, sep="\n")
    return summary


## Finally execute the summarization

In [None]:
summary = summarize(input_text, DEFAULT_SUMMARIZE_PARAMS.copy())
summary

'The 20 meter pacer test will begin in 30 seconds. Line up at the start. The running speed starts slowly, but gets faster each minute after you hear this signal. A single lap should be completed each time you hear the signal. Remember to run in a straight line, and run as long as possible.'