In [1]:
%cd ..

/home/chris/DSR/plum_music


In [2]:
import pickle
from collections import Counter
from pathlib import Path

from config.config import get_constants_dict
from models.processing import get_encoding_for_genre
from utils.magenta_models_utils import get_model
from utils.magenta_note_seq_utils import get_sec_for_num_bars
from utils.data_utils import get_all_songs_from_genres_of_size
from utils.lakh_utils import get_midi_path, get_matched_midi_md5, get_msd_score_matches
from utils.msd_utils import get_artist

import numpy as np
import pandas as pd
import plotly.express as px
from sklearn.manifold import TSNE
from scipy.spatial import distance

import note_seq as ns
from mido import MidiFile
from music21 import midi

In [3]:
constants = get_constants_dict()
meta = pd.read_csv(constants["LMD_METADATA_CSV_FILE"])

In [15]:
def load_encodings(genre="all"):        
    encodings_path = Path(constants["DATA_PATH"], "encodings")
    if genre == "all":
        encodings_files = sorted(list(encodings_path.glob("*.pkl")))
    else:
        encodings_files = sorted(list(encodings_path.glob(f"*{genre}*.pkl")))
        
    genre_labels, encodings, msd_ids = [], [], []
    for file in encodings_files:
        with file.open('rb') as f:
            if "genre" in file.stem:
                genre_labels.extend(pickle.load(f))
            elif "encoding" in file.stem:
                encodings.extend(pickle.load(f))
            elif "msd" in file.stem:
                msd_ids.extend(pickle.load(f))

    assert len(genre_labels) == len(encodings) == len(msd_ids)
    return genre_labels, encodings, msd_ids

        
def play_msd(msd_id):
    return ns.play_sequence(
        ns.midi_file_to_note_sequence(
            get_midi_path(
                msd_id,
                get_matched_midi_md5(msd_id, get_msd_score_matches())
            )
        )
    )


def tsne_embedd(
    encodings_arr,
    n_components=2,
    perplexity=30,
    early_exaggeration=12.0,
    learning_rate=200.0,
    n_iter=1000,
    n_iter_without_progress=300,
    metric='euclidean',
    init='random',
    random_state=42
):
    if isinstance(encodings_arr, list):
        enc = np.vstack(encodings_arr)
    else:
        enc = encodings_arrcodings_arr
        
    return TSNE(
    n_components=n_components,
    perplexity=perplexity,
    early_exaggeration=early_exaggeration,
    learning_rate=learning_rate,
    n_iter=n_iter,
    n_iter_without_progress=n_iter_without_progress,
    metric=metric,
    init=init,
    random_state=random_state
).fit_transform(enc)
    
    
def plot_embedding(embedded, labels, ids, df=None):
    if not df:
        df = pd.read_csv(constants["LMD_METADATA_CSV_FILE"])
        df = df[df.msdID.isin(ids)]
        
    if embedded.shape[1] == 2:
        return px.scatter(
            data_frame=df,
            x=embedded[:,0],
            y=embedded[:,1],
            color=labels,
            hover_name=ids,
            hover_data=["artist", "track"],
            labels={"color": "genre"}
        )
    elif embedded.shape[1] == 3:
        return px.scatter_3d(
            data_frame=df,
            x=embedded[:,0],
            y=embedded[:,1],
            z=embedded[:,2],
            color=labels,
            hover_name=ids,
            hover_data=["artist", "track"],
            labels={"color": "genre"}
        )
    else:
        pass
      

## Load the encodings

In [16]:
genre_labels, encodings, msd_ids = load_encodings()

In [22]:
encodings_arr = np.vstack(encodings)

In [19]:
a = meta[meta.msdID.isin(msd_ids)].loc[:, ["msdID", "mb_genre"]]


In [20]:
for _, row in a.iterrows():
    idx = msd_ids.index(row.msdID)
    print(genre_labels[idx], row.mb_genre)

rock rock
classic pop and rock classic pop and rock
italian italian
canadian canadian
rock and indie rock and indie
uk uk
soul and reggae soul and reggae
french french
british british
uk uk
folk folk
classic pop and rock classic pop and rock
pop and chart pop and chart
classic pop and rock classic pop and rock
classic pop and rock classic pop and rock
pop pop
classic pop and rock classic pop and rock
classic pop and rock classic pop and rock
rock and indie rock and indie
british british
country country
soul and reggae soul and reggae
classic pop and rock classic pop and rock
classic pop and rock classic pop and rock
classic pop and rock classic pop and rock
classic pop and rock classic pop and rock
italian italian
country country
country country
uk uk
french french
classical classical
rock rock
american american
classic pop and rock classic pop and rock
rock and indie rock and indie
pop and chart pop and chart
hard rock hard rock
classic pop and rock classic pop and rock
classic pop an

In [23]:
files_per_genre = Counter(genre_labels)
files_per_genre.most_common()

[('classic pop and rock', 76),
 ('uk', 61),
 ('british', 53),
 ('american', 26),
 ('folk', 22),
 ('french', 21),
 ('hip hop rnb and dance hall', 21),
 ('rock and indie', 21),
 ('country', 20),
 ('rock', 20),
 ('italian', 19),
 ('pop and chart', 17),
 ('australian', 14),
 ('alternative rock', 13),
 ('new wave', 13),
 ('pop', 13),
 ('irish', 9),
 ('canadian', 8),
 ('soul and reggae', 8),
 ('finnish', 7),
 ('german', 5),
 ('hard rock', 5),
 ('spanish', 5),
 ('european', 4),
 ('production music', 4),
 ('rnb', 3),
 ('dance and electronica', 2),
 ('swedish', 2),
 ('classical', 1),
 ('heavy metal', 1),
 ('progressive rock', 1)]

## Transform in latent space

In [26]:
embedded = tsne_embedd(
    encodings,
    n_components=3,
    metric='euclidean',
    perplexity=50,
    init='pca'
)

In [27]:
plot_embedding(embedded, genre_labels, msd_ids)

In [None]:
outlier = "TRICWAP128F42988B7"
midi_path= get_midi_path(
        outlier,
        get_matched_midi_md5(outlier, get_msd_score_matches())
    )

### Only take two distinct genres

In [None]:
country_g, country_e, country_m = load_encodings("country")

In [None]:
hh_g, hh_e, hh_m = load_encodings("hip hop rnb and dance hall")

In [None]:
emb = tsne_embedd([*country_e, *hh_e], n_components=2)

In [None]:
plot_embedding(
    emb,
    [*country_g, *hh_g],
    [*country_m, *hh_m]
)

In [None]:
country_sample = "TRACLRS12903CE9386"
hh_sample = "TRCZMLQ128F9307825"

#### Show title and artist for all of them

In [None]:
cols = ["artist", "track", "mb_genre", "md5"]
meta[meta.msdID.isin([*country_m, *hh_m])].loc[:, cols]

### Calculate genre means

In [None]:
country_mean = np.mean(np.vstack(country_e), axis=0)
hh_mean = np.mean(np.vstack(hh_e), axis=0)

In [None]:
distance.euclidean(country_mean, hh_mean)

## Latent space shenanigans

In [None]:
country_to_hh = hh_mean - country_mean
hh_to_country = country_mean - hh_mean

In [None]:
model = get_model(constants["NAME_MUSICVAE_MULTITRACK"])

In [None]:
country_ns = ns.midi_file_to_note_sequence(
        get_midi_path(country_sample, get_matched_midi_md5(country_sample, get_msd_score_matches()))
    )

In [None]:
country_sample_enc, _, _ = model.encode([country_ns])

In [None]:
hh_country = country_sample_enc + country_to_hh

In [None]:
seq = model.decode(hh_country)

In [None]:
ns.play_sequence(seq[0],synth=ns.fluidsynth)

### Split example tracks

In [None]:
hop_size_country = get_sec_for_num_bars(country_ns, n_bars=1)

In [None]:
for bla in country_m:
    s = ns.midi_file_to_note_sequence(
        get_midi_path(bla, get_matched_midi_md5(bla, get_msd_score_matches()))
    )
    pm = ns.sequence_proto_to_pretty_midi(s)
    if len(pm.instruments) <= 8:
        print(meta[meta.msdID == bla].loc[:, ["msdID", "artist", "track"]])
        ns.play_sequence(s, synth=ns.fluidsynth)

In [None]:
hh1 = "TRZDMWV128E0796976"
hh2 = "TRKWKER128F1482409"
hh3 = "TRYFUNE12903CCDCD5"

In [None]:
c1 = "TRULTSR12903CBD13D"
c2 = "TRCRKXS12903CD21F9"
c3 = "TRQHQLN12903CBD237"

In [None]:
c = ns.midi_file_to_note_sequence(
    get_midi_path(c3, get_matched_midi_md5(c3, get_msd_score_matches()))
)

In [None]:
splits = model.encode(ns.split_note_sequence(
    c,
    get_sec_for_num_bars(c, n_bars=1)
))

In [None]:
hh_dec = model.decode(hh_enc[1])

In [None]:
len(hh_dec)

In [None]:
ns.play_sequence(hh_dec[0], synth=ns.fluidsynth)