# Libraries

In [19]:
%pip install llama-index
%pip install transformers
%pip install torch
%pip install llama-index-llms-groq
%pip install sentence-transformers
%pip install "llama-index-embeddings-huggingface"
%pip install kdbai-client
%pip install llama-index-vector-stores-kdbai
%pip install kdbai_client pandas



In [20]:
import pandas as pd
from typing import List, Dict
from llama_index.core import VectorStoreIndex, ServiceContext, Document
from llama_index.core.node_parser import SentenceSplitter, MarkdownNodeParser
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.groq import Groq
from llama_index.core.llms import ChatMessage
import kdbai_client as kdbai

import time
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

# Data Loading

In [21]:
def load_data(csv_path: str, text_col: List[str], metadata_cols: List[str]) -> List[Document]:
  """
  Load documents and include class in metadata
  """
  df = pd.read_csv(csv_path)
  documents = []
  cols = ['document_id', 'class', 'issuing_authority', 'title', 'issue_date', 'reference_number']
  for _, row in df.iterrows():
      text = str(row[text_col])
      doc = Document(
          text=text,
          metadata= {cols[i]: row[col] for i, col in enumerate(metadata_cols)}
      )
      documents.append(doc)
  return documents

DATA_PATH = "/content/drive/MyDrive/Omdena/Regulatory RAG (SL Chapter)/code/model dev/data/2024_11_28 v0_LK_tea_dataset.csv"
text_col = 'markdown_content'
metadata_cols = ['id', 'class', 'issuing_authority', 'llama_title', 'llama_issue_date', 'llama_reference_number']

all_documents = load_data(DATA_PATH, text_col, metadata_cols)
len(all_documents)

167

In [22]:
circulars_docs = [doc for doc in all_documents if doc.metadata['class'] == 'circular']
len(circulars_docs)

107

In [23]:
set([doc.metadata['issuing_authority'] for doc in circulars_docs])

{'Tea Board', 'Tea Board Analytical Lab', 'Tea Research Institute'}

In [24]:
tri_circulars_docs = [doc for doc in all_documents if ((doc.metadata['class'] == 'circular') and (doc.metadata['issuing_authority'] == ('Tea Research Institute')))]
len(tri_circulars_docs)

50

Edgecase: When two dates are available, taking the first date. Confirm how to handle.<br>
For eg.

```nodes[568].metadata['issue_date']```
> January 1996 and July 2000 (two dates available)

In [25]:
date_list = []

def convert_to_datetime64(docs):
  for doc in tqdm(tri_circulars_docs):
    doc_date = doc.metadata['issue_date']
    if not str(doc_date) == "nan":
      # pick first date if multiple available
      doc_date = " ".join(doc_date.split()[0:2])
    doc.metadata['issue_date_ts'] = pd.to_datetime(doc_date, format="%B %Y")
    date_list.append(doc.metadata['issue_date_ts'])
  return docs

tri_circulars_docs = convert_to_datetime64(tri_circulars_docs)

  0%|          | 0/50 [00:00<?, ?it/s]

In [26]:
tri_circulars_docs[0].metadata['issue_date']

'February 2024'

In [27]:
tri_circulars_docs[0].metadata['issue_date_ts']

Timestamp('2024-02-01 00:00:00')

In [28]:
pd.Series(date_list).value_counts()

Unnamed: 0,count
2024-02-01,20
2003-09-01,4
2000-07-01,4
2003-03-01,3
2009-05-01,3
2013-06-01,2
1996-01-01,2
2002-10-01,1
2001-02-01,1
2011-01-01,1


# Chunking

In [29]:
node_parser = MarkdownNodeParser()
nodes = node_parser.get_nodes_from_documents(tri_circulars_docs)
len(nodes)

725

In [30]:
chunk_word_counts = pd.Series([len(node.text.split()) for node in nodes])
chunk_word_counts.describe()

Unnamed: 0,0
count,725.0
mean,71.195862
std,77.791443
min,2.0
25%,17.0
50%,48.0
75%,96.0
max,833.0


# Embedding Model

In [31]:
def setup_embedding_model():
    """
    Setup HuggingFace embedding model
    """
    model_name = 'BAAI/bge-small-en-v1.5'
    return HuggingFaceEmbedding(
        model_name=model_name,
        trust_remote_code=True,
        cache_folder="/content/drive/MyDrive/Omdena/Regulatory RAG (SL Chapter)/code/model dev/cached_models/"
        )

embed_model = setup_embedding_model()

# Groq + KDBAI API Setup

In [32]:
from google.colab import userdata
GROQ_API_KEY = userdata.get('GROQ_API_KEY')

In [33]:
def setup_groq_llm():
    """
    Setup Groq LLM
    """
    groq_api_key = GROQ_API_KEY
    if not groq_api_key:
        raise ValueError("Please set GROQ_API_KEY environment variable")

    return Groq(
        api_key=groq_api_key,
        model="llama-3.1-8b-instant",
        temperature=0.0
    )

llm = setup_groq_llm()

# KDBAI API + Session Setup

In [34]:
KDBAI_API_KEY = userdata.get('KDBAI_API_KEY')
KDBAI_SESSION_ENDPOINT = userdata.get('KDBAI_SESSION_ENDPOINT')

In [35]:
def setup_kdbai_api():
  """
  Setup KDBAI Session Endpoint and API
  """

  kdbai_endpoint = KDBAI_SESSION_ENDPOINT
  if not kdbai_endpoint:
        raise ValueError("Please set KDBAI_SESSION_ENDPOINT environment variable")

  kdbai_api_key = KDBAI_API_KEY
  if not kdbai_api_key:
        raise ValueError("Please set KDBAI_API_KEY environment variable")

  return kdbai.Session(
    endpoint=f"https://cloud.kdb.ai/instance/{kdbai_endpoint}",
    api_key=f"{kdbai_api_key}"
    )

session = setup_kdbai_api()

# KDBAI Vector Store Setup

## Session Database

In [86]:
session.databases()

[KDBAI database "default", KDBAI database "srilanka_tri_circulars"]

In [87]:
# ensure no database called "srilanka_tea" exists
try:
    session.database("srilanka_tri_circulars").drop()
except kdbai.KDBAIException:
    pass

# Create the database
db = session.create_database("srilanka_tri_circulars")
session.databases()

[KDBAI database "default", KDBAI database "srilanka_tri_circulars"]

## Table Schema + Creation

In [88]:
# List all of the tables in the db
db.tables

[]

In [89]:
# Table - name & schema
table_name = "rag_baseline"

table_schema = [
        dict(name="document_id", type="bytes"),
        dict(name="text", type="bytes"),
        dict(name="embeddings", type="float32s"),
        dict(name="issue_date_ts", type="datetime64[ns]"),
    ]

indexFlat = {
        "name": "flat_index",
        "type": "flat",
        "column": "embeddings",
        "params": {'dims': 384, 'metric': 'CS'} # For similarity metric, choose from Euclidean Distance (L2), Dot Product (IP), or Cosine Similarity (CS).
    }

In [90]:
# First ensure the table does not already exist
try:
    db.table("rag_baseline").drop()
except kdbai.KDBAIException:
    pass

# Create table
table = db.create_table(table_name, table_schema, indexes=[indexFlat])
db.tables

[KDBAI table "rag_baseline"]

In [91]:
table.indexes

[{'name': 'flat_index',
  'type': 'flat',
  'column': 'embeddings',
  'params': {'metric': 'CS', 'dims': 384}}]

## Insert Data into Tables

In [92]:
from llama_index.vector_stores.kdbai import KDBAIVectorStore
from llama_index.core import StorageContext, Settings
from llama_index.core.indices import VectorStoreIndex

In [93]:
Settings.llm = llm
Settings.embed_model = embed_model

In [94]:
%%time

# Vector Store
vector_store = KDBAIVectorStore(
    table=table,
    index_name="circular_baseline_index"
    )

storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_documents(
    tri_circulars_docs,
    storage_context=storage_context,
    transformations=[MarkdownNodeParser()]
)

CPU times: user 4min 43s, sys: 8.86 s, total: 4min 52s
Wall time: 4min 56s


In [95]:
table.query()

Unnamed: 0,document_id,text,embeddings,issue_date_ts
0,b'246000ca-0ef6-4a5f-8d10-fc1f6a131e29',b'# ADVISORY CIRCULAR',"[-0.007562597, -0.026492892, 0.03072106, 0.018...",2024-02-01
1,b'bb5a71bf-1eb5-4a79-8465-6315702da723',b'# No.DM JHL 925VynvT\r\n\r\nIssued in: Febru...,"[-0.0121435, -0.023042029, 0.02277239, 0.01008...",2024-02-01
2,b'cd46e8a7-e016-4029-a670-86a30cc883ff',b'# PROTECTION OF TEA FROM BLISTER BLIGHT\r\n\...,"[-0.009545566, -0.013665088, 0.028926225, 0.02...",2024-02-01
3,b'7bb0aa01-3dc6-4dcd-8e73-c11ed0da1153',b'# 1. Introduction\r\n\r\nBlister blight dise...,"[0.01621599, -0.022643894, 0.051311307, 0.0504...",2024-02-01
4,b'3d9dca48-c5fd-4c06-9c76-ca231703105f',b'# 2. Disease Management\r\n\r\nIntegrated di...,"[-0.011557567, -0.017213918, 0.047385782, 0.03...",2024-02-01
...,...,...,...,...
720,b'372ff8c2-6842-4f30-952d-da1d591098c2',b'# 3.4 Cultural Ecological Weed Control Metho...,"[0.03683722, 0.023794655, 0.018897794, 0.05579...",2024-02-01
721,b'b91b23fb-9703-4c84-8095-977e69c07ea9',b'# 3.5 Manual Weeding\r\n\r\nManual weeding c...,"[-0.008825304, -0.06506193, 0.017658412, 0.038...",2024-02-01
722,b'a3f9c792-31a7-4742-a965-681b4fd6e957',b'# 3.6 Mechanical Weeding\r\n\r\nSlash weedin...,"[-0.0036664123, -0.043851368, 0.032723363, 0.0...",2024-02-01
723,b'6d9aefc1-b63a-4ee4-8eb2-21bf282e3804',b'# 3.7 Chemical Weed Control\r\n\r\nChemical ...,"[0.021501746, -0.0641808, 0.016451132, 0.04455...",2024-02-01


## Setting up Query Engine

In [96]:
%%time

# Using llama-3.1-8b-instant, the 128k tokens context size can take 100 pages.
K = 15

# query_engine = index.as_query_engine(llm=llm)

query_engine = index.as_query_engine(
    similarity_top_k=K,
    llm=llm,
    vector_store_kwargs={
        "index": "flat_index"#,
        # "filter": [["<", "publication_date", pd.to_datetime("")]],
        # "sort_columns": "publication_date",
    },
)

CPU times: user 463 µs, sys: 0 ns, total: 463 µs
Wall time: 472 µs


## Querying Vector Store with Questions

In [97]:
%%time

input_query = "Were there any major circulars for special export duties changes?"

result = query_engine.query(input_query)
print(result.response)

TypeError: Client.__init__() got an unexpected keyword argument 'proxies'

In [None]:
# table.drop()