<a href="https://colab.research.google.com/github/DomizianoScarcelli/big-data-project/blob/main/PlaylistReccomender.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Config

# Configuration

In [79]:
#@title Download necessary libraries
!pip install pyspark
!pip install -U -q PyDrive 
!apt install openjdk-8-jdk-headless -qq

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
openjdk-8-jdk-headless is already the newest version (8u362-ga-0ubuntu1~20.04.1).
0 upgraded, 0 newly installed, 0 to remove and 24 not upgraded.


In [80]:
#@title Imports
import os
import requests
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
import plotly

import pyspark
from pyspark.sql import *
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, ArrayType, FloatType, LongType
from pyspark.sql.functions import *
from pyspark import SparkContext, SparkConf
from pyspark.ml.linalg import SparseVector, DenseVector

from tqdm.notebook import tqdm
import time
import gc

from google.colab import drive

In [81]:
#@title Set up variables
JAVA_HOME = "/usr/lib/jvm/java-8-openjdk-amd64"
GDRIVE_DIR = "/content/drive"
GDRIVE_HOME_DIR = GDRIVE_DIR + "/MyDrive"
GDRIVE_DATA_DIR = GDRIVE_HOME_DIR + "/Big Data/datasets"
DATASET_FILE = os.path.join(GDRIVE_DATA_DIR, "pyspark_friendly_spotify_playlist_dataset")
AUDIO_FEATURES_FILE = os.path.join(GDRIVE_DATA_DIR, "pyspark_track_features")
LITTLE_SLICE_FILE = os.path.join(GDRIVE_DATA_DIR, "little_slice")
SMALL_SLICE_FLIE = os.path.join(GDRIVE_DATA_DIR, "small_slice")
LITTLE_SLICE_AUDIO_FEATURES = os.path.join(GDRIVE_DATA_DIR, "little_slice_audio_features")
MICRO_SLICE_AUDIO_FEATURES = os.path.join(GDRIVE_DATA_DIR, "micro_slice_audio_features")
SPLITTED_SLICE_AUDIO_FEATURES = os.path.join(GDRIVE_DATA_DIR, "splitted_pyspark_track_features")
SAVED_DFS_PATH = os.path.join(GDRIVE_DATA_DIR, "saved_dfs")
RANDOM_SEED = 42 # for reproducibility
os.environ["JAVA_HOME"] = JAVA_HOME
os.environ["PYSPARK_PYTHON"]="python"

In [82]:
#@title Create the session
conf = SparkConf().\
                set('spark.ui.port', "4050").\
                set('spark.executor.memory', '12G').\
                set('spark.driver.memory', '12G').\
                set('spark.driver.maxResultSize', '100G').\
                set("spark.executor.extraJavaOptions", "-XX:+UseG1GC").\
                setAppName("PySparkTutorial").\
                setMaster("local[*]")

# Create the context
sc = pyspark.SparkContext(conf=conf)
spark = SparkSession.builder.getOrCreate()

ValueError: ignored

In [84]:
drive.mount(GDRIVE_DIR, force_remount=True)

Mounted at /content/drive


## Setup ngrok

In [85]:
!pip install pyngrok

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [86]:
!ngrok authtoken 2NVN8kdoOnMVtlDGGWtwsbT5M3Q_2EJv2HE77FEXkz978Qtnq

Authtoken saved to configuration file: /root/.ngrok2/ngrok.yml


In [87]:
from pyngrok import ngrok

# Open a ngrok tunnel on the port 4050 where Spark is running
port = '4050'
public_url = ngrok.connect(port).public_url



In [88]:
print("To access the Spark Web UI console, please click on the following link to the ngrok tunnel \"{}\" -> \"http://127.0.0.1:{}\"".format(public_url, port))

To access the Spark Web UI console, please click on the following link to the ngrok tunnel "https://4d9d-34-141-140-215.ngrok-free.app" -> "http://127.0.0.1:4050"


In [89]:
#@title Check if everything is ok
spark, sc._conf.getAll()


(<pyspark.sql.session.SparkSession at 0x7f44e3c115a0>,
 [('spark.executor.extraJavaOptions',
   '-Djava.net.preferIPv6Addresses=false -XX:+IgnoreUnrecognizedVMOptions --add-opens=java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.invoke=ALL-UNNAMED --add-opens=java.base/java.lang.reflect=ALL-UNNAMED --add-opens=java.base/java.io=ALL-UNNAMED --add-opens=java.base/java.net=ALL-UNNAMED --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED --add-opens=java.base/java.util.concurrent=ALL-UNNAMED --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED --add-opens=java.base/sun.nio.ch=ALL-UNNAMED --add-opens=java.base/sun.nio.cs=ALL-UNNAMED --add-opens=java.base/sun.security.action=ALL-UNNAMED --add-opens=java.base/sun.util.calendar=ALL-UNNAMED --add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED -Djdk.reflect.useDirectMethodHandle=false -XX:+UseG1GC'),
  ('spark.app.name', 'PySparkTutorial'),
  ('spark.driver.port', '42775'),
  ('spa

# Data acquisition

In [90]:

song_schema = StructType([
    StructField("pos", IntegerType(), True),
    StructField("artist_name", StringType(), True),
    StructField("track_uri", StringType(), True),
    StructField("artist_uri", StringType(), True),
    StructField("track_name", StringType(), True),
    StructField("album_uri", StringType(), True),
    StructField("duration_ms", LongType(), True),
    StructField("album_name", StringType(), True)
])

playlist_schema = StructType([
    StructField("name", StringType(), True),
    StructField("collaborative", StringType(), True),
    StructField("pid", IntegerType(), True),
    StructField("modified_at", IntegerType(), True),
    StructField("num_tracks", IntegerType(), True),
    StructField("num_albums", IntegerType(), True),
    StructField("num_followers", IntegerType(), True),
    StructField("tracks", ArrayType(song_schema), True),
    StructField("num_edits", IntegerType(), True),
    StructField("duration_ms", IntegerType(), True),
    StructField("num_artists", IntegerType(), True),
])

audio_features_schema = StructType([
    StructField("danceability", FloatType(), True),
    StructField("energy", FloatType(), True),
    StructField("key", IntegerType(), True),
    StructField("loudness", FloatType(), True),
    StructField("mode", IntegerType(), True),
    StructField("speechiness", FloatType(), True),
    StructField("acousticness", FloatType(), True),
    StructField("instrumentalness", FloatType(), True),
    StructField("liveness", FloatType(), True),
    StructField("valence", FloatType(), True),
    StructField("tempo", FloatType(), True),
    StructField("type", StringType(), True),
    StructField("id", StringType(), True),
    StructField("uri", StringType(), True),
    StructField("track_href", StringType(), True),
    StructField("analysis_url", StringType(), True),
    StructField("duration_ms", LongType(), True),
    StructField("time_signature", IntegerType(), True)
])


In [91]:
playlist_df = spark.read.schema(playlist_schema).json(DATASET_FILE, multiLine=True)
slice_df = spark.read.schema(playlist_schema).json(SMALL_SLICE_FLIE, multiLine=True)
# slice_df = spark.read.schema(playlist_schema).json(LITTLE_SLICE_FILE, multiLine=True)
audio_df = spark.read.schema(audio_features_schema).json(SPLITTED_SLICE_AUDIO_FEATURES, multiLine=True) #has less songs than expected

In [92]:
# slice_df.select("tracks").first()

In [93]:
slice_df.show()

+--------------+-------------+-----+-----------+----------+----------+-------------+--------------------+---------+-----------+-----------+
|          name|collaborative|  pid|modified_at|num_tracks|num_albums|num_followers|              tracks|num_edits|duration_ms|num_artists|
+--------------+-------------+-----+-----------+----------+----------+-------------+--------------------+---------+-----------+-----------+
|         Ratch|        false|45000| 1508976000|        88|        70|            1|[{0, Beyoncé, spo...|       50|   20047039|         48|
|  slow it down|        false|45001| 1505952000|        80|        77|            1|[{0, Twinbed, spo...|       20|   20365984|         65|
|    Phat Beats|        false|45002| 1466640000|        24|        15|            5|[{0, Baths, spoti...|       16|    5127143|         14|
|           ✌🏽|        false|45003| 1509148800|        77|        63|            3|[{0, Owl City, sp...|       50|   17201663|         54|
|          💘💘|       

# User-Based Collaborative Filtering
Note: The users are the playlists, the items are the songs and the ratings are 0 if the song is not in the playlist, 1 otherwise.

We have to define a function $sim(u,v)$ that defines the similarity between two users based on their ratings.

We represent the ratings $r_u \in \mathbb{R}^n$ as the $n$ dimensional vector that represents the ratings of the user $u$, where $n$ is the number of total songs in the dataset.

As the similarity function we can use Jaccard similarity.
\begin{equation}
sim(u,v) = J(r_u, r_v) = \frac{|r_u \cap r_v|}{|r_u \cup r_v|}
\end{equation}

Jaccard similarity ignores rating values, but we don't care here since the ratings are binary. In case of discrete value ratings we can use cosine similarity, or better pearson's correlation.

Done that, and defined as ${U^k}$ the neighborhood of $u$ ($k$ most similar users to $u$), we define the set of items rated by $u$'s neighborhood as

\begin{equation}
I^k = \{i \in I : \mathbf{r_{u,i}} \downarrow \land u \in U^k\}
\end{equation}

The rating for the item $i$ to the user $u$ will just be $\mathbf{r_u[i]}$.

In [94]:
RATING_VECTOR_FILE_PATH = os.path.join(SAVED_DFS_PATH, "fake_parquet.parquet")

In [95]:
def dense_to_sparse(dense: DenseVector) -> SparseVector:
  nonzero_indices = np.nonzero(np.array(dense))[0]
  nonzero_values = np.array(dense)[nonzero_indices]
  sparse_vector = SparseVector(len(dense), nonzero_indices.tolist(), nonzero_values.tolist())
  return sparse_vector

In [96]:
def get_all_songs(playlist_df: DataFrame, set_in_playlist: bool = False) -> DataFrame:
   all_songs = playlist_df.select(explode("tracks.track_uri").alias("track_uri")).distinct()
   if set_in_playlist:
     all_songs = all_songs.withColumn("in_playlist", lit(1))
   return all_songs
  
def get_songs_info(playlist_df: DataFrame, set_in_playlist: bool = False) -> DataFrame:
   all_songs = playlist_df.select(explode("tracks")).select("col.*").drop("pos").distinct()
   if set_in_playlist:
     all_songs = all_songs.withColumn("in_playlist", lit(1))
   return all_songs

songs_df = get_all_songs(slice_df)
songs_df.createOrReplaceTempView("SONGS")

songs_df = spark.sql("""
SELECT 
    row_number() OVER (
        PARTITION BY '' 
        ORDER BY '' 
    ) as pos,
    *
FROM 
    SONGS
""")

songs_df = songs_df.sort("track_uri")

songs_info_df = get_songs_info(slice_df)

songs_info_df = songs_info_df.join(songs_df, "track_uri", "left")

RATING_VECTOR_LENGTH = songs_df.count()

In [97]:
songs_df.show()

+---+--------------------+
|pos|           track_uri|
+---+--------------------+
|  1|spotify:track:1mr...|
|  2|spotify:track:1Uv...|
|  3|spotify:track:4WR...|
|  4|spotify:track:7B6...|
|  5|spotify:track:2Gy...|
|  6|spotify:track:7AO...|
|  7|spotify:track:48Z...|
|  8|spotify:track:1Um...|
|  9|spotify:track:7MO...|
| 10|spotify:track:27P...|
| 11|spotify:track:6lt...|
| 12|spotify:track:1yz...|
| 13|spotify:track:5Mz...|
| 14|spotify:track:3BU...|
| 15|spotify:track:4Cl...|
| 16|spotify:track:2dN...|
| 17|spotify:track:341...|
| 18|spotify:track:7ja...|
| 19|spotify:track:4eQ...|
| 20|spotify:track:6fy...|
+---+--------------------+
only showing top 20 rows



Preprocessing the dataframe in order to associate to each track_uri an integer, that will represent the position of the track in the rating_vector. This is useful in order to avoid doing a lot of joins when generating the rating_vectors.

In [98]:
# def map_track_df_to_pos(playlist_df: DataFrame, mapping: DataFrame) -> List[DataFrame]:
#   songs_df_list = [get_all_songs(spark.createDataFrame([row])) for row in tqdm(slice_df.collect(), desc="Creating list of dataframes")]
#   track_uri_to_id = songs_df.select('track_uri', 'pos').rdd.collectAsMap()
#   track_uri_to_id_udf = udf(lambda x: track_uri_to_id.get(x), IntegerType())
#   songs_df_mapped_list = []

#   for df in tqdm(songs_df_list, desc="Mapping uris to pos"):
#       df = df.withColumn('pos', track_uri_to_id_udf(col('track_uri')))
#       songs_df_mapped_list.append(df)
  
#   return songs_df_mapped_list

# songs_df_mapped_list = map_track_df_to_pos(slice_df, songs_df)

In [99]:
from pyspark.ml.linalg import VectorUDT
from pyspark.sql.functions import col, udf
from pyspark.sql.types import IntegerType, ArrayType

def map_track_df_to_pos_2(playlist_df: DataFrame, mapping: DataFrame) -> DataFrame:
    track_uri_to_id = songs_df.select('track_uri', 'pos').rdd.collectAsMap()

    track_uri_to_id_udf = udf(lambda x: track_uri_to_id.get(x), IntegerType())

    # Define a UDF to map track URIs to integer IDs within each element of the list
    
    def extract_vector(tracks):
      pos_list = set()
      for row in tracks:
        uri = track_uri_to_id.get(row.track_uri)
        pos_list.add(uri)

      return SparseVector(RATING_VECTOR_LENGTH, sorted(list(pos_list)), [1 for _ in pos_list])

    map_track_uri_udf = udf(lambda tracks: extract_vector(tracks), returnType=VectorUDT())

    # Apply the mapping UDF on the "tracks" column of the slice_df dataframe
    mapped_df = playlist_df.withColumn('tracks', map_track_uri_udf(col('tracks')))

    return mapped_df

mapped_slice_df = map_track_df_to_pos_2(slice_df, songs_df)

In [104]:
mapped_slice_df.show()

PythonException: ignored

In [101]:
def _create_rating_df(playlist_row: Row, songs_df: DataFrame) -> DataFrame:
  """
  Creates a dataframe that represents the "ratings" for a playlist in the dataframe
  """
  playlist_row = spark.createDataFrame([playlist_row], playlist_schema)
  playlist_uris = get_all_songs(playlist_row)

  joined = songs_df.join(playlist_uris, on="track_uri", how="right")
  return joined


def _check_songs_ordering(playlist_row: DataFrame, songs_df: DataFrame) -> bool:
  """
  Returns a boolean that indicates if the ordering in the songs_df and rating_df is the same
  """
  playlist_row = spark.createDataFrame([playlist_row], playlist_schema)
  playlist_uris = get_all_songs(playlist_row, True).withColumnRenamed("in_playlist", "isin")

  joined = songs_df.join(playlist_uris, on="track_uri", how="right")
  joined_left = songs_df.join(playlist_uris, on="track_uri", how="left").filter("isin == 1")
  assert joined.collect() == joined_left.collect(), f"The order of songs_df is different from the order of rating_df!"

# def _extract_rating_vector(rating_df: DataFrame) -> SparseVector:
#   """
#   Extracts the rating vectors for each playlist 
#   """
#   dense_vector = DenseVector([row.isin for row in rating_df.select("isin").collect()])
#   return dense_to_sparse(dense_vector)

def _extrac_sparse_rating_vector(rating_df: DataFrame) -> SparseVector:
  indices = np.sort([row.pos for row in rating_df.collect()])
  return SparseVector(RATING_VECTOR_LENGTH, indices, np.ones(indices.shape[0]) )

def rating_vector_from_row(playlist_row: Row, songs_df: DataFrame):
  """
  Pipelines togheter create_rating_df and extract_rating_vector.
  """
  rating_df_1 = _create_rating_df(playlist_row, songs_df)
  rating_vector_1 = _extrac_sparse_rating_vector(rating_df_1)
  return rating_vector_1

# t1 = time.time() 

# rating_vector_1 = rating_vector_from_row(slice_df.first(), songs_df)

# t2 = time.time()

# t2 - t1, rating_vector_1, type(rating_vector_1)

In [102]:
def jaccard_similarity(vector_1: SparseVector, vector_2: SparseVector) -> float:
  # Convert SparseVectors to sets
  set1 = set(vector_1.indices)
  set2 = set(vector_2.indices)

  # Calculate the intersection and union of the sets
  intersection = len(set1.intersection(set2))
  union = len(set1.union(set2))

  # Calculate the similarity
  similarity = intersection / union

  return similarity

In [103]:
def create_rating_vectors_df(playlists_df: DataFrame) -> DataFrame:
  rating_vectors = []

  for playlist_row in tqdm(playlists_df.collect(), desc="Creating rating vectors"):
    rating_vector = rating_vector_from_row(playlist_row, songs_df)
    new_row = Row(playlist_id=playlist_row.pid, rating_vector=rating_vector)
    rating_vectors.append([new_row])
  return spark.createDataFrame(rating_vectors)

if os.path.exists(RATING_VECTOR_FILE_PATH):
  # rv_schema = StructType([StructField('playlist_id', LongType(), True), StructField('rating_vector', pyspark.ml.linalg.VectorUDT(), True)])
  rating_vectors_df = spark.read.parquet(RATING_VECTOR_FILE_PATH)
  rv_df = rating_vectors_df.select(col("_1.playlist_id").alias("playlist_id"), col("_1.rating_vector").alias("rating_vector"))
else:
  rating_vectors_df = create_rating_vectors_df(slice_df)
  rating_vectors_df.write.parquet(RATING_VECTOR_FILE_PATH)

ERROR:root:KeyboardInterrupt while sending command.
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
  File "/usr/local/lib/python3.10/dist-packages/py4j/clientserver.py", line 511, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
  File "/usr/lib/python3.10/socket.py", line 705, in readinto
    return self._sock.recv_into(b)
KeyboardInterrupt


KeyboardInterrupt: ignored

In [None]:
def create_similarity_df(input_vector: DataFrame, rating_vectors_df: DataFrame, similarityFunction: Callable) -> DataFrame:
  rv_df_input = rating_vectors_df.crossJoin(input_vector)
  similarity_udf = udf(similarityFunction, returnType='double')
  result_df = rv_df_input.withColumn("similarity", similarity_udf(rv_df_input["input_vector"], rv_df_input["rating_vector"]))
  return result_df

first_playlist_vector = rv_df.limit(1).select("rating_vector").withColumnRenamed("rating_vector","input_vector")
result_df = create_similarity_df(first_playlist_vector, rv_df, jaccard_similarity)

Curse of dimensionality! We can see that each playlist is very dissimilar from each other playlist.

In [None]:
rv_df.show()

In [None]:
first_playlist_pid = rv_df.limit(1).select("playlist_id")
K = 20
top_k_results = result_df.filter( (col('playlist_id') != 1000)).orderBy(col("similarity").desc()).limit(K)
top_k_results.show()

Making a prediction

In [None]:
# from pyspark.ml.linalg import VectorUDT

# def add_sparse_vectors(accumulator: SparseVector, vector: SparseVector, weight: float) -> SparseVector:
#     accumulator_vec = accumulator.toArray() 
#     array_2 = vector.toArray() * weight

#     summed_array = accumulator_vec + array_2

#     values = [value for value in summed_array if value != 0]
#     sorted_indices = [index for index, value in enumerate(summed_array) if value != 0]
#     return SparseVector(accumulator_vec.size, sorted_indices, values)

# @udf(returnType=VectorUDT())
# def accumulate_sparse_vectors(accumulator, rating_vector, similarity):
#     summed_vector = add_sparse_vectors(accumulator, rating_vector, similarity)
#     return summed_vector

# df = top_k_results.withColumn('accumulated_vector', accumulate_sparse_vectors(top_k_results["rating_vector"], top_k_results["rating_vector"], top_k_results['similarity']))

In [None]:
def accumulate_top_k_results(top_k_results: DataFrame):
  accumulator = np.zeros(RATING_VECTOR_LENGTH)
  top_k = top_k_results.collect()
  for row in top_k:
    accumulator += row.rating_vector.toArray() * row.similarity

  values = [value for value in accumulator if value != 0]
  sorted_indices = [index for index, value in enumerate(accumulator) if value != 0]
  return SparseVector(accumulator.size, sorted_indices, values)

accumulated_vector = accumulate_top_k_results(top_k_results)

In [None]:
# @udf(returnType="int")
# def compute_vector_len(rating_vector):
#     return len(rating_vector.indices)
# final_df = df.withColumn('accumulated_vector_len', compute_vector_len(df["rating_vector"]))

In [None]:
# accumulated_vector = final_df.orderBy(col("accumulated_vector_len").desc()).first().accumulated_vector

In [None]:
def get_top_n_values(vector: SparseVector, n:int) -> List[int]:
  elements = list(enumerate(vector.toArray()))
  sorted_elements = sorted(elements, key=lambda x: x[1], reverse=True)
  top_n_indices = [(index, confidence) for index, confidence in sorted_elements[:n]]
  return top_n_indices

top_n_reccomendations = get_top_n_values(accumulated_vector, 10)

In [None]:
top_n_reccomendations

In [None]:
def song_info_from_index(index: int, confidence: float) -> Row:
  song_info = songs_info_df.filter(f"pos == {index}").withColumn("confidence", lit(confidence)).first()
  return song_info

songs_info = [song_info_from_index(index, confidence) for index, confidence in top_n_reccomendations]

In [None]:
reccomendations = spark.createDataFrame(songs_info)

In [None]:
slice_df.filter(f"pid == {first_playlist_pid.first().playlist_id}").show()

In [None]:
reccomendations.show()

#Item-Based Collaborative Filtering


# Fighting against the curse of dimensionality: Matrix Factorization

We want to define $\mathbf{x}_u \in \mathbb{R}^d$ $d$-dimensional vector that represents the user $u$, and $\mathbf{w}_i \in \mathbb{R}^d$ vector that represent the item $i$.

We then can estimate the rating of user $u$ for the item $i$ by computing
\begin{equation}
\hat{r}_{u, i}=\mathbf{x}_u^T \cdot \mathbf{w}_i=\sum_{j=1}^d x_{u, j} w_{j, i}
\end{equation}
Or, in matrix notation,

\begin{equation}
\underbrace{R}_{m \times n} =
\underbrace{X}_{m \times d}
\underbrace{W^T}_{d \times n}
\end{equation}

### How to learn $X$ and $W$
The matrix $R$ is partially known and filled with the observations inside the dataset $\mathcal{D}$. In order to learn the latent factor representations $X$ and $W$, we minimize the following loss function:
\begin{equation}
L(X, W)=\sum_{(u, i) \in \mathcal{D}}\underbrace{\left(r_{u, i}-\mathbf{x}_u^T \cdot \mathbf{w}_i\right)^2}_{\text{squared error term}}+\underbrace{\lambda\left(\sum_{u \in \mathcal{D}}\left\|\mathbf{x}_u\right\|^2+\sum_{i \in \mathcal{D}}\left\|\mathbf{w}_i\right\|^2\right)}_{\text{regularization term}}
\end{equation}

We can then minimize the loss using Stochastic Gradient Descent or Alternating Least Squares.

# Matrix Factorization
Generate a matrix Y where each column represent a playlist and each row represent a song, the (i,j) entry will be 1 if the playlist contains the song, 0 otherwise.

In [None]:
# Throw error in order to not execute the following code
raise ValueError()

In [None]:
import pyspark.sql.functions as f
from pyspark.sql.functions import explode
spark.conf.set("spark.sql.pivotMaxValues", 1000000)

from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.recommendation import ALS
from pyspark.sql import Row
# Get a DataFrame of only the relevant columns from the playlist schema
# playlists = slice_df.select("pid", "tracks.track_uri")
# playlists = playlists.select("pid", explode("track_uri").alias("song_uri"))
# playlists = playlists.withColumn("song_id", dense_rank().over(Window.orderBy("song_uri")))
# plaulists = playlists.withColumn("rating", lit(1))
# playlists.count()

from pyspark.sql.functions import expr


In [None]:
from pyspark.sql.functions import explode
import random
tracks_df = slice_df.select("pid", explode("tracks").alias("track")).select("pid", "track.track_uri")
tracks_df = tracks_df.withColumn("rating", lit(1))
# tracks_df = tracks_df.withColumn("rating", (rand() * 10 + 1).cast("integer"))

In [None]:
tracks_df.show()

In [None]:
# # Explode the tracks array column into multiple rows
# # tracks_df = slice_df.select("pid", explode("tracks").alias("track"))
# # tracks_df = slice_df.select("pid", "tracks", "tracks")
# tracks_df = slice_df.select("pid", explode("tracks").alias("track")).select("pid", "track.track_uri", "track.pos")

# # Select relevant columns and add a rating column with value 1
# playlist_track_df = tracks_df.withColumn("rating", lit(1))

# # Get distinct track_uri values and join with playlist_track_df
# all_tracks_df = slice_df.select(explode("tracks").alias("track")).select("track.track_uri").distinct()
# all_playlists_df = slice_df.select("pid").distinct()

# all_against_all = all_tracks_df.join(all_playlists_df).distinct()

# from pyspark.sql.functions import when, col

# # playlist_track_rating_df = playlist_track_df.join(all_against_all, ["pid", "track_uri"], "left_outer") \
# #     .withColumn("rating", when(col("pos").isNull(), 0).otherwise(1))

# playlist_track_rating_df = all_against_all.join(playlist_track_df, ["pid", "track_uri"], "left_outer") \
#     .withColumn("rating", when(col("pos").isNull(), 0).otherwise(1)) \
#     .drop("pos")


In [None]:
playlist_track_rating_df = tracks_df.withColumn("song_id", dense_rank().over(Window.orderBy("track_uri")))

In [None]:
playlist_track_rating_df.show(truncate=False)

In [None]:
als = ALS(userCol="pid", itemCol="song_id", ratingCol="rating", nonnegative=True, coldStartStrategy="drop")

In [None]:
from typing import Tuple
import random

def train_test_split(df: DataFrame, split_ratio: float, seed: Optional[int] = None) -> Tuple[DataFrame, DataFrame]:
  random.seed(seed)
  distinct_pids = df.select("pid").distinct().rdd.map(lambda x: x[0]).collect()
  random.shuffle(distinct_pids)
  split_index = int(len(distinct_pids) * split_ratio)
  train_pids = distinct_pids[:split_index]
  test_pids = distinct_pids[split_index:]
  train_df = df.filter(col("pid").isin(train_pids))
  test_df = df.filter(col("pid").isin(test_pids))
  return train_df, test_df



In [None]:
training, test = playlist_track_rating_df.randomSplit([0.8, 0.2], seed=42)

In [None]:
model = als.fit(training)

In [None]:
predictions = model.transform(test)

In [None]:
predictions.show()

In [None]:
evaluator = RegressionEvaluator(metricName="rmse", labelCol="rating",
                                predictionCol="prediction")
rmse = evaluator.evaluate(predictions)

In [None]:
predictions.filter(col("prediction") != "NaN").count(), predictions.filter(col("prediction") == "NaN").count()

In [None]:
rmse

In [None]:
subset = playlist_track_rating_df.select("pid").distinct().limit(1)
subUserRecs = model.recommendForUserSubset(subset, 10)

In [None]:
subset.show()

In [None]:
subUserRecs.show(truncate=False)

In [None]:
def song_name_from_id(song_id: int, reverse_lookup: DataFrame) -> str:
  return 
  
def interpretRecommendation(recommended_result: DataFrame) -> str:
  return

In [None]:
userRecs = model.recommendForAllUsers(1).orderBy("recommendations")
userRecs.show(truncate=False)
userRecs.count()

In [None]:
slice_df.filter(col("pid") == 1710).select(explode("tracks.track_name")).show()

In [None]:
track_uris = playlist_track_rating_df.filter(col("song_id") == 588).select("track_uri")
track_uris.first()