# 05 All Combined (Drum Resynthesis + Swing Aligned) 

In [None]:
import os
import sys
print(sys.version)

In [None]:
import pickle as pkl

import numpy as np
import matplotlib.pyplot as plt

import torch
import librosa
import pyrubberband

from tqdm.auto import tqdm
from IPython.display import display, Audio

In [None]:
sys.path.append('COMMON_UTILS/')

In [None]:
from utils import plot_audio, play, normalise

from tempo_align import warpAudio, quantiseAudio

from drum_processor import getDownbeats
from drum_decomposition import (
    getDecomposition, plotDecomposition, isolateSources, reconstructDrums, getSamples
)

from wasserstein_transformations import SmoothTransition

from swing_align import alignSwing, divideTimes

In [None]:
plt.rcParams['figure.figsize'] = (15, 5)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'computing on {device}')

In [None]:
def findDownbeats(p):
    try:
        with open(os.path.join(p, 'downbeats.pkl'), 'rb') as f:
            db = pkl.load(f)
        print('found downbeats.pkl')
    except FileNotFoundError:
        print('computing downbeats')
        db = getDownbeats(
            os.path.join(p, 'source.wav'), 
            transition_lambda=64, 
        )

        with open(os.path.join(p, 'downbeats.pkl'), 'wb') as f:
            pkl.dump(db, f)
            
    return db

#### 01 Select Targets

In [None]:
print("avaliable songs names:\n")
print('\n'.join(sorted(os.listdir("./PROCESSED/"))))

In [None]:
TRG = './PROCESSED/jazz3/'
y_trg, sr = librosa.load(os.path.join(TRG, 'source.wav'), sr=44100)

ORG = './PROCESSED/rock2/'
y_org, _ = librosa.load(os.path.join(ORG, 'source.wav'), sr=sr)

print('TARGET:')
play(y_trg, sr)
print('ORIGINAL:')
play(y_org, sr)

#### 02 Get Stems 

In [None]:
y_trg_drums, sr = librosa.load(os.path.join(TRG, 'DRUMS/drums.wav'), sr=sr)
y_org_drums, _ = librosa.load(os.path.join(ORG, 'DRUMS/drums.wav'), sr=sr)
y_org_full, _ = librosa.load(os.path.join(ORG, 'source.wav'), sr=sr)

In [None]:
play(y_trg_drums, sr)
play(y_org_drums, sr)
play(y_org_full, sr)

In [None]:
stems = []
for s in ['vocals', 'bass', 'other']:
    y_ins, _ = librosa.load(os.path.join(ORG, s.upper(), f'{s}.wav'), sr=sr)
    stems.append(y_ins)
y_org_ins = np.sum(stems, axis=0)
play(y_org_ins, sr)

#### 03 Get Downbeats 

In [None]:
db_org = findDownbeats(ORG)
db_trg = findDownbeats(TRG)

In [None]:
# tempo_trg = librosa.beat.tempo(y=y_trg, sr=sr)
# tempo_org = librosa.beat.tempo(y=y_org, sr=sr)
tempo_trg = 60/(np.diff(db_trg).mean()/4)
tempo_org = 60/(np.diff(db_org).mean()/4)
print(tempo_trg, tempo_org)

In [None]:
# fix tempo
# db_org = db_org[::2]

#### 04 Quantise Audio 

In [None]:
y_trg, _ = quantiseAudio(y_trg, sr, db_trg, hq=True)
y_org, _ = quantiseAudio(y_org, sr, db_org, hq=True)
y_trg_drums, db_trg = quantiseAudio(y_trg_drums, sr, db_trg, hq=True)
y_org_drums, db_org = quantiseAudio(y_org_drums, sr, db_org, hq=True)
y_org_ins, _  = quantiseAudio(y_org_ins, sr, db_org, hq=True)

#### 05 Sync Audio 

In [None]:
y_trg_warped, y_org_synced, downbeats, idxs_trg, idxs_org = warpAudio(
    y_trg_drums, y_org_drums, db_trg, db_org, sr
)

In [None]:
assert y_trg_warped.shape == y_org_synced.shape

In [None]:
y_ins_segment = y_org_ins[idxs_org[0]: idxs_org[1]]

#### 06 Decompose Drum Audio 

In [None]:
with open('./drum_templates.pkl', 'rb') as f:
    templates = pkl.load(f)
    
kd_temp = templates['kd_temp']
sd_temp = templates['sd_temp']
hh_temp = templates['hh_temp']

W_0 = np.stack([kd_temp, sd_temp, hh_temp], axis=1)
W_0 = torch.from_numpy(W_0).to(device)

In [None]:
W_trg, H_trg, V_trg, _, net_trg = getDecomposition(
    y_trg_warped, R=3, W=W_0, trainable_W=True, device=device,
)

In [None]:
W_org, H_org, V_org, _, net_org = getDecomposition(
    y_org_synced, R=3, W=W_0, trainable_W=True, device=device,
)

In [None]:
samples_trg = []
for i in range(3):
    samples = getSamples(net_trg, i, sr)
    sample = samples[np.argmax(list(map(len, samples)))]
    samples_trg.append(sample)

In [None]:
for s in samples_trg:
    play(s, sr)

In [None]:
plot_audio(np.concatenate(samples_trg), sr);

In [None]:
samples_org = []
for i in range(3):
    samples = getSamples(net_org, i, sr)
    try:
        sample = samples[np.argmax(list(map(len, samples)))]
    except ValueError:
        print(f'not samples found for {i}')
        sample = np.zeros(256)
    samples_org.append(sample)

In [None]:
for s in samples_org:
    play(s, sr)

In [None]:
plot_audio(np.concatenate(samples_org), sr);

#### 07 Morph Activations 

In [None]:
sm = SmoothTransition(H_org, H_trg)

In [None]:
weight = 1.0

In [None]:
H_w = sm(weight, power=0.5)

H_w = torch.from_numpy(sm(weight)).type(torch.float).to(device)
for i in range(len(H_org)):
    H_w[i] *= (H_org[i].max() / H_w[i].max())
    
V_w = net_org.reconstruct(
        net_org.W, H_w.unsqueeze(0)
    ).detach().cpu().squeeze().numpy()

#### 08 Reconstruct Audio 

In [None]:
y_w_drums = librosa.griffinlim(V_w)
# if y_w_drums.max() > 1 or y_w_drums.min() < -1:
#     y_w_drums /= max(y_w_drums.max(), -y_w_drums.min())

# y_w_drums = reconstructDrums(H_w, samples_org, len(y_org_synced))
y_w_full = normalise(y_w_drums + np.copy(y_ins_segment[:len(y_w_drums)]))
y_w_full *= y_org_drums.max()

In [None]:
plot_audio(y_w_drums);

In [None]:
print(f'ORIGINAL with {weight:.0%} morphing:')
play(y_w_full, sr, normalize=True)
print(f'isolated resynthesised drums:')
play(y_w_drums, sr, normalize=True)

#### 09 Apply Swing 

In [None]:
from swing_align import getSwingPoints, getSwingMap, getSwingTimings
from tempo_align import matchAudioEvents

In [None]:
points_org = getSwingPoints(y_org, sr, db_org, hop_length=256)
points_trg = getSwingPoints(y_trg, sr, db_trg, hop_length=256)

map_ = getSwingMap(points_org, points_trg)
points_from, points_to = getSwingTimings(db_org, map_)

y_ins_warped = matchAudioEvents(y_org_ins, sr, points_from, points_to, hq=True)
y_ins_warped = y_ins_warped[idxs_org[0]:idxs_org[1]]

In [None]:
map_

In [None]:
print(f'ORIGINAL INSTRUMENTATION:')
play(y_ins_segment, sr)
print(f'with swing applied:')
play(y_ins_warped, sr)

#### 10 Complete Style Transfer 

In [None]:
y_styled = y_w_drums * 0.5 + np.copy(y_ins_warped[:len(y_w_drums)])
y_styled_bpm = pyrubberband.pyrb.time_stretch(y_styled, sr, rate=(tempo_trg / tempo_org)[0])

In [None]:
print(f'ORIGINAL (aligned):')
play(y_org[idxs_org[0]:idxs_org[1]], sr)
print(f'Drum resynthesis + swing applied')
play(y_styled, sr, normalize=True)
print(f'at TRGs tempo {tempo_trg[0]:.0f}bpm (from {tempo_org[0]:.0f}bpm)')
play(y_styled_bpm, sr, normalize=True)

In [None]:
def fadeOut(y, sr, length=1):
    env = np.ones_like(y)
    slope_length = librosa.time_to_samples(length, sr=sr)
    env[-slope_length:] = np.linspace(1, 0, slope_length)
    
    return y * env

In [None]:
y_before_after = np.concatenate(
    [fadeOut(y_org[idxs_org[0]:idxs_org[1]], sr, 1.5), np.zeros(sr), y_styled_bpm]
)

In [None]:
play(y_before_after, sr)