# CSE 144 Group 3
## Music Recommendation System (MRS)

In this notebook, we write the predictive model for our music recommendation system. Our work leverages modern tools including recurrent neural networks (RNN) and BERT sentence transformers...

<br>

Our work leverages this RNN model:

https://github.com/taylorhawks/RNN-music-recommender/blob/master/cloud/model.ipynb


In [1]:
# import matplotlib.pyplot as plt
import pandas as pd
%matplotlib inline
%config InlineBackend.figure_format="retina"
import numpy as np
import random
import torch
import os
# from torch import nn, optim
# import math
# from IPython import display
# import torchvision.datasets as datasets
# import torchvision.transforms as transforms
# from torch.utils.data import TensorDataset
# import torch.nn.functional as F
# from sklearn.preprocessing import MinMaxScaler
# import pdb
import plotly.graph_objects as go
import numpy as np

from skimage.util.shape import view_as_windows as viewW
from sklearn.model_selection import train_test_split
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics.pairwise import pairwise_distances
from sklearn.decomposition import PCA

# import tensorflow as tf

# import keras.backend as K
from keras.models import Sequential, load_model
# from keras.optimizers import RMSprop
from keras.layers import Dense, SimpleRNN, Input
from keras.losses import *


### Load the data

In [2]:
song_features_data = pd.read_csv('misc/processed_music_info.csv')
user_listening_data = pd.read_csv('misc/processed_user_listening_hist.csv')

# from google.colab import drive
# drive.mount('/content/drive')
# import pandas as pd
# song_features_data = pd.read_csv('/content/drive/MyDrive/Colab Notebooks/music_info.csv')
# user_listening_data = pd.read_csv('/content/drive/MyDrive/Colab Notebooks/user_listening_hist.csv')

### Set Random Seed

In [3]:
torch.manual_seed(24)

<torch._C.Generator at 0x2bc83a45830>

### Read and Display Data

In [4]:
print('# of rows of Song Data: ' + str(len(song_features_data)))
print('# of unique songs: ' + str(len(song_features_data['track_id'].unique())))
song_features_data.head()

# of rows of Song Data: 23584
# of unique songs: 23584


Unnamed: 0,track_id,name,artist,spotify_id,tags,year,duration_ms,danceability,energy,key,loudness,mode,speechiness,acousticness,instrumentalness,liveness,valence,tempo,time_signature
0,TRIOREW128F424EAF0,Mr. Brightside,The Killers,09ZQ5TmUG8TSL56n0knqrj,"rock, alternative, indie, alternative_rock, in...",2004,222200,0.355,0.918,1,-4.36,1,0.0746,0.00119,0.0,0.0971,0.24,148.114,4
1,TRRIVDJ128F429B0E8,Wonderwall,Oasis,06UfBBDISthj1ZJAtX4xjj,"rock, alternative, indie, pop, alternative_roc...",2006,258613,0.409,0.892,2,-4.373,1,0.0336,0.000807,0.0,0.207,0.651,174.426,4
2,TRXOGZT128F424AD74,Karma Police,Radiohead,01puceOqImrzSfKDAcd1Ia,"rock, alternative, indie, alternative_rock, in...",1996,264066,0.36,0.505,7,-9.129,1,0.026,0.0626,9.2e-05,0.172,0.317,74.807,4
3,TRUJIIV12903CA8848,Clocks,Coldplay,0BCPKOYdS2jbQ8iyB56Zns,"rock, alternative, indie, pop, alternative_roc...",2002,307879,0.577,0.749,5,-7.215,0,0.0279,0.599,0.0115,0.183,0.255,130.97,4
4,TRIODZU128E078F3E2,Under the Bridge,Red Hot Chili Peppers,06zh28PcYIFvNOAz5Wq2Xb,"rock, alternative, alternative_rock, 90s, funk",2003,265506,0.554,0.49,4,-8.046,1,0.0457,0.0168,0.000534,0.136,0.513,84.275,4


In [5]:
print('# of rows of User Listening Data: ' + str(len(user_listening_data)))
print('# of unique users: ' + str(len(user_listening_data['user_id'].unique())))
user_listening_data.head()

# of rows of User Listening Data: 806745
# of unique users: 25343


Unnamed: 0,track_id,user_id,playcount
0,TRLATHU128F92FC275,5a905f000fc1ff3df7ca807d57edb608863db05d,11
1,TRMKFPN128F42858C3,5a905f000fc1ff3df7ca807d57edb608863db05d,2
2,TRGAOLV128E0789D40,5a905f000fc1ff3df7ca807d57edb608863db05d,2
3,TREAQSX128E07818CA,5a905f000fc1ff3df7ca807d57edb608863db05d,2
4,TRUMDRI128F424FEFC,5a905f000fc1ff3df7ca807d57edb608863db05d,3


### Data Preprocessing


In [6]:
# Join user_listening_data with song_features_data ON track_id
# data = pd.merge(song_features_data, user_listening_data, on='track_id')

# Drop unnecessary columns
song_features_data = song_features_data.drop(columns=['spotify_id', 'year', 'time_signature', 'key'])

In [7]:
# Commented out artist data preprocessing because a
# stringified version for Sentence BERT

# data["artists"] = data["artists"].str.replace("[", "")
# data["artists"] = data["artists"].str.replace("]", "")
# data["artists"] = data["artists"].str.replace("'", "")
# data["artists"] = data["artists"].map(lambda row: row.split(', '))

# Convert song duration from milliseconds to minutes
song_features_data["duration_mins"] = song_features_data["duration_ms"] / 60000
song_features_data.drop("duration_ms", axis=1, inplace=True)


song_features_data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 23584 entries, 0 to 23583
Data columns (total 15 columns):
 #   Column            Non-Null Count  Dtype  
---  ------            --------------  -----  
 0   track_id          23584 non-null  object 
 1   name              23584 non-null  object 
 2   artist            23584 non-null  object 
 3   tags              23083 non-null  object 
 4   danceability      23584 non-null  float64
 5   energy            23584 non-null  float64
 6   loudness          23584 non-null  float64
 7   mode              23584 non-null  int64  
 8   speechiness       23584 non-null  float64
 9   acousticness      23584 non-null  float64
 10  instrumentalness  23584 non-null  float64
 11  liveness          23584 non-null  float64
 12  valence           23584 non-null  float64
 13  tempo             23584 non-null  float64
 14  duration_mins     23584 non-null  float64
dtypes: float64(10), int64(1), object(4)
memory usage: 2.7+ MB


In [8]:
data = pd.merge(song_features_data, user_listening_data, on='track_id')
data.head()

Unnamed: 0,track_id,name,artist,tags,danceability,energy,loudness,mode,speechiness,acousticness,instrumentalness,liveness,valence,tempo,duration_mins,user_id,playcount
0,TRIOREW128F424EAF0,Mr. Brightside,The Killers,"rock, alternative, indie, alternative_rock, in...",0.355,0.918,-4.36,1,0.0746,0.00119,0.0,0.0971,0.24,148.114,3.703333,fe31db6d197a667d265ff5a35d80d60f3660f729,2
1,TRRIVDJ128F429B0E8,Wonderwall,Oasis,"rock, alternative, indie, pop, alternative_roc...",0.409,0.892,-4.373,1,0.0336,0.000807,0.0,0.207,0.651,174.426,4.310217,67874d1a189c83326c529e554be6f7acf55effae,12
2,TRRIVDJ128F429B0E8,Wonderwall,Oasis,"rock, alternative, indie, pop, alternative_roc...",0.409,0.892,-4.373,1,0.0336,0.000807,0.0,0.207,0.651,174.426,4.310217,e3ee8846c9a5a0916700a9e7abfc1c5b2fcb8e36,5
3,TRRIVDJ128F429B0E8,Wonderwall,Oasis,"rock, alternative, indie, pop, alternative_roc...",0.409,0.892,-4.373,1,0.0336,0.000807,0.0,0.207,0.651,174.426,4.310217,cbb6b8dccf0af0d221dfd4684072c04bb0346f30,2
4,TRRIVDJ128F429B0E8,Wonderwall,Oasis,"rock, alternative, indie, pop, alternative_roc...",0.409,0.892,-4.373,1,0.0336,0.000807,0.0,0.207,0.651,174.426,4.310217,2cdf67cd70a64964cb914835af0043fcc28a8f48,12


### Obtain total number of listens per song

In [9]:
play_counts = data.groupby('name')['playcount'].sum().reset_index()
play_counts

Unnamed: 0,name,playcount
0,#1 Zero,13
1,#16,110
2,#17,7
3,#24,5
4,$20 for Boban,43
...,...,...
23579,慟哭と去りぬ,134
23580,我、闇とて･･･,7
23581,朔-saku-,51
23582,蜷局,368


### Create playlists for input to RNN

In [10]:
data = data.sort_values(['user_id'])
data

Unnamed: 0,track_id,name,artist,tags,danceability,energy,loudness,mode,speechiness,acousticness,instrumentalness,liveness,valence,tempo,duration_mins,user_id,playcount
306346,TRTVXIH128F426625A,Come Round Soon,Sara Bareilles,"pop, female_vocalists, singer_songwriter, soul...",0.338,0.819,-4.495,0,0.0776,0.077700,0.000000,0.1590,0.545,74.751,3.552000,0000f88f8d76a238c251450913b0d070e4a77d19,2
417455,TRWUFEW128F14782F3,Forever My Friend,Ray LaMontagne,"folk, singer_songwriter, soul, blues, acoustic...",0.493,0.524,-13.553,1,0.0423,0.334000,0.014100,0.3570,0.379,176.233,5.788883,0000f88f8d76a238c251450913b0d070e4a77d19,2
32466,TRNXEPE128F9339E47,My Name Is Jonas,Weezer,"rock, alternative, indie, alternative_rock, in...",0.261,0.947,-3.031,1,0.0488,0.000197,0.003320,0.3100,0.550,185.942,3.435333,0000f88f8d76a238c251450913b0d070e4a77d19,2
698954,TRMKCCV128F92EB22E,Light On,David Cook,"rock, alternative_rock, male_vocalists",0.448,0.830,-4.156,0,0.0332,0.067300,0.000000,0.1130,0.362,131.991,3.816883,0000f88f8d76a238c251450913b0d070e4a77d19,3
227171,TRJGJTH128F4291A81,"Oh My God, Whatever, Etc.",Ryan Adams,"rock, indie, folk, singer_songwriter, acoustic...",0.572,0.395,-10.630,1,0.0304,0.700000,0.000250,0.1260,0.483,79.552,2.532667,0000f88f8d76a238c251450913b0d070e4a77d19,2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
802077,TRSEFCM128F429354D,Set It Up,Xavier Rudd,"acoustic, 00s",0.469,0.385,-11.300,0,0.0270,0.503000,0.001390,0.1150,0.116,130.767,4.141550,fffbab4b8416fc41d05fcbdcf0e6735c4f37cb39,2
417463,TRWUFEW128F14782F3,Forever My Friend,Ray LaMontagne,"folk, singer_songwriter, soul, blues, acoustic...",0.493,0.524,-13.553,1,0.0423,0.334000,0.014100,0.3570,0.379,176.233,5.788883,fffbab4b8416fc41d05fcbdcf0e6735c4f37cb39,8
649627,TRDVGIH128F429353C,Come Let Go,Xavier Rudd,"reggae, male_vocalists",0.547,0.546,-8.634,1,0.0470,0.114000,0.037200,0.3810,0.280,140.477,6.870433,fffbab4b8416fc41d05fcbdcf0e6735c4f37cb39,28
553208,TROXFVJ128F1465265,Bottom Of the Barrel,Amos Lee,"folk, soul, acoustic, guitar",0.609,0.346,-12.703,1,0.1460,0.761000,0.000000,0.1100,0.550,178.137,2.006433,fffbab4b8416fc41d05fcbdcf0e6735c4f37cb39,4


In [11]:
# Changed name to track_id
playlists = data.groupby('user_id')['track_id'].apply(lambda x: list(x.head(20)))
playlist_dict = playlists.to_dict()
print(playlists)

user_id
0000f88f8d76a238c251450913b0d070e4a77d19    [TRTVXIH128F426625A, TRWUFEW128F14782F3, TRNXE...
0005eb11fd1dad47e6e6719a4db30340073a9e38    [TRGOJNK128F92F2A03, TRQPSHM128F92F29ED, TRTUW...
000d80cd9b58a8f77b33aa613dcfc5cbf1daf5e8    [TRDYYKS128F4275626, TRBHLYP12903D0D107, TRABF...
000e9296161b73a1821aaed3d7f50d95e8665bf6    [TROPEIV128F428F5A8, TRIAZQY128F934D58D, TRMKA...
00100482b3f3074549c751e718c57ed211b35991    [TRSNCIW128F14557BC, TRJKPFL12903CCE490, TRWJN...
                                                                  ...                        
fff7352d8ca192c451ce4fa00d18e33e261ecad3    [TRDRVJA128F4267831, TRCKWGF12903CD2DCD, TRXUW...
fff759a45a3a68de552740e8285a97d5f65d4e58    [TRDJZFF128F92D2627, TRULONW128F9302209, TRBNY...
fff9bd021bf6e07936883b9bb045207fcf372a2c    [TROHXCJ128F935A6AC, TRUMJNK12903CF465A, TRXYM...
fffb0b218640d86e5cb99d41cd3ecad977142da5    [TRZGGHL12903CDBF1F, TRCAUIX128F4277AD0, TRYIK...
fffbab4b8416fc41d05fcbdcf0e6735c4f37cb39    [TRGPCUN

In [12]:
# Changed track_id to name
data_dict = data.drop(['name', 'artist', 'tags', 'playcount'], axis=1)
# Changed name to track_id
data_dict = data_dict.set_index(['user_id', 'track_id']).to_dict('index')

In [13]:
songs_done = 0
updated_playlist_dict = {}
for user_id, songs in playlist_dict.items():
    updated_songs = []
    for song in songs:
        key = (user_id, song)
        if key in data_dict:
            the_features = list(data_dict[key].values())
            updated_songs.append([song] + the_features)
            songs_done += 1
            if songs_done % 10000 == 0:
                print(songs_done)
    updated_playlist_dict[user_id] = updated_songs

playlist_dict = updated_playlist_dict

print(f"Total songs processed: {songs_done}")

10000
20000
30000
40000
50000
60000
70000
80000
90000
100000
110000
120000
130000
140000
150000
160000
170000
180000
190000
200000
210000
220000
230000
240000
250000
260000
270000
280000
290000
300000
310000
320000
330000
340000
350000
360000
370000
380000
390000
400000
410000
420000
430000
440000
450000
460000
470000
480000
490000
500000
Total songs processed: 506860


In [14]:
arr = []
for user_id, playlist in playlist_dict.items():
    arr2 = []
    for song in playlist:
        arr2.append(np.concatenate((song[0:4], song[5:10])))
    arr.append(arr2)

arr_np = np.array(arr)
print(arr_np)

[[['TRTVXIH128F426625A' '0.338' '0.819' ... '0.0' '0.159' '0.545']
  ['TRWUFEW128F14782F3' '0.493' '0.524' ... '0.0141' '0.357' '0.379']
  ['TRNXEPE128F9339E47' '0.261' '0.947' ... '0.00332' '0.31' '0.55']
  ...
  ['TRTWOCA128F14840B8' '0.634' '0.341' ... '0.14' '0.0861' '0.599']
  ['TRQSEMJ128F4294F24' '0.419' '0.537' ... '1.83e-06' '0.482' '0.737']
  ['TRUNKTP12903CD1EFB' '0.514' '0.507' ... '0.107' '0.185' '0.854']]

 [['TRGOJNK128F92F2A03' '0.645' '0.618' ... '0.0173' '0.149' '0.112']
  ['TRQPSHM128F92F29ED' '0.665' '0.517' ... '0.0' '0.0925' '0.609']
  ['TRTUWMO128F92F2A09' '0.638' '0.691' ... '3.31e-05' '0.161' '0.424']
  ...
  ['TRMIHFS128F92F2A01' '0.441' '0.819' ... '0.0176' '0.28' '0.66']
  ['TRRLGDR128F933A7C9' '0.42' '0.995' ... '0.000429' '0.0931' '0.469']
  ['TRLNFKN128F931BAF2' '0.6' '0.772' ... '0.0141' '0.0489' '0.424']]

 [['TRDYYKS128F4275626' '0.426' '0.995' ... '0.735' '0.56' '0.697']
  ['TRBHLYP12903D0D107' '0.17' '0.988' ... '0.106' '0.145' '0.0832']
  ['TRABFDT1

In [15]:
playlists = pd.DataFrame.from_dict(playlist_dict, orient='index')
playlists.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
0000f88f8d76a238c251450913b0d070e4a77d19,"[TRTVXIH128F426625A, 0.338, 0.819, -4.495, 0, ...","[TRWUFEW128F14782F3, 0.493, 0.524, -13.553, 1,...","[TRNXEPE128F9339E47, 0.261, 0.947, -3.031, 1, ...","[TRMKCCV128F92EB22E, 0.448, 0.83, -4.156, 0, 0...","[TRJGJTH128F4291A81, 0.572, 0.395, -10.63, 1, ...","[TRFQYFT128F14840BC, 0.524, 0.405, -10.679, 1,...","[TRKSXHR128F1455E4D, 0.341, 0.301, -11.074, 1,...","[TRGJTIY128F4296A0E, 0.372, 0.532, -7.968, 1, ...","[TRZLJOC128F14840BE, 0.513, 0.723, -7.211, 1, ...","[TRSMEUG128F14856D2, 0.487, 0.435, -9.753, 1, ...","[TRZYESA128F148D67F, 0.652, 0.227, -12.395, 1,...","[TRRUNEV128F148D719, 0.526, 0.772, -5.016, 1, ...","[TRDKRLP128F4291A80, 0.275, 0.784, -4.896, 1, ...","[TRFTUIW128E0784B9F, 0.597, 0.532, -10.583, 1,...","[TRKGCIA128F92C315D, 0.542, 0.911, -4.789, 0, ...","[TRFUCYR128F92DC67F, 0.544, 0.869, -4.079, 1, ...","[TROZZNY128F14782F7, 0.241, 0.0721, -22.115, 1...","[TRTWOCA128F14840B8, 0.634, 0.341, -14.214, 1,...","[TRQSEMJ128F4294F24, 0.419, 0.537, -7.301, 1, ...","[TRUNKTP12903CD1EFB, 0.514, 0.507, -12.25, 1, ..."
0005eb11fd1dad47e6e6719a4db30340073a9e38,"[TRGOJNK128F92F2A03, 0.645, 0.618, -8.594, 1, ...","[TRQPSHM128F92F29ED, 0.665, 0.517, -9.264, 1, ...","[TRTUWMO128F92F2A09, 0.638, 0.691, -7.948, 1, ...","[TRRNWAK128F92F29FB, 0.633, 0.686, -8.015, 1, ...","[TRYEGSH12903CD2DCE, 0.48, 0.79, -5.214, 1, 0....","[TRCKWGF12903CD2DCD, 0.729, 0.94, -4.933, 1, 0...","[TRNFVQI128F931BAEA, 0.641, 0.686, -6.727, 1, ...","[TRTKLFX12903CD2DC2, 0.646, 0.52, -7.289, 1, 0...","[TRPGPDK12903CCC651, 0.429, 0.563, -8.006, 0, ...","[TRJDMHS128F92F2A0C, 0.67, 0.583, -11.843, 1, ...","[TRLVQME128F931BAF3, 0.582, 0.523, -8.166, 1, ...","[TRUGOGT128F92F29E9, 0.673, 0.495, -8.299, 1, ...","[TRCXWLU128F92F2A0D, 0.536, 0.377, -13.255, 1,...","[TRRVJCK12903CD2DCB, 0.705, 0.605, -5.252, 1, ...","[TRCJAHJ128E07815B6, 0.776, 0.729, -5.427, 1, ...","[TRPWIGO128F931BAEB, 0.602, 0.69, -5.45, 1, 0....","[TRNEITZ128F92F29EA, 0.522, 0.804, -7.457, 1, ...","[TRMIHFS128F92F2A01, 0.441, 0.819, -8.414, 1, ...","[TRRLGDR128F933A7C9, 0.42, 0.995, -2.984, 0, 0...","[TRLNFKN128F931BAF2, 0.6, 0.772, -5.068, 1, 0...."
000d80cd9b58a8f77b33aa613dcfc5cbf1daf5e8,"[TRDYYKS128F4275626, 0.426, 0.995, -5.546, 0, ...","[TRBHLYP12903D0D107, 0.17, 0.988, -4.087, 1, 0...","[TRABFDT12903CADD73, 0.661, 0.886, -6.248, 0, ...","[TRLNVSC12903CADD67, 0.804, 0.739, -4.699, 1, ...","[TRKOCXI128F9316B54, 0.379, 0.532, -21.419, 1,...","[TRSEFCM128F429354D, 0.469, 0.385, -11.3, 0, 0...","[TRUWANM128F1485EE2, 0.673, 0.607, -7.672, 1, ...","[TRXKEMH128F423381D, 0.827, 0.729, -7.473, 1, ...","[TREMDON128F427C701, 0.751, 0.675, -8.159, 0, ...","[TRHPKWO128F92E01D5, 0.251, 0.642, -7.185, 1, ...","[TRPONOG128F4275608, 0.621, 0.825, -10.436, 1,...","[TRJGDTG128F421CE22, 0.606, 0.935, -4.389, 0, ...","[TROTYPC128E07940AB, 0.741, 0.37, -13.604, 0, ...","[TRPXIWX128F429831F, 0.632, 0.653, -7.384, 1, ...","[TROINZB128F932F740, 0.64, 0.767, -5.829, 1, 0...","[TROUAEG128F429354A, 0.473, 0.595, -10.2, 1, 0...","[TRQEBRP12903CADD6C, 0.564, 0.715, -7.654, 0, ...","[TROTWMO128F42B9238, 0.278, 0.163, -20.454, 0,...","[TRJYECB128F4230F29, 0.376, 0.325, -9.936, 1, ...","[TRJLGXB128F93043EA, 0.717, 0.559, -8.317, 1, ..."
000e9296161b73a1821aaed3d7f50d95e8665bf6,"[TROPEIV128F428F5A8, 0.281, 0.377, -9.827, 0, ...","[TRIAZQY128F934D58D, 0.788, 0.896, -6.863, 1, ...","[TRMKAZB128F92F2F3E, 0.566, 0.598, -7.202, 1, ...","[TRPHDFT128F92C5A75, 0.715, 0.534, -10.054, 1,...","[TRNXBBR128F425ECE3, 0.584, 0.242, -11.772, 0,...","[TRKPWGR128E078EE06, 0.551, 0.673, -7.362, 1, ...","[TRLPOFY128F425ECE8, 0.516, 0.241, -13.284, 1,...","[TRCHYZB128F425ECE1, 0.355, 0.35, -7.739, 0, 0...","[TRXEAZB128E078EDCE, 0.78, 0.884, -6.445, 1, 0...","[TRFVSOZ128F4281933, 0.389, 0.608, -7.531, 1, ...","[TRDMUWU128E078EDDB, 0.682, 0.405, -11.86, 1, ...","[TRDRFVY128F4281937, 0.46, 0.731, -6.866, 1, 0...","[TRIPLBA128F427200F, 0.692, 0.529, -8.459, 0, ...","[TRJSAID128F934D596, 0.518, 0.318, -11.312, 0,...","[TRMYAYJ128F934D0AF, 0.6, 0.432, -9.674, 0, 0....","[TRWGIOT128F425ECDE, 0.419, 0.248, -14.437, 1,...","[TRLRCIA128F425ECD7, 0.564, 0.499, -10.427, 0,...","[TRIAGDA128F4296176, 0.674, 0.723, -7.374, 1, ...","[TRIDPWO128F423DBC6, 0.338, 0.289, -13.479, 1,...","[TRPFLRB128F14A895D, 0.333, 0.75, -6.942, 1, 0..."
00100482b3f3074549c751e718c57ed211b35991,"[TRSNCIW128F14557BC, 0.306, 0.0694, -18.717, 0...","[TRJKPFL12903CCE490, 0.85, 0.369, -10.282, 1, ...","[TRWJNEC128E079654F, 0.616, 0.525, -7.411, 1, ...","[TRACWHF128F14557BB, 0.299, 0.119, -25.491, 1,...","[TRAZCMI128F14557B9, 0.62, 0.357, -11.9, 1, 0....","[TRUEXGL128F14557BD, 0.712, 0.175, -18.975, 0,...","[TREECSZ128F14557BE, 0.616, 0.292, -14.923, 1,...","[TRUAJOJ128F14557B6, 0.687, 0.678, -6.32, 1, 0...","[TRASVEM128E0796553, 0.619, 0.622, -9.251, 0, ...","[TROXRVT128E079650A, 0.399, 0.699, -6.981, 0, ...","[TRZJHGG128E079655A, 0.609, 0.472, -11.42, 1, ...","[TRIXKKQ12903CCE495, 0.631, 0.698, -4.616, 1, ...","[TRORPWW12903CCE48E, 0.755, 0.725, -4.79, 0, 0...","[TRYIASQ128E079650E, 0.576, 0.352, -10.773, 0,...","[TRDNHAW128F429DB9A, 0.463, 0.901, -2.885, 1, ...","[TRXYEKR128E079654C, 0.69, 0.529, -7.555, 0, 0...","[TRHZMPR128F42A52CB, 0.595, 0.394, -9.193, 1, ...","[TRXZMLY128E0796512, 0.608, 0.633, -8.046, 0, ...","[TRJSQQT128F149F9B4, 0.691, 0.923, -5.204, 1, ...","[TRXCZNS128F428A15E, 0.409, 0.928, -11.532, 1,..."


### Train and Test Split

In [16]:
# Train and test splits for playlist
X = arr_np[:,:-1,:]
Y = arr_np[:,1:,:]
x_train, x_val, y_train, y_val = train_test_split(X,Y,train_size=0.9,random_state=3000)
x_train, x_test, y_train, y_test = train_test_split(x_train,y_train,train_size=0.9,random_state=3000)


In [17]:
# print(x_train.shape)
# print(y_train.shape)
# print(x_val.shape)
# print(y_val.shape)
# print(x_test.shape)
# print(y_test.shape)

print(x_train[0, :, :])

[['TRMXZMN128F425980B' '0.683' '0.353' '-9.547' '0.0523' '0.953' '0.013'
  '0.0937' '0.575']
 ['TRXXVOG128F92F411D' '0.578' '0.752' '-5.264' '0.033' '0.0275' '0.0'
  '0.125' '0.405']
 ['TRXIQDL128F92F27DC' '0.626' '0.317' '-13.692' '0.0414' '0.931' '0.363'
  '0.0937' '0.568']
 ['TRIVUMW128F425980E' '0.514' '0.141' '-14.38' '0.0375' '0.917'
  '0.00255' '0.132' '0.127']
 ['TRCWYJA128F92F27D6' '0.487' '0.235' '-13.489' '0.0406' '0.929' '0.01'
  '0.157' '0.239']
 ['TRMFXAY128F92F27DF' '0.38' '0.155' '-16.735' '0.0458' '0.963' '0.0394'
  '0.693' '0.0393']
 ['TRIQKEJ128F9307761' '0.21' '0.197' '-11.895' '0.0336' '0.943' '0.0'
  '0.0995' '0.496']
 ['TRVJGDX128F42645CA' '0.581' '0.468' '-15.763' '0.0314' '0.0251'
  '0.0845' '0.0831' '0.49']
 ['TRXJWFS128F92F27DB' '0.188' '0.221' '-12.557' '0.0339' '0.928'
  '0.00227' '0.106' '0.0572']
 ['TRHFGRN128F427EFC3' '0.331' '0.559' '-8.263' '0.0316' '0.224' '0.0'
  '0.297' '0.236']
 ['TRHRFIN128F425DEF1' '0.381' '0.939' '-4.684' '0.125' '0.042' '0.918'

In [18]:
# Original Playlists
ops_x_train, ops_y_train, ops_x_val, ops_y_val, ops_x_test, ops_y_test = [], [], [], [], [], []

# This only works based on size if val and test sets switch in size switch them in these loops
for user in range(np.ma.size(x_train, axis=0)):
    names_x_train, names_y_train, names_x_val, names_y_val, names_x_test, names_y_test = [], [], [], [], [], []
    for song in range(np.ma.size(x_train, axis=1)):
        names_x_train.append(x_train[user, song, 0])
        names_y_train.append(y_train[user, song, 0])
        try:
            names_x_val.append(x_val[user, song, 0])
            names_y_val.append(y_val[user, song, 0])
        except IndexError:
            continue
        try:
            names_x_test.append(x_test[user, song, 0])
            names_y_test.append(y_test[user, song, 0])
        except IndexError:
            continue

    ops_x_train.append(names_x_train)
    ops_y_train.append(names_y_train)
    if not names_x_val:
        continue
    ops_x_val.append(names_x_val)
    ops_y_val.append(names_y_val)
    if not names_x_test:
        continue
    ops_x_test.append(names_x_test)
    ops_y_test.append(names_y_test)
x_train = x_train[:, :, 1:].astype(np.float64)
y_train = y_train[:, :, 1:].astype(np.float64)
x_val = x_val[:, :, 1:].astype(np.float64)
y_val = y_val[:, :, 1:].astype(np.float64)
x_test = x_test[:, :, 1:].astype(np.float64)
y_test = y_test[:, :, 1:].astype(np.float64)

In [19]:
print(x_train[0, :, :])
print(ops_x_train[0])

[[ 6.8300e-01  3.5300e-01 -9.5470e+00  5.2300e-02  9.5300e-01  1.3000e-02
   9.3700e-02  5.7500e-01]
 [ 5.7800e-01  7.5200e-01 -5.2640e+00  3.3000e-02  2.7500e-02  0.0000e+00
   1.2500e-01  4.0500e-01]
 [ 6.2600e-01  3.1700e-01 -1.3692e+01  4.1400e-02  9.3100e-01  3.6300e-01
   9.3700e-02  5.6800e-01]
 [ 5.1400e-01  1.4100e-01 -1.4380e+01  3.7500e-02  9.1700e-01  2.5500e-03
   1.3200e-01  1.2700e-01]
 [ 4.8700e-01  2.3500e-01 -1.3489e+01  4.0600e-02  9.2900e-01  1.0000e-02
   1.5700e-01  2.3900e-01]
 [ 3.8000e-01  1.5500e-01 -1.6735e+01  4.5800e-02  9.6300e-01  3.9400e-02
   6.9300e-01  3.9300e-02]
 [ 2.1000e-01  1.9700e-01 -1.1895e+01  3.3600e-02  9.4300e-01  0.0000e+00
   9.9500e-02  4.9600e-01]
 [ 5.8100e-01  4.6800e-01 -1.5763e+01  3.1400e-02  2.5100e-02  8.4500e-02
   8.3100e-02  4.9000e-01]
 [ 1.8800e-01  2.2100e-01 -1.2557e+01  3.3900e-02  9.2800e-01  2.2700e-03
   1.0600e-01  5.7200e-02]
 [ 3.3100e-01  5.5900e-01 -8.2630e+00  3.1600e-02  2.2400e-01  0.0000e+00
   2.9700e-01  2.

### Define the Model

In [20]:
if os.path.exists('misc/mae_optimized_model.keras'):
    model = load_model('misc/mae_optimized_model.keras')
else:
    model = Sequential()
    model.add(Input(shape=(None,8)))
    model.add(SimpleRNN(
        16,
        activation='linear',
        return_sequences=True,
        kernel_initializer='random_uniform',
    ))
    model.add(SimpleRNN(
        16,
        activation='linear',
        return_sequences=True,
        kernel_initializer='random_uniform',
    ))
    model.add(Dense(8, activation='linear', kernel_initializer='random_uniform',))
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(torch.cuda.get_device_name(0))
    
    model.compile(loss='mae', optimizer='adam')
    model.fit(x_train, y_train, epochs=10, batch_size=32, validation_data=(x_val, y_val))
    model.save('misc/mae_optimized_model.keras')

In [21]:

mae_optimized_model_adam = model

In [22]:
def predict_sample(sample,model):
    return (model.predict(np.array([sample]))[0,-1])

### Run RNN

In [44]:
print('Selecting a random index in our test dataset: ')
random_index = random.randint(0,len(x_test)-1)
print(random_index)

print('Input: ')
print(x_test[random_index])

print('\n','Output: ')
predicted = predict_sample(x_test[random_index], mae_optimized_model_adam)
print(predicted)

Selecting a random index in our test dataset: 
1721
Input: 
[[ 5.6000e-01  8.7100e-01 -9.0580e+00  1.1200e-01  4.8900e-03  7.3800e-04
   7.7200e-02  5.5300e-01]
 [ 5.2700e-01  7.8900e-01 -8.0490e+00  6.9400e-02  1.4000e-02  1.3300e-02
   3.8100e-01  4.5200e-01]
 [ 1.0300e-01  1.4000e-01 -1.7992e+01  3.4200e-02  8.7000e-01  6.8900e-01
   1.7800e-01  9.5500e-02]
 [ 3.7500e-01  4.0300e-01 -1.2776e+01  2.8600e-02  7.4500e-01  5.4600e-01
   1.8100e-01  2.0700e-01]
 [ 2.8500e-01  9.1500e-01 -5.1580e+00  7.4100e-02  2.6800e-03  2.7900e-02
   1.3700e-01  3.0900e-01]
 [ 1.9300e-01  9.6800e-01 -3.2430e+00  6.4300e-02  1.1100e-03  7.6200e-02
   1.1100e-01  3.1900e-01]
 [ 2.2900e-01  9.3100e-01 -3.4880e+00  8.1700e-02  3.3200e-04  4.0800e-06
   4.6000e-01  4.8800e-01]
 [ 3.6300e-01  9.2000e-01 -3.5220e+00  1.1700e-01  1.2400e-03  0.0000e+00
   8.1300e-02  3.3100e-01]
 [ 2.6700e-01  8.8700e-01 -5.9930e+00  8.1500e-02  1.1100e-05  7.9300e-05
   1.4500e-01  4.4000e-01]
 [ 3.1600e-01  9.1800e-01 -7.33

In [45]:
# np.save('song_embbeding', predicted)

In [46]:
distance_frame = data.drop(['artist','tags','tempo','duration_mins','user_id','playcount','mode'], axis=1)
distance_frame.head()

Unnamed: 0,track_id,name,danceability,energy,loudness,speechiness,acousticness,instrumentalness,liveness,valence
306346,TRTVXIH128F426625A,Come Round Soon,0.338,0.819,-4.495,0.0776,0.0777,0.0,0.159,0.545
417455,TRWUFEW128F14782F3,Forever My Friend,0.493,0.524,-13.553,0.0423,0.334,0.0141,0.357,0.379
32466,TRNXEPE128F9339E47,My Name Is Jonas,0.261,0.947,-3.031,0.0488,0.000197,0.00332,0.31,0.55
698954,TRMKCCV128F92EB22E,Light On,0.448,0.83,-4.156,0.0332,0.0673,0.0,0.113,0.362
227171,TRJGJTH128F4291A81,"Oh My God, Whatever, Etc.",0.572,0.395,-10.63,0.0304,0.7,0.00025,0.126,0.483


In [47]:
distance_frame.drop_duplicates(subset='track_id', keep='first', inplace=True)
distance_frame.track_id.nunique()

23584

In [48]:
def get_distances(data, p_vector):
    names = data['name']
    distance_frame = data.drop(['name'], axis=1)
    distance_dict = distance_frame.set_index(['track_id']).to_dict('index')
    for key in distance_dict:
        distance_dict[key] = list(distance_dict[key].values())
    distance_dict = distance_calc(distance_dict, p_vector, names)
    return pd.DataFrame.from_dict(distance_dict, orient='index', columns=['id', 'distance'])

def distance_calc(dict, v1, name_list):
    distances = {}
    i = 0
    name_list = name_list.to_list()
    for id in dict.keys():
        v2 = dict[id]
        value = 0.0
        for n in range(len(v1)):
            value += np.linalg.norm(v1[n] - v2[n])
        distances[name_list[i]] = (id, value)
        i += 1
    return distances

distance_frame2 = get_distances(distance_frame, predicted)


In [49]:
POTENTIAL_N = 50 #defines size

potential_songs = distance_frame2.nsmallest(POTENTIAL_N, columns='distance', keep='all')
print(potential_songs.shape)
potential_songs.head(20)

(50, 2)


Unnamed: 0,id,distance
Crab,TRVSLCD128E079268B,0.158175
Ghost In The Mirror,TRGBAFG12903CC5E15,0.181814
Dominhate,TRDSMFO128F426A81D,0.196451
Send In The Clowns,TRUHOOQ128E079251E,0.201744
Rational Eyes,TRCLZRQ128F422A819,0.20316
Faded Beauty Queens,TRTDSOD128E078115C,0.215894
This Could Be My Moment,TRMBNUB128F426848E,0.216018
Believe In Nothing,TRXOUBQ12903CFEA17,0.236443
Godspeed,TRIOPJP128F14A7703,0.25732
Beast Of Honor,TRTWKHF128F4252945,0.260519


In [50]:
ops_x_test[random_index]

['TRXOXHI128F426A36D',
 'TRSDDHV128F426A370',
 'TRRCJEI128F92C23B1',
 'TRKGGMK128F42286FE',
 'TRADCIF128F9338278',
 'TRTBASA128F92D262F',
 'TRJDDAZ128F92D262E',
 'TRKKBJS128F92EF7D2',
 'TRQTXHB128F92E3855',
 'TRGVSMR128F42B58E7',
 'TRXFKAF128E078884E',
 'TRHCIJJ128F9305C8A',
 'TRPZIII128F92D5420',
 'TRULNMQ128F92E1FDA',
 'TRQQORS128F930B14D',
 'TRDJZFF128F92D2627',
 'TRAWRKT128E0788857',
 'TRGUXOJ128E0788852',
 'TRIHPDV128F932B5DC']

In [51]:
lyrics_embeddings_csv = pd.read_csv('misc/lyrics_embeddings.csv')
lyrics_embeddings_3d_csv = pd.read_csv('misc/lyrics_embeddings_3d.csv')

In [52]:
lyrics_embeddings = dict()
lyrics_embeddings_3d  = dict()
for idx, row in lyrics_embeddings_csv.iterrows():
    lyrics_embeddings[row[0]] = np.array(row[1:])

for idx, row in lyrics_embeddings_3d_csv.iterrows():
    lyrics_embeddings_3d[row[0]] = np.array(row[1:])


In [53]:
candidates = dict()
for track_id in ops_x_test[random_index]:
    candidates[track_id] = lyrics_embeddings_3d[track_id]

cutoff = len(candidates)

for idx, row in potential_songs.iterrows():
    candidates[row['id']] = lyrics_embeddings_3d[row['id']]

len(candidates)

69

In [54]:
# For reducing dimensions of the embeddings
raw_embeddings = np.concatenate(list(lyrics_embeddings.values())).reshape(len(lyrics_embeddings), 768)
track_ids = list(lyrics_embeddings.keys())
dim_model = PCA(n_components=150, random_state=42)
dim_model.fit(raw_embeddings)
reduced_embeddings = dim_model.transform(raw_embeddings)
reduced_embeddings_dict = {track_ids[i]: reduced_embeddings[i] for i in range(len(track_ids))}

og_embeddings = np.array([reduced_embeddings_dict[track_id] for track_id in ops_x_test[random_index]])

At this stage, we must compare the embeddings in the predicted list against those in the original input list and find the best candidates
### Cosine Similarity

In [55]:
similarities = list()

for track_id in potential_songs['id']:

    candidate_embedding = reduced_embeddings_dict[track_id].reshape(1, -1)
    similarity = cosine_similarity(candidate_embedding, og_embeddings)
    similarities.append(np.mean(similarity))

similarities = np.array(similarities)
most_similar_indices = np.argsort(similarities)[::-1]
selected_songs_cs = potential_songs.iloc[most_similar_indices[:10]]
selected_songs_cs

Unnamed: 0,id,distance
City Noise,TRVWDBV12903CEACE0,0.264726
Beast Of Honor,TRTWKHF128F4252945,0.260519
Explodiert,TRUNCHE128F425BF5D,0.320959
Beyond Within,TRGTXCP12903CF05DF,0.294816
Made of Glass,TRHTLBU128F429327B,0.337567
Loose Nuts On The Veladrome,TRHFZEU128F9329BF9,0.303103
Over-rated,TRWBXCW128F4266098,0.310143
Eiszeit,TRTSQHD12903CE87CA,0.319483
Closet Monster,TRWBPDM128F4262371,0.328078
Rational Eyes,TRCLZRQ128F422A819,0.20316


### Pairwise Distances

In [56]:
candidate_embeddings = np.array([reduced_embeddings_dict[track_id] for track_id in selected_songs_cs['id']])

distances = pairwise_distances(candidate_embeddings, og_embeddings, metric='euclidean')
mean_distances = np.mean(distances, axis=1)
closest_candidates_indices = np.argsort(mean_distances)[:10]
selected_songs_pd = selected_songs_cs.iloc[closest_candidates_indices]
selected_songs_pd

Unnamed: 0,id,distance
Loose Nuts On The Veladrome,TRHFZEU128F9329BF9,0.303103
Closet Monster,TRWBPDM128F4262371,0.328078
Rational Eyes,TRCLZRQ128F422A819,0.20316
Over-rated,TRWBXCW128F4266098,0.310143
Made of Glass,TRHTLBU128F429327B,0.337567
Eiszeit,TRTSQHD12903CE87CA,0.319483
Beyond Within,TRGTXCP12903CF05DF,0.294816
Explodiert,TRUNCHE128F425BF5D,0.320959
Beast Of Honor,TRTWKHF128F4252945,0.260519
City Noise,TRVWDBV12903CEACE0,0.264726


In [57]:
closest_candidates_indices

array([5, 8, 9, 6, 4, 7, 3, 2, 1, 0], dtype=int64)

In [58]:
song_features_data[song_features_data['track_id'].isin(selected_songs_pd['id'])].set_index('track_id').reindex(selected_songs_pd['id'])

Unnamed: 0_level_0,name,artist,tags,danceability,energy,loudness,mode,speechiness,acousticness,instrumentalness,liveness,valence,tempo,duration_mins
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
TRHFZEU128F9329BF9,Loose Nuts On The Veladrome,Liars,"experimental, noise",0.362,0.868,-4.909,1,0.0842,0.000305,9e-06,0.118,0.528,73.784,2.319767
TRWBPDM128F4262371,Closet Monster,Voodoo Glow Skulls,ska,0.376,0.989,-4.862,1,0.0903,0.00198,0.0119,0.125,0.453,144.14,2.598883
TRCLZRQ128F422A819,Rational Eyes,Threat Signal,"industrial, thrash_metal, metalcore, melodic_d...",0.41,0.92,-4.951,1,0.0589,2e-06,0.0102,0.187,0.377,172.954,3.611417
TRWBXCW128F4266098,Over-rated,Gavin DeGraw,"rock, acoustic, male_vocalists, love, pop_rock",0.387,0.77,-4.937,1,0.0476,0.00633,0.0,0.159,0.341,160.854,4.195333
TRHTLBU128F429327B,Made of Glass,Trapt,"rock, alternative_rock, hard_rock, emo",0.613,0.914,-4.877,1,0.0419,0.000246,0.0,0.146,0.324,96.944,3.493333
TRTSQHD12903CE87CA,Eiszeit,Eisbrecher,"industrial, german",0.472,0.856,-4.762,0,0.0622,0.00189,0.00811,0.0721,0.383,173.185,3.65285
TRGTXCP12903CF05DF,Beyond Within,Nevermore,"thrash_metal, progressive_metal, power_metal",0.363,0.926,-4.919,0,0.0377,3.8e-05,0.0184,0.249,0.39,151.839,5.198883
TRUNCHE128F425BF5D,Explodiert,Bosse,"rock, german",0.493,0.939,-4.93,1,0.0837,0.0033,1.5e-05,0.177,0.515,146.966,4.85
TRTWKHF128F4252945,Beast Of Honor,Auf Der Maur,"rock, alternative, female_vocalists, alternati...",0.482,0.831,-4.879,0,0.0774,0.00238,0.00112,0.0971,0.352,124.207,3.453333
TRVWDBV12903CEACE0,City Noise,Scarling.,"rock, alternative, female_vocalists, alternati...",0.437,0.906,-4.928,0,0.0863,0.000185,0.0014,0.0169,0.374,131.05,3.2411


In [59]:
song_features_data[song_features_data['track_id'].isin(ops_x_test[random_index])].set_index('track_id').reindex(ops_x_test[random_index])

Unnamed: 0_level_0,name,artist,tags,danceability,energy,loudness,mode,speechiness,acousticness,instrumentalness,liveness,valence,tempo,duration_mins
track_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
TRXOXHI128F426A36D,Down Rodeo,Rage Against the Machine,"rock, alternative, metal, alternative_rock, ha...",0.56,0.871,-9.058,1,0.112,0.00489,0.000738,0.0772,0.553,84.673,5.343767
TRSDDHV128F426A370,Roll Right,Rage Against the Machine,"rock, alternative, metal, alternative_rock, ha...",0.527,0.789,-8.049,1,0.0694,0.014,0.0133,0.381,0.452,86.137,4.337767
TRRCJEI128F92C23B1,Good Morning Revival,Good Charlotte,"rock, punk, chillout, punk_rock, downtempo, em...",0.103,0.14,-17.992,1,0.0342,0.87,0.689,0.178,0.0955,87.321,0.941767
TRKGGMK128F42286FE,Screenager,Muse,"rock, alternative, indie, alternative_rock, in...",0.375,0.403,-12.776,0,0.0286,0.745,0.546,0.181,0.207,80.661,4.333333
TRADCIF128F9338278,City of Delusion,Muse,"rock, alternative, indie, alternative_rock, in...",0.285,0.915,-5.158,0,0.0741,0.00268,0.0279,0.137,0.309,119.817,4.804433
TRTBASA128F92D262F,Whereabouts Unknown,Rise Against,"rock, punk, hardcore, punk_rock",0.193,0.968,-3.243,0,0.0643,0.00111,0.0762,0.111,0.319,174.331,4.029767
TRJDDAZ128F92D262E,Hairline Fracture,Rise Against,"rock, punk, hardcore, punk_rock, american",0.229,0.931,-3.488,1,0.0817,0.000332,4e-06,0.46,0.488,165.403,4.042667
TRKKBJS128F92EF7D2,Tears Don't Fall,Bullet for My Valentine,"rock, metal, hardcore, metalcore, emo, screamo",0.363,0.92,-3.522,0,0.117,0.00124,0.0,0.0813,0.331,162.167,4.6511
TRQTXHB128F92E3855,Hyper Music,Muse,"rock, alternative, indie, alternative_rock, in...",0.267,0.887,-5.993,1,0.0815,1.1e-05,7.9e-05,0.145,0.44,121.629,3.345983
TRGVSMR128F42B58E7,New Born,Muse,"rock, alternative, indie, alternative_rock, pr...",0.316,0.918,-7.333,1,0.0932,0.00354,0.14,0.112,0.148,152.007,6.091333


In [60]:
fig = go.Figure()

text_data = list(candidates.keys())
embeddings_3d = np.concatenate(list(candidates.values())).reshape(len(candidates), 3)

color_data = ['blue' if i < cutoff else 'red' for i in range(len(candidates))]
for i in closest_candidates_indices:
    color_data[i] = 'green'
color_data[closest_candidates_indices[0]] = 'purple'

fig.add_trace(go.Scatter3d(
    x=embeddings_3d[:, 0],
    y=embeddings_3d[:, 1],
    z=embeddings_3d[:, 2],
    text=text_data,
    mode='markers',
    marker=dict(
        size=5,
        color=color_data,
        colorscale='Viridis',
        opacity=1
    )
))


fig.update_layout(
    scene=dict(
        xaxis=dict(title='x'),
        yaxis=dict(title='y'),
        zaxis=dict(title='z')
    ),
	width=1000,
    height=800
)

fig.show()
