# RAVE: Latent Space Exploration

In [1]:
import plotly.graph_objs as go
import dash
from dash import dcc
from dash import html
from dash.dependencies import Input, Output
import base64
import io
import cv2
import numpy as np
import fiftyone.brain as fob
import pandas as pd
import fiftyone as fo
import librosa
import os
import glob
import torch
import torchaudio
import IPython.display as ipd
import tempfile
import soundfile as sf
import rave
from pydub import AudioSegment
from pydub.playback import play

Migrating database to v0.23.2




## Common functions

In [18]:
def get_relative_path(absolute_path):
    base_path = os.getcwd()
    return os.path.relpath(absolute_path, base_path)

def find_wav_files(*folder_paths):
    wav_files = []

    for folder_path in folder_paths:
        if folder_path:
            for root, _, _ in os.walk(folder_path):
                wav_files.extend(glob.glob(os.path.join(root, '*.wav')))

    return wav_files

def trim_audio(audio, sr, ti, tf, mono=True):
    i = ti * sr
    f = tf * sr
    if mono: t_audio = audio[i:f]
    else: t_audio = audio[:, i:f]
    return t_audio

def read_audio(file_path, trim_interval=None, mono=True, print_it=False):
    audio, sr = librosa.load(file_path, mono=mono)
    audio_dim = len(audio.shape)
    if not mono and audio_dim == 1:
        audio = np.asarray((audio, audio))
    if trim_interval is not None:
        ti = trim_interval[0]
        tf = trim_interval[1]
        audio = trim_audio(audio, sr, ti, tf, mono)
    if print_it:
        print(audio.shape)
        print(sr)
    return audio, sr

def remix_audio(left_audio_array, right_audio_array, sample_rate=44100):
    # ensure both audio arrays have the same length
    length = min(len(left_audio_array), len(right_audio_array))
    left_audio_array = left_audio_array[:length]
    right_audio_array = right_audio_array[:length]
    stereo_audio_array = np.column_stack((left_audio_array, right_audio_array)) # create stereo array
    return stereo_audio_array

def remove_common_part(file_names):
    common_prefix = os.path.commonprefix(file_names)
    # common_suffix = os.path.commonprefix([name[::-1] for name in file_names])[::-1]
    return [name[len(common_prefix):] for name in file_names]

## Create torch model

In [54]:
model_name = 'drumkit_v1'
model_name = 'GMDrums_v3_29-09_3M_streaming'
model_name = 'percussion'
model_name = 'darbouka_onnx'
model = torch.jit.load(f'../models/{model_name}.ts')

## Process sounds

In [55]:
drumhits_folder = '../data/WAV/Individual Hits'
taps_folder = '../data/finger_tapping'
samples = find_wav_files(drumhits_folder, taps_folder)

print(f'Found a total of {len(samples)} samples:')

Found a total of 1611 samples:


In [56]:
sample_latents = []
samples_filtered = []

for sample_path in samples:
    min_length = 4410
    desired_length = 88200
    sample_audio, wav_sr = read_audio(sample_path)
    if sample_audio.shape[0] < min_length: continue # ignore files shorter than 100ms
    samples_filtered.append(sample_path) # add to list of filtered files
    padding_width = max(0, desired_length - len(sample_audio)) # calculate padding width
    sample_audio = np.pad(sample_audio, (0, padding_width), mode='constant', constant_values=0) # pad with zeros
    with torch.no_grad():
        x = torch.from_numpy(sample_audio).reshape(1, 1, -1)
        z = model.encode(x) # encode the audio into the RAVE latent space
        latent_space_matrix = torch.squeeze(z, 0)
        sample_latents.append(latent_space_matrix) # add to list of latent space matrices

# make sure all tensors have the same shape
tensor_shapes = [tensor.shape for tensor in sample_latents]
if all(shape == tensor_shapes[0] for shape in tensor_shapes):
    print('All tensors have same shape.')
if not all(shape == tensor_shapes[0] for shape in tensor_shapes):
    tensor_shapes = [tensor.shape for tensor in sample_latents]
    different_dimension_index = np.where(np.array(tensor_shapes) != tensor_shapes[0])[0]
    for index in different_dimension_index:
        print('There were some tensors that happen to have different shape:')
        print(f'Removing sample {samples_filtered[index]} and its corresponding latent space matrix.')
        del samples_filtered[index]
        del sample_latents[index]
        del tensor_shapes[index]

sample_latents_np = np.array(sample_latents) # convert sample_latents to a np array

All tensors have same shape.


In [57]:
flattened_tensors = [tensor.flatten() for tensor in sample_latents]
embeddings = np.vstack(flattened_tensors)
print(embeddings.shape)

(1236, 176)


## Use t-SNE in order to narrow down to 2 the number of dimensions

In [58]:
audio_paths = samples_filtered
audio_names = remove_common_part(samples_filtered)

dataset = fo.Dataset() # create a FiftyOne SampleCollection

# create a sample for each audio file
for audio_path, audio_name in zip(audio_paths, audio_names):
    audio_data, sample_rate = librosa.load(audio_path, sr=None)
    if 'BD' in audio_name: label = 'BD' # label BD
    elif 'SD' in audio_name: label = 'SD' # label SD
    elif 'onset' in audio_name: label = 'Tap' # label SD
    else: label = 'other' # label other
    # create and append sample
    sample = fo.Sample(filepath=audio_path, audio=audio_data, sample_rate=sample_rate, label=label, audio_name=audio_name)
    dataset.add_sample(sample)

# dataset.save('../data') # save the the datase

In [59]:
# compute a 2D representation using t-SNE
results = fob.compute_visualization(
    dataset,
    embeddings=embeddings,
    num_dims=2,
    method='tsne',
    brain_key='mnist_test',
    verbose=True,
    seed=51,
)

Generating visualization...
[t-SNE] Computing 91 nearest neighbors...
[t-SNE] Indexed 1236 samples in 0.001s...
[t-SNE] Computed neighbors for 1236 samples in 0.027s...
[t-SNE] Computed conditional probabilities for sample 1000 / 1236
[t-SNE] Computed conditional probabilities for sample 1236 / 1236
[t-SNE] Mean sigma: 2.715692
[t-SNE] Computed conditional probabilities in 0.028s
[t-SNE] Iteration 50: error = 81.1890564, gradient norm = 0.1522968 (50 iterations in 0.134s)
[t-SNE] Iteration 100: error = 82.5773010, gradient norm = 0.2936852 (50 iterations in 0.141s)
[t-SNE] Iteration 150: error = 84.1498795, gradient norm = 0.2638909 (50 iterations in 0.117s)
[t-SNE] Iteration 200: error = 85.7504425, gradient norm = 0.2561235 (50 iterations in 0.125s)
[t-SNE] Iteration 250: error = 86.0628586, gradient norm = 0.2616243 (50 iterations in 0.138s)
[t-SNE] KL divergence after 250 iterations with early exaggeration: 86.062859
[t-SNE] Iteration 300: error = 2.5489414, gradient norm = 0.01126

## Interactive latent space exploration

In [60]:
df = pd.DataFrame(results.points, columns=['x', 'y'])
df['label'] = dataset.values('label')
df['audio_name'] = dataset.values('audio_name')
df['filepath'] = dataset.values('filepath')

df.head()

Unnamed: 0,x,y,label,audio_name,filepath
0,-6.999027,3.540467,BD,WAV/Individual Hits\01. Bass Drum\01. Clean\02...,C:\Users\asantos6\Documents\RAVERs\data\WAV\In...
1,-8.95024,14.47272,BD,WAV/Individual Hits\01. Bass Drum\01. Clean\02...,C:\Users\asantos6\Documents\RAVERs\data\WAV\In...
2,1.673388,9.086163,BD,WAV/Individual Hits\01. Bass Drum\01. Clean\02...,C:\Users\asantos6\Documents\RAVERs\data\WAV\In...
3,-7.443605,9.83209,BD,WAV/Individual Hits\01. Bass Drum\01. Clean\02...,C:\Users\asantos6\Documents\RAVERs\data\WAV\In...
4,-8.109398,-2.461702,BD,WAV/Individual Hits\01. Bass Drum\01. Clean\02...,C:\Users\asantos6\Documents\RAVERs\data\WAV\In...


In [61]:

# Create a scatter plot of the latent space
data = []
for label in df['label'].unique():
    df_label = df[df['label'] == label]
    scatter = go.Scatter(
        x=df_label['x'], 
        y=df_label['y'], 
        mode='markers',
        text=df_label['audio_name'],  # Add the audio name as text
        hovertemplate='%{text}<extra></extra>',  # Customize the hover template
        name=label,  # Use the label as the name of the trace
        customdata=df_label['filepath']  # Add the file path as custom data
    )
    scatter.on_click(lambda x: print(x.points[0].hovertext))  # Print the name of the point when it's clicked
    data.append(scatter)

fig = go.Figure(data=data)

# Adjust the margins (l, r, t, b stand for left, right, top, and bottom)
fig.update_layout(
    autosize=False,
    width=1000,
    height=400,
    margin=dict(l=24,r=24,b=24,t=24,pad=0)
)

app = dash.Dash(__name__)

app.layout = html.Div([
    dcc.Graph(id='scatter-plot', figure=fig),
    html.Pre(id='click-data', style={'padding': '10px', 'color': 'white'}),
    html.Audio(id='input-player', controls=True, autoPlay=True),
    html.Audio(id='output-player', controls=True, autoPlay=False)
])

@app.callback(Output('click-data', 'children'), [Input('scatter-plot', 'clickData')])
def display_click_data(clickData):
    if clickData is None: return 'None'
    absolute_path = clickData["points"][0]["customdata"]
    base_path = os.getcwd()
    return os.path.relpath(absolute_path, base_path)

@app.callback(Output('input-player', 'src'), [Input('scatter-plot', 'clickData')])
def play_input(clickData):
    if clickData is None: return ''

    relative_path = get_relative_path(clickData["points"][0]["customdata"])
    with open(relative_path, 'rb') as audio_file:
        encoded_audio = base64.b64encode(audio_file.read()).decode('ascii')
    src = f'data:audio/mp3;base64,{encoded_audio}'
    
    return src

@app.callback(Output('output-player', 'src'), [Input('scatter-plot', 'clickData')])
def play_output(clickData):
    if clickData is None: return ''
    
    relative_path = get_relative_path(clickData["points"][0]["customdata"])

    audio, sr = read_audio(relative_path)
    with torch.no_grad():
        x = torch.from_numpy(audio).reshape(1 ,1, -1)
        z = model.encode(x)
        x_hat = model.decode(z)
    waveform_tensor = torch.squeeze(x_hat, 0)
    output_path = '../output/output.wav'
    torchaudio.save(output_path, waveform_tensor, sr)

    with open(output_path, 'rb') as audio_file:
        encoded_audio = base64.b64encode(audio_file.read()).decode('ascii')
    src = f'data:audio/mp3;base64,{encoded_audio}'

    return src

if __name__ == '__main__': app.run_server(debug=True)