<a href="https://colab.research.google.com/github/WeiShi78/Cross-Cancellation/blob/main/Crosstalk_Cancellation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [85]:
import math
import numpy as np
import scipy.signal
import librosa
import IPython.display as ipd
import soundfile as sf
from scipy.io import wavfile

In [86]:
def rms(sig):
    return np.sqrt(np.mean(sig**2))

In [87]:
def compute_geometry(spkr_to_spkr, lstnr_to_spkr, ear_to_ear):
    S = spkr_to_spkr / 2
    L = lstnr_to_spkr
    r = ear_to_ear / 2
    theta = math.acos(S / (math.sqrt(L**2 + S**2)))
    delta_d = r * (math.pi - 2*theta)
    d1 = math.sqrt(L**2 + (S-r)**2)
    d2 = d1 + delta_d

    # angle from center of head to speaker (used for computing headshadow)
    theta = math.atan(S / L)

    return d1, d2, theta

In [88]:
def headshadow_filter_coefficients(theta, r, sr):
    theta = theta + math.pi/2
    theta0 = 2.618
    alpha_min = 0.5
    c = 343.2
    w0 = c / r
    alpha = 1 + alpha_min/2 + (1-alpha_min/2)*math.cos(theta*math.pi/theta0)
    b = [(alpha+w0/sr)/(1+w0/sr), (-alpha+w0/sr)/(1+w0/sr)]
    a = [1, -(1-w0/sr)/(1+w0/sr)]
    return b, a

In [89]:
def fractional_delay(sig, time, sr):
  tmp = np.zeros(sig.size)
  m = time * sr
  m_int = int(m)
  m_frac = m - m_int;
  for i in range(sig.size):
     index_low = i - (m_int + 1)
     index_high = i - m_int
     low = sig[index_low] if index_low else 0
     high = sig[index_high] if index_high >= 0 else 0
     tmp[i] = high + (low - high) * m_frac;
  sig = tmp;
  return sig

In [90]:
def sum_signals(signals):
    """
    Sum together a list of mono signals
    append zeros to match the longest array
    """
    if not signals:
        return np.array([])
    max_length = max(len(sig) for sig in signals)
    y = np.zeros(max_length)
    for sig in signals:
        padded = np.zeros(max_length)
        padded[0:len(sig)] = sig
        y += padded
    return y


In [91]:
def recursive_cancel(sig, ref, time, attenuation, headshadow, sr, threshold_db=-70):
    # delay and invert
    cancel_sig =  -1 * fractional_delay(sig, time, sr)
    # apply headshadow filter (lowpass based on theta)
    cancel_sig = scipy.signal.filtfilt(*headshadow, cancel_sig)
    # attenuate
    cancel_sig = cancel_sig * attenuation

    # Recurse until rms db is below threshold
    db = 20 * math.log10(np.max(np.abs(cancel_sig)) / ref)
    if db < threshold_db:
        return cancel_sig
    else:
        yield cancel_sig
        yield from recursive_cancel(cancel_sig, ref, time, attenuation, headshadow, sr)

In [92]:
def cancel_crosstalk(signal, d1, d2, headshadow, sr):
    c = 343.2 #speed of sound
    delta_d = abs(d2 - d1)
    time_delay = delta_d / c
    attenuation = (d1) / (d2)
    # Reference max amplitude
    ref = np.max(np.abs(signal))
    cancel_sigs = recursive_cancel(signal, ref, time_delay, attenuation, headshadow, sr)
    cancel_sigs = list(cancel_sigs)
    contralateral = sum_signals(cancel_sigs[0::2])
    ipsilateral = sum_signals(cancel_sigs[1::2])
    return ipsilateral, contralateral

In [93]:
def sum_signals(signals):
    """
    Sum together a list of mono signals
    append zeros to match the longest array
    """
    if not signals:
        return np.array([])
    max_length = max(len(sig) for sig in signals)
    y = np.zeros(max_length)
    for sig in signals:
        padded = np.zeros(max_length)
        padded[0:len(sig)] = sig
        y += padded
    return y

In [94]:
y, sr = librosa.load('/content/lib_2.wav', sr=44100, mono=False)
print(y.shape)
left = y[0, :400000]
right = y[1, :400000]
ytest = np.vstack([left, right])

(2, 11931277)


In [111]:
spkr_to_spkr = 0.5
lstnr_to_spkr = 1.5
ear_to_ear = 0.23
d1, d2, theta = compute_geometry(spkr_to_spkr, lstnr_to_spkr, ear_to_ear)

In [112]:
headshadow = headshadow_filter_coefficients(theta, ear_to_ear/2, sr)

In [113]:
l_left, l_right = cancel_crosstalk(left, d1, d2, headshadow, sr)
r_right, r_left = cancel_crosstalk(right, d1, d2, headshadow, sr)

In [114]:
left_out = sum_signals([l_left, r_left, left])
right_out = sum_signals([l_right, r_right, right])

stereo = np.vstack([left_out, right_out])

In [115]:
print(stereo.shape)

(2, 400000)


In [116]:
ipd.Audio(ytest, rate=sr)

In [108]:
ipd.Audio(stereo, rate=sr)

In [109]:
sf.write('dummyhead.wav', np.transpose(ytest), sr)

In [110]:
sf.write('output.wav', np.transpose(stereo), sr)