In [None]:
%cd ..

In [None]:
import pickle
from pathlib import Path

from config.config import get_constants_dict
from models.processing import get_encoding_for_genre
from utils.magenta_models_utils import get_model
from utils.magenta_note_seq_utils import get_sec_for_num_bars
from utils.data_utils import get_all_songs_from_genres_of_size
from utils.lakh_utils import get_midi_path, get_matched_midi_md5, get_msd_score_matches
from utils.msd_utils import get_artist

import pandas as pd
import numpy as np
import plotly.express as px
from sklearn.manifold import TSNE

import note_seq as ns
from magenta.models.music_vae.trained_model import NoExtractedExamplesError

In [None]:
def load_encodings(genre="all"):        
    encodings_path = Path(constants["DATA_PATH"], "encodings")
    if genre == "all":
        encodings_files = list(encodings_path.glob("*.pkl"))
    else:
        encodings_files = list(encodings_path.glob(f"*{genre}*.pkl"))
        
    genre_labels, encodings, msd_ids = [], [], []
    for file in encodings_files:
        with file.open('rb') as f:
            if "genre" in file.stem:
                genre_labels.extend(pickle.load(f))
            elif "encoding" in file.stem:
                encodings.extend(pickle.load(f))
            elif "msd" in file.stem:
                msd_ids.extend(pickle.load(f))

    assert len(genre_labels) == len(encodings) == len(msd_ids)
    return genre_labels, encodings, msd_ids

        
def play_msd(msd_id):
    return ns.play_sequence(
        ns.midi_file_to_note_sequence(
            get_midi_path(
                msd_id,
                get_matched_midi_md5(msd_id, get_msd_score_matches())
            )
        )
    )


def tsne_embedd(
    encodings_arr,
    n_components=2,
    perplexity=30,
    early_exaggeration=12.0,
    learning_rate=200.0,
    n_iter=1000,
    n_iter_without_progress=300,
    metric='euclidean',
    init='random',
    random_state=42
):
    if isinstance(encodings_arr, list):
        enc = np.vstack(encodings_arr)
    else:
        enc = encodings_arr
        
    return TSNE(
    n_components=n_components,
    perplexity=perplexity,
    early_exaggeration=early_exaggeration,
    learning_rate=learning_rate,
    n_iter=n_iter,
    n_iter_without_progress=n_iter_without_progress,
    metric=metric,
    init=init,
    random_state=random_state
).fit_transform(enc)
    
    
def plot_embedding(embedded, labels, ids, df=None):
    if not df:
        df = pd.read_csv(constants["LMD_METADATA_CSV_FILE"])
        df = df[df.msdID.isin(ids)]
        
    if embedded.shape[1] == 2:
        return px.scatter(
            data_frame=df,
            x=embedded[:,0],
            y=embedded[:,1],
            color=labels,
            hover_name=ids,
            hover_data=["artist", "track"],
            labels={"color": "genre"}
        )
    elif embedded.shape[1] == 3:
        return px.scatter_3d(
            data_frame=df,
            x=embedded[:,0],
            y=embedded[:,1],
            z=embedded[:,2],
            color=labels,
            hover_name=ids,
            hover_data=["artist", "track"],
            labels={"color": "genre"}
        )
    else:
        pass
      

In [None]:
constants = get_constants_dict()
meta = pd.read_csv(constants["LMD_METADATA_CSV_FILE"])

## Load data

In [None]:
hh_g, hh_e, hh_m = load_encodings("hip hop rnb and dance hall")
country_g, country_e, country_m = load_encodings("country")

## Latent Space

In [None]:
model = get_model(constants["NAME_MUSICVAE_MULTITRACK"])

## Splitting

In [None]:
for bla in country_m:
    s = ns.midi_file_to_note_sequence(
        get_midi_path(bla, get_matched_midi_md5(bla, get_msd_score_matches()))
    )
    pm = ns.sequence_proto_to_pretty_midi(s)
    if len(pm.instruments) <= 8:
        print(meta[meta.msdID == bla].loc[:, ["msdID", "artist", "track"]])
        ns.play_sequence(s, synth=ns.fluidsynth)

In [None]:
hh1 = "TRZDMWV128E0796976"
hh2 = "TRKWKER128F1482409"
hh3 = "TRYFUNE12903CCDCD5"

In [None]:
c1 = "TRULTSR12903CBD13D"
c2 = "TRCRKXS12903CD21F9"
c3 = "TRQHQLN12903CBD237"

In [None]:
c = ns.midi_file_to_note_sequence(
    get_midi_path(c2, get_matched_midi_md5(c2, get_msd_score_matches()))
)

In [None]:
splits = ns.split_note_sequence(
    c,
    get_sec_for_num_bars(c, n_bars=1)
)

## Encoding

In [None]:
encs = []
for split in splits:
    try:
        encs.append(model.encode([split]))
    except NoExtractedExamplesError:
        encs.append(split)
        continue

In [None]:
[type(e) for e in encs]

## Decoding

In [None]:
decs = []
for enc in encs:
    try:
        decs.append(model.decode([enc]))
    except NoExtractedExamplesError:
        decs.append(enc)

In [None]:
ns.play_sequence(decs[0], synth=ns.fluidsynth)