In [1]:
import numpy as np
import pandas as pd
import torch
from sentence_transformers import SentenceTransformer

  from tqdm.autonotebook import tqdm, trange


In [2]:
# Set random seed for everything (except sklearn)
RND_SEED: int = 12345
np.random.seed(RND_SEED) # for numpy, scipy
pd.core.common.random_state(RND_SEED) # for pandas
torch.set_default_device("cuda")
torch.manual_seed(RND_SEED)

# Resolution for graph images
WIDTH: int = 1366
HEIGHT: int = 768

In [4]:
df = pd.read_csv("./../data/Combined-2023.csv", encoding="utf-8", index_col=[0])
df["track_name"] = df["track_name"].astype("string")
df["streams"] = df["streams"].astype(str).apply(lambda x: float(x) / 1e6 if x.isdigit() else np.nan)
df["in_deezer_playlists"] = df["in_deezer_playlists"].astype(str).apply(lambda x: float(x.replace(",", "")) / 1000)
df["in_shazam_charts"] = df["in_shazam_charts"].astype(str).apply(
    lambda x: int(float(x.replace(",", ""))) if x.isdigit() else pd.NA
).astype("Int64")
df["key"] = df["key"].astype("category")
df["mode"] = df["mode"].astype("category")

In [6]:
model = SentenceTransformer("sentence-transformers/LaBSE", device="cuda")
res = model.encode(df.loc[0, "track_name"])
print(res)



[-3.63083854e-02 -6.22163005e-02  2.67840363e-02  4.80489284e-02
 -1.38645377e-02 -7.59204030e-02 -1.12822847e-02  3.01432591e-02
 -3.58952843e-02  2.07574312e-02 -4.19806279e-02 -2.14882605e-02
 -2.70952228e-02  3.74145098e-02 -2.15717256e-02 -7.61645734e-02
 -2.75994521e-02 -2.76782475e-02  8.41530913e-04  5.19787669e-02
 -1.07569052e-02  1.01270704e-02 -3.42174023e-02 -3.78496386e-02
 -2.09101252e-02 -6.53083399e-02 -2.84665152e-02  5.44525646e-02
  1.78288352e-02  4.18733433e-03 -1.64238364e-02  3.80505584e-02
  6.42888062e-03  7.62426630e-02  1.29185466e-03 -3.71769853e-02
 -3.12337205e-02  1.61463376e-02  8.85865092e-02 -9.84560177e-02
 -4.82909083e-02  1.88684445e-02 -3.31109986e-02  4.68229875e-02
  5.17229624e-02 -1.98363420e-02  3.71056162e-02 -5.66946343e-02
  1.60592683e-02  3.61988991e-02  5.67144640e-02 -2.58656330e-02
 -6.26120344e-02 -6.63500354e-02 -2.92838458e-02  5.72132766e-02
 -6.14085719e-02 -4.18855958e-02 -4.30946127e-02 -3.79403643e-02
 -1.35353217e-02 -5.25105

In [9]:
res = model.encode(df.loc[2, "track_name"])