[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mongodb-developer/GenAI-Showcase/blob/main/notebooks/rag/rag_mongodb_llama3_huggingface_open_source.ipynb)

# Implementing RAG pipelines with MongoDB and Llama3 and Open models From Hugging Face

This notebook is designed to demonstrate how to integrate and utilize Hugging Face's open-source models, specifically Llama3, with MongoDB to implement Retrieval-Augmented Generation (RAG) pipelines for enhanced question answering capabilities.

The process involves preparing a dataset of arXiv papers, transforming their data for effective retrieval, setting up a MongoDB database with vector search capabilities, and using llama3 model for generating answers based on the retrieved documents.

Key Highlights:
- Usage of Hugging Face open-source models and MongoDB for creating RAG pipelines.
- Steps include dataset preparation, database setup, data ingestion, and query processing.
- Detailed guidance on setting up MongoDB collections and vector search indexes.
- Integration with the Llama3 model from Hugging Face for answering complex queries.

Follow the following instruction to set up a MongoDB database and enable vector search:
1. [Register a free Atlas account](https://account.mongodb.com/account/register?utm_campaign=devrel&utm_source=community&utm_medium=cta&utm_content=GitHub%20Cookbook&utm_term=richmond.alake)
 or sign in to your existing Atlas account.

2. [Follow the instructions](https://www.mongodb.com/docs/atlas/tutorial/deploy-free-tier-cluster/)
 (select Atlas UI as the procedure) to deploy your first cluster, which distributes your data across multiple servers for improved performance and redundancy.

 ![image.png](attachment:image.png)

3. For a free Cluser, be sure to select "Shared" option when creating your new cluster. See image below for details

![image-2.png](attachment:image-2.png)

4. Create the database: `knowledge_base`, and collection `research_papers`



## Import Libraries

Import libaries into development environment

In [2]:
!pip install datasets pandas pymongo sentence_transformers
!pip install -U transformers
# Install the library below if using GPU, if using CPU, please comment out below
!pip install accelerate

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting pymongo
  Downloading pymongo-4.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (22 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Collecting dnspython<3.0.0,>=1.16.0 (from pymongo)
  Downloading dnspython-2.7.0-py3-none-any.whl.metadata (5.8 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m10.2 MB/s[0m eta [36

## Dataset Loading and Preparation

Load the dataset from HuggingFace.

Only using the first 100 datapoint for demo purposes.

In [4]:
# Load Dataset
import os

import pandas as pd
from datasets import load_dataset

# Make sure you have an Hugging Face token(HF_TOKEN) in your development environemnt before running the code below
# How to get a token: https://huggingface.co/docs/hub/en/security-tokens
# Dataset Location: https://huggingface.co/datasets/MongoDB/subset_arxiv_papers_with_embeddings
os.environ["HF_TOKEN"] = (
    "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"  # Do not use this in production environment, use a .env file instead
)

dataset = load_dataset("Xmm/stock-data-dataset", 'MSFT')

# Convert the dataset to a pandas dataframe
dataset_df = pd.DataFrame(dataset["train"])

dataset_df.head(5)

train-00000-of-00001.parquet:   0%|          | 0.00/95.0k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2517 [00:00<?, ? examples/s]

Unnamed: 0,Date,Adj Close,Volume,Open,High,Low
0,09/06/2024,401.7,19609530,409.06,410.65,400.8
1,09/05/2024,408.39,14195520,407.62,413.1,406.13
2,09/04/2024,408.9,15135810,405.91,411.24,404.37
3,09/03/2024,409.44,20313600,417.91,419.88,407.03
4,08/30/2024,417.14,24308320,415.6,417.49,412.13


In [34]:
# Data Preparation

# # Only use the first 100 for demo and POC purposes
dataset_df = dataset_df.head(100)

# # Remove the embedding from each data point in the dataset as we are going to create new embeddings with an open source embedding model from Hugging Face
dataset_df = dataset_df.drop(columns=["embedding"])
dataset_df.head(5)

Unnamed: 0,Date,Adj Close,Volume,Open,High,Low
0,09/06/2024,401.7,19609530,409.06,410.65,400.8
1,09/05/2024,408.39,14195520,407.62,413.1,406.13
2,09/04/2024,408.9,15135810,405.91,411.24,404.37
3,09/03/2024,409.44,20313600,417.91,419.88,407.03
4,08/30/2024,417.14,24308320,415.6,417.49,412.13


## Generate Embeddings

In [35]:
from sentence_transformers import SentenceTransformer

# https://huggingface.co/thenlper/gte-large
embedding_model = SentenceTransformer("thenlper/gte-large")


def get_embedding(text: str) -> list[float]:
    if not text.strip():
        print("Attempted to get embedding for empty text.")
        return []

    embedding = embedding_model.encode(text)

    return embedding.tolist()


dataset_df["embedding"] = dataset_df.apply(
    lambda x: get_embedding(str(x["Adj Close"]) + " " + str(x["Open"]) + " " + str(x["High"])),
    axis=1,
)

dataset_df.head()

Unnamed: 0,Date,Adj Close,Volume,Open,High,Low,embedding
0,09/06/2024,401.7,19609530,409.06,410.65,400.8,"[0.0006415346870198846, 0.010343987494707108, ..."
1,09/05/2024,408.39,14195520,407.62,413.1,406.13,"[0.0012050988152623177, 0.006467541214078665, ..."
2,09/04/2024,408.9,15135810,405.91,411.24,404.37,"[-0.010366330854594707, 0.02119181677699089, -..."
3,09/03/2024,409.44,20313600,417.91,419.88,407.03,"[0.0059565408155322075, 0.014039739966392517, ..."
4,08/30/2024,417.14,24308320,415.6,417.49,412.13,"[0.007199700456112623, 0.016408847644925117, 0..."


## Database and Collection Setup

Complete the steps below if not already carried out previously.
Creating a database and collection within MongoDB is made simple with MongoDB Atlas.

1. [Register a free Atlas account](https://account.mongodb.com/account/register?utm_campaign=devrel&utm_source=community&utm_medium=cta&utm_content=GitHub%20Cookbook&utm_term=richmond.alake)
 or sign in to your existing Atlas account.

2. [Follow the instructions](https://www.mongodb.com/docs/atlas/tutorial/deploy-free-tier-cluster/)
 (select Atlas UI as the procedure)  to deploy your first cluster.

3. Create the database: `knowledge_base`.
4. Within the database` knowledge_base`, create the following collections: `research_papers`
5. Create a
[vector search index](https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#procedure)
 named `vector_index` for the `research_papers` collection. This index enables the RAG application to retrieve records as additional context to supplement user queries via vector search. Below is the JSON definition of the data collection vector search index.


 Below is a snipper of what the vector search index definition should looks like:
 ```
    {
      "fields": [
        {
          "numDimensions": 1024,
          "path": "embedding",
          "similarity": "cosine",
          "type": "vector"
        }
      ]
    }
  ```

In [49]:
import pymongo


def get_mongo_client(mongo_uri):
    """Establish connection to the MongoDB."""
    try:
        client = pymongo.MongoClient(mongo_uri, appname="devrel.content.python",tls=True, tlsAllowInvalidCertificates=True)
        print("Connection to MongoDB successful")
        return client
    except pymongo.errors.ConnectionFailure as e:
        print(f"Connection failed: {e}")
        return None


# Ensure connection strings are placed securely in environment variables and not disclosed in production environments.
mongo_uri = "mongodb+srv://tharunkumargofficial:Cpj4qoCLPe8WJzSF@stockdata.n8xo4.mongodb.net/"  # Placeholder, replace with your connection string or actual environment variable fetching method.

if not mongo_uri:
    print("MONGO_URI not set in environment variables.")

mongo_client = get_mongo_client(mongo_uri)

# Ingest data into MongoDB
db = mongo_client["sample_mflix"]
collection = db["theaters"]






Connection to MongoDB successful


In [50]:
# Delete any existing records in the collection
collection.delete_many({})

ServerSelectionTimeoutError: SSL handshake failed: stockdata-shard-00-00.n8xo4.mongodb.net:27017: [SSL: TLSV1_ALERT_INTERNAL_ERROR] tlsv1 alert internal error (_ssl.c:1007) (configured timeouts: socketTimeoutMS: 20000.0ms, connectTimeoutMS: 20000.0ms),SSL handshake failed: stockdata-shard-00-01.n8xo4.mongodb.net:27017: [SSL: TLSV1_ALERT_INTERNAL_ERROR] tlsv1 alert internal error (_ssl.c:1007) (configured timeouts: socketTimeoutMS: 20000.0ms, connectTimeoutMS: 20000.0ms),SSL handshake failed: stockdata-shard-00-02.n8xo4.mongodb.net:27017: [SSL: TLSV1_ALERT_INTERNAL_ERROR] tlsv1 alert internal error (_ssl.c:1007) (configured timeouts: socketTimeoutMS: 20000.0ms, connectTimeoutMS: 20000.0ms), Timeout: 30s, Topology Description: <TopologyDescription id: 677150a1c4cc89eb0f93eb17, topology_type: ReplicaSetNoPrimary, servers: [<ServerDescription ('stockdata-shard-00-00.n8xo4.mongodb.net', 27017) server_type: Unknown, rtt: None, error=AutoReconnect('SSL handshake failed: stockdata-shard-00-00.n8xo4.mongodb.net:27017: [SSL: TLSV1_ALERT_INTERNAL_ERROR] tlsv1 alert internal error (_ssl.c:1007) (configured timeouts: socketTimeoutMS: 20000.0ms, connectTimeoutMS: 20000.0ms)')>, <ServerDescription ('stockdata-shard-00-01.n8xo4.mongodb.net', 27017) server_type: Unknown, rtt: None, error=AutoReconnect('SSL handshake failed: stockdata-shard-00-01.n8xo4.mongodb.net:27017: [SSL: TLSV1_ALERT_INTERNAL_ERROR] tlsv1 alert internal error (_ssl.c:1007) (configured timeouts: socketTimeoutMS: 20000.0ms, connectTimeoutMS: 20000.0ms)')>, <ServerDescription ('stockdata-shard-00-02.n8xo4.mongodb.net', 27017) server_type: Unknown, rtt: None, error=AutoReconnect('SSL handshake failed: stockdata-shard-00-02.n8xo4.mongodb.net:27017: [SSL: TLSV1_ALERT_INTERNAL_ERROR] tlsv1 alert internal error (_ssl.c:1007) (configured timeouts: socketTimeoutMS: 20000.0ms, connectTimeoutMS: 20000.0ms)')>]>

## Data Ingestion

In [41]:

documents = dataset_df.to_dict("records")
collection.insert_many(documents)

print("Data ingestion into MongoDB completed")

ServerSelectionTimeoutError: SSL handshake failed: stockdata-shard-00-00.n8xo4.mongodb.net:27017: [SSL: TLSV1_ALERT_INTERNAL_ERROR] tlsv1 alert internal error (_ssl.c:1007) (configured timeouts: socketTimeoutMS: 20000.0ms, connectTimeoutMS: 20000.0ms),SSL handshake failed: stockdata-shard-00-01.n8xo4.mongodb.net:27017: [SSL: TLSV1_ALERT_INTERNAL_ERROR] tlsv1 alert internal error (_ssl.c:1007) (configured timeouts: socketTimeoutMS: 20000.0ms, connectTimeoutMS: 20000.0ms),SSL handshake failed: stockdata-shard-00-02.n8xo4.mongodb.net:27017: [SSL: TLSV1_ALERT_INTERNAL_ERROR] tlsv1 alert internal error (_ssl.c:1007) (configured timeouts: socketTimeoutMS: 20000.0ms, connectTimeoutMS: 20000.0ms), Timeout: 30s, Topology Description: <TopologyDescription id: 67714c2ac4cc89eb0f93eab0, topology_type: ReplicaSetNoPrimary, servers: [<ServerDescription ('stockdata-shard-00-00.n8xo4.mongodb.net', 27017) server_type: Unknown, rtt: None, error=AutoReconnect('SSL handshake failed: stockdata-shard-00-00.n8xo4.mongodb.net:27017: [SSL: TLSV1_ALERT_INTERNAL_ERROR] tlsv1 alert internal error (_ssl.c:1007) (configured timeouts: socketTimeoutMS: 20000.0ms, connectTimeoutMS: 20000.0ms)')>, <ServerDescription ('stockdata-shard-00-01.n8xo4.mongodb.net', 27017) server_type: Unknown, rtt: None, error=AutoReconnect('SSL handshake failed: stockdata-shard-00-01.n8xo4.mongodb.net:27017: [SSL: TLSV1_ALERT_INTERNAL_ERROR] tlsv1 alert internal error (_ssl.c:1007) (configured timeouts: socketTimeoutMS: 20000.0ms, connectTimeoutMS: 20000.0ms)')>, <ServerDescription ('stockdata-shard-00-02.n8xo4.mongodb.net', 27017) server_type: Unknown, rtt: None, error=AutoReconnect('SSL handshake failed: stockdata-shard-00-02.n8xo4.mongodb.net:27017: [SSL: TLSV1_ALERT_INTERNAL_ERROR] tlsv1 alert internal error (_ssl.c:1007) (configured timeouts: socketTimeoutMS: 20000.0ms, connectTimeoutMS: 20000.0ms)')>]>

## Vector Search

In [51]:
def vector_search(user_query, collection):
    """
    Perform a vector search in the MongoDB collection based on the user query.

    Args:
    user_query (str): The user's query string.
    collection (MongoCollection): The MongoDB collection to search.

    Returns:
    list: A list of matching documents.
    """

    # Generate embedding for the user query
    query_embedding = get_embedding(user_query)

    if query_embedding is None:
        return "Invalid query or embedding generation failed."

    # Define the vector search pipeline
    vector_search_stage = {
        "$vectorSearch": {
            "index": "vector_index",
            "queryVector": query_embedding,
            "path": "embedding",
            "numCandidates": 150,  # Number of candidate matches to consider
            "limit": 4,  # Return top 4 matches
        }
    }

    unset_stage = {
        "$unset": "embedding"  # Exclude the 'embedding' field from the results
    }

    project_stage = {
        "$project": {
            "_id": 0,  # Exclude the _id field
            "fullplot": 1,  # Include the plot field
            "title": 1,  # Include the title field
            "genres": 1,  # Include the genres field
            "score": {
                "$meta": "vectorSearchScore"  # Include the search score
            },
        }
    }

    pipeline = [vector_search_stage, unset_stage, project_stage]

    # Execute the search
    results = collection.aggregate(pipeline)
    return list(results)

In [30]:
def get_search_result(query, collection):
    get_knowledge = vector_search(query, collection)

    search_result = ""
    for result in get_knowledge:
        search_result += f"Title: {result.get('title', 'N/A')}, Plot: {result.get('abstract', 'N/A')}\n"

    return search_result

## Handling User Queries

In [52]:
# Conduct query with retrival of sources
query = "Get me papers on Artificial Intelligence?"
source_information = get_search_result(query, collection)
combined_information = f"Query: {query}\nContinue to answer the query by using the Search Results:\n{source_information}."
messages = [
    {"role": "system", "content": "You are a research assitant!"},
    {"role": "user", "content": combined_information},
]
print(messages)

ServerSelectionTimeoutError: SSL handshake failed: stockdata-shard-00-00.n8xo4.mongodb.net:27017: [SSL: TLSV1_ALERT_INTERNAL_ERROR] tlsv1 alert internal error (_ssl.c:1007) (configured timeouts: socketTimeoutMS: 20000.0ms, connectTimeoutMS: 20000.0ms),SSL handshake failed: stockdata-shard-00-01.n8xo4.mongodb.net:27017: [SSL: TLSV1_ALERT_INTERNAL_ERROR] tlsv1 alert internal error (_ssl.c:1007) (configured timeouts: socketTimeoutMS: 20000.0ms, connectTimeoutMS: 20000.0ms),SSL handshake failed: stockdata-shard-00-02.n8xo4.mongodb.net:27017: [SSL: TLSV1_ALERT_INTERNAL_ERROR] tlsv1 alert internal error (_ssl.c:1007) (configured timeouts: socketTimeoutMS: 20000.0ms, connectTimeoutMS: 20000.0ms), Timeout: 30s, Topology Description: <TopologyDescription id: 677150a1c4cc89eb0f93eb17, topology_type: ReplicaSetNoPrimary, servers: [<ServerDescription ('stockdata-shard-00-00.n8xo4.mongodb.net', 27017) server_type: Unknown, rtt: None, error=AutoReconnect('SSL handshake failed: stockdata-shard-00-00.n8xo4.mongodb.net:27017: [SSL: TLSV1_ALERT_INTERNAL_ERROR] tlsv1 alert internal error (_ssl.c:1007) (configured timeouts: socketTimeoutMS: 20000.0ms, connectTimeoutMS: 20000.0ms)')>, <ServerDescription ('stockdata-shard-00-01.n8xo4.mongodb.net', 27017) server_type: Unknown, rtt: None, error=AutoReconnect('SSL handshake failed: stockdata-shard-00-01.n8xo4.mongodb.net:27017: [SSL: TLSV1_ALERT_INTERNAL_ERROR] tlsv1 alert internal error (_ssl.c:1007) (configured timeouts: socketTimeoutMS: 20000.0ms, connectTimeoutMS: 20000.0ms)')>, <ServerDescription ('stockdata-shard-00-02.n8xo4.mongodb.net', 27017) server_type: Unknown, rtt: None, error=AutoReconnect('SSL handshake failed: stockdata-shard-00-02.n8xo4.mongodb.net:27017: [SSL: TLSV1_ALERT_INTERNAL_ERROR] tlsv1 alert internal error (_ssl.c:1007) (configured timeouts: socketTimeoutMS: 20000.0ms, connectTimeoutMS: 20000.0ms)')>]>

## Loading and Using Llama3

In [43]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
# CPU Enabled uncomment below 👇🏽
# model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
# GPU Enabled use below 👇🏽
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3-8B-Instruct", torch_dtype=torch.bfloat16, device_map="auto"
)

OSError: You are trying to access a gated repo.
Make sure to have access to it at https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct.
403 Client Error. (Request ID: Root=1-67714df8-0ed12c7a09964c8f19f7ab55;b428aa83-efe0-49a0-a62c-f07e5191e448)

Cannot access gated repo for url https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/resolve/main/config.json.
Access to model meta-llama/Meta-Llama-3-8B-Instruct is restricted and you are not in the authorized list. Visit https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct to ask for access.

In [33]:
input_ids = tokenizer.apply_chat_template(
    messages, add_generation_prompt=True, return_tensors="pt"
).to(model.device)

terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]

outputs = model.generate(
    input_ids,
    max_new_tokens=256,
    eos_token_id=terminators,
    do_sample=True,
    temperature=0.6,
    top_p=0.9,
)
response = outputs[0][input_ids.shape[-1] :]
print(tokenizer.decode(response, skip_special_tokens=True))

NameError: name 'tokenizer' is not defined