In [1]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler, StringIndexer
from pyspark.ml.recommendation import ALS
from pyspark.sql.functions import col
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.sql.functions import col, monotonically_increasing_id, explode

In [2]:
# Initialize Spark session
spark = SparkSession.builder.appName("SpotifyRecommendationModel").getOrCreate()


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/12/06 14:08:32 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [None]:
# Load the dataset
data_path = "./tracks_features.csv"
df = spark.read.csv(data_path, header=True, inferSchema=True)

                                                                                

In [4]:
# Select relevant columns
data = df.select("artists", "danceability", "energy", "liveness", "valence", "tempo", "id", "name", "album")

# Assign unique numeric IDs to tracks with integer values
data = data.withColumn("trackId", monotonically_increasing_id().cast("int"))


In [5]:
data.show()

+--------------------+-------------------+------------------+-------------------+-------------------+------------------+--------------------+--------------------+--------------------+-------+
|             artists|       danceability|            energy|           liveness|            valence|             tempo|                  id|                name|               album|trackId|
+--------------------+-------------------+------------------+-------------------+-------------------+------------------+--------------------+--------------------+--------------------+-------+
|['Rage Against Th...|               0.47|             0.978|0.35600000000000004|              0.503|           117.906|7lmeHLHBe4nmXzuXc...|             Testify|The Battle Of Los...|      0|
|['Rage Against Th...|              0.599|0.9570000000000001|              0.155|              0.489|            103.68|1wsRitfRRtWyEapl0...|     Guerrilla Radio|The Battle Of Los...|      1|
|['Rage Against Th...|              0.31

In [6]:
# Convert necessary columns to numeric types
data = data.withColumn("danceability", col("danceability").cast("float"))
data = data.withColumn("energy", col("energy").cast("float"))
data = data.withColumn("liveness", col("liveness").cast("float"))
data = data.withColumn("valence", col("valence").cast("float"))
data = data.withColumn("tempo", col("tempo").cast("float"))

In [7]:
data.show()

+--------------------+------------+------+--------+-------+-------+--------------------+--------------------+--------------------+-------+
|             artists|danceability|energy|liveness|valence|  tempo|                  id|                name|               album|trackId|
+--------------------+------------+------+--------+-------+-------+--------------------+--------------------+--------------------+-------+
|['Rage Against Th...|        0.47| 0.978|   0.356|  0.503|117.906|7lmeHLHBe4nmXzuXc...|             Testify|The Battle Of Los...|      0|
|['Rage Against Th...|       0.599| 0.957|   0.155|  0.489| 103.68|1wsRitfRRtWyEapl0...|     Guerrilla Radio|The Battle Of Los...|      1|
|['Rage Against Th...|       0.315|  0.97|   0.122|   0.37|149.749|1hR0fIFK2qRG3f3RF...|    Calm Like a Bomb|The Battle Of Los...|      2|
|['Rage Against Th...|        0.44| 0.967|   0.121|  0.574| 96.752|2lbASgTSoDO7MTuLA...|           Mic Check|The Battle Of Los...|      3|
|['Rage Against Th...|     

In [8]:
# Handle string columns (artists) with StringIndexer
artist_indexer = StringIndexer(inputCol="artists", outputCol="artistIndex")
data = artist_indexer.fit(data).transform(data)

                                                                                

In [9]:
# Check and drop rows with null values in critical columns
data = data.filter(
    (col("danceability").isNotNull()) &
    (col("energy").isNotNull()) &
    (col("liveness").isNotNull()) &
    (col("valence").isNotNull()) &
    (col("tempo").isNotNull()) &
    (col("artistIndex").isNotNull()) &
    (col("trackId").isNotNull())
)

In [10]:
# Feature engineering
feature_columns = ["danceability", "energy", "liveness", "valence", "tempo"]
assembler = VectorAssembler(inputCols=feature_columns, outputCol="features")
data = assembler.transform(data)

In [11]:
# Generate synthetic ratings for ALS (e.g., based on energy and valence sum)
data = data.withColumn("rating", (col("energy") + col("valence")) / 2)


In [12]:
data.show()

24/12/06 14:08:39 WARN DAGScheduler: Broadcasting large task binary with size 9.2 MiB


+--------------------+------------+------+--------+-------+-------+--------------------+--------------------+--------------------+-------+-----------+--------------------+------------------+
|             artists|danceability|energy|liveness|valence|  tempo|                  id|                name|               album|trackId|artistIndex|            features|            rating|
+--------------------+------------+------+--------+-------+-------+--------------------+--------------------+--------------------+-------+-----------+--------------------+------------------+
|['Rage Against Th...|        0.47| 0.978|   0.356|  0.503|117.906|7lmeHLHBe4nmXzuXc...|             Testify|The Battle Of Los...|      0|      724.0|[0.46999999880790...|0.7404999732971191|
|['Rage Against Th...|       0.599| 0.957|   0.155|  0.489| 103.68|1wsRitfRRtWyEapl0...|     Guerrilla Radio|The Battle Of Los...|      1|      724.0|[0.59899997711181...|0.7229999899864197|
|['Rage Against Th...|       0.315|  0.97|   

In [13]:
# Split the data into training and testing sets
train_data, test_data = data.randomSplit([0.8, 0.2], seed=42)

In [14]:
# Initialize the ALS model

als = ALS(
    userCol="artistIndex",
    itemCol="trackId",
    ratingCol="rating",
    maxIter=10,
    regParam=0.1,
    coldStartStrategy="drop",
    numUserBlocks=10,  # Reduce if memory is constrained
    numItemBlocks=10   # Reduce if memory is constrained
)

In [15]:
# Train the ALS model
model = als.fit(train_data)

24/12/06 14:08:40 WARN DAGScheduler: Broadcasting large task binary with size 9.2 MiB
24/12/06 14:08:41 WARN DAGScheduler: Broadcasting large task binary with size 9.3 MiB
24/12/06 14:08:42 WARN DAGScheduler: Broadcasting large task binary with size 9.3 MiB
24/12/06 14:08:43 WARN GarbageCollectionMetrics: To enable non-built-in garbage collector(s) List(G1 Concurrent GC), users should configure it(them) to spark.eventLog.gcMetrics.youngGenerationGarbageCollectors or spark.eventLog.gcMetrics.oldGenerationGarbageCollectors
24/12/06 14:08:44 WARN DAGScheduler: Broadcasting large task binary with size 9.3 MiB
24/12/06 14:08:45 WARN DAGScheduler: Broadcasting large task binary with size 9.3 MiB
24/12/06 14:08:46 WARN DAGScheduler: Broadcasting large task binary with size 9.3 MiB
24/12/06 14:08:46 WARN DAGScheduler: Broadcasting large task binary with size 9.3 MiB
24/12/06 14:08:47 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
24/12/06 14:08:47 WARN

In [16]:
data.show()

+--------------------+------------+------+--------+-------+-------+--------------------+--------------------+--------------------+-------+-----------+--------------------+------------------+
|             artists|danceability|energy|liveness|valence|  tempo|                  id|                name|               album|trackId|artistIndex|            features|            rating|
+--------------------+------------+------+--------+-------+-------+--------------------+--------------------+--------------------+-------+-----------+--------------------+------------------+
|['Rage Against Th...|        0.47| 0.978|   0.356|  0.503|117.906|7lmeHLHBe4nmXzuXc...|             Testify|The Battle Of Los...|      0|      724.0|[0.46999999880790...|0.7404999732971191|
|['Rage Against Th...|       0.599| 0.957|   0.155|  0.489| 103.68|1wsRitfRRtWyEapl0...|     Guerrilla Radio|The Battle Of Los...|      1|      724.0|[0.59899997711181...|0.7229999899864197|
|['Rage Against Th...|       0.315|  0.97|   

24/12/06 14:08:55 WARN DAGScheduler: Broadcasting large task binary with size 9.2 MiB


In [17]:
song_data = data.filter(data.id == "3FUS56gKr9mVBmzvlnodlh").select("artistIndex", "trackId").distinct()
song_data.show()

24/12/06 14:08:55 WARN DAGScheduler: Broadcasting large task binary with size 9.2 MiB
[Stage 63:=====>                                                   (1 + 9) / 10]

+-----------+-------+
|artistIndex|trackId|
+-----------+-------+
|      724.0|     13|
+-----------+-------+



24/12/06 14:08:56 WARN DAGScheduler: Broadcasting large task binary with size 9.2 MiB
                                                                                

In [None]:
# Function to get recommendations for a given song
def get_recommendations(song_id, num_recommendations=5):
    # Map the alphanumeric song_id to trackId
    song_data = data.filter(data.id == song_id).select("artistIndex", "trackId", "name", "artists", "album").distinct()

    # Check if the song exists in the data
    if song_data.count() == 0:
        print(f"Song ID {song_id} not found in the dataset.")
        return

    # Get the artistIndex and trackId for the song
    song_info = song_data.collect()[0]
    artist_index = song_info["artistIndex"]
    track_id = song_info["trackId"]
    track_name = song_info["name"]
    artist_name = song_info["artists"]
    album_name = song_info["album"]

    print(f"Recommendations based on the song: {track_name} by {artist_name} from the album {album_name}")

    # Generate recommendations for the given trackId
    recommendations = model.recommendForAllItems(num_recommendations)

    # Explode the recommendations to extract track IDs
    exploded_recommendations = recommendations.select(col("trackId"), explode(col("recommendations")).alias("recommendation"))
    exploded_recommendations = exploded_recommendations.select(col("trackId"), col("recommendation.trackId").alias("recommendedTrackId"))

    # Filter recommendations for the specific trackId
    recommendations_for_track = exploded_recommendations.filter(col("trackId") == track_id)

    # Join recommendations with the original dataset to get track details
    recommended_track_ids = recommendations_for_track.select("recommendedTrackId").rdd.flatMap(lambda x: x).collect()
    recommended_tracks = data.filter(col("trackId").isin(recommended_track_ids)).select("name", "artists", "album")

    print("Recommended tracks:")
    recommended_tracks.show(truncate=False)

In [21]:
# Example usage
song_id_input = "3FUS56gKr9mVBmzvlnodlh"  # Replace with actual alphanumeric song ID
get_recommendations(song_id_input, num_recommendations=5)

24/12/06 14:10:34 WARN DAGScheduler: Broadcasting large task binary with size 9.2 MiB
24/12/06 14:10:35 WARN DAGScheduler: Broadcasting large task binary with size 9.2 MiB
24/12/06 14:10:35 WARN DAGScheduler: Broadcasting large task binary with size 9.2 MiB
24/12/06 14:10:35 WARN DAGScheduler: Broadcasting large task binary with size 9.2 MiB


Recommendations based on the song: Killing In the Name by ['Rage Against The Machine'] from the album Rage Against The Machine


AnalysisException: [UNRESOLVED_COLUMN.WITH_SUGGESTION] A column or function parameter with name `id` cannot be resolved. Did you mean one of the following? [`trackId`, `recommendations`].;
'Filter ('id = 13)
+- Project [trackId#768, cast(recommendations#769 as array<struct<artistIndex:int,rating:float>>) AS recommendations#772]
   +- Project [_1#763 AS trackId#768, _2#764 AS recommendations#769]
      +- SerializeFromObject [knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._1 AS _1#763, mapobjects(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, 50), if (isnull(validateexternaltype(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, 50), StructField(_1,IntegerType,false), StructField(_2,FloatType,false), ObjectType(class scala.Tuple2)))) null else named_struct(_1, knownnotnull(validateexternaltype(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, 50), StructField(_1,IntegerType,false), StructField(_2,FloatType,false), ObjectType(class scala.Tuple2)))._1, _2, knownnotnull(validateexternaltype(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, 50), StructField(_1,IntegerType,false), StructField(_2,FloatType,false), ObjectType(class scala.Tuple2)))._2), knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._2, None) AS _2#764]
         +- MapElements org.apache.spark.ml.recommendation.ALSModel$$Lambda/0x0000000f022a35a0@111dd7b5, class scala.Tuple2, [StructField(_1,IntegerType,false), StructField(_2,ArrayType(StructType(StructField(_1,FloatType,false),StructField(_2,IntegerType,false)),true),true)], obj#762: scala.Tuple2
            +- DeserializeToObject newInstance(class scala.Tuple2), obj#761: scala.Tuple2
               +- Aggregate [trackId#741], [trackId#741, collect_top_k(struct(rating, rating#743, artistIndex, artistIndex#742), 5, false, 0, 0) AS collect_top_k(struct(rating, artistIndex))#751]
                  +- Project [_1#738 AS trackId#741, _2#739 AS artistIndex#742, _3#740 AS rating#743]
                     +- SerializeFromObject [knownnotnull(assertnotnull(input[0, scala.Tuple3, true]))._1 AS _1#738, knownnotnull(assertnotnull(input[0, scala.Tuple3, true]))._2 AS _2#739, knownnotnull(assertnotnull(input[0, scala.Tuple3, true]))._3 AS _3#740]
                        +- MapPartitions org.apache.spark.ml.recommendation.ALSModel$$Lambda/0x0000000f0229d888@4e8d91e6, obj#737: scala.Tuple3
                           +- DeserializeToObject newInstance(class scala.Tuple4), obj#736: scala.Tuple4
                              +- Join Cross
                                 :- SerializeFromObject [staticinvoke(class org.apache.spark.sql.catalyst.expressions.UnsafeArrayData, ArrayType(IntegerType,false), fromPrimitiveArray, knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._1, true, false, true) AS _1#707, staticinvoke(class org.apache.spark.sql.catalyst.expressions.UnsafeArrayData, ArrayType(FloatType,false), fromPrimitiveArray, knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._2, true, false, true) AS _2#708]
                                 :  +- MapPartitions org.apache.spark.ml.recommendation.ALSModel$$Lambda/0x0000000f0229b5a0@5668b881, obj#706: scala.Tuple2
                                 :     +- DeserializeToObject newInstance(class scala.Tuple2), obj#705: scala.Tuple2
                                 :        +- Project [_1#458 AS id#463, _2#459 AS features#464]
                                 :           +- SerializeFromObject [knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._1 AS _1#458, staticinvoke(class org.apache.spark.sql.catalyst.expressions.UnsafeArrayData, ArrayType(FloatType,false), fromPrimitiveArray, knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._2, true, false, true) AS _2#459]
                                 :              +- ExternalRDD [obj#457]
                                 +- SerializeFromObject [staticinvoke(class org.apache.spark.sql.catalyst.expressions.UnsafeArrayData, ArrayType(IntegerType,false), fromPrimitiveArray, knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._1, true, false, true) AS _1#718, staticinvoke(class org.apache.spark.sql.catalyst.expressions.UnsafeArrayData, ArrayType(FloatType,false), fromPrimitiveArray, knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._2, true, false, true) AS _2#719]
                                    +- MapPartitions org.apache.spark.ml.recommendation.ALSModel$$Lambda/0x0000000f0229b5a0@29376d86, obj#717: scala.Tuple2
                                       +- DeserializeToObject newInstance(class scala.Tuple2), obj#716: scala.Tuple2
                                          +- Project [_1#446 AS id#451, _2#447 AS features#452]
                                             +- SerializeFromObject [knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._1 AS _1#446, staticinvoke(class org.apache.spark.sql.catalyst.expressions.UnsafeArrayData, ArrayType(FloatType,false), fromPrimitiveArray, knownnotnull(assertnotnull(input[0, scala.Tuple2, true]))._2, true, false, true) AS _2#447]
                                                +- ExternalRDD [obj#445]


24/12/06 14:35:26 WARN HeartbeatReceiver: Removing executor driver with no recent heartbeats: 531638 ms exceeds timeout 120000 ms
24/12/06 14:35:26 WARN SparkContext: Killing executors is not supported by current scheduler.
24/12/06 14:35:27 ERROR Inbox: Ignoring error
org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.SparkThreadUtils$.awaitResult(SparkThreadUtils.scala:56)
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:310)
	at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRefByURI(RpcEnv.scala:102)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRef(RpcEnv.scala:110)
	at org.apache.spark.util.RpcUtils$.makeDriverRef(RpcUtils.scala:36)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.driverEndpoint$lzycompute(BlockManagerMasterEndpoint.scala:124)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.org$apache$spark$storage$BlockManagerMasterEndpoint$$