In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pprint 
import re

In [2]:
track_data = pd.read_csv('misc/processed_music_info.csv')
track_lyrics = pd.read_csv('misc/track_lyrics.csv')

In [3]:
track_data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 23584 entries, 0 to 23583
Data columns (total 19 columns):
 #   Column            Non-Null Count  Dtype  
---  ------            --------------  -----  
 0   track_id          23584 non-null  object 
 1   name              23584 non-null  object 
 2   artist            23584 non-null  object 
 3   spotify_id        23584 non-null  object 
 4   tags              23083 non-null  object 
 5   year              23584 non-null  int64  
 6   duration_ms       23584 non-null  int64  
 7   danceability      23584 non-null  float64
 8   energy            23584 non-null  float64
 9   key               23584 non-null  int64  
 10  loudness          23584 non-null  float64
 11  mode              23584 non-null  int64  
 12  speechiness       23584 non-null  float64
 13  acousticness      23584 non-null  float64
 14  instrumentalness  23584 non-null  float64
 15  liveness          23584 non-null  float64
 16  valence           23584 non-null  float6

In [4]:
track_lyrics.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 23584 entries, 0 to 23583
Data columns (total 2 columns):
 #   Column    Non-Null Count  Dtype 
---  ------    --------------  ----- 
 0   track_id  23584 non-null  object
 1   lyrics    23584 non-null  object
dtypes: object(2)
memory usage: 368.6+ KB


In [5]:
track_ids_exist = track_data['track_id'].isin(track_lyrics['track_id'])
all_track_ids_exist = track_ids_exist.all()
if not all_track_ids_exist:
	print('Some track ids do not exist in track_lyrics')
	print(track_data[~track_ids_exist])

In [6]:
def clean_text(text):
    text = re.sub(r'\s+', ' ', text)
    text = re.sub(r'[^A-Za-z0-9 ]+', ' ', text)
    text = text.replace("|||", " ")
    return text.strip()

track_lyrics['lyrics'] = track_lyrics['lyrics'].apply(clean_text)

In [7]:
track_lyrics.head()

Unnamed: 0,track_id,lyrics
0,TRIOREW128F424EAF0,Verse 1 Comin out of my cage and I ve been do...
1,TRRIVDJ128F429B0E8,Verse 1 Today is gonna be the day that they re...
2,TRXOGZT128F424AD74,Verse 1 Karma police arrest this man He talks...
3,TRUJIIV12903CA8848,Verse 1 The lights go out and I can t be save...
4,TRIODZU128E078F3E2,Verse 1 Sometimes I feel like I don t have a p...


In [8]:
import os
from sentence_transformers import SentenceTransformer, models
from torch import nn

saved_transformer_path = 'misc/sentence_transformer'

if os.path.exists(saved_transformer_path):
	model = SentenceTransformer(saved_transformer_path)
else:
	word_embedding_model = models.Transformer("bert-base-uncased", max_seq_length=150)
	pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
	dense_model = models.Dense(
		in_features=pooling_model.get_sentence_embedding_dimension(),
		out_features=150,
		activation_function=nn.Tanh(),
	)
	model = SentenceTransformer(modules=[word_embedding_model, pooling_model, dense_model])
	
	model.save(saved_transformer_path)

model = SentenceTransformer.load('misc/sentence_transformer')

In [12]:
lyrics_embeddings = model.encode(track_lyrics['lyrics'][0], show_progress_bar=True)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

In [14]:
lyrics_embeddings.shape

(150,)

In [15]:
pprint.pprint(lyrics_embeddings)

array([ 0.21813162, -0.05294361, -0.02294298,  0.20409702,  0.11121859,
       -0.07217646,  0.19420993,  0.08826566, -0.02758382, -0.02307406,
       -0.11153927, -0.06614064, -0.13606685,  0.08941007,  0.00794795,
       -0.00270648, -0.24280885, -0.15496674, -0.02099121, -0.28822514,
       -0.1263215 , -0.19553095,  0.2827556 ,  0.27845988,  0.19658986,
        0.29475227,  0.06451423, -0.26581156, -0.1363212 , -0.15106559,
       -0.24017142,  0.36102647, -0.14633602,  0.21057779, -0.13753334,
        0.35417604, -0.15312333, -0.06277979, -0.03696794, -0.33209848,
        0.42374694,  0.12091794,  0.23108965, -0.10289379,  0.38343132,
        0.02219431, -0.10446766,  0.1694062 , -0.06381396,  0.45691526,
       -0.19027442,  0.21414563,  0.091069  ,  0.27247635, -0.11386531,
       -0.1355532 ,  0.03778983, -0.16620581, -0.244795  , -0.1273817 ,
        0.06870184,  0.3158253 , -0.11028661,  0.05900016, -0.33063248,
       -0.04446921,  0.41326374, -0.05272987, -0.23222125, -0.24