# Step 1: Install libraries


In [4]:
! pip install -qU pymongo datasets langchain tiktoken sentence_transformers tqdm torch

# Step 2: Setup prerequisites

Replace:

- `<MONGODB_URI>` with your **MongoDB connection string**

In [5]:
import os
from pymongo import MongoClient

In [None]:
from google.colab import userdata

# Retain the quotes ("") when pasting the URI
MONGODB_URI = userdata.get("MONGODB_URI")
# Initialize a MongoDB Python client
mongodb_client = MongoClient(MONGODB_URI, appname="devrel.workshop.rag")
# Check the connection to the server
mongodb_client.admin.command("ping")

{'ok': 1}

# Step 3: Load the dataset


In [7]:
import pandas as pd
from datasets import load_dataset

In [41]:
data = load_dataset("durrah/flutter_questions_answers_", split="train", streaming=True)
data_head = data.take(200)
docs = pd.DataFrame(data_head).to_dict("records")

In [42]:
# Check the number of documents in the dataset
len(docs)

200

In [43]:
# Preview a document
docs[0]

{'answer': 'Flutter is an open-source mobile application development framework created by Google. It is used to build high-performance, visually attractive, and natively compiled mobile applications for both iOS and Android platforms from a single codebase.',
 'question': 'What is Flutter?'}

# Step 4: Chunk up the data


In [8]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from typing import Dict, List

In [45]:
# Separators to split on
separators = ["\n\n", "\n", " ", "", "#", "##", "###"]

📚 https://python.langchain.com/v0.1/docs/modules/data_connection/document_transformers/split_by_token/#tiktoken


In [47]:
# Use the `RecursiveCharacterTextSplitter` text splitter with the `cl100k_base` encoding
# For text data, you typically want to keep 1-2 paragraphs (~200 tokens) in a single chunk
# Chunk overlap of 15-20% of the chunk size is recommended
# Pass the `separators` list above as an argument called `separators`
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    encoding_name="cl100k_base", separators=separators, chunk_size=200, chunk_overlap=30
)

📚 https://api.python.langchain.com/en/latest/character/langchain_text_splitters.character.RecursiveCharacterTextSplitter.html

In [6]:
def get_chunks(doc: Dict, text_field: str) -> List[Dict]:
    """
    Chunk up a document.

    Args:
        doc (Dict): Parent document to generate chunks from.
        text_field (str): Text field to chunk.

    Returns:
        List[Dict]: List of chunked documents.
    """
    # Extract the field to chunk from `doc`
    text = doc[text_field]
    # Split `text` using the appropriate method of the `RecursiveCharacterTextSplitter` class
    # NOTE: `text` is a string
    chunks = text_splitter.split_text(text)

    # Iterate through `chunks` and for each chunk:
    # 1. Create a shallow copy of `doc`, call it `temp`
    # 2. Set the `text_field` field in `temp` to the content of the chunk
    # 3. Append `temp` to `chunked_data`
    chunked_data = []
    for chunk in chunks:
       temp = doc.copy()
       temp[text_field]=chunk
       chunked_data.append(temp)

    return chunked_data

In [49]:
split_docs = []

In [51]:
# Iterate through `docs`, use the `get_chunks` function to chunk up the documents based on the "body" field, and add the list of chunked documents to `split_docs` initialized above.
split_docs = []
for doc in docs:
    chunks = get_chunks(doc, "answer")
    split_docs.extend(chunks)

In [53]:
# Check that the length of the list of chunked documents is greater than the length of `docs`
len(split_docs)

200

In [54]:
# Preview one of the items in split_docs- ensure that it is a Python dictionary
split_docs[0]

{'answer': 'Flutter is an open-source mobile application development framework created by Google. It is used to build high-performance, visually attractive, and natively compiled mobile applications for both iOS and Android platforms from a single codebase.',
 'question': 'What is Flutter?'}

# Step 5: Generate embeddings


In [9]:
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

In [10]:
# Load the `gte-small` model using the Sentence Transformers library
embedding_model = SentenceTransformer("thenlper/gte-large")

📚 https://huggingface.co/thenlper/gte-small

In [11]:
# Define a function that takes a piece of text (`text`) as input, embeds it using the `embedding_model` instantiated above and returns the embedding as a list
# An array can be converted to a list using the `tolist()` method
def get_embedding(text: str) -> List[float]:
    """
    Generate the embedding for a piece of text.

    Args:
        text (str): Text to embed.

    Returns:
        List[float]: Embedding of the text as a list.
    """
    embedding = embedding_model.encode(text)

    return embedding.tolist()

In [58]:
embedded_docs = []

In [59]:
# Add an `embedding` field to each dictionary in `split_docs`
# The `embedding` field should correspond to the embedding of the value of the `body` field
# Use the `get_embedding` function defined above to generate the embedding
# Append the updated dictionaries to `embedded_docs` initialized above.
for doc in tqdm(split_docs):
    doc["embedding"] = get_embedding(doc["answer"])

    embedded_docs.append(doc)

100%|██████████| 200/200 [03:44<00:00,  1.12s/it]


In [60]:
# Check that the length of `embedded_docs` is the same as that of `split_docs`
len(embedded_docs)

200

# Step 6: Ingest data into MongoDB


In [12]:
# Name of the database -- Change if needed or leave as is
DB_NAME = "mongodb_rag_lab"
# Name of the collection -- Change if needed or leave as is
COLLECTION_NAME = "knowledge_base"
# Name of the vector search index -- Change if needed or leave as is
ATLAS_VECTOR_SEARCH_INDEX_NAME = "vector_index"

📚 https://pymongo.readthedocs.io/en/stable/tutorial.html#getting-a-database

📚 https://pymongo.readthedocs.io/en/stable/tutorial.html#getting-a-collection

In [13]:
# Connect to the collection defined above using the `mongodb_client` defined in Step 2
collection = mongodb_client[DB_NAME][COLLECTION_NAME]

In [63]:
# Bulk delete all existing records from the collection defined above
collection.delete_many({})

DeleteResult({'n': 202, 'electionId': ObjectId('7fffffff0000000000000115'), 'opTime': {'ts': Timestamp(1732898895, 42), 't': 277}, 'ok': 1.0, '$clusterTime': {'clusterTime': Timestamp(1732898895, 42), 'signature': {'hash': b'\xcc\xaf\xd1\xca\xae\x04\x1e\r\x04;\x9ef\xd9\xf5/B5\x1e0\x9f', 'keyId': 7390364571417968669}}, 'operationTime': Timestamp(1732898895, 42)}, acknowledged=True)

📚 https://pymongo.readthedocs.io/en/stable/examples/bulk.html#bulk-insert


In [64]:
# Bulk insert `embedded_docs` into the collection defined above -- should be a one-liner
collection.insert_many(embedded_docs)

print("Data ingestion into MongoDB completed")

Data ingestion into MongoDB completed


# Step 7: Create a vector search index

Follow the instructions in the documentation to create a Vector Search index in the Atlas UI.


# Step 8: Perform semantic search on your data


### Define a vector search function

📚 https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#fields

📚 https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#ann-examples (Refer to the "Basic Example")


In [14]:
# Define a function to retrieve relevant documents for a user query using vector search
def vector_search(user_query: str) -> List[Dict]:
    """
    Retrieve relevant documents for a user query using vector search.

    Args:
    user_query (str): The user's query string.

    Returns:
    list: A list of matching documents.
    """

    # Generate embedding for the `user_query` using the `get_embedding` function defined in Step 5
    query_embedding = get_embedding(user_query)

    # Define an aggregation pipeline consisting of a $vectorSearch stage, followed by a $project stage
    # Set the number of candidates to 150 and only return the top 5 documents from the vector search
    # In the $project stage, exclude the `_id` field and include only the `body` field and `vectorSearchScore`
    # NOTE: Use variables defined previously for the `index`, `queryVector` and `path` fields in the $vectorSearch stage
    pipeline = [
      {
          "$vectorSearch": {
              "index": ATLAS_VECTOR_SEARCH_INDEX_NAME,
              "queryVector": query_embedding,
              "path": "embedding",
              "numCandidates": 150,
              "limit": 5,
          }
      },
      {
          "$project": {
              "_id": 0,
              "answer": 1,
              "score": {"$meta": "vectorSearchScore"}
          }
      }
  ]

    # Execute the aggregation `pipeline` and store the results in `results`
    results = collection.aggregate(pipeline)

    return list(results)

### Run vector search queries


In [77]:
vector_search("What is Flutter?")

[{'answer': 'Flutter is an open-source mobile application development framework created by Google. It is used to build high-performance, high-fidelity, mobile apps for both Android and iOS platforms from a single codebase.',
  'score': 0.9542860984802246},
 {'answer': 'Flutter Dart Animation is the process of creating smooth, fluid, and visually appealing animations in a Flutter application, by leveraging the built-in animation widgets and APIs, as well as custom animation techniques and libraries.',
  'score': 0.9531720876693726},
 {'answer': 'Flutter is an open-source mobile application development framework created by Google. It is used to build high-performance, visually attractive, and natively compiled mobile applications for both iOS and Android platforms from a single codebase.',
  'score': 0.9514681100845337},
 {'answer': 'Flutter Dart Animations, Gestures, User Interactions, and Motion Design is the system that allows Flutter applications to create smooth, fluid, and visually

In [78]:
vector_search("Explain stateless and stateful in Flutter")

[{'answer': "Flutter's widget system is based on the concept of stateful and stateless widgets. Stateless widgets are immutable and do not maintain any internal state, while stateful widgets can manage their own internal state and respond to changes dynamically. Understanding the differences between these two types of widgets is crucial for building efficient and responsive Flutter applications that can handle user interactions and state changes effectively.",
  'score': 0.9704997539520264},
 {'answer': "Flutter's widget system is based on the concept of stateful and stateless widgets. Stateless widgets are immutable and do not maintain any internal state, while stateful widgets can manage their own internal state and respond to changes dynamically.",
  'score': 0.9669656753540039},
 {'answer': 'Stateless widgets are immutable and do not change their appearance during runtime, while Stateful widgets can change their appearance based on user interactions or other events.',
  'score': 0.

# 🦹‍♀️ Combine pre-filtering with vector search

📚 https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/#about-the-filter-type

📚 https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#ann-examples (Refer to the "Filter Example")


### Filter for documents where the content type is `Video`

Modify the vector search index definition (from Step 7) in the Atlas UI to include the `metadata.contentType` field as a `filter` field

# Step 9: Build the RAG application


### Instantiate a chat model


In [23]:
# Initializing the client and model string
# model = "google/gemma-2-27b"
model = "llama-duo/gemma7b-summarize-gemini1_5flash-128k"

In [2]:
! pip install -q -U google-generativeai

In [18]:
import google.generativeai as genai
from google.colab import userdata

gemini_api = userdata.get("GEMINI_API")

genai.configure(api_key=gemini_api)

### Define a function to create the chat prompt

In [19]:
# Define a function to create the user prompt for our RAG application
def create_prompt(user_query: str) -> str:
    """
    Create a chat prompt that includes the user query and retrieved context.

    Args:
        user_query (str): The user's query string.

    Returns:
        str: The chat prompt string.
    """
    # Retrieve the most relevant documents for the `user_query` using the `vector_search` function defined in Step 8
    context = vector_search(user_query)

    # Join the retrieved documents into a single string, where each document is separated by two new lines ("\n\n")
    context = "\n\n".join([doc.get('answer') for doc in context])
    # Prompt consisting of the question and relevant context to answer it
    prompt = f"Answer the question based only on the following context. If the context is empty, say I DON'T KNOW\n\nContext:\n{context}\n\nQuestion:{user_query}"
    return prompt

### Define a function to answer user queries

📚 https://docs.fireworks.ai/guides/querying-text-models#chat-completions-api

In [22]:
import torch
from transformers import pipeline


# Define a function to answer user queries using Fireworks' Chat Completion API
def generate_answer(user_query: str) -> None:
    """
    Generate an answer to the user query.

    Args:
        user_query (str): The user's query string.
    """
    # Use the `create_prompt` function above to create a chat prompt
    prompt = create_prompt(user_query)

    # Use the `prompt` created above to populate the `content` field in the chat message
    # pipe = pipeline("text-generation", model=model)
    # messages = [
    #     {"role": "user", "content": prompt},
    # ]

    # outputs = pipe(messages)
    # assistant_response = outputs[0]["generated_text"][-1]["content"].strip()
    # print(assistant_response)

    messages = [
        {"role": "user", "content": prompt},
    ]
    model = genai.GenerativeModel("gemini-1.5-flash")
    response = model.generate_content(prompt)

    print(response.text)


### Query the RAG application


In [23]:
generate_answer("What is Flutter?")

Flutter is an open-source mobile application development framework created by Google.  It is used to build high-performance, high-fidelity, visually attractive, and natively compiled mobile apps for both Android and iOS platforms from a single codebase.



In [24]:
generate_answer("Explain Scaffold?")

I DON'T KNOW



# 🦹‍♀️ Re-rank retrieved results


In [None]:
from sentence_transformers import CrossEncoder

In [None]:
rerank_model = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1")

config.json:   0%|          | 0.00/968 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/142M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.45k [00:00<?, ?B/s]

spm.model:   0%|          | 0.00/2.46M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/8.65M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/23.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/970 [00:00<?, ?B/s]

📚 https://huggingface.co/mixedbread-ai/mxbai-rerank-xsmall-v1

In [None]:
# Add a re-ranking step to the following function
def create_prompt(user_query: str) -> str:
    """
    Create a chat prompt that includes the user query and retrieved context.

    Args:
        user_query (str): The user's query string.

    Returns:
        str: The chat prompt string.
    """
    # Retrieve the most relevant documents for the `user_query` using the `vector_search` function defined in Step 8
    context = vector_search(user_query)
    # Extract the "body" field from each document in `context`
    documents = [d.get("answer") for d in context]
    # Use the `rerank_model` instantiated above to re-rank `documents`
    # Set the `top_k` argument to 5
    reranked_documents = rerank_model.rank(
        user_query, documents, return_documents=True, top_k=5
    )
    # Join the re-ranked documents into a single string, where each document is separated by two new lines ("\n\n")
    context = "\n\n".join([d.get("answer", "") for d in reranked_documents])
    # Prompt consisting of the question and relevant context to answer it
    prompt = f"Answer the question based only on the following context. If the context is empty, say I DON'T KNOW\n\nContext:\n{context}\n\nQuestion:{user_query}"
    return prompt

In [None]:
# Note the impact of re-ranking on the generated answer
generate_answer("What are triggers in MongoDB Atlas?")

According to the context, triggers in MongoDB Atlas are a feature that allows you to create a trigger that monitors all changes in a certain collection (in this case, the `test` collection) for `insert`, `update`, and `delete` operations.


# Step 10: Add memory to the RAG application


In [25]:
from datetime import datetime

In [26]:
history_collection = mongodb_client[DB_NAME]["chat_history"]

📚 https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.create_index


In [45]:
# Create an index on the key `session_id` for the `history_collection` collection
history_collection.create_index("session_id")

'session_id_1'

### Define a function to store chat messages in MongoDB

📚 https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.insert_one

In [46]:
def store_chat_message(session_id: str, role: str, content: str) -> None:
    """
    Store a chat message in a MongoDB collection.

    Args:
        session_id (str): Session ID of the message.
        role (str): Role for the message. One of `system`, `user` or `assistant`.
        content (str): Content of the message.
    """
    # Create a message object with `session_id`, `role`, `content` and `timestamp` fields
    # `timestamp` should be set the current timestamp
    message = {
        "session_id": session_id,
        "role": role,
        "parts": content,
        "timestamp": datetime.now(),
    }
    # Insert the `message` into the `history_collection` collection
    history_collection.insert_one(message)



### Define a function to retrieve chat history from MongoDB

📚 https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.find

📚 https://pymongo.readthedocs.io/en/stable/api/pymongo/cursor.html#pymongo.cursor.Cursor.sort

In [47]:
def retrieve_session_history(session_id: str) -> List:
    """
    Retrieve chat message history for a particular session.

    Args:
        session_id (str): Session ID to retrieve chat message history for.

    Returns:
        List: List of chat messages.
    """
    # Query the `history_collection` collection for documents where the "session_id" field has the value of the input `session_id`
    # Sort the results in increasing order of the values in `timestamp` field
    cursor =  history_collection.find({"session_id": session_id}).sort("timestamp", 1)

    if cursor:
        # Iterate through the cursor and extract the `role` and `content` field from each entry
        # Then format each entry as: {"role": <role_value>, "content": <content_value>}
        messages = [{"role": msg["role"], "parts": msg["parts"]} for msg in cursor]
    else:
        # If cursor is empty, return an empty list
        messages = []

    return messages

### Handle chat history in the `generate_answer` function

📚 https://docs.fireworks.ai/guides/querying-text-models#chat-completions-api


In [55]:
def generate_answer(session_id: str, user_query: str) -> None:
    """
    Generate an answer to the user's query taking chat history into account.

    Args:
        session_id (str): Session ID to retrieve chat history for.
        user_query (str): The user's query string.
    """
    # Initialize list of messages to pass to the chat completion model
    messages = []

    # Retrieve documents relevant to the user query and convert them to a single string
    context = vector_search(user_query)
    context = "\n\n".join([d.get("answer", "") for d in context])
    # Create a system prompt containing the retrieved context
    # system_message = {
    #     "role": "system",
    #     "parts": f"Answer the question based only on the following context. If the context is empty, say I DON'T KNOW\n\nContext:\n{context}",
    # }
    # Append the system prompt to the `messages` list
    # messages.append(system_message)

    # Use the `retrieve_session_history` function to retrieve message history from MongoDB for the session ID `session_id`
    # And add all messages in the message history to the `messages` list
    message_history = retrieve_session_history(session_id)

    messages.extend(message_history)

    # Format the user message in the format {"role": <role_value>, "content": <content_value>}
    # The role value for user messages must be "user"
    # And append the user message to the `messages` list
    user_message = {"role": "user", "parts": user_query}


    messages.append(user_message)

    # Call the chat completions API
    # pipe = pipeline("text-generation", model=model)
    # outputs = pipe(messages)
    # answer = outputs[0]["generated_text"][-1]["content"].strip()

    model = genai.GenerativeModel(
        "gemini-1.5-pro",
        system_instruction=[
        f"Answer the question based only on the following context. If the context is empty, say I DON'T KNOW\n\nContext:\n{context}"
    ],
    )

    chat = model.start_chat(
        history=messages
    )
    response = chat.send_message(user_query)
    answer = response.text

    # Use the `store_chat_message` function to store the user message and also the generated answer in the message history collection
    # The role value for user messages is "user", and "assistant" for the generated answer
    store_chat_message(session_id, "user", user_query)
    store_chat_message(session_id, "assistant", answer)

    print(answer)

In [52]:
generate_answer(
    session_id="2",
    user_query="How to work with Flutter Button",
)

response:
GenerateContentResponse(
    done=True,
    iterator=None,
    result=protos.GenerateContentResponse({
      "candidates": [
        {
          "content": {
            "parts": [
              {
                "text": "This document does not contain the answer to how to work with Flutter Buttons. While it mentions that buttons are UI elements included in the Flutter Widgets library, it does not provide any implementation details or examples of how to use them.\n"
              }
            ],
            "role": "model"
          },
          "finish_reason": "STOP",
          "avg_logprobs": -0.22818650370058807
        }
      ],
      "usage_metadata": {
        "prompt_token_count": 252,
        "candidates_token_count": 46,
        "total_token_count": 298
      }
    }),
)
This document does not contain the answer to how to work with Flutter Buttons. While it mentions that buttons are UI elements included in the Flutter Widgets library, it does not provide any imple

In [54]:
generate_answer(
    session_id="2",
    user_query="Flutter difference towards other framework",
)

response:
GenerateContentResponse(
    done=True,
    iterator=None,
    result=protos.GenerateContentResponse({
      "candidates": [
        {
          "content": {
            "parts": [
              {
                "text": "Flutter differs from other cross-platform frameworks in several key ways:\n\n* **Rendering:** Unlike frameworks that use WebViews, Flutter uses its own rendering engine and widgets. This leads to faster and more efficient performance.\n* **Native Compilation:** Flutter compiles to native ARM code, avoiding the need for a bridge to communicate with the native platform. This results in better performance and a more native look and feel.\n* **Cross-Platform Reach:** Flutter supports building applications for mobile (iOS and Android), web, and desktop platforms from a single codebase.  This allows developers to create unified user experiences across multiple platforms.\n"
              }
            ],
            "role": "model"
          },
          "finish_r

In [56]:
generate_answer(
    session_id="2",
    user_query="What did I just ask?",
)

You just asked about the differences between Flutter and other frameworks.

