In [None]:
!pip install matplotlib
!pip install seaborn

In [2]:
# Importing the necessary libraries
import matplotlib.pyplot as plt
import seaborn as sns
import os
import time
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql.functions import *
from IPython.display import clear_output
from collections import Counter

In [3]:
# App CONSTANTS
STATIC_DATA_PATH = "stocks_static_data.csv"
TRANSACTIONS_TOPIC = "transactions"
KAFKA_SERVER = "localhost:9092"

In [4]:
# Global variables to store accumulated data
accumulated_data = {
    'industry_growth': pd.DataFrame(),
    'transaction_types': pd.DataFrame(),
    'daily_volume': pd.DataFrame(),
    'avg_volume': pd.DataFrame(),
    'country_growth': pd.DataFrame(),
    'tech_growth': pd.DataFrame()
}

In [5]:
# Global variables for summary
total_transactions = 0
unique_countries = set()
unique_tickers = set()
unique_industries = set()
unique_sectors = set()
total_volume = 0

In [None]:
# define the spark session
spark = SparkSession.builder \
    .appName("StockMarketAnalysis") \
    .config("spark.sql.legacy.timeParserPolicy", "LEGACY") \
    .config("spark.driver.memory", "4g") \
    .config("spark.executor.memory", "4g") \
    .getOrCreate()

# Set log level to ERROR to reduce output
spark.sparkContext.setLogLevel("ERROR")

In [7]:
# define the static data schema:
static_tickers_data_scheme = StructType([
    StructField("ticker", StringType(), True),
    StructField("company_name", StringType(), True),
    StructField("shares_outstanding", FloatType(), True),
    StructField("exchange", StringType(), True),
    StructField("sector", StringType(), True),
    StructField("industry", StringType(), True),
    StructField("country", StringType(), True),
])

# define mock live transaction data scheme:
live_transactions_data_scheme = StructType([
    StructField("ticker", StringType(), True),
    StructField("timestamp", StringType(), True),
    StructField("price", StringType(), True),
    StructField("volume", StringType(), True),
    StructField("action", StringType(), True),
])

In [None]:
# For debug purposes - check if the csv file exists
print(f"CSV file exists: {os.path.exists(STATIC_DATA_PATH)}")
print(f"CSV file path: {os.path.abspath(STATIC_DATA_PATH)}")

In [None]:
# Load the static data
tickers_df = spark.read.csv(STATIC_DATA_PATH, header=True, schema=static_tickers_data_scheme)

# For debug purposes - Print the schema of the Tickers DataFrame
print("Tickers DataFrame Schema:")
tickers_df.printSchema()
print("\nFirst few rows of the Tickers DataFrame:")
tickers_df.show(5, truncate=False)
print(f"\nNumber of rows in Tickers DataFrame: {tickers_df.count()}")

In [None]:
# Define the Kafka stream
kafka_df = spark.readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", KAFKA_SERVER) \
    .option("subscribe", TRANSACTIONS_TOPIC) \
    .option("startingOffsets", "earliest") \
    .option("maxOffsetsPerTrigger", 1000) \
    .load()

# For debug purposes - Print Kafka DataFrame schema
print("Kafka DataFrame Schema:")
kafka_df.printSchema()

# Parse the JSON data
parsed_df = kafka_df.select(
    from_json(col("value").cast("string"), live_transactions_data_scheme).alias("parsed_data")
).select(
    col("parsed_data.ticker"),
    to_timestamp(col("parsed_data.timestamp"), "dd/MM/yyyy HH:mm").alias("timestamp"),
    col("parsed_data.price").cast("float").alias("price"),
    col("parsed_data.volume").cast("integer").alias("volume"),
    col("parsed_data.action"),
    to_date(to_timestamp(col("parsed_data.timestamp"), "dd/MM/yyyy HH:mm")).alias("date")
)

In [None]:
# Initialize plots
plt.ion()
fig, axs = plt.subplots(3, 2, figsize=(20, 20))
fig.suptitle('Stock Market Analysis', fontsize=16)

In [18]:
def update_plots(df):
    global accumulated_data
    
    # 1. Industry growth analysis
    new_industry_growth = df.groupBy("industry").agg(sum("volume").alias("total_volume")).toPandas()
    accumulated_data['industry_growth'] = pd.concat([accumulated_data['industry_growth'], new_industry_growth]).groupby('industry').sum().reset_index()

    # 2. Percentage of each transaction type
    new_transaction_types = df.groupBy("action").count().toPandas()
    accumulated_data['transaction_types'] = pd.concat([accumulated_data['transaction_types'], new_transaction_types]).groupby('action').sum().reset_index()

    # 3. Daily volume for each stock
    new_daily_volume = df.groupBy("ticker", "date").agg(sum("volume").alias("daily_volume")).toPandas()
    accumulated_data['daily_volume'] = pd.concat([accumulated_data['daily_volume'], new_daily_volume])
    accumulated_data['daily_volume'] = accumulated_data['daily_volume'].groupby(['ticker', 'date']).sum().reset_index()

    # 4. Stock with highest average trading volume
    new_avg_volume = df.groupBy("ticker").agg(avg("volume").alias("avg_volume")).toPandas()
    accumulated_data['avg_volume'] = pd.concat([accumulated_data['avg_volume'], new_avg_volume]).groupby('ticker').mean().reset_index()

    # 5. Locations of growing companies by country
    new_country_growth = df.groupBy("country").agg(sum("volume").alias("total_volume")).toPandas()
    accumulated_data['country_growth'] = pd.concat([accumulated_data['country_growth'], new_country_growth]).groupby('country').sum().reset_index()

    # 6. Growing technologies market
    new_tech_growth = df.filter(col("sector") == "Technology").groupBy("industry").agg(sum("volume").alias("total_volume")).toPandas()
    accumulated_data['tech_growth'] = pd.concat([accumulated_data['tech_growth'], new_tech_growth]).groupby('industry').sum().reset_index()

    # Clear the current output and create new plots
    clear_output(wait=True)
    fig, axs = plt.subplots(3, 2, figsize=(20, 20))
    fig.suptitle('Stock Market Analysis', fontsize=16)

    # Plot accumulated data
    sns.heatmap(accumulated_data['industry_growth'].pivot(columns='industry', values='total_volume'), ax=axs[0, 0], cmap='YlOrRd')
    axs[0, 0].set_title('Industry Growth')

    sns.barplot(x='action', y='count', data=accumulated_data['transaction_types'], ax=axs[0, 1])
    axs[0, 1].set_title('Transaction Types')

    sns.lineplot(x='date', y='daily_volume', hue='ticker', data=accumulated_data['daily_volume'], ax=axs[1, 0])
    axs[1, 0].set_title('Daily Volume by Stock')

    top_10_volume = accumulated_data['avg_volume'].nlargest(10, 'avg_volume')
    sns.barplot(x='ticker', y='avg_volume', data=top_10_volume, ax=axs[1, 1])
    axs[1, 1].set_title('Top 10 Stocks by Average Trading Volume')

    sns.heatmap(accumulated_data['country_growth'].pivot(columns='country', values='total_volume'), ax=axs[2, 0], cmap='YlOrRd')
    axs[2, 0].set_title('Company Growth by Country')

    sns.heatmap(accumulated_data['tech_growth'].pivot(columns='industry', values='total_volume'), ax=axs[2, 1], cmap='YlOrRd')
    axs[2, 1].set_title('Technology Market Growth')

    plt.tight_layout()
    plt.show()

In [19]:
def update_summary(df):
    global total_transactions, unique_countries, unique_tickers, unique_industries, unique_sectors, total_volume
    
    # Update total transactions
    batch_transactions = df.count()
    total_transactions += batch_transactions
    
    # Update unique sets
    unique_countries.update(df.select('country').distinct().rdd.flatMap(lambda x: x).collect())
    unique_tickers.update(df.select('ticker').distinct().rdd.flatMap(lambda x: x).collect())
    unique_industries.update(df.select('industry').distinct().rdd.flatMap(lambda x: x).collect())
    unique_sectors.update(df.select('sector').distinct().rdd.flatMap(lambda x: x).collect())
    
    # Update total volume
    batch_volume = df.agg(sum('volume')).collect()[0][0]
    total_volume += batch_volume
    
    # Print summary
    print("\n--- Summary ---")
    print(f"Total transactions processed: {total_transactions}")
    print(f"Number of unique countries: {len(unique_countries)}")
    print(f"Number of unique tickers: {len(unique_tickers)}")
    print(f"Number of unique industries: {len(unique_industries)}")
    print(f"Number of unique sectors: {len(unique_sectors)}")
    print(f"Total market volume: {total_volume}")
    
    # Additional detailed information
    print("\nTop 5 countries by transaction count:")
    country_counts = Counter(df.select('country').rdd.flatMap(lambda x: x).collect())
    for country, count in country_counts.most_common(5):
        print(f"{country}: {count}")
    
    print("\nTop 5 tickers by volume:")
    ticker_volumes = df.groupBy('ticker').agg(sum('volume').alias('total_volume')).orderBy('total_volume', ascending=False).limit(5).collect()
    for row in ticker_volumes:
        print(f"{row['ticker']}: {row['total_volume']}")

In [20]:
# Process each batch
def process_batch(df, epoch_id):
    try:
        print(f"Processing batch {epoch_id}")
        print("Input DataFrame Schema:")
        df.printSchema()

        print("\nInput DataFrame Sample:")
        df.show(5, truncate=False)

        if df.rdd.isEmpty():
            print("Warning: Input DataFrame is empty. Skipping this batch.")
            return

        # Enrich the data with static information
        enriched_df = df.join(tickers_df, "ticker", "left_outer")
        print("\nEnriched DataFrame Schema:")
        enriched_df.printSchema()

        print("\nEnriched DataFrame Sample:")
        enriched_df.show(5, truncate=False)
        
        print(f"\nUnique tickers in enriched data: {enriched_df.select('ticker').distinct().collect()}")
        print(f"\nUnique industries in enriched data: {enriched_df.select('industry').distinct().collect()}")
        print(f"\nUnique sectors in enriched data: {enriched_df.select('sector').distinct().collect()}")
        print(f"\nUnique countries in enriched data: {enriched_df.select('country').distinct().collect()}")

        # Count rows before and after join
        input_count = df.count()
        enriched_count = enriched_df.count()
        print(f"\nRows before join: {input_count}")
        print(f"Rows after join: {enriched_count}")

        if input_count != enriched_count:
            print("WARNING: Row count changed after join. Some rows may have been dropped.")

        # Update summary
        update_summary(enriched_df)

        # Update plots
        update_plots(enriched_df)
    
    except Exception as e:
        print(f"Error processing batch {epoch_id}: {str(e)}")
        import traceback
        traceback.print_exc()

In [None]:
# Start the streaming query
query = parsed_df \
    .writeStream \
    .foreachBatch(process_batch) \
    .option("checkpointLocation", "/tmp/checkpoint") \
    .trigger(processingTime='5 seconds') \
    .start()

# Print the query status
print(f"Query started: {query.isActive}")
print(f"Query name: {query.name}")
print(f"Query id: {query.id}")

# Use awaitTermination
try:
    query.awaitTermination()
except KeyboardInterrupt:
    print("Stopping the streaming query...")
finally:
    query.stop()