In [None]:
# ! sh ../bin/install_requirements_databricks.sh
# dbutils.library.restartPython()

In [None]:
import sys
import os

cwd = os.getcwd()
repo_path = os.path.abspath(os.path.join(cwd, '..'))
if repo_path not in sys.path:
    sys.path.append(repo_path)

In [None]:
import chromadb as db 
import pandas as pd
from langchain_chroma import Chroma
import tomli

from lib.vector.structure import build_vector_db_structure
from lib.scraping.scrap import collect_rss_feed, extract_news_content_from_url_to_dataframe, load_rss_urls_from_config
from lib.embedding.custom_embedding import CustomHuggingFaceEmbeddings
from lib.text_processing.splitting import split_text_into_chunks
from lib.utils import convert_to_cet_timezone

In [None]:
with open(os.path.join(os.path.dirname(os.path.abspath("")), "config", "config.toml"), "rb") as f:
    config = tomli.load(f)

In [None]:
READ_RSS = config['general']['READ_RSS']     
SCRAP_ARTICLES_CONTENT = config['general']['SCRAP_ARTICLES_CONTENT']
UPDATE = config['general']['UPDATE_STOCKS_DB']
chunk_size = 1000
chunk_overlap = 300
collection_name = config['data']['vector_db']['news']['news_collection_name']
id_column = "ID"
to_be_embedded_column = "Content"
metadatas_cols = ['ArticleID','ArticleChunkID','Published','Link','Title','Source','Summary']

In [None]:
if READ_RSS:
    rss_urls = load_rss_urls_from_config(os.path.join(repo_path, 'config', 'rss_urls.yaml'))
    rss_feed_df = collect_rss_feed(rss_urls)
    rss_feed_df.to_csv(os.path.join(repo_path, config["data"]["location"], config["data"]["rss_feed"]['location'], config["data"]["rss_feed"]['filename']), index=False)

In [None]:
if SCRAP_ARTICLES_CONTENT:
    rss_feed_df = pd.read_csv(os.path.join(repo_path, 'data', 'rss_feed', 'rss_feed_df.csv'))
    rss_feed_df = extract_news_content_from_url_to_dataframe(rss_feed_df, url_column = 'Link', output_column = to_be_embedded_column)
    rss_feed_df.to_csv(os.path.join(repo_path, config["data"]["location"], config["data"]["rss_feed_with_content"]['location'], config["data"]["rss_feed_with_content"]['filename']), index=False)


In [None]:
embedding_model = CustomHuggingFaceEmbeddings(model_name=config['models']['embdelling_model_name'])  # sentence-transformers/all-MiniLM-l6-v2

In [None]:
news_db_path = os.path.join(os.path.dirname(os.path.abspath("")), config["data"]["location"], config["data"]['vector_db']["location"],config["data"]['vector_db']["news"]["location"])
news_chroma_client = db.PersistentClient(path=news_db_path)

if collection_name not in [c.name for c in news_chroma_client.list_collections()]:
    news_chroma_client.create_collection(
        name=collection_name,
        metadata={"hnsw:space": "cosine"},
        embedding_function=embedding_model,
    )
    collection_one = news_chroma_client.get_collection(name=collection_name, embedding_function=embedding_model,)
    news_data = pd.read_csv(os.path.join(repo_path, config["data"]["location"], config["data"]["news"]['rss_feed_with_content']['location'], config["data"]["news"]['rss_feed_with_content']['filename']))
    news_data = news_data.dropna().reset_index(drop=True).drop_duplicates().reset_index().rename(columns={'index': 'ArticleID'})
    news_data['Published'] = news_data['Published'].apply(convert_to_cet_timezone)
    news_data = split_text_into_chunks(news_data, content_col=to_be_embedded_column, chunk_size= chunk_size, chunk_overlap=chunk_overlap, separator= " ", chunk_colname='ArticleChunkID')
    vect_db_structure = build_vector_db_structure(news_data, metadatas_cols, id_column, to_be_embedded_column)
    collection_one.add(
        documents=vect_db_structure['datas'],
        metadatas=vect_db_structure['metadatas'],
        ids=vect_db_structure['ids']
    )
else:
    if UPDATE==True:
        collection_one = news_chroma_client.get_collection(name=collection_name, embedding_function=embedding_model)
        news_data = pd.read_csv(os.path.join(repo_path, config["data"]["location"], config["data"]["news"]['rss_feed_with_content']['location'], config["data"]["news"]['rss_feed_with_content']['filename']))
        news_data = news_data.dropna().reset_index(drop=True).drop_duplicates().reset_index().rename(columns={'index': 'ArticleID'})
        news_data['Published'] = news_data['Published'].apply(convert_to_cet_timezone)
        news_data = split_text_into_chunks(news_data, content_col=to_be_embedded_column, chunk_size= chunk_size, chunk_overlap=chunk_overlap, separator= " ", chunk_colname='ArticleChunkID')
        news_data[id_column] = news_data[id_column] + max([int(id) for id in collection_one.get()['ids']])
        vect_db_structure = build_vector_db_structure(news_data, metadatas_cols, id_column, to_be_embedded_column)
        collection_one.add(
            documents=vect_db_structure['datas'],
            metadatas=vect_db_structure['metadatas'],
            ids=vect_db_structure['ids']
        )
    else:
        collection_one = news_chroma_client.get_collection(name=collection_name, embedding_function=embedding_model,)

In [None]:
langchain_chroma = Chroma(
    client=news_chroma_client,
    collection_name=collection_name,
    embedding_function=embedding_model,
)
print("There are", langchain_chroma._collection.count(), "in the collection")

In [None]:
# Example of querying the collection
client = db.PersistentClient(path=news_db_path)
collection = client.get_collection(name=collection_name, embedding_function=embedding_model)
collection.query(query_embeddings=embedding_model.embed_query("oil"), n_results=5)