# RAG-on-GKE Application

This is a Python notebook for generating the vector embeddings based on [Kubernetes docs](https://github.com/dohsimpson/kubernetes-doc-pdf/) used by the RAG on GKE application.   
For full information, please checkout the GitHub documentation [here](https://github.com/GoogleCloudPlatform/ai-on-gke/blob/main/applications/rag/README.md).


## Clone the kubernetes docs repo

In [None]:
!mkdir /data/kubernetes-docs -p
!git clone https://github.com/dohsimpson/kubernetes-doc-pdf /data/kubernetes-docs


## Install the required packages

In [None]:
!pip install langchain langchain-community sentence_transformers pypdf

## Writting job to be used on the Ray Cluster

In [None]:
# Create a directory to package the contents that need to be downloaded in ray worker
! mkdir -p rag-app

In [None]:
%%writefile rag-app/job.py

import os
import uuid
import glob

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.document_loaders import PyPDFLoader

from google.cloud.sql.connector import Connector, IPTypes
import sqlalchemy

from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, String, Text, text, JSON
from sqlalchemy.orm import scoped_session, sessionmaker, mapped_column
from pgvector.sqlalchemy import Vector

# initialize parameters

INSTANCE_CONNECTION_NAME = os.environ["CLOUDSQL_INSTANCE_CONNECTION_NAME"]
print(f"Your instance connection name is: {INSTANCE_CONNECTION_NAME}")
VECTOR_EMBEDDINGS_TABLE_NAME = "rag_embeddings_db"
DB_NAME = "pgvector-database"

db_username_file = open("/etc/secret-volume/username", "r")
DB_USER = db_username_file.read()
db_username_file.close()

db_password_file = open("/etc/secret-volume/password", "r")
DB_PASS = db_password_file.read()
db_password_file.close()

# initialize Connector object
connector = Connector()

# function to return the database connection object
def getconn():
    conn = connector.connect(
        INSTANCE_CONNECTION_NAME,
        "pg8000",
        user=DB_USER,
        password=DB_PASS,
        db=DB_NAME,
        ip_type=IPTypes.PRIVATE
    )
    return conn

# create connection pool with 'creator' argument to our connection object function
pool = sqlalchemy.create_engine(
    "postgresql+pg8000://",
    creator=getconn,
)

Base = declarative_base()
DBSession = scoped_session(sessionmaker())

class TextEmbedding(Base):
    __tablename__ = VECTOR_EMBEDDINGS_TABLE_NAME
    id = Column(String(255), primary_key=True)
    text = Column(Text)
    text_embedding = mapped_column(Vector(384))

with pool.connect() as conn:
    conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
    conn.commit() 
    
DBSession.configure(bind=pool, autoflush=False, expire_on_commit=False)
Base.metadata.drop_all(pool)
Base.metadata.create_all(pool)

SENTENCE_TRANSFORMER_MODEL = "intfloat/multilingual-e5-small"  # Transformer to use for converting text chunks to vector embeddings

# the dataset has been pre-dowloaded to the GCS bucket as part of the notebook in the cell above. Ray workers will find the dataset readily mounted.
SHARED_DATASET_BASE_PATH = "/data/kubernetes-docs/"

CHUNK_SIZE = 1000  # text chunk sizes which will be converted to vector embeddings
CHUNK_OVERLAP = 10
VECTOR_DIMENSION = 384  # Embeddings size

splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, length_function=len)
embeddings_service = HuggingFaceEmbeddings(model_name=SENTENCE_TRANSFORMER_MODEL)

def process_pdf(file_path):
    """Loads, splits and embed a single PDF file."""
    loader = PyPDFLoader(file_path)
    print(f"Loading {file_path}")
    pages = loader.load_and_split()
    
    splits = splitter.split_documents(pages)

    chunks = []
    for split in splits:
        id = uuid.uuid4()
        page_content = split.page_content
        file_metadata = split.metadata
        embedded_document = embeddings_service.embed_query(page_content)
        split_data = {
            "langchain_id" : id,
            "content" : page_content,
            "embedding" : embedded_document,
            "langchain_metadata" : file_metadata
        }
        chunks.append(split_data)
    return chunks

documents_file_paths = glob.glob(f"{SHARED_DATASET_BASE_PATH}/PDFs/*.pdf")
for file_path in documents_file_paths:
    processed_result = process_pdf(file_path)
    DBSession.bulk_insert_mappings(TextEmbedding, processed_result)
        
DBSession.commit()

#Verifying the results.

query_text = "What's kubernetes?" 
query_emb = embeddings_service.embed_query(query_text).tolist()
query_request = "SELECT id, text, text_embedding, 1 - ('[" + ",".join(map(str, query_emb)) + "]' <=> text_embedding) AS cosine_similarity FROM " + TABLE_NAME + " ORDER BY cosine_similarity DESC LIMIT 5;" 
query_results = DBSession.execute(sqlalchemy.text(query_request)).fetchall()
DBSession.commit()
print("print query_results, the 1st one is the hit")
for row in query_results:
    print(row)

print ("end job")

## Summiting the job into Ray Cluster:

In [None]:
import ray, time
from ray.job_submission import JobSubmissionClient
client = JobSubmissionClient("ray://ray-cluster-kuberay-head-svc:10001")

In [None]:
# Port forward to the Ray dashboard and go to `localhost:8265` in a browser to see job status: kubectl port-forward -n <namespace> service/ray-cluster-kuberay-head-svc 8265:8265
import time

start_time = time.time()
job_id = client.submit_job(
    entrypoint="python job.py",
    # Path to the local directory that contains the entrypoint file.
    runtime_env={
        "working_dir": "/home/jovyan/rag-app", # upload the local working directory to ray workers
        "pip": [               
            "langchain",
            "langchain-community",
            "sentence-transformers",
            "pypdf",
            "pgvector"
        ]
    }
)

# The Ray job typically takes 5m-10m to complete.
print("Job submitted with ID:", job_id)
while True:
    status = client.get_job_status(job_id)
    print("Job status:", status)
    print("Job info:", client.get_job_info(job_id).message)
    if status.is_terminal():
        break
    time.sleep(30)

end_time = time.time()
job_duration = end_time - start_time
print(f"Job  completed in {job_duration} seconds.")

ray.shutdown()