In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import jax.numpy as jnp
from tqdm.notebook import tqdm

import imageio

import jax
import matplotlib.pyplot as plt

from pmwd import (
    Configuration,
    Cosmology, SimpleLCDM,
    boltzmann,
    white_noise, linear_modes,
    lpt,
    nbody,
    scatter,
)
from pmwd.particles import ptcl_pos
from pmwd.vis_util import plt_2d, CosmicWebNorm
from pmwd.nbody import nbody_init, nbody_step

import nvix.camera as nvixc
import nvix.utils as nvixu

from gaepsi2.painter import paint

In [3]:
ptcl_spacing = 4.
ptcl_grid_shape = (128,) * 3

# setup
conf = Configuration(ptcl_spacing, ptcl_grid_shape, mesh_shape=2,
                     a_nbody_maxstep=1/128)
a_nbody = conf.a_nbody
cosmo = SimpleLCDM(conf)
seed = 0
modes = white_noise(seed, conf)

# IC
cosmo = boltzmann(cosmo, conf)
modes = linear_modes(modes, cosmo, conf)
ptcl, obsvbl = lpt(modes, cosmo, conf)

In [4]:
# range of the cubic box, assumed to be the same for x, y, z
lim = np.array([0, ptcl_spacing * ptcl_grid_shape[0]])
box_size = lim[1] - lim[0]

# the model matrix to shift the box to the origin of the world space
M_model = np.eye(4)
M_model[:3, 3] = np.full(3, -lim.mean())

# the camera is at postive z and facing the origin of the world space
eye = np.array([0, 0, box_size*5])
target = np.array([0, 0, 0])

# field of view angle
fovy = np.arctan2(box_size/2, eye[-1]) * 1.8 * 2

# distance of camera to near and far planes in world space
# set as the closer and further z planes of the cube
near = eye[-1] - box_size
far = eye[-1] + box_size * 2

# window size
window = (512, 512)

def project(ptcl, angle, res, a):
    rot = nvixu.rotation('y', angle * 2 * np.pi)
    eye_n = rot @ eye

    # cube box vertices
    xs, ys, zs = np.meshgrid(lim, lim, lim)
    vs = np.vstack([_.ravel() for _ in (xs, ys, zs)])
    vs = nvixc.shutter(vs, eye_n, target, fovy, near, far, M_model=M_model, window=window)

    X = ptcl_pos(ptcl, conf).T
    X = nvixc.shutter(X, eye_n, target, fovy, near, far, M_model=M_model, window=window)
    sml = 3.
    X = paint(np.array(X.T), np.full(X.shape[1], sml), np.ones(X.shape[1]), window)[0].T

    res.append([a, vs, X])

In [5]:
# forward simulation
fw_res = []

ptcl, obsvbl = nbody_init(a_nbody[0], ptcl, obsvbl, cosmo, conf)
project(ptcl, 0, fw_res, a_nbody[0])

for i, (a_prev, a_next) in enumerate(tqdm(zip(a_nbody[:-1], a_nbody[1:]), total=len(a_nbody)-1)):
    ptcl, obsvbl = nbody_step(a_prev, a_next, ptcl, obsvbl, cosmo, conf)
    project(ptcl, (i + 1) / (len(a_nbody) - 1), fw_res, a_next)

  0%|          | 0/126 [00:00<?, ?it/s]

In [6]:
# reverse simulation
bw_res = []

a_nbody_r = a_nbody[::-1]

for i, (a_prev, a_next) in enumerate(tqdm(zip(a_nbody_r[:-1], a_nbody_r[1:]), total=len(a_nbody_r)-1)):
    ptcl, obsvbl = nbody_step(a_prev, a_next, ptcl, obsvbl, cosmo, conf)
    project(ptcl, (i + 1) / (len(a_nbody) - 1), bw_res, a_next)

  0%|          | 0/126 [00:00<?, ?it/s]

In [7]:
# determine vmax
vmax = 0
for res in fw_res + bw_res:
    vmax = max(vmax, res[2].max())
print(vmax)

121.267365


In [18]:
def plot_image(a, vs, X, title, prefix, i):
    fig, ax = plt.subplots(1, 1, figsize=(10, 10))
    ax.set_xlim(0, window[0]); ax.set_ylim(0, window[1])
    ax.set_title(f'{title}, redshift : {1/a-1:g}')

    nvixu.draw_box(vs, ax)
    ax.imshow(X, extent=(0, window[0], 0, window[1]), cmap='inferno', vmin=0, vmax=vmax,
              interpolation='lanczos', interpolation_stage='rgba')
    fig.savefig(f'tmp/{prefix}_{i}.png', dpi=100, bbox_inches='tight')
    plt.close()

In [19]:
for i, res in enumerate(tqdm(fw_res)):
    plot_image(*res, 'time forward evolution', 'fw', i)

  0%|          | 0/127 [00:00<?, ?it/s]

In [20]:
for i, res in enumerate(tqdm(bw_res)):
    plot_image(*res, 'time reversal evolution', 'bw', i)

  0%|          | 0/126 [00:00<?, ?it/s]

In [21]:
writer = imageio.get_writer('time_evo.mp4', fps=10)
for i in tqdm(range(127)):
    writer.append_data(imageio.imread(f'tmp/fw_{i}.png'))
for i in tqdm(range(126)):
    writer.append_data(imageio.imread(f'tmp/bw_{i}.png'))
writer.close()

  0%|          | 0/127 [00:00<?, ?it/s]



  0%|          | 0/126 [00:00<?, ?it/s]