In [1]:
import numpy as np
import mlx.core as mx

# getting the inner product and real extraction to work

In [2]:
a_re = mx.array(3.)
a_im = mx.array(5.)
b_re = mx.array(4.)
b_im = mx.array(7.)

# inner product (dot product)
c_re = a_re * a_re - a_im * a_im
c_im = 2 * a_re * a_im

print(c_re, c_im)

array(-16, dtype=float32) array(30, dtype=float32)


In [12]:
def inner_product(a_re, a_im, b_re, b_im):
    # <a|b>
    c_re = a_re * b_re - a_im * b_im
    c_im = a_re * b_im + b_re * a_im
    return c_re, c_im

def noise_weighted_inner_product(a_re, a_im, b_re, b_im, PSD, duration):
    ''' from Bilby directly: 
    integrand = np.conj(aa) * bb / power_spectral_density
    return 4 / duration * np.sum(integrand) 
    # this also follows the PyCBC convention, the conjugated vector is given first
    '''
    numerator_re, numerator_im = inner_product(a_re, a_im, b_re, b_im)
    integrand_re = numerator_re / PSD
    integrand_im = numerator_im / PSD
    result_re = 4 / duration * sum(integrand_re)
    result_im = 4 / duration * sum(integrand_im)
    return result_re, result_im


In [3]:
a_re = mx.array((3., 4., 3., 3.))
a_im = mx.array((5., 7., 5., 5.))
b_re = mx.array((4., 4., 3., 7.))
b_im = mx.array((7., 7., 5., 7.))

c_re, c_im = inner_product(a_re, a_im, a_re, a_im)
print(c_re, c_im)

array([-16, -33, -16, -16], dtype=float32) array([30, 56, 30, 30], dtype=float32)


In [5]:
mock_PSD = mx.array((1e-20, 8e-21, 6e-21, 9e-21))
duration = 8

result_re, result_im = noise_weighted_inner_product(a_re, a_im, b_re, b_im, mock_PSD, duration)
print(result_re, result_im)

array(1.23569e+22, dtype=float32) array(-7.27778e+20, dtype=float32)


In [61]:
# messing around with array.real
def get_real(z):
    real = (z + np.conjugate(z))/2
    real_str_keep = str(real).split('+')[0]
    return float(real_str_keep[1:])

a = complex(3, 5)
print(get_real(a))
print(a.real)

3.0
3.0


In [81]:
def get_real_mx(z):
    real = (z + mx.conjugate(z))/2
    real_str_keep = str(real).split('+')[0].split('-')[0]
    return float(real_str_keep[7:])

print(a_re_fft)
a_re_re = get_real_mx(a_re_fft)
print(a_re_re)

array([13+0j, 0-1j, -1+0j, 0+1j], dtype=complex64)
13.0


In [106]:
array = a_re_fft
print(array.dtype)
str(array.dtype) == 'mlx.core.complex64'

mlx.core.complex64


True

In [66]:
def get_real_mx(array):
    if str(array.dtype) == 'mlx.core.complex64':
        container = []
        for iii in range(array.size):
            z = array[iii]
            real = (z + mx.conjugate(z))/2
            real_str1 = str(real).split('(')[1].split('+')[0]
            if real_str1[0] == '-':
                real_str2 = str('-') + real_str1.split('-')[1]
            else: 
                real_str2 = real_str1.split('-')[0]
            container.append(float(real_str2))
            container_mx = mx.array(container)
        return container_mx
    else: 
        return array

test_array = a_re_fft
# test_array = a_re
print(test_array)
get_real_mx(test_array)    

array([13+0j, 0-1j, -1+0j, 0+1j], dtype=complex64)


array([13, 0, -1, 0], dtype=float32)

In [98]:
iii = 3
print(test_array[iii])
imag = str(test_array[iii])

print(imag.split('(')[1].split(',')[0].split('+')[0])

print(imag.split('j')[0].split('+')[0])
#print(imag.split('j')[0].split('-')[1])


array(0+1j, dtype=complex64)
0
array(0


In [106]:
def imag_mx_DTYPE_HACK(array):
    if str(array.dtype) == 'mlx.core.complex64':
        container = []
        for iii in range(array.size):
            z = array[iii]
            imag = (z - mx.conjugate(z))/2j
            imag_str1 = str(imag).split('(')[1].split(',')[0]
            if imag_str1[-1] == 'j':
                imag_str2 = str(imag_str1).split('j')[0].split('+')[0]
            else: 
                imag_str2 = '-' + str(imag_str1).split('j')[0].split('-')[1]
            container.append(float(imag_str2))
            container_mx = mx.array(container)
        return container_mx
    else: 
        return array

test_array = a_re_fft
# test_array = a_re
print(test_array)
imag_mx_HACK(test_array)    

array([13+0j, 0-1j, -1+0j, 0+1j], dtype=complex64)


array([0, -1, 0, 1], dtype=float32)

In [114]:
z = test_array[1]
print(z)
print(real_mx((real_mx(z)-z)*1j))

array(0-1j, dtype=complex64)
array(-1, dtype=float32)


In [None]:
def imag_mx(array):
    return real_mx((real_mx(array)-array)*1j)

In [104]:
print(test_array)
print(mx.array.astype(test_array, dtype=mx.float32))

array([13+0j, 0-1j, -1+0j, 0+1j], dtype=complex64)
array([13, 0, -1, 0], dtype=float32)


In [2]:
# checking out some things

a_float = mx.array((1., 2., 3., 4.))
a_test = mx.array.astype(a_float, dtype=mx.complex64)
a_test2 = mx.array.astype(a_test, dtype=mx.complex64)
print(a_float)
print(a_test)
#print(mx.sum(a_test2))

array([1, 2, 3, 4], dtype=float32)
array([1+0j, 2+0j, 3+0j, 4+0j], dtype=complex64)


In [13]:
print(mx.sum(a_float))
print(mx.conjugate(a_float))
print(inner_product_mx(a_float,a_float))
print(complex_sum_mx(a_test))

array(10, dtype=float32)
array([1, 2, 3, 4], dtype=float32)
array([1, 4, 9, 16], dtype=float32)
array(10+0j, dtype=complex64)


In [11]:
sum = mx.array(0)
print(sum)
# for iii in range(a_float.size):
for iii in range(a_test.size):
    # element = a_float[iii]
    element = a_test[iii]
    sum += element
    print(sum)

print("done!")
print(sum)

array(0, dtype=int32)
array(1+0j, dtype=complex64)
array(3+0j, dtype=complex64)
array(6+0j, dtype=complex64)
array(10+0j, dtype=complex64)
done!
array(10+0j, dtype=complex64)


In [65]:
# test = mx.array((4., 4., 3., 4.))
test = mx.array((1., 2., 3., 4.))
test_complex = mx.array.astype(test, dtype=mx.complex64)
test_real = mx.array.astype(test_complex, dtype=mx.float32)

print(test)
print(test_complex)
print(test_real)

test_fft = mx.fft.fftn(test)
test_fft_real = mx.array.astype(test_fft, dtype=mx.float32)
print(test_fft)
print(test_fft_real)


array([1, 2, 3, 4], dtype=float32)
array([1+0j, 2+0j, 3+0j, 4+0j], dtype=complex64)
array([1, 2, 3, 4], dtype=float32)
array([10+0j, -2+2j, -2+0j, -2-2j], dtype=complex64)
array([10, -2, -2, -2], dtype=float32)


In [12]:
def make_dtype_complex64_mx(array):
    return mx.array.astype(array, dtype=mx.complex64)

def inner_product_mx(a, b): 
    return mx.conjugate(a) * b

def complex_sum_mx(array):
    sum = mx.array(0)
    for iii in range(array.size):
        element = array[iii]
        sum += element
    return sum

def noise_weighted_inner_product_mx(a, b, PSD, duration): 
    ''' from Bilby directly: 
    integrand = np.conj(aa) * bb / power_spectral_density
    return 4 / duration * np.sum(integrand) 
    this also follows the PyCBC convention, the conjugated vector is given first
    '''
    integrand = inner_product_mx(a,b) / PSD
    return 4 / duration * complex_sum_mx(integrand)


# functions to keep

In [109]:
### functions to keep
def split_inner_product_mx(a_re, a_im, b_re, b_im):
    ''' calculates the inner product of two vectors a, b such that <a|b>
    '''
    c_re = a_re * b_re - a_im * b_im
    c_im = a_re * b_im + b_re * a_im
    return c_re, c_im

def split_noise_weighted_inner_product_mx(a_re, a_im, b_re, b_im, PSD, duration):
    ''' from Bilby directly: 
    integrand = np.conj(aa) * bb / power_spectral_density
    return 4 / duration * np.sum(integrand) 
    this also follows the PyCBC convention, the conjugated vector is given first
    '''
    numerator_re, numerator_im = inner_product_mx(a_re, a_im, b_re, b_im)
    integrand_re = numerator_re / PSD
    integrand_im = numerator_im / PSD
    result_re = 4 / duration * sum(integrand_re)
    result_im = 4 / duration * sum(integrand_im)
    return result_re, result_im

def real_mx_DTYPE_HACK(array):
    ''' this is a silly way to extract the real part of a complex number
    we calculate the real, transform it into a string, split the string to recover 
    the characters that constitute the real part, and transform that into a float
    dtype: complex64 -> float32
    '''
    if str(array.dtype) == 'mlx.core.complex64':
        container = []
        for iii in range(array.size):
            z = array[iii]
            real = (z + mx.conjugate(z))/2
            real_str1 = str(real).split('(')[1].split('+')[0]
            if real_str1[0] == '-':
                real_str2 = str('-') + real_str1.split('-')[1]
            else: 
                real_str2 = real_str1.split('-')[0]
            container.append(float(real_str2))
            container_mx = mx.array(container)
        return container_mx
    else: 
        return array

def imag_mx_DTYPE_HACK(array):
    if str(array.dtype) == 'mlx.core.complex64':
        container = []
        for iii in range(array.size):
            z = array[iii]
            imag = (z - mx.conjugate(z))/2j
            imag_str1 = str(imag).split('(')[1].split(',')[0]
            if imag_str1[-1] == 'j':
                imag_str2 = str(imag_str1).split('j')[0].split('+')[0]
            else: 
                imag_str2 = '-' + str(imag_str1).split('j')[0].split('-')[1]
            container.append(float(imag_str2))
            container_mx = mx.array(container)
        return container_mx
    else: 
        return array

def make_dtype_complex64_mx(array):
    return mx.array.astype(array, dtype=mx.complex64)

def inner_product_mx(a, b): 
    return mx.conjugate(a) * b

def complex_sum_mx(array):
    sum = mx.array(0)
    for iii in range(array.size):
        element = array[iii]
        sum += element
    return sum

def noise_weighted_inner_product_mx(a, b, power_spectral_density, duration): 
    ''' from Bilby directly: 
    integrand = np.conj(aa) * bb / power_spectral_density
    return 4 / duration * np.sum(integrand) 
    this also follows the PyCBC convention, the conjugated vector is given first
    '''
    integrand = inner_product_mx(a,b) / power_spectral_density
    return 4 / duration * complex_sum_mx(integrand)
    
def real_mx(array):
    ''' keeps the real part of a complex64 array by changing type
    dtype: complex64 -> float32
    '''
    return mx.array.astype(array, dtype=mx.float32)

def imag_mx(array):
    ''' keeps the imaginary part of a complex64 array by changing type
    dtype: complex64 -> float32
    '''
    return real_mx((real_mx(array)-array)*1j)

def decompose_complex_array_mx(array):
    ''' saves the real and imaginary parts of a complex64 array as float32 arrays
    '''
    return real_mx(array), imag_mx(array)

def rescale_PSD(power_spectral_density, dynamic_range_factor):
    ''' rescaling the dynamic range of the PSD, forcing use of numpy as MLX
    does not play nicely with the large exponents
    '''
    return np.array(power_spectral_density) * dynamic_range_factor



In [116]:
dir(mx)

['Device',
 'DeviceType',
 'Dtype',
 'DtypeCategory',
 'Inf',
 'Infinity',
 'NAN',
 'NINF',
 'NZERO',
 'NaN',
 'PINF',
 'PZERO',
 'Stream',
 'StreamContext',
 '_ArrayAt',
 '_ArrayIterator',
 '__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__spec__',
 '__version__',
 'abs',
 'add',
 'addmm',
 'all',
 'allclose',
 'any',
 'arange',
 'arccos',
 'arccosh',
 'arcsin',
 'arcsinh',
 'arctan',
 'arctan2',
 'arctanh',
 'argmax',
 'argmin',
 'argpartition',
 'argsort',
 'array',
 'array_equal',
 'as_strided',
 'async_eval',
 'atleast_1d',
 'atleast_2d',
 'atleast_3d',
 'bfloat16',
 'bitwise_and',
 'bitwise_or',
 'bitwise_xor',
 'block_masked_mm',
 'block_sparse_mm',
 'block_sparse_qmm',
 'bool_',
 'broadcast_to',
 'ceil',
 'checkpoint',
 'clip',
 'compile',
 'complex64',
 'complexfloating',
 'concatenate',
 'conj',
 'conjugate',
 'conv1d',
 'conv2d',
 'conv3d',
 'conv_general',
 'convolve',
 'cos',
 'cosh',
 'cpu',
 'cummax',
 'cummin',
 'cumprod',
 'cumsum',
 'custom_funct

In [117]:
print(mx.__file__)

/opt/anaconda3/envs/mj_pycbc_test/lib/python3.12/site-packages/mlx/core.cpython-312-darwin.so


# messing around with FFTs

In [35]:
a_re = mx.array((3., 4., 3., 3.))
a_im = mx.array((5., 7., 5., 5.))
b_re = mx.array((4., 4., 3., 7.))
b_im = mx.array((7., 7., 5., 7.))

a_re_fft = mx.fft.fftn(a_re)
a_re_rfft = mx.fft.rfftn(a_re)
print(a_re_rfft)
print(a_re_fft)

b_re_rfft = mx.fft.rfftn(b_re)
b_re_fft = mx.fft.fftn(b_re)
print(b_re_rfft)
print(b_re_fft)

array([13+0j, 0-1j, -1+-0j], dtype=complex64)
array([13+0j, 0-1j, -1+0j, 0+1j], dtype=complex64)
array([18+0j, 1+3j, -4+-0j], dtype=complex64)
array([18+0j, 1+3j, -4+0j, 1-3j], dtype=complex64)


In [36]:
a_re_fft_conj = mx.conjugate(a_re_fft)


In [37]:
print(complex_sum_mx(a_re_fft))

array(12+0j, dtype=complex64)


# implementing dynamic range correction

In [42]:
# directly from pycbc
'''For PSDs taken from models or text files, if `dyn_range_factor` is
not None, then the PSD is multiplied by `dyn_range_factor` ** 2.

# Dynamic range factor: a large constant for rescaling
# GW strains.  This is 2**69 rounded to 17 sig.fig.
DYN_RANGE_FAC =  5.9029581035870565e+20
'''

mock_PSD = mx.array((1e-20, 8e-21, 6e-21, 9e-21))
mock_PSD2 = mx.array((1e-40, 8e-41, 6e-41, 9e-41))
duration = 8
DYN_RANGE_FAC = mx.array((5.9029581035870565e+20))
DYN_RANGE_FAC2 = mx.array((5.9029581035870565e+40))

print(mock_PSD)
print(mock_PSD * DYN_RANGE_FAC)
print(mock_PSD2)
print(mock_PSD2 * DYN_RANGE_FAC)
print(mock_PSD2 * DYN_RANGE_FAC2)

array([1e-20, 8e-21, 6e-21, 9e-21], dtype=float32)
array([5.90296, 4.72237, 3.54177, 5.31266], dtype=float32)
array([9.99995e-41, 8.00001e-41, 5.99994e-41, 8.99998e-41], dtype=float32)
array([0, 0, 0, 0], dtype=float32)
array([nan, nan, nan, nan], dtype=float32)


In [133]:
mock_PSD_np = np.array(mock_PSD)
mock_PSD2_np = np.array(mock_PSD2_np)

DYN_RANGE_FAC_np = 5.9029581035870565e+20
DYN_RANGE_FAC2_np = 5.9029581035870565e+40
DYN_RANGE_FAC = mx.array((5.9029581035870565e+20))

new_mock_PSD_np = mock_PSD_np * DYN_RANGE_FAC_np
new_mock_PSD2_np = mock_PSD2_np * DYN_RANGE_FAC2_np

new_mock_PSD2_mx = mx.array(new_mock_PSD2_np)

print(mock_PSD_np)
print(new_mock_PSD_np)
print(mock_PSD2_np)
print(mock_PSD2_np * DYN_RANGE_FAC)
print(new_mock_PSD2_np)
print(new_mock_PSD2_mx)


[1.e-20 8.e-21 6.e-21 9.e-21]
[5.902958  4.7223663 3.5417747 5.312662 ]
[1.e-40 8.e-41 6.e-41 9.e-41]
[5.90295810e-20 4.72236648e-20 3.54177486e-20 5.31266229e-20]
[5.9029581  4.72236648 3.54177486 5.31266229]
array([5.90296, 4.72237, 3.54177, 5.31266], dtype=float32)


In [30]:
mock_PSD_complex = make_dtype_complex64_mx(mock_PSD)
mock_PSD2_complex = make_dtype_complex64_mx(mock_PSD2)

print(mock_PSD2_complex)
print(mock_PSD2_complex * DYN_RANGE_FAC2)


array([9.99995e-41+0j, 8.00001e-41+0j, 5.99994e-41+0j, 8.99998e-41+0j], dtype=complex64)
array([0+0j, 0+0j, 0+0j, 0+0j], dtype=complex64)


In [137]:
def rescale_PSD(power_spectral_density, dynamic_range_factor):
    ''' rescaling the dynamic range of the PSD, forcing use of numpy as MLX
    does not play nicely with the large exponents
    '''
    return np.array(power_spectral_density) * dynamic_range_factor

In [136]:
new_PSD = rescale_PSD(mock_PSD, DYN_RANGE_FAC)
new_PSD2 = rescale_PSD(mock_PSD2, DYN_RANGE_FAC2)

new_PSD_np = rescale_PSD(mock_PSD_np, DYN_RANGE_FAC_np)
new_PSD2_np = rescale_PSD(mock_PSD2_np, DYN_RANGE_FAC2_np)

print(new_PSD)
print(new_PSD2)
print(new_PSD_np)
print(new_PSD2_np)

[5.902958  4.7223663 3.5417747 5.312662 ]
[inf+nanj inf+nanj inf+nanj inf+nanj]
[5.902958  4.7223663 3.5417747 5.312662 ]
[5.9029581  4.72236648 3.54177486 5.31266229]


  return np.array(power_spectral_density) * dynamic_range_factor


In [132]:
print(DYN_RANGE_FAC2_np)

5.902958103587057e+40


# messing around with division

In [5]:
a_float = mx.array((1., 2., 3., 4.))
a_compl = mx.fft.fftn(mx.array.astype(a_float, dtype=mx.complex64))

division = a_float / a_compl
print(division)


array([0.1+0j, -0.5-0.5j, -1.5+-0j, -1+1j], dtype=complex64)
