In [1]:
import os
import IPython.display as ipd
import librosa
import numpy as np
import torch

In [2]:
def trim_audio(audio, sr, ti, tf, mono=True):
    i = ti * sr
    f = tf * sr
    if mono:
      t_audio = audio[i:f]
    else:
      t_audio = audio[:, i:f]
    return t_audio

def read_audio(file_path, trim_interval=None, mono=True, print_it=False):
    audio, sr = librosa.load(file_path, mono=mono)
    audio_dim = len(audio.shape)
    if not mono and audio_dim == 1:
      audio = np.asarray((audio, audio))
    if trim_interval is not None:
      ti = trim_interval[0]
      tf = trim_interval[1]
      audio = trim_audio(audio, sr, ti, tf, mono)
    if print_it:
      print(audio.shape)
      print(sr)
    return audio, sr

In [5]:
audio_folder_path = '../data/WAV/Individual Hits/02. Snare Drum/01. Clean'

model_path = '../models/drumkit_v1.ts'
model = torch.jit.load(model_path)

In [6]:

latent_spaces = []

for audio_file in os.listdir(audio_folder_path):
  if not audio_file.endswith('.wav'): continue
  audio, sr = read_audio(f'{audio_folder_path}/{audio_file}')
  ipd.Audio(audio, rate=sr)
  with torch.no_grad():
    x = torch.from_numpy(audio).reshape(1,1,-1)
    z = model.encode(x)
    latent_space = torch.squeeze(z, 0)
    latent_space_np = latent_space.numpy()
    latent_spaces.append(latent_space_np)

In [None]:
import pandas as pd

df = pd.DataFrame(results.points, columns=['x', 'y'])
df['label'] = dataset.values('label')
df['audio_name'] = dataset.values('audio_name')
df['filepath'] = dataset.values('filepath')

df.head()

In [None]:
import plotly.graph_objs as go

# Create a scatter plot of the latent space
data = []
for label in df['label'].unique():
    df_label = df[df['label'] == label]
    scatter = go.Scatter(
        x=df_label['x'], 
        y=df_label['y'], 
        mode='markers',
        text=df_label['audio_name'],  # Add the audio name as text
        hovertemplate='%{text}<extra></extra>',  # Customize the hover template
        name=label,  # Use the label as the name of the trace
        customdata=df_label['filepath']  # Add the file path as custom data
    )
    scatter.on_click(lambda x: print(x.points[0].hovertext))  # Print the name of the point when it's clicked
    data.append(scatter)

fig = go.Figure(data=data)

# Adjust the margins (l, r, t, b stand for left, right, top, and bottom)
fig.update_layout(
    autosize=False,
    width=1000,
    height=400,
    margin=dict(l=24,r=24,b=24,t=24,pad=0)
)

import dash
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Input, Output
import base64
import io

def get_relative_path(absolute_path):
    base_path = os.getcwd()
    return os.path.relpath(absolute_path, base_path)

app = dash.Dash(__name__)

app.layout = html.Div([
    dcc.Graph(id='scatter-plot', figure=fig),
    html.Pre(id='click-data', style={'padding': '10px', 'color': 'white'}),
    html.Audio(id='input-player', controls=True, autoPlay=True),
    html.Audio(id='output-player', controls=True, autoPlay=False)
])

@app.callback(Output('click-data', 'children'), [Input('scatter-plot', 'clickData')])
def display_click_data(clickData):
    if clickData is None: return 'None'
    absolute_path = clickData["points"][0]["customdata"]
    base_path = os.getcwd()
    return os.path.relpath(absolute_path, base_path)

@app.callback(Output('input-player', 'src'), [Input('scatter-plot', 'clickData')])
def play_input(clickData):
    if clickData is None: return ''

    relative_path = get_relative_path(clickData["points"][0]["customdata"])
    with open(relative_path, 'rb') as audio_file:
        encoded_audio = base64.b64encode(audio_file.read()).decode('ascii')
    src = f'data:audio/mp3;base64,{encoded_audio}'
    
    return src

@app.callback(Output('output-player', 'src'), [Input('scatter-plot', 'clickData')])
def play_output(clickData):
    if clickData is None: return ''
    
    relative_path = get_relative_path(clickData["points"][0]["customdata"])

    audio, sr = read_audio(relative_path)
    with torch.no_grad():
        x = torch.from_numpy(audio).reshape(1 ,1, -1)
        z = model.encode(x)
        x_hat = model.decode(z)
    waveform_tensor = torch.squeeze(x_hat, 0)
    output_path = '../output/output.wav'
    torchaudio.save(output_path, waveform_tensor, sr)

    with open(output_path, 'rb') as audio_file:
        encoded_audio = base64.b64encode(audio_file.read()).decode('ascii')
    src = f'data:audio/mp3;base64,{encoded_audio}'

    return src

if __name__ == '__main__': app.run_server(debug=True)