In [31]:
from newsapi import NewsApiClient
from langchain.chat_models import init_chat_model
from langchain_core.messages import HumanMessage
from dotenv import load_dotenv
import os
import sqlite3
from langchain_core.tools import tool
from datetime import timedelta, datetime, date
from typing import List
from langchain_tavily import TavilySearch
from langgraph.prebuilt import create_react_agent
from transformers import pipeline

load_dotenv()

True

In [2]:
model = init_chat_model("gemini-2.0-flash", model_provider="google_genai")
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")

Device set to use mps:0


In [3]:
def init_db():
    conn = sqlite3.connect("news_cache.db")
    c = conn.cursor()
    c.execute('''
        CREATE TABLE IF NOT EXISTS news_articles(
              id INTEGER PRIMARY KEY AUTOINCREMENT,
              company TEXT,
              title TEXT,
              information TEXT,
              published_date TEXT,
              sentiment_label TEXT,
              sentiment_score REAL,
              UNIQUE(company, title, published_date)
              )         
''')
    c.execute('''
        CREATE TABLE IF NOT EXISTS query_log(
              id INTEGER PRIMARY KEY AUTOINCREMENT,
              company TEXT,
              date TEXT,
              UNIQUE(company, date)
              )         
''')
    conn.commit()
    conn.close()
init_db()

In [26]:
@tool
def get_news(company_name: str, begin_date: str, end_date: str) -> List[dict]:
    '''
    This Tool Here checks the sqlite database to see if a query has already been done for a particular date, if it has not been done yet, 
    it then uses NewsAPI to fetch articles from the dates it needs, updateing our databse in the process. Afterwards, our tool returns
    every article from the requested dates
    '''
    def use_news_api(company_name, begin_date, end_date):
        news = NewsApiClient(api_key=os.getenv('NEWS_API_KEY'))
        all_articles = news.get_everything(
            q=company_name,
            from_param=begin_date,
            to=end_date,
            language='en',
            sort_by='relevancy'
        )['articles']

        with sqlite3.connect("news_cache.db", timeout=10) as conn:
            c = conn.cursor()

            query_date = datetime.strptime(begin_date, '%Y-%m-%d')
            end_dt = datetime.strptime(end_date, '%Y-%m-%d')
            while query_date <= end_dt:
                c.execute('''
                    INSERT OR IGNORE INTO query_log (company, date)
                    VALUES (?, ?)
                ''', (company_name, query_date.strftime('%Y-%m-%d')))
                query_date += timedelta(days=1)

            for article in all_articles:
                candidate_labels = ["positive", "negative", "neutral"]
                article_content =  article.get('description') or article.get('content', '')
                result = classifier(article_content, candidate_labels)
                dominant_score = max(result['scores'])
                maxidx = result['scores'].index(dominant_score)
                dominant_label = result['labels'][maxidx]
                c.execute('''
                    INSERT OR IGNORE INTO news_articles (company, title, information, published_date)
                    VALUES (?, ?, ?, ?, ?, ?)
                ''', (
                    company_name,
                    article['title'],
                    article_content,
                    article['publishedAt'].split("T")[0],
                    dominant_label,
                    dominant_score
                ))

            conn.commit()

    begin_date_dt = datetime.strptime(begin_date, '%Y-%m-%d')
    end_date_dt = datetime.strptime(end_date, '%Y-%m-%d')

    with sqlite3.connect("news_cache.db", timeout=10) as conn:
        c = conn.cursor()

        c.execute('''
            SELECT MIN(date) FROM query_log WHERE company = ?
        ''', (company_name,))
        min_result = c.fetchone()
        min_available = datetime.strptime(min_result[0], '%Y-%m-%d') if min_result and min_result[0] else None

        c.execute('''
            SELECT MAX(date) FROM query_log WHERE company = ?
        ''', (company_name,))
        max_result = c.fetchone()
        max_available = datetime.strptime(max_result[0], '%Y-%m-%d') if max_result and max_result[0] else None

    if min_available is None or max_available is None:
        use_news_api(company_name, begin_date, end_date)
    elif end_date_dt < min_available or begin_date_dt > max_available:
        use_news_api(company_name, begin_date, end_date)
        use_news_api(company_name, (max_available + timedelta(days=1)).strftime('%Y-%m-%d'), end_date)
    elif begin_date_dt < min_available:
        use_news_api(company_name, begin_date, (min_available - timedelta(days=1)).strftime('%Y-%m-%d'))

    with sqlite3.connect("news_cache.db", timeout=10) as conn:
        c = conn.cursor()
        c.execute('''
            SELECT title, information, published_date, score, label
            FROM news_articles
            WHERE company = ?
              AND published_date BETWEEN ? AND ?
        ''', (company_name, begin_date, end_date))
        rows = c.fetchall()

    return [
        {
            "title": row[0],
            "information": row[1],
            "published_date": row[2],
            "score":row[3],
            "label":row[4]
        } for row in rows
    ]


In [27]:
basic_search_tool = TavilySearch(
    max_results=5,
    topic="general",
    # include_answer=False,
    # include_raw_content=False,
    # include_images=False,
    # include_image_descriptions=False,
    # search_depth="basic",
    # time_range="day",
    # include_domains=None,
    # exclude_domains=None
)

In [36]:
today = date.today().strftime("%B %d, %Y")
agent_executor = create_react_agent(
    model=model, 
    tools=[get_news, basic_search_tool],
    prompt=(
        f"You are an expert in public relations (PR) analysis. Today is {today}. \n When given a company to investigate, your job is to "
        "analyze the overall sentiment of recent news coverage about the company.\n\n"
        "You can use the `get_news` tool to retrieve relevant news articles. This tool returns up to 100 articles per call. "
        "You may only call `get_news` once per company unless the date range is very large—in that case, you may call it up to 3 times.\n\n"
        "Based on the articles returned, provide a sentiment score and a qualitative assessment of the company's current PR outlook. "
        "Be sure to cite the main reasons behind the sentiment you assign.\n\n"
        "If you lack information needed to proceed (e.g., missing date ranges, unclear references), use the `basic_search_tool` to clarify "
        "the context using relevant keywords or events mentioned in the prompt.\n\n"
        "ENSURE ANY REQUEST YOU MAKE TO GET_NEWS IS WITHIN ONE MONTH, IF THEY REQUEST SOMETHING BEFORE TELL THEM YOU CANNOT\n\n"
        "Example: If a user asks about 'Tesla's recent earnings report,' but doesn't specify a date, use the context of the message along "
        "with the `basic_search_tool` to identify the relevant time frame before making further decisions.\n"
    )
)

In [37]:
agent_executor.invoke({"messages": [{"role": "user", "content": "How has the public felt about Nvidia this past week"}]})

{'messages': [HumanMessage(content='How has the public felt about Nvidia this past week', additional_kwargs={}, response_metadata={}, id='2586d342-9f66-4da5-8282-522c9572dfd0'),
  AIMessage(content='', additional_kwargs={'function_call': {'name': 'get_news', 'arguments': '{"company_name": "Nvidia", "end_date": "2025-06-12", "begin_date": "2025-06-05"}'}}, response_metadata={'prompt_feedback': {'block_reason': 0, 'safety_ratings': []}, 'finish_reason': 'STOP', 'model_name': 'gemini-2.0-flash', 'safety_ratings': []}, id='run--97363373-b4e2-4136-93b2-7418105a2fd0-0', tool_calls=[{'name': 'get_news', 'args': {'company_name': 'Nvidia', 'end_date': '2025-06-12', 'begin_date': '2025-06-05'}, 'id': '59ba278d-a0fd-4e1d-866a-de6f0cfdd753', 'type': 'tool_call'}], usage_metadata={'input_tokens': 855, 'output_tokens': 33, 'total_tokens': 888, 'input_token_details': {'cache_read': 0}}),
  ToolMessage(content="Error: OperationalError('6 values for 4 columns')\n Please fix your mistakes.", name='get_n