In [18]:
import os
import pandas as pd
from sklearn.model_selection import train_test_split
from ast import literal_eval

PATH_DATASET_CSV = "../dataset/Dataset40k/spotify_with_youtube_and_prompts.csv"
assert os.path.exists(PATH_DATASET_CSV), f"Received bad path `{PATH_DATASET_CSV}`"

PATH_SAVE_FINAL_CSV = "../dataset/Dataset40k/dataset.csv"
assert not os.path.exists(PATH_SAVE_FINAL_CSV), f"`{PATH_SAVE_FINAL_CSV}` already exists."

# Remove rows that have more than one genre
NOTE: The way I have design the train/valid split require single genre songs.<br>
You could do something more intelligently here, but there are so few cases that I didn't bother.

In [19]:
df = pd.read_csv(PATH_DATASET_CSV)
df["spotify_genre"] = df["spotify_genre"].apply(literal_eval)
assert all(df["spotify_genre"].apply(len) >= 1)
has_more_than_one_genre = df["spotify_genre"].apply(len) != 1
print(f"Rows with 2+ genres: {sum(has_more_than_one_genre)}/{len(df)}")
df = df[~has_more_than_one_genre]
df["spotify_genre"] = df["spotify_genre"].apply(lambda x: x[0])
print(f"Rows remaining: {len(df)}")

Rows with 2+ genres: 307/36770
Rows remaining: 36463


# Splits

In [20]:
# Decide on the number of bins for numerical variables
n_bins = 4
min_group_size = 2

# Removing rows with 5+ artists
artist_count = df["spotify_artists"].apply(literal_eval).apply(len)
print("Tracks with a more than 5 artists: ", sum(artist_count > 5))
df = df[artist_count <= 5]

# Create bins for 'spotify_valence' and 'spotify_energy'
df['valence_bin'] = pd.cut(df['spotify_valence'], bins=n_bins, labels=False)
df['energy_bin'] = pd.cut(df['spotify_energy'], bins=n_bins, labels=False)

# Combine 'spotify_genre', 'valence_bin', and 'energy_bin' to create a stratification key
df['stratify_key'] = (
        df['spotify_genre'].astype(str) + '_' +
        df['valence_bin'].astype(str) + '_' +
        df['energy_bin'].astype(str)
)

# Combine rare combinations into a single category
strat_counts = df['stratify_key'].value_counts()
rare_stratify_keys = list(strat_counts[strat_counts < min_group_size].keys())
df.loc[df['stratify_key'].isin(rare_stratify_keys), 'stratify_key'] = "rare"
strat_counts = df['stratify_key'].value_counts()
assert df['stratify_key'].value_counts().min() >= min_group_size
print(strat_counts)

# Stratified train-valid split
train_track_ids, valid_track_ids = train_test_split(
    df["spotify_track_id"],
    test_size=0.2,
    random_state=42,
    stratify=df['stratify_key']
)
assert len(set(train_track_ids.values)) + len(set(valid_track_ids.values)) == len(df) == len(set(df["spotify_track_id"].unique()))
assert not any(train_track_ids.isin(valid_track_ids))
assert not any(valid_track_ids.isin(train_track_ids))

# These are the 32 IDs used for experiment 2. 
# They will become the test dataset.
test_track_ids = [
    '5vjLSffimiIP26QG5WcN2K', 
    '2QjOHCTQ1Jl3zawyYOpxh6', 
    '60a0Rd6pjrkxjPbaKzXjfq', 
    '0yc6Gst2xkRu0eMLeRMGCX', 
    '5ygDXis42ncn6kYG14lEVG', 
    '3WMj8moIAXJhHsyLaqIIHI', 
    '5YbPxJwPfrj7uswNwoF1pJ', 
    '3XucsgiwXb8KPn9Csf9Zmu', 
    '2grjqo0Frpf2okIBiifQKs', 
    '3WBRfkOozHEsG0hbrBzwlm', 
    '1WCEAGGRD066z2Q89ObXTq', 
    '3ZCTVFBt2Brf31RLEnCkWJ', 
    '3S7A85bAWOd6ltk6r2ANOI', 
    '1Fid2jjqsHViMX6xNH70hE', 
    '5XeFesFbtLpXzIVDNQP22n', 
    '0lP4HYLmvowOKdsQ7CVkuq', 
    '2HZLXBOnaSRhXStMLrq9fD', 
    '2tTmW7RDtMQtBk7m2rYeSw', 
    '5KTBaWu8IOczQ0sPWzZ7MY', 
    '0o9zmvc5f3EFApU52PPIyW', 
    '2gYj9lubBorOPIVWsTXugG', 
    '37ZJ0p5Jm13JPevGcx4SkF', 
    '3zb8S65LhiPPPH4vov8yV2', 
    '4h9wh7iOZ0GGn8QVp4RAOB', 
    '4LRPiXqCikLlN15c3yImP7', 
    '5itOtNx0WxtJmi1TQ3RuRd', 
    '6mFkJmJqdDVQ1REhVfGgd1', 
    '39shmbIHICJ2Wxnk1fPSdz', 
    '44AyOl4qVkzS48vBsbNXaC', 
    '7eJMfftS33KTjuF7lTsMCx', 
    '2TktkzfozZifbQhXjT6I33', 
    '4RvWPyQ5RL0ao9LPZeSouE'
]

# Assign splits
df.loc[df["spotify_track_id"].isin(train_track_ids), "train_split"] = "train"
df.loc[df["spotify_track_id"].isin(valid_track_ids), "train_split"] = "valid"
df.loc[df["spotify_track_id"].isin(test_track_ids),  "train_split"] = "test"
renamer = {'valence_bin':'train_valence_bin', 'energy_bin':'train_energy_bin', 'stratify_key':'train_stratify_key'}
df = df.rename(columns=renamer)

# Validate that there's no data leakage between splits
train_df = df[df['train_split'] == 'train']
valid_df = df[df['train_split'] == 'valid']
test_df = df[df['train_split'] == 'test']
train_ids = set(train_df['spotify_track_id'])
valid_ids = set(valid_df['spotify_track_id'])
test_ids = set(test_df['spotify_track_id'])

train_valid_overlap = train_ids.intersection(valid_ids)
if train_valid_overlap: # Check for overlaps between train and validation sets
    print("Data leakage detected between train and validation sets!")
    print(f"Overlapping 'spotify_track_id's: {train_valid_overlap}")

train_test_overlap = train_ids.intersection(test_ids)
if train_test_overlap: # Check for overlaps between train and test sets
    print("Data leakage detected between train and test sets!")
    print(f"Overlapping 'spotify_track_id's: {train_test_overlap}")

valid_test_overlap = valid_ids.intersection(test_ids)
if valid_test_overlap: # Check for overlaps between validation and test sets
    print("Data leakage detected between validation and test sets!")
    print(f"Overlapping 'spotify_track_id's: {valid_test_overlap}")

total_unique_ids = set(df['spotify_track_id'])
assigned_ids = train_ids.union(valid_ids).union(test_ids)
if total_unique_ids != assigned_ids: # Additional validation: Ensure all 'spotify_track_id's are assigned to a split
    missing_ids = total_unique_ids - assigned_ids
    print("Some 'spotify_track_id's are not assigned to any split!")
    print(f"Missing 'spotify_track_id's: {missing_ids}")
print(df["train_split"].value_counts())

df = df.reset_index(drop=True)
df.to_csv(PATH_SAVE_FINAL_CSV, index=False)

Tracks with a more than 5 artists:  121
stratify_key
forro_3_3          376
salsa_3_3          241
ambient_0_0        236
grunge_1_3         219
salsa_3_2          211
                  ... 
world-music_2_0      2
industrial_0_0       2
piano_2_0            2
breakbeat_3_2        2
show-tunes_3_3       2
Name: count, Length: 1176, dtype: int64
train_split
train    29053
valid     7257
test        32
Name: count, dtype: int64
