In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.stat import Correlation

# Initialize SparkSession
spark = SparkSession.builder \
    .appName("Spotify EDA") \
    .config("spark.memory.offHeap.enabled", "true") \
    .config("spark.memory.offHeap.size", "2g") \
    .config("spark.driver.memory", "4g") \
    .getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/04/05 19:28:36 INFO SparkEnv: Registering MapOutputTracker
25/04/05 19:28:36 INFO SparkEnv: Registering BlockManagerMaster
25/04/05 19:28:36 INFO SparkEnv: Registering BlockManagerMasterHeartbeat
25/04/05 19:28:36 INFO SparkEnv: Registering OutputCommitCoordinator


In [2]:
# Read the deduplicated data
deduplicated_path = "gs://dataproc-staging-us-central1-361128386781-eo9ksqfa/spotify_deduplicated_data.csv"
print("Reading deduplicated data...")
spotify_df = spark.read.option("header", "true").csv(deduplicated_path, inferSchema=True)

# Verify data was loaded correctly
print(f"Loaded {spotify_df.count()} tracks with {len(spotify_df.columns)} features")
print("Data schema:")
spotify_df.printSchema()

Reading deduplicated data...




Loaded 200290 tracks with 20 features
Data schema:
root
 |-- title: string (nullable = true)
 |-- artist: string (nullable = true)
 |-- region: string (nullable = true)
 |-- track_id: string (nullable = true)
 |-- album: string (nullable = true)
 |-- popularity: double (nullable = true)
 |-- duration_ms: double (nullable = true)
 |-- explicit: boolean (nullable = true)
 |-- af_danceability: double (nullable = true)
 |-- af_energy: double (nullable = true)
 |-- af_key: double (nullable = true)
 |-- af_loudness: double (nullable = true)
 |-- af_mode: double (nullable = true)
 |-- af_speechiness: double (nullable = true)
 |-- af_acousticness: double (nullable = true)
 |-- af_instrumentalness: double (nullable = true)
 |-- af_liveness: double (nullable = true)
 |-- af_valence: double (nullable = true)
 |-- af_tempo: double (nullable = true)
 |-- af_time_signature: double (nullable = true)



                                                                                

# EDA

In [3]:
# Exploratory Data Analysis (EDA)
print("\n========== EXPLORATORY DATA ANALYSIS ==========\n")

# Basic dataset info
print(f"Number of tracks: {spotify_df.count()}")
print(f"Number of features: {len(spotify_df.columns)}")

# Compute basic statistics for numerical features
print("\nBasic statistics for numerical features:")
spotify_df.describe().show()



Number of tracks: 200290
Number of features: 20

Basic statistics for numerical features:


25/04/05 19:29:06 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
[Stage 10:>                                                         (0 + 1) / 1]

+-------+--------+-----------------------+--------------------+--------------------+----------+------------------+-----------------+------------------+------------------+-----------------+------------------+-------------------+-------------------+-------------------+--------------------+-------------------+-------------------+------------------+------------------+
|summary|   title|                 artist|              region|            track_id|     album|        popularity|      duration_ms|   af_danceability|         af_energy|           af_key|       af_loudness|            af_mode|     af_speechiness|    af_acousticness| af_instrumentalness|        af_liveness|         af_valence|          af_tempo| af_time_signature|
+-------+--------+-----------------------+--------------------+--------------------+----------+------------------+-----------------+------------------+------------------+-----------------+------------------+-------------------+-------------------+-------------------

                                                                                

In [4]:
# Show distribution of explicit flag if present
if "explicit" in spotify_df.columns:
    print("\nDistribution of explicit content:")
    spotify_df.groupBy("explicit").count().orderBy("count", ascending=False).show()

# Identify audio features (columns starting with 'af_')
audio_features = [col for col in spotify_df.columns if col.startswith('af_')]
print(f"\nIdentified audio features: {audio_features}")


Distribution of explicit content:


[Stage 11:>                                                         (0 + 2) / 2]

+--------+------+
|explicit| count|
+--------+------+
|   false|148553|
|    true| 49387|
|    NULL|  2350|
+--------+------+


Identified audio features: ['af_danceability', 'af_energy', 'af_key', 'af_loudness', 'af_mode', 'af_speechiness', 'af_acousticness', 'af_instrumentalness', 'af_liveness', 'af_valence', 'af_tempo', 'af_time_signature']


                                                                                

In [5]:
# Compute correlation matrix for audio features
if len(audio_features) > 0:
    print("\nComputing correlation matrix for audio features...")
    # Create feature vector
    assembler = VectorAssembler(inputCols=audio_features, outputCol="features")
    
    # Check for null values in audio features
    null_counts = {}
    for feature in audio_features:
        null_count = spotify_df.filter(col(feature).isNull()).count()
        if null_count > 0:
            null_counts[feature] = null_count
    
    if null_counts:
        print(f"Warning: Found null values in these features: {null_counts}")
        print("Dropping rows with null values in audio features for correlation calculation")
        df_for_corr = spotify_df.dropna(subset=audio_features)
    else:
        df_for_corr = spotify_df
    
    feature_vector = assembler.transform(df_for_corr.select(audio_features))
    
    # Compute correlation matrix
    correlation_matrix = Correlation.corr(feature_vector, "features").collect()[0][0]
    correlation_matrix_np = correlation_matrix.toArray()
    
    # Print correlation matrix with feature names
    print("\nCorrelation Matrix:")
    print("Features: " + ", ".join(audio_features))
    for i, row in enumerate(correlation_matrix_np):
        print(f"{audio_features[i]}: {[round(x, 2) for x in row]}")


Computing correlation matrix for audio features...


                                                                                

Dropping rows with null values in audio features for correlation calculation


[Stage 53:>                                                         (0 + 2) / 2]


Correlation Matrix:
Features: af_danceability, af_energy, af_key, af_loudness, af_mode, af_speechiness, af_acousticness, af_instrumentalness, af_liveness, af_valence, af_tempo, af_time_signature
af_danceability: [1.0, 0.15, 0.01, 0.17, -0.09, 0.23, -0.22, -0.13, -0.09, 0.38, -0.09, 0.15]
af_energy: [0.15, 1.0, 0.03, 0.74, -0.06, 0.06, -0.61, -0.12, 0.17, 0.37, 0.12, 0.13]
af_key: [0.01, 0.03, 1.0, 0.01, -0.16, 0.02, -0.01, -0.0, 0.0, 0.03, -0.0, 0.0]
af_loudness: [0.17, 0.74, 0.01, 1.0, -0.03, -0.01, -0.46, -0.33, 0.09, 0.27, 0.09, 0.11]
af_mode: [-0.09, -0.06, -0.16, -0.03, 1.0, -0.07, 0.05, -0.0, 0.01, -0.03, 0.01, -0.03]
af_speechiness: [0.23, 0.06, 0.02, -0.01, -0.07, 1.0, -0.06, -0.12, 0.05, 0.07, 0.04, 0.04]
af_acousticness: [-0.22, -0.61, -0.01, -0.46, 0.05, -0.06, 1.0, 0.08, -0.07, -0.16, -0.1, -0.13]
af_instrumentalness: [-0.13, -0.12, -0.0, -0.33, -0.0, -0.12, 0.08, 1.0, -0.03, -0.15, -0.01, -0.05]
af_liveness: [-0.09, 0.17, 0.0, 0.09, 0.01, 0.05, -0.07, -0.03, 1.0, 0.05, 0.

                                                                                

In [6]:
# Find top 10 most popular tracks if popularity column exists
if "popularity" in spotify_df.columns:
    print("\nTop 10 tracks by popularity:")
    spotify_df.orderBy("popularity", ascending=False).select("title", "artist", "popularity").show(10)

# Calculate distribution of audio features
print("\nDistribution of audio features:")
for feature in audio_features:
    try:
        quartiles = spotify_df.approxQuantile(feature, [0.25, 0.5, 0.75], 0.01)
        print(f"{feature}:")
        print(f"  25th percentile: {quartiles[0]}")
        print(f"  Median: {quartiles[1]}")
        print(f"  75th percentile: {quartiles[2]}")
        print()
    except Exception as e:
        print(f"Could not calculate quartiles for {feature}: {e}")


Top 10 tracks by popularity:


                                                                                

+----------------+--------------------+----------+
|           title|              artist|popularity|
+----------------+--------------------+----------+
| Versace - Remix|"El Mayor Clasico...|    2912.0|
|    Cruel Summer|        Taylor Swift|      96.0|
|       Unwritten| Natasha Bedingfield|      92.0|
|         Starboy|The Weeknd, Daft ...|      91.0|
|           Lover|        Taylor Swift|      90.0|
|    Another Love|           Tom Odell|      90.0|
| Blinding Lights|          The Weeknd|      90.0|
|          Yellow|            Coldplay|      90.0|
|The Night We Met|          Lord Huron|      90.0|
|    Pink + White|         Frank Ocean|      89.0|
+----------------+--------------------+----------+
only showing top 10 rows


Distribution of audio features:


                                                                                

af_danceability:
  25th percentile: 0.547
  Median: 0.666
  75th percentile: 0.763



                                                                                

af_energy:
  25th percentile: 0.517
  Median: 0.652
  75th percentile: 0.776



                                                                                

af_key:
  25th percentile: 2.0
  Median: 5.0
  75th percentile: 8.0



                                                                                

af_loudness:
  25th percentile: -8.644
  Median: -6.75
  75th percentile: -5.21

af_mode:
  25th percentile: 0.0
  Median: 1.0
  75th percentile: 1.0



                                                                                

af_speechiness:
  25th percentile: 0.0393
  Median: 0.0635
  75th percentile: 0.158



                                                                                

af_acousticness:
  25th percentile: 0.0476
  Median: 0.179
  75th percentile: 0.44

af_instrumentalness:
  25th percentile: 0.0
  Median: 1.55e-06
  75th percentile: 0.000302



                                                                                

af_liveness:
  25th percentile: 0.0954
  Median: 0.121
  75th percentile: 0.21



                                                                                

af_valence:
  25th percentile: 0.316
  Median: 0.488
  75th percentile: 0.676



                                                                                

af_tempo:
  25th percentile: 97.595
  Median: 120.015
  75th percentile: 139.946

af_time_signature:
  25th percentile: 4.0
  Median: 4.0
  75th percentile: 4.0



                                                                                

In [7]:
# Print sample data
print("\nSample of dataset:")
spotify_df.show(5)

# Stop SparkSession to release resources
spark.stop()


Sample of dataset:
+--------------------+------------+---------+--------------------+--------------------+----------+-----------+--------+---------------+---------+------+-----------+-------+--------------+---------------+-------------------+-----------+----------+--------+-----------------+
|               title|      artist|   region|            track_id|               album|popularity|duration_ms|explicit|af_danceability|af_energy|af_key|af_loudness|af_mode|af_speechiness|af_acousticness|af_instrumentalness|af_liveness|af_valence|af_tempo|af_time_signature|
+--------------------+------------+---------+--------------------+--------------------+----------+-----------+--------+---------------+---------+------+-----------+-------+--------------+---------------+-------------------+-----------+----------+--------+-----------------+
|Still Got Time (f...|        ZAYN|Australia|000xQL6tZNLJzIrtI...|Still Got Time (f...|      56.0|   188490.0|   false|          0.748|    0.627|   7.0|     -

# EDA Plots

In [9]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import numpy as np
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

# Initialize SparkSession
spark = SparkSession.builder \
    .appName("Spotify Visualization") \
    .config("spark.memory.offHeap.enabled", "true") \
    .config("spark.memory.offHeap.size", "2g") \
    .config("spark.driver.memory", "4g") \
    .getOrCreate()

# Read the deduplicated data
deduplicated_path = "gs://dataproc-staging-us-central1-361128386781-eo9ksqfa/spotify_deduplicated_data.csv"
print("Reading deduplicated data from GCP...")
spotify_df = spark.read.option("header", "true").csv(deduplicated_path, inferSchema=True)

# Verify data was loaded correctly
print(f"Loaded {spotify_df.count()} tracks with {len(spotify_df.columns)} features")

# Identify audio features (columns starting with 'af_')
audio_features = [col for col in spotify_df.columns if col.startswith('af_')]
print(f"Identified audio features: {audio_features}")

# Convert to Pandas for visualization
print("Converting to pandas for visualization...")
pdf = spotify_df.toPandas()
print(f"Converted {len(pdf)} records to pandas DataFrame")

# Set up plotting environment
plt.style.use('ggplot')
plt.rcParams['figure.figsize'] = (12, 8)

25/04/05 19:54:55 INFO SparkEnv: Registering MapOutputTracker
25/04/05 19:54:55 INFO SparkEnv: Registering BlockManagerMaster
25/04/05 19:54:55 INFO SparkEnv: Registering BlockManagerMasterHeartbeat
25/04/05 19:54:55 INFO SparkEnv: Registering OutputCommitCoordinator


Reading deduplicated data from GCP...


                                                                                

Loaded 200290 tracks with 20 features
Identified audio features: ['af_danceability', 'af_energy', 'af_key', 'af_loudness', 'af_mode', 'af_speechiness', 'af_acousticness', 'af_instrumentalness', 'af_liveness', 'af_valence', 'af_tempo', 'af_time_signature']
Converting to pandas for visualization...


                                                                                

Converted 200290 records to pandas DataFrame


In [None]:
# 1. Distribution of audio features
print("Plotting audio feature distributions...")
fig, axes = plt.subplots(4, 3, figsize=(18, 16))
axes = axes.flatten()

for i, feature in enumerate(audio_features):
    if i < len(axes):
        sns.histplot(pdf[feature].dropna(), kde=True, ax=axes[i])
        axes[i].set_title(f'Distribution of {feature}')
        axes[i].set_xlabel(feature)
        axes[i].set_ylabel('Count')

plt.tight_layout()
plt.savefig('/tmp/audio_feature_distributions.png')
print("Saved audio feature distributions to /tmp/audio_feature_distributions.png")

Plotting audio feature distributions...


In [None]:
# 2. Correlation heatmap
print("Creating correlation heatmap...")
plt.figure(figsize=(14, 12))
corr_matrix = pdf[audio_features].corr()
mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
sns.heatmap(corr_matrix, annot=True, fmt=".2f", cmap='coolwarm', mask=mask, 
            square=True, linewidths=.5, cbar_kws={"shrink": .7})
plt.title('Correlation between Audio Features', fontsize=18)
plt.tight_layout()
plt.savefig('/tmp/correlation_heatmap.png')
print("Saved correlation heatmap to /tmp/correlation_heatmap.png")

In [None]:
# 3. Popularity vs audio features (if popularity column exists)
if "popularity" in pdf.columns:
    print("Plotting popularity vs audio features...")
    fig, axes = plt.subplots(4, 3, figsize=(18, 16))
    axes = axes.flatten()
    
    for i, feature in enumerate(audio_features):
        if i < len(axes):
            axes[i].scatter(pdf[feature], pdf['popularity'], alpha=0.5)
            axes[i].set_title(f'{feature} vs Popularity')
            axes[i].set_xlabel(feature)
            axes[i].set_ylabel('Popularity')
    
    plt.tight_layout()
    plt.savefig('/tmp/popularity_vs_features.png')
    print("Saved popularity vs features plots to /tmp/popularity_vs_features.png")

In [None]:
# 4. Boxplot of audio features by explicit content (if explicit column exists)
if "explicit" in pdf.columns:
    print("Creating boxplots by explicit content...")
    fig, axes = plt.subplots(4, 3, figsize=(18, 16))
    axes = axes.flatten()
    
    for i, feature in enumerate(audio_features):
        if i < len(axes):
            sns.boxplot(x='explicit', y=feature, data=pdf, ax=axes[i])
            axes[i].set_title(f'{feature} by Explicit Content')
    
    plt.tight_layout()
    plt.savefig('/tmp/features_by_explicit.png')
    print("Saved features by explicit content boxplots to /tmp/features_by_explicit.png")

In [None]:
# 5. PCA visualization
print("Performing PCA analysis...")
# Prepare data for PCA
features_for_pca = pdf[audio_features].dropna()

if len(features_for_pca) > 0:
    # Standardize the features
    scaler = StandardScaler()
    scaled_features = scaler.fit_transform(features_for_pca)
    
    # Apply PCA
    pca = PCA(n_components=2)
    principal_components = pca.fit_transform(scaled_features)
    
    # Create a DataFrame with the principal components
    pca_df = pd.DataFrame(data=principal_components, columns=['PC1', 'PC2'])
    
    # Add track information
    pca_df['title'] = pdf.loc[features_for_pca.index, 'title'].values
    pca_df['artist'] = pdf.loc[features_for_pca.index, 'artist'].values
    if 'popularity' in pdf.columns:
        pca_df['popularity'] = pdf.loc[features_for_pca.index, 'popularity'].values
    
    # Plot PCA
    plt.figure(figsize=(12, 10))
    scatter = plt.scatter(pca_df['PC1'], pca_df['PC2'], 
                         c=pca_df['popularity'] if 'popularity' in pca_df.columns else None, 
                         alpha=0.5, cmap='viridis')
    plt.title('PCA of Audio Features', fontsize=18)
    plt.xlabel(f'PC1 ({pca.explained_variance_ratio_[0]*100:.2f}%)')
    plt.ylabel(f'PC2 ({pca.explained_variance_ratio_[1]*100:.2f}%)')
    
    if 'popularity' in pca_df.columns:
        plt.colorbar(scatter, label='Popularity')
    
    plt.savefig('/tmp/pca_visualization.png')
    print("Saved PCA visualization to /tmp/pca_visualization.png")
    print(f"PCA explained variance: {pca.explained_variance_ratio_}")

In [None]:
# 6. Top artists by track count
print("Analyzing top artists...")
top_artists = pdf['artist'].value_counts().head(20)
plt.figure(figsize=(14, 10))
sns.barplot(x=top_artists.values, y=top_artists.index)
plt.title('Top 20 Artists by Number of Tracks', fontsize=18)
plt.xlabel('Number of Tracks')
plt.tight_layout()
plt.savefig('/tmp/top_artists.png')
print("Saved top artists chart to /tmp/top_artists.png")

print("All visualizations complete!")

In [1]:
spark.stop()

NameError: name 'spark' is not defined

# data manipulation

In [2]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

# Initialize Spark Session
spark = SparkSession.builder \
    .appName("Spotify Classification Analysis") \
    .getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/04/05 21:17:17 INFO SparkEnv: Registering MapOutputTracker
25/04/05 21:17:17 INFO SparkEnv: Registering BlockManagerMaster
25/04/05 21:17:17 INFO SparkEnv: Registering BlockManagerMasterHeartbeat
25/04/05 21:17:18 INFO SparkEnv: Registering OutputCommitCoordinator


In [3]:
# Step 1: Load datasets
print("\n===== LOADING DATASETS =====")

# Load the deduplicated data from GCS
deduplicated_path = "gs://dataproc-staging-us-central1-361128386781-eo9ksqfa/spotify_deduplicated_data.csv"
print("Reading deduplicated data from GCP...")
spotify_df = spark.read.option("header", "true").csv(deduplicated_path, inferSchema=True)

# Check if the data was loaded successfully
print(f"Loaded deduplicated dataset with {spotify_df.count()} rows and {len(spotify_df.columns)} columns")
print("Sample of deduplicated data:")
spotify_df.show(5)

# Load the tracks features dataset
tracks_features_path = "gs://dataproc-staging-us-central1-361128386781-eo9ksqfa/tracks_features.csv"
print("\nReading tracks features data...")
tracks_features_df = spark.read.option("header", "true").csv(tracks_features_path, inferSchema=True)

# Check if the data was loaded successfully
print(f"Loaded tracks features dataset with {tracks_features_df.count()} rows and {len(tracks_features_df.columns)} columns")
print("Sample of tracks features data:")
tracks_features_df.show(5)

# Print the schema of both datasets to understand their structure
print("\nSchema of deduplicated dataset:")
spotify_df.printSchema()

print("\nSchema of tracks features dataset:")
tracks_features_df.printSchema()


===== LOADING DATASETS =====
Reading deduplicated data from GCP...


                                                                                

Loaded deduplicated dataset with 200290 rows and 20 columns
Sample of deduplicated data:
+--------------------+------------+---------+--------------------+--------------------+----------+-----------+--------+---------------+---------+------+-----------+-------+--------------+---------------+-------------------+-----------+----------+--------+-----------------+
|               title|      artist|   region|            track_id|               album|popularity|duration_ms|explicit|af_danceability|af_energy|af_key|af_loudness|af_mode|af_speechiness|af_acousticness|af_instrumentalness|af_liveness|af_valence|af_tempo|af_time_signature|
+--------------------+------------+---------+--------------------+--------------------+----------+-----------+--------+---------------+---------+------+-----------+-------+--------------+---------------+-------------------+-----------+----------+--------+-----------------+
|Still Got Time (f...|        ZAYN|Australia|000xQL6tZNLJzIrtI...|Still Got Time (f...|  

                                                                                

Loaded tracks features dataset with 1204025 rows and 24 columns
Sample of tracks features data:
+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+------------+-----------+--------+------------+------------------+---+-------------------+----+-----------+------------+----------------+-------------------+-------+-----------------+-----------+--------------+----+------------+
|                  id|                name|               album|            album_id|             artists|          artist_ids|track_number|disc_number|explicit|danceability|            energy|key|           loudness|mode|speechiness|acousticness|instrumentalness|           liveness|valence|            tempo|duration_ms|time_signature|year|release_date|
+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+------------+-----------+--------+------------+------------------

In [5]:
from pyspark.sql.functions import col

# Step 2: Filter songs from 2015-2021 in the tracks features dataset
print("\n===== FILTERING SONGS BY YEAR =====")

# Convert year column to integer for comparison
tracks_features_df = tracks_features_df.withColumn("year", col("year").cast("integer"))

# Filter songs between 2015 and 2021
filtered_tracks_df = tracks_features_df.filter((col("year") >= 2015) & (col("year") <= 2021))

# Count the number of songs after filtering
filtered_count = filtered_tracks_df.count()
print(f"Songs from 2015-2021: {filtered_count}")


===== FILTERING SONGS BY YEAR =====




Songs from 2015-2021: 336181


                                                                                

In [6]:
from pyspark.sql.functions import explode, split, trim, collect_set

# Step 3: Extract unique artists from the rank dataset
print("\n===== EXTRACTING UNIQUE ARTISTS FROM RANK DATASET =====")

# Collect all unique artists from the artist column
# The artists column might contain multiple artists separated by commas
# We'll collect them all as a single set
unique_artists = spotify_df.select("artist").distinct()

# Count the number of unique artists
artist_count = unique_artists.count()
print(f"Extracted {artist_count} unique artists from the rank dataset")

# Alternatively, we can persist this information without collecting
unique_artists_df = unique_artists.cache()


===== EXTRACTING UNIQUE ARTISTS FROM RANK DATASET =====




Extracted 86341 unique artists from the rank dataset


                                                                                

In [14]:
from pyspark.sql.functions import col, lower, lit
from pyspark.sql.types import BooleanType

# Step 4 (Re-revised): Filter songs by ranked artists using substring matching
print("\n===== FILTERING SONGS BY RANKED ARTISTS (SUBSTRING MATCHING) =====")

# Get unique artists from the rank dataset
unique_artists_df = spotify_df.select("artist").distinct().filter(col("artist").isNotNull())

# Collect artist names as a list
ranked_artists_list = [row.artist for row in unique_artists_df.collect()]
broadcast_artists = spark.sparkContext.broadcast(ranked_artists_list)

# Function to check if any ranked artist appears within the song's artists string
def artist_substring_match(artists_str, artist_list_broadcast):
    if artists_str is None:
        return False
    
    for artist in artist_list_broadcast.value:
        if artist is not None and artist in artists_str:
            return True
    
    return False

# Register as UDF
artist_match_udf = udf(
    lambda artists_str: artist_substring_match(artists_str, broadcast_artists),
    BooleanType()
)

# Apply the filter using substring matching (as in the original pandas code)
filtered_songs_df = filtered_tracks_df.filter(
    col("artists").isNotNull() & 
    artist_match_udf(col("artists"))
)

# Count the number of filtered songs
filtered_songs_count = filtered_songs_df.count()
print(f"Songs by ranked artists (substring matching): {filtered_songs_count}")

# Cache the filtered data for subsequent operations
filtered_songs_df = filtered_songs_df.cache()


===== FILTERING SONGS BY RANKED ARTISTS (SUBSTRING MATCHING) =====




Songs by ranked artists (substring matching): 278731


                                                                                

In [15]:
from pyspark.sql.functions import when, lit

# Step 5: Create labels - check if songs are in the rank dataset
print("\n===== CREATING BINARY LABELS =====")

# Extract track IDs from the rank dataset (spotify_df)
ranked_track_ids_df = spotify_df.select("track_id").distinct()

# Convert to a list for broadcast
ranked_track_ids = [row.track_id for row in ranked_track_ids_df.collect() if row.track_id is not None]
ranked_track_ids_broadcast = spark.sparkContext.broadcast(set(ranked_track_ids))

# Define a UDF to check if a track ID is in the ranked set
def is_in_rankings(track_id, ranked_ids_broadcast):
    if track_id is None:
        return False
    return track_id in ranked_ids_broadcast.value

# Register the UDF
from pyspark.sql.types import BooleanType
is_ranked_udf = udf(lambda x: is_in_rankings(x, ranked_track_ids_broadcast), BooleanType())

# Add label column (1 if song is in rank_df, 0 otherwise)
labeled_songs_df = filtered_songs_df.withColumn(
    "is_ranked", 
    when(is_ranked_udf(col("id")), 1).otherwise(0)
)

# Count ranked vs. non-ranked songs
ranked_count = labeled_songs_df.filter(col("is_ranked") == 1).count()
total_count = labeled_songs_df.count()

print(f"Songs that appeared in rankings: {ranked_count}")
print(f"Songs that did not appear in rankings: {total_count - ranked_count}")
print(f"Percentage of songs in rankings: {(ranked_count / total_count * 100):.2f}%")

# Cache the labeled data for subsequent operations
labeled_songs_df = labeled_songs_df.cache()


===== CREATING BINARY LABELS =====




Songs that appeared in rankings: 7510
Songs that did not appear in rankings: 271221
Percentage of songs in rankings: 2.69%


                                                                                

In [16]:
from pyspark.ml.feature import VectorAssembler, StandardScaler
from pyspark.sql.functions import col, isnan, when

# Step 6: Prepare features for classification
print("\n===== PREPARING FEATURES FOR CLASSIFICATION =====")

# Feature mapping from the original code
feature_mapping = {
    'danceability': 'af_danceability',
    'energy': 'af_energy',
    'key': 'af_key',
    'loudness': 'af_loudness',
    'mode': 'af_mode',
    'speechiness': 'af_speechiness',
    'acousticness': 'af_acousticness',
    'instrumentalness': 'af_instrumentalness',
    'liveness': 'af_liveness',
    'valence': 'af_valence',
    'tempo': 'af_tempo',
    'time_signature': 'af_time_signature'
}

# Cast all feature columns to double
for orig_col, target_col in feature_mapping.items():
    filtered_songs_df = filtered_songs_df.withColumn(
        orig_col, 
        col(orig_col).cast("double")
    )

# List of feature columns to use
feature_cols = list(feature_mapping.keys())

# Create a feature vector column using VectorAssembler
assembler = VectorAssembler(
    inputCols=feature_cols,
    outputCol="features"
)

# Apply the vector assembler
songs_with_features = assembler.transform(filtered_songs_df)

# Normalize/Standardize the features
scaler = StandardScaler(
    inputCol="features",
    outputCol="scaled_features",
    withStd=True,
    withMean=True
)

# Fit the scaler on the data
scaler_model = scaler.fit(songs_with_features)

# Apply the scaler to the data
scaled_songs = scaler_model.transform(songs_with_features)

# Cache the prepared data for model training
scaled_songs = scaled_songs.cache()


===== PREPARING FEATURES FOR CLASSIFICATION =====


                                                                                

In [17]:
# Save the prepared dataset to GCS
output_path = "gs://dataproc-staging-us-central1-361128386781-eo9ksqfa/spotify_prepared_data"

# Save as parquet format (efficient columnar storage)
scaled_songs.write.mode("overwrite").parquet(output_path)

print(f"Saved prepared dataset to: {output_path}")

# Alternatively, if you need CSV format:
# scaled_songs.write.mode("overwrite").option("header", "true").csv(output_path + "_csv")

                                                                                

Saved prepared dataset to: gs://dataproc-staging-us-central1-361128386781-eo9ksqfa/spotify_prepared_data


In [25]:
from pyspark.sql.functions import col, lit, when

# First, let's load the original deduplicated dataset to get the track IDs
deduplicated_path = "gs://dataproc-staging-us-central1-361128386781-eo9ksqfa/spotify_deduplicated_data.csv"
spotify_df = spark.read.option("header", "true").csv(deduplicated_path, inferSchema=True)

# Extract ranked track IDs from the original dataset
ranked_track_ids = spotify_df.select("track_id").distinct()

# Add is_ranked column to our current dataset using left join
completed_songs = scaled_songs.join(
    ranked_track_ids,
    scaled_songs.id == ranked_track_ids.track_id,
    "left"
).withColumn(
    "is_ranked",
    when(col("track_id").isNotNull(), 1).otherwise(0)
).drop("track_id")  # Drop the redundant track_id column from the join

# Cache the updated dataset
completed_songs = completed_songs.cache()

# Check the new schema
print("\nUpdated dataset schema:")
completed_songs.printSchema()

# Verify the is_ranked column was added
print("\nDistribution of the target variable:")
completed_songs.groupBy("is_ranked").count().orderBy("is_ranked").show()

# Save the updated dataset with the is_ranked column
updated_path = "gs://dataproc-staging-us-central1-361128386781-eo9ksqfa/spotify_prepared_data_complete"
completed_songs.write.mode("overwrite").parquet(updated_path)

print(f"\nUpdated dataset saved to: {updated_path}")

                                                                                


Updated dataset schema:
root
 |-- id: string (nullable = true)
 |-- name: string (nullable = true)
 |-- album: string (nullable = true)
 |-- album_id: string (nullable = true)
 |-- artists: string (nullable = true)
 |-- artist_ids: string (nullable = true)
 |-- track_number: string (nullable = true)
 |-- disc_number: string (nullable = true)
 |-- explicit: string (nullable = true)
 |-- danceability: double (nullable = true)
 |-- energy: double (nullable = true)
 |-- key: double (nullable = true)
 |-- loudness: double (nullable = true)
 |-- mode: double (nullable = true)
 |-- speechiness: double (nullable = true)
 |-- acousticness: double (nullable = true)
 |-- instrumentalness: double (nullable = true)
 |-- liveness: double (nullable = true)
 |-- valence: double (nullable = true)
 |-- tempo: double (nullable = true)
 |-- duration_ms: string (nullable = true)
 |-- time_signature: double (nullable = true)
 |-- year: integer (nullable = true)
 |-- release_date: string (nullable = true)
 

                                                                                

+---------+------+
|is_ranked| count|
+---------+------+
|        0|271221|
|        1|  7510|
+---------+------+



                                                                                


Updated dataset saved to: gs://dataproc-staging-us-central1-361128386781-eo9ksqfa/spotify_prepared_data_complete


In [26]:
spark.stop()

# model

In [2]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
import matplotlib.pyplot as plt
import seaborn as sns

# Initialize a new Spark Session
spark = SparkSession.builder \
    .appName("Spotify Classification Modeling") \
    .getOrCreate()

# Define the path to the updated dataset
data_path = "gs://dataproc-staging-us-central1-361128386781-eo9ksqfa/spotify_prepared_data_complete"

# Load the dataset
print("Loading prepared dataset from GCS...")
completed_songs = spark.read.parquet(data_path)

# Cache the dataset for faster operations
completed_songs = completed_songs.cache()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/04/05 23:03:33 INFO SparkEnv: Registering MapOutputTracker
25/04/05 23:03:33 INFO SparkEnv: Registering BlockManagerMaster
25/04/05 23:03:33 INFO SparkEnv: Registering BlockManagerMasterHeartbeat
25/04/05 23:03:33 INFO SparkEnv: Registering OutputCommitCoordinator


Loading prepared dataset from GCS...


25/04/05 23:03:53 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


In [3]:
# Display basic information about the dataset
print(f"\nDataset size: {completed_songs.count()} rows and {len(completed_songs.columns)} columns")

# Show the schema
print("\nDataset schema:")
completed_songs.printSchema()

# Show a sample of the data
print("\nSample of the prepared data with target variable:")
completed_songs.select("id", "name", "artists", "is_ranked", "features", "scaled_features").show(5, truncate=True)

                                                                                


Dataset size: 278731 rows and 28 columns

Dataset schema:
root
 |-- id: string (nullable = true)
 |-- name: string (nullable = true)
 |-- album: string (nullable = true)
 |-- album_id: string (nullable = true)
 |-- artists: string (nullable = true)
 |-- artist_ids: string (nullable = true)
 |-- track_number: string (nullable = true)
 |-- disc_number: string (nullable = true)
 |-- explicit: string (nullable = true)
 |-- danceability: double (nullable = true)
 |-- energy: double (nullable = true)
 |-- key: double (nullable = true)
 |-- loudness: double (nullable = true)
 |-- mode: double (nullable = true)
 |-- speechiness: double (nullable = true)
 |-- acousticness: double (nullable = true)
 |-- instrumentalness: double (nullable = true)
 |-- liveness: double (nullable = true)
 |-- valence: double (nullable = true)
 |-- tempo: double (nullable = true)
 |-- duration_ms: string (nullable = true)
 |-- time_signature: double (nullable = true)
 |-- year: integer (nullable = true)
 |-- releas

In [5]:
from pyspark.ml.feature import StringIndexer
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.sql.functions import col

completed_songs = completed_songs.withColumn("label", col("is_ranked").cast("double"))

print("\n===== SPLITTING DATA INTO TRAIN AND TEST SETS =====")
train_data, test_data = completed_songs.randomSplit([0.8, 0.2], seed=42)

print(f"Training set size: {train_data.count()} rows")
print(f"Test set size: {test_data.count()} rows")

train_data = train_data.cache()
test_data = test_data.cache()

try:
    from sparkxgb import XGBoostClassifier
    
    xgb = XGBoostClassifier(
        featuresCol="scaled_features", 
        labelCol="label",
        numRound=100,
        maxDepth=6,
        eta=0.1,
        objective="binary:logistic",
        evalMetric="auc"
    )
    
    print("\n===== TRAINING XGBOOST MODEL =====")
    xgb_model = xgb.fit(train_data)
    
    predictions = xgb_model.transform(test_data)
    
    evaluator = BinaryClassificationEvaluator(
        labelCol="label", 
        rawPredictionCol="probabilities", 
        metricName="areaUnderROC"
    )
    
    auc = evaluator.evaluate(predictions)
    print(f"XGBoost AUC: {auc:.4f}")
    
    accuracy_evaluator = MulticlassClassificationEvaluator(
        labelCol="label", 
        predictionCol="prediction", 
        metricName="accuracy"
    )
    
    accuracy = accuracy_evaluator.evaluate(predictions)
    print(f"XGBoost Accuracy: {accuracy:.4f}")
    
    tp = predictions.filter((col("prediction") == 1) & (col("label") == 1)).count()
    fp = predictions.filter((col("prediction") == 1) & (col("label") == 0)).count()
    tn = predictions.filter((col("prediction") == 0) & (col("label") == 0)).count()
    fn = predictions.filter((col("prediction") == 0) & (col("label") == 1)).count()
    
    print("\nConfusion Matrix:")
    print(f"True Positives: {tp}")
    print(f"False Positives: {fp}")
    print(f"True Negatives: {tn}")
    print(f"False Negatives: {fn}")
    
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    
    print(f"\nPrecision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")
    
    model_path = "gs://dataproc-staging-us-central1-361128386781-eo9ksqfa/spotify_xgboost_model"
    xgb_model.write().overwrite().save(model_path)
    print(f"\nXGBoost model saved to: {model_path}")
    
except ImportError:
    print("\nXGBoost for Spark not found. Using LogisticRegression and RandomForest instead.")
    
    from pyspark.ml.classification import LogisticRegression, RandomForestClassifier
    
    lr = LogisticRegression(
        featuresCol="scaled_features", 
        labelCol="label",
        maxIter=10,
        regParam=0.01
    )
    
    print("\n===== TRAINING LOGISTIC REGRESSION MODEL =====")
    lr_model = lr.fit(train_data)
    
    lr_predictions = lr_model.transform(test_data)
    
    evaluator = BinaryClassificationEvaluator(
        labelCol="label", 
        rawPredictionCol="rawPrediction", 
        metricName="areaUnderROC"
    )
    
    lr_auc = evaluator.evaluate(lr_predictions)
    print(f"Logistic Regression AUC: {lr_auc:.4f}")
    
    rf = RandomForestClassifier(
        featuresCol="scaled_features", 
        labelCol="label",
        numTrees=100,
        maxDepth=10
    )
    
    print("\n===== TRAINING RANDOM FOREST MODEL =====")
    rf_model = rf.fit(train_data)
    
    rf_predictions = rf_model.transform(test_data)
    
    rf_auc = evaluator.evaluate(rf_predictions)
    print(f"Random Forest AUC: {rf_auc:.4f}")
    
    if rf_auc > lr_auc:
        print("\nRandom Forest performed better!")
        best_model = rf_model
        best_predictions = rf_predictions
        best_model_name = "RandomForest"
    else:
        print("\nLogistic Regression performed better!")
        best_model = lr_model
        best_predictions = lr_predictions
        best_model_name = "LogisticRegression"

    accuracy_evaluator = MulticlassClassificationEvaluator(
        labelCol="label", 
        predictionCol="prediction", 
        metricName="accuracy"
    )
    
    best_accuracy = accuracy_evaluator.evaluate(best_predictions)
    print(f"Best Model Accuracy: {best_accuracy:.4f}")
    
    tp = best_predictions.filter((col("prediction") == 1) & (col("label") == 1)).count()
    fp = best_predictions.filter((col("prediction") == 1) & (col("label") == 0)).count()
    tn = best_predictions.filter((col("prediction") == 0) & (col("label") == 0)).count()
    fn = best_predictions.filter((col("prediction") == 0) & (col("label") == 1)).count()
    
    print("\nConfusion Matrix:")
    print(f"True Positives: {tp}")
    print(f"False Positives: {fp}")
    print(f"True Negatives: {tn}")
    print(f"False Negatives: {fn}")
    
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    
    print(f"\nPrecision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")
    
    model_path = f"gs://dataproc-staging-us-central1-361128386781-eo9ksqfa/spotify_{best_model_name.lower()}_model"
    best_model.write().overwrite().save(model_path)
    print(f"\n{best_model_name} model saved to: {model_path}")


===== SPLITTING DATA INTO TRAIN AND TEST SETS =====


                                                                                

Training set size: 223011 rows


[Stage 9:>                                                          (0 + 2) / 2]

Test set size: 55720 rows

XGBoost for Spark not found. Using LogisticRegression and RandomForest instead.


                                                                                


===== TRAINING LOGISTIC REGRESSION MODEL =====


                                                                                

Logistic Regression AUC: 0.7940

===== TRAINING RANDOM FOREST MODEL =====


25/04/05 23:05:20 WARN DAGScheduler: Broadcasting large task binary with size 1004.7 KiB
25/04/05 23:05:25 WARN DAGScheduler: Broadcasting large task binary with size 1851.2 KiB
25/04/05 23:05:32 WARN DAGScheduler: Broadcasting large task binary with size 3.4 MiB
25/04/05 23:05:40 WARN DAGScheduler: Broadcasting large task binary with size 6.1 MiB
25/04/05 23:05:48 WARN DAGScheduler: Broadcasting large task binary with size 1580.2 KiB
25/04/05 23:05:50 WARN DAGScheduler: Broadcasting large task binary with size 10.5 MiB
25/04/05 23:06:00 WARN DAGScheduler: Broadcasting large task binary with size 2.4 MiB
25/04/05 23:06:02 WARN DAGScheduler: Broadcasting large task binary with size 4.1 MiB
                                                                                

Random Forest AUC: 0.8138

Random Forest performed better!


25/04/05 23:06:04 WARN DAGScheduler: Broadcasting large task binary with size 4.1 MiB
                                                                                

Best Model Accuracy: 0.9741


25/04/05 23:06:05 WARN DAGScheduler: Broadcasting large task binary with size 4.1 MiB
25/04/05 23:06:07 WARN DAGScheduler: Broadcasting large task binary with size 4.1 MiB
25/04/05 23:06:08 WARN DAGScheduler: Broadcasting large task binary with size 4.1 MiB
25/04/05 23:06:09 WARN DAGScheduler: Broadcasting large task binary with size 4.1 MiB
                                                                                


Confusion Matrix:
True Positives: 0
False Positives: 0
True Negatives: 54275
False Negatives: 1445

Precision: 0.0000
Recall: 0.0000
F1 Score: 0.0000


25/04/05 23:06:14 WARN TaskSetManager: Stage 94 contains a task of very large size (2075 KiB). The maximum recommended task size is 1000 KiB.
                                                                                


RandomForest model saved to: gs://dataproc-staging-us-central1-361128386781-eo9ksqfa/spotify_randomforest_model


In [6]:
from pyspark.ml.classification import GBTClassifier, DecisionTreeClassifier, LinearSVC, NaiveBayes, MultilayerPerceptronClassifier
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator

print("\n===== TRAINING GRADIENT BOOSTED TREES MODEL =====")
gbt = GBTClassifier(
    featuresCol="scaled_features", 
    labelCol="label",
    maxIter=10,
    maxDepth=5
)
gbt_model = gbt.fit(train_data)
gbt_predictions = gbt_model.transform(test_data)
gbt_auc = BinaryClassificationEvaluator(labelCol="label", rawPredictionCol="rawPrediction").evaluate(gbt_predictions)
print(f"GBT AUC: {gbt_auc:.4f}")

print("\n===== TRAINING DECISION TREE MODEL =====")
dt = DecisionTreeClassifier(
    featuresCol="scaled_features", 
    labelCol="label",
    maxDepth=5
)
dt_model = dt.fit(train_data)
dt_predictions = dt_model.transform(test_data)
dt_auc = BinaryClassificationEvaluator(labelCol="label", rawPredictionCol="rawPrediction").evaluate(dt_predictions)
print(f"Decision Tree AUC: {dt_auc:.4f}")

print("\n===== TRAINING LINEAR SVM MODEL =====")
svm = LinearSVC(
    featuresCol="scaled_features", 
    labelCol="label",
    maxIter=10,
    regParam=0.1
)
svm_model = svm.fit(train_data)
svm_predictions = svm_model.transform(test_data)
svm_auc = BinaryClassificationEvaluator(labelCol="label", rawPredictionCol="rawPrediction").evaluate(svm_predictions)
print(f"Linear SVM AUC: {svm_auc:.4f}")

print("\n===== TRAINING NAIVE BAYES MODEL =====")
from pyspark.ml.feature import MinMaxScaler
minmax_scaler = MinMaxScaler(inputCol="features", outputCol="minmax_features")
minmax_model = minmax_scaler.fit(train_data)
train_data_minmax = minmax_model.transform(train_data)
test_data_minmax = minmax_model.transform(test_data)

nb = NaiveBayes(
    featuresCol="minmax_features", 
    labelCol="label",
    modelType="multinomial"
)
nb_model = nb.fit(train_data_minmax)
nb_predictions = nb_model.transform(test_data_minmax)
nb_evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")
nb_accuracy = nb_evaluator.evaluate(nb_predictions)
print(f"Naive Bayes Accuracy: {nb_accuracy:.4f}")

print("\n===== TRAINING NEURAL NETWORK MODEL =====")
feature_size = len(train_data.select("scaled_features").first()[0])
layers = [feature_size, 10, 5, 2]

mlp = MultilayerPerceptronClassifier(
    featuresCol="scaled_features", 
    labelCol="label",
    layers=layers,
    blockSize=128,
    seed=42,
    maxIter=100
)
mlp_model = mlp.fit(train_data)
mlp_predictions = mlp_model.transform(test_data)
mlp_accuracy = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction").evaluate(mlp_predictions)
print(f"Neural Network Accuracy: {mlp_accuracy:.4f}")

print("\n===== ENSEMBLE MODEL VOTING =====")
from pyspark.sql.functions import col, when, greatest

ensemble_df = test_data.select("id", "label")
ensemble_df = ensemble_df.join(gbt_predictions.select("id", col("prediction").alias("gbt_pred")), "id")
ensemble_df = ensemble_df.join(dt_predictions.select("id", col("prediction").alias("dt_pred")), "id")
ensemble_df = ensemble_df.join(svm_predictions.select("id", col("prediction").alias("svm_pred")), "id")
ensemble_df = ensemble_df.join(mlp_predictions.select("id", col("prediction").alias("mlp_pred")), "id")

ensemble_df = ensemble_df.withColumn(
    "vote_sum", 
    col("gbt_pred") + col("dt_pred") + col("svm_pred") + col("mlp_pred")
)
ensemble_df = ensemble_df.withColumn(
    "ensemble_pred", 
    when(col("vote_sum") >= 2, 1.0).otherwise(0.0)
)

ensemble_accuracy = ensemble_df.filter(col("ensemble_pred") == col("label")).count() / ensemble_df.count()
print(f"Ensemble Model Accuracy: {ensemble_accuracy:.4f}")


===== TRAINING GRADIENT BOOSTED TREES MODEL =====
GBT AUC: 0.8076

===== TRAINING DECISION TREE MODEL =====
Decision Tree AUC: 0.5000

===== TRAINING LINEAR SVM MODEL =====
Linear SVM AUC: 0.7744

===== TRAINING NAIVE BAYES MODEL =====


                                                                                

Naive Bayes Accuracy: 0.9741

===== TRAINING NEURAL NETWORK MODEL =====


                                                                                

Neural Network Accuracy: 0.9613

===== ENSEMBLE MODEL VOTING =====
Ensemble Model Accuracy: 0.9741


In [8]:
from pyspark.sql.functions import col, when, expr
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator

# Create a function to calculate and display comprehensive metrics
def evaluate_model(predictions, label_col="label", prediction_col="prediction", model_name="Model", has_raw_prediction=True):
    # Calculate basic metrics
    evaluator = MulticlassClassificationEvaluator(labelCol=label_col, predictionCol=prediction_col)
    accuracy = evaluator.evaluate(predictions)
    
    # Area under ROC - only if the model has rawPrediction column
    if has_raw_prediction:
        binary_evaluator = BinaryClassificationEvaluator(labelCol=label_col, rawPredictionCol="rawPrediction")
        auc = binary_evaluator.evaluate(predictions)
    else:
        auc = None  # Mark as unavailable
    
    # Calculate confusion matrix elements
    tp = predictions.filter((col(prediction_col) == 1.0) & (col(label_col) == 1.0)).count()
    fp = predictions.filter((col(prediction_col) == 1.0) & (col(label_col) == 0.0)).count()
    tn = predictions.filter((col(prediction_col) == 0.0) & (col(label_col) == 0.0)).count()
    fn = predictions.filter((col(prediction_col) == 0.0) & (col(label_col) == 1.0)).count()
    
    # Calculate additional metrics
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    
    # Print metrics
    print(f"\n===== {model_name} EVALUATION =====")
    print(f"Accuracy: {accuracy:.4f}")
    if auc is not None:
        print(f"AUC: {auc:.4f}")
    else:
        print("AUC: Not available (no rawPrediction column)")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")
    print(f"Specificity: {specificity:.4f}")
    
    # Print confusion matrix
    print("\nConfusion Matrix:")
    print(f"True Positives: {tp}")
    print(f"False Positives: {fp}")
    print(f"True Negatives: {tn}")
    print(f"False Negatives: {fn}")
    
    # Return metrics as a dictionary
    return {
        "model": model_name,
        "accuracy": accuracy,
        "auc": auc if auc is not None else 0,  # Use 0 for sorting purposes
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "specificity": specificity,
        "tp": tp,
        "fp": fp,
        "tn": tn,
        "fn": fn
    }

# Evaluate all models
print("\n===== EVALUATING ALL MODELS ON TEST DATA =====")

# Evaluate GBT
gbt_metrics = evaluate_model(gbt_predictions, model_name="Gradient Boosted Trees")

# Evaluate Decision Tree
dt_metrics = evaluate_model(dt_predictions, model_name="Decision Tree")

# Evaluate SVM
svm_metrics = evaluate_model(svm_predictions, model_name="Linear SVM")

# Evaluate Naive Bayes
nb_metrics = evaluate_model(nb_predictions, model_name="Naive Bayes")

# Evaluate Neural Network
mlp_metrics = evaluate_model(mlp_predictions, model_name="Neural Network")

# Evaluate Ensemble - specify that it doesn't have rawPrediction column
ensemble_metrics = evaluate_model(
    ensemble_df, 
    label_col="label", 
    prediction_col="ensemble_pred", 
    model_name="Ensemble Model",
    has_raw_prediction=False
)

# Compare all models
metrics_list = [gbt_metrics, dt_metrics, svm_metrics, nb_metrics, mlp_metrics, ensemble_metrics]
metrics_list.sort(key=lambda x: x["f1"], reverse=True)

print("\n===== MODEL COMPARISON (SORTED BY F1 SCORE) =====")
print(f"{'Model':<20} {'Accuracy':<10} {'AUC':<10} {'Precision':<10} {'Recall':<10} {'F1 Score':<10} {'Specificity':<10}")
print("-" * 80)
for m in metrics_list:
    auc_str = f"{m['auc']:.4f}" if m['auc'] != 0 else "N/A"
    print(f"{m['model']:<20} {m['accuracy']:<10.4f} {auc_str:<10} {m['precision']:<10.4f} {m['recall']:<10.4f} {m['f1']:<10.4f} {m['specificity']:<10.4f}")

# Identify best model by F1 score
best_model = metrics_list[0]["model"]
print(f"\nBest model by F1 score: {best_model}")

# Map model names to dataframes for error analysis
model_predictions = {
    "Gradient Boosted Trees": gbt_predictions,
    "Decision Tree": dt_predictions,
    "Linear SVM": svm_predictions,
    "Naive Bayes": nb_predictions,
    "Neural Network": mlp_predictions,
    "Ensemble Model": ensemble_df
}

# Get the predictions for the best model
predictions = model_predictions[best_model]
prediction_col = "ensemble_pred" if best_model == "Ensemble Model" else "prediction"

# Get false positives (songs predicted to rank but didn't)
false_positives = predictions.filter((col(prediction_col) == 1.0) & (col("label") == 0.0))
print(f"\nFalse Positive Examples (songs predicted to rank but didn't): {false_positives.count()}")
if "name" in predictions.columns and "artists" in predictions.columns:
    false_positives.select("id", "name", "artists").show(5, truncate=False)
else:
    false_positives.select("id").show(5)

# Get false negatives (songs that ranked but predicted not to)
false_negatives = predictions.filter((col(prediction_col) == 0.0) & (col("label") == 1.0))
print(f"\nFalse Negative Examples (songs that ranked but predicted not to): {false_negatives.count()}")
if "name" in predictions.columns and "artists" in predictions.columns:
    false_negatives.select("id", "name", "artists").show(5, truncate=False)
else:
    false_negatives.select("id").show(5)

# Check prediction distribution
prediction_dist = predictions.groupBy(prediction_col).count().orderBy(prediction_col)
print("\nPrediction Distribution:")
prediction_dist.show()

# Check feature distributions if features exist in the predictions dataframe
if all(col in predictions.columns for col in ["danceability", "energy", "acousticness", "valence", "tempo"]):
    print("\nFeature analysis for correct vs. incorrect predictions:")
    feature_cols = ["danceability", "energy", "acousticness", "valence", "tempo"]

    for feature in feature_cols:
        correct_predictions = predictions.filter(col(prediction_col) == col("label"))
        incorrect_predictions = predictions.filter(col(prediction_col) != col("label"))
        
        correct_avg = correct_predictions.select(expr(f"avg({feature})")).collect()[0][0]
        incorrect_avg = incorrect_predictions.select(expr(f"avg({feature})")).collect()[0][0]
        
        print(f"{feature}: Correct avg = {correct_avg:.4f}, Incorrect avg = {incorrect_avg:.4f}, Difference = {abs(correct_avg - incorrect_avg):.4f}")


===== EVALUATING ALL MODELS ON TEST DATA =====

===== Gradient Boosted Trees EVALUATION =====
Accuracy: 0.9613
AUC: 0.8076
Precision: 0.0000
Recall: 0.0000
F1 Score: 0.0000
Specificity: 1.0000

Confusion Matrix:
True Positives: 0
False Positives: 1
True Negatives: 54274
False Negatives: 1445

===== Decision Tree EVALUATION =====
Accuracy: 0.9613
AUC: 0.5000
Precision: 0.0000
Recall: 0.0000
F1 Score: 0.0000
Specificity: 1.0000

Confusion Matrix:
True Positives: 0
False Positives: 0
True Negatives: 54275
False Negatives: 1445

===== Linear SVM EVALUATION =====
Accuracy: 0.9613
AUC: 0.7744
Precision: 0.0000
Recall: 0.0000
F1 Score: 0.0000
Specificity: 1.0000

Confusion Matrix:
True Positives: 0
False Positives: 0
True Negatives: 54275
False Negatives: 1445

===== Naive Bayes EVALUATION =====
Accuracy: 0.9613
AUC: 0.6118
Precision: 0.0000
Recall: 0.0000
F1 Score: 0.0000
Specificity: 1.0000

Confusion Matrix:
True Positives: 0
False Positives: 0
True Negatives: 54275
False Negatives: 1445


In [9]:
spark.stop()