In [None]:
### functions to keep
def split_inner_product_mx(a_re, a_im, b_re, b_im):
    ''' calculates the inner product of two vectors a, b such that <a|b>
    '''
    c_re = a_re * b_re - a_im * b_im
    c_im = a_re * b_im + b_re * a_im
    return c_re, c_im

def split_noise_weighted_inner_product_mx(a_re, a_im, b_re, b_im, PSD, duration):
    ''' from Bilby directly: 
    integrand = np.conj(aa) * bb / power_spectral_density
    return 4 / duration * np.sum(integrand) 
    this also follows the PyCBC convention, the conjugated vector is given first
    '''
    numerator_re, numerator_im = inner_product_mx(a_re, a_im, b_re, b_im)
    integrand_re = numerator_re / PSD
    integrand_im = numerator_im / PSD
    result_re = 4 / duration * sum(integrand_re)
    result_im = 4 / duration * sum(integrand_im)
    return result_re, result_im

def get_real_mx_DTYPE_HACK(array):
    ''' this is a silly way to extract the real part of a complex number
    we calculate the real value, transform it into a string, split the string to recover 
    the characters that constitute the real part, and transform that into a float
    dtype: complex64 -> float32
    '''
    if str(array.dtype) == 'mlx.core.complex64':
        container = []
        for iii in range(array.size):
            z = array[iii]
            real = (z + mx.conjugate(z))/2
            real_str1 = str(real).split('(')[1].split('+')[0]
            if real_str1[0] == '-':
                real_str2 = str('-') + real_str1.split('-')[1]
            else: 
                real_str2 = real_str1.split('-')[0]
            container.append(float(real_str2))
            container_mx = mx.array(container)
        return container_mx
    else: 
        return array

def make_dtype_complex64_mx(array):
    ''' recasts an array such that dtype=complex64
    '''
    return mx.array.astype(array, dtype=mx.complex64)

def inner_product_mx(a, b): 
    ''' calculates the inner product of two vectors a, b such that <a|b>
    '''
    return mx.conjugate(a) * b

def complex_sum_mx(array):
    ''' because array.sum crashes the kernel if dtype=complex64
    '''
    sum = mx.array(0)
    for iii in range(array.size):
        element = array[iii]
        sum += element
    return sum

def noise_weighted_inner_product_mx(a, b, PSD, duration): 
    ''' from Bilby directly: 
    integrand = np.conj(aa) * bb / power_spectral_density
    return 4 / duration * np.sum(integrand) 
    this also follows the PyCBC convention, the conjugated vector is given first
    '''
    integrand = inner_product_mx(a,b) / PSD
    return 4 / duration * complex_sum_mx(integrand)
    
def extract_real_xm(z):
    ''' keeps the real part of a complex64 array
    dtype: complexe64 -> complex64
    '''
    real = (z + np.conjugate(z))/2
    return real
