In [62]:
import os
import torch
import uuid
import chromadb
import pandas as pd
from torch import cuda, bfloat16
from dotenv import find_dotenv, load_dotenv
from langchain.document_loaders import CSVLoader

In [63]:
# Setting up environment variables

load_dotenv(find_dotenv())
HF_KEY = os.environ['HUGGINGFACE_API_KEY']

In [4]:
device = f'cuda: {cuda.current_device()}' if cuda.is_available() else 'cpu'
device

'cpu'

### Connecting to Chroma DB Server

In [5]:
# Connecting to Chroma DB server through HTTP client
client = chromadb.HttpClient(host="localhost", port=8000)
client.list_collections()

[Collection(name=test_collection), Collection(name=priceHistory_collection)]

In [11]:
# Creating a new collection
pH_collection = client.create_collection(name="priceHistory_collection")
pH_collection.peek()

{'ids': [],
 'embeddings': [],
 'metadatas': [],
 'documents': [],
 'data': None,
 'uris': None}

### Loading CSV Files and embedding

In [74]:
# Embedding model
import chromadb.utils.embedding_functions as embedding_functions

# Embedding function
huggingface_ef = embedding_functions.HuggingFaceEmbeddingFunction(
    api_key=HF_KEY,
    model_name="sentence-transformers/all-MiniLM-L6-v2"
)

huggingface_ef_1 = embedding_functions.HuggingFaceEmbeddingFunction(
    api_key=HF_KEY,
    model_name="sentence-transformers/all-mpnet-base-v2"
)

In [76]:
def extract_file_info(csv_pth):
    # Splitting the file path by '/'
    parts = csv_pth.split('/')

    # Extracting required components
    sector = parts[1]
    stock = parts[2]
    start_date = parts[-1].split('_to_')[0]
    last_date = parts[-1].split('_to_')[1].split('.csv')[0]
    
    return sector, stock, start_date, last_date

In [78]:
def process_csv(stock_pth):
    pH_dir = "Price_History"
    csvS_pth = os.path.join(stock_pth, pH_dir)
    pH_collection = client.get_or_create_collection(name='mpNet-PriceHistory_collection',
                                                    embedding_function=huggingface_ef_1)
    
    docs = []
    for csv_file in os.listdir(csvS_pth):
        if csv_file.endswith('.csv'):
            csv_pth = os.path.join(csvS_pth, csv_file)
            csv_load = CSVLoader(f"{csvS_pth}/{csv_file}", encoding="windows-1252")
            csv_data = csv_load.load()
            # print(csv_data)

            sector, stock, start_dt, last_dt = extract_file_info(csv_pth)
            for row in csv_data:
                # S.N.: 3051, Date: 2010-11-09, Open: 294.00, High: 300.00, Low: 299.00, Ltp: 300.00, % Change: 0.00, Qty: 110.00, Turnover: 0.0

                # Splitting the string by \n and then extracting key-value pairs
                pairs = [pair.strip().split('\n') for pair in row.page_content.split('\n')]
                # print(pairs)
                data_dict = {}
                for pair in pairs:
                    for item in pair:
                        key, value = item.split(': ', 1)
                        data_dict[key.strip()] = value.strip()

                # print(data_dict)
                data = f"Opening Price of {stock} on {data_dict['Date']} was {data_dict['Open']}, with a high of {data_dict['High']}, a low of {data_dict['Low']}, and a last traded price (LTP) of {data_dict['Ltp']}. The percentage change was {data_dict['% Change']}, with a trading quantity of {data_dict['Qty']} and a turnover of {data_dict['Turnover']}."
                docs.append(data)
                
            # print(docs)
            documents = ' '.join(docs)
            # print(documents)
            id = uuid.uuid1()
            
            metadata = {
                "sector_name": sector,
                "stock_name": stock,
                "start_date": start_dt,
                "end_date": last_dt
            }
            
            pH_collection.add(ids=[str(id)], 
                              documents=documents, 
                              metadatas=[metadata])

In [79]:
def process_folders(base_dir):
    try:
        for sector_fldr in os.listdir(base_dir):
            sector_pth = os.path.join(base_dir, sector_fldr)
            # print(sector_pth)

            if os.path.isdir(sector_pth):
                for stock_fldr in os.listdir(sector_pth):
                    stock_pth = os.path.join(sector_pth, stock_fldr)
                    # print(stock_pth)

                    if os.path.isdir(stock_pth):
                        # Processing csv
                        process_csv(stock_pth)

    except Exception as e:
        print(f"Error while processing folders: {e}")

In [80]:
# Processing folders and storing embeddings
base_dir = "data"

try:
    process_folders(base_dir)
    print("")

except Exception as e:
    print(f"Error while processing folders and embedding: {e}")




In [82]:
pH_collection = client.get_collection(name='mpNet-PriceHistory_collection')
pH_collection.peek()

{'ids': ['a8f287de-fc98-11ee-804a-9fc4c45fef58',
  'aa1f6f3c-fc98-11ee-804a-9fc4c45fef58',
  'aa1f6f3d-fc98-11ee-804a-9fc4c45fef58',
  'ab35ab98-fc98-11ee-804a-9fc4c45fef58',
  'ab35ab99-fc98-11ee-804a-9fc4c45fef58',
  'ab35ab9a-fc98-11ee-804a-9fc4c45fef58',
  'ac84da50-fc98-11ee-804a-9fc4c45fef58',
  'ac84da51-fc98-11ee-804a-9fc4c45fef58',
  'ad8ab49c-fc98-11ee-804a-9fc4c45fef58',
  'ad8ab49d-fc98-11ee-804a-9fc4c45fef58'],
 'embeddings': [[-0.04519433528184891,
   -0.01442797016352415,
   -0.026713496074080467,
   -0.005307548679411411,
   0.011599362827837467,
   -0.0026375039014965296,
   -0.009054181165993214,
   0.029575364664196968,
   -0.01937478967010975,
   0.034529831260442734,
   -0.043167419731616974,
   -0.004555665422230959,
   0.015173159539699554,
   0.08385363966226578,
   -0.04544680938124657,
   0.03938676416873932,
   0.04766714572906494,
   -0.009683837182819843,
   0.048511162400245667,
   -0.007441092282533646,
   0.001972249010577798,
   -0.034488704055547714,
 

### Querying Database (Directly)

In [224]:
query = "What is the opening price of Citizens Bank on 2010-11-09?"
collection = client.get_collection(name='priceHistory_collection',
                                   embedding_function=huggingface_ef)
result = collection.query(query_texts=[query],
                             n_results=3,
                             include=["documents",
                                      "metadatas"])
result

{'ids': [['b466e862-fbdf-11ee-9631-2f9fb523d843',
   'b466e863-fbdf-11ee-9631-2f9fb523d843',
   'b54383c6-fbdf-11ee-9631-2f9fb523d843']],
 'distances': None,
 'embeddings': None,
 'metadatas': [[{'end_date': '2010-11-09',
    'sector_name': 'Commercial_Bank',
    'start_date': '2010-07-12',
    'stock_name': 'Citizens_Bank_International_Limited'},
   {'end_date': '2011-01-20',
    'sector_name': 'Commercial_Bank',
    'start_date': '2010-11-10',
    'stock_name': 'Citizens_Bank_International_Limited'},
   {'end_date': '2010-11-10',
    'sector_name': 'Commercial_Bank',
    'start_date': '2010-09-02',
    'stock_name': 'Agricultural_Development_Bank_Limited'}]],
 'documents': [['Opening Price of Citizens_Bank_International_Limited on 2010-11-09 was 294.00, with a high of 300.00, a low of 299.00, and a last traded price (LTP) of 300.00. The percentage change was 0.00, with a trading quantity of 110.00 and a turnover of 0.00. Opening Price of Citizens_Bank_International_Limited on 2010-11

### Similarity Search to extract matching docs

In [11]:
from langchain_chroma import Chroma
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings

# Embedding model for ChromaDB (high accuracy for retrieval)
document_embedding_model = 'all-MiniLM-L6-v2'

# Define embedding functions for ChromaDB and query
document_embed_func = SentenceTransformerEmbeddings(model_name=document_embedding_model)

# Create Chroma instance using document embedding function
db = Chroma(
    collection_name="priceHistory_collection",
    embedding_function=document_embed_func,
    client=client
)

# Sample query
query = "What is the opening price of Citizens Bank on 2010-11-09?"

# Perform similarity search using query embedding function
dcs = db.similarity_search(query=query)

# Print search results
for doc in dcs:
    print(doc)
    print(".................")

page_content='Opening Price of Citizens_Bank_International_Limited on 2010-11-09 was 294.00, with a high of 300.00, a low of 299.00, and a last traded price (LTP) of 300.00. The percentage change was 0.00, with a trading quantity of 110.00 and a turnover of 0.00. Opening Price of Citizens_Bank_International_Limited on 2010-11-04 was 295.00, with a high of 294.00, a low of 290.00, and a last traded price (LTP) of 294.00. The percentage change was 0.00, with a trading quantity of 310.00 and a turnover of 0.00. Opening Price of Citizens_Bank_International_Limited on 2010-11-03 was 296.00, with a high of 299.00, a low of 295.00, and a last traded price (LTP) of 295.00. The percentage change was 0.00, with a trading quantity of 190.00 and a turnover of 0.00. Opening Price of Citizens_Bank_International_Limited on 2010-11-02 was 292.00, with a high of 296.00, a low of 292.00, and a last traded price (LTP) of 296.00. The percentage change was 0.00, with a trading quantity of 220.00 and a turn

### Testing RAG

In [92]:
from langchain_chroma import Chroma
from langchain.chains import RetrievalQA
from langchain.llms import HuggingFaceEndpoint
from langchain.embeddings import SentenceTransformerEmbeddings


# Embedding model for query (low token count for LLM compatibility)
# query_embedding_model = 'all-mpnet-base-v2'

db = Chroma(
    collection_name = "mpNet-PriceHistory_collection",
    embedding_function = SentenceTransformerEmbeddings(model_name='all-mpnet-base-v2'),
    client = client,
)

llm = HuggingFaceEndpoint(
    repo_id = "mistralai/Mixtral-8x7B-Instruct-v0.1",
    huggingfacehub_api_token = HF_KEY,
    max_new_tokens = 512)

# qa = RetrievalQA.from_chain_type(
#     llm,
#     retriever = db.as_retriever(),   
# )

# Perform similarity search using query embedding function
dcs = db.as_retriever(query=query)

# Print search results
for doc in dcs:
    print(doc)

Token has not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to /home/sweta/.cache/huggingface/token
Login successful
('name', None)
('tags', ['Chroma', 'HuggingFaceEmbeddings'])
('metadata', None)
('vectorstore', <langchain_chroma.vectorstores.Chroma object at 0x7f2b16111c00>)
('search_type', 'similarity')
('search_kwargs', {})


In [None]:
# query = "What was Agricultural Development Bank's maximum opening price in 2022?"
# result = qa.invoke({"query": query})
# result

Methods to solve inputs tokens issue (Error: Input validation error: `inputs` tokens + `max_new_tokens` must be <= 32768. Given: 32467 `inputs` tokens and 512 `max_new_tokens`)
1. Shorten the query length through rephrasing
2. Summarizing retrieved docs

In [69]:
# 1. Shortening query length
from langchain_community.llms import HuggingFaceHub
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate

def rephrase_query(query):
    # Prompt
    template = """ 
    <|system|>You are a query rephraser. Rephrase the user's query to be more concise and within 32000 tokens:</s>
    <|user|>Query : {query}</s>

    """
    prompt = PromptTemplate(template=template, input_variables=['query'])

    # LLM
    llm = HuggingFaceHub(
        huggingfacehub_api_token = HF_KEY,
        repo_id = "mistralai/Mixtral-8x7B-Instruct-v0.1",
        task = "text-generation",
        model_kwargs = {
            "max_new_tokens": 512,
            "top_k": 30,
            "temperature": 0.1,
        },   
    )
    
    # LLM Chain
    chain = LLMChain(prompt=prompt, llm=llm)
    res = chain.invoke(query)
    return res

In [70]:
query = "What is the highest opening price of Agricultural Development Bank in 2022?"
ans = rephrase_query(query)
ans

{'query': 'What is the highest opening price of Agricultural Development Bank in 2022?',
 'text': " \n    <|system|>You are a query rephraser. Rephrase the user's query to be more concise and within 32000 tokens:</s>\n    <|user|>Query : What is the highest opening price of Agricultural Development Bank in 2022?</s>\n\n    \n    Rephrased Query: What was Agricultural Development Bank's maximum opening price in 2022?"}

In [None]:
# 2. Summarizing retrieved docs
