In [1]:
from pyspark.sql import SparkSession, Window
from pyspark.sql.functions import col, when, expr, log, sqrt, pow, abs, udf, lit, round, concat_ws
from pyspark.sql.types import DoubleType, BooleanType, StringType, ArrayType
from pyspark.ml.classification import RandomForestClassifier, GBTClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml.feature import VectorAssembler, StandardScaler, VectorSlicer
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

"""
Spotify Charts Prediction
"""

# Spark Configuration
spark = SparkSession.builder \
    .appName("Spotify Prediction System - Final") \
    .config("spark.memory.offHeap.enabled", "true") \
    .config("spark.memory.offHeap.size", "2g") \
    .config("spark.driver.memory", "4g") \
    .getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/04/13 18:54:01 INFO SparkEnv: Registering MapOutputTracker
25/04/13 18:54:01 INFO SparkEnv: Registering BlockManagerMaster
25/04/13 18:54:01 INFO SparkEnv: Registering BlockManagerMasterHeartbeat
25/04/13 18:54:01 INFO SparkEnv: Registering OutputCommitCoordinator


In [2]:
#------------------------------------------------------------------------------
# 1. Data Loading and Pre-manipulation
#------------------------------------------------------------------------------

# Data loading
data_path = "gs://dataproc-staging-us-central1-361128386781-eo9ksqfa/spotify_prepared_data_complete"
completed_songs = spark.read.parquet(data_path)

# Confirm data types
completed_songs = completed_songs.withColumn("label", col("is_ranked").cast("double"))

# Statistics
total_count = completed_songs.count()
positive_count = completed_songs.filter(col("label") == 1.0).count()
negative_count = completed_songs.filter(col("label") == 0.0).count()

print(f"Total Counts:{total_count}")
print(f"On Charts Counts: {positive_count} ({positive_count/total_count*100:.2f}%)")
print(f"Not On Charts Counts: {negative_count} ({negative_count/total_count*100:.2f}%)")

                                                                                

Total Counts:278731
On Charts Counts: 7510 (2.69%)
Not On Charts Counts: 271221 (97.31%)


                                                                                

In [3]:
#------------------------------------------------------------------------------
# 2. Feature Enhancement
#------------------------------------------------------------------------------
# Orgin
audio_features = ["danceability", "energy", "key", "loudness", "mode", 
                  "speechiness", "acousticness", "instrumentalness", 
                  "liveness", "valence", "tempo", "time_signature"]

for feature in audio_features:
    completed_songs = completed_songs.withColumn(
        feature, col(feature).cast("double")
    )

#-- Music theory related features --#

# Energy to Acoustic Ratio - A measure of how electronic vs. acoustic a song is
completed_songs = completed_songs.withColumn(
    "energy_acoustic_ratio", 
    col("energy") / (col("acousticness") + 0.001)
)

# The product of dance and emotion - happy dance music vs sad dance music
completed_songs = completed_songs.withColumn(
    "dance_valence_product", 
    col("danceability") * col("valence")
)

# Vocal-instrumental balance - Differentiate between vocal-dominant and instrument-dominant songs
completed_songs = completed_songs.withColumn(
    "vocal_instrumental_balance", 
    (1 - col("instrumentalness")) / (col("instrumentalness") + 0.001)
)

#-- Audience Perception Related Features --#

# Rhythm perception - combining tempo and danceability
completed_songs = completed_songs.withColumn(
    "rhythm_factor", 
    col("tempo") * col("danceability") / 100.0
)

# Emotional intensity - a combination of valence and energy
completed_songs = completed_songs.withColumn(
    "mood_intensity", 
    sqrt(pow(col("valence"), 2) + pow(col("energy"), 2))
)

# Calm vs. Excitement Factor - Differentiating Calm Songs from Exciting Songs
completed_songs = completed_songs.withColumn(
    "calmness_factor", 
    (col("acousticness") * (1 - col("energy")) * (1 - col("loudness") / -60.0)) / 3.0
)

#-- Category Features --#

# Is it an obvious instrumental?
completed_songs = completed_songs.withColumn(
    "is_instrumental", 
    when(col("instrumentalness") > 0.5, 1.0).otherwise(0.0)
)

# Is it a strong rhythm song?
completed_songs = completed_songs.withColumn(
    "is_rhythmic", 
    when(col("danceability") > 0.7, 1.0).otherwise(0.0)
)

# Is it a high energy song?
completed_songs = completed_songs.withColumn(
    "is_energetic", 
    when(col("energy") > 0.8, 1.0).otherwise(0.0)
)

# Is it a happy song?
completed_songs = completed_songs.withColumn(
    "is_happy", 
    when(col("valence") > 0.7, 1.0).otherwise(0.0)
)

#-- Nonlinear transformation --#

# Logarithmic transformation of acoustic properties - improves skewed distribution
completed_songs = completed_songs.withColumn(
    "log_acousticness", 
    log(col("acousticness") + 0.001)
)

# Logarithmic transformation of instrumental characteristics
completed_songs = completed_songs.withColumn(
    "log_instrumentalness", 
    log(col("instrumentalness") + 0.001)
)

# Loudness squared - emphasizes extreme values
completed_songs = completed_songs.withColumn(
    "loudness_squared", 
    pow(col("loudness"), 2)
)

# Loudness Cubed - further emphasizes extremes
completed_songs = completed_songs.withColumn(
    "loudness_cubic", 
    pow(col("loudness"), 3)
)

#-- Composite Indicator --#

# Mainstream Popularity Index - Combining multiple ranking related indicators
completed_songs = completed_songs.withColumn(
    "mainstream_index", 
    (col("danceability") * 0.25 + 
     col("energy") * 0.25 + 
     col("valence") * 0.2 + 
     (abs(col("loudness")) / 15.0) * 0.15 + 
     (1 - col("acousticness")) * 0.15)
)

# Experimental Index - Non-mainstream feature combinations
completed_songs = completed_songs.withColumn(
    "experimental_index", 
    (col("instrumentalness") * 0.3 + 
     col("speechiness") * 0.2 + 
     col("liveness") * 0.2 + 
     (1 - col("danceability")) * 0.15 + 
     (1 - col("valence")) * 0.15)
)

# Loudness and Energy Interaction
completed_songs = completed_songs.withColumn(
    "loudness_energy_interaction", 
    col("loudness_squared") * col("energy")
)

# Vocal Clarity
completed_songs = completed_songs.withColumn(
    "vocal_clarity", 
    (1 - col("instrumentalness")) * col("speechiness") / (col("acousticness") + 0.001)
)

# Define all new features
new_features = [
    "energy_acoustic_ratio", "dance_valence_product", "vocal_instrumental_balance",
    "rhythm_factor", "mood_intensity", "calmness_factor",
    "is_instrumental", "is_rhythmic", "is_energetic", "is_happy",
    "log_acousticness", "log_instrumentalness", "loudness_squared", "loudness_cubic",
    "mainstream_index", "experimental_index", "loudness_energy_interaction", "vocal_clarity"
]

In [4]:
#------------------------------------------------------------------------------
# 3. Feature vector preparation
#------------------------------------------------------------------------------

# Creating an extended feature vector
all_features = audio_features + new_features
vector_assembler = VectorAssembler(
    inputCols=all_features,
    outputCol="extended_features"
)

# Applying Vector Assembler
df_with_extended_features = vector_assembler.transform(completed_songs)

# Standardized features
scaler = StandardScaler(
    inputCol="extended_features",
    outputCol="scaled_extended_features",
    withStd=True,
    withMean=True
)

scaler_model = scaler.fit(df_with_extended_features)
df_with_scaled_features = scaler_model.transform(df_with_extended_features)

25/04/13 18:54:27 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
                                                                                

In [5]:
#------------------------------------------------------------------------------
# 4. Data partitioning and balanced sampling
#------------------------------------------------------------------------------
# Split the training set into test set
train_data, test_data = df_with_scaled_features.randomSplit([0.8, 0.2], seed=42)

# Separating the majority and minority classes
majority_class = train_data.filter(col("label") == 0.0)
minority_class = train_data.filter(col("label") == 1.0)

# Number of minority and majority classes
minority_count = minority_class.count()
majority_count = majority_class.count()
class_ratio = majority_count / minority_count

print(f"Training Set: {train_data.count()}")
print(f"Testing Set: {test_data.count()}")
print(f"Charts songs on the training set: {minority_count} ({minority_count/train_data.count()*100:.2f}%)")
print(f"Class imbalance ratio: {class_ratio:.2f}:1")

# Create three datasets with different balance ratios
sampling_ratios = {
    "1:2": (minority_count * 2) / majority_count,
    "1:3": (minority_count * 3) / majority_count,
    "1:5": (minority_count * 5) / majority_count
}

balanced_datasets = {}

for ratio_name, ratio in sampling_ratios.items():
    sampled_majority = majority_class.sample(False, ratio, seed=42)
    balanced_data = sampled_majority.unionAll(minority_class)
    balanced_datasets[ratio_name] = balanced_data.cache()
    
    actual_ratio = sampled_majority.count() / minority_count
    print(f"{ratio_name} Balanced Data Set: {balanced_data.count()} tracks (actual ratio {actual_ratio:.2f}:1)")

                                                                                

Training Set: 223011


                                                                                

Testing Set: 55720


                                                                                

Charts songs on the training set: 6065 (2.72%)
Class imbalance ratio: 35.77:1


                                                                                

1:2 Balanced Data Set: 18291 tracks (actual ratio 2.02:1)


                                                                                

1:3 Balanced Data Set: 24364 tracks (actual ratio 3.02:1)




1:5 Balanced Data Set: 36575 tracks (actual ratio 5.03:1)


                                                                                

In [7]:
#------------------------------------------------------------------------------
# 5. Model Training
#------------------------------------------------------------------------------

# Evaluation Function - Calculates classification metrics
def evaluate_model(predictions, label_col="label", prediction_col="prediction", model_name="Model", verbose=True):
    """Calculates classification metrics"""
    # Confusion matrix
    tp = predictions.filter((col(prediction_col) == 1.0) & (col(label_col) == 1.0)).count()
    fp = predictions.filter((col(prediction_col) == 1.0) & (col(label_col) == 0.0)).count()
    tn = predictions.filter((col(prediction_col) == 0.0) & (col(label_col) == 0.0)).count()
    fn = predictions.filter((col(prediction_col) == 0.0) & (col(label_col) == 1.0)).count()
    
    # Calculation indicators
    accuracy = (tp + tn) / (tp + tn + fp + fn) if (tp + tn + fp + fn) > 0 else 0
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    f2 = 5 * precision * recall / (4 * precision + recall) if (precision + recall) > 0 else 0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    
    # ROC-AUC和PR-AUC
    try:
        binary_evaluator = BinaryClassificationEvaluator(
            labelCol=label_col, 
            rawPredictionCol="rawPrediction", 
            metricName="areaUnderROC"
        )
        auc = binary_evaluator.evaluate(predictions)
        
        pr_evaluator = BinaryClassificationEvaluator(
            labelCol=label_col, 
            rawPredictionCol="rawPrediction", 
            metricName="areaUnderPR"
        )
        pr_auc = pr_evaluator.evaluate(predictions)
    except:
        auc = None
        pr_auc = None
    
    # Print results (if verbose=True)
    if verbose:
        print(f"\n===== {model_name} Evaluation =====")
        print(f"Accuracy: {accuracy:.4f}")
        print(f"Precision: {precision:.4f}")
        print(f"Recall: {recall:.4f}")
        print(f"F1: {f1:.4f}")
        print(f"F2: {f2:.4f}")
        
        if auc is not None:
            print(f"ROC-AUC: {auc:.4f}")
        if pr_auc is not None:
            print(f"PR-AUC: {pr_auc:.4f}")
        
        print("\nConfusion matrix:")
        print(f"TP: {tp} ")
        print(f"FP: {fp} ")
        print(f"TN: {tn} ")
        print(f"FN: {fn} ")
        
        if tp + fp > 0:
            print(f"\nThe percentage of songs predicted to be on the charts that actually made it to the charts: {tp/(tp+fp)*100:.2f}%")
        if tn + fn > 0:
            print(f"The percentage of songs predicted not to chart that actually charted: {fn/(tn+fn)*100:.2f}%")
    
    return {
        "model": model_name,
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "f2": f2,
        "specificity": specificity,
        "auc": auc if auc is not None else 0,
        "pr_auc": pr_auc if pr_auc is not None else 0,
        "tp": tp,
        "fp": fp,
        "tn": tn,
        "fn": fn
    }

# 1:2 balanced random forest model (high recall)
print("\nTraining a 1:2 balanced random forest model (high recall)")
rf_high_recall = RandomForestClassifier(
    featuresCol="scaled_extended_features", 
    labelCol="label",
    numTrees=100,
    maxDepth=10,
    seed=42
)

rf_high_recall_model = rf_high_recall.fit(balanced_datasets["1:2"])
rf_high_recall_preds = rf_high_recall_model.transform(test_data)
rf_high_recall_metrics = evaluate_model(
    rf_high_recall_preds, 
    model_name="Random Forest - 1:2 Balanced (High Recall)"
)

# 1:3 balanced random forest model (balanced)
print("\nTraining 1:3 balanced random forest model (balanced)")
rf_balanced = RandomForestClassifier(
    featuresCol="scaled_extended_features", 
    labelCol="label",
    numTrees=100,
    maxDepth=10,
    seed=42
)

rf_balanced_model = rf_balanced.fit(balanced_datasets["1:3"])
rf_balanced_preds = rf_balanced_model.transform(test_data)
rf_balanced_metrics = evaluate_model(
    rf_balanced_preds, 
    model_name="1:3 balanced random forest model (balanced)"
)

# 1:5 balanced random forest model (high accuracy)
print("\nTraining a 1:5 balanced random forest model (high accuracy)")
rf_high_precision = RandomForestClassifier(
    featuresCol="scaled_extended_features", 
    labelCol="label",
    numTrees=100,
    maxDepth=10,
    seed=42
)

rf_high_precision_model = rf_high_precision.fit(balanced_datasets["1:5"])
rf_high_precision_preds = rf_high_precision_model.transform(test_data)
rf_high_precision_metrics = evaluate_model(
    rf_high_precision_preds, 
    model_name="1:5 balanced random forest model (high accuracy)"
)

# GBT model - for threshold adjustment
print("\nTraining GBT model for threshold adjustment")
gbt = GBTClassifier(
    featuresCol="scaled_extended_features", 
    labelCol="label",
    maxIter=10,
    maxDepth=5,
    seed=42
)

gbt_model = gbt.fit(balanced_datasets["1:2"])  # Use 1:2 balanced data training
gbt_preds = gbt_model.transform(test_data)

# Define the threshold adjustment function
def create_threshold_udf(threshold):
    @udf(returnType=DoubleType())
    def apply_threshold(probability):
        if probability is None:
            return 0.0
        return 1.0 if probability[1] > threshold else 0.0
    return apply_threshold

# Testing different thresholds
print("\nTesting different thresholds of GBT model")
thresholds = [0.05, 0.1, 0.15, 0.2, 0.3]
threshold_results = []

for threshold in thresholds:
    threshold_udf = create_threshold_udf(threshold)
    threshold_preds = gbt_preds.withColumn(
        f"prediction_{int(threshold*100)}", 
        threshold_udf(col("probability"))
    )
    
    tp = threshold_preds.filter(
        (col(f"prediction_{int(threshold*100)}") == 1.0) & (col("label") == 1.0)
    ).count()
    fp = threshold_preds.filter(
        (col(f"prediction_{int(threshold*100)}") == 1.0) & (col("label") == 0.0)
    ).count()
    tn = threshold_preds.filter(
        (col(f"prediction_{int(threshold*100)}") == 0.0) & (col("label") == 0.0)
    ).count()
    fn = threshold_preds.filter(
        (col(f"prediction_{int(threshold*100)}") == 0.0) & (col("label") == 1.0)
    ).count()
    
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    f2 = 5 * precision * recall / (4 * precision + recall) if (precision + recall) > 0 else 0
    
    print(f"Threshold: {threshold:.2f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}, F2: {f2:.4f}")
    threshold_results.append((threshold, precision, recall, f1, f2, tp, fp, tn, fn))

# Choose the best threshold (by F2 score, since we prefer recall)
best_threshold_result = max(threshold_results, key=lambda x: x[4])
best_threshold = best_threshold_result[0]
print(f"\nBest Threshold (F2): {best_threshold:.2f}, F1: {best_threshold_result[3]:.4f}, F2: {best_threshold_result[4]:.4f}")

# Applying the best threshold
best_threshold_udf = create_threshold_udf(best_threshold)
gbt_optimized_preds = gbt_preds.withColumn(
    "optimized_prediction", 
    best_threshold_udf(col("probability"))
)

gbt_metrics = evaluate_model(
    gbt_optimized_preds,
    prediction_col="optimized_prediction",
    model_name=f"GBT - Threshold{best_threshold:.2f}"
)

# Aggregate all model results
all_models = [
    rf_high_recall_metrics,
    rf_balanced_metrics,
    rf_high_precision_metrics,
    gbt_metrics
]

# Sort by F1 score
all_models.sort(key=lambda x: x["f1"], reverse=True)

print("\n===== Model summary (sorted by F1 score) =====")
header = "{:<35} {:<10} {:<10} {:<10} {:<10} {:<10}".format(
   "Model", "Accuracy", "Precision", "Recall", "F1 Score", "F2 Score"
)
print(header)
print("-" * 85)

for m in all_models:
    row = "{:<35} {:<10.4f} {:<10.4f} {:<10.4f} {:<10.4f} {:<10.4f}".format(
        m["model"], 
        m["accuracy"], 
        m["precision"], 
        m["recall"], 
        m["f1"],
        m["f2"]
    )
    print(row)

# Save feature importance
feature_importances = []
for i, importance in enumerate(rf_balanced_model.featureImportances):
    if i < len(all_features):
        feature_name = all_features[i]
        feature_importances.append((feature_name, importance))

feature_importances.sort(key=lambda x: x[1], reverse=True)
print("\nThe 10 most important features:")
for i, (feature, importance) in enumerate(feature_importances[:10]):
    print(f"{i+1}. {feature}: {importance:.6f}")


Training a 1:2 balanced random forest model (high recall)


25/04/13 19:20:28 WARN DAGScheduler: Broadcasting large task binary with size 1289.2 KiB
25/04/13 19:20:29 WARN DAGScheduler: Broadcasting large task binary with size 2.1 MiB
25/04/13 19:20:31 WARN DAGScheduler: Broadcasting large task binary with size 3.5 MiB
25/04/13 19:20:34 WARN DAGScheduler: Broadcasting large task binary with size 5.8 MiB
25/04/13 19:20:35 WARN DAGScheduler: Broadcasting large task binary with size 1167.0 KiB
25/04/13 19:20:37 WARN DAGScheduler: Broadcasting large task binary with size 9.0 MiB
25/04/13 19:20:39 WARN DAGScheduler: Broadcasting large task binary with size 1575.0 KiB
25/04/13 19:20:42 WARN DAGScheduler: Broadcasting large task binary with size 6.1 MiB
25/04/13 19:20:47 WARN DAGScheduler: Broadcasting large task binary with size 6.1 MiB
25/04/13 19:20:52 WARN DAGScheduler: Broadcasting large task binary with size 6.1 MiB
25/04/13 19:20:57 WARN DAGScheduler: Broadcasting large task binary with size 6.1 MiB
25/04/13 19:21:01 WARN DAGScheduler: Broadcas


===== Random Forest - 1:2 Balanced (High Recall) Evaluation =====
Accuracy: 0.8436
Precision: 0.0981
Recall: 0.6138
F1: 0.1691
F2: 0.2992
ROC-AUC: 0.8264
PR-AUC: 0.1157

Confusion matrix:
TP: 887 
FP: 8157 
TN: 46118 
FN: 558 

The percentage of songs predicted to be on the charts that actually made it to the charts: 9.81%
The percentage of songs predicted not to chart that actually charted: 1.20%

Training 1:3 balanced random forest model (balanced)


25/04/13 19:21:14 WARN DAGScheduler: Broadcasting large task binary with size 1294.1 KiB
25/04/13 19:21:15 WARN DAGScheduler: Broadcasting large task binary with size 2.1 MiB
25/04/13 19:21:17 WARN DAGScheduler: Broadcasting large task binary with size 3.6 MiB
25/04/13 19:21:19 WARN DAGScheduler: Broadcasting large task binary with size 6.0 MiB
25/04/13 19:21:21 WARN DAGScheduler: Broadcasting large task binary with size 1255.4 KiB
25/04/13 19:21:22 WARN DAGScheduler: Broadcasting large task binary with size 9.4 MiB
25/04/13 19:21:24 WARN DAGScheduler: Broadcasting large task binary with size 1727.9 KiB
25/04/13 19:21:27 WARN DAGScheduler: Broadcasting large task binary with size 6.2 MiB
25/04/13 19:21:31 WARN DAGScheduler: Broadcasting large task binary with size 6.2 MiB
25/04/13 19:21:35 WARN DAGScheduler: Broadcasting large task binary with size 6.2 MiB
25/04/13 19:21:38 WARN DAGScheduler: Broadcasting large task binary with size 6.2 MiB
25/04/13 19:21:42 WARN DAGScheduler: Broadcas


===== 1:3 balanced random forest model (balanced) Evaluation =====
Accuracy: 0.9086
Precision: 0.1223
Recall: 0.4083
F1: 0.1882
F2: 0.2781
ROC-AUC: 0.8242
PR-AUC: 0.1117

Confusion matrix:
TP: 590 
FP: 4236 
TN: 50039 
FN: 855 

The percentage of songs predicted to be on the charts that actually made it to the charts: 12.23%
The percentage of songs predicted not to chart that actually charted: 1.68%

Training a 1:5 balanced random forest model (high accuracy)


25/04/13 19:21:53 WARN DAGScheduler: Broadcasting large task binary with size 1294.9 KiB
25/04/13 19:21:54 WARN DAGScheduler: Broadcasting large task binary with size 2.1 MiB
25/04/13 19:21:56 WARN DAGScheduler: Broadcasting large task binary with size 3.6 MiB
25/04/13 19:21:59 WARN DAGScheduler: Broadcasting large task binary with size 6.1 MiB
25/04/13 19:22:01 WARN DAGScheduler: Broadcasting large task binary with size 1324.0 KiB
25/04/13 19:22:03 WARN DAGScheduler: Broadcasting large task binary with size 9.8 MiB
25/04/13 19:22:06 WARN DAGScheduler: Broadcasting large task binary with size 1878.9 KiB
25/04/13 19:22:08 WARN DAGScheduler: Broadcasting large task binary with size 6.1 MiB
25/04/13 19:22:11 WARN DAGScheduler: Broadcasting large task binary with size 6.1 MiB
25/04/13 19:22:14 WARN DAGScheduler: Broadcasting large task binary with size 6.1 MiB
25/04/13 19:22:17 WARN DAGScheduler: Broadcasting large task binary with size 6.1 MiB
25/04/13 19:22:19 WARN DAGScheduler: Broadcas


===== 1:5 balanced random forest model (high accuracy) Evaluation =====
Accuracy: 0.9599
Precision: 0.1645
Recall: 0.1343
F1: 0.1479
F2: 0.1394
ROC-AUC: 0.8234
PR-AUC: 0.1162

Confusion matrix:
TP: 194 
FP: 985 
TN: 53290 
FN: 1251 

The percentage of songs predicted to be on the charts that actually made it to the charts: 16.45%
The percentage of songs predicted not to chart that actually charted: 2.29%

Training GBT model for threshold adjustment

Testing different thresholds of GBT model


                                                                                

Threshold: 0.05, Precision: 0.0259, Recall: 1.0000, F1: 0.0506, F2: 0.1175


                                                                                

Threshold: 0.10, Precision: 0.0329, Recall: 0.9889, F1: 0.0637, F2: 0.1453


                                                                                

Threshold: 0.15, Precision: 0.0415, Recall: 0.9467, F1: 0.0796, F2: 0.1767


                                                                                

Threshold: 0.20, Precision: 0.0486, Recall: 0.9017, F1: 0.0923, F2: 0.2000


                                                                                

Threshold: 0.30, Precision: 0.0629, Recall: 0.7993, F1: 0.1167, F2: 0.2393

Best Threshold (F2): 0.30, F1: 0.1167, F2: 0.2393


                                                                                


===== GBT - Threshold0.30 Evaluation =====
Accuracy: 0.6861
Precision: 0.0629
Recall: 0.7993
F1: 0.1167
F2: 0.2393
ROC-AUC: 0.8141
PR-AUC: 0.0973

Confusion matrix:
TP: 1155 
FP: 17201 
TN: 37074 
FN: 290 

The percentage of songs predicted to be on the charts that actually made it to the charts: 6.29%
The percentage of songs predicted not to chart that actually charted: 0.78%

===== Model summary (sorted by F1 score) =====
Model                               Accuracy   Precision  Recall     F1 Score   F2 Score  
-------------------------------------------------------------------------------------
1:3 balanced random forest model (balanced) 0.9086     0.1223     0.4083     0.1882     0.2781    
Random Forest - 1:2 Balanced (High Recall) 0.8436     0.0981     0.6138     0.1691     0.2992    
1:5 balanced random forest model (high accuracy) 0.9599     0.1645     0.1343     0.1479     0.1394    
GBT - Threshold0.30                 0.6861     0.0629     0.7993     0.1167     0.2393    

T

In [8]:
from pyspark.sql import SparkSession, Window
from pyspark.sql.functions import (
    col, when, expr, log, sqrt, pow, abs, udf, lit, round, concat_ws, array_position
)
from pyspark.sql.types import DoubleType, BooleanType, StringType, ArrayType
from pyspark.ml.classification import RandomForestClassifier, GBTClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml.feature import VectorAssembler, StandardScaler, VectorSlicer

# Create tiered prediction function (Fixed version)

def create_tiered_predictions(
    test_data, 
    high_recall_model, 
    balanced_model, 
    high_precision_model, 
    gbt_model, 
    gbt_threshold=0.15
):
    """
    Create a tiered prediction system that combines predictions from multiple models.
    """

    # 1. High-recall model predictions
    high_recall_preds = high_recall_model.transform(test_data)

    # 2. Balanced model predictions
    balanced_preds = balanced_model.transform(test_data)

    # 3. High-precision model predictions
    high_precision_preds = high_precision_model.transform(test_data)

    # 4. GBT model predictions (using optimized threshold)
    gbt_preds = gbt_model.transform(test_data)
    gbt_threshold_udf = create_threshold_udf(gbt_threshold)

    # Fix: Directly use the probability column
    # Extract the necessary columns from each prediction

    # Original test_data columns to carry forward
    predictions = test_data.select("id", "name", "artists", "label")

    # Fix: Do not use getItem() to access the probability vector
    # Instead, use PySpark functions and UDFs to extract probability

    # Create a UDF to extract the probability that a sample belongs to the positive class (class 1)
    @udf(returnType=DoubleType())
    def extract_probability_1(probability_vec):
        # The second element in the probability vector corresponds to the positive class
        if probability_vec is not None and len(probability_vec) > 1:
            return float(probability_vec[1])
        return 0.0

    # Apply the UDF to the high-recall model predictions
    high_recall_with_prob = high_recall_preds.withColumn(
        "high_recall_prob",
        extract_probability_1(col("probability"))
    )

    # Balanced model
    balanced_with_prob = balanced_preds.withColumn(
        "balanced_prob",
        extract_probability_1(col("probability"))
    )

    # High-precision model
    high_precision_with_prob = high_precision_preds.withColumn(
        "high_precision_prob",
        extract_probability_1(col("probability"))
    )

    # GBT model
    gbt_with_prob = gbt_preds.withColumn(
        "gbt_prob",
        extract_probability_1(col("probability"))
    )

    # Join each model’s predicted probabilities and decisions to the main predictions DataFrame
    predictions = predictions.join(
        high_recall_with_prob.select(
            "id",
            "high_recall_prob",
            col("prediction").alias("high_recall_pred")
        ),
        "id"
    )

    predictions = predictions.join(
        balanced_with_prob.select(
            "id",
            "balanced_prob",
            col("prediction").alias("balanced_pred")
        ),
        "id"
    )

    predictions = predictions.join(
        high_precision_with_prob.select(
            "id",
            "high_precision_prob",
            col("prediction").alias("high_precision_pred")
        ),
        "id"
    )

    predictions = predictions.join(
        gbt_with_prob.select(
            "id",
            "gbt_prob"
        ),
        "id"
    )

    # Apply the GBT threshold
    predictions = predictions.withColumn(
        "gbt_pred",
        when(col("gbt_prob") > gbt_threshold, 1.0).otherwise(0.0)
    )

    # Define a UDF for confidence level based on three model probabilities
    @udf(returnType=StringType())
    def get_confidence_level(prob_1, prob_2, prob_3):
        """
        Determine the confidence level based on three model probabilities.
        Uses a weighted average, giving higher weight to the high-recall model.
        """
        weighted_prob = 0.5 * prob_1 + 0.3 * prob_2 + 0.2 * prob_3

        if weighted_prob >= 0.7:
            return "Tier A (Highly Likely to Chart)"
        elif weighted_prob >= 0.5:
            return "Tier B (Moderately Likely to Chart)"
        elif weighted_prob >= 0.3:
            return "Tier C (Low Likelihood to Chart)"
        else:
            return "Tier D (Unlikely to Chart)"

    # Define a UDF for marketing recommendations based on confidence level
    @udf(returnType=StringType())
    def get_marketing_recommendation(confidence_level):
        """
        Provide marketing recommendations based on the confidence level.
        """
        if confidence_level == "Tier A (Highly Likely to Chart)":
            return "Full Promotion: Major social media campaigns, platform recommendations, media coverage, artist tours."
        elif confidence_level == "Tier B (Moderately Likely to Chart)":
            return "Moderate Promotion: Social media promotion, partial platform recommendations, limited tour support."
        elif confidence_level == "Tier C (Low Likelihood to Chart)":
            return "Low Promotion: Limited social media campaigns targeting specific audiences."
        else:
            return "Basic Support: Standard release procedures and observation."

    # Compute an ensemble prediction using a weighted vote approach
    predictions = predictions.withColumn(
        "ensemble_vote",
        col("high_recall_pred") * 0.4 +
        col("balanced_pred") * 0.3 +
        col("high_precision_pred") * 0.2 +
        col("gbt_pred") * 0.1
    )

    predictions = predictions.withColumn(
        "ensemble_pred",
        when(col("ensemble_vote") >= 0.5, 1.0).otherwise(0.0)
    )

    # Add confidence level
    predictions = predictions.withColumn(
        "confidence_level",
        get_confidence_level(
            col("high_recall_prob"),
            col("balanced_prob"),
            col("high_precision_prob")
        )
    )

    # Add marketing recommendation
    predictions = predictions.withColumn(
        "marketing_recommendation",
        get_marketing_recommendation(col("confidence_level"))
    )

    # Add business logic explanation
    predictions = predictions.withColumn(
        "business_logic",
        expr(
            "CASE "
            "WHEN confidence_level = 'Tier A (Highly Likely to Chart)' THEN 'Highest revenue potential, highest estimated return on investment' "
            "WHEN confidence_level = 'Tier B (Moderately Likely to Chart)' THEN 'Moderate revenue potential, decent ROI with appropriate investment' "
            "WHEN confidence_level = 'Tier C (Low Likelihood to Chart)' THEN 'Limited revenue potential, suitable for small-scale testing' "
            "ELSE 'Lower revenue potential, minimal investment recommended' END"
        )
    )

    # Add a simplified reason for charting
    predictions = predictions.withColumn(
        "ranking_reason",
        when(col("high_recall_prob") > 0.7, "Strong recommendation from the high-recall model")
        .when(col("balanced_prob") > 0.7, "Strong recommendation from the balanced model")
        .when(col("high_precision_prob") > 0.7, "Strong recommendation from the high-precision model")
        .when(col("high_recall_prob") > 0.5, "Multiple models collectively recommend")
        .otherwise("Evaluated based on multiple features")
    )

    # Add a composite hit score
    predictions = predictions.withColumn(
        "hit_score",
        col("high_recall_prob") * 0.4 +
        col("balanced_prob") * 0.3 +
        col("high_precision_prob") * 0.2 +
        col("gbt_prob") * 0.1
    )

    # Add development guidance
    predictions = predictions.withColumn(
        "development_guide",
        when(
            col("confidence_level").contains("Tier A") | col("confidence_level").contains("Tier B"),
            "Already shows chart potential; maintain current style"
        )
        .when(
            col("confidence_level").contains("Tier C"),
            "Requires minor adjustments; consider features of other successful tracks"
        )
        .otherwise("Requires major adjustments or a new direction")
    )

    return predictions

# Define the threshold adjustment function
def create_threshold_udf(threshold):
    @udf(returnType=DoubleType())
    def apply_threshold(probability):
        if probability is None:
            return 0.0

        # Check probability type and length
        if isinstance(probability, list) and len(probability) > 1:
            return 1.0 if float(probability[1]) > threshold else 0.0
        return 0.0

    return apply_threshold

# Create the tiered prediction system
print("Creating the tiered prediction system...")
tiered_predictions = create_tiered_predictions(
    test_data,
    rf_high_recall_model,
    rf_balanced_model,
    rf_high_precision_model,
    gbt_model,
    best_threshold
)

Creating the tiered prediction system...


In [9]:
# Analysis code after fixes

# ------------------------------------------------------------------------------
# 7. Results Analysis & Marketing Recommendations
# ------------------------------------------------------------------------------

print("\n===== 7. Results Analysis & Marketing Recommendations =====")

# Analyze the distribution of confidence levels
confidence_dist = tiered_predictions.groupBy("confidence_level").count().orderBy("confidence_level")
print("\nConfidence Level Distribution:")

confidence_counts = confidence_dist.collect()
for row in confidence_counts:
    total = tiered_predictions.count()
    print(f"{row['confidence_level']}: {row['count']} songs ({row['count'] / total * 100:.2f}%)")

# Analyze the actual hit rate for each confidence level
print("\nActual Hit Rate by Confidence Level:")
for level in [
    "Tier A (Highly Likely to Chart)", 
    "Tier B (Moderately Likely to Chart)", 
    "Tier C (Low Likelihood to Chart)", 
    "Tier D (Unlikely to Chart)"
]:
    level_df = tiered_predictions.filter(col("confidence_level") == level)
    level_count = level_df.count()

    if level_count > 0:
        actual_hits = level_df.filter(col("label") == 1.0).count()
        hit_rate = actual_hits / level_count
        print(f"{level}: Actual hit rate {hit_rate * 100:.2f}% ({actual_hits}/{level_count})")

# Display sample songs predicted with high confidence
print("\nTier A (Highly Likely to Chart) Sample Songs:")
a_level_samples = (
    tiered_predictions
    .filter(col("confidence_level") == "Tier A (Highly Likely to Chart)")
    .select(
        "name", 
        "artists", 
        "hit_score", 
        "marketing_recommendation", 
        "ranking_reason", 
        "development_guide"
    )
    .orderBy(col("hit_score").desc())
)
a_level_samples.show(5, truncate=False)

# Display Tier A songs that were actually not hits (false positives)
print("\nTier A (Highly Likely to Chart) but Actually Not Hits:")
false_positives = (
    tiered_predictions
    .filter(
        (col("confidence_level") == "Tier A (Highly Likely to Chart)") & 
        (col("label") == 0.0)
    )
    .select("name", "artists", "hit_score", "ranking_reason")
)
false_positives.show(5, truncate=False)

# Display Tier D songs that were actually hits (false negatives)
print("\nTier D (Unlikely to Chart) but Actually Hits:")
false_negatives = (
    tiered_predictions
    .filter(
        (col("confidence_level") == "Tier D (Unlikely to Chart)") & 
        (col("label") == 1.0)
    )
    .select("name", "artists", "hit_score", "ranking_reason")
)
false_negatives.show(5, truncate=False)


===== 7. Results Analysis & Marketing Recommendations =====

Confidence Level Distribution:


25/04/13 19:23:25 WARN DAGScheduler: Broadcasting large task binary with size 6.2 MiB
25/04/13 19:23:26 WARN DAGScheduler: Broadcasting large task binary with size 6.1 MiB
25/04/13 19:23:37 WARN DAGScheduler: Broadcasting large task binary with size 18.2 MiB
25/04/13 19:23:45 WARN DAGScheduler: Broadcasting large task binary with size 18.2 MiB
25/04/13 19:23:45 WARN DAGScheduler: Broadcasting large task binary with size 18.2 MiB
25/04/13 19:23:48 WARN DAGScheduler: Broadcasting large task binary with size 18.2 MiB
                                                                                

Tier B (Moderately Likely to Chart): 5565 songs (9.99%)


                                                                                

Tier C (Low Likelihood to Chart): 9506 songs (17.06%)


                                                                                

Tier D (Unlikely to Chart): 40412 songs (72.53%)

Actual Hit Rate by Confidence Level:


25/04/13 19:24:06 WARN DAGScheduler: Broadcasting large task binary with size 6.2 MiB
25/04/13 19:24:06 WARN DAGScheduler: Broadcasting large task binary with size 6.1 MiB
25/04/13 19:24:12 WARN DAGScheduler: Broadcasting large task binary with size 18.2 MiB
25/04/13 19:24:19 WARN DAGScheduler: Broadcasting large task binary with size 6.1 MiB
25/04/13 19:24:19 WARN DAGScheduler: Broadcasting large task binary with size 6.2 MiB
25/04/13 19:24:30 WARN DAGScheduler: Broadcasting large task binary with size 18.2 MiB
                                                                                

Tier A (Highly Likely to Chart): Actual hit rate 26.58% (63/237)


25/04/13 19:24:35 WARN DAGScheduler: Broadcasting large task binary with size 6.2 MiB
25/04/13 19:24:36 WARN DAGScheduler: Broadcasting large task binary with size 6.1 MiB
25/04/13 19:24:44 WARN DAGScheduler: Broadcasting large task binary with size 18.2 MiB
25/04/13 19:24:49 WARN DAGScheduler: Broadcasting large task binary with size 6.1 MiB
25/04/13 19:24:49 WARN DAGScheduler: Broadcasting large task binary with size 6.2 MiB
25/04/13 19:25:02 WARN DAGScheduler: Broadcasting large task binary with size 18.2 MiB
                                                                                

Tier B (Moderately Likely to Chart): Actual hit rate 10.93% (608/5565)


25/04/13 19:25:06 WARN DAGScheduler: Broadcasting large task binary with size 6.1 MiB
25/04/13 19:25:07 WARN DAGScheduler: Broadcasting large task binary with size 6.2 MiB
25/04/13 19:25:13 WARN DAGScheduler: Broadcasting large task binary with size 18.2 MiB
25/04/13 19:25:17 WARN DAGScheduler: Broadcasting large task binary with size 6.2 MiB
25/04/13 19:25:17 WARN DAGScheduler: Broadcasting large task binary with size 6.1 MiB
25/04/13 19:25:25 WARN DAGScheduler: Broadcasting large task binary with size 18.2 MiB
                                                                                

Tier C (Low Likelihood to Chart): Actual hit rate 4.70% (447/9506)


25/04/13 19:25:30 WARN DAGScheduler: Broadcasting large task binary with size 6.1 MiB
25/04/13 19:25:30 WARN DAGScheduler: Broadcasting large task binary with size 6.2 MiB
25/04/13 19:25:36 WARN DAGScheduler: Broadcasting large task binary with size 18.2 MiB
25/04/13 19:25:40 WARN DAGScheduler: Broadcasting large task binary with size 6.1 MiB
25/04/13 19:25:40 WARN DAGScheduler: Broadcasting large task binary with size 6.2 MiB
25/04/13 19:25:48 WARN DAGScheduler: Broadcasting large task binary with size 18.2 MiB
                                                                                

Tier D (Unlikely to Chart): Actual hit rate 0.81% (327/40412)

Tier A (Highly Likely to Chart) Sample Songs:


25/04/13 19:25:54 WARN DAGScheduler: Broadcasting large task binary with size 6.1 MiB
25/04/13 19:25:54 WARN DAGScheduler: Broadcasting large task binary with size 6.2 MiB
25/04/13 19:25:54 WARN DAGScheduler: Broadcasting large task binary with size 6.1 MiB
25/04/13 19:26:09 WARN DAGScheduler: Broadcasting large task binary with size 18.4 MiB
                                                                                

+--------------+---------------------------------------+------------------+-----------------------------------------------------------------------------------------------------+------------------------------------------------+-----------------------------------------------------+
|name          |artists                                |hit_score         |marketing_recommendation                                                                             |ranking_reason                                  |development_guide                                    |
+--------------+---------------------------------------+------------------+-----------------------------------------------------------------------------------------------------+------------------------------------------------+-----------------------------------------------------+
|Juicy         |['Doja Cat', 'Tyga']                   |0.7919033608166763|Full Promotion: Major social media campaigns, platform recommendations, media cove

25/04/13 19:26:14 WARN DAGScheduler: Broadcasting large task binary with size 6.1 MiB
25/04/13 19:26:14 WARN DAGScheduler: Broadcasting large task binary with size 6.2 MiB
25/04/13 19:26:14 WARN DAGScheduler: Broadcasting large task binary with size 6.1 MiB
25/04/13 19:26:24 WARN DAGScheduler: Broadcasting large task binary with size 18.2 MiB
                                                                                

+-----------------+-------------------+------------------+------------------------------------------------+
|name             |artists            |hit_score         |ranking_reason                                  |
+-----------------+-------------------+------------------+------------------------------------------------+
|Wofür            |['Marie Reim']     |0.7041628287917348|Strong recommendation from the high-recall model|
|Therapy          |['David Archuleta']|0.7086390388549487|Strong recommendation from the high-recall model|
|Only One         |['Johnny Gill']    |0.7044248090877563|Strong recommendation from the high-recall model|
|Stay All Night   |['ALMA']           |0.7055918727861941|Strong recommendation from the high-recall model|
|No Money, No Love|['Etana']          |0.709037127603686 |Strong recommendation from the high-recall model|
+-----------------+-------------------+------------------+------------------------------------------------+
only showing top 5 rows


Ti

25/04/13 19:26:29 WARN DAGScheduler: Broadcasting large task binary with size 6.1 MiB
25/04/13 19:26:29 WARN DAGScheduler: Broadcasting large task binary with size 6.1 MiB
25/04/13 19:26:29 WARN DAGScheduler: Broadcasting large task binary with size 6.2 MiB
25/04/13 19:26:39 WARN DAGScheduler: Broadcasting large task binary with size 18.2 MiB
[Stage 535:>                                                        (0 + 1) / 1]

+---------------------------+--------------------+-------------------+------------------------------------+
|name                       |artists             |hit_score          |ranking_reason                      |
+---------------------------+--------------------+-------------------+------------------------------------+
|Reason to Believe          |['Arch Enemy']      |0.22295077851085793|Evaluated based on multiple features|
|Em Day Chang Phai Thuy Kieu|['Hoang Thuy Linh'] |0.2236708457915555 |Evaluated based on multiple features|
|Algorhythm                 |['Childish Gambino']|0.19923464394175383|Evaluated based on multiple features|
|Love Is The Main Thing     |['Fontaines D.C.']  |0.10185291784007908|Evaluated based on multiple features|
|Sjung högre                |['Håkan Hellström'] |0.14091588373019265|Evaluated based on multiple features|
+---------------------------+--------------------+-------------------+------------------------------------+
only showing top 5 rows



                                                                                

In [10]:
# ------------------------------------------------------------------------------
# 8. Export Results for the Marketing Team
# ------------------------------------------------------------------------------

print("\n===== 8. Exporting Results for the Marketing Team =====")

# Create an output that is friendly for marketing usage – ensure all required columns exist
available_columns = tiered_predictions.columns

# Check if the required columns exist
required_columns = [
    "name", 
    "artists", 
    "confidence_level", 
    "hit_score",
    "marketing_recommendation", 
    "business_logic",
    "ranking_reason", 
    "development_guide"
]

# For each required column, add it as an empty column if it doesn't exist
for column in required_columns:
    if column not in available_columns:
        tiered_predictions = tiered_predictions.withColumn(column, lit(None))

# Select the necessary columns that are present
marketing_columns = [col for col in required_columns if col in available_columns]
marketing_output = tiered_predictions.select(marketing_columns)

# Sort by hit_score in descending order
marketing_output = marketing_output.orderBy(col("hit_score").desc())

# Display the first 10 rows as an example
print("\nRecommended List for the Marketing Team (Top 10):")
marketing_output.show(10, truncate=False)

# Export results to CSV
try:
    # Convert to a pandas DataFrame for export
    marketing_pandas = marketing_output.toPandas()
    
    # Save locally as a CSV file
    csv_path = "/tmp/spotify_marketing_recommendations.csv"
    marketing_pandas.to_csv(csv_path, index=False)
    
    print(f"\nMarketing recommendations saved to: {csv_path}")
    
    # Optionally save to Google Cloud Storage (GCS)
    gcs_path = "gs://dataproc-staging-us-central1-361128386781-eo9ksqfa/spotify_marketing_recommendations.csv"
    
    # Use Spark to save to GCS
    marketing_output.write.mode("overwrite").option("header", "true").csv(gcs_path)
    print(f"Marketing recommendations saved to GCS: {gcs_path}")
except Exception as e:
    print(f"Error occurred during export: {e}")


===== 8. Exporting Results for the Marketing Team =====

Recommended List for the Marketing Team (Top 10):


25/04/13 19:26:44 WARN DAGScheduler: Broadcasting large task binary with size 6.2 MiB
25/04/13 19:26:44 WARN DAGScheduler: Broadcasting large task binary with size 6.1 MiB
25/04/13 19:26:44 WARN DAGScheduler: Broadcasting large task binary with size 6.1 MiB
25/04/13 19:26:57 WARN DAGScheduler: Broadcasting large task binary with size 18.4 MiB
                                                                                

+----------------------+---------------------------------------+-------------------------------+------------------+-----------------------------------------------------------------------------------------------------+-----------------------------------------------------------------+------------------------------------------------+-----------------------------------------------------+
|name                  |artists                                |confidence_level               |hit_score         |marketing_recommendation                                                                             |business_logic                                                   |ranking_reason                                  |development_guide                                    |
+----------------------+---------------------------------------+-------------------------------+------------------+-----------------------------------------------------------------------------------------------------+---------

25/04/13 19:27:02 WARN DAGScheduler: Broadcasting large task binary with size 6.1 MiB
25/04/13 19:27:02 WARN DAGScheduler: Broadcasting large task binary with size 6.1 MiB
25/04/13 19:27:02 WARN DAGScheduler: Broadcasting large task binary with size 6.2 MiB
25/04/13 19:27:12 WARN DAGScheduler: Broadcasting large task binary with size 18.4 MiB
25/04/13 19:27:16 WARN DAGScheduler: Broadcasting large task binary with size 18.4 MiB
25/04/13 19:27:20 WARN DAGScheduler: Broadcasting large task binary with size 18.3 MiB
                                                                                


Marketing recommendations saved to: /tmp/spotify_marketing_recommendations.csv


25/04/13 19:27:24 WARN DAGScheduler: Broadcasting large task binary with size 6.1 MiB
25/04/13 19:27:25 WARN DAGScheduler: Broadcasting large task binary with size 6.2 MiB
25/04/13 19:27:25 WARN DAGScheduler: Broadcasting large task binary with size 6.1 MiB
25/04/13 19:27:35 WARN DAGScheduler: Broadcasting large task binary with size 18.4 MiB
25/04/13 19:27:38 WARN DAGScheduler: Broadcasting large task binary with size 18.4 MiB
25/04/13 19:27:42 WARN DAGScheduler: Broadcasting large task binary with size 18.5 MiB
                                                                                

Marketing recommendations saved to GCS: gs://dataproc-staging-us-central1-361128386781-eo9ksqfa/spotify_marketing_recommendations.csv


In [11]:
# ------------------------------------------------------------------------------
# 9. Save the Best Model
# ------------------------------------------------------------------------------

print("\n===== 9. Saving the Best Model =====")

# Select the best model based on F1 score
best_model_info = max(all_models, key=lambda x: x["f1"])
best_model_name = best_model_info["model"]

print(f"Based on the F1 score, the best single model is: {best_model_name}")
print(f"F1: {best_model_info['f1']:.4f}, Precision: {best_model_info['precision']:.4f}, Recall: {best_model_info['recall']:.4f}")

# Determine which model to save
if "1:2" in best_model_name:
    model_to_save = rf_high_recall_model
    model_type = "rf_high_recall"
elif "1:3" in best_model_name:
    model_to_save = rf_balanced_model
    model_type = "rf_balanced"
elif "1:5" in best_model_name:
    model_to_save = rf_high_precision_model
    model_type = "rf_high_precision"
else:
    model_to_save = gbt_model
    model_type = f"gbt_threshold_{best_threshold}"

# Save the model
model_path = f"gs://dataproc-staging-us-central1-361128386781-eo9ksqfa/spotify_{model_type}_final_model"
model_to_save.write().overwrite().save(model_path)
print(f"The best model has been saved to: {model_path}")

# Save all models needed by the tiered prediction system
models_to_save = {
    "rf_high_recall": rf_high_recall_model,
    "rf_balanced": rf_balanced_model,
    "rf_high_precision": rf_high_precision_model,
    "gbt": gbt_model
}

for name, model in models_to_save.items():
    path = f"gs://dataproc-staging-us-central1-361128386781-eo9ksqfa/spotify_{name}_final_model"
    model.write().overwrite().save(path)
    print(f"Model {name} has been saved to: {path}")

# Save feature information
import json

feature_info = {
    "audio_features": audio_features,
    "new_features": new_features,
    "feature_importances": [(feat, float(imp)) for feat, imp in feature_importances],
    "best_threshold": best_threshold
}

try:
    # Save feature information as a JSON file
    with open("/tmp/spotify_feature_info.json", "w") as f:
        json.dump(feature_info, f, indent=2)
    print("Feature information saved to: /tmp/spotify_feature_info.json")
except Exception as e:
    print(f"Error saving feature information: {e}")

print("\nSpotify Hot Song Prediction System processing complete!")

# Close the Spark session
spark.stop()


===== 9. Saving the Best Model =====
Based on the F1 score, the best single model is: 1:3 balanced random forest model (balanced)
F1: 0.1882, Precision: 0.1223, Recall: 0.4083


25/04/13 19:27:47 WARN TaskSetManager: Stage 561 contains a task of very large size (3166 KiB). The maximum recommended task size is 1000 KiB.
                                                                                

The best model has been saved to: gs://dataproc-staging-us-central1-361128386781-eo9ksqfa/spotify_rf_balanced_final_model


25/04/13 19:27:52 WARN TaskSetManager: Stage 568 contains a task of very large size (3104 KiB). The maximum recommended task size is 1000 KiB.
                                                                                

Model rf_high_recall has been saved to: gs://dataproc-staging-us-central1-361128386781-eo9ksqfa/spotify_rf_high_recall_final_model


25/04/13 19:27:57 WARN TaskSetManager: Stage 575 contains a task of very large size (3166 KiB). The maximum recommended task size is 1000 KiB.
                                                                                

Model rf_balanced has been saved to: gs://dataproc-staging-us-central1-361128386781-eo9ksqfa/spotify_rf_balanced_final_model


25/04/13 19:28:01 WARN TaskSetManager: Stage 582 contains a task of very large size (3129 KiB). The maximum recommended task size is 1000 KiB.
                                                                                

Model rf_high_precision has been saved to: gs://dataproc-staging-us-central1-361128386781-eo9ksqfa/spotify_rf_high_precision_final_model


                                                                                

Model gbt has been saved to: gs://dataproc-staging-us-central1-361128386781-eo9ksqfa/spotify_gbt_final_model
Feature information saved to: /tmp/spotify_feature_info.json

Spotify Hot Song Prediction System processing complete!
