In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds

# load the NSynth dataset
dataset, info = tfds.load('nsynth', split='train', with_info=True)
print(info)

In [None]:
# inspect the keys of one sample
for sample in dataset.take(1):
    print("Available keys:")
    for key in sample.keys():
        print(key)

In [None]:
from IPython.display import Audio

audio_np = audio.numpy()
Audio(audio_np, rate=16000)  # Assuming a sample rate of 16kHz

In [None]:
import plotly.graph_objects as go

fig = go.Figure()
fig.add_trace(go.Scatter(
    y=audio_np,
    mode='lines',
    line=dict(color='black'),
    name="Waveform"
))

fig.update_layout(
    title="Waveform",
    xaxis_title="Time (samples)",
    yaxis_title="Amplitude",
    template="plotly_white",
    width=800,
    height=400
)

fig.show()

In [None]:
import librosa
import numpy as np

# compute the STFT
spectrogram = librosa.stft(audio_np, n_fft=512, hop_length=256)
spectrogram_db = librosa.amplitude_to_db(abs(spectrogram))

time = np.linspace(0, len(audio_np) / 16000, spectrogram_db.shape[1])
frequencies = np.linspace(0, 16000 / 2, spectrogram_db.shape[0])

fig = go.Figure(data=go.Heatmap(
    z=spectrogram_db,
    x=time,
    y=frequencies,
    colorscale='Viridis',
    colorbar=dict(title='Amplitude (dB)'),
))

fig.update_layout(
    title="Spectrogram",
    xaxis_title="Time (seconds)",
    yaxis_title="Frequency (Hz)",
    yaxis=dict(type="log"),
    template="plotly"
)

fig.show()

In [None]:
from collections import Counter

# count instrument occurrences
instrument_counts = Counter()
for sample in dataset.take(1000):
    instrument = sample['instrument']['family'].numpy()
    instrument_counts[instrument] += 1

# map numeric IDs to instrument family names
instrument_families = ["Bass", "Brass", "Flute", "Guitar", "Keyboard", "Mallet", "Organ", "Reed", "String", "Synth Lead", "Synth Pad", "Vocal"]
mapped_family_counts = {instrument_families[family_id]: count for family_id, count in instrument_counts.items()}

import plotly.express as px
fig = px.bar(
    x=list(mapped_family_counts.keys()),
    y=list(mapped_family_counts.values()),
    labels={'x': 'Instrument Family', 'y': 'Count'},
    title="Distribution of Instrument Families",
    template="plotly"
)
fig.show()

In [None]:
mel_spectrogram = librosa.feature.melspectrogram(y=audio_np, sr=16000, n_mels=128)
mel_spectrogram_db = librosa.power_to_db(mel_spectrogram, ref=np.max)

fig = go.Figure(data=go.Heatmap(
    z=mel_spectrogram_db,
    x=time,
    y=np.linspace(0, 16000 / 2, mel_spectrogram_db.shape[0]),
    colorscale='Viridis',
    colorbar=dict(title="Amplitude (dB)")
))
fig.show()

In [None]:
mfccs = librosa.feature.mfcc(y=audio_np, sr=16000, n_mfcc=13)

fig = go.Figure(data=go.Heatmap(
    z=mfccs,
    x=time,
    y=np.arange(1, mfccs.shape[0] + 1),
    colorscale='Viridis',
    colorbar=dict(title="MFCC Value")
))
fig.show()

In [None]:
# apply pitch shift (+2 semitones)
audio_pitch_shifted = librosa.effects.pitch_shift(audio_np, sr=16000, n_steps=2)

# apply time-stretching (speed up by 1.5x)
audio_time_stretched = librosa.effects.time_stretch(audio_np, rate=1.5)

# plot waveforms
fig = go.Figure()
fig.add_trace(go.Scatter(y=audio_np, mode='lines', name='Original'))
fig.add_trace(go.Scatter(y=audio_pitch_shifted, mode='lines', name='Pitch Shifted'))
fig.add_trace(go.Scatter(y=audio_time_stretched, mode='lines', name='Time Stretched'))
fig.show()