In [None]:
%matplotlib inline

import librosa
import numpy as np
import matplotlib.pyplot as plt
import soundfile as sf

In [None]:
feature = 'cqt'
filename = 'mary2_s'

w_weight = 100 # nonzero for faces
w_reg = 'exp'
h_weight = 1e-6 # nonzero for faces
h_reg = 'l2'
r = 16 # 20 for audio, 16 for faces
iterations = 5000 # 10000 for audio, 5000 for faces

FACES = False
AUDIO = not FACES

loaded = True

settings = '%s_%s_%s' % ('faces' if FACES else 'audio', h_reg if AUDIO else w_reg, h_weight if AUDIO else w_weight) 
out_w = 'out/w_%s.png' % settings
out_h = 'out/h_%s.png' % settings
out_faces = 'out/faces_%s.png' % settings
out_results = 'out/results_%s.txt' % settings

In [None]:
# Data loading/processing

def load_wav(filename, feature='stft'):
    y, sr = librosa.load(filename)
    if feature == 'stft':
        s = np.abs(librosa.stft(y, 512))
    if feature == 'cqt':
        s = np.abs(librosa.cqt(y, sr=sr))
    if feature == 'mel':
        s = librosa.feature.melspectrogram(y=y, sr=sr)
    return s, sr

def save_wav(filename, s, sr, feature='stft'):
    if feature == 'stft':
        y = librosa.griffinlim(s, sr=sr)
    if feature == 'cqt':
        y = librosa.griffinlim_cqt(s, sr=sr)
    if feature == 'mel':
        y = librosa.feature.inverse.mel_to_audio(s, sr=sr)
    sf.write(filename, y, sr)

def log_norm(feature, clip=1e-6):
    #feature = np.log10(np.maximum(feature, clip))
    bounds = (np.min(feature), np.max(feature))
    return (feature-bounds[0])/(bounds[1]-bounds[0]), bounds
    #return feature, bounds

def inv_log_norm(feature, bounds):
    #return 10 ** (feature * (bounds[1] - bounds[0]) + bounds[0])
    return feature * (bounds[1] - bounds[0]) + bounds[0]
    #return feature

def show_feature(feature):
    plt.figure()
    plt.imshow(V, origin='lower', extent=[-4,4,-1,1])
    plt.tight_layout()
    plt.pause(0.1)

def plot_rows(X, normalized=False, filename=None):
    r, n = X.shape
    bounds = (np.min(X), np.max(X))
    fig, ax = plt.subplots(r, 1, figsize=(10,r/2))
    t = np.linspace(0, 1, n)
    for i, a in enumerate(ax):
        a.plot(t, X[i,:])
        if normalized: a.set_ylim(bounds)
        a.axis('off')
    plt.tight_layout()
    if filename is not None: plt.savefig(filename)
    plt.pause(0.1)

def load_faces():
    from sklearn.datasets import fetch_olivetti_faces
    return fetch_olivetti_faces().data

def plot_gallery(title, images, n_col, n_row, normalized=False, filename=None):
    plt.figure(figsize=(2. * n_col, 2.26 * n_row))
    plt.suptitle(title, size=16)
    for i, comp in enumerate(images):
        plt.subplot(n_row, n_col, i + 1)
        vmax = max(comp.max(), -comp.min()) if normalized else max(images.max(), -images.min())
        plt.imshow(comp.reshape((64,64)), cmap=plt.cm.gray,
                   interpolation='nearest',
                   vmin=0, vmax=vmax)
        plt.xticks(())
        plt.yticks(())
    plt.subplots_adjust(0.01, 0.05, 0.99, 0.93, 0.04, 0.)
    if filename is not None: plt.savefig(filename)

In [None]:
if not loaded:
    if AUDIO:
        s, sr = load_wav('data/%s.wav' % filename, feature=feature)
        V, bounds = log_norm(s)
        print('Bounds: ', bounds)
        show_feature(V)
        reconstruction = inv_log_norm(V, bounds)
        if False:
            save_wav('out/%s/original.wav' % filename, reconstruction , sr, feature=feature)
    if FACES:
        V = load_faces()[:,:]
        plot_gallery('dataset', V, 20, 20)

In [None]:
def euclidean_loss(V, W, H):
    return np.linalg.norm(V - W @ H) ** 2

def l1_loss(W):
    return np.sum(W)

def vector_l1_loss(W, axis=0):
    return np.sum(np.square(np.sum(W, axis=axis)))

def l2_loss(W):
    return np.sum(np.square(W))

def l21_loss(W, axis=0):
    return np.sum(np.linalg.norm(W, axis=axis))

def rand_small(shape):
    return 1e-6 * np.random.ranf(shape)

def percent_zero(W):
    return np.sum(W / np.max(W) < 1e-3) / W.shape[0] / W.shape[1]

def multiplicative(V, w=0, h=0, w_reg=None, h_reg=None, iterations=1000, r=10):
    n, m = V.shape
    W = np.abs(np.random.randn(n, r))
    H = np.abs(np.random.randn(r, m))
    for iter in range(iterations):
        # None
        if w_reg is None: W = W * ((V @ H.T)/( (W @ H) @ H.T))
        if h_reg is None: H = H * ((W.T @ V)/( (W.T @ W) @ H))
        
        # L2
        if w_reg is 'l2': W = W * ((V @ H.T)/( (W @ H) @ H.T + w * W))
        if h_reg is 'l2': H = H * ((W.T @ V)/( (W.T @ W) @ H + h * H))
        
        # L1
        if w_reg is 'l1': W = W * ((V @ H.T)/( (W @ H) @ H.T + w))
        if h_reg is 'l1': H = H * ((W.T @ V)/( (W.T @ W) @ H + h))
        
        # vector-wise L1
        if w_reg is 'l1_col': W = W * ((V @ H.T)/( (W @ H) @ H.T + w * np.sum(W, axis=0)[np.newaxis,:]))
        if w_reg is 'l1_row': W = W * ((V @ H.T)/( (W @ H) @ H.T + w * np.sum(W, axis=1)[:,np.newaxis]))
        if h_reg is 'l1_col': H = H * ((W.T @ V)/( (W.T @ W) @ H + h * np.sum(H, axis=0)[np.newaxis,:]))
        if h_reg is 'l1_row': H = H * ((W.T @ V)/( (W.T @ W) @ H + h * np.sum(H, axis=1)[:,np.newaxis]))
        
        # L2,1
        if w_reg is 'l21_col': W = W * ((V @ H.T)/( (W @ H) @ H.T + w * np.divide(W, np.linalg.norm(W, axis=0)[np.newaxis,:])))
        if w_reg is 'l21_row': W = W * ((V @ H.T)/( (W @ H) @ H.T + w * np.divide(W, np.linalg.norm(W, axis=1)[:,np.newaxis])))
        if h_reg is 'l21_col': H = H * ((W.T @ V)/( (W.T @ W) @ H + h * np.divide(H, np.linalg.norm(H, axis=0)[np.newaxis,:])))
        if h_reg is 'l21_row': H = H * ((W.T @ V)/( (W.T @ W) @ H + h * np.divide(H, np.linalg.norm(H, axis=1)[:,np.newaxis])))

        # log
        if w_reg is 'log': W = W * ((V @ H.T)/( (W @ H) @ H.T + w * np.divide(W, np.square(W)+1)))
        if h_reg is 'log': H = H * ((W.T @ V)/( (W.T @ W) @ H + h * np.divide(H, np.square(H)+1)))

        # exp
        if w_reg is 'exp': W = W * ((V @ H.T)/( (W @ H) @ H.T + w * np.multiply(W, np.exp(-np.square(W)))))
        if h_reg is 'exp': H = H * ((W.T @ V)/( (W.T @ W) @ H + h * np.multiply(H, np.exp(-np.square(H)))))

        if iter % 100 == 0: print('Iteration ', iter, '- Reconstruction Loss: ', euclidean_loss(V, W, H), '\tW: ', percent_zero(W), '\tH: ', percent_zero(H))
    return W, H

In [None]:
W, H = multiplicative(V, w=w_weight, h=h_weight, w_reg=w_reg, h_reg=h_reg, iterations=iterations, r=r)

In [None]:
if AUDIO:
    plot_rows(W.T, normalized=True, filename=out_w)
    plot_rows(H, normalized=True, filename=out_h)
if FACES:
    plot_gallery('', H, 4, 4, normalized=False, filename=out_faces)

In [None]:
%%capture cap --no-stderr
print('H %0: ', np.sum(H / np.max(H) < 1e-3) / H.shape[0] / H.shape[1])
print('W %0: ', np.sum(W / np.max(W) < 1e-3) / W.shape[0] / W.shape[1])
print('Reconstruction: ', euclidean_loss(V, W, H))
with open(out_results, 'w') as f:
    f.write(cap.stdout)

In [None]:
if False:
    save_wav("out/%s/out.wav" % filename, inv_log_norm(W @ H, bounds), sr, feature=feature)
if False:
    print('Generating each source...')
    for i in range(r):
        f = np.outer(W[:,i], H[i,:])
        save_wav("out/%s/out_%d.wav" % (filename, i), inv_log_norm(f, bounds), sr, feature=feature)