In [1]:
import os
import re
from typing import Callable, Any
from functools import wraps
from time import time
import json

import random
from datetime import datetime

import pandas as pd
import torch
from torch import nn
from torch.nn import (
    Module,
    Linear,
    ReLU,
    TripletMarginLoss,
    TripletMarginWithDistanceLoss,
    PairwiseDistance,
    CosineSimilarity
)
from torch.nn.functional import normalize
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import BertTokenizer, BertModel
import pickle

In [2]:
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-uncased")
bert_model = BertModel.from_pretrained("bert-base-multilingual-uncased")

Some weights of the model checkpoint at bert-base-multilingual-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
bert_model.save_pretrained("embeddings/pretrained/BertModel")
bert_tokenizer.save_pretrained("embeddings/pretrained/BertTokenizer")

('embeddings/pretrained/BertTokenizer/tokenizer_config.json',
 'embeddings/pretrained/BertTokenizer/special_tokens_map.json',
 'embeddings/pretrained/BertTokenizer/vocab.txt',
 'embeddings/pretrained/BertTokenizer/added_tokens.json')

In [38]:
bert_model(
    **bert_tokenizer(
        "przykładowy tekst w języku polskim", return_tensors="pt", padding=True
    )
).last_hidden_state.mean(dim=1).shape

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper__index_select)

# Load data

### Create dataset:

- Positive example is a song paired with mean vector of a playlist songs, the song belongs to
- Negative example is a song paired with mean vector of a playlist songs, the song doesn't belong to


In [5]:
with open("data/all_in_one_playlist_dataset.json", "r") as file:
    data = json.load(file)
df = pd.DataFrame(data)

df = df.dropna()

In [6]:
numeric_features: dict = {
    "duration_ms": {
        "min_val": 0,
        "max_val": 6950000,
        "desc": "duration of a song in ms",
    },
    "danceability": {
        "min_val": 0.0,
        "max_val": 1.0,
        "desc": "Danceability describes how suitable a track is for dancing based on a combination of musical elements including tempo, rhythm stability, beat strength, and overall regularity. A value of 0.0 is least danceable and 1.0 is most danceable.",
    },
    "energy": {
        "min_val": 0.0,
        "max_val": 1.0,
        "desc": "Energy is a measure from 0.0 to 1.0 and represents a perceptual measure of intensity and activity. Typically, energetic tracks feel fast, loud, and noisy. For example, death metal has high energy, while a Bach prelude scores low on the scale. Perceptual features contributing to this attribute include dynamic range, perceived loudness, timbre, onset rate, and general entropy.",
    },
    "key": {
        "min_val": -1,
        "max_val": 11,
        "desc": "The key the track is in. Integers map to pitches using standard Pitch Class notation. E.g. 0 = C, 1 = C♯/D♭, 2 = D, and so on. If no key was detected, the value is -1.",
    },
    "loudness": {
        "min_val": -60.0,
        "max_val": 0.0,
        "desc": "The overall loudness of a track in decibels (dB). Loudness values are averaged across the entire track and are useful for comparing relative loudness of tracks. Loudness is the quality of a sound that is the primary psychological correlate of physical strength (amplitude). Values typically range between -60 and 0 db.",
    },
    "mode": {
        "min_val": 0,
        "max_val": 1,
        "desc": "Mode indicates the modality (major or minor) of a track, the type of scale from which its melodic content is derived. Major is represented by 1 and minor is 0.",
    },
    "speechiness": {
        "min_val": 0.0,
        "max_val": 1.0,
        "desc": "Speechiness detects the presence of spoken words in a track. The more exclusively speech-like the recording (e.g. talk show, audio book, poetry), the closer to 1.0 the attribute value. Values above 0.66 describe tracks that are probably made entirely of spoken words. Values between 0.33 and 0.66 describe tracks that may contain both music and speech, either in sections or layered, including such cases as rap music. Values below 0.33 most likely represent music and other non-speech-like tracks.",
    },
    "acousticness": {
        "min_val": 0.0,
        "max_val": 1.0,
        "desc": "A confidence measure from 0.0 to 1.0 of whether the track is acoustic. 1.0 represents high confidence the track is acoustic.",
    },
    "instrumentalness": {
        "min_val": 0.0,
        "max_val": 1.0,
        "desc": "Predicts whether a track contains no vocals. 'Ooh' and 'aah' sounds are treated as instrumental in this context. Rap or spoken word tracks are clearly 'vocal'. The closer the instrumentalness value is to 1.0, the greater likelihood the track contains no vocal content. Values above 0.5 are intended to represent instrumental tracks, but confidence is higher as the value approaches 1.0.",
    },
    "liveness": {
        "min_val": 0.0,
        "max_val": 1.0,
        "desc": "Detects the presence of an audience in the recording. Higher liveness values represent an increased probability that the track was performed live. A value above 0.8 provides strong likelihood that the track is live.",
    },
    "valence": {
        "min_val": 0.0,
        "max_val": 1.0,
        "desc": "A measure from 0.0 to 1.0 describing the musical positiveness conveyed by a track. Tracks with high valence sound more positive (e.g. happy, cheerful, euphoric), while tracks with low valence sound more negative (e.g. sad, depressed, angry).",
    },
}

text_features = ["track_name", "artist_name", "album_name"]


def preprocess_features(
    df: pd.DataFrame, numeric_features: dict[str, dict[str, Any]], section_feature: str
) -> pd.DataFrame:
    stand_data = {}

    # Process numeric features
    for key in numeric_features:
        # normalize data
        # df[key] = df[key].fillna(0)
        df[key] = (df[key] - numeric_features[key]["min_val"]) / (
            numeric_features[key]["max_val"] - numeric_features[key]["min_val"]
        )
        stand_data[key] = {
            "mean": df[key].mean(),
            "std": df[key].std(),
            "max": numeric_features[key]["max_val"],
            "min": numeric_features[key]["min_val"],
        }
        # # standarize data
        # df[key] = (df[key] - stand_data[key]["mean"]) / stand_data[key]["std"]

    # Process sections data
    seq_col = f"{section_feature}_seq"
    df[seq_col] = df.analysis_sections.apply(
        lambda section: torch.tensor(
            [item[section_feature] for item in (section if section is not None else [])]
        )
    )
    # Normalize/standarize data
    all_tempo = torch.tensor([item for row in df[seq_col] for item in row])

    max_tempo = all_tempo.max()
    min_tempo = all_tempo.min()
    all_tempo = (all_tempo - min_tempo) / (max_tempo - min_tempo)
    stand_data[seq_col] = {
        "mean": all_tempo.mean(),
        "std": all_tempo.std(),
        "max": max_tempo,
        "min": min_tempo,
    }

    df[seq_col] = df[seq_col].apply(
        lambda row: torch.tensor(
            [((x - min_tempo) / (max_tempo - min_tempo)) * 2 - 1 for x in row]
        )
    )

    # pad to 14
    df[seq_col] = df[seq_col].apply(lambda x: x[:14])
    df[seq_col] = df[seq_col].apply(
        lambda x: torch.cat(
            (
                x,
                torch.tensor([0] * (14 - len(x))),
            ),
            dim=0,
        )
    )

    return stand_data


stand_data = preprocess_features(df, numeric_features, "tempo")

In [7]:
stand_data

{'duration_ms': {'mean': 0.034333916002505675,
  'std': 0.011708760542764075,
  'max': 6950000,
  'min': 0},
 'danceability': {'mean': 0.6008363223754211,
  'std': 0.1676360658383639,
  'max': 1.0,
  'min': 0.0},
 'energy': {'mean': 0.6200577285783162,
  'std': 0.21908584763805444,
  'max': 1.0,
  'min': 0.0},
 'key': {'mean': 0.515858677661524,
  'std': 0.301661135622213,
  'max': 11,
  'min': -1},
 'loudness': {'mean': 0.8719418800245317,
  'std': 0.06646245390106514,
  'max': 0.0,
  'min': -60.0},
 'mode': {'mean': 0.6513954158538748,
  'std': 0.47668099722459145,
  'max': 1,
  'min': 0},
 'speechiness': {'mean': 0.09108276013034222,
  'std': 0.09830546862857148,
  'max': 1.0,
  'min': 0.0},
 'acousticness': {'mean': 0.2618604989430196,
  'std': 0.294296959722126,
  'max': 1.0,
  'min': 0.0},
 'instrumentalness': {'mean': 0.07918648350145194,
  'std': 0.22214266693042017,
  'max': 1.0,
  'min': 0.0},
 'liveness': {'mean': 0.18713169095805993,
  'std': 0.15585710596509122,
  'max': 1

# Train test split

In [8]:
random.seed(42)
train_split = 0.8
playlists = list(df.id_playlist.unique())
total_len = len(playlists)
random.shuffle(playlists)
train_playlists = set(playlists[: int(0.8 * total_len)])
test_playlists = set(playlists[len(train_playlists) :])

df_train = df.query("id_playlist in @train_playlists")
df_test = df.query("id_playlist in @test_playlists")

In [9]:
# playlist = df[df.id_playlist == 2]
# torch.tensor(playlist[list(numeric_features.keys())].mean())

In [10]:
df.columns

Index(['id_playlist', 'name', 'collaborative', 'pid', 'modified_at',
       'num_tracks', 'num_albums', 'num_followers', 'num_edits', 'num_artists',
       'pos', 'artist_name', 'track_uri', 'artist_uri', 'track_name',
       'album_uri', 'duration_ms', 'album_name', 'danceability', 'energy',
       'key', 'loudness', 'mode', 'speechiness', 'acousticness',
       'instrumentalness', 'liveness', 'valence', 'tempo', 'time_signature',
       'num_samples', 'duration', 'offset_seconds', 'window_seconds',
       'analysis_sample_rate', 'analysis_channels', 'end_of_fade_in',
       'start_of_fade_out', 'tempo_confidence', 'time_signature_confidence',
       'key_confidence', 'mode_confidence', 'analysis_sections', 'tempo_seq'],
      dtype='object')

In [11]:
def _ensure_exists(path_out: str) -> None:
    if os.path.exists(path_out):
        return
    os.makedirs(path_out)


class SpotifyDataset(Dataset):
    """Wrapper dataset that draws triplets from the Spotify dataset.

    size: arbitrary 'length' of dataset, number of triplets to draw in one epoch
    """

    def __init__(
        self,
        frame: pd.DataFrame,
        model: BertModel,
        tokenizer: BertTokenizer,
        seq_feature: str = "",
        size: int = 10000,
        numeric_features: dict = [],
        text_features: list = [],
        data_path: str = "data/embedds",
        prefix: str = "spotify",
        device: str = "cuda",
        prepare: bool = False,
    ):
        super(SpotifyDataset, self).__init__()
        self.df = frame
        self.playlists = list(frame.id_playlist.unique())
        self.numeric_features = numeric_features
        self.text_features = text_features
        self.seq_feature = seq_feature
        self.size = size
        self.model = model.to(device)
        self.tokenizer = tokenizer
        self.numeric_keys = list(numeric_features.keys())
        self.data_path = data_path
        self.prefix = prefix
        self.device = device
        if prepare:
            self._prepare()
        self.playlist_bags = self.prepare_bags()

    def __len__(self):
        return self.size

    def __getitem__(self, i):
        return self.get_random_triplet()

    def _prepare(self):
        _ensure_exists(self.data_path)
        for _, row in tqdm(self.df.iterrows(), total=len(self.df)):
            with open(
                os.path.join(
                    self.data_path, f"{self.prefix}_{row['track_uri']}.pckl"
                ).replace(":", "_"),
                "wb",
            ) as file:
                pickle.dump(
                    self.model(
                        **self.tokenizer(
                            list(row[self.text_features]),
                            return_tensors="pt",
                            padding=True,
                        ).to(self.device)
                    )
                    .last_hidden_state.mean(dim=1)
                    .flatten(),
                    file,
                )

    def read_uri(self, uri):
        with open(
            os.path.join(
                self.data_path,
                f"{self.prefix}_{uri}.pckl".replace(":", "_"),
            ),
            "rb",
        ) as file:
            return pickle.load(file)

    def prepare_bags(self):
        bags = {}
        print("Prepairing playlists")
        for _id in tqdm(self.playlists):
            playlist = self.df[self.df.id_playlist == _id]
            playlist_embedd = torch.stack(
                [self.read_uri(row["track_uri"]) for i, row in playlist.iterrows()]
            ).mean(axis=0)

            bags[_id] = {
                "data": playlist,
                "features": torch.tensor(playlist[self.numeric_keys].mean()).float(),
                "embedds": playlist_embedd.float(),
                "seq": torch.stack(list(playlist[self.seq_feature]))
                .float()
                .mean(axis=0)
                .reshape(-1, 1),
            }
        return bags

    def get_random_triplet(self, track_uri: str = None, to_cpu: bool = False):
        if track_uri:
            try:
                anchor_song = self.df[self.df["track_uri"] == track_uri].iloc[0]
            except:
                print(track_uri)
                raise
        else:
            # get a random song and its playlist
            anchor_song = self.df.sample(1).iloc[0]
        anchor_playlist_id: int = anchor_song.id_playlist

        # positive eanchor_songxample
        positive_playlist_id: int = anchor_playlist_id

        # negative example -> get a playlist song does't belong to. Stupid solution
        while True:
            negative_playlist_id = random.choice(self.playlists)
            # check if it is a different playlist
            if negative_playlist_id == anchor_playlist_id:
                continue  # find another oneds
            # check if it doesn't have our song
            if len(
                self.df[
                    (self.df.id_playlist == negative_playlist_id)
                    & (self.df.track_uri == anchor_song.track_uri)
                ]
            ):
                continue  # find another one
            break

        # positive and negative playlist data
        positive_playlist = self.playlist_bags[positive_playlist_id]["data"]
        negative_playlist = self.playlist_bags[negative_playlist_id]["data"]

        # get anchor numeric features
        anchor_song_features = torch.tensor(anchor_song[self.numeric_keys]).float()

        # get positive and negative playlist numeric features
        positive_playlist_features = self.playlist_bags[positive_playlist_id][
            "features"
        ]
        negative_playlist_features = self.playlist_bags[negative_playlist_id][
            "features"
        ]

        # get anchor text features
        anchor_song_embedds = self.read_uri(anchor_song["track_uri"]).float()

        # get positive and negative playlist text features
        positive_playlist_embedds = self.playlist_bags[positive_playlist_id]["embedds"]
        negative_playlist_embedds = self.playlist_bags[negative_playlist_id]["embedds"]

        # get anchor sequence features
        anchor_song_seq = anchor_song[self.seq_feature].reshape(-1, 1).float()

        # get positive/negative playlist seq features
        positive_playlist_seq = self.playlist_bags[positive_playlist_id]["seq"]
        negative_playlist_seq = self.playlist_bags[negative_playlist_id]["seq"]

        if to_cpu is True:
            device = "cpu"
        else:
            device = self.device

        return (
            (
                anchor_song_features.to(device),
                anchor_song_embedds.to(device),
                anchor_song_seq.to(device),
            ),
            (
                positive_playlist_features.to(device),
                positive_playlist_embedds.to(device),
                positive_playlist_seq.to(device),
            ),
            (
                negative_playlist_features.to(device),
                negative_playlist_embedds.to(device),
                negative_playlist_seq.to(device),
            ),
        )
    
    def get_anchor(self, track_uri: str = None, to_cpu: bool = False):
        anchor_song = self.df[self.df["track_uri"] == track_uri].iloc[0]
        anchor_playlist_id: int = anchor_song.id_playlist

        # get anchor numeric features
        anchor_song_features = torch.tensor(anchor_song[self.numeric_keys]).float()


        # get anchor text features
        anchor_song_embedds = self.read_uri(anchor_song["track_uri"]).float()

        # get anchor sequence features
        anchor_song_seq = anchor_song[self.seq_feature].reshape(-1, 1).float()

        if to_cpu is True:
            device = "cpu"
        else:
            device = self.device

        return (
            (
                anchor_song_features.to(device),
                anchor_song_embedds.to(device),
                anchor_song_seq.to(device),
            )
        ), (anchor_song)

In [12]:
dataset_train = SpotifyDataset(
    frame=df_train,
    model=bert_model,
    tokenizer=bert_tokenizer,
    seq_feature="tempo_seq",
    numeric_features=numeric_features,
    text_features=text_features,
    size=10000,
    prefix="train",
    prepare=False,
)

dataset_test = SpotifyDataset(
    frame=df_test,
    model=bert_model,
    tokenizer=bert_tokenizer,
    seq_feature="tempo_seq",
    numeric_features=numeric_features,
    text_features=text_features,
    size=100,
    prefix="test",
    prepare=False,
)

Prepairing playlists


100%|██████████| 2674/2674 [00:23<00:00, 113.29it/s]


Prepairing playlists


100%|██████████| 669/669 [00:05<00:00, 115.93it/s]


# Neural network

In [18]:
def measure_time(fcn):
    @wraps(fcn)
    def wrapped(*args, **kwargs):
        start_time = time()
        res = fcn(*args, **kwargs)
        end_time = time()
        print("Execution time: {:.4f}".format(end_time - start_time))
        return res

    return wrapped


# @measure_time
def run_epoch(
    model,
    data_loader,
    loss_function: Callable,
    optimizer,
    device: str = "cuda",
) -> float:
    average_loss = 0.0
    n_batches = 0
    for (anchor, positive, negative) in data_loader:
        # klasyfikator i funkcja kosztu
        anchor_embedd = model(
            anchor[0].to(device), anchor[1].to(device), anchor[2].to(device)
        )
        positive_embedd = model(
            positive[0].to(device), positive[1].to(device), positive[2].to(device)
        )
        negative_embedd = model(
            negative[0].to(device), negative[1].to(device), negative[2].to(device)
        )
        l = loss_function(anchor_embedd, positive_embedd, negative_embedd)

        l.backward()

        # optymalizacja
        optimizer.step()
        optimizer.zero_grad()
        average_loss += l.item()
        n_batches += 1
    return average_loss / n_batches


def run_validate(model, data_loader, loss_function, device="cuda"):
    total_loss = 0.0
    i = 0
    with torch.no_grad():
        for (anchor, positive, negative) in data_loader:
            i += 1
            anchor_embedd = model(
                anchor[0].to(device), anchor[1].to(device), anchor[2].to(device)
            )
            positive_embedd = model(
                positive[0].to(device), positive[1].to(device), positive[2].to(device)
            )
            negative_embedd = model(
                negative[0].to(device), negative[1].to(device), negative[2].to(device)
            )
            l = loss_function(anchor_embedd, positive_embedd, negative_embedd)
            total_loss += l
    return {"loss": total_loss / i}


# @measure_time
def fit(
    model: Module,
    train_loader,
    test_loader,
    loss_function,
    optimizer,
    epochs: int,
    writer: SummaryWriter,
    device: str = "cuda",
    patience: int = 10,
    output_path: str = "torch_logs/checkpoints/best",
    run_prefix: str = "test",
    print_metrics: bool = True,
):
    min_val_loss = 1e10
    current_patience = 0
    for epoch in tqdm(range(epochs)):
        model.train()
        train_loss = run_epoch(
            model=model,
            data_loader=train_loader,
            loss_function=loss_function,
            optimizer=optimizer,
            device=device,
        )
        model.eval()
        val_results = run_validate(
            model=model,
            data_loader=test_loader,
            loss_function=loss_function,
            device=device,
        )
        val_loss = val_results["loss"]
        writer.add_scalars(
            main_tag=f"{run_prefix} loss",
            tag_scalar_dict={"train": train_loss, "val": val_loss},
            global_step=epoch + 1,
        )
        if print_metrics:
            print(f"Train loss: {train_loss} Val loss: {val_loss}")
        if val_loss < min_val_loss:
            min_val_loss = val_loss
            current_patience = 0
            _ensure_exists(os.path.split(output_path)[0])
            torch.save(
                obj={
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                },
                f=output_path + "_" + run_prefix,
            )
        else:
            current_patience += 1

        if current_patience >= patience:
            break
    model.eval()

In [40]:
log_dir = "tensorboard_logs"
_ensure_exists(log_dir)

writer_tensorboard = SummaryWriter(log_dir)

%reload_ext tensorboard
%tensorboard --logdir $log_dir --port=6011

In [20]:
class EmbeddingModel(Module):
    def __init__(self, kwargs):
        Module.__init__(self)
        self.input_features_dim = kwargs.get("input_features_dim", 11)
        self.input_text_dim = kwargs.get("input_text_dim", 2304)
        self.hidden_text_dim = kwargs.get("hidden_text_dim", 16)
        self.input_lstm_dim = kwargs.get("input_lstm_dim", 1)
        self.hidden_lstm_dim = kwargs.get("hidden_lstm_dim", 16)
        self.num_layers_lstm = kwargs.get("num_layers_lstm", 2)
        self.hidden_dense_dim = kwargs.get("num_layers_lstm", 32)
        self.output_dim = kwargs.get("output_dim", 16)
        self.lstmSections = nn.LSTM(
            input_size=self.input_lstm_dim,
            hidden_size=self.hidden_lstm_dim,
            num_layers=self.num_layers_lstm,
            batch_first=True,
        )  # lstm
        self.fc_text = Linear(self.input_text_dim, 16)
        merge_size = (
            self.input_features_dim + self.hidden_text_dim + self.hidden_lstm_dim
        )
        self.fc1 = Linear(merge_size, self.hidden_dense_dim)
        self.fc2 = Linear(self.hidden_dense_dim, self.output_dim)
        self.relu = ReLU()

    def forward(
        self,
        features: torch.Tensor,
        embedds: torch.Tensor,
        sections: torch.Tensor,
    ) -> torch.Tensor:
        # Output Text
        embedds = self.relu(self.fc_text(embedds.clone().detach().requires_grad_(True)))

        # Output LSTM
        outputSections, (hnSections, cnSections) = self.lstmSections(
            sections.clone().detach().requires_grad_(True)
        )
        hnSections = hnSections[-1].view(-1, self.hidden_lstm_dim)

        # Output cat
        output = self.relu(
            self.fc1(
                torch.cat(
                    (
                        embedds,
                        hnSections,
                        features.clone().detach().requires_grad_(True),
                    ),
                    axis=1,
                ).float()
            )
        )
        output = self.fc2(output)
        # Now process together
        output = normalize(output, 2)
        return output

In [21]:
model_embedding = EmbeddingModel({})

models = {"EmbeddingModel": model_embedding, "EmbeddingModelCosine": model_embedding}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
models = {k: v.to(device) for (k, v) in models.items()}

print(f"Starting with {device}")

EPOCHS = 200
BATCH_SIZE = 512
LR = {"EmbeddingModel": 0.0001, "EmbeddingModelCosine": 0.0001}

train_loader = torch.utils.data.DataLoader(
    dataset_train, batch_size=BATCH_SIZE, shuffle=True
)
val_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE)

optimizers = {
    "EmbeddingModel": torch.optim.Adam(
        models["EmbeddingModel"].parameters(), lr=LR["EmbeddingModel"]
    ),
    "EmbeddingModelCosine": torch.optim.Adam(
        models["EmbeddingModelCosine"].parameters(), lr=LR["EmbeddingModelCosine"]
    )
}

loss_fcn = {
    "EmbeddingModel": TripletMarginWithDistanceLoss(
        margin=0.1, distance_function=PairwiseDistance()
    ),
    "EmbeddingModelCosine": TripletMarginWithDistanceLoss(
        margin=0.1, distance_function=CosineSimilarity()
    )
}


to_train = ["EmbeddingModelCosine"]

for model_name in to_train:
    time_stamp = datetime.now().strftime("%d_%m_%y_%H_%M_%S")
    print("------------------------------------------")
    print(f"Starting training for model: {model_name}")
    fit(
        model=models[model_name],
        train_loader=train_loader,
        test_loader=val_loader,
        loss_function=loss_fcn[model_name],
        device=device,
        optimizer=optimizers[model_name],
        epochs=EPOCHS,
        writer=writer_tensorboard,
        output_path="torch_logs/checkpoints/best",
        patience=10,
        run_prefix=model_name + "_" + time_stamp,
        print_metrics=True,
    )
    checkpoint = torch.load(f"torch_logs/checkpoints/best_{model_name}_{time_stamp}")
    models[model_name].load_state_dict(checkpoint["model_state_dict"])
    optimizers[model_name].load_state_dict(checkpoint["optimizer_state_dict"]),

Starting with cuda
------------------------------------------
Starting training for model: EmbeddingModelCosine


  0%|          | 1/200 [01:13<4:02:40, 73.17s/it]

Train loss: 0.10409471727907657 Val loss: 0.10291509330272675


  1%|          | 2/200 [02:26<4:00:58, 73.02s/it]

Train loss: 0.10190884098410606 Val loss: 0.10195048898458481


  2%|▏         | 3/200 [03:39<3:59:54, 73.07s/it]

Train loss: 0.10110696367919444 Val loss: 0.10106384754180908


  2%|▏         | 4/200 [04:53<4:00:18, 73.57s/it]

Train loss: 0.10075028762221336 Val loss: 0.10087358206510544


  2%|▎         | 5/200 [06:05<3:57:45, 73.16s/it]

Train loss: 0.10057304874062538 Val loss: 0.10065126419067383


  3%|▎         | 6/200 [07:21<3:58:41, 73.82s/it]

Train loss: 0.10045204348862172 Val loss: 0.10045995563268661


  4%|▎         | 7/200 [08:34<3:56:54, 73.65s/it]

Train loss: 0.10039854645729065 Val loss: 0.10049018263816833


  4%|▍         | 8/200 [09:47<3:55:26, 73.57s/it]

Train loss: 0.10033045969903469 Val loss: 0.10025554150342941


  4%|▍         | 9/200 [11:01<3:54:42, 73.73s/it]

Train loss: 0.10027565211057662 Val loss: 0.10027941316366196


  5%|▌         | 10/200 [12:16<3:54:38, 74.10s/it]

Train loss: 0.10023749209940433 Val loss: 0.10019498318433762


  6%|▌         | 11/200 [13:29<3:52:24, 73.78s/it]

Train loss: 0.10022377409040928 Val loss: 0.10028062760829926


  6%|▌         | 12/200 [14:43<3:51:08, 73.77s/it]

Train loss: 0.10018972158432007 Val loss: 0.10018149763345718


  6%|▋         | 13/200 [15:57<3:49:44, 73.72s/it]

Train loss: 0.10016975998878479 Val loss: 0.10014872997999191


  7%|▋         | 14/200 [17:09<3:47:01, 73.24s/it]

Train loss: 0.100166841968894 Val loss: 0.10017968714237213


  8%|▊         | 15/200 [18:20<3:44:13, 72.72s/it]

Train loss: 0.10014724619686603 Val loss: 0.1001376286149025


  8%|▊         | 16/200 [19:31<3:41:24, 72.20s/it]

Train loss: 0.10013211853802204 Val loss: 0.10012941062450409


  8%|▊         | 17/200 [20:43<3:39:48, 72.07s/it]

Train loss: 0.10012853741645814 Val loss: 0.10019101947546005


  9%|▉         | 18/200 [21:53<3:36:49, 71.48s/it]

Train loss: 0.10011186935007572 Val loss: 0.10011507570743561


 10%|▉         | 19/200 [23:04<3:34:37, 71.15s/it]

Train loss: 0.10010519064962864 Val loss: 0.10010001808404922


 10%|█         | 20/200 [24:14<3:32:33, 70.85s/it]

Train loss: 0.10009394735097885 Val loss: 0.10006275027990341


 10%|█         | 21/200 [25:26<3:32:37, 71.27s/it]

Train loss: 0.10008501745760441 Val loss: 0.1001538410782814


 11%|█         | 22/200 [26:41<3:34:44, 72.38s/it]

Train loss: 0.10008628517389298 Val loss: 0.10011796653270721


 12%|█▏        | 23/200 [27:53<3:33:21, 72.33s/it]

Train loss: 0.10007667355239391 Val loss: 0.10008400678634644


 12%|█▏        | 24/200 [29:04<3:30:57, 71.92s/it]

Train loss: 0.10006949342787266 Val loss: 0.10013096779584885


 12%|█▎        | 25/200 [30:15<3:28:33, 71.51s/it]

Train loss: 0.10006714016199111 Val loss: 0.10014507174491882


 13%|█▎        | 26/200 [31:25<3:26:35, 71.24s/it]

Train loss: 0.10006060525774955 Val loss: 0.10005339980125427


 14%|█▎        | 27/200 [32:36<3:24:37, 70.97s/it]

Train loss: 0.10005681216716766 Val loss: 0.10005435347557068


 14%|█▍        | 28/200 [33:46<3:22:58, 70.81s/it]

Train loss: 0.10005283877253532 Val loss: 0.10013274103403091


 14%|█▍        | 29/200 [34:57<3:21:43, 70.78s/it]

Train loss: 0.10005181729793548 Val loss: 0.10008374601602554


 15%|█▌        | 30/200 [36:07<3:20:15, 70.68s/it]

Train loss: 0.10005009546875954 Val loss: 0.1000744104385376


 16%|█▌        | 31/200 [37:18<3:19:08, 70.70s/it]

Train loss: 0.10005112625658512 Val loss: 0.10011939704418182


 16%|█▌        | 32/200 [38:29<3:18:01, 70.72s/it]

Train loss: 0.1000465627759695 Val loss: 0.10003972798585892


 16%|█▋        | 33/200 [39:39<3:16:51, 70.72s/it]

Train loss: 0.1000414527952671 Val loss: 0.10007276386022568


 17%|█▋        | 34/200 [40:50<3:15:37, 70.71s/it]

Train loss: 0.10004152581095696 Val loss: 0.10004635900259018


 18%|█▊        | 35/200 [42:01<3:14:25, 70.70s/it]

Train loss: 0.1000351656228304 Val loss: 0.10005035996437073


 18%|█▊        | 36/200 [43:14<3:15:24, 71.49s/it]

Train loss: 0.1000314425677061 Val loss: 0.10009683668613434


 18%|█▊        | 37/200 [44:28<3:16:16, 72.25s/it]

Train loss: 0.1000351656228304 Val loss: 0.10005872696638107


 19%|█▉        | 38/200 [45:42<3:16:26, 72.75s/it]

Train loss: 0.10002976693212987 Val loss: 0.10001841932535172


 20%|█▉        | 39/200 [46:54<3:14:54, 72.64s/it]

Train loss: 0.10002956837415695 Val loss: 0.10002893954515457


 20%|██        | 40/200 [48:06<3:12:40, 72.26s/it]

Train loss: 0.10002698712050914 Val loss: 0.10002604871988297


 20%|██        | 41/200 [49:18<3:11:09, 72.14s/it]

Train loss: 0.10002655759453774 Val loss: 0.1000305563211441


 21%|██        | 42/200 [50:30<3:09:46, 72.06s/it]

Train loss: 0.10002415105700493 Val loss: 0.10003393888473511


 22%|██▏       | 43/200 [51:42<3:08:35, 72.07s/it]

Train loss: 0.10002022609114647 Val loss: 0.10001363605260849


 22%|██▏       | 44/200 [52:54<3:07:22, 72.06s/it]

Train loss: 0.10001758113503456 Val loss: 0.10001439601182938


 22%|██▎       | 45/200 [54:06<3:06:03, 72.02s/it]

Train loss: 0.10001654922962189 Val loss: 0.10003311187028885


 23%|██▎       | 46/200 [55:18<3:05:06, 72.12s/it]

Train loss: 0.10000997334718705 Val loss: 0.10004004836082458


 24%|██▎       | 47/200 [56:31<3:04:12, 72.24s/it]

Train loss: 0.10000836662948132 Val loss: 0.09999635815620422


 24%|██▍       | 48/200 [57:43<3:03:05, 72.27s/it]

Train loss: 0.09999947845935822 Val loss: 0.09999372065067291


 24%|██▍       | 49/200 [58:57<3:03:19, 72.84s/it]

Train loss: 0.09999760612845421 Val loss: 0.09999609738588333


 25%|██▌       | 50/200 [1:00:10<3:02:28, 72.99s/it]

Train loss: 0.09998940899968148 Val loss: 0.09989220649003983


 26%|██▌       | 51/200 [1:01:24<3:02:05, 73.33s/it]

Train loss: 0.09996733851730824 Val loss: 0.09994678944349289


 26%|██▌       | 52/200 [1:02:38<3:01:21, 73.52s/it]

Train loss: 0.09992575794458389 Val loss: 0.10010343044996262


 26%|██▋       | 53/200 [1:03:51<2:59:24, 73.23s/it]

Train loss: 0.09986348487436772 Val loss: 0.09951908141374588


 27%|██▋       | 54/200 [1:05:04<2:57:46, 73.06s/it]

Train loss: 0.09970543757081032 Val loss: 0.09872541576623917


 28%|██▊       | 55/200 [1:06:17<2:56:34, 73.07s/it]

Train loss: 0.0991179183125496 Val loss: 0.0946933701634407


 28%|██▊       | 56/200 [1:07:31<2:56:17, 73.45s/it]

Train loss: 0.0979144211858511 Val loss: 0.09964293986558914


 28%|██▊       | 57/200 [1:08:46<2:56:24, 74.02s/it]

Train loss: 0.09735613763332367 Val loss: 0.09608739614486694


 29%|██▉       | 58/200 [1:10:03<2:56:48, 74.71s/it]

Train loss: 0.09722314961254597 Val loss: 0.09612216800451279


 30%|██▉       | 59/200 [1:11:18<2:55:48, 74.81s/it]

Train loss: 0.09726430699229241 Val loss: 0.094892717897892


 30%|███       | 60/200 [1:12:33<2:54:32, 74.80s/it]

Train loss: 0.09646637588739396 Val loss: 0.09675885736942291


 30%|███       | 61/200 [1:13:46<2:52:23, 74.41s/it]

Train loss: 0.09713325276970863 Val loss: 0.09723033756017685


 31%|███       | 62/200 [1:15:00<2:50:55, 74.32s/it]

Train loss: 0.09590908214449882 Val loss: 0.10244843363761902


 32%|███▏      | 63/200 [1:16:13<2:48:31, 73.80s/it]

Train loss: 0.09609788618981838 Val loss: 0.09650662541389465


 32%|███▏      | 64/200 [1:17:28<2:47:54, 74.08s/it]

Train loss: 0.09571351185441017 Val loss: 0.09345302730798721


 32%|███▎      | 65/200 [1:18:42<2:46:42, 74.09s/it]

Train loss: 0.0948974385857582 Val loss: 0.09237977117300034


 33%|███▎      | 66/200 [1:19:56<2:45:26, 74.08s/it]

Train loss: 0.09357935115695 Val loss: 0.09491458535194397


 34%|███▎      | 67/200 [1:21:08<2:42:49, 73.46s/it]

Train loss: 0.0920678809285164 Val loss: 0.09172277897596359


 34%|███▍      | 68/200 [1:22:21<2:41:11, 73.27s/it]

Train loss: 0.0910867266356945 Val loss: 0.08962676674127579


 34%|███▍      | 69/200 [1:23:35<2:40:34, 73.54s/it]

Train loss: 0.0888154398649931 Val loss: 0.08992360532283783


 35%|███▌      | 70/200 [1:24:48<2:38:56, 73.35s/it]

Train loss: 0.08756212592124939 Val loss: 0.08599342405796051


 36%|███▌      | 71/200 [1:25:59<2:36:12, 72.65s/it]

Train loss: 0.08582325428724288 Val loss: 0.08141154795885086


 36%|███▌      | 72/200 [1:27:10<2:34:13, 72.29s/it]

Train loss: 0.08348294384777546 Val loss: 0.09379773586988449


 36%|███▋      | 73/200 [1:28:22<2:32:58, 72.27s/it]

Train loss: 0.08317503184080124 Val loss: 0.08297254145145416


 37%|███▋      | 74/200 [1:29:37<2:33:02, 72.88s/it]

Train loss: 0.08270604126155376 Val loss: 0.08259501308202744


 38%|███▊      | 75/200 [1:30:51<2:32:51, 73.37s/it]

Train loss: 0.08181416504085064 Val loss: 0.08992074429988861


 38%|███▊      | 76/200 [1:32:07<2:32:52, 73.97s/it]

Train loss: 0.08022334016859531 Val loss: 0.07657494395971298


 38%|███▊      | 77/200 [1:33:20<2:31:34, 73.94s/it]

Train loss: 0.08014401085674763 Val loss: 0.0876498892903328


 39%|███▉      | 78/200 [1:34:34<2:30:15, 73.90s/it]

Train loss: 0.0791904978454113 Val loss: 0.08292856067419052


 40%|███▉      | 79/200 [1:35:47<2:28:27, 73.61s/it]

Train loss: 0.07857236750423909 Val loss: 0.07809840887784958


 40%|████      | 80/200 [1:37:01<2:27:14, 73.62s/it]

Train loss: 0.07738675251603126 Val loss: 0.08303812146186829


 40%|████      | 81/200 [1:38:13<2:25:18, 73.26s/it]

Train loss: 0.0769883755594492 Val loss: 0.0825415626168251


 41%|████      | 82/200 [1:39:27<2:24:16, 73.36s/it]

Train loss: 0.07723356187343597 Val loss: 0.06980647146701813


 42%|████▏     | 83/200 [1:40:41<2:23:44, 73.72s/it]

Train loss: 0.07580381855368615 Val loss: 0.07926183193922043


 42%|████▏     | 84/200 [1:41:54<2:21:37, 73.26s/it]

Train loss: 0.07450390122830867 Val loss: 0.07782445847988129


 42%|████▎     | 85/200 [1:43:06<2:19:46, 72.93s/it]

Train loss: 0.07372906282544137 Val loss: 0.05950700491666794


 43%|████▎     | 86/200 [1:44:18<2:18:05, 72.68s/it]

Train loss: 0.07428744472563267 Val loss: 0.09563163667917252


 44%|████▎     | 87/200 [1:45:32<2:17:33, 73.04s/it]

Train loss: 0.07445858046412468 Val loss: 0.07133326679468155


 44%|████▍     | 88/200 [1:46:46<2:17:06, 73.45s/it]

Train loss: 0.07272207364439964 Val loss: 0.09172207862138748


 44%|████▍     | 89/200 [1:48:00<2:16:06, 73.57s/it]

Train loss: 0.07364913485944272 Val loss: 0.07509603351354599


 45%|████▌     | 90/200 [1:49:15<2:15:36, 73.97s/it]

Train loss: 0.0721798524260521 Val loss: 0.08909673243761063


 46%|████▌     | 91/200 [1:50:27<2:13:29, 73.48s/it]

Train loss: 0.07281517125666141 Val loss: 0.08864570409059525


 46%|████▌     | 92/200 [1:51:43<2:13:48, 74.34s/it]

Train loss: 0.07085649892687798 Val loss: 0.06880342960357666


 46%|████▋     | 93/200 [1:52:58<2:12:40, 74.40s/it]

Train loss: 0.07094623260200024 Val loss: 0.09253158420324326


 47%|████▋     | 94/200 [1:54:11<2:10:56, 74.11s/it]

Train loss: 0.0702742662280798 Val loss: 0.08815746009349823


 47%|████▋     | 94/200 [1:55:26<2:10:10, 73.68s/it]

Train loss: 0.06986254192888737 Val loss: 0.08472307026386261





In [23]:
checkpoint = torch.load(f"embeddings/pretrained/EmbeddingModel")
models["EmbeddingModel"].load_state_dict(checkpoint["model_state_dict"])
optimizers["EmbeddingModel"].load_state_dict(checkpoint["optimizer_state_dict"]),

(None,)

In [24]:
checkpoint = torch.load(f"embeddings/pretrained/EmbeddingModelCosine")
models["EmbeddingModelCosine"].load_state_dict(checkpoint["model_state_dict"])
optimizers["EmbeddingModelCosine"].load_state_dict(checkpoint["optimizer_state_dict"]),

(None,)

In [25]:
random.seed(42)

In [None]:
# uris_train = df_train["track_uri"]
# songs_train = [dataset_train.get_anchor(uri, True) for uri in tqdm(uris_train)]

# with open("songs_train.pkl", "wb") as file:
#     pickle.dump(songs_train, file)

In [27]:
with open("songs_train.pkl", "rb") as file:
    songs_train = pickle.load(file)

In [None]:
# uris_test = df_test["track_uri"]
# songs_test = [dataset_test.get_anchor(uri, True) for uri in tqdm(uris_test)]

# with open("songs_test.pkl", "wb") as file:
#     pickle.dump(songs_test, file)

In [28]:
with open("songs_test.pkl", "rb") as file:
    songs_test = pickle.load(file)

In [29]:
songs_train_meta = [
    {
        "title": song[1]["track_name"],
        "artist_name": song[1]["artist_name"],
        "track_uri": song[1]["track_uri"],
        "data": song[0],
    }
    for song in tqdm(songs_train)
]


with open("songs_train_meta.pkl", "wb") as file:
    pickle.dump(songs_train_meta, file)

100%|██████████| 144052/144052 [00:01<00:00, 128060.52it/s]


In [30]:
with open("songs_train_meta.pkl", "rb") as file:
    songs_train_meta = pickle.load(file)

In [31]:
songs_test_meta = [
    {
        "title": song[1]["track_name"],
        "artist_name": song[1]["artist_name"],
        "track_uri": song[1]["track_uri"],
        "data": song[0],
    }
    for song in tqdm(songs_test)
]


with open("songs_test_meta.pkl", "wb") as file:
    pickle.dump(songs_test_meta, file)

100%|██████████| 36396/36396 [00:00<00:00, 112510.30it/s]


In [32]:
with open("songs_test_meta.pkl", "rb") as file:
    songs_test_meta = pickle.load(file)

In [33]:
for i, song in enumerate(tqdm(songs_train_meta)):
    with torch.no_grad():
        songs_train_meta[i]["embedding"] = models["EmbeddingModelCosine"](
            song["data"][0].to("cuda").unsqueeze(dim=0),
            song["data"][1].to("cuda").unsqueeze(dim=0),
            song["data"][2].to("cuda").unsqueeze(dim=0),
        )
        song["data"][0].to("cpu")
        song["data"][1].to("cpu")
        song["data"][2].to("cpu")
    
    
with open("songs_train_meta_embedding_cosine.pkl", "wb") as file:
    pickle.dump(songs_train_meta, file)

100%|██████████| 144052/144052 [01:25<00:00, 1694.58it/s]


In [34]:
for i, song in enumerate(tqdm(songs_test_meta)):
    with torch.no_grad():
        songs_test_meta[i]["embedding"] = models["EmbeddingModelCosine"](
            song["data"][0].to("cuda").unsqueeze(dim=0),
            song["data"][1].to("cuda").unsqueeze(dim=0),
            song["data"][2].to("cuda").unsqueeze(dim=0),
        )
        song["data"][0].to("cpu")
        song["data"][1].to("cpu")
        song["data"][2].to("cpu")

with open("songs_test_meta_embedding_cosine.pkl", "wb") as file:
    pickle.dump(songs_test_meta, file)

100%|██████████| 36396/36396 [00:22<00:00, 1650.28it/s]


In [37]:
songs_all_meta_embedding_cosine = songs_train_meta + songs_test_meta
with open("songs_all_meta_embedding_cosine.pkl", "wb") as file:
    pickle.dump(songs_all_meta_embedding_cosine, file)

In [None]:
df[df["track_uri"] == "spotify:track:45poGZbUrcBgMEed0xtV5W"].iloc[0]

In [None]:
songs_meta

In [None]:
song_to_test_1[2].shape

In [None]:
df.sample(10)[["artist_name", "track_name", "track_uri"]]