In [None]:
import numpy as np
import matplotlib.pyplot as plt
import librosa
import librosa.display
from IPython.display import Audio, display
import ruptures as rpt
from scipy.ndimage import gaussian_filter1d
from scipy.signal import find_peaks

def fig_ax(figsize=(15, 5), dpi=150):
    return plt.subplots(figsize=figsize, dpi=dpi)

# =============================
# --- 1. Load the audio file ---
# Replace the file_path with your classical instrumental music file path.
file_path = "FILE_PATH" # Preferably mp3 file
signal, sr = librosa.load(file_path, sr=None)
display(Audio(data=signal, rate=sr))

# =============================
# --- 2. Feature Extraction ---
hop_length = 512

# 2.1 MFCCs and their deltas (timbre and dynamics)
n_mfcc = 20
mfcc  = librosa.feature.mfcc(y=signal, sr=sr, n_mfcc=n_mfcc, hop_length=hop_length)
mfcc_delta  = librosa.feature.delta(mfcc)
mfcc_delta2 = librosa.feature.delta(mfcc, order=2)

# 2.2 High-resolution Chroma (harmonic content)
chroma = librosa.feature.chroma_stft(y=signal, sr=sr, hop_length=hop_length, n_chroma=24)

# 2.3 Tonnetz features (additional harmonic descriptors)
tonnetz = librosa.feature.tonnetz(y=signal, sr=sr)

# 2.4 Spectral contrast (for dynamics and timbral shifts)
spectral_contrast = librosa.feature.spectral_contrast(y=signal, sr=sr, hop_length=hop_length)

# Stack the features into a single feature matrix.
# The weights have been tuned to emphasize key aspects for classical instrumental music.
features = np.vstack([
    0.4 * mfcc,          # base MFCCs
    0.2 * mfcc_delta,    # first-order delta
    0.2 * mfcc_delta2,   # second-order delta
    0.4 * chroma,        # high-res chroma for harmony
    0.4 * tonnetz,       # tonnetz features
    0.3 * spectral_contrast  # spectral contrast for dynamics
])

# Smooth the feature trajectories along the time axis (to reduce minor fluctuations)
sigma = 1.5  # Smoothing parameter; adjust if necessary.
features_smooth = gaussian_filter1d(features, sigma=sigma, axis=1)

# =============================
# --- 3. Change Point Detection with PELT ---
# Use the PELT algorithm from ruptures with an adaptive penalty.
algo = rpt.Pelt(model="rbf").fit(features_smooth.T)
# Compute a penalty value based on the number of frames (adjust the multiplier as needed)
penalty = 10 * np.log(features_smooth.shape[1])
bkps = algo.predict(pen=penalty)
# bkps: list of frame indices where a change is detected (last index equals the number of frames)
bkps_times = librosa.frames_to_time(bkps, sr=sr, hop_length=hop_length)

# =============================
# --- 4. Adjust Boundaries to Onsets ---
# Compute an onset strength envelope to find strong musical onsets.
onset_env = librosa.onset.onset_strength(y=signal, sr=sr, hop_length=hop_length)
# Find peaks in the onset envelope – parameters may be tuned (here we use the median as a threshold)
peaks, _ = find_peaks(onset_env, height=np.median(onset_env), distance=3)
beat_times = librosa.frames_to_time(peaks, sr=sr, hop_length=hop_length)

# For each detected breakpoint (except the final one), snap to the nearest onset.
adjusted_bkps_times = []
for t in bkps_times[:-1]:
    if len(beat_times) > 0:
        idx = np.argmin(np.abs(beat_times - t))
        adjusted_bkps_times.append(beat_times[idx])
# Append the end time of the signal as the final boundary.
adjusted_bkps_times.append(signal.shape[0] / sr)
adjusted_bkps_times = np.array(adjusted_bkps_times)

# =============================
# --- 5. Plot the Segmentation ---
fig, ax = fig_ax()
img = librosa.display.specshow(mfcc, sr=sr, hop_length=hop_length, x_axis='time', ax=ax)
ax.set_title("MFCCs with Detected Segmentation Boundaries")
for t in adjusted_bkps_times[:-1]:
    ax.axvline(t, color='red', linestyle='--', linewidth=2)
plt.colorbar(img, ax=ax)
plt.show()

# =============================
# --- 6. Playback Segments ---
# Convert adjusted times to sample indices.
bkps_sample_indices = (adjusted_bkps_times * sr).astype(int).tolist()
# Create boundaries list (including start and end of signal).
boundaries = [0] + bkps_sample_indices

print("Playing segmented audio clips:")
for i in range(len(boundaries) - 1):
    start, end = boundaries[i], boundaries[i+1]
    segment = signal[start:end]
    duration = (end - start) / sr
    print(f"Segment {i+1}: {duration:.2f} seconds")
    display(Audio(data=segment, rate=sr))
