In [117]:
import numpy as np
import soundfile as sf
import os

In [118]:
base = "./Party Music"
s1 = os.path.join(base, "Adele.mp3")
s2 = os.path.join(base, "Beethoven.mp3")
s3 = os.path.join(base, "DP.mp3")

files = [s1, s2, s3]

In [119]:
signals = []
sr = None

import librosa
for f in files:
    x, sr = librosa.load(f, sr=16000, mono=True)
    signals.append(x)

In [120]:
# print(signals[0])
mnl = min(len(s) for s in signals)
signals = [s[:mnl] for s in signals]
X = np.stack(signals)

np.random.seed(42)
# A = np.random.randn(3, 3)
A = np.array([[1, 0.5, 0.8],
              [0.3, 1, 0.5],
              [0.2, 0.3, 1]])
mixed = A @ X

noise_level = 0.05*np.std(mixed)
noise = noise_level * np.random.randn(*mixed.shape)
mixed_noisy = mixed+noise

#Normalise to avoid clipping
mixed_noisy = mixed_noisy / np.max(np.abs(mixed_noisy), axis=1, keepdims=True) * 0.9

os.makedirs(f"{base}/mixed", exist_ok=True)
for i in range(3):
    sf.write(f"{base}/mixed/song_{i+1}.wav", mixed_noisy[i], sr)

In [121]:
def center(X):
    return X - np.mean(X, axis=1, keepdims=True)

def whiten(X):
    cov = np.cov(X)
    E, D, _ = np.linalg.svd(cov)
    D_inv = np.diag(1.0 / np.sqrt(D + 1e-6))
    X_white = E @ D_inv @ E.T @ X
    return X_white

Xc = center(mixed_noisy)
Xw = whiten(Xc)

def sigmoid(u):
    return 1/(1+np.exp(-u))

In [122]:
from scipy.linalg import sqrtm

def train(X, epochs=100, lr=0.01, batch=20000):

    m, n = X.shape
    W = np.random.randn(m, m)

    for _ in range(epochs):
        idx = np.random.choice(n, batch, replace=False)
        for i in idx:
            x = X[:, i:i+1]
            u = W @ x
            g = sigmoid(u)
            grad = (1 - 2*g) @ x.T + np.linalg.inv(W.T)
            W += lr * grad
        # decorrelate rows
        W = np.linalg.inv(sqrtm(W @ W.T)) @ W
        if (_+1)%(epochs//10)==0:
            print(f"Epoch {_+1}/{epochs} done")
            
    return W

In [123]:
W = train(Xw)
recovered = W @ Xw

os.makedirs(f"{base}/recovered", exist_ok=True)
for i in range(3):
    sf.write(f"{base}/recovered/song_{i+1}.wav", recovered[i] / np.max(np.abs(recovered[i])), sr)

Epoch 10/100 done
Epoch 20/100 done
Epoch 30/100 done
Epoch 40/100 done
Epoch 50/100 done
Epoch 60/100 done
Epoch 70/100 done
Epoch 80/100 done
Epoch 90/100 done
Epoch 100/100 done
