# Environment Setup
Get necessary packages and setup google colab.

In [1]:
# download spark in google colab
!pip install pyspark findspark recommenders



In [2]:
# for hashing uri
from hashlib import sha256

# connect google drive for data
from google.colab import drive

# mount google drive
drive.mount('/content/drive')

# get spark
import findspark
from pyspark.ml.recommendation import ALS
from pyspark.ml.feature import StringIndexer
from pyspark.sql import SparkSession, functions as F
from pyspark.sql.functions import explode, col, udf, countDistinct

# initialize the spark environment
findspark.init()
spark = SparkSession.builder.appName("SongRecommender").getOrCreate()

# variable to select test or real training environment
IS_TESTING = False

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Extract
Load in the raw data from the dataset.

In [None]:
# select dataset
if IS_TESTING:
    # one dataset
    data_path = r"path/data/mpd.slice.0-999.json"
else:
    # select all datasets
    data_path = r"path/data/*.json"

# extract json from each file
raw_json = spark.read.option("multiline", "true") \
    .option("mode", "PERMISSIVE") \
    .json(data_path)

# Inspect
View the raw data collected and select the features important to the analysis.



In [4]:
# display schema
raw_json.printSchema()

root
 |-- info: struct (nullable = true)
 |    |-- generated_on: string (nullable = true)
 |    |-- slice: string (nullable = true)
 |    |-- version: string (nullable = true)
 |-- playlists: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- collaborative: string (nullable = true)
 |    |    |-- description: string (nullable = true)
 |    |    |-- duration_ms: long (nullable = true)
 |    |    |-- modified_at: long (nullable = true)
 |    |    |-- name: string (nullable = true)
 |    |    |-- num_albums: long (nullable = true)
 |    |    |-- num_artists: long (nullable = true)
 |    |    |-- num_edits: long (nullable = true)
 |    |    |-- num_followers: long (nullable = true)
 |    |    |-- num_tracks: long (nullable = true)
 |    |    |-- pid: long (nullable = true)
 |    |    |-- tracks: array (nullable = true)
 |    |    |    |-- element: struct (containsNull = true)
 |    |    |    |    |-- album_name: string (nullable = true)
 |    |    |    |

What data is important?

1. pid (playlist ID): useful to uniquely identify playlists
2. name (playlist name): for readability
3. track_uri (track resource identifier): useful to uniquely identify songs.
4. track_name: for readability
5. artist_uri (artist resource identifier): useful for uniqely identifying artist. Beneficial to prevent over-recommneding popular artists.
6. artist_name: for readability


# Transform
Tabularize the non-relational data. Perform some cleaning to ensure that the data is good for training. Remove features deemed unnecessary to speed up the training.

In [5]:
# Setup helper to decode the uri to an integer

# convert the uri into a hash
def map_uri(string: str) -> int:
    '''
    Context:
    Since the uris are 23 base-64 characters they are not possible to be mapped
    to integers in spark. The next best solution is to use sha256 which has an
    even distribution and binning them to integer range supported in ALS models.
    The base two modulo allows for fast binning operation.
    '''
    return int(sha256(string.encode()).hexdigest(), 16) % (2**31)

# make decode a user defined
decode_uri = udf(lambda x: map_uri(x.split(":")[-1]))

In [6]:
from pyspark.sql.functions import concat_ws

# Transform JSON -> table with only selected fields

# 1. tabularize the playlist key
raw_df = raw_json.select(explode("playlists").alias("playlist"))

# 2. filter to only keep relevant columns
raw_df = raw_df.select(
    col("playlist.pid").alias("p_id"),
    col("playlist.name").alias("p_name"),
    col("playlist.modified_at").alias("modified_at"),
    explode(col("playlist.tracks")).alias("tracks"))\
    .select(
        col("p_id"),
        col("p_name"),
        col("modified_at"),
        col("tracks.track_uri").alias("track_uri"),
        col("tracks.track_name").alias("track_name"),
        col("tracks.artist_uri").alias("artist_uri"),
        col("tracks.artist_name").alias("artist_name")
    )

# 3. decode the uri
raw_df = raw_df.withColumn("track_id", decode_uri(col("track_uri")))

# NOTE: the data is clean - no nulls, unique playlist-track pairs
# cleaning is not necessary

# Inspect
View the cleaned transformed, clean data.


In [7]:
# describe the dataframe
raw_df.describe().show()

+-------+------------------+------------------+-------------------+--------------------+--------------------+--------------------+-----------+--------------------+
|summary|              p_id|            p_name|        modified_at|           track_uri|          track_name|          artist_uri|artist_name|            track_id|
+-------+------------------+------------------+-------------------+--------------------+--------------------+--------------------+-----------+--------------------+
|  count|           3344374|           3344374|            3344374|             3344374|             3344374|             3344374|    3344374|             3344374|
|   mean|25029.666755871203|1316.9330318329428|1.480902844315857E9|                NULL|                 NaN|                NULL|   Infinity|1.0794762524261513E9|
| stddev| 14432.17233952735| 933.5519478170474|3.495759582354033E7|                NULL|                 NaN|                NULL|        NaN| 6.200548846118802E8|
|    min|       

In [8]:
# show example dataframe
raw_df.show()

+-----+------+-----------+--------------------+--------------------+--------------------+--------------+----------+
| p_id|p_name|modified_at|           track_uri|          track_name|          artist_uri|   artist_name|  track_id|
+-----+------+-----------+--------------------+--------------------+--------------------+--------------+----------+
|45000| Ratch| 1508976000|spotify:track:7te...|         ***Flawless|spotify:artist:6v...|       Beyoncé|1528394279|
|45000| Ratch| 1508976000|spotify:track:7DT...|           Lifestyle|spotify:artist:6h...|     Rich Gang|1739692796|
|45000| Ratch| 1508976000|spotify:track:5NQ...|679 (feat. Remy B...|spotify:artist:6P...|     Fetty Wap|1942972086|
|45000| Ratch| 1508976000|spotify:track:6lb...|Up Down (Do This ...|spotify:artist:3a...|        T-Pain|1043994298|
|45000| Ratch| 1508976000|spotify:track:3G7...|         Classic Man|spotify:artist:4T...|       Jidenna| 874172154|
|45000| Ratch| 1508976000|spotify:track:27G...|             Jumpman|spot

In [9]:
# show the number of unique values in each column
raw_df.agg(*(
    [countDistinct(c).alias(c) for c in raw_df.columns]
)).show()

+-----+------+-----------+---------+----------+----------+-----------+--------+
| p_id|p_name|modified_at|track_uri|track_name|artist_uri|artist_name|track_id|
+-----+------+-----------+---------+----------+----------+-----------+--------+
|50000| 19156|       1846|   461880|    336535|     80483|      79327|  461832|
+-----+------+-----------+---------+----------+----------+-----------+--------+



In [10]:
# show data types
raw_df.dtypes

[('p_id', 'bigint'),
 ('p_name', 'string'),
 ('modified_at', 'bigint'),
 ('track_uri', 'string'),
 ('track_name', 'string'),
 ('artist_uri', 'string'),
 ('artist_name', 'string'),
 ('track_id', 'string')]

# Create Lookup tables

In [11]:
# track look up table
track_lut = raw_df.select(
    col("track_id"),
    col("track_uri"),
    col("track_name"),
    col("artist_uri"),
    col("artist_name"),
    col("modified_at")
)

# sort by modified at
track_lut = track_lut.sort(["track_id", "modified_at"], ascending=False)

# drop duplicated
track_lut = track_lut.dropDuplicates(["track_id"])

# drop modified_at
track_lut = track_lut.drop("modified_at")

track_lut.show()

+----------+--------------------+--------------------+--------------------+-------------------+
|  track_id|           track_uri|          track_name|          artist_uri|        artist_name|
+----------+--------------------+--------------------+--------------------+-------------------+
|1000018913|spotify:track:3vQ...|    Suspicious Minds|spotify:artist:3K...|    Dee Dee Warwick|
|1000034681|spotify:track:3KG...|Phantom Of The Op...|spotify:artist:7e...|              ERock|
| 100004043|spotify:track:339...|Your Love Gets Sw...|spotify:artist:1p...|       Finley Quaye|
|1000042406|spotify:track:5zh...|   Isolate - Ambient|spotify:artist:3O...|               Moby|
|1000049331|spotify:track:45o...|Your Love Is Extr...|spotify:artist:6e...|     Casting Crowns|
|1000053224|spotify:track:7Bm...|             Tool Up|spotify:artist:6z...|            Blueboy|
| 100005557|spotify:track:497...|California - Tcha...|spotify:artist:0L...|     Phantom Planet|
|1000056024|spotify:track:6nP...|       

# Machine Learning

### Training Dataframe

In [None]:
from pyspark.sql.functions import lit

# create a rating column with 1
track_popularity = raw_df.groupBy("artist_uri").count()
training_df = raw_df.join(track_popularity, on="artist_uri")

# rating column as 1
training_df = training_df.withColumn("rating", lit(1))

# # convert track_id to integer
training_df = training_df.withColumn("track_id", training_df["track_id"].cast('integer'))

training_df.show()

+--------------------+-----+-----------+-----------+--------------------+--------------------+-------------------+----------+-----+------+
|          artist_uri| p_id|     p_name|modified_at|           track_uri|          track_name|        artist_name|  track_id|count|rating|
+--------------------+-----+-----------+-----------+--------------------+--------------------+-------------------+----------+-----+------+
|spotify:artist:00...|44663|    Hmmm...| 1507593600|spotify:track:2xN...|The Third Untitle...|            Kosmose| 507534104|    2|     1|
|spotify:artist:00...|44663|    Hmmm...| 1507593600|spotify:track:7BO...|The Tenth Untitle...|            Kosmose|1580668254|    2|     1|
|spotify:artist:00...|10420|Study Music| 1474329600|spotify:track:4og...|        Tango en Sky|       Roland Dyens| 240362849|    1|     1|
|spotify:artist:00...|36125|     Bandas| 1470096000|spotify:track:2hj...|        Como La Luna|         Banda Boom|1803040577|    3|     1|
|spotify:artist:00...|36125

In [13]:
# show number of nans in each column
training_df.select(
    [F.count(F.when(F.isnan(c) | F.col(c).isNull(), c)).alias(c) for c in training_df.columns]
).show()

+----------+----+------+-----------+---------+----------+-----------+--------+-----+------+
|artist_uri|p_id|p_name|modified_at|track_uri|track_name|artist_name|track_id|count|rating|
+----------+----+------+-----------+---------+----------+-----------+--------+-----+------+
|         0|   0|     0|          0|        0|         2|          0|       0|    0|     0|
+----------+----+------+-----------+---------+----------+-----------+--------+-----+------+



### Train Model

In [14]:
# Initialize the ALS model with implicit feedback
als = ALS(
    userCol="p_id",
    itemCol="track_id",
    ratingCol="rating",
    implicitPrefs=True,
    coldStartStrategy="drop",
    maxIter=10,
    regParam=0.1,
    alpha=1.0
)

# Fit the model
model = als.fit(training_df)

### Get Recommendations

In [15]:
# Generate top 10 recommendations for each playlist
rec_json = model.recommendForAllUsers(10)

In [16]:
# Explode the recommendations array to have one row per recommended track
rec_df = rec_json.withColumn("rec", explode("recommendations"))\
    .select(
        col("p_id"),
        col("rec.track_id").alias("track_id"),
        col("rec.rating").alias("rating")
    )

# Join with the track lookup table to get human-readable details
rec_mapped = rec_df.join(track_lut, on="track_id", how="left")

# Display the results
rec_mapped.show(truncate=False)


+----------+----+------------+------------------------------------+-----------------------------+-------------------------------------+----------------------------+
|track_id  |p_id|rating      |track_uri                           |track_name                   |artist_uri                           |artist_name                 |
+----------+----+------------+------------------------------------+-----------------------------+-------------------------------------+----------------------------+
|64861707  |2   |0.00195412  |spotify:track:6D0b04NJIKfEMg040WioJQ|Issues                       |spotify:artist:0ZED1XzwlLHW4ZaG4lOT6m|Julia Michaels              |
|107761059 |3   |0.018479995 |spotify:track:5JuA3wlm0kn7IHfbeHV0i6|All I Want                   |spotify:artist:4BxCuXFJrSWGi1KHcVqaU4|Kodaline                    |
|159134787 |0   |0.3175912   |spotify:track:5i66xrvSh1MjjyDd6zcwgj|Umbrella                     |spotify:artist:5pKCCKE2ajJHZ9KAiaK11H|Rihanna                     |
|166464816

# Extension: Recommending to New Playlists
Make recommendations on custom user playlists.

## Spotify API parsing tool

In [17]:
import pandas as pd
import requests
import json

# Class to get spotify playlists
class Spotify:
    # gets the access token for spotify
    def _getAccessToken(self) -> str:
        # send request for token
        response = requests.post(
            url="https://accounts.spotify.com/api/token",
            headers={"Content-Type": "application/x-www-form-urlencoded"},
            data={
                "grant_type": "client_credentials",
                "client_id": self.client_id,
                "client_secret": self.client_secret
            }
        )

        # return access token
        if response.status_code == 200:
            return response.json()["access_token"]

        # raise error
        raise Exception(f"Error: {response.status_code} - {response.json()}")

    # gets information related to playlist
    def _getPlMetadata(self, pid: str) -> dict:
        # call the api
        response = requests.get(
            url=f"https://api.spotify.com/v1/playlists/{pid}",
            headers={"Authorization": f"Bearer {self.access_token}"}
        )

        if response.status_code == 200:
            playlist = response.json()

            return {
                "p_name": playlist["name"]
            }

        raise Exception(f"Error: {response.status_code} - {response.json()}")

    # get all the tracks in the playlist
    def _getPlTracks(self, pid: str) -> dict:
        # playlist api call helpers
        url = f"https://api.spotify.com/v1/playlists/{pid}/tracks"
        headers = {"Authorization": f"Bearer {self.access_token}"}

        # keep all the tracks
        all_tracks = []

        # each call provides up to 100 values
        while url:
            # call api
            response = requests.get(url, headers=headers)

            if response.status_code == 200:
                data = response.json()

                # go through all the tracks
                for idx, item in enumerate(data["items"]):
                    # validate if its a track
                    if item["track"]["track"]:
                        # get track data
                        track_data = {
                            "added_at": item["added_at"],
                            "track_name": item["track"]["name"],
                            "track_uri": item["track"]["uri"],
                            "artist_name": item["track"]["artists"][0]["name"],
                            "artist_uri": item["track"]["artists"][0]["uri"]
                        }

                        all_tracks.append(track_data)

                url = data["next"]
            else:
                raise Exception(
                    f"Error: Failed to fetch tracks: {response.status_code}"
                )

        return all_tracks

    def __init__(self, client_id, client_secret):
        self.client_id = client_id
        self.client_secret = client_secret
        self.access_token = self._getAccessToken()

    def getPlaylist(self, pid:str):
        metadata = self._getPlMetadata(pid)
        tracks = self._getPlTracks(pid)

        frame = pd.DataFrame.from_dict(tracks)

        for key, value in metadata.items():
            frame[key] = value

        frame["modified_at"] = pd.to_datetime(frame["added_at"])\
            .astype("int64").max()//1000000000

        frame.drop(columns=["added_at"], inplace=True)

        return frame


## Playlist Retrieval and Tabularization


In [18]:
# get max playlist id
MAX_PID = raw_df.agg({"p_id": "max"}).collect()[0][0]

In [None]:
playlist_ids = [
    "2uByAv6dvHWgmSlhtbRAoA",
    "4l2aGNlWpQmPsqVC68xnjv",
    "4rQryQheY6YvdOGB2OLYg6",
    "0qTRBbSEEVCgap2XtnBPVL",
]

# Can generate your own the spotify developers page
spotify = Spotify(
    client_id="SECRET",
    client_secret="SECRET"
)

# create list of dataframes
user_playlists = []

# add new dataframes
for idx, playlist_id in enumerate(playlist_ids):
    user_playlists.append(spotify.getPlaylist(playlist_id))
    user_playlists[idx]["p_id"] = MAX_PID + idx + 1

# combine the dataframes
user_pd_df = pd.concat(user_playlists)

# order the columns
user_pd_df = user_pd_df[[
    "p_id", "p_name", "modified_at",
    "track_uri", "track_name", "artist_name", "artist_uri"
]]

# convert to spark
user_pl_df = spark.createDataFrame(user_pd_df)
user_pl_df = user_pl_df.withColumn("track_id", decode_uri(col("track_uri")))

# show
user_pl_df.show(10, truncate=False)

+-----+-------+-----------+------------------------------------+-------------------------+--------------+-------------------------------------+----------+
|p_id |p_name |modified_at|track_uri                           |track_name               |artist_name   |artist_uri                           |track_id  |
+-----+-------+-----------+------------------------------------+-------------------------+--------------+-------------------------------------+----------+
|50000|up high|1739089625 |spotify:track:4Pwjz3DfvfQWV0rO2V8jyh|Bitch, Don’t Kill My Vibe|Kendrick Lamar|spotify:artist:2YZyLoL8N0Wb9xBt1NhZWg|1649569792|
|50000|up high|1739089625 |spotify:track:5hM5arv9KDbCHS0k9uqwjr|Borderline               |Tame Impala   |spotify:artist:5INjqkS1o8h1imAzPqGZBb|524024734 |
|50000|up high|1739089625 |spotify:track:5M4yti0QxgqJieUYaEXcpw|Eventually               |Tame Impala   |spotify:artist:5INjqkS1o8h1imAzPqGZBb|139391466 |
|50000|up high|1739089625 |spotify:track:1fOkmYW3ZFkkjIdOZSf596|Pink M

# Recommend Songs to New Playlists

In [20]:
from pyspark.sql.functions import col, collect_list, explode
from pyspark.sql.types import ArrayType, FloatType
from pyspark.sql import Window
import pyspark.sql.functions as F
import numpy as np

# get latent factors
item_factors_df = model.itemFactors.withColumnRenamed("id", "track_id")

# get user playlist id and track id
playlist_tracks = user_pl_df.select("p_id", "track_id").distinct()

# merge tracks with factors
playlist_with_factors = playlist_tracks.join(item_factors_df, on="track_id", how="inner")

# average latent factors for each vector
def average_vectors(vectors):
    return np.mean(np.array(vectors), axis=0).tolist()

# to user defined function
avg_udf = F.udf(average_vectors, ArrayType(FloatType()))

# compute average vectors
playlist_latent = playlist_with_factors.groupBy("p_id") \
    .agg(collect_list("features").alias("features_list")) \
    .withColumn("user_features", avg_udf(col("features_list")))

# dot two vectors: shows you how aligned they are
def dot_product(v1, v2):
    return float(np.dot(np.array(v1), np.array(v2)))

# to udf
dot_udf = F.udf(dot_product, FloatType())

# get score: dot of user features and actual features
recommendations_all = playlist_latent.crossJoin(item_factors_df) \
    .withColumn("score", dot_udf(col("user_features"), col("features")))

# remove existing tracks from being recommended
recommendations_filtered = recommendations_all.join(
    playlist_tracks,
    on=["p_id", "track_id"],
    how="left_anti"
)

# recommenend top 10 for each playlist
windowSpec = Window.partitionBy("p_id").orderBy(F.col("score").desc())
recommendations_top = recommendations_filtered.withColumn("rank", F.row_number().over(windowSpec)) \
    .filter(F.col("rank") <= 10)

# join with lookup table
final_recommendations = recommendations_top.join(track_lut, on="track_id", how="left")

# Show the final recommendations for each new playlist
final_recommendations.select("p_id", "track_id", "score", "track_name", "track_uri", "artist_name", "artist_uri") \
    .show(truncate=False)


+-----+----------+------------+------------------------------+------------------------------------+----------------+-------------------------------------+
|p_id |track_id  |score       |track_name                    |track_uri                           |artist_name     |artist_uri                           |
+-----+----------+------------+------------------------------+------------------------------------+----------------+-------------------------------------+
|50000|394565429 |0.11989869  |HUMBLE.                       |spotify:track:7KXjTSCq5nL1LoYtL7XAwS|Kendrick Lamar  |spotify:artist:2YZyLoL8N0Wb9xBt1NhZWg|
|50000|447907122 |0.13727513  |White Iverson                 |spotify:track:6eT7xZZlB2mwyzJ2sUKG6w|Post Malone     |spotify:artist:246dkjvS1zLTtiykXe5h60|
|50001|484207603 |0.010996327 |Shut Up and Dance             |spotify:track:4kbj5MwxO1bq9wjT5g9HaA|WALK THE MOON   |spotify:artist:6DIS6PRrLS3wbnZsf7vYic|
|50001|644352059 |0.011401648 |Hey Ya! - Radio Mix / Club Mix|spotify: