In [1]:
import jax.numpy as jnp
import jax
import matplotlib.pyplot as plt
import os

In [2]:
def load_sequence(path, grid_size=128):
    # get all files in path
    files = os.listdir(path)
    # sort files
    files.sort()
    print(files)
    # read binary files
    data = []
    for file in files:
        with open(os.path.join(path, file), 'rb') as f:
            grid = jnp.frombuffer(f.read(), dtype=jnp.float32)
            grid = grid.reshape(grid_size, grid_size, grid_size)
            data.append(grid)

    data = jnp.array(data)
    
    return data

density_sequence = load_sequence('data/grid/001')


print(density_sequence.shape)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


['grid.00000.0', 'grid.00010.0', 'grid.00020.0', 'grid.00030.0', 'grid.00040.0', 'grid.00050.0', 'grid.00060.0', 'grid.00070.0', 'grid.00080.0', 'grid.00090.0', 'grid.00100.0']
(11, 128, 128, 128)


In [3]:
def overdensity(density):
    mean = density.mean()
    return (density - mean) / mean

def resize_sequence(density_sequence, size):
    resize_func = lambda x: jax.image.resize(x, (size, size, size), method='linear')
    density_sequence = jax.vmap(resize_func)(density_sequence)
    return density_sequence
    
def power(density_sequence, bins = 30, binning='linear'):

    size = density_sequence.shape[1]

    overdensity_sequence = jax.vmap(overdensity)(density_sequence)
    
    fft_sequence = jax.vmap(jnp.fft.fftn)(overdensity_sequence)
    fft_sequence = jax.vmap(jnp.fft.fftshift)(fft_sequence)

    # generate wave numbers
    kx = jnp.fft.fftfreq(size, d=1/size)
    ky = jnp.fft.fftfreq(size, d=1/size)
    kz = jnp.fft.fftfreq(size, d=1/size)
    kx, ky, kz = jnp.meshgrid(kx, ky, kz, indexing='ij')

    # wave number magnitude
    k = jnp.sqrt(kx**2 + ky**2 + kz**2)

    if binning == 'log':
        k = jnp.log10(k)
    if binning == 'linear':
        k = jnp.floor(k)

    # obtain real space power spectrum
    power_spectrum = jnp.abs(fft_sequence)**2

    hist_func_power = lambda x: jnp.histogram(k, bins=bins, weights=x)[0]
    f_power = jax.vmap(hist_func_power)(power_spectrum)

    n_power = jnp.histogram(k, bins=bins)[0]

    normalized = f_power / n_power
    return normalized   

resized_density_sequence = resize_sequence(density_sequence, 32)
power_spectrum_sequence = power(resized_density_sequence, bins=20, binning='linear')
