In [None]:
import ads
import tempfile

from ads.llm.deploy import ChainDeployment
from langchain.chains import RetrievalQA
from langchain_community.document_loaders import TextLoader
from langchain_community.embeddings import OCIGenAIEmbeddings
from langchain_community.chat_models import ChatOCIGenAI
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import PromptTemplate
from langchain_text_splitters import RecursiveCharacterTextSplitter

ads.set_auth(auth="resource_principal")

### Load Documents
## Load Text
loader = TextLoader("txt/mgs5_cassette_tapes.txt")
documents = loader.load()

## Split Text
# Initialize the RecursiveCharacterTextSplitter
# chunk_size: The maximum size of each chunk (in characters by default).
# chunk_overlap: The number of characters to overlap between consecutive chunks,
#                helping to maintain context.
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=100,
    chunk_overlap=20,
    length_function=len,  # Use character length for chunk size
    is_separator_regex=False, # Treat separators literally
)

# Split the loaded documents
split_docs = text_splitter.split_documents(documents)

# Print the resulting chunks
print(f"Number of original documents: {len(documents)}")
print(f"Number of split chunks: {len(split_docs)}\n")

#for i, chunk in enumerate(split_docs):
#    print(f"Chunk {i+1}:\n{chunk.page_content}\n---")
print(f"split_docs[0].page_content: {split_docs[0].page_content}\n")
print(f"split_docs[0].metadata: {split_docs[0].metadata}\n")


In [None]:
### Connect to OCI embeddings and generative AI
oci_embeddings = OCIGenAIEmbeddings(
    model_id="cohere.embed-english-light-v3.0",
    service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com",
    compartment_id="ocid1.compartment.oc1..aaaaaaaa52sp42nqmtwwzzvmp5mmldri26razhrbyw7cvixmims7p5crsg7a",
)

oci_chat = ChatOCIGenAI(
    model_id="cohere.command-a-03-2025",
    service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com",
    compartment_id="ocid1.compartment.oc1..aaaaaaaa52sp42nqmtwwzzvmp5mmldri26razhrbyw7cvixmims7p5crsg7a",
    model_kwargs={"temperature": 0.7, "max_tokens": 500},
)

print("Connected to OCI embeddings and generative AI.")

In [None]:
### Build Vector Store (FAISS)
vectorstore = FAISS.from_documents(split_docs, oci_embeddings)

print("FAISS vector store built.")

In [None]:
### Build Chain (OCI chat-based Retrieval QA)
retriever = vectorstore.as_retriever(retriever_kwargs={"similarity_top_k": 3})

#rag_prompt_template = """Answer the question based only on the following context:
#{context}
#Question: {question}
#"""

#rag_prompt_template = """Try to the question using only the following context, but, if that fails, use your general knowledge.  You don't have to mention whether you did or did not use the context.
#{context}
#Question: {question}
#"""

rag_prompt_template = """Try to the question using only the following context, but, if that fails, use your general knowledge.
{context}
Question: {question}
"""

rag_prompt = PromptTemplate.from_template(rag_prompt_template)

rag = RetrievalQA.from_chain_type(
    llm=oci_chat,
    retriever=retriever,
    chain_type_kwargs={"prompt": rag_prompt,},
)

#rag = RetrievalQA.from_chain_type(
#    llm=oci_chat,
#    chain_type="stuff",
#    retriever=retriever,
#    return_source_documents=True,
#    verbose=False,
#)

print("RAG chain built.")

In [None]:
### Invoke the chain (unit test)
print(rag.invoke("Who is Kazuhira?"))
print(rag.invoke("Are Venom Snake and Revolver Ocelot friends?"))
print(rag.invoke("What is the name of Anderson's AI?"))
print(rag.invoke("What is Zero's fear once the Cold War is over?"))
print(rag.invoke("How many cassette tapes are there in Metal Gear Solid 5?"))

In [None]:
#### Use ADS to deploy the unit-tested chain
### Create the ADS deployment object
#artifact_dir = tempfile.mkdtemp()
#
#ads_deployment = ChainDeployment(
#    chain=rag,
#    artifact_dir=artifact_dir,
#    force_overwrite=True
#)
#
#ads_deployment.summary_status()

In [None]:
## Prepare the ADS deployment
#ads_deployment.prepare(
#    inference_conda_env="automlx251_p311_cpu_x86_64_v2",
#    inference_python_version="3.11",
#    force_overwrite=True,
#)
#
## Summarize the checkpoitn ADS workflow status
#ads_deployment.summary_status()

In [None]:
## Save the ADS model
#ads_deployment.verify("Who wants to cook hamburgers once the conflict is over?")
#
## Summarize the checkpoitn ADS workflow status
#ads_deployment.summary_status()

In [None]:
## Save the ADS model
#model_id = ads_deployment.save()
#
## Summarize the checkpoitn ADS workflow status
#ads_deployment.summary_status()