# 04 Swing Detection 

In [None]:
import os
import sys
print(sys.version)

In [None]:
import warnings
import pickle as pkl
from importlib import reload

import librosa
from librosa import time_to_frames, time_to_samples, frames_to_time, frames_to_samples

import numpy as np
from scipy.signal import find_peaks
import matplotlib.pyplot as plt

In [None]:
sys.path.append('./COMMON_UTILS/')

In [None]:
from utils import play, plot_audio

from drum_processor import getDownbeats

from tempo_align import matchAudioEvents

In [None]:
plt.rcParams['figure.figsize'] = (15, 5)

In [None]:
hop_length=256

In [None]:
SRC = './PROCESSED/jazz3/'
# y, sr = librosa.load(os.path.join(SRC, 'DRUMS/drums.wav'), sr=44100)
y, sr = librosa.load(os.path.join(SRC, 'source.wav'), sr=44100)
rms_src = librosa.feature.rms(y=y, hop_length=hop_length)[0]
play(y, sr)

In [None]:
DST = './PROCESSED/rock2/'
# y_dst, sr = librosa.load(os.path.join(DST, 'DRUMS/drums.wav'), sr=44100)
y_dst, _ = librosa.load(os.path.join(DST, 'source.wav'), sr=sr)
rms_dst = librosa.feature.rms(y=y_dst, hop_length=hop_length)[0]
play(y_dst, sr)

In [None]:
def findDownbeats(p):
    try:
        with open(os.path.join(p, 'downbeats.pkl'), 'rb') as f:
            db = pkl.load(f)
        print('found downbeats.pkl')
    except FileNotFoundError:
        print('computing downbeats')
        db = getDownbeats(
            os.path.join(p, 'source.wav'), 
            transition_lambda=64, 
        )

        with open(os.path.join(p, 'downbeats.pkl'), 'wb') as f:
            pkl.dump(db, f)
            
    return db

In [None]:
db_src = findDownbeats(SRC)
db_dst = findDownbeats(DST)

In [None]:
plot_audio(y, sr);
# plt.plot(
#     librosa.frames_to_time(np.arange(len(rms_src)) - 1, sr=sr, hop_length=hop_length), 
#     rms_src, 
#     c='k'
# );
for d in db_src:
    plt.axvline(d, color='r')

beats = []
for i in range(len(db_src) - 1):
    beats.extend(np.linspace(db_src[i], db_src[i+1], 4, endpoint=False))
beats = np.array(beats)

for b in beats:
    plt.axvline(b, color='k', ls='--', lw=1)

## Development 

In [None]:
beats_rms = []
beats_frames = librosa.time_to_frames(beats, sr=sr, hop_length=hop_length)
for start, end in zip(beats_frames[:-1], beats_frames[1:]):
    beats_rms.append(rms_src[start:end])

In [None]:
length = min(map(len, beats_rms))
beat_energy = np.mean(
    np.stack(
        [b[:length] for b in beats_rms]
    ),
    axis=0
)
peaks, _ = find_peaks(beat_energy, height=0.01, prominence=0.005)

In [None]:
peaks

In [None]:
t = np.linspace(0, 1, len(beat_energy))
# plt.plot(t, beat_energy);
# plt.scatter(t[peak00s], beat_energy[peaks]);

In [None]:
plt.plot(t, np.roll(beat_energy, -peaks[0]));
plt.scatter(t[peaks-peaks[0]], beat_energy[peaks]);

In [None]:
for p in peaks[1:]:
    swing = (p - peaks[0]) / len(beat_energy)
    swing = round(swing, 2)
    if abs(swing - 0.5) < 0.05:
        swing = 0.5
        
    print(f'swing: {swing:.2%}')

## Pipeline 

In [None]:
def getBeatTimes(db, n=4):
    beats = []
    for i in range(len(db) - 1):
        beats.extend(np.linspace(db[i], db[i+1], n, endpoint=False))
    beats = np.array(beats)
    
    return beats

In [None]:
def getSwingPoints(y, sr, db, hop_length=256):
    beats = getBeatTimes(db)
        
    rms = librosa.feature.rms(y=y, hop_length=hop_length)[0]
    beats_rms = []
    beats_frames = time_to_frames(beats, sr=sr, hop_length=hop_length)
    for start, end in zip(beats_frames[:-1], beats_frames[1:]):
        beats_rms.append(rms[start:end])
        
    length = min(map(len, beats_rms))
    beat_energy = np.mean(
        np.stack(
            [b[:length] for b in beats_rms]
        ),
        axis=0
    )

    peaks, _ = find_peaks(beat_energy, height=0.03, prominence=0.005)
    
    points = []
    for p in peaks:
        swing = (p - peaks[0]) / length
        if abs(swing-0.5) < 0.05:
            swing = 0.5
            
        swing = round(swing, 2)
        points.append(swing)
        
    if len(points) == 1:
        points.append(0.5)
        
    return points

In [None]:
def plotSwingmap(a, b, map_=None):
    fig, ax = plt.subplots(figsize=(15, 1))
    ax.plot(a, np.zeros_like(a)+1, marker='o', ls='', ms=15, c='k')
    ax.plot(b, np.zeros_like(b), marker='o', ls='', ms=15, c='k')
    ax.set_xlim(0, 1); ax.set_ylim(-0.2, 1.2); 
    ax.set_yticks([0, 1]); ax.set_yticklabels(['to', 'from']);
    if map_ is not None:
        for m in map_:
            ax.arrow(m[0], 1, m[1]-m[0], -1, color='red', ls=':', lw=1)

In [None]:
def getSwingMap(a, b):
    a = np.array(a)
    b = np.array(b)
    map_ = []
    if len(a) == len(b):
        for x, y in zip(a, b):
            map_.append((x, y))
    elif len(a) < len(b):
        for x in a:
            y = b[np.argmin(np.abs(b - x))]
            map_.append((x, y))
    else:
        for y in b:
            x = a[np.argmin(np.abs(a - y))]
            map_.append((x, y))
    
    return map_

In [None]:
points_src = getSwingPoints(y, sr, db_src, hop_length=hop_length)
points_dst = getSwingPoints(y_dst, sr, db_dst, hop_length=hop_length)

In [None]:
points_src, points_dst

In [None]:
map_ = getSwingMap(points_dst, points_src)
print(map_)
plotSwingmap(points_dst, points_src, map_)

In [None]:
def getSwingTimings(db, map_):
    beats = getBeatTimes(db)
    
    points_from = []
    points_to = []
    
    for i in range(len(beats) - 1):
        dt = beats[i+1] - beats[i]
        for m in map_:
            points_from.append(beats[i] + dt * m[0])
            points_to.append(beats[i] + dt * m[1])
            
    return points_from, points_to

In [None]:
points_from, points_to = getSwingTimings(db_dst, map_)

In [None]:
fig, ax = plt.subplots()   
ax.vlines(db_dst, 0, 1, color='k')
ax.vlines(points_from, 0.5, 1.0, color='red', ls=':', lw=1)
ax.vlines(points_to, 0.0, 0.5, color='blue', ls=':', lw=1)
    
ax.set_xlim(0, 4)

In [None]:
y_warped = matchAudioEvents(y_dst, sr, points_from, points_to, hq=True)

In [None]:
play(y, sr)
play(y_dst, sr)
play(y_warped, sr)

In [None]:
fig, axes = plt.subplots(2, sharex=True)
plot_audio(y_dst, sr, ax=axes[0])
axes[0].vlines(points_from, -0.5, 0.5, color='red', ls='-', lw=2)
plot_audio(y_warped, sr, ax=axes[1])
axes[1].vlines(points_to, -0.5, 0.5, color='blue', ls='-', lw=2)

axes[1].set_xlim(1, 2.5);

In [None]:
def alignSwing(y_org, y_trg, sr, db_org, db_trg, hop_length=256):
    points_org = getSwingPoints(y_org, sr, db_org, hop_length=hop_length)
    points_trg = getSwingPoints(y_trg, sr, db_trg, hop_length=hop_length)
    
    map_ = getSwingMap(points_org, points_trg)
    points_from, points_to = getSwingTimings(db_org, map_)
    
    y_warped = matchAudioEvents(y_org, sr, points_from, points_to, hq=True)
    return y_warped

## As Import 

In [None]:
from swing_align import alignSwing

In [None]:
TRG = './PROCESSED/jazz_beat/'
y_trg, sr = librosa.load(os.path.join(TRG, 'source.wav'), sr=44100)

ORG = './PROCESSED/bowie_heroes/'
y_org, _ = librosa.load(os.path.join(ORG, 'source.wav'), sr=sr)

print('TARGET:')
play(y_trg, sr)
print('ORIGINAL:')
play(y_org, sr)

In [None]:
db_trg = findDownbeats(TRG)
db_org = findDownbeats(ORG)

In [None]:
y_warped = alignSwing(y_org, y_trg, sr, db_org, db_trg)

In [None]:
play(y_warped, sr)