# 02 Rhythm Interpolation and Resynthesis

In [None]:
import os
import sys
print(sys.version)
sys.path.extend(['./COMMON_UTILS/'])

In [None]:
import warnings
import pickle as pkl
from importlib import reload

import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import find_peaks

import torch
import librosa

from tqdm.auto import tqdm
from IPython.display import display, Audio

In [None]:
from utils import plot_audio, play

from wasserstein_transformations import SmoothTransition

from tempo_align import warpAudio, quantiseAudio

from drum_processor import getDownbeats
from drum_decomposition import getDecomposition, plotDecomposition, isolateSources, getSamples

In [None]:
warnings.simplefilter('ignore')

plt.rcParams['figure.figsize'] = (15, 5)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'computing on {device}')

In [None]:
def findDownbeats(p):
    try:
        with open(os.path.join(p, 'downbeats.pkl'), 'rb') as f:
            db = pkl.load(f)
        print('found downbeats.pkl')
    except FileNotFoundError:
        print('computing downbeats')
        db = getDownbeats(
            os.path.join(p, 'source.wav'), 
            transition_lambda=64, 
        )

        with open(os.path.join(p, 'downbeats.pkl'), 'wb') as f:
            pkl.dump(db, f)
            
    return db

In [None]:
def normalise(y):
    if y.min() < -1 or y.max() > 1:
        y /= max(y.max(), -y.min())
    return y

### Load Audio 

In [None]:
print('Avaliable genres:\n')
print('\n'.join(os.listdir("./PROCESSED/")))

In [None]:
SRC = './PROCESSED/salsa/'
DST = './PROCESSED/techno/'

In [None]:
y_src, sr = librosa.load(os.path.join(SRC, 'DRUMS/drums.wav'), sr=44100)
y_dst, _ = librosa.load(os.path.join(DST, 'DRUMS/drums.wav'), sr=sr)
y_dst_full, _ = librosa.load(os.path.join(DST, 'source.wav'), sr=sr)

In [None]:
play(y_src, sr)
play(y_dst, sr)

In [None]:
stems = []
for s in ['vocals', 'bass', 'other']:
    y_ins, _ = librosa.load(os.path.join(DST, s.upper(), f'{s}.wav'), sr=sr)
    stems.append(y_ins)
y_dst_ins = np.sum(stems, axis=0)

In [None]:
play(y_dst_ins, sr)

In [None]:
tempo_src = librosa.beat.tempo(y=y_src, sr=sr)
tempo_dst = librosa.beat.tempo(y=y_dst, sr=sr)
print(tempo_src, tempo_dst)

In [None]:
db_src = findDownbeats(SRC)
db_dst = findDownbeats(DST)

In [None]:
y_dst_ins, _  = quantiseAudio(y_dst_ins, sr, db_src, hq=True)
y_src, db_src = quantiseAudio(y_src, sr, db_src, hq=True)

In [None]:
# quantise beats
y_dst, db_dst = quantiseAudio(y_dst, sr, db_dst, hq=True)

In [None]:
y_src_metro = librosa.clicks(times=db_src, sr=sr, length=len(y_src))
y_dst_metro = librosa.clicks(times=db_dst, sr=sr, length=len(y_dst))

In [None]:
play(y_src + y_src_metro, sr, normalize=True)
play(y_dst + y_dst_metro, sr, normalize=True)

In [None]:
y_src_warped, y_dst_synced, downbeats, idxs_src, idxs_dst = warpAudio(
    y_src, y_dst, db_src, db_dst, sr
)

In [None]:
assert y_src_warped.shape == y_dst_synced.shape

In [None]:
y_ins_segment = y_dst_ins[idxs_dst[0]: idxs_dst[1]]

In [None]:
play(y_src_warped, sr, autoplay=False)
play(y_dst_synced, sr, autoplay=False)

In [None]:
ax = plot_audio(y_src_warped, sr)
plot_audio(y_dst_synced, sr, ax=ax)
for db in downbeats:
    ax.axvline(db, color='red')

In [None]:
y_metronome = librosa.clicks(times=downbeats, sr=sr, length=len(y_dst_synced))

In [None]:
play(y_metronome + y_dst_synced, sr, normalize=True)

### Decompose Drums

In [None]:
with open('./drum_templates.pkl', 'rb') as f:
    templates = pkl.load(f)
    
kd_temp = templates['kd_temp']
sd_temp = templates['sd_temp']
hh_temp = templates['hh_temp']

W_0 = np.stack([kd_temp, sd_temp, hh_temp], axis=1)
W_0 = torch.from_numpy(W_0).to(device)

In [None]:
W_src, H_src, V_src, _, net_src = getDecomposition(
    y_src_warped, R=3, trainable_W=True, W=W_0, device=device,
)

In [None]:
plotDecomposition(W_src, H_src, V_src, sr=sr)

In [None]:
W_dst, H_dst, V_dst, _, net_dst = getDecomposition(y_dst_synced, R=3, W=W_0, device=device,)

In [None]:
plotDecomposition(W_dst, H_dst, V_dst, sr=sr)

In [None]:
ys = isolateSources(net_dst, phi=net_dst.phi, device=device)
for y in ys:
    play(y, sr, normalize=True)

In [None]:
src_samples = []
for i in range(3):
    samples = getSamples(net_src, i, sr=sr)
    sample = samples[np.argmax(list(map(len, samples)))]
    src_samples.append(sample)

In [None]:
dst_samples = []
for i in range(3):
    samples = getSamples(net_dst, i, sr=sr)
    sample = samples[np.argmax(list(map(len, samples)))]
    dst_samples.append(sample)

In [None]:
for s in dst_samples:
    play(s, sr)

In [None]:
H_src.shape

In [None]:
peaks, _ = find_peaks(np.insert(H_dst[2], 0, 0), prominence=3)
peaks -= 1

In [None]:
plt.plot(H_dst[2])
plt.scatter(peaks, H_dst[2][peaks], c='red');

In [None]:
def reconstructDrums(H, samples, length):
    if len(H) != len(samples):
        raise ValueError(
            f'H shape ({H.shape}) does not match number of samples ({len(samples)})'
        )
    
    y_rec = np.zeros(length)
    
    for j, sample in enumerate(samples):
        act = H[j]
        peaks, _ = find_peaks(np.insert(act, 0, 0), prominence=3)
        peaks -= 1
        time = librosa.frames_to_samples(peaks)
        
        for i, p in enumerate(peaks):
            amp = act[p]# / act.max()
            sl = min(len(sample), length - time[i])
            
            y_rec[time[i]:time[i]+sl] += amp*sample[:sl]
            
    y_rec = normalise(y_rec)
            
    return y_rec

In [None]:
y_src_rec = reconstructDrums(H_src, src_samples, len(y_src_warped))
y_dst_rec = reconstructDrums(H_dst, dst_samples, len(y_dst_synced))

In [None]:
play(y_src_warped, sr)
play(y_src_rec, sr)
print('---')
play(y_dst_synced, sr)
play(y_dst_rec, sr)

In [None]:
plot_audio(y_src_rec);

### Morph Activations 

#### Full Replacement

In [None]:
H_hat = torch.from_numpy(H_src).to(device)
for i in range(H_src.shape[0]):
    scale = net_dst.H[0, i].max() / H_src[i].max()
    H_hat[i] *= scale

In [None]:
V_hat = net_dst.reconstruct(
    net_dst.W, H_hat.unsqueeze(0)
).detach().cpu().squeeze().numpy()

In [None]:
y_r = librosa.griffinlim(V_hat)

In [None]:
y_r_rec = reconstructDrums(H_src, dst_samples, len(y_dst_rec))

In [None]:
# y_final = np.sum([y_r, y_ins_segment[:len(y_r)]], axis=0)
y_final = y_r + y_ins_segment[:len(y_r)]
y_final_rec = y_r_rec + y_ins_segment[:len(y_r_rec)]

In [None]:
print('Convolution reconstruction:')
play(y_r, sr, normalize=True)
play(y_final, sr, normalize=True)

print('Sample reconstruction:')
play(y_r_rec, sr)
play(y_final_rec, sr, normalize=True)

#### Threshold Transition 

In [None]:
# sms = [SmoothTransition(H_dst[i], H_src[i]) for i in range(3)]
sm = SmoothTransition(H_dst, H_src)

In [None]:
y_transformed = []
weights = np.linspace(0, 1, 5)

for t in tqdm(weights):
#     H_w = torch.from_numpy(sm(t)).type(torch.float)
    H_w = sm(t)
    
    for i in range(len(H_dst)):
        H_w[i] *= (H_dst[i].max() / H_w[i].max())
    
#     V_w = net_dst.reconstruct(
#         net_dst.W, H_w.unsqueeze(0)
#     ).detach().cpu().squeeze().numpy()
    
#     y_w = librosa.griffinlim(V_w)
    y_w = reconstructDrums(H_w, src_samples, len(y_dst_synced))
    y_w_final = y_w + np.copy(y_ins_segment[:len(y_w)])
    
    y_w_final = normalise(y_w_final)
    y_transformed.append(y_w_final)

In [None]:
for i, yt in enumerate(y_transformed):
    print(f'SRC transformed {weights[i]:.0%}')
    play(yt, sr)

#### Transition 

In [None]:
db_frames = librosa.time_to_frames(downbeats, sr=sr)[1:]
db_samples = librosa.time_to_samples(downbeats, sr=sr)[1:]

In [None]:
H_src_bars = np.split(H_src, db_frames, axis=1)   # (bar_num, ins, frame)
H_dst_bars = np.split(H_dst, db_frames, axis=1)   # (bar_num, ins, frame)
y_src_bars = np.split(y_src_warped, db_samples)   # (bar_num, samples)
y_dst_bars = np.split(y_dst_synced, db_samples)   # (bar_num, samples)

In [None]:
idx = 0
H_start = H_dst_bars[idx]
H_end = H_src_bars[idx]

In [None]:
sm = SmoothTransition(
    H_start, 
    H_end, 
    H_dst.max(axis=1).reshape((-1, 1)), 
    H_src.max(axis=1).reshape((-1, 1))
)

In [None]:
play(y_src_bars[idx], sr)
play(y_dst_bars[idx], sr)

In [None]:
transition = []
weights = np.linspace(0, 1, 7)
for t in weights:
    H_w = sm(t, power=2, score_threshold=0.1)
    transition.append(H_w)
H_t = np.concatenate(transition, axis=1)

In [None]:
plt.plot(H_t[2]);
plt.plot(H_start[2], lw=1, alpha=0.5)
for db in db_frames[:len(transition)]:
    plt.axvline(db, c='k', ls=':')

In [None]:
H_t.unsqueeze(0).shape

In [None]:
H

In [None]:
np.sum([np.ones(2), np.ones(2)], axis=0)

In [None]:
def reconstructCombined(H, net, device='cpu'):
    ys = []
    for c in range(H.shape[0]):
        H_c = torch.zeros_like(H).to(device)
        H_c[c, :] = H_t[c, :]
        V_c = net.reconstruct(
                net.W, H_c.unsqueeze(0)
            ).detach().cpu().squeeze().numpy()

        ys.append(librosa.griffinlim(V_c, momentum=0.99))

    return np.sum(ys, axis=0)

In [None]:
H_t = torch.tensor(H_t).type(torch.float).to(device)

y_transitioning = reconstructCombined(H_t, net_dst, device)

In [None]:
V_t = net_dst.reconstruct(
    net_dst.W, H_t.unsqueeze(0)
).detach().cpu().squeeze().numpy()
y_transitioning2 = librosa.griffinlim(V_t, momentum=0.99)

In [None]:
y_transitioning3 = reconstructDrums(
    H_t.cpu().numpy(), dst_samples, librosa.frames_to_samples(H_t.shape[1])
)

In [None]:
transition_beeps = librosa.clicks(
    times=downbeats, sr=sr, length=len(y_transitioning)
) * 0.25

In [None]:
play(y_transitioning + transition_beeps, sr, normalize=True)

In [None]:
play(y_transitioning2 + transition_beeps, sr, normalize=True)

In [None]:
play(y_transitioning3, sr, normalize=True)

## Full Track Transition 

In [None]:
H_transition_full = torch.from_numpy(np.concatenate([
#     H_start,
    H_start,
    H_start,
    H_t.cpu().numpy(),
    H_end,
    H_end,
#     H_end,
#     H_end,
#     H_end,
], axis=1)).type(torch.float)

In [None]:
# V_transition_full = net_src.reconstruct(
#         net_src.W, H_transition_full.unsqueeze(0).to(device)
#     ).detach().cpu().squeeze().numpy()
# y_transitioning_full = normalise(librosa.griffinlim(V_transition_full))
y_transitioning_full = reconstructDrums(
    H_transition_full.numpy(), dst_samples, librosa.frames_to_samples(H_transition_full.shape[1])
)

In [None]:
plot_audio(y_transitioning_full, sr)
play(y_transitioning_full, sr)

In [None]:
y_ins_synced = y_dst_ins[idxs_dst[0]:]

In [None]:
full_length = min(len(y_ins_synced), len(y_transitioning_full))
y_full_transition = y_ins_synced[:full_length] + y_transitioning_full[:full_length]
y_full_transition = normalise(y_full_transition)

In [None]:
play(y_full_transition, sr)

#### Percent Transformed 

In [None]:
import wasserstein_transformations
reload(wasserstein_transformations)
from wasserstein_transformations import SmoothTransform, SmoothTransition

In [None]:
sm = SmoothTransform(steps=5)

In [None]:
sm.weights

In [None]:
transformations = []
for i in range(len(H_dst)):
    ps, qs = H_src[i], H_dst[i]
    ps, qs = ps/ps.max(), qs/qs.max()
    
    trans = sm.transform(ps, qs)
    transformations.append(trans)

In [None]:
sm.plot_transform(transformations[1])
plt.show()

In [None]:
y_transformed = []

for w in tqdm(range(len(transformations[0]))):
    H_w = torch.from_numpy(
        np.stack(
            [transformations[i][w] for i in range(len(transformations))]
        )
    ).type(torch.float)
    
    # rescale activations from probability dist
    for i in range(len(H_dst)):
        H_w[i] *= (H_dst[i].max() / H_w[i].max())

    V_w = net_dst.reconstruct(
        net_dst.W, H_w.unsqueeze(0)
    ).detach().cpu().squeeze().numpy()

    y_w = librosa.griffinlim(V_w)
    y_w_final = y_w + np.copy(y_ins_segment[:len(y_w)])
    y_transformed.append(y_w_final)

In [None]:
for i, yt in enumerate(y_transformed):
    print(f'SRC transformed {sm.weights[i]:.0%}')
    play(yt, sr)