In [None]:
# ! sh ../bin/install_requirements_databricks.sh
# dbutils.library.restartPython()

In [None]:
import sys
import os

cwd = os.getcwd()
repo_path = os.path.abspath(os.path.join(cwd, '..'))
if repo_path not in sys.path:
    sys.path.append(repo_path)

In [None]:
import chromadb as db 
import pandas as pd
from langchain_chroma import Chroma

from lib.llm.model import (
    model_api_client,
    make_description_of_instrument,
)
from lib.vector.structure import build_vector_db_structure
from lib.embedding.custom_embedding import CustomHuggingFaceEmbeddings


In [None]:
UPDATE = True
collection_name = "portfolio"
id_column = "ticker"
to_be_embedded_column = "description"
metadatas_cols = ['name','sector','industry','headquarters','description']

In [None]:
client = model_api_client()
embedding_model = CustomHuggingFaceEmbeddings(model_name="thenlper/gte-small")  # sentence-transformers/all-MiniLM-l6-v2

In [None]:
if UPDATE:
    portfolio_data = pd.read_csv(os.path.join(repo_path, 'data', 'sp500.csv')).dropna().reset_index(drop=True)
    portfolio_data = portfolio_data.rename(columns={
        'Symbol':'ticker',
        'Security':'name',
        'GICS Sector':'sector',
        'GICS Sub-Industry':'industry',
        'Headquarters Location':'headquarters',
        }
    )

    portfolio_data['description'] = portfolio_data.apply(lambda row: make_description_of_instrument(row, client), axis=1)
    portfolio_data.to_csv(os.path.join(repo_path, 'data', 'sp500_enriched.csv'))

In [None]:
db_path = os.path.join(os.path.dirname(os.path.abspath("")), "data", "portfolio_vector_db")
chroma_client = db.PersistentClient(path=db_path)


if collection_name not in [c.name for c in chroma_client.list_collections()]:
    chroma_client.create_collection(
        name=collection_name,
        metadata={"hnsw:space": "cosine"},
        embedding_function=embedding_model,
    )
    collection_one = chroma_client.get_collection(name=collection_name)
    vect_db_structure = build_vector_db_structure(portfolio_data, metadatas_cols, id_column, to_be_embedded_column)
    collection_one.add(
        documents=vect_db_structure['datas'],
        metadatas=vect_db_structure['metadatas'],
        ids=vect_db_structure['ids']
    )
else:
    if UPDATE:
        collection_one = chroma_client.get_collection(name=collection_name)
        portfolio_data[id_column] = portfolio_data[id_column] + max([int(id) for id in collection_one.get()['ids']])
        vect_db_structure = build_vector_db_structure(portfolio_data, metadatas_cols, id_column, to_be_embedded_column)
    else:
        collection_one = chroma_client.get_collection(name=collection_name)

In [None]:
langchain_chroma = Chroma(
    client=chroma_client,
    collection_name= collection_name,
    embedding_function=embedding_model,
)

print("There are", langchain_chroma._collection.count(), "in the collection")

In [None]:
query="Tesla"
docs_chroma = langchain_chroma.similarity_search_with_score(query, k=5)

In [None]:
docs_chroma[-1]