In [None]:
import os

import dotenv
import bs4
import requests
import re
import yaml
import chromadb
import dspy
from google.oauth2 import service_account
from langchain.text_splitter import SentenceTransformersTokenTextSplitter
from dspy.retrieve.chromadb_rm import ChromadbRM
from dsp.modules import GoogleVertexAI
from chromadb.utils import embedding_functions
from tqdm.notebook import tqdm
from redis import Redis


In [None]:
# !gcloud auth login
# !gcloud auth application-default login

## Crawler

In [None]:
URL = "https://en.wikipedia.org/wiki/Go_(game)"
URL_LOCAL = URL.split("/")[-1]
URL_LOCAL = re.sub(r'\W+', '', URL_LOCAL)

In [None]:
response = requests.get(url=URL)
soup = bs4.BeautifulSoup(response.content, "html.parser")

parsed = {}
p_counter = 0
all_titles = soup.find_all("h2")[1:]
for title in all_titles:
    header = title.span["id"].strip()
    textContent = {}
    for para in title.find_next_siblings("p"):
        if header in para.find_previous_siblings("h2")[0].span["id"].strip():
            textContent[p_counter] = para.text.strip()
            p_counter += 1
    if textContent:
        parsed[header] = textContent


## Vectorize

In [None]:
CHROMA_COLLECTION_NAME = f"wiki_{URL_LOCAL}"
CHROMADB_DIR = "../db/"

In [None]:
chroma_client = chromadb.PersistentClient(path=CHROMADB_DIR)
collection = chroma_client.get_or_create_collection(name=CHROMA_COLLECTION_NAME)
text_splitter = SentenceTransformersTokenTextSplitter()

In [None]:
num_paragraphs = list(parsed[list(parsed.keys())[-1]].keys())[-1]
for header, paragraphs in tqdm(parsed.items()):
    for id, text in tqdm(paragraphs.items()):
        # split the text into chunks and insert into chromadb
        ids = []
        documents = []
        metadatas = []
        chunks = text_splitter.create_documents([text]) # takes array of documents
        for chunk_no, chunk in enumerate(chunks):
            ids.append(f"pid_{id}#{chunk_no}")
            documents.append(chunk.page_content)
            metadatas.append({"title": header, "source": URL})
        if ids:
            collection.upsert(ids=ids, documents=documents, metadatas=metadatas)
        # print(f"{int(0.5 + 100.0 * id / num_paragraphs)}% ({collection.count()})", end=" ", flush=True)
        # if id % 10 == 0:
        #     print()



### Test retriever

In [None]:
def Retriever(collection, db_dir):
    """
    Retreives rules for bidding in bridge.
    This is just a retriever and does not have any language model.
    """
    default_ef = embedding_functions.DefaultEmbeddingFunction()
    return ChromadbRM(collection, db_dir, default_ef, k=3)

In [None]:
question = "What is GO?"
retrieved = Retriever(CHROMA_COLLECTION_NAME, CHROMADB_DIR)(question)
print("vector store:", retrieved)

## Test LM

In [None]:
with open("../ai_assistant/config.yaml") as f:
    cfg = yaml.load(f, Loader=yaml.FullLoader)
cfg

In [None]:
dotenv.load_dotenv(dotenv.find_dotenv(".env_dev"))
api_key = os.environ.get("GOOGLE_API_KEY")
# print(api_key)
credentials = service_account.Credentials.from_service_account_file(api_key)

In [None]:
gemini = GoogleVertexAI(
    model_name="gemini-1.0-pro-002",
    project="deft-weaver-396616",
    location="us-central1",
    credentials=credentials
)
dspy.settings.configure(lm=gemini, temperature=0.25, max_tokens=1024)

In [None]:
class AdvisorSignature(dspy.Signature):
    context = dspy.InputField(format=str) # function to call on input to make it a string
    question = dspy.InputField() # function to call on input to make it a string
    answer = dspy.OutputField()

In [None]:
class ZeroShot(dspy.Module):
    """
    Provide answer to question
    """
    def __init__(self):
        super().__init__()
        self.prog = dspy.Predict("question -> answer")

    def forward(self, question):
        return self.prog(question=question)

class wiki_assistant(dspy.Module):
    def __init__(self):
        super().__init__()
        self.prog = dspy.ChainOfThought(AdvisorSignature, n=3)

    def forward(self, question, retriver_collection, database_loc):
        retrieved = Retriever(retriver_collection, database_loc)(question)
        prediction =  self.prog(
            context=retrieved,
            question=question
        )
        return dspy.Prediction(context=retrieved, answer=prediction.answer)

In [None]:
QUESTION = "Which among the three is more difficult, chess, backgammon or GO?"

In [None]:
zs_assistant = ZeroShot()
response = zs_assistant(QUESTION)
print(f"The answer of the non RAG agent is: \n {response.answer}")

print("\n")

cot_assistant = wiki_assistant()
response = cot_assistant(QUESTION, CHROMA_COLLECTION_NAME, CHROMADB_DIR)
print(f"The answer of the RAG agent is: \n {response.answer}")

In [None]:
# gemini.inspect_history(n=3)

## Cache

In [None]:
redis_host = "127.0.0.1"
r = Redis(host=redis_host, port=6379, decode_responses=True)

In [None]:
r.set("foo", "bar")
r.get("foo")

### TODO:
- check if given url exists in Redis
  - If yes: Fetch path of persisted vector db
  - If no or update=True: scrape wiki page, preprocess and persist to vector dd and persist vdb obj to storage 