In [None]:
# ! sh ../bin/install_requirements_databricks.sh
# dbutils.library.restartPython()

In [None]:
import sys
import os

cwd = os.getcwd()
repo_path = os.path.abspath(os.path.join(cwd, '..'))
if repo_path not in sys.path:
    sys.path.append(repo_path)

In [None]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
from lib.llm.model import (
    model_api_client,
    make_impact_from_news,
    make_reasons_from_news,
    make_title_from_news,
    make_summary_from_news,
)
from lib.vector.querying import make_news_retrieval
import pandas as pd
from pathlib import Path
import tomli
from tqdm import tqdm
import chromadb as db 
from langchain_chroma import Chroma
from lib.embedding.custom_embedding import CustomHuggingFaceEmbeddings

In [None]:
with open(os.path.join(os.path.dirname(os.path.abspath("")), "config", "config.toml"), "rb") as f:
    config = tomli.load(f)

In [None]:
llm_client = model_api_client(config['models']['llm_model_api'])
embedding_model = CustomHuggingFaceEmbeddings(model_name=config['models']['embdelling_model_name'])  # sentence-transformers/all-MiniLM-l6-v2

In [None]:
# user_portfolio_df = pd.read_csv(os.path.join(os.path.dirname(os.path.abspath("")), config["data"]["location"], config["data"]['user_portfolio']["location"],config["data"]['user_portfolio']["filename"]))
user_portfolio_df = pd.read_csv(os.path.join(os.path.dirname(os.path.abspath("")), config["data"]["location"], config["data"]['stocks']["location"],config["data"]['stocks']["filename"])).sample(10)['ticker'].rename('ticker').reset_index(drop=True)
if config['general']["UPDATE_USER_PORTFOLIO"]:
    user_portfolio_df.to_csv(os.path.join(os.path.dirname(os.path.abspath("")), config["data"]["location"], config["data"]['user_portfolio']["location"],config["data"]['user_portfolio']["filename"]), index=False)

In [None]:
stocks_collection_name = config["data"]['vector_db']["stocks"]["stocks_collection_name"]
stocks_db_path = os.path.join(os.path.dirname(os.path.abspath("")), config["data"]["location"], config["data"]['vector_db']["location"],config["data"]['vector_db']["stocks"]["location"])
stocks_chroma_client = db.PersistentClient(path=stocks_db_path)
stocks_langchain_chroma = Chroma(
    client=stocks_chroma_client,
    collection_name= stocks_collection_name,
    embedding_function=embedding_model,
)

In [None]:
news_collection_name = config["data"]['vector_db']["news"]["news_collection_name"]
news_db_path = os.path.join(os.path.dirname(os.path.abspath("")), config["data"]["location"], config["data"]['vector_db']["location"],config["data"]['vector_db']["news"]["location"])
news_chroma_client = db.PersistentClient(path=news_db_path)
news_langchain_chroma = Chroma(
    client=news_chroma_client,
    collection_name= news_collection_name,
    embedding_function=embedding_model,
)

In [None]:
user_stocks = []
for ticker in user_portfolio_df.to_list():
    filter_criteria = {'ticker': {'$eq': ticker}}
    retriever = stocks_langchain_chroma.as_retriever(search_kwargs={"k": 1, "filter": filter_criteria})
    result = retriever.invoke("")[0].metadata
    user_stocks.append(result)

user_stocks_df = pd.DataFrame(user_stocks)

In [None]:
days_threshold = config["rag"]['news']["news_date_threshold"]
top_articles_k = config["rag"]['news']["top_articles_k"]
timestamp_threshold = int((pd.Timestamp.utcnow() - pd.Timedelta(days=days_threshold)).timestamp())
filter_criteria = {'Published': {'$gte': timestamp_threshold}}

match_list = []

for i, stock in user_stocks_df.iterrows():
        match_list.append(pd.DataFrame(make_news_retrieval(stock, news_langchain_chroma, filters=filter_criteria, top_k_results=top_articles_k)))

news_x_portfolio_df = pd.concat(match_list, axis=0).merge(user_stocks_df, on='ticker', how='inner')

In [None]:
news_x_portfolio_df['impact'] = news_x_portfolio_df.apply(lambda row: make_impact_from_news(row, llm_client), axis=1)

In [None]:
news_x_portfolio_df['reasons'] = news_x_portfolio_df.apply(lambda row: make_reasons_from_news(row, llm_client), axis=1)

In [None]:
news_x_portfolio_df['news_summary'] = news_x_portfolio_df.apply(lambda row: make_summary_from_news(row, llm_client), axis=1)

In [None]:
news_x_portfolio_df['news_title'] = news_x_portfolio_df.apply(lambda row: make_title_from_news(row, llm_client), axis=1)

In [None]:
news_x_portfolio_df.to_csv(os.path.join(os.path.dirname(os.path.abspath("")), config["data"]["location"], config["data"]['dashboard']["location"],config["data"]['dashboard']["filename"]), index=False)

In [None]:
# spark.createDataFrame(news_x_portfolio_df).write.mode("overwrite").saveAsTable("default.dashboard")