In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pprint 
import re
from sklearn.manifold import TSNE
from sklearn.neighbors import NearestNeighbors

In [None]:
track_data = pd.read_csv('misc/processed_music_info.csv')
track_lyrics = pd.read_csv('misc/track_lyrics.csv')

In [None]:
track_data.info()

In [None]:
track_lyrics.info()

In [None]:
track_ids_exist = track_data['track_id'].isin(track_lyrics['track_id'])
all_track_ids_exist = track_ids_exist.all()
if not all_track_ids_exist:
	print('Some track ids do not exist in track_lyrics')
	print(track_data[~track_ids_exist])

In [None]:
def clean_text(text):
    if text == "This song is instrumental.":
        return ""
    text = re.sub(r'\s+', ' ', text)
    text = re.sub(r'[^A-Za-z0-9 ]+', ' ', text)
    text = text.replace("|||", " ")
    return text.strip()

track_lyrics['lyrics'] = track_lyrics['lyrics'].apply(clean_text)
track_data.set_index('track_id', inplace=True)

In [None]:
for idx, row in track_lyrics.iterrows():
    track_id = row['track_id']
    if track_id in track_data.index:
        track_data_row = track_data.loc[track_id]
        new_lyrics = f"song lyrics: {row['lyrics']}, song title: {track_data_row['name']}, song artist: {track_data_row['artist']}, song genres: {track_data_row['tags']}"
        track_lyrics.at[idx, 'lyrics'] = new_lyrics

In [None]:
track_lyrics.head()

In [None]:
import os
from sentence_transformers import SentenceTransformer, models
from torch import nn

saved_transformer_path = 'misc/sentence_transformer'
saved_embeddings_path = 'misc/lyrics_embeddings.npy'
saved_embeddings_3d_path = 'misc/lyrics_embeddings_3d.npy'

if os.path.exists(saved_transformer_path):
	model = SentenceTransformer(saved_transformer_path)
else:
	word_embedding_model = models.Transformer("bert-base-uncased", max_seq_length=150)
	pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
	dense_model = models.Dense(
		in_features=pooling_model.get_sentence_embedding_dimension(),
		out_features=150,
		activation_function=nn.Tanh(),
	)
	model = SentenceTransformer(modules=[word_embedding_model, pooling_model, dense_model])
	
	model.save(saved_transformer_path)

model = SentenceTransformer.load('misc/sentence_transformer')

In [None]:
if os.path.exists(saved_embeddings_path):
	lyrics_embeddings = np.load(saved_embeddings_path)
else:
	lyrics_embeddings = model.encode(track_lyrics['lyrics'], show_progress_bar=True)
	np.save('misc/lyrics_embeddings.npy', lyrics_embeddings)

In [None]:
lyrics_embeddings.shape

In [None]:
pprint.pprint(lyrics_embeddings[3])

In [None]:
if os.path.exists(saved_embeddings_3d_path):
	embeddings_3d = np.load(saved_embeddings_3d_path)
else:
	tsne = TSNE(n_components=3, random_state=42)
	embeddings_3d = tsne.fit_transform(lyrics_embeddings)

import plotly.graph_objects as go

fig = go.Figure()

fig.add_trace(go.Scatter3d(
    x=embeddings_3d[:, 0],
    y=embeddings_3d[:, 1],
    z=embeddings_3d[:, 2],
	text=track_data['tags'],
    mode='markers',
    marker=dict(
        size=2,
        color=embeddings_3d[:, 2], 
        colorscale='Viridis',
        opacity=0.8
    )
))

fig.update_layout(
    scene=dict(
        xaxis=dict(title='x'),
        yaxis=dict(title='y'),
        zaxis=dict(title='z')
    ),
	width=1000,
    height=800
)

fig.show()


In [None]:
np.save('misc/lyrics_embeddings_3d.npy', embeddings_3d)

In [None]:
random_index = np.random.randint(0, len(embeddings_3d))

nn_model = NearestNeighbors(n_neighbors=6)
nn_model.fit(embeddings_3d)

distances, indices = nn_model.kneighbors(embeddings_3d[random_index].reshape(1, -1))

nearest_indices = indices[0][1:]
nearest_indices = np.insert(nearest_indices, 0, random_index)
track_data.reset_index(drop=True, inplace=True)

print("target song:", random_index)
track_data.iloc[nearest_indices][['name', 'artist', 'year', 'loudness', 'danceability', 'liveness', 'tags']]