In [1]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler, StringIndexer, OneHotEncoder
from pyspark.ml.clustering import KMeans
from pyspark.ml import Pipeline
from pyspark.sql.functions import col, when
from pyspark.ml.clustering import KMeansModel

In [2]:
spark = SparkSession.builder.appName("SongRecommendationWithArtists").config("spark.executor.memory", "4g").config("spark.driver.memory", "4g").config("spark.executor.cores", "2").config("spark.driver.maxResultSize", "2g").config("spark.sql.shuffle.partitions", "200").getOrCreate()


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/12/08 13:53:35 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
# Define paths for saving model and data
model_path = "./model/kmeans_model"
preprocessed_data_path = "./model/preprocessed_data"
file_path = './tracks_features.csv'

In [4]:
# Function to preprocess data
def preprocess_data(file_path):
    df = spark.read.csv(file_path, header=True, inferSchema=True)

    # Attributes for recommendation
    attributes = ["danceability", "energy", "liveness", "valence", "tempo"]

    # Replace null values in all columns with 0
    #df = df.select([when(col(c).isNull(), 0).otherwise(col(c)).alias(c) for c in df.columns])

    # Replace nulls or non-numeric values and cast to double
    for attr in attributes:
        df = df.withColumn(attr, when(col(attr).isNull(), 0.0).otherwise(col(attr).cast("double")))

    df = df.fillna(0)

    # Encode 'artists' using StringIndexer and OneHotEncoder
    string_indexer = StringIndexer(inputCol="artists", outputCol="artist_index")
    one_hot_encoder = OneHotEncoder(inputCol="artist_index", outputCol="artist_vector")

    # Combine numerical attributes and the artist vector into a single feature vector
    vector_assembler = VectorAssembler(inputCols=attributes + ["artist_vector"], outputCol="features")

    df = string_indexer.fit(df).transform(df)
    df = one_hot_encoder.fit(df).transform(df)
    df = vector_assembler.transform(df)


    return df

In [5]:
# Load the trained model
model = KMeansModel.load(model_path)

                                                                                

In [6]:

# Preprocess the raw data
df = preprocess_data(file_path)

                                                                                

In [7]:
df = model.transform(df) 

In [8]:
df.show()

24/12/08 13:53:44 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
24/12/08 13:53:44 WARN DAGScheduler: Broadcasting large task binary with size 135.3 MiB


+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+------------+-----------+--------+-------------------+------------------+---+-------------------+----+--------------------+--------------------+--------------------+-------------------+-------------------+------------------+-----------+--------------+----+------------+------------+--------------------+--------------------+-------+
|                  id|                name|               album|            album_id|             artists|          artist_ids|track_number|disc_number|explicit|       danceability|            energy|key|           loudness|mode|         speechiness|        acousticness|    instrumentalness|           liveness|            valence|             tempo|duration_ms|time_signature|year|release_date|artist_index|       artist_vector|            features|cluster|
+--------------------+--------------------+--------------------+----------------

In [9]:

# Define the 5 given songs by their IDs
given_song_ids = ["12Cbou8Hl4yGGuTZlkLl60"]

# Get clusters for the given songs
given_songs = df.filter(col("id").isin(given_song_ids)).select("id", "cluster")

# Recommend songs from the same cluster for each given song
recommendations = {}
for row in given_songs.collect():
    song_id = row["id"]
    cluster_id = row["cluster"]
    
    # Get 5 songs from the same cluster, excluding the given song itself
    similar_songs = df.filter((col("cluster") == cluster_id) & (col("id") != song_id)) \
                      .select("name", "artists", "danceability", "energy", "liveness", "valence", "tempo") \
                      .limit(5)
    
    recommendations[song_id] = similar_songs.collect()

# Display recommendations
for song_id, recs in recommendations.items():
    print(f"\nRecommendations for Song ID {song_id}:")
    for row in recs:
        print(f"  Name: {row['name']}, Artists: {row['artists']}, "
              f"Danceability: {row['danceability']}, Energy: {row['energy']}, "
              f"Liveness: {row['liveness']}, Valence: {row['valence']}, Tempo: {row['tempo']}")

24/12/08 13:53:45 WARN DAGScheduler: Broadcasting large task binary with size 135.3 MiB
24/12/08 13:53:47 WARN DAGScheduler: Broadcasting large task binary with size 135.3 MiB



Recommendations for Song ID 12Cbou8Hl4yGGuTZlkLl60:
  Name: Man on a Mission, Artists: ['Daryl Hall & John Oates'], Danceability: 0.787, Energy: 0.903, Liveness: 0.10099999999999999, Valence: 0.9620000000000001, Tempo: 119.946
  Name: strange, Artists: ['Tori Amos'], Danceability: 0.5329999999999999, Energy: 0.319, Liveness: 0.11800000000000001, Valence: 0.19699999999999998, Tempo: 119.475
  Name: Waste of Mind, Artists: ['zebrahead'], Danceability: 0.602, Energy: 0.9109999999999999, Liveness: 0.0514, Valence: 0.848, Tempo: 120.178
  Name: Wish I May, Artists: ['Ani DiFranco'], Danceability: 0.804, Energy: 0.47200000000000003, Liveness: 0.369, Valence: 0.5870000000000001, Tempo: 120.376
  Name: Done Wrong, Artists: ['Ani DiFranco'], Danceability: 0.5589999999999999, Energy: 0.33899999999999997, Liveness: 0.27399999999999997, Valence: 0.0805, Tempo: 119.661


24/12/08 13:53:51 WARN GarbageCollectionMetrics: To enable non-built-in garbage collector(s) List(G1 Concurrent GC), users should configure it(them) to spark.eventLog.gcMetrics.youngGenerationGarbageCollectors or spark.eventLog.gcMetrics.oldGenerationGarbageCollectors
