# Processing Data in spark because it massive

In [None]:
import pandas as pd
from sentence_transformers import SentenceTransformer
from pyspark.sql import SparkSession, functions as F
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import ArrayType, FloatType
import config

# Take caution, running this on local would probably crash
spark = (
        SparkSession.builder
         .master("local")
         .appName("Word Count")
       .config("spark.some.config.option", "some-value")
       .getOrCreate()
)

In [None]:
df = spark.read.csv(config.DATA_PATH, header=True, multiLine=True, escape='"', mode="DROPMALFORMED")
# Clean column names
for col in df.columns:
    clean_name = col.strip().replace('"', '')
    df = df.withColumnRenamed(col, clean_name)

In [None]:
df = df.select(
    F.trim(F.col("playlistname")).alias("playlistname"),
    F.trim(F.col("trackname")).alias("trackname"),
    F.trim(F.col("artistname")).alias("artistname")
)

# Clean text data and combine track+artist
df = df.withColumn("trackname", F.lower(F.col("trackname"))) \
       .withColumn("artistname", F.lower(F.col("artistname"))) \
       .withColumn("track_artist", F.concat(F.col("trackname"), F.lit(" by "), F.col("artistname")))

# Group by playlist and filter
playlist_df = df.groupBy("playlistname") \
                .agg(F.collect_list("track_artist").alias("tracklist")) \
                .filter(F.size("tracklist") > 40)
playlist_df.cache()

In [None]:
# Get unique songs and compute embeddings
unique_songs = playlist_df.select(F.explode("tracklist").alias("song")).distinct()

# Define schema for embedding UDF
embedding_schema = ArrayType(FloatType())

@pandas_udf(embedding_schema)
def encode_songs_udf(songs: pd.Series) -> pd.Series:
    """Pandas UDF to encode songs in parallel across workers"""
    # Load model on each worker (this happens once per worker)
    encoder = SentenceTransformer(config.EMBEDDING_MODEL, device='cpu')

    # Encode all songs in this partition
    embeddings = encoder.encode(songs.tolist(), convert_to_numpy=True)

    # Return as pandas Series
    return pd.Series([embedding for embedding in embeddings])

# Add partition ID to help with distribution
unique_songs = unique_songs.repartition(64)  # Adjust based on cluster size
embeddings_df = unique_songs.withColumn("embedding", encode_songs_udf("song"))
embeddings_df.cache()
# embeddings_df.write.parquet(config.EMBEDDINGS_FILE_PATH, mode="overwrite")
# embeddings_df = spark.read.json('/Volumes/mkopa_default/testing_brackly_murunga/generative_modelling/song_embeddings.parquet/')

In [None]:
# Process playlists in distributed fashion
processed_playlists = playlist_df.withColumn(
    "chunks",
    F.transform(
        F.sequence(F.lit(0), F.floor(F.size("tracklist") / config.CONTEXT_SIZE) - 1),
        lambda i: F.slice("tracklist", i * config.CONTEXT_SIZE + 1, config.CONTEXT_SIZE)
    )
).withColumn("split_index",
    F.floor(F.size("chunks") * F.lit(config.TRAIN_RATIO)).cast("int")
).withColumn("val_index",
    F.floor(F.size("chunks") * F.lit(config.TRAIN_RATIO + config.VAL_RATIO)).cast("int")
).withColumn("playlist_id", F.monotonically_increasing_id())


# Split into train/val/test and add chunk IDs
train_df = processed_playlists.withColumn(
    "train_chunks",
    F.slice("chunks", 1, F.col("split_index"))
).select(
    'playlist_id',
    F.monotonically_increasing_id().alias("chunk_id"),
    F.explode("train_chunks").alias("chunk")
)

val_df = processed_playlists.withColumn(
    "val_chunks",
    F.slice("chunks", F.col("split_index") + 1, F.col("val_index") - F.col("split_index"))
).select(
    'playlist_id',
    F.monotonically_increasing_id().alias("chunk_id"),
    F.explode("val_chunks").alias("chunk")
)

test_df = processed_playlists.withColumn(
    "test_chunks",
    F.slice("chunks", F.col("val_index") + 1, F.size("chunks") - F.col("val_index"))
).select(
    'playlist_id',
    F.monotonically_increasing_id().alias("chunk_id"),
    F.explode("test_chunks").alias("chunk")
)

In [None]:
# Function to process chunks using joins instead of broadcasting
def process_chunks_with_join(chunk_df, dataset_name):
    # Explode chunks to get individual songs with their positions
    exploded_chunks = chunk_df.select(
        "chunk_id",'playlist_id',
        F.posexplode("chunk").alias("pos", "song")
    )

    # Join with embeddings
    chunks_with_embeddings = exploded_chunks.join(
        embeddings_df,
        exploded_chunks.song == embeddings_df.song,
        "inner"
    ).drop(embeddings_df.song).select(
        "chunk_id", "pos", "song",'playlist_id',
        F.col("embedding").alias("embedding")
    )

    # Group by chunk_id and collect embeddings in order
    grouped_embeddings = chunks_with_embeddings.groupBy("chunk_id","playlist_id").agg(
        F.sort_array(
            F.collect_list(F.struct("pos", "embedding"))
        ).alias("sorted_embeddings")
    ).select(
        "chunk_id",
        "playlist_id",
        F.expr("transform(sorted_embeddings, x -> x.embedding)").alias("embeddings")
    )

    # Write to storage
    output_path = f"{config.OUTPUT_PATH}{dataset_name}"
    grouped_embeddings.write.parquet(output_path, mode="overwrite")

    return output_path

# Process each dataset
train_path = process_chunks_with_join(train_df, "train")
val_path = process_chunks_with_join(val_df, "val")
test_path = process_chunks_with_join(test_df, "test")

print(f"Train data written to: {train_path}")
print(f"Validation data written to: {val_path}")
print(f"Test data written to: {test_path}")

In [1]:
from playgenie.data.dataset import DataSet
import pandas as pd

In [2]:
dataset = DataSet(folder_path='../data/train')

In [7]:
dataset.__getitem__(0)

(tensor([[[-0.0772,  0.0288,  0.0500,  ..., -0.0355,  0.0366, -0.1130],
          [-0.0372, -0.0300,  0.0429,  ...,  0.0395,  0.0329, -0.0317],
          [-0.0850, -0.0648,  0.0652,  ...,  0.1156, -0.0064, -0.1484],
          ...,
          [-0.0985, -0.0101,  0.0528,  ..., -0.0148,  0.0127, -0.0121],
          [-0.1077,  0.0797, -0.0016,  ...,  0.0873,  0.0091,  0.0331],
          [-0.0233, -0.0761,  0.0458,  ...,  0.0778,  0.0160, -0.0382]],
 
         [[ 0.0056,  0.0249,  0.0208,  ..., -0.0228, -0.0080, -0.0405],
          [ 0.0704, -0.0093, -0.0020,  ...,  0.0013, -0.0194, -0.0466],
          [ 0.0622, -0.0403,  0.0346,  ..., -0.0404, -0.0292, -0.0221],
          ...,
          [ 0.0125,  0.0291, -0.0080,  ...,  0.0115,  0.0351, -0.0085],
          [ 0.0200,  0.0072,  0.0095,  ..., -0.0159,  0.0245, -0.0097],
          [ 0.0778, -0.0394,  0.0447,  ..., -0.0478,  0.0475, -0.0673]],
 
         [[-0.0039,  0.0017,  0.0329,  ..., -0.0278, -0.0586,  0.0134],
          [-0.0625,  0.0518,