Dependencies

In [2]:
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql.functions import col, lit, rand, explode, array, concat_ws
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.regression import LinearRegression
from pyspark.ml import Pipeline, PipelineModel
import random
import string

Initialization of the Spark session

In [3]:
# Initialize Spark session
spark = SparkSession.builder.master("local[*]").appName("DummyDataGenerator").getOrCreate()

your 131072x1 screen size is bogus. expect trouble
25/01/01 15:57:36 WARN Utils: Your hostname, emspa resolves to a loopback address: 127.0.1.1; using 10.255.255.254 instead (on interface lo)
25/01/01 15:57:36 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/01/01 15:57:36 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


Schema for data structure

In [4]:
schema = StructType([
    StructField("name", StringType(), True),
    StructField("collaborative", StringType(), True),
    StructField("pid", IntegerType(), True),
    StructField("modified_at", LongType(), True),
    StructField("num_tracks", IntegerType(), True),
    StructField("num_albums", IntegerType(), True),
    StructField("num_followers", IntegerType(), True),
    StructField("tracks", ArrayType(
        StructType([
            StructField("pos", IntegerType(), True),
            StructField("artist_name", StringType(), True),
            StructField("track_uri", StringType(), True),
            StructField("artist_uri", StringType(), True),
            StructField("track_name", StringType(), True),
            StructField("album_uri", StringType(), True),
            StructField("duration_ms", IntegerType(), True),
            StructField("album_name", StringType(), True),
            StructField("artist_data", StructType([
                StructField("genres", ArrayType(StringType()), True),
                StructField("name", StringType(), True),
                StructField("popularity", IntegerType(), True),
                StructField("uri", StringType(), True)
            ]))
        ])
    ), True),
    StructField("num_edits", IntegerType(), True),
    StructField("duration_ms", LongType(), True),
    StructField("num_artists", IntegerType(), True),
    StructField("origin", StringType(), True),
    StructField("genre_counts", MapType(StringType(), IntegerType()), True),
])

## Random Data generation for testing:

In [5]:
def random_string(length=8):
    return ''.join(random.choices(string.ascii_letters, k=length))

def random_genre_counts():
    genres = ["pop", "rock", "hip-hop", "jazz", "classical"]
    return {genre: random.randint(0, 50) for genre in random.sample(genres, random.randint(1, len(genres)))}

def random_track():
    return {
        "pos": random.randint(0, 100),
        "artist_name": random_string(10),
        "track_uri": f"spotify:track:{random_string(22)}",
        "artist_uri": f"spotify:artist:{random_string(22)}",
        "track_name": random_string(15),
        "album_uri": f"spotify:album:{random_string(22)}",
        "duration_ms": random.randint(180000, 300000),
        "album_name": random_string(12),
        "artist_data": {
            "genres": random.sample(["pop", "rock", "hip-hop", "jazz", "classical"], random.randint(1, 3)),
            "name": random_string(10),
            "popularity": random.randint(0, 100),
            "uri": f"spotify:artist:{random_string(22)}"
        }
    }

def random_playlist():
    return {
        "name": random_string(12),
        "collaborative": random.choice(["true", "false"]),
        "pid": random.randint(1, 10000),
        "modified_at": random.randint(1609459200, 1672444800),  # Timestamps from 2021 to 2023
        "num_tracks": random.randint(1, 50),
        "num_albums": random.randint(1, 20),
        "num_followers": random.randint(0, 10000),
        "tracks": [random_track() for _ in range(random.randint(1, 20))],
        "num_edits": random.randint(0, 50),
        "duration_ms": random.randint(3600000, 7200000),
        "num_artists": random.randint(1, 30),
        "origin": random.choice(["US", "UK", "IN", "DE", "FR"]),
        "genre_counts": random_genre_counts()
    }


In [6]:
# Generate dummy data
data = [random_playlist() for _ in range(10)]

In [7]:
df = spark.createDataFrame(data, schema=schema)

df.show(truncate=False)
df.printSchema()

                                                                                

+------------+-------------+----+-----------+----------+----------+-------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

## Data Preprocessing:

In [8]:
# creates a flat structure for the data
exploded_tracks = df.select(
    explode(col("tracks")).alias("track"),
    col("num_followers")
).select(
    col("track.artist_name"),
    col("track.track_name"),
    col("track.artist_data.genres").alias("genres"),
    col("track.artist_data.popularity").alias("popularity"),
    col("num_followers"),
)

# converts genres array into a string
exploded_tracks = exploded_tracks.withColumn("genres_str", concat_ws(",", col("genres")))

# index the genres
indexer = StringIndexer(inputCol="genres_str", outputCol="genre_index")
indexed_tracks = indexer.fit(exploded_tracks).transform(exploded_tracks)

# creates a feature vector of the inputs
assembler = VectorAssembler(
    inputCols=["genre_index", "popularity", "num_followers"],
    outputCol="features"
)

final_data = assembler.transform(indexed_tracks).select("features", "popularity")

                                                                                

### Split data into training and test sets:

In [9]:
train_data, test_data = final_data.randomSplit([0.8, 0.2], seed=1234)

### Model Training:

In [10]:
lr = LinearRegression(featuresCol="features", labelCol="popularity", regParam=0.1)
lr_model = lr.fit(train_data)

25/01/01 15:58:06 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
25/01/01 15:58:06 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.lapack.JNILAPACK


### Model Evaluation:

In [11]:
test_results = lr_model.evaluate(test_data)
print(f"Root Mean Squared Error (RMSE): {test_results.rootMeanSquaredError}")
print(f"R2: {test_results.r2}")

Root Mean Squared Error (RMSE): 0.1300557800347414
R2: 0.9999789883458495


In [13]:
# Step 8: Use the model to make predictions (example)
predictions = lr_model.transform(test_data)
predictions.select("features", "popularity", "prediction").show(truncate=False)

+------------------+----------+------------------+
|features          |popularity|prediction        |
+------------------+----------+------------------+
|[30.0,28.0,6986.0]|28        |28.119530774333327|
|[5.0,61.0,8589.0] |61        |60.91281567261692 |
|[8.0,75.0,8589.0] |75        |74.86336695858968 |
|[33.0,27.0,8589.0]|27        |27.12155755993408 |
|[1.0,40.0,2811.0] |40        |40.022779100531096|
|[3.0,95.0,2811.0] |95        |94.80393057680993 |
|[27.0,79.0,2811.0]|79        |78.92935518808069 |
|[34.0,23.0,2811.0]|23        |23.174885819017128|
|[0.0,51.0,7748.0] |51        |50.94598438881523 |
|[1.0,17.0,7748.0] |17        |17.08689122769338 |
|[2.0,99.0,7748.0] |99        |98.75562883590985 |
|[12.0,11.0,7748.0]|11        |11.138950275322804|
|[14.0,52.0,7748.0]|52        |51.97708769323321 |
+------------------+----------+------------------+



## Model Usage

### User Input:

In [44]:
user_genres = ["hip-hop", "classical"]

### Model Suggestion:

In [84]:
def filter_by_user_genres(df, genres):
    return df.filter(
        (col("genres_str").contains(genres[0])) | (col("genres_str").contains(genres[1]))
    )

def preprocess_genres(df):
    return df.withColumn("genres_str", concat_ws(",", col("genres")))

preprocessed_data = preprocess_genres(exploded_tracks)
filtered_tracks = filter_by_user_genres(preprocessed_data, user_genres)
indexed_tracks = indexer.fit(filtered_tracks).transform(filtered_tracks)
assembled_data = assembler.transform(indexed_tracks).select("features", "track_name", "artist_name")

predictions = lr_model.transform(assembled_data)

predictions.select("track_name", "artist_name", "features", "prediction").show(truncate=False)

+---------------+-----------+------------------+------------------+
|track_name     |artist_name|features          |prediction        |
+---------------+-----------+------------------+------------------+
|dEBssPVaWoaRFAn|cEGCYPifKL |[2.0,94.0,2716.0] |93.85365259503689 |
|EEJuyCdXNSSJPgM|owSraOYtkW |[1.0,82.0,2716.0] |81.89648625028266 |
|EEdLvwCOELEsdxh|LNtsWaKRJD |[0.0,2.0,2716.0]  |2.1931829409831605|
|tYqIoGbpnAbEOgd|zINbBfAweh |[0.0,12.0,2716.0] |12.155850141651584|
|cwWYPfcMytFWaTP|WGXltBfIyy |[16.0,68.0,7716.0]|67.95707320528395 |
|uvjLmNVkimOIDXk|gkizhygllk |[0.0,66.0,7716.0] |65.93308850191627 |
|OumqdZxNmSesjRR|gldYmCciLO |[8.0,50.0,7716.0] |50.00854661246379 |
|GTDFkWDUPrnNZxK|QdaJcAlbiE |[32.0,54.0,7716.0]|54.04079038758214 |
|zcWvCSRwaeGUfWS|VxbrUMLLSe |[1.0,2.0,7716.0]  |2.173984121590479 |
|PBDGPPhQjvGKcZr|iTBCFKYOZp |[7.0,47.0,7716.0] |47.01778074831113 |
|MVJOfdKjeRqNdbe|tJeeHQKGCO |[34.0,93.0,7716.0]|92.89912387809325 |
|jbJqNryAnnsgoCv|AoodCXZxDj |[24.0,81.0,7716.0]|

In [55]:
# Step 8: Display the top 15 tracks with the highest predicted popularity
top_tracks = predictions.orderBy(col("prediction").desc()).limit(15)

# Show the result
top_tracks.select("track_name", "artist_name", "prediction").show(truncate=False)

+---------------+-----------+-----------------+
|track_name     |artist_name|prediction       |
+---------------+-----------+-----------------+
|RYjhkDnieDcNXzx|aRZZItPOhz |96.99999999999997|
|QaPqPJyjPxGKycP|UDeEmsfQiN |92.99999999999999|
|LTIOArFTroCLJmf|sqPlAKDHDt |92.99999999999997|
|zncyDpqDxOExWIW|zaFcnwMRIr |88.99999999999999|
|XvICoPJcEjhlfrb|FmafEYgebt |87.99999999999997|
|ifNrarAkApjPcRx|SADDOJbmeX |84.99999999999999|
|frjaQKPaJnTlrTu|JQyPnfpNhx |83.0             |
|ywjxrWRmHpioybV|BLQbnLczto |82.99999999999999|
|FQgWrQYaKfiRgXO|pzaKHdehYC |82.99999999999997|
|mgfIjNVtWheBCBD|puabXLQDRK |79.99999999999999|
|LzDvIpGmFvOGdIE|YsShxMtnrZ |77.99999999999999|
|fITlrrWwiCzFJSc|LdkbBeAXmt |76.99999999999999|
|FYsdMBDftFQckqd|qqQhLWABzx |75.99999999999999|
|siwtKSjRqSQhszv|aZfZjaefQL |75.99999999999999|
|ewvKtiqhGNYPhYo|gTcZWeInIH |75.0             |
+---------------+-----------+-----------------+

