In [None]:
# Import necessary libraries
import os
from neo4j import GraphDatabase
from dotenv import load_dotenv
from yfiles_jupyter_graphs import GraphWidget
from typing import List

# Load environment variables
load_dotenv()

# Set Neo4j credentials
NEO4J_URI = os.getenv("NEO4J_URI", "neo4j+s://53d608e5.databases.neo4j.io")
NEO4J_USERNAME = os.getenv("NEO4J_USERNAME", "neo4j")
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "NSdlyOMl34oT6HilZL1wRmBP4nQde0lWopDwotBl9Q0")

# Establish a connection to the Neo4j database
driver = GraphDatabase.driver(
    uri=NEO4J_URI,
    auth=(NEO4J_USERNAME, NEO4J_PASSWORD)
)

# Utility function to execute a Cypher query
def execute_cypher_query(query: str, params: dict = None):
    """
    Executes a Cypher query on the Neo4j database and returns the results.

    Args:
        query (str): The Cypher query string.
        params (dict): Query parameters (optional).

    Returns:
        list: Query results as a list of dictionaries.
    """
    with driver.session() as session:
        result = session.run(query, params)
        return [record.data() for record in result]

# Utility function to visualize the graph
def visualize_graph(cypher_query: str):
    """
    Visualizes the result of a Cypher query using yFiles GraphWidget.

    Args:
        cypher_query (str): The Cypher query string.

    Returns:
        GraphWidget: Interactive graph visualization widget.
    """
    with driver.session() as session:
        result = session.run(cypher_query)
        widget = GraphWidget(graph=result.graph())
        widget.node_label_mapping = 'name'  # Customize based on node label property
        display(widget)
    return widget

# Example Queries

# 1. Find top stocks by market capitalization in a specific sector
def get_top_stocks_by_sector(sector: str, top_n: int = 5):
    """
    Retrieves the top N stocks by market capitalization within a specified sector.

    Args:
        sector (str): The name of the sector.
        top_n (int): The number of top stocks to return.

    Returns:
        list: Query results as a list of dictionaries.
    """
    query = f"""
    MATCH (s:Stock)-[:BELONGS_TO]->(sec:Sector {{name: $sector}})
    RETURN s.name AS Stock, s.marketCap AS MarketCap
    ORDER BY s.marketCap DESC
    LIMIT $top_n
    """
    return execute_cypher_query(query, {"sector": sector, "top_n": top_n})

# Example usage:
print("Top Tech Stocks:")
top_tech_stocks = get_top_stocks_by_sector("Information Technology")
print(top_tech_stocks)

# 2. Find correlations between two stocks
def get_stock_correlations(ticker: str, correlation_threshold: float = 0.8):
    """
    Retrieves stocks highly correlated with the given ticker.

    Args:
        ticker (str): The ticker of the stock.
        correlation_threshold (float): Minimum correlation threshold.

    Returns:
        list: Query results as a list of dictionaries.
    """
    query = f"""
    MATCH (s1:Stock {{name: $ticker}})-[r:CORRELATES_WITH]->(s2:Stock)
    WHERE r.correlation >= $threshold
    RETURN s2.name AS CorrelatedStock, r.correlation AS Correlation
    ORDER BY r.correlation DESC
    """
    return execute_cypher_query(query, {"ticker": ticker, "threshold": correlation_threshold})

# Example usage:
print("Highly Correlated Stocks for AAPL:")
aapl_correlations = get_stock_correlations("AAPL", 0.85)
print(aapl_correlations)

# 3. Visualize correlations for a stock
def visualize_stock_correlations(ticker: str, correlation_threshold: float = 0.8):
    """
    Visualizes correlations for a specific stock.

    Args:
        ticker (str): The ticker of the stock.
        correlation_threshold (float): Minimum correlation threshold.

    Returns:
        GraphWidget: Interactive graph visualization widget.
    """
    query = f"""
    MATCH (s1:Stock {{name: $ticker}})-[r:CORRELATES_WITH]->(s2:Stock)
    WHERE r.correlation >= $threshold
    RETURN s1, r, s2
    """
    return visualize_graph(query)

# Example usage:
print("Visualizing Correlations for AAPL:")
visualize_stock_correlations("AAPL", 0.85)

# 4. Visualize sector-level correlations
def visualize_sector_correlations(correlation_threshold: float = 0.85):
    """
    Visualizes correlations between sectors.

    Args:
        correlation_threshold (float): Minimum correlation threshold.

    Returns:
        GraphWidget: Interactive graph visualization widget.
    """
    query = f"""
    MATCH (sec1:Sector)-[r:CORRELATES_WITH]->(sec2:Sector)
    WHERE r.correlation >= $threshold
    RETURN sec1, r, sec2
    """
    return visualize_graph(query)

# Example usage:
print("Visualizing Sector Correlations:")
visualize_sector_correlations(0.85)

# Close the Neo4j driver after use
def close_driver():
    """
    Closes the Neo4j driver connection.
    """
    driver.close()

# Ensure proper cleanup
import atexit
atexit.register(close_driver)