In [None]:
# ! sh ../bin/install_requirements_databricks.sh
# dbutils.library.restartPython()

In [None]:
import sys
import os

cwd = os.getcwd()
repo_path = os.path.abspath(os.path.join(cwd, '..'))
if repo_path not in sys.path:
    sys.path.append(repo_path)

In [None]:
import chromadb as db 
import pandas as pd
from langchain_chroma import Chroma
from tqdm import tqdm
import tomli

from lib.llm.model import (
    model_api_client,
    make_description_of_instrument,
)
from lib.vector.structure import build_vector_db_structure
from lib.embedding.custom_embedding import CustomHuggingFaceEmbeddings


In [None]:
with open(os.path.join(os.path.dirname(os.path.abspath("")), "config", "config.toml"), "rb") as f:
    config = tomli.load(f)

In [None]:
UPDATE = config['general']['UPDATE_STOCKS_DB']
UPDATE_PORTFOLIO_DATA = False
UPDATE_DB = True
collection_name = config['data']['vector_db']['stocks']['stocks_collection_name']
id_column = "stockID"
to_be_embedded_column = "description"
metadatas_cols = ['name', 'ticker', 'sector','industry','headquarters','description']

In [None]:
llm_client = model_api_client(config['models']['llm_model_api'])
embedding_model = CustomHuggingFaceEmbeddings(model_name=config['models']['embdelling_model_name'])  # sentence-transformers/all-MiniLM-l6-v2

In [None]:
if UPDATE_PORTFOLIO_DATA:
    portfolio_data = pd.read_csv(os.path.join(repo_path, 'data', 'sp500.csv')).dropna().reset_index(drop=True)
    portfolio_data = portfolio_data.rename(columns={
        'Symbol':'ticker',
        'Security':'name',
        'GICS Sector':'sector',
        'GICS Sub-Industry':'industry',
        'Headquarters Location':'headquarters',
        }
    )
    tqdm.pandas(desc='Get Financial Instrument Description')
    portfolio_data['description'] = portfolio_data.progress_apply(lambda row: make_description_of_instrument(row, llm_client), axis=1)
    portfolio_data.to_csv(os.path.join(repo_path, 'data', 'stocks', 'sp500_enriched.csv'))

In [None]:
portfolio_data = pd.read_csv(os.path.join(repo_path, config['data']['location'], config['data']['stocks']['location'], config['data']['stocks']['filename'])).dropna().reset_index(drop=True).reset_index().rename(columns={'index': 'stockID'})
stocks_db_path = os.path.join(os.path.dirname(os.path.abspath("")), config["data"]["location"], config["data"]['vector_db']["location"],config["data"]['vector_db']["stocks"]["location"])
stocks_chroma_client = db.PersistentClient(path=stocks_db_path)

In [None]:
if collection_name not in [c.name for c in stocks_chroma_client.list_collections()]:
    stocks_chroma_client.create_collection(
        name=collection_name,
        embedding_function=embedding_model,
    )
    collection_one = stocks_chroma_client.get_collection(name=collection_name, embedding_function=embedding_model)
    vect_db_structure = build_vector_db_structure(portfolio_data, metadatas_cols, id_column, to_be_embedded_column)
    collection_one.add(
        documents=vect_db_structure['datas'],
        metadatas=vect_db_structure['metadatas'],
        ids=vect_db_structure['ids']
    )
else:
    if UPDATE_DB:
        collection_one = stocks_chroma_client.get_collection(name=collection_name,embedding_function=embedding_model,)
        portfolio_data[id_column] = portfolio_data[id_column] + max([int(id) for id in collection_one.get()['ids']])
        vect_db_structure = build_vector_db_structure(portfolio_data, metadatas_cols, id_column, to_be_embedded_column)
    else:
        collection_one = stocks_chroma_client.get_collection(name=collection_name,embedding_function=embedding_model)

In [None]:
stocks_langchain_chroma = Chroma(
    client=stocks_chroma_client,
    collection_name= collection_name,
    embedding_function=embedding_model,
)

print("There are", stocks_langchain_chroma._collection.count(), "in the collection")

In [None]:
client = db.PersistentClient(path=stocks_db_path)
collection = client.get_collection(name=collection_name, embedding_function=embedding_model)
result = collection.query(query_embeddings=embedding_model.embed_query("oil"), n_results=2)

In [None]:
result['metadatas'][0]