#### This notebook is a companion to the `contributing_functions` notebook. 

This will hopefully be a helpful starting place for continuing this project. 

Eventually, we intend to have rewritten some of the GW search analysis that leverages the shared CPU-GPU memory on the Apple silicon hardware. This is coded up through the MLX package which creates its own arrays which we will use. 

So let's start with bringing in MLX. 

In [1]:
# may need to be run for '%run contributing_functions.ipynb'
pip install nbformat

SyntaxError: invalid syntax (3946586278.py, line 2)

In [2]:
import numpy as np
import mlx.core as mx

%run contributing_functions.ipynb

MLX has its own arrays, the `float32` and `complex64` dtypes will be the most useful for us. 

These have all the arithemtic functionality one would expect; addition, subtraction, multiplication, division. However there, are some quality of life features that are missing. 

MLX `complex64` arrays do not have the `array.real`, `array.imag`, or `array.sum` functionality. We create explicit functions to extract these in `contributing_functions.ipynb`.

In [3]:
z = mx.array(4 + 3j)
print('z as complex64: {}'.format(z))

z_real, z_imag = decompose_complex_array_mx(z)
print('real part of z: {}'.format(z_real))
print('imag part of z: {}'.format(z_imag))

z_array = mx.array((4 + 3j, 5 + 6j, 2 + 1j, 3 + 7j))
print('sum of a complex array: {}'.format(complex_sum_mx(z_array)))

z as complex64: array(4+3j, dtype=complex64)
real part of z: array(4, dtype=float32)
imag part of z: array(3, dtype=float32)
sum of a complex array: array(14+17j, dtype=complex64)


#### With these, we can move on to the GW application. 

We will need to recreate some functions that already exist in PyCBC and BILBY using MLX arrays. Let's start with the noise weighted inner product. 

For simplicity, we read in an arbitrary PSD file, but this can be taken from PyCBC as well. 

We will need to flatten out the dynamic range of the PSD values so that the MLX arrays will play nicely. Floats with large numbers in their exponents will become `NAN` or `inf`.

In [4]:
example_PSD = np.loadtxt('KAGRA_design_psd.txt')

DYN_RANGE_FAC = 5.9029581035870565e+35

frequency = example_PSD[:,0]
rescaled_PSD = rescale_PSD_mx(example_PSD[:,1], DYN_RANGE_FAC)

print('Original PSD range: {:.5e}, {:.5e}'.format(np.min(example_PSD[:,1]), np.max(example_PSD[:,1])))
print('Rescaled PSD range: {:.5e}, {:.5e}'.format(np.min(rescaled_PSD), np.max(rescaled_PSD)))

Original PSD range: 1.59473e-47, 2.71941e-28
Rescaled PSD range: 9.41361e-12, 1.60526e+08


Now with the expanded array functionality and the PSD with values close-ish to unity, we are able to calculate the noise-weighted inner product of any two complex arrays, given a specific PSD, in the same way that BILBY calculates this value. 

$(a|b)\propto\int\frac{<a|b>}{PSD}df$

In general, we care just about the real part. 

In [5]:
z1 = mx.array(4 + 3j)
z2 = mx.array(5 + 6j)

duration = 8

NWIP = noise_weighted_inner_product_mx(z1, z2, rescaled_PSD, duration)
print(NWIP)
print(real_mx(NWIP))

array(1.36424e+15+3.23108e+14j, dtype=complex64)
array(1.36424e+15, dtype=float32)


Next we can use MLX FFTs to handle the Fourier transform in place of PyCBC FFTs. 

`mx.fft.fftn` most closely follows the behavior we want to recreate. 

In [6]:
zt_array = mx.fft.fftn(z_array)
print('original    : {}'.format(z_array))
print('transformed : {}'.format(zt_array))


original    : array([4+3j, 5+6j, 2+1j, 3+7j], dtype=complex64)
transformed : array([14+17j, 1+0j, -2-9j, 3+4j], dtype=complex64)


### Future work

Finding places in PyCBC which should be done in MLX (all the places that use CUDA)

Starting with some promising locations, which are more likely to need MLX versions
- Within `pycbc/waveform`
    - `utils_cuda.py` -> specifically the use of `pycuda.complier.SourceModule` (https://github.com/gwastro/pycbc/blob/ff8e49916b4a53ab8f8e3a442d983a23d9816b89/pycbc/waveform/utils_cuda.py#L76C7-L76C19)
    - `decompress_cuda.py` -> again for the use of `pycuda.complier.SourceModule` (https://github.com/gwastro/pycbc/blob/ff8e49916b4a53ab8f8e3a442d983a23d9816b89/pycbc/waveform/decompress_cuda.py#L265)
    - `spa_tmplt_cuda.py` -> for the use of `pycuda.elementwise.ElementwiseKernel` (https://github.com/gwastro/pycbc/blob/ff8e49916b4a53ab8f8e3a442d983a23d9816b89/pycbc/waveform/spa_tmplt_cuda.py#L80)
- Witin `pycbc/filter`
    - `matchedfilter_cuda.py` -> uses multiple `pycuda` tools (https://github.com/gwastro/pycbc/blob/ff8e49916b4a53ab8f8e3a442d983a23d9816b89/pycbc/filter/matchedfilter_cuda.py#L26)
        - `pycuda.elementwise.ElementwiseKernel` 
        - `pycuda.tools.context_dependent_memoize`
        - `pycuda.tools.dtype_to_ctype`
        - `pycuda.gpuarray._get_common_dtype`
- Witin `pycbc/fft`
    - `cuda_pyfft` -> uses `pyfft.cuda.Plan` (https://github.com/gwastro/pycbc/blob/ff8e49916b4a53ab8f8e3a442d983a23d9816b89/pycbc/fft/cuda_pyfft.py#L46)
 

Other places to check
- Within `pycbc/types/array_cuda.py`
    - (https://github.com/gwastro/pycbc/blob/ff8e49916b4a53ab8f8e3a442d983a23d9816b89/pycbc/types/array_cuda.py#L26C8-L26C14)
- https://github.com/search?q=repo%3Agwastro%2Fpycbc+cuda&type=code
- Install specifications (https://github.com/gwastro/pycbc/blob/ff8e49916b4a53ab8f8e3a442d983a23d9816b89/INSTALL#L1)