# Beat browser demo

In [None]:
import warnings
warnings.filterwarnings("ignore")  # Disable warning messages (Spleeter generates a lot of those :/)

# Music
from autodj.dj import songcollection, tracklister
from autodj.annotation.beat import beattracker
from autodj.dj.annotators import wrappers
import librosa
import numpy as np
from spleeter.separator import Separator as Spleeter

# I/O
import csv
import ipywidgets as w
from ipyfilechooser import FileChooser
from IPython.display import Audio
import matplotlib.pyplot as plt
import os
import soundfile as sf

# Utilities
import demo_features as _f
import demo_util as _u
from util_nmf_experiment import evaluation
from util_nmf_experiment import odf_op
from util_nmf_experiment import spectrogram_op
from util_nmf_experiment import files

# Plotting
import plotly.graph_objects as go
import plotly.express as px
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler

**Enter the path to your music folder here:**

In [None]:
music_home_dir = 'PATH/TO/MUSIC/HOME/DIR'

Song collection loading and annotation widgets:

In [None]:
sc = songcollection.SongCollection([
    wrappers.BeatAnnotationWrapper(),
    wrappers.OnsetCurveAnnotationWrapper(),
    wrappers.DownbeatAnnotationWrapper(),
    wrappers.StructuralSegmentationWrapper(),
])

# Layout
text_output_layout = w.Layout(margin='2px 30px 0px 50px')
button_layout_fillwidth = w.Layout(width='90%',)

# Title for the text output widget
text_output_title = w.HTML(value='<b>Log</b>', layout=text_output_layout)

# Text output widget for info messages
text_output = w.Output(layout=text_output_layout)

# Create and display a FileChooser widget
fc = FileChooser(music_home_dir,
                 select_desc='Load directory', change_desc='Load directory',)
fc.title = '<b>Choose the music folder to load directory</b>'

# Button for loading folder
load_button = fc.children[2].children[0]  # Reuse the load button from the file chooser
load_button.description = 'Load directory'

# Button for loading annotations
annotate_button = w.Button(description='Load song annotations', 
                           button_style='warning', icon='hourglass-half', layout=button_layout_fillwidth)

# Button for clearing song collection
clear_sc_button = w.Button(description='Unload all songs', layout=button_layout_fillwidth)

# Button for clearing the output widget text
clear_text_button = w.Button(description='Clear text output', layout=button_layout_fillwidth)

# Button for annotating unannotated songs
annotate_unannotated_checkbox = w.Checkbox(value=False, 
    description='Analyse new songs (warning: might take long!)',
    layout=button_layout_fillwidth)

# Callback for load button
def load_button_callback(b):
    try:
        is_new_dir = fc.selected_path not in sc.directories
        sc.load_directory(fc.selected_path)
        if is_new_dir:
            with text_output:
                print(f'Loaded {fc.selected_path}...')
                print(f'Song collection contains {len(sc.get_annotated())} annotated songs ({len(sc.get_unannotated())} unannotated)')
    except:
        with text_output:
            print(f'Select a valid directory first! {fc.selected_path}')
load_button.on_click(load_button_callback)

# Callback for clearing song collection
def clear_sc_button_callback(b):
    sc.clear()
    with text_output:
        print(f'Cleared song collection')
        print(f'Song collection contains {len(sc.get_annotated())} annotated songs ({len(sc.get_unannotated())} unannotated)')
clear_sc_button.on_click(clear_sc_button_callback)

# Callback for clearing text output
def clear_text_button_callback(b):
    text_output.clear_output()
clear_text_button.on_click(clear_text_button_callback)


grid_note = 16

DEFAULT_MARKER_COLOR = 'SteelBlue'
DEFAULT_MARKER_SIZE = 10

feature_functions = [
    _f.low_density, _f.mid_density, _f.hi_density,
    _f.low_sync_pattern, _f.mid_sync_pattern, _f.hi_sync_pattern,
]

feature_values = {k.__name__ : [] for k in feature_functions}

def annotate_button_callback(b):
    
    for k in feature_values.keys():
        feature_values[k] = list(feature_values[k])

    force_annotation = False
    
    text_output.clear_output()
    with text_output:
        if len(sc.get_unannotated()) > 0 and annotate_unannotated_checkbox.value:
            print('Annotating with beats and downbeats...')
            for s in _u.log_progress(sc.get_unannotated()):
                s.annotate()
        else:
            print('All song are already annotated!')

        print('Loading features...')
        for s in _u.log_progress(sc.get_annotated()):
            
            s.open()
            filename_short = files.filename_to_compact_string(s.filepath)
            filename_onsets = f'{os.path.dirname(s.filepath)}/_annot_auto/{filename_short}_onsets_aligned.npy'
            filename_grid = f'{os.path.dirname(s.filepath)}/_annot_auto/{filename_short}_drum_grid.npy'

            if force_annotation or not os.path.exists(filename_onsets):
                odfs = _u.annotate_song_odf(s.filepath, s.tempo, s.downbeats[tracklister.getHAfter(s, 0) + 8], filename_onsets)
            else:
                odfs = np.load(filename_onsets)

            if force_annotation or not os.path.exists(filename_grid):
                grid, boundaries = _u.odf_to_beat_grid(np.vstack(odfs), 
                                                    tempo=175,  # ODFS have been stretched!
                                                    hop_size=512, sr=44100, grid_note=grid_note,)
                np.save(filename_grid, grid)
            else:
                grid = np.load(filename_grid)

            s.filename_odf = filename_onsets
            s.filename_spleeter = filename_onsets+'.wav'
            s.simple_name = filename_short
            s.drum_grid = grid
            s.odfs = odfs
            #s.close()
            
            for feature in feature_functions:
                feature_values[feature.__name__].append(feature(s))
            
        for feature in feature_functions:
            feature_values[feature.__name__] = np.vstack(feature_values[feature.__name__])
annotate_button.on_click(annotate_button_callback)

ui_songcollection_loader = w.HBox([
        w.VBox((fc, annotate_button, annotate_unannotated_checkbox, clear_sc_button, clear_text_button)),
        w.VBox((text_output_title, text_output)),
])

Song collection visualization code:

In [None]:
def display_song_grid(song):
    output = w.Output()
    
    song.open()
    path_to_audio, tempo, start = song.filepath, song.tempo, song.downbeats[tracklister.getHAfter(song, 0) + 8]
    L_s = (60.0 / tempo) * 4 * 4
    
    with output:
        plt.figure(figsize=(18, 16))
        for i in range(4):
            plt.subplot(1,4,i+1)
            plt.imshow(song.drum_grid[:, i*grid_note : (i+1)*grid_note], cmap='magma', interpolation='none')
            plt.axis('off')
        plt.show()
        
    return output

def display_song_audio(song):
    
    song.open()
    path_to_audio, tempo, start = song.filepath, song.tempo, song.downbeats[tracklister.getHAfter(song, 0) + 8]
    L_s = (60.0 / tempo) * 4 * 4
    
    y, sr = librosa.load(path_to_audio, sr=44100, offset=start, duration=L_s)
    y_display = w.Output()
    with y_display:
        display(Audio(y, rate=sr))
        
    y_spltr, sr = librosa.load(song.filename_spleeter, sr=44100, offset=0, duration=L_s)
    y_spltr_display = w.Output()
    with y_spltr_display:
        display(Audio(y_spltr, rate=sr))
        
    return w.HBox((y_display, y_spltr_display))

In [None]:
# Layout
layout_align_bot = w.Layout(align_content='flex-end')

# Output widgets
output_click = w.Output()
output_hover = w.Output()
tracklist_output = w.Output()

# Selection widget for features
features_selector = w.SelectMultiple(
    options=feature_values.keys(),
    description='X-axis',
    disabled=False
)

# Perplexity control for tSNE
# TODO give this a better, more interpretable name (without having to know what tSNE does.)
perplexity_slider = w.IntSlider(
    value=25,
    min=5,
    max=50.0,
    description='tSNE perplexity:',
    continuous_update=False,
)

# Checkbox for scattering: should the plot wiggle all points by a small, random amount?
# This is useful if some datapoints are on exactly the same location (exactly the same feature values).
random_scatter_checkbox = w.Checkbox(
    value=False, 
    description='Random scatter',
)

# Checkbox: perform tSNE or not?
tsne_checkbox = w.Checkbox(
    value=True, 
    description='Perform tSNE',
)

# Button for loading a tracklist and highlighting the tracks
highlight_tracklist_button = w.Button(description='Highlight tracklist', icon='fa-list', layout=layout_align_bot)

# Button for exporting the song names of all selected points
export_tracklist_button = w.Button(description='Export selection to tracklist', icon='fa-download', layout=layout_align_bot)

# File chooser widget to select the tracklist to highlight
fc_tracklist = FileChooser(music_home_dir,
                 select_desc='Load tracklist', change_desc='Load tracklist',)

def scatter_plot_songs(
    x_axis = None,
    random_scatter = None,
    perform_tsne = None,
    perplexity = None,
):
    
    all_songs = sc.get_annotated()
    
    if len(x_axis) == 0:
        display(w.Label(value="Please select some x-axis features to plot."))
        return
    if feature_values is None or len(feature_values[x_axis[0]]) == 0:
        display(w.Label(value="No songs to be displayed, please load some songs first!"))
        return
    
    X = np.hstack((feature_values[i] for i in x_axis))
        
    # TODO make sure that each feature array is indeed a (N, ?) float numpy array (during feature loading)
    # TODO scale invariance of TSNE?
    if perform_tsne:
        X = StandardScaler().fit_transform(X)
        X = TSNE(n_components=2, perplexity=perplexity).fit_transform(X)
        x,y = X[:,0], X[:,1]
    else:
        x = X[:, 0]
        y = X[:, 1]
    
    x, y = x.flatten(), y.flatten()
    
    # scatter x and y just slightly to avoid overlap
    if random_scatter:
        x += 0.0005*np.random.rand(len(x))
        y += 0.0005*np.random.rand(len(y))

    labels = [s.simple_name for s in all_songs]
    fig = go.FigureWidget([go.Scatter(x=x, y=y, mode='markers')])
    scatter = fig.data[0]
    scatter.hovertext=labels
    scatter.marker.color = [DEFAULT_MARKER_COLOR] * len(labels)
    scatter.marker.size = [DEFAULT_MARKER_SIZE] * len(labels)
    fig.update_layout(
        font=dict(
            family="Courier New, monospace",
            size=18,
            color="#7f7f7f"
        )
    )

    def scatter_on_click(trace, points, state):
        output_click.clear_output()
        idx = points.point_inds[0]
        with output_click:
            audio_display = display_song_audio(all_songs[idx])
            grid_display= display_song_grid(all_songs[idx])
            display(w.VBox((audio_display, grid_display)))
    scatter.on_click(scatter_on_click)

    #def scatter_on_hover(trace, points, state):
    #    output_hover.clear_output()
    #    idx = points.point_inds[0]
    #    with output_hover:
    #        display(display_song_grid(all_songs[idx]))
            
    #scatter.on_hover(scatter_on_hover)
    
    def scatter_highlight_tracklist(b):
        tracklist_output.clear_output()
        with tracklist_output:
            print(f'Loading songs from tracklist {fc_tracklist.selected}')
        
        with open(fc_tracklist.selected) as csvfile:
            reader = csv.reader(csvfile)
            tracklist_songs = [files.filename_to_compact_string(l[0]) for l in reader]
        
        selection = [i for i,s in enumerate(all_songs) if s.simple_name in tracklist_songs]
        
        c = list(scatter.marker.color)
        s = list(scatter.marker.size)
        
        if True:
            new_color, new_size = 'ForestGreen', 20
        else:
            new_color, new_size = DEFAULT_MARKER_COLOR, DEFAULT_MARKER_SIZE
            
        if selection is not None and len(selection) > 0:
            for i in selection:
                c[i] = new_color
                s[i] = new_size
            with fig.batch_update():
                scatter.marker.color = c
                scatter.marker.size = s
    highlight_tracklist_button._click_handlers.callbacks = []
    highlight_tracklist_button.on_click(scatter_highlight_tracklist)
       
    def scatter_export_highlights(b):
        selection = fig.data[0].selectedpoints
        if selection is not None and len(selection) > 0:
            tracklist_output.clear_output()
            with tracklist_output:
                for i in selection:
                    print(all_songs[i].filepath)
    export_tracklist_button._click_handlers.callbacks = []
    export_tracklist_button.on_click(scatter_export_highlights)

    display(w.VBox([fig, output_click, output_hover]))
            
    
out = w.interactive_output(scatter_plot_songs,
                           {'x_axis' : features_selector,
                            'random_scatter' : random_scatter_checkbox,
                            'perform_tsne' : tsne_checkbox,
                            'perplexity' : perplexity_slider,
                           })

ui = w.VBox([
    w.HBox([features_selector,]),
    w.HBox([perplexity_slider, random_scatter_checkbox, tsne_checkbox]),
    w.HBox([fc_tracklist, highlight_tracklist_button, export_tracklist_button,])
])

ui_scatter_plot = w.VBox([ui, out, tracklist_output])

In [None]:
tab = w.Tab([ui_songcollection_loader, ui_scatter_plot])
tab.set_title(0, 'Loading songs')
tab.set_title(1, 'Music browser')
display(tab)