In [1]:
# Import necessary libraries
import os
from neo4j import GraphDatabase
from dotenv import load_dotenv
from yfiles_jupyter_graphs import GraphWidget
from typing import List

In [2]:
# Load environment variables
load_dotenv()

# Set Neo4j credentials
NEO4J_URI = os.getenv("NEO4J_URI", "neo4j+s://53d608e5.databases.neo4j.io")
NEO4J_USERNAME = os.getenv("NEO4J_USERNAME", "neo4j")
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "NSdlyOMl34oT6HilZL1wRmBP4nQde0lWopDwotBl9Q0")

In [3]:
# Establish a connection to the Neo4j database
driver = GraphDatabase.driver(
    uri=NEO4J_URI,
    auth=(NEO4J_USERNAME, NEO4J_PASSWORD)
)

In [4]:
# Utility function to execute a Cypher query
def execute_cypher_query(query: str, params: dict = None):
    """
    Executes a Cypher query on the Neo4j database and returns the results.

    Args:
        query (str): The Cypher query string.
        params (dict): Query parameters (optional).

    Returns:
        list: Query results as a list of dictionaries.
    """
    with driver.session() as session:
        result = session.run(query, params)
        return [record.data() for record in result]

In [11]:
def visualize_graph(cypher_query: str, params: dict = None):
    """
    Visualizes the result of a Cypher query using yFiles GraphWidget.

    Args:
        cypher_query (str): The Cypher query string.
        params (dict): Query parameters.

    Returns:
        GraphWidget: Interactive graph visualization widget.
    """
    with driver.session() as session:
        result = session.run(cypher_query, params)
        widget = GraphWidget(graph=result.graph())
        widget.node_label_mapping = 'name'  # Customize based on node label property
        display(widget)
    return widget

In [6]:
# Example Queries

# 1. Find top stocks by market capitalization in a specific sector
def get_top_stocks_by_sector(sector: str, top_n: int = 5):
    """
    Retrieves the top N stocks by market capitalization within a specified sector.

    Args:
        sector (str): The name of the sector.
        top_n (int): The number of top stocks to return.

    Returns:
        list: Query results as a list of dictionaries.
    """
    query = f"""
    MATCH (s:Stock)-[:BELONGS_TO]->(sec:Sector {{name: $sector}})
    RETURN s.name AS Stock, s.marketCap AS MarketCap
    ORDER BY s.marketCap DESC
    LIMIT $top_n
    """
    return execute_cypher_query(query, {"sector": sector, "top_n": top_n})

# Example usage:
print("Top Tech Stocks:")
top_tech_stocks = get_top_stocks_by_sector("Information Technology")
print(top_tech_stocks)

Top Tech Stocks:
[{'Stock': 'AAPL', 'MarketCap': 5.035290923488713}, {'Stock': 'MSFT', 'MarketCap': 4.335693773804269}, {'Stock': 'NVDA', 'MarketCap': 4.297988715526784}, {'Stock': 'AVGO', 'MarketCap': 1.0230716507129027}, {'Stock': 'V', 'MarketCap': 0.3416377234504017}]


In [7]:
# 2. Find correlations between two stocks
def get_stock_correlations(ticker: str, correlation_threshold: float = 0.8):
    """
    Retrieves stocks highly correlated with the given ticker.

    Args:
        ticker (str): The ticker of the stock.
        correlation_threshold (float): Minimum correlation threshold.

    Returns:
        list: Query results as a list of dictionaries.
    """
    query = f"""
    MATCH (s1:Stock {{name: $ticker}})-[r:CORRELATES_WITH]->(s2:Stock)
    WHERE r.correlation >= $threshold
    RETURN s2.name AS CorrelatedStock, r.correlation AS Correlation
    ORDER BY r.correlation DESC
    """
    return execute_cypher_query(query, {"ticker": ticker, "threshold": correlation_threshold})

In [8]:
# Example usage:
print("Highly Correlated Stocks for AAPL:")
aapl_correlations = get_stock_correlations("AAPL", 0.85)
print(aapl_correlations)

Highly Correlated Stocks for AAPL:
[{'CorrelatedStock': 'AAPL', 'Correlation': 1.0000000000000215}, {'CorrelatedStock': 'AAPL', 'Correlation': 1.000000000000021}, {'CorrelatedStock': 'AAPL', 'Correlation': 1.0000000000000209}, {'CorrelatedStock': 'AAPL', 'Correlation': 1.0000000000000198}, {'CorrelatedStock': 'AAPL', 'Correlation': 1.0000000000000195}, {'CorrelatedStock': 'AAPL', 'Correlation': 1.000000000000019}, {'CorrelatedStock': 'AAPL', 'Correlation': 1.0000000000000182}, {'CorrelatedStock': 'AAPL', 'Correlation': 1.0000000000000164}, {'CorrelatedStock': 'AAPL', 'Correlation': 1.0000000000000144}, {'CorrelatedStock': 'AAPL', 'Correlation': 1.0000000000000142}, {'CorrelatedStock': 'AAPL', 'Correlation': 1.000000000000014}, {'CorrelatedStock': 'AAPL', 'Correlation': 1.0000000000000138}, {'CorrelatedStock': 'AAPL', 'Correlation': 1.0000000000000135}, {'CorrelatedStock': 'AAPL', 'Correlation': 1.000000000000013}, {'CorrelatedStock': 'AAPL', 'Correlation': 1.0000000000000129}, {'Correl

In [12]:
def visualize_stock_correlations(ticker: str, correlation_threshold: float = 0.8):
    """
    Visualizes correlations for a specific stock.

    Args:
        ticker (str): The ticker of the stock.
        correlation_threshold (float): Minimum correlation threshold.

    Returns:
        GraphWidget: Interactive graph visualization widget.
    """
    query = """
    MATCH (s1:Stock {name: $ticker})-[r:CORRELATES_WITH]->(s2:Stock)
    WHERE r.correlation >= $threshold
    RETURN s1, r, s2
    """
    params = {"ticker": ticker, "threshold": correlation_threshold}
    return visualize_graph(query, params)

In [13]:
# Example usage:
print("Visualizing Correlations for AAPL:")
visualize_stock_correlations("AAPL", 0.85)

Visualizing Correlations for AAPL:


GraphWidget(layout=Layout(height='800px', width='100%'))

GraphWidget(layout=Layout(height='800px', width='100%'))

In [16]:
def visualize_sector_correlations(correlation_threshold: float = 0.85):
    """
    Visualizes correlations between sectors.

    Args:
        correlation_threshold (float): Minimum correlation threshold.

    Returns:
        GraphWidget: Interactive graph visualization widget.
    """
    query = """
    MATCH (sec1:Sector)-[r:CORRELATES_WITH]->(sec2:Sector)
    WHERE r.correlation >= $threshold
    RETURN sec1, r, sec2
    """
    params = {"threshold": correlation_threshold}
    return visualize_graph(query, params)

In [17]:
# Example usage:
print("Visualizing Sector Correlations:")
visualize_sector_correlations(0.85)

Visualizing Sector Correlations:


GraphWidget(layout=Layout(height='610px', width='100%'))

GraphWidget(layout=Layout(height='610px', width='100%'))

In [18]:
# Close the Neo4j driver after use
def close_driver():
    """
    Closes the Neo4j driver connection.
    """
    driver.close()

In [19]:
# Ensure proper cleanup
import atexit
atexit.register(close_driver)

<function __main__.close_driver()>