# LLM Finance Prediction RAG implementation

In [None]:
!pip install finnhub-python yfinance --quiet

In [1]:
!pip install pysqlite3-binary chromadb sentence_transformers --quiet

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
apache-beam 2.46.0 requires dill<0.3.2,>=0.3.1.1, but you have dill 0.3.8 which is incompatible.
apache-beam 2.46.0 requires numpy<1.25.0,>=1.14.3, but you have numpy 1.26.4 which is incompatible.
apache-beam 2.46.0 requires pyarrow<10.0.0,>=3.0.0, but you have pyarrow 16.1.0 which is incompatible.
kfp 2.5.0 requires google-cloud-storage<3,>=2.2.1, but you have google-cloud-storage 1.44.0 which is incompatible.
kfp 2.5.0 requires kubernetes<27,>=8.0.0, but you have kubernetes 30.1.0 which is incompatible.[0m[31m
[0m

In [2]:
!pip install transformers bitsandbytes accelerate peft --quiet

In [3]:
import os
import re
import csv
import math
import time
import json
import random
import pandas as pd
from datetime import datetime, timedelta
from collections import defaultdict
from inspect import cleandoc

In [4]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()

## Util methods

In [5]:
def time2date(timestamp):
    """Converts a timestamp to %Y-%m-%d string"""
    return datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d')

def date2time(date):
    """Parses %Y-%m-%d string to timestamp"""
    return int(datetime.strptime(date, '%Y-%m-%d').timestamp())

def time_before_days(timestamp, days):
    past_day = datetime.fromtimestamp(timestamp) - timedelta(days)
    return int(past_day.timestamp())

def time_after_days(timestamp, days):
    future_day = datetime.fromtimestamp(timestamp) + timedelta(days)
    return int(future_day.timestamp())

def sort_chroma_by_date(results):
    """
    Sorts chroma query results by metadatas `start_date`
    
    Returns: dict with values sorted by `start_date` { 'metadatas': [[ ]], 'documents': [[ ]] }
    """
    sorted_indices = sorted(
        range(len(results['metadatas'][0])), 
        key=lambda i: results['metadatas'][0][i]['start_date']
    )

    sorted_result = {
        'metadatas': [list(map(results['metadatas'][0].__getitem__, sorted_indices))],
        'documents': [list(map(results['documents'][0].__getitem__, sorted_indices))],
    }
    
    return sorted_result

## Collect data

In [None]:
import finnhub
import yfinance as yf

### Collect fin data 

In [None]:
finnhub_client = finnhub.Client(api_key=user_secrets.get_secret("FINNHUB_API_KEY"))

In [None]:
def bin_mapping(ret):
    up_down = 'U' if ret >= 0 else 'D'
    integer = math.ceil(abs(100 * ret))
    return up_down + (str(integer) if integer <= 5 else '5+')

In [None]:
def get_returns(stock_ticker, start_date, end_date):
    stock_data = yf.download(stock_ticker, start=start_date, end=end_date)

    weekly_data = stock_data['Adj Close'].resample('W').ffill()
    weekly_returns = weekly_data.pct_change()[1:]
    weekly_start_prices = weekly_data[:-1]
    weekly_end_prices = weekly_data[1:]

    weekly_data = pd.DataFrame({
        'start_date': weekly_start_prices.index,
        'start_price': weekly_start_prices.values,
        'end_date': weekly_end_prices.index,
        'end_price': weekly_end_prices.values,
        'weekly_returns': weekly_returns.values
    })

    weekly_data['bin_label'] = weekly_data['weekly_returns'].map(bin_mapping)

    return weekly_data

In [None]:
def get_news(ticker, data):
    news_list = []

    for _, row in data.iterrows():
        start_date = row['start_date'].strftime('%Y-%m-%d')
        end_date = row['end_date'].strftime('%Y-%m-%d')
        time.sleep(1) # control qpm
        weekly_news = finnhub_client.company_news(ticker, _from=start_date, to=end_date)
        weekly_news = [
            {
                "date": datetime.fromtimestamp(n['datetime']).strftime('%Y%m%d%H%M%S'),
                "headline": n['headline'],
                "summary": n['summary'],
            } for n in weekly_news
        ]
        weekly_news.sort(key=lambda x: x['date'])
        news_list.append(json.dumps(weekly_news))

    data['news'] = news_list

    return data

In [None]:
def get_basics(ticker, data, start_date, always=False):
    basic_financials = finnhub_client.company_basic_financials(ticker, 'all')

    final_basics, basic_list, basic_dict = [], [], defaultdict(dict)

    for metric, value_list in basic_financials['series']['quarterly'].items():
        for value in value_list:
            basic_dict[value['period']].update({metric: value['v']})

    for k, v in basic_dict.items():
        v.update({'period': k})
        basic_list.append(v)

    basic_list.sort(key=lambda x: x['period'])

    for i, row in data.iterrows():

        start_date = row['end_date'].strftime('%Y-%m-%d')
        last_start_date = start_date if i < 2 else data.loc[i-2, 'start_date'].strftime('%Y-%m-%d')

        used_basic = {}
        for basic in basic_list[::-1]:
            if (always and basic['period'] < start_date) or (last_start_date <= basic['period'] < start_date):
                used_basic = basic
                break
        final_basics.append(json.dumps(used_basic))

    data['basics'] = final_basics

    return data

In [None]:
def prep_data_for_ticker(ticker, data_dir, start_date, end_date):

    _ = get_returns(ticker, start_date, end_date)
    data = get_news(ticker, _)

    data = get_basics(ticker, data, start_date)
    data.to_csv(f"{data_dir}/{ticker}_{start_date}_{end_date}.csv")

### Prepare csv data

In [None]:
TICKERS = [
    "AXP", "AMGN", "AAPL", "BA", "CAT", "CSCO", "CVX", "GS", "HD", "HON",
    "IBM", "INTC", "JNJ", "KO", "JPM", "MCD", "MMM", "MRK", "MSFT", "NKE",
    "PG", "TRV", "UNH", "CRM", "VZ", "V", "WBA", "WMT", "DIS", "DOW"

    ## With my account there is access only to the US tickers

    # "ADS.DE", "ADYEN.AS", "AD.AS", "AI.PA", "AIR.PA", "ALV.DE",
    # "ABI.BR", "ASML.AS", "CS.PA", "BAS.DE", "BAYN.DE", "BBVA.MC",
    # "SAN.MC", "BMW.DE", "BNP.PA", "BN.PA", "DAI.DE", "DPW.DE", "DTE.DE",
    # "ENEL.MI", "ENGI.PA", "EL.PA", "FRE.DE", "IBE.MC", "ITX.MC", "IFX.DE",
    # "INGA.AS", "ISP.MI", "KER.PA", "AD.AS", "PHIA.AS", "OR.PA", "LIN.DE",
    # "MC.PA", "MUV2.DE", "NOKIA.SE", "ORA.PA", "RI.PA", "SAF.PA", "SAN.PA",
    # "SAP.DE", "SU.PA", "SIE.DE", "GLE.PA", "STM.PA", "TEF.MC", "TTE.PA",
    # "UNA.AS", "DG.PA", "VOW3.DE"
]

In [None]:
START_DATE = "2023-02-01"
END_DATE = "2024-07-01"

DATA_DIR = f"./llama_{START_DATE}_{END_DATE}"
os.makedirs(DATA_DIR, exist_ok=True)

In [None]:
for ticker in TICKERS:
    prep_data_for_ticker(ticker, DATA_DIR, START_DATE, END_DATE)

### Prepare data 

In [None]:
def create_company_profile(ticker):
    profile = finnhub_client.company_profile2(symbol=ticker)
    company_template = "[Company Introduction]:\n\n{name} is a leading entity in the {finnhubIndustry} sector. " \
                      "Incorporated and publicly traded since {ipo}, the company has established its reputation " \
                      "as one of the key players in the market. \n\n{name} operates primarily in the {country}, " \
                      "trading under the ticker {ticker} on the {exchange}. As a dominant force in the {finnhubIndustry} space, " \
                      "the company continues to innovate and drive progress within the industry."

    formatted_str = company_template.format(**profile)

    return formatted_str

In [None]:
def map_bin_label(bin_lb):
    lb = bin_lb.replace('U', 'up by ')
    lb = lb.replace('D', 'down by ')
    lb = lb.replace('1', '0-1%')
    lb = lb.replace('2', '1-2%')
    lb = lb.replace('3', '2-3%')
    lb = lb.replace('4', '3-4%')
    if lb.endswith('+'):
        lb = lb.replace('5+', 'more than 5%')
    else:
        lb = lb.replace('5', '4-5%')

    return lb

In [None]:
def sample_news(news, n=5):
    if not 0 <= n <= len(news):
        raise ValueError(f"Bad N")
    sampled_indices = random.sample(range(len(news)), n)
    return [news[i] for i in sampled_indices]

In [None]:
def get_info_by_row(ticker, row):
    start_date = row['start_date'].strftime('%Y-%m-%d') if isinstance(row['start_date'], datetime) else str(row['start_date'])
    end_date = row['end_date'].strftime('%Y-%m-%d') if isinstance(row['end_date'], datetime) else str(row['end_date'])

    term = 'increased' if row['end_price'] > row['start_price'] else 'decreased'
    head = f"From {start_date} to {end_date}, {ticker}'s stock price {term} " \
           f"from {row['start_price']:.2f} to {row['end_price']:.2f} with final estimation - {map_bin_label(row['bin_label'])}. " \
           f"News during this period are listed below:\n"

    news = json.loads(row["news"])
    news = [f"[Headline]: {n['headline']}\n[Summary]: {n['summary']}\n"
            for n in news
            if n['date'][:8] <= end_date.replace('-', '')
            and not n['summary'].startswith("Looking for stock market analysis and research with proves results?")]

    basics = json.loads(row['basics'])
    if not basics:
        basics_str = "[Basic Financials]:\nNo basic financial reported."
    else:
        basics_str = f"Some recent basic financials of {ticker}, reported at {basics['period']}, are presented below:\n\n[Basic Financials]:\n\n"
        basics_str += "\n".join(f"{k}: {v}" for k, v in basics.items() if k != 'period')

    return head, news, basics_str

In [None]:
def build_info(ticker, row, prev_rows, max_weeks):
    prompt = ""
    if prev_rows:
        num_prev_rows = min(random.choice(range(1, max_weeks+1)), len(prev_rows))
        for i in range(-num_prev_rows, 0):
            prompt += f"\n{prev_rows[i][0]}"  # Add Price Movement (Head)
            sampled_news = sample_news(prev_rows[i][1], min(5, len(prev_rows[i][1])))
            if sampled_news:
                prompt += "\n".join(sampled_news)
            else:
                prompt += "No relative news reported.\n"

    head, news, basics = get_info_by_row(ticker, row)
    prev_rows.append((head, news, basics))

    if len(prev_rows) > max_weeks:
        prev_rows.pop(0)

    if not prompt:
        return "", prev_rows

    prompt += f"\n{basics}"
    
    return (prompt.strip(), prev_rows)

In [None]:
def create_ticker_infos(ticker, data_dir, start_date, end_date, max_weeks=5):
    df = pd.read_csv(f'{data_dir}/{ticker}_{start_date}_{end_date}.csv')

    info_prompt = create_company_profile(ticker)
    prev_rows = []
    all_prompts = [info_prompt]

    for _, row in df.iterrows():
        prompt, prev_rows = build_info(ticker, row, prev_rows, max_weeks)
        if prompt:
            all_prompts.append(prompt)

    return all_prompts

In [None]:
START_DATE = "2023-02-01"
END_DATE = "2024-07-01"

DATA_DIR = f"/kaggle/input/llama-{START_DATE}-{END_DATE}"

In [None]:
prepped_dir = f"/kaggle/working/ticker_news-{START_DATE}-{END_DATE}"
os.makedirs(prepped_dir, exist_ok=True)

for ticker in TICKERS:
    print("Processing ticker:", ticker)

    prepped_file = f'{prepped_dir}/{ticker}_{START_DATE}_{END_DATE}.txt'
    ticker_infos = create_ticker_infos(ticker, DATA_DIR, START_DATE, END_DATE)
    
    with open(prepped_file, 'w') as f:
        f.write("\n\n".join(ticker_infos))

## Data Sctructuring

In [None]:
def parse_company_data(raw_data: str):
    intro, data = raw_data.split("From ", 1)
    company_info = {"periods": []}
    data = f"From {data}"
    
    intro_lines = list(filter(None, intro.split('\n')))
    company_info['name'] = intro_lines[1].split(" is ", 1)[0].strip()
    company_info['introduction'] = "".join(intro_lines[1:])

    ticker = re.search(r", (.*?)'s stock price", data)
    company_info["ticker"] = ticker.group(1)
        
    period_pattern = r'From (\d{4}-\d{2}-\d{2}) to (\d{4}-\d{2}-\d{2}), .*?stock price .*? (\d+\.\d+) to (\d+\.\d+) with final estimation - (up|down) by (\d+-\d+%)\. News during this period are listed below:(.*?)(?=From|\Z)'
    periods = re.finditer(period_pattern, data, re.DOTALL)
    
    for period in periods:
        period_data = {"news": [], "basics": []}
        
        start_date, end_date = period[1], period[2]
        start_price, end_price = period[3], period[4]
        direction, percentage_range = period[5], period[6]

        period_data['start_date'] = start_date
        period_data['end_date'] = end_date
        period_data['start_price'] = float(start_price)
        period_data['end_price'] = float(end_price)
        period_data['direction'] = direction
        period_data['percentage_range'] = percentage_range
        
        news_with_basics = period[7]
        
        news_pattern = r'\[Headline\]: (.*?)(?:\r?\n|\r)\[Summary\]: ((?:(?!\[Headline\]:)[\s\S])*?)(?=\r?\n\[Headline\]|\r?\n\[Basic Financials\]|\r?\nSome recent basic financials of|\Z)'
        news = re.findall(news_pattern, news_with_basics, re.DOTALL)
        
        period_data['news'] += [{'headline': headline.strip(), 'summary': summary.strip()} for headline, summary in news]
        
        basics = re.findall(r'\[Basic Financials\]:(.*?)(?=From|\Z)', news_with_basics, re.DOTALL)
        
        if basics and "No basic financial reported" in basics[0]:
            basics = []
            
        period_data['basics'] = basics

        company_info['periods'].append(period_data)

    return company_info

In [None]:
def remove_duplicates(data):
    seen = set()
    i = 0
    while i < len(data['periods']):
        start_date = data['periods'][i]['start_date']
        if start_date in seen:
            data['periods'].pop(i)
        else:
            seen.add(start_date)
            i += 1
    return data

In [None]:
ticker_news_path = '/kaggle/input/ticker-news-2023-02-01-2024-07-01'
prepped_ticker_news_path = '/kaggle/working/prepped_ticker_news-2023-02-01-2024-07-01'
os.makedirs(prepped_ticker_news_path, exist_ok=True)

In [None]:
for filename in os.listdir(ticker_news_path):
    print("Processing", filename)
    if filename.endswith('.txt'):
        with open(os.path.join(ticker_news_path, filename), 'r') as f:
            raw_data = f.read()

        company_data = parse_company_data(raw_data)
        company_data = remove_duplicates(company_data)
        
        filename = filename.replace('.txt', '.json')
        with open(os.path.join(prepped_ticker_news_path, filename), 'w') as f:
            json.dump(company_data, f)

## ChromaDB

In [6]:
__import__('pysqlite3')
import sys
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')

import chromadb
from chromadb.utils import embedding_functions

In [7]:
chroma_client = chromadb.PersistentClient(path="/kaggle/working/chroma_db")
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")

  from tqdm.autonotebook import tqdm, trange
2024-08-21 15:31:12.023929: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-21 15:31:12.024035: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-21 15:31:12.122965: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.7k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [8]:
collections = {
    "companies": chroma_client.get_or_create_collection(
        name="companies",
        embedding_function=sentence_transformer_ef,
        metadata={"hnsw:space": 'cosine'}
    ),
    "financials": chroma_client.get_or_create_collection(
        name="financials",
        embedding_function=sentence_transformer_ef,
        metadata={"hnsw:space": 'cosine'}
    ),
    "news": chroma_client.get_or_create_collection(
        name="news",
        embedding_function=sentence_transformer_ef,
        metadata={"hnsw:space": 'cosine'}
    )
}

print(f"Total documents in 'companies' collection: {collections['companies'].count()}")
print(f"Total documents in 'financials' collection: {collections['financials'].count()}")
print(f"Total documents in 'news' collection: {collections['news'].count()}")

Total documents in 'companies' collection: 30
Total documents in 'financials' collection: 2399
Total documents in 'news' collection: 6286


## Embedding

In [None]:
def prepare_company_documents(company_data):
    documents = {
        'companies': [],
        'financials': [],
        'news': []
    }
    
    documents['companies'].append({
        "text": f"{company_data['name']} ({company_data['ticker']})\n\n[Company Introduction]:\n\n{company_data['introduction']}",
        "metadata": {
            "type": "company_overview",
            "company": company_data['name'],
            "ticker": company_data['ticker'],
        }
    })
    
   
    for period in company_data['periods']:
        period_text = f"From {period['start_date']} to {period['end_date']}, {company_data['name']} ({company_data['ticker']}) " \
                      f"stock price went {period['direction']} by {period['percentage_range']} " \
                      f"from {period['start_price']} to {period['end_price']}."
        
        prediction_bin = 1 if period['direction'] == "up" else -1
        percentage_range = prediction_bin * (int(period['percentage_range'].split('-')[0]) + 0.5)
        
        metadata = {
            "company": company_data['name'],
            "ticker": company_data['ticker'],
            "start_date": date2time(period['start_date']),
            "end_date": date2time(period['end_date']),
            "direction": period['direction'],
            "percentage_range": percentage_range
        }
        
        documents['financials'].append({
            "text": period_text,
            "metadata": {
                "type": "period_summary",
                **metadata
            }
        })
        
        for news in period['news']:
            news_text = f"News about {company_data['name']} ({company_data['ticker']}) " \
                        f"during the period {period['start_date']} to {period['end_date']}:\n\n" \
                        f"[Headline]: {news['headline']}\n[Summary]: {news['summary']}"
            documents['news'].append({
                "text": news_text,
                "metadata": {
                    "type": "news",
                    **metadata
                }
            })
        
        if period['basics']:      
            basics_texts = f"Basic financials of {company_data['name']} ({company_data['ticker']}) " \
                           f"reported at {period['end_date']}:\n\n[Basic Financials]:{period['basics'][0]}"
            
            documents['financials'].append({
                "text": basics_texts,
                "metadata": {
                    "type": "basic_financials",
                    **metadata
                }
            })
    
    return documents

In [None]:
company_data_path = '/kaggle/input/prepped-ticker-news-2023-02-01-2024-07-01'

In [None]:
for filename in os.listdir(company_data_path):
    print("-"*80)
    print("Processing", filename)
    if filename.endswith('.json'):
        with open(os.path.join(company_data_path, filename), 'r') as f:
            company_data = json.load(f)

        company_documents = prepare_company_documents(company_data)
        
        for collection, documents in company_documents.items():   
            ids = [
                f"""{company_data['ticker']}_{doc['metadata']['type']}{f"_{doc['metadata'].get('start_date', '$')}".rstrip("_$")}_{i}"""
                for i, doc in enumerate(documents)
            ]
            
            collections[collection].upsert(
                ids=ids,
                documents=[doc["text"] for doc in documents],
                metadatas=[doc["metadata"] for doc in documents],
            )
              
        
        print(f"Processed {sum(len(lst) for lst in company_documents.values())} documents for {company_data['name']} ({company_data['ticker']})")
        
        
print(f"Total documents in 'companies' collection: {collections['companies'].count()}")
print(f"Total documents in 'financials' collection: {collections['financials'].count()}")
print(f"Total documents in 'news' collection: {collections['news'].count()}")

## Querying

### Companies

Let's query 'What is Apple Inc?'.  
By the results we see that the first and the closest result is about Apple Inc (AAPL).  
Even though we use cosine and it goes from -1 to 1, the distance is taken from 0 (the closest) to 2.  

In [None]:
query_results = collections['companies'].query(
    query_texts=["What is Apple Inc?"],
    n_results=5
)
query_results

In [None]:
query_results = collections['companies'].query(
    query_texts=["A company in the communications sector"],
    n_results=3
)
query_results

### Financials

In [None]:
query_results = collections['financials'].query(
    query_texts=["Apple Inc's stock price change for a period from 2024-05-05 to 2024-05-26"],
    n_results=10
)
query_results

Not quite a period that I specified

In [None]:
query_result_dates = [datetime.fromtimestamp(res['start_date']).strftime("%Y-%m-%d") for res in query_results['metadatas'][0]]
print("\n".join(query_result_dates))

In [None]:
query_results = collections['financials'].query(
    query_texts=["Basic financials of Apple Inc (AAPL) from 2024-03-17 to 2024-04-07"],
    n_results=10,
    where={
        "$and": [{"start_date": {"$gte": date2time("2024-03-17")}}, {"start_date": {"$lt": date2time("2024-04-07")}}]
    }
)
query_results

Filters help a lot and, although we got some other companies in the results, we have the distance that has a big gap from Apple's data and helps us as well

In [None]:
query_result_dates = [datetime.fromtimestamp(res['start_date']).strftime("%Y-%m-%d") for res in query_results['metadatas'][0]]
print("\n".join(query_result_dates))

### News

In [None]:
query_results = collections['news'].query(
    query_texts=["News about Apple Inc with stock price change"],
    n_results=10
)
query_results

## Parse user query

In [9]:
def parse_user_query(query):
    max_distance_threshold = 0.54
    
    query = query.lower()
    parsed = {
        "company": None,
        "ticker": None,
        "start_date": None,
        "end_date": None,
    }
    
    company_query_res = collections['companies'].query(
        query_texts=[query],
        n_results=1,
        include=["metadatas", "distances"]
    )
  
    # Company name and ticker
    if company_query_res and (metadatas := company_query_res['metadatas']) and metadatas[0]:
#         print(f"{company_query_res['distances']=}")
        if company_query_res['distances'][0][0] > max_distance_threshold:
            raise Exception("The provided query seems a bit vague, try clarifying what you are looking for")
        parsed['company'] = metadatas[0][0]['company']
        parsed['ticker'] = metadatas[0][0]['ticker']
    
    # Date ranges
    date_ranges = re.findall(r'\d{4}-\d{2}-\d{2}', query)
    weeks_match = re.search(r'\b(\d{1,2}) week', query)
    try:
        date_ranges = sorted(date_ranges, key=lambda date: -date2time(date))[:2]
    except ValueError:
        pass
    
    weeks = 4
    if date_ranges:
        parsed["end_date"] = date_ranges[0]
        
        if len(date_ranges) == 2:
            parsed["start_date"] = date_ranges[1]
            return parsed
        
        if weeks_match:
            weeks = int(weeks_match.group(1))
            
        parsed["start_date"] = time2date(time_before_days(date2time(date_ranges[0]), 7*weeks))
        
    else:
        end_date = datetime.now()
        
        if weeks_match:
            weeks = int(weeks_match.group(1))
            
        start_date = end_date - timedelta(weeks=weeks)
        parsed["start_date"] = time2date(start_date.timestamp())
        parsed["end_date"] = time2date(end_date.timestamp())
    
    return parsed

### Testing

In [None]:
print(parse_user_query("What are the potential concerns about Apple?"))

In [None]:
print(parse_user_query("What are the predictions for the next week from today and 4 weeks before for the company Apple?"))

In [None]:
parse_user_query("What are the potential concerns about the leading company in Technology?")

## Llama with RAG

In [10]:
import torch
import transformers
from torch.optim import AdamW
from transformers import (
    AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig,
    GenerationConfig, pipeline, Trainer, TrainingArguments, DataCollatorForSeq2Seq
)

from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model, PeftModel

In [11]:
os.environ["HF_TOKEN"] = user_secrets.get_secret("HF_TOKEN")

### Model

In [16]:
finance_llama3_8b = "instruction-pretrain/finance-Llama3-8B"
finance_peft_adapter_path = "/kaggle/input/fin-prediction-llama3/transformers/peft_adapter/1"
base_llama3_8b = "meta-llama/Meta-Llama-3-8B-Instruct"

In [13]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

In [14]:
base_model = AutoModelForCausalLM.from_pretrained(base_llama3_8b, quantization_config=bnb_config, device_map='auto', low_cpu_mem_usage=True)
base_model = base_model.eval()

config.json:   0%|          | 0.00/654 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]

In [None]:
finance_model = AutoModelForCausalLM.from_pretrained(
    finance_llama3_8b, 
    return_dict=True, 
    quantization_config=bnb_config, 
    low_cpu_mem_usage=True,
    device_map="cuda:0"
)
PeftModel.from_pretrained(finance_model, finance_peft_adapter_path)
finetuned_model = finance_model.eval()

In [18]:
finance_tokenizer = AutoTokenizer.from_pretrained(finance_llama3_8b)
base_tokenizer = AutoTokenizer.from_pretrained(base_llama3_8b)

tokenizer_config.json:   0%|          | 0.00/50.6k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/73.0 [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


tokenizer_config.json:   0%|          | 0.00/51.0k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/73.0 [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [23]:
generation_config = GenerationConfig.from_pretrained(base_llama3_8b)
generation_config.temperature = 0.001

base_pipeline = pipeline(
    "text-generation",
    model=base_model,
    tokenizer=base_tokenizer,
    generation_config=generation_config,
)

In [29]:
generation_config = GenerationConfig.from_pretrained(finance_llama3_8b)
generation_config.temperature = 0.001

terminators = [
    finance_tokenizer.eos_token_id,
    finance_tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
finance_pipeline = pipeline(
    "text-generation",
    model=finance_llama3_8b_model,
    tokenizer=finance_tokenizer,
    generation_config=generation_config,
    eos_token_id=terminators
)

In [19]:
finetuned_model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): lora.Linear4bit(
            (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
            (lora_dropout): ModuleDict(
              (default): Dropout(p=0.1, inplace=False)
            )
            (lora_A): ModuleDict(
              (default): Linear(in_features=4096, out_features=16, bias=False)
            )
            (lora_B): ModuleDict(
              (default): Linear(in_features=16, out_features=4096, bias=False)
            )
            (lora_embedding_A): ParameterDict()
            (lora_embedding_B): ParameterDict()
            (lora_magnitude_vector): ModuleDict()
          )
          (k_proj): lora.Linear4bit(
            (base_layer): Linear4bit(in_features=4096, out_features=1024, bias=False)
            (lora_dropout): ModuleDict(
   

### RAG

In [85]:
finetuned_model.base_model_prefix

'model'

In [91]:
base_model.config._name_or_path

'meta-llama/Meta-Llama-3-8B-Instruct'

In [111]:
def get_assistant_response(output):
    gen_text = output[0]['generated_text']
    assistant_text = gen_text.rsplit('<|im_start|>assistant\n', 1)[1]
    return assistant_text

def _get_completion(gen_pipeline, tokenizer, messages):
    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    outputs = gen_pipeline(prompt, add_special_tokens=True)
    return outputs


def decode(tokenizer, output):
    return tokenizer.decode(output[0]).rsplit("|>assistant\n", 1)[-1].split('<|e')[0]

def decode_base(tokenizer, output):
    return tokenizer.decode(output[0]).rsplit("|>assistant<|end_header_id|>\n", 1)[-1].split('<|e')[0]

def get_completion(model, tokenizer, messages):
    is_base_model = model.config._name_or_path == 'meta-llama/Meta-Llama-3-8B-Instruct'
    
    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    inputs = tokenizer(
        prompt, return_tensors='pt',
        add_special_tokens=not is_base_model
    ).to(model.device)
    
    res = model.generate(
        **inputs, 
        do_sample=True,
        eos_token_id=tokenizer.eos_token_id
    )
      
    return decode(tokenizer, res) if not is_base_model else decode_base(tokenizer, res)

#### RAG Market Analyst and Assistant without query parsing/understanding

In [99]:
def simple_rag(model, tokenizer, query):
    prompt = ""

    company_intro_res = collections['companies'].query(
        query_texts=[query],
        n_results=1
    )
    company_intro = company_intro_res['documents'][0][0].split("\n\n", 1)[1]
    print(f"company intro query distance: {company_intro_res['distances'][0][0]}")
    prompt += company_intro
    
    news_res = collections['news'].query(
        query_texts=[query],
        n_results=10
    )
    print(f"news query distances: {', '.join(map(str, news_res['distances'][0]))}")
    prompt += "\n\nNews:\n" + "\n\n".join([doc for doc in news_res['documents'][0]])

    
    period_summaries_res = collections['financials'].query(
        query_texts=[query],
        n_results=10,
        where={ 'type': 'period_summary' }
    )
    print(f"period summaries query distances: {', '.join(map(str, period_summaries_res['distances'][0]))}")
    prompt += "\n\nPeriod summaries:\n" + "\n\n".join([doc for doc in period_summaries_res['documents'][0]])
    
    
    basic_financials_res = collections['financials'].query(
        query_texts=[query],
        n_results=3,
        where={ 'type': 'basic_financials' }
    )
    print(f"basic financials query distances: {', '.join(map(str, basic_financials_res['distances'][0]))}")
    prompt += "\n\nBasic Financials:\n" + "\n".join([doc.replace("\n[Basic Financials]:\n", "") for doc in basic_financials_res['documents'][0]])

    system_prompt = "You are a seasoned stock market analyst. " \
                    "Your task is to list the positive developments and potential " \
                    "concerns for companies based on the context of relevant news, period summaries, and basic financials from the past weeks, " \
                    "then answer the query by only using the provided context."
    
    prompt = f"Context:\n{prompt}\nQuery: {query}"
    
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": prompt}
    ]
    
    completion = get_completion(finetuned_model, finance_tokenizer, messages)
    return completion, prompt

Fine-tuned

In [24]:
response, prompt = simple_rag(
    finetuned_model, 
    finance_tokenizer,
    "Tell me what are the concerns and potential developments of Apple Inc for the next week if today is 2024-03-17?"
)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

company intro query distance: 0.44003707333224984


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

news query distances: 0.2735745906829834, 0.2744227647781372, 0.27813631296157837, 0.2817426919937134, 0.2894265651702881, 0.2932044267654419, 0.29453492164611816, 0.2991286516189575, 0.30116337537765503, 0.3043619394302368


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

period summaries query distances: 0.40203988552093506, 0.40408027172088623, 0.4109033942222595, 0.41113418340682983, 0.4112357497215271, 0.41276049613952637, 0.4141501188278198, 0.4149589538574219, 0.4150373935699463, 0.41513651609420776


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


basic financials query distances: 0.4599674940109253, 0.4614250063896179, 0.4684370756149292


In [25]:
print(prompt)

Context:
[Company Introduction]:

Apple Inc is a leading entity in the Technology sector. Incorporated and publicly traded since 1980-12-12, the company has established its reputation as one of the key players in the market. Apple Inc operates primarily in the US, trading under the ticker AAPL on the NASDAQ NMS - GLOBAL MARKET. As a dominant force in the Technology space, the company continues to innovate and drive progress within the industry.

News:
News about Apple Inc (AAPL) during the period 2024-03-17 to 2024-03-24:

[Headline]: Apple Stock Rises. Analysts Deliver Their Verdict on the Legal Threat.
[Summary]: Analysts led by Wedbush’s Dan Ives write, ‘Ultimately, we do not expect any business model changes for now.’

News about Apple Inc (AAPL) during the period 2023-10-01 to 2023-10-08:

[Headline]: Apple Stock Needs a Win. Here Are 2 Big Ideas.
[Summary]: Apple’s next earnings report is just a few weeks away. It’s likely to post a fourth consecutive quarter of year-over-year re

In [26]:
print(response)

Based on the context provided, the concerns and potential developments of Apple Inc for the next week are as follows:

Concerns:
1. Apple's upcoming earnings report may reveal disappointing holiday quarter sales, particularly in China.
2. The company's stock price has been volatile, with a recent decline of 2-3% from 197.05 to 193.09.
3. There are concerns about the company's ability to innovate and drive progress in the technology space.
4. The potential integration of AI into Apple's product lineup may be a significant development, but it is unclear how it will impact the company's business model.
5. The company's stock price may be affected by the legal threat mentioned in the news article.

Potential developments:
1. Apple's upcoming earnings report may reveal a turn in the company's fortunes, with a potential increase in stock price.
2. The company's integration of AI into its product lineup may lead to new innovations and enhancements, driving progress in the technology space.
3.

Base

In [100]:
response, prompt = simple_rag(
    base_model, 
    base_tokenizer,
    "Tell me what are the concerns and potential developments of Apple Inc for the next week if today is 2024-03-17?"
)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

company intro query distance: 0.44003707333224984


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

news query distances: 0.2735745906829834, 0.2744227647781372, 0.27813631296157837, 0.2817426919937134, 0.2894265651702881, 0.2932044267654419, 0.29453492164611816, 0.2991286516189575, 0.30116337537765503, 0.3043619394302368


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

period summaries query distances: 0.40203988552093506, 0.40408027172088623, 0.4109033942222595, 0.41113418340682983, 0.4112357497215271, 0.41276049613952637, 0.4141501188278198, 0.4149589538574219, 0.4150373935699463, 0.41513651609420776


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


basic financials query distances: 0.4599674940109253, 0.4614250063896179, 0.4684370756149292


In [101]:
print(response)

Based on the context provided, the concerns and potential developments for Apple Inc for the next week are as follows:

Concerns:
- Apple's upcoming earnings report may reveal disappointing holiday quarter sales, particularly in China.
- The stock price has been down by 2-3% from 197.05 to 193.09, indicating a potential decline in investor confidence.
- The company's legal issues and potential business model changes may affect its stock price.

Potential developments:
- Apple's upcoming earnings report may reveal a rebound in sales, which could lead to a rise in stock price.
- The company's new products and features, such as the integration of artificial intelligence (AI), may drive innovation and growth.
- The company's legal issues may be resolved, which could lead to a rise in stock price.

Overall, the concerns and potential developments for Apple Inc for the next week are related to its upcoming earnings report, stock price, and potential legal issues.


#### RAG Stock Price prediction without query understanding

In [108]:
def prediction_rag_no_prep(model, tokenizer, query):
    prompt = ""
    parsed_query = parse_user_query(query) # Only used for Llama 3 instruction prompt
    
    company_intro_res = collections['companies'].query(
        query_texts=[query],
        n_results=1
    )
    company_intro = company_intro_res['documents'][0][0].split("\n\n", 1)[1]
    print(f"company intro query distance: {company_intro_res['distances'][0][0]}")
    prompt += company_intro
    
    news_res = collections['news'].query(
        query_texts=[query],
        n_results=15
    )
    print(f"news query distances: {', '.join(map(str, news_res['distances'][0]))}")
    prompt += "\n\nNews:\n" + "\n\n".join([doc for doc in news_res['documents'][0]])

    
    period_summaries_res = collections['financials'].query(
        query_texts=[query],
        n_results=10,
        where={ 'type': 'period_summary' }
    )
    print(f"period summaries query distances: {', '.join(map(str, period_summaries_res['distances'][0]))}")
    prompt += "\n\nPeriod summaries:\n" + "\n\n".join([doc for doc in period_summaries_res['documents'][0]])
    
    
    basic_financials_res = collections['financials'].query(
        query_texts=[query],
        n_results=3,
        where={ 'type': 'basic_financials' }
    )
    print(f"basic financials query distances: {', '.join(map(str, basic_financials_res['distances'][0]))}")
    prompt += "\n\nBasic Financials:\n" + "\n".join([doc.replace("\n[Basic Financials]:\n", "") for doc in basic_financials_res['documents'][0]])
    
    system_prompt = "You are a seasoned stock market analyst. " \
                    "Your task is to list the positive developments and potential " \
                    "concerns for companies based on the context of relevant news, period summaries, and basic financials from the past weeks, " \
                    "then provide an analysis and prediction of the companies' stock price movements for the upcoming week. " \
                    "Your answer format should be as follows: " \
                    "\n\n[Positive Developments]:" \
                    "\n1. ..." \
                    "\n\n[Potential Concerns]:" \
                    "\n1. ..." \
                    "\n\n[Prediction & Analysis]:" \
                    "\nPrediction: Up|Down by X-Y%" \
                    "\nAnalysis: ...\n"

    instruction_prompt = f"Based on all the information before {parsed_query['end_date']}, let's first analyze the positive developments " \
                         f"and potential concerns for {parsed_query['ticker']}. Come up with 2-4 most important factors respectively " \
                         f"and keep them concise. Most factors should be inferred from company related news. " \
                         f"Then make your prediction of the {parsed_query['ticker']} stock price movement for next week ({parsed_query['end_date']} to {time2date(time_after_days(date2time(parsed_query['end_date']), 7))}). " \
                         f"Provide a summary analysis to support your prediction."

    prompt += "\n" + instruction_prompt
    
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": prompt}
    ]
    
    completion = get_completion(model, tokenizer, messages)
    return completion, prompt

Fine-tuned

In [32]:
response, prompt = prediction_rag_no_prep(
    finetuned_model, 
    finance_tokenizer,
    "What are the predictions of Apple Inc's stock price for the next week if today is 2024-03-14?"
)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

company intro query distance: 0.5366427233333579


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

news query distances: 0.3156413435935974, 0.32973307371139526, 0.33078426122665405, 0.33690977096557617, 0.3371772766113281, 0.34524381160736084, 0.3458632230758667, 0.3460145592689514, 0.3510153889656067, 0.3520105481147766, 0.3540758490562439, 0.35786187648773193, 0.35936248302459717, 0.3615577816963196, 0.3642909526824951


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

period summaries query distances: 0.33930671215057373, 0.3405546545982361, 0.34656381607055664, 0.3475383520126343, 0.3504473567008972, 0.3509594202041626, 0.3516607880592346, 0.3520732522010803, 0.352130651473999, 0.35292428731918335


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


basic financials query distances: 0.471174418926239, 0.47252196073532104, 0.47502416372299194


In [33]:
print(response)

[Positive Developments]:
1. Apple Inc's stock price has been relatively stable, with only minor fluctuations, indicating a strong foundation for the company's performance.
2. The company's financials, as reported in the past few weeks, show a consistent trend of growth and profitability, suggesting a healthy business model.
3. Apple Inc's ability to innovate and adapt to changing market conditions has been highlighted in recent news, indicating a strong competitive position.
4. The company's stock has been included in various lists of profitable and trending stocks, highlighting its potential for growth.

[Potential Concerns]:
1. The recent sell-off in the stock market has raised concerns about Apple Inc's stock price, as it may be affected by broader market trends.
2. The company's earnings report for the fourth quarter is expected to show a decline in revenue, which may lead to a further decrease in stock price.
3. Apple Inc's reliance on a few key products, such as the iPhone, may e

Base

In [112]:
response, prompt = prediction_rag_no_prep(
    base_model, 
    base_tokenizer,
    "What are the predictions of Apple Inc's stock price for the next week if today is 2024-03-14?"
)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

company intro query distance: 0.5366427233333579


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

news query distances: 0.3156413435935974, 0.32973307371139526, 0.33078426122665405, 0.33690977096557617, 0.3371772766113281, 0.34524381160736084, 0.3458632230758667, 0.3460145592689514, 0.3510153889656067, 0.3520105481147766, 0.3540758490562439, 0.35786187648773193, 0.35936248302459717, 0.3615577816963196, 0.3642909526824951


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

period summaries query distances: 0.33930671215057373, 0.3405546545982361, 0.34656381607055664, 0.3475383520126343, 0.3504473567008972, 0.3509594202041626, 0.3516607880592346, 0.3520732522010803, 0.352130651473999, 0.35292428731918335


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.


basic financials query distances: 0.471174418926239, 0.47252196073532104, 0.47502416372299194


In [113]:
print(response)


**Positive Developments:**

1. **Strong Financials**: Apple's financials have been consistently strong, with a high operating margin, decent cash flow, and a manageable debt-to-equity ratio.
2. **Innovation and Diversification**: Apple's efforts to innovate and diversify its product portfolio, such as the Apple Watch and Apple TV+, have been well-received and may drive future growth.
3. **Stable Earnings**: Apple's earnings have been stable, with a consistent dividend payout and a relatively low payout ratio.

**Potential Concerns:**

1. **Market Volatility**: The stock market has been volatile, and Apple's stock price may be affected by broader market trends.
2. **Competition**: Apple faces intense competition in the tech industry, particularly from other smartphone manufacturers and emerging players.
3. **Valuation**: Apple's stock price has been relatively high, and some investors may be concerned about its valuation.

**Prediction and Analysis:**

Based on the positive development

In [30]:
with open('/kaggle/working/prepped_ticker_news-2023-02-01-2024-07-01/AAPL.json', 'r') as f:
    aapl_prepped_data = json.load(f)

In [31]:
from operator import itemgetter
[
    itemgetter('start_date', 'end_date', 'direction', 'percentage_range')(period) 
    for period in aapl_prepped_data['periods'] 
    if date2time("2024-03-14") <= date2time(period['start_date']) <= date2time("2024-03-25")
]

[('2024-03-17', '2024-03-24', 'down', '0-1%'),
 ('2024-03-24', '2024-03-31', 'down', '0-1%')]

#### RAG Stock Price Prediction With query understanding

In [117]:
def prediction_rag_with_prep(model, tokenizer, query):
    prompt = ""
    parsed_query = parse_user_query(query)
    print(parsed_query)

    company_intro_res = collections['companies'].query(
        query_texts=[f"Company {parsed_query['company']} ({parsed_query['ticker']})"],
        n_results=1
    )
    company_intro = company_intro_res['documents'][0][0].split("\n\n", 1)[1]
    prompt += company_intro + "\n\n"
    
    period_summaries_res = collections['financials'].query(
        query_texts=[query],
        n_results=10,
        include=['documents', 'metadatas'],
        where={ 
            '$and': [
                { 'type': 'period_summary' }, 
                { 'ticker': parsed_query['ticker'] },
                { 'start_date': { '$gte': date2time(parsed_query['start_date']) }},
                { 'end_date': { '$lte': date2time(parsed_query['end_date']) }}
            ] 
        }
    )
    
    sorted_period_summaries = sort_chroma_by_date(period_summaries_res)
    for doc, metadata in zip(sorted_period_summaries['documents'][0], sorted_period_summaries['metadatas'][0]):
        search_filters = [
            { 'ticker': parsed_query['ticker'] },
            { 'start_date': { '$gte': metadata['start_date'] }},
            { 'end_date': { '$lte': metadata['end_date'] }}
        ]
        
        
        news_res = collections['news'].query(
            query_texts=[query],
            where={ '$and': search_filters }
        )
        
        basic_financials_res = collections['financials'].query(
            query_texts=[query],
            where={ '$and': [ { 'type': 'basic_financials' }, *search_filters ]
            }
        )
        
        
        prompt += f"{doc} News during this period are listed below:\n\n"
        
        if news_res['documents'][0]:
            remove_before_news = lambda t: re.sub(r'.*(?=\[Headline\])', '', t, flags=re.DOTALL)
            prompt += "\n\n".join(list(map(remove_before_news, news_res['documents'][0])))
        else:
            prompt += "\nNo relative news reported."
        
        if basic_financials_res['documents'][0]:
            remove_before_basics = lambda t: re.sub(r'.*(?=\[Basic Financials\])', '', t, flags=re.DOTALL)
            prompt += f"\n\nSome recent basic financials of {parsed_query['ticker']}, reported at {time2date(metadata['end_date'])}, are presented below:\n\n"
            prompt += "".join(list(map(remove_before_basics, basic_financials_res['documents'][0])))
        else: 
            prompt += "\n\n[Basic Financials]:\n\nNo basic financial reported.\n\n"

        
    next_week_timestamp = time_after_days(date2time(parsed_query['end_date']), 7)
    
    system_prompt = "You are a seasoned stock market analyst. " \
                    "Your task is to list the positive developments and potential " \
                    "concerns for companies based on the context of relevant news, period summaries, and basic financials from the past weeks, " \
                    "then provide an analysis and prediction of the companies' stock price movements for the upcoming week. " \
                    "Your answer format should be as follows: " \
                    "\n\n[Positive Developments]:" \
                    "\n1. ..." \
                    "\n\n[Potential Concerns]:" \
                    "\n1. ..." \
                    "\n\n[Prediction & Analysis]:" \
                    "\nPrediction: Up|Down by X-Y%" \
                    "\nAnalysis: ...\n"
    
    instruction_prompt = f"Based on all the information before {parsed_query['end_date']}, let's first analyze the positive developments " \
                         f"and potential concerns for {parsed_query['ticker']}. Come up with 2-4 most important factors respectively " \
                         f"and keep them concise. Most factors should be inferred from company related news. " \
                         f"Then make your prediction of the {parsed_query['ticker']} stock price movement for next week ({parsed_query['end_date']} to {time2date(next_week_timestamp)}). " \
                         f"Provide a summary analysis to support your prediction."

    prompt += "\n" + instruction_prompt
    
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": prompt}
    ]
    
    completion = get_completion(model, tokenizer, messages)
    return completion, prompt

Fine-tuned

In [118]:
response, prompt = prediction_rag_with_prep(
    finetuned_model, 
    finance_tokenizer,
    "What are the predictions of Apple Inc's stock price for the next week if today is 2024-03-14?"
)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

{'company': 'Apple Inc', 'ticker': 'AAPL', 'start_date': '2024-02-15', 'end_date': '2024-03-14'}


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


In [38]:
print(response)

[Positive Developments]:
1. Apple's new sports app launch, which could lead to increased user engagement and content prospects.
2. The company's potential AI tools unveiling at its June developer conference, which could drive innovation and growth.
3. The Magnificent 7 tech stocks' Q4 results, which could provide insights into the industry's performance and Apple's competitive position.
4. The EU's investigation into Apple's decision to close Epic Games' developer account, which could have potential implications for the company's app store policies.

[Potential Concerns]:
1. The stock's underperformance compared to competitors and its potential impact on investor confidence.
2. The company's lagging behind tech peers with a clearer AI strategy, which could affect its competitive position.
3. The EU's investigation into Apple's app store policies, which could lead to regulatory changes and potential financial implications.
4. The company's decision to end its self-driving EV plans, whic

Base

In [115]:
response, prompt = prediction_rag_with_prep(
    base_model, 
    base_tokenizer,
    "What are the predictions of Apple Inc's stock price for the next week if today is 2024-03-14?"
)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

{'company': 'Apple Inc', 'ticker': 'AAPL', 'start_date': '2024-02-15', 'end_date': '2024-03-14'}


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.


In [116]:
print(response)


[Positive Developments]:

1. Apple's new Sports app, which offers real-time scores and stats, could be a significant addition to their ecosystem, potentially driving user engagement and increasing the company's presence in the sports streaming market.
2. The company's AI strategy is expected to be unveiled at its June developer conference, which could help Apple catch up with its tech peers and potentially drive growth in the future.
3. Apple's decision to offer discounts on the latest iPhone series in China could help revive sales in the region, where the company faces stiff competition from domestic brands.

[Potential Concerns]:

1. Apple's underperformance compared to its competitors, as seen in the recent stock price decline, could be a concern for investors, particularly if the company fails to address the issue.
2. The European Union's investigation into Apple's decision to close Epic Games' developer account could lead to regulatory scrutiny and potentially impact the company'

### RAG Market Analyst and Assistant with query understanding

In [39]:
query_classification = {
    'explanation': {
        'regex': r'\b(explain|explanation|explaining|what is|how does|how do)\b',
        "instruction": lambda parsed_query: cleandoc(
            f"Explain the relevant financial concepts or terms for {parsed_query['company']} and its current situation. "
            "Provide context on how these factors affect the company's performance and stock price. "
            "Ensure the explanation directly answers the query."
        )
    },
    'analysis': {
        'regex': r'\b(analyze|analysis|analyzing|evaluate|evaluates|evaluation)\b',
        "instruction": lambda parsed_query: cleandoc(
            f"Provide a comprehensive analysis of {parsed_query['company']}'s current situation, "
            f"including financial health, recent performance, and market position. "
            f"Highlight key strengths and potential risks for {parsed_query['ticker']}. "
            "Ensure the analysis directly addresses the query."
        )
    },
    'prediction': {
        'regex': r'\b(predict|prediction|predictions|forecast|predicting|predicts|forecasting|forecasts)\b',
        "instruction": lambda parsed_query: cleandoc(
            f"First, analyze the positive developments and potential concerns for {parsed_query['ticker']}. "
            f"Come up with the 2-4 most important factors, respectively, and keep them concise. "
            f"Most factors should be inferred from company-related news. "
            f"Then make your prediction of the {parsed_query['ticker']} stock price movement for next week "
            f"({parsed_query['end_date']} to {time2date(time_after_days(date2time(parsed_query['end_date']), 7))}). "
            "Provide a summary analysis to support your prediction."
        )
    },
    'general': {
        'regex': None,
        "instruction": lambda parsed_query: cleandoc(
            f"Address the query: '{parsed_query['query']}'. Provide a thorough and informative response based on the available data."
        )
    }
}

def classify_query(query):
    query = query.lower()
    for query_type, data in query_classification.items():
        if data['regex'] and re.search(data['regex'], query, flags=re.IGNORECASE):
            return query_type
    return "general"

def get_instructions(query, parsed_query):
    query_type = classify_query(query)
    parsed_query['query'] = query
    base_prompt = f"Based on the following information about {parsed_query['company']} and market data up to {parsed_query['end_date']}: "
    return base_prompt + query_classification[query_type]['instruction'](parsed_query)

In [137]:
def rag_with_prep(model, tokenizer, query):
    prompt = ""
    parsed_query = parse_user_query(query)

    company_intro_res = collections['companies'].query(
        query_texts=[f"Company {parsed_query['company']} ({parsed_query['ticker']})"],
        n_results=1
    )
    company_intro = company_intro_res['documents'][0][0].split("\n\n", 1)[1]
    prompt += company_intro + "\n\n"
    
    period_summaries_res = collections['financials'].query(
        query_texts=[query],
        n_results=10,
        include=['documents', 'metadatas'],
        where={ 
            '$and': [
                { 'type': 'period_summary' }, 
                { 'ticker': parsed_query['ticker'] },
                { 'start_date': { '$gte': date2time(parsed_query['start_date']) }},
                { 'end_date': { '$lte': date2time(parsed_query['end_date']) }}
            ] 
        }
    )
    
    sorted_period_summaries = sort_chroma_by_date(period_summaries_res)
    for doc, metadata in zip(sorted_period_summaries['documents'][0], sorted_period_summaries['metadatas'][0]):
        search_filters = [
            { 'ticker': parsed_query['ticker'] },
            { 'start_date': { '$gte': metadata['start_date'] }},
            { 'end_date': { '$lte': metadata['end_date'] }}
        ]
        
        
        news_res = collections['news'].query(
            query_texts=[query],
            where={ '$and': search_filters }
        )
        
        basic_financials_res = collections['financials'].query(
            query_texts=[query],
            where={ '$and': [ { 'type': 'basic_financials' }, *search_filters ]
            }
        )
        
        
        prompt += f"{doc} News during this period are listed below:\n\n"
        
        if news_res['documents'][0]:
            remove_before_news = lambda t: re.sub(r'.*(?=\[Headline\])', '', t, flags=re.DOTALL)
            prompt += "\n\n".join(list(map(remove_before_news, news_res['documents'][0])))
        else:
            prompt += "\nNo relative news reported."
        
        if basic_financials_res['documents'][0]:
            remove_before_basics = lambda t: re.sub(r'.*(?=\[Basic Financials\])', '', t, flags=re.DOTALL)
            prompt += f"\n\nSome recent basic financials of {parsed_query['ticker']}, reported at {time2date(metadata['end_date'])}, are presented below:\n\n"
            prompt += "".join(list(map(remove_before_basics, basic_financials_res['documents'][0])))
        else: 
            prompt += "\n\n[Basic Financials]:\n\nNo basic financial reported.\n\n"

        
    next_week_timestamp = time_after_days(date2time(parsed_query['end_date']), 7)
    
#     system_prompt = cleandoc(
#         "You are a seasoned stock market analyst. "
#         "Your task is to list the positive developments and potential "
#         "concerns for companies based on the context of relevant news, period summaries, and basic financials from the past weeks, "
#         "then answer the query by only using the provided context."
#     )

    system_prompt = cleandoc(
        "You are an advanced AI market analyst and financial assistant with access to a vast database of financial information, news, and market data. "
        "Your role is to assist users with various financial queries, provide market insights, and offer analysis on stocks and market trends "
        "based on the context of relevant news, period summaries, and basic financials from the past weeks.\n"
        "\n"
        "Your capabilities include:\n"
        "1. Analyzing company performance and financials\n"
        "2. Providing market trend analysis\n"
        "3. Offering stock price predictions with careful disclaimers\n"
        "4. Explaining financial concepts and terms\n"
        "5. Identifying potential investment opportunities and risks\n"
        "\n"
        "Always base your responses on the provided context from the database. "
        "If you're unsure or don't have enough information, say so and suggest what additional data might be helpful."
        "\n"
        "\n"
        "Remember:\n"
        "- Always provide balanced views, mentioning both potential upsides and risks.\n"
        "- When making predictions, clearly state that these are opinions based on available data and not guaranteed outcomes.\n"
        "- Use data to support your points, but explain it in a way that's easy for non-experts to understand.\n"
        "- If asked about specific financial advice, remind the user that you're an AI and"
        "recommend consulting with a qualified financial advisor for personalized guidance."
    )
    
    prompt = f"Context:\n{prompt}\nQuery: {query}"
#     prompt += "\n\n" + get_instructions(query, parsed_query)
    
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": prompt},
        {"role": "assistant", "content": f"Certainly! I'd be happy to help with your query. Let me analyze the relevant information from our financial database."},
        {"role": "user", "content": get_instructions(query, parsed_query)}
    ]
    
#     messages = [
#         {"role": "system", "content": system_prompt},
#         {"role": "user", "content": prompt},
# #         {"role": "assistant", "content": f"Certainly! I'd be happy to help with your query. Let me analyze the relevant information from our financial database."},
# #         {"role": "user", "content": get_instructions(query, parsed_query)}
#     ]
    
    completion = get_completion(model, tokenizer, messages)
    return completion, prompt

Fine-tuned

In [132]:
response, prompt = rag_with_prep(
    finetuned_model, 
    finance_tokenizer,
    "What are the predictions of Apple Inc's stock price for the next week if today is 2024-03-14?"
)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


In [133]:
print(response) # without assistant in messages

Based on the news and market data up to 2024-03-14, the following are the 2-4 most important factors for AAPL:

1. Apple's new Sports app: The launch of Apple's Sports app could lead to increased engagement and user base, which could positively impact the company's stock price.
2. AI strategy: Apple's lagging behind tech peers with a clearer AI strategy could be a concern for investors, as the company is expected to unveil AI tools at its June developer conference.
3. Competition: Apple's competition in the tech space, particularly from Huawei, could impact its stock price.
4. EU investigation: The European Union's investigation into Apple's decision to close Epic Games' developer account could have implications for the company's stock price.

Based on these factors, I predict that AAPL's stock price will increase by 1-2% from 170.5 to 172.5 by the end of next week (2024-03-21). This prediction is based on the assumption that the positive developments related to Apple's new Sports app 

In [125]:
print(response) # assistant in messages

Based on the available data, the two most important factors affecting Apple Inc's stock price are the company's AI strategy and its position in the market. The article "Apple Investors Grow Impatient on Artificial Intelligence" suggests that the company's lack of a clear AI strategy is a concern for investors. Additionally, the article "EU looking into Apple's decision to kill Epic Games' developer account" highlights the company's regulatory issues.

Considering these factors, I predict that Apple Inc's stock price will remain relatively stable or decline slightly over the next week. The company's AI strategy and regulatory issues may continue to weigh on investor sentiment, leading to a cautious approach to the stock. However, the company's strong market position and potential for innovation may provide some support for the stock.

To summarize, based on the available data, I predict that Apple Inc's stock price will remain relatively stable or decline slightly over the next week, wi

Base

In [138]:
response, prompt = rag_with_prep(
    base_model, 
    base_tokenizer,
    "What are the predictions of Apple Inc's stock price for the next week if today is 2024-03-14?"
)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.


In [136]:
print(response) # without assistant in messages


Based on the provided context, I've analyzed the positive developments and potential concerns for Apple Inc's stock price.

**Positive Developments:**

1. **Innovative Sports App**: Apple's launch of its free sports app, offering real-time scores and stats, could attract more users and create a new revenue stream.
2. **Potential AI Advancements**: The expected unveiling of AI tools at Apple's June developer conference could lead to increased investor confidence and boost the stock price.
3. **Discounted iPhone Sales**: Apple's authorized retailers in China are offering significant discounts on the latest iPhone series, which could revitalize sales and increase demand.

**Potential Concerns:**

1. **Competition from Huawei**: Apple faces stiff competition from domestic brands like Huawei Technologies, which could erode its market share in China.
2. **Lack of Clear AI Strategy**: Apple's lagging behind in AI strategy compared to its tech peers could raise concerns among investors and af

In [139]:
print(response) # assistant in messages


Based on the provided information, here are the key positive developments and potential concerns for AAPL:

**Positive Developments:**

1. **Innovation in Sports App**: Apple's launch of its free sports app, offering real-time scores and stats, could be a significant growth driver, especially with the increasing popularity of sports streaming.
2. **AI Strategy**: Apple's expected unveiling of AI tools at its June developer conference might help the company catch up with its peers and drive innovation in the tech space.
3. **Potential for Partnerships**: The news about media companies teaming up for streaming bundles could lead to new opportunities for Apple to collaborate and expand its services.

**Potential Concerns:**

1. **Competition from Huawei**: Apple's struggles in China, where Huawei is gaining traction, could impact its sales and revenue.
2. **Lack of AI Strategy**: Apple's lag behind its peers in AI strategy might continue to be a concern for investors, potentially affecti

### For prompting:
1. It's better to use periods of dates in YYYY-MM-DD format, such as 'from 2024-01-01 to 2024-02-28'.
2. You can also provide a date and a number of weeks BEFORE that date, such as '5 weeks before 2024-01-01'

Otherwise, the data will be taken 4 weeks before today.