In [1]:
# Packages 
# import warnings
# warnings.simplefilter('ignore')

import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import qmc
import glob
import h5py
import time
from datetime import timedelta

# We change the default level of the logger so that
# we can see what's happening with caching.
import sys, os
import logging
logger = logging.getLogger('21cmFAST')
logger.setLevel(logging.INFO)

import py21cmfast as p21c

# For plotting the cubes, we use the plotting submodule:
# from py21cmfast import plotting

# For interacting with the cache
from py21cmfast import cache_tools

# Parallize
from mpi4py import MPI
import multiprocessing
from multiprocessing import Pool

# Cache for intermediate process
cache_direc = "/storage/home/hcoda1/3/bxia34/scratch/_cache"

if not os.path.exists(cache_direc):
    os.mkdir(cache_direc)

p21c.config['direc'] = cache_direc

string_pad_length = 80

--------------------------------------------------------------------------

  Local host:   atl1-1-02-002-26-2
  Local device: mlx5_0
--------------------------------------------------------------------------


In [2]:
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

In [39]:
if rank == 0:
    sampler = qmc.LatinHypercube(d=2, strength=2)
    sample = sampler.random(n=9)
    # send_data = np.random.normal(size=(10,2))
    send_data = sample[:int(sample.shape[0]//size * size),:]
    send_data = send_data.reshape(size, int(send_data.shape[0]/size), send_data.shape[1])
    print(f"Process {rank} scatters data (shape = {send_data.shape}) to {size} nodes".center(string_pad_length,'-'))
else:
    send_data = None
recv_data = comm.scatter(send_data, root=0)
print(f"Process {rank}/{size} recvs data (shape = {recv_data.shape})".center(string_pad_length, '-'))

-------------Process 0 scatters data (shape = (1, 9, 2)) to 1 nodes-------------
--------------------Process 0/1 recvs data (shape = (9, 2))---------------------


In [4]:
# plt.scatter(recv_data.T[0], recv_data.T[1])

In [5]:
def denormalize(norm, kind):
    if kind == 'ION_Tvir_MIN':
        value = norm*2 + 4
    elif kind == 'HII_EFF_FACTOR':
        value = norm*np.log10(25) + 1
        value = 10**value
    else:
        raise TypeError(f"kind \'{kind}\' is not supported.")
    return value

In [40]:
# What to sample
params_node = dict(
    ION_Tvir_MIN = None,
    HII_EFF_FACTOR = None,
)

for i, kind in enumerate(params_node.keys()):
    params_node[kind] = denormalize(recv_data.T[i], kind)
# params_node.values()

In [7]:
def generate_brightness_temp(params_node_value):
    # All parameters
    generate_brightness_temp_start = time.perf_counter()

    pid_cpu = multiprocessing.current_process().pid
    random_seed = np.random.randint(1,2**32) + pid_cpu
    params_cpu = {key: params_node_value[i] for (i, key) in enumerate(params_node.keys())}

    redshift = 11.93 
    user_params = {
        "HII_DIM":60, 
        "BOX_LEN":150, 
        # "USE_INTERPOLATION_TABLE":True
        }
    cosmo_params = dict(
        SIGMA_8 = 0.810,
        hlittle = 0.677,
        OMm = 0.310,
        OMb = 0.0490,
        POWER_INDEX = 0.967,
    )
    astro_params = dict(
        ION_Tvir_MIN = params_cpu['ION_Tvir_MIN'],
        HII_EFF_FACTOR = params_cpu['HII_EFF_FACTOR'],
    )
    # Simulation
    coeval = p21c.run_coeval(
        redshift = redshift,
        user_params = user_params,
        cosmo_params = p21c.CosmoParams(cosmo_params),
        astro_params = p21c.AstroParams(astro_params),
        random_seed = random_seed
    )

    cache_pattern = os.path.join(cache_direc, f"*{coeval.random_seed}*")
    for filename in glob.glob(cache_pattern):
        # print(filename)
        os.remove(filename)

    generate_brightness_temp_end = time.perf_counter()
    time_elapsed = generate_brightness_temp_end - generate_brightness_temp_start
    print(f'cpu {pid_cpu} in {pid_node}, seed {random_seed}, {params_cpu}, cost {timedelta(seconds=time_elapsed)}')
    
    return coeval.brightness_temp

pid_node = os.getpid()
CPU_num = len(os.sched_getaffinity(pid_node))
print(f"node {pid_node}: {CPU_num} CPUs are working on {np.shape(list(params_node.values()))[-1]} groups of params".center(string_pad_length,'-'))

with Pool(CPU_num) as p:
    Pool_start = time.perf_counter()
    images_node = np.array(p.map(generate_brightness_temp, np.array(list(params_node.values())).T))
    Pool_end = time.perf_counter()
    time_elapsed = Pool_end - Pool_start
    print(f"images {images_node.shape} generated by node {pid_node} with {timedelta(seconds=time_elapsed)}".center(string_pad_length,'-'))

-------------node 167085: 2 CPUs are working on 9 groups of params--------------




cpu 167155 in 167085, seed 2352081998, {'ION_Tvir_MIN': 4.565850736820181, 'HII_EFF_FACTOR': 24.860662168060824}, cost 0:00:09.851730
cpu 167156 in 167085, seed 2352081999, {'ION_Tvir_MIN': 5.899184070153515, 'HII_EFF_FACTOR': 12.965212356596512}, cost 0:00:10.210185
cpu 167155 in 167085, seed 2720883129, {'ION_Tvir_MIN': 5.232517403486848, 'HII_EFF_FACTOR': 20.226310871581912}, cost 0:00:09.195817
cpu 167156 in 167085, seed 2720883130, {'ION_Tvir_MIN': 4.437667819817743, 'HII_EFF_FACTOR': 72.69301716312737}, cost 0:00:09.605201
cpu 167156 in 167085, seed 1629976822, {'ION_Tvir_MIN': 4.161351179238974, 'HII_EFF_FACTOR': 212.55567162919678}, cost 0:00:09.161497
cpu 167155 in 167085, seed 1629976821, {'ION_Tvir_MIN': 5.104334486484409, 'HII_EFF_FACTOR': 59.14209176711326}, cost 0:00:12.338349
cpu 167156 in 167085, seed 309723227, {'ION_Tvir_MIN': 4.828017845905641, 'HII_EFF_FACTOR': 172.93252540205214}, cost 0:00:09.366627
cpu 167155 in 167085, seed 309723226, {'ION_Tvir_MIN': 5.77100115

In [35]:
# Save as hdf5
save_direc_name = "./test.h5"

# if os.path.exists(save_direc_name):
#     os.remove(save_direc_name)
# images = np.expand_dims(images_node, axis=0)
images = images_node
params = np.array(list(params_node.values())).T

HII_DIM = images.shape[-1]
# params = np.expand_dims(params, axis=0)
# random_seed = np.array(coeval.random_seed).reshape(1,1)

with h5py.File(save_direc_name, 'a') as f:
    if 'images' not in f.keys():
        f.create_dataset(
            'images', 
            data=images, 
            maxshape=(None, HII_DIM, HII_DIM, HII_DIM)
        )
        f.create_dataset(
            'params',
            data = params,
            maxshape = (None, params.shape[-1]))
        # f.create_dataset(
        #     'random_seed',
        #     data=random_seed,
        #     maxshape=(None,1)
        # )
    else:
        # print(image.shape)
        new_size = f['images'].shape[0] + images.shape[0]
        f['images'].resize(new_size, axis=0)
        f['images'][-images.shape[0]:] = images
        f['params'].resize(new_size, axis=0)
        f['params'][-images.shape[0]:] = params
        # f['random_seed'].resize(new_size, axis=0)
        # f['random_seed'][-1] = random_seed

In [43]:
# fig, axes = plt.subplots(3,3, figsize=(10,10))
# for i, ax in enumerate(axes.flat):
#     ax.imshow(images_node[i][0])
#     ax.set_xlabel('')
#     ax.set_ylabel('')
#     ax.axis('off')
#     ax.set_title(np.array(list(params_node.values())).T[i])
# # plt.margins(x=0, y=0)
# plt.suptitle(params_node.keys())
# plt.tight_layout(w_pad=0.2)
# # plt.close()

In [None]:
# cache_tools.clear_cache(direc=cache_direc)

In [44]:
# # How to read images, params, and random_seed
# with h5py.File(save_direc_name, 'a') as f:
#     images = np.array(f['images'])
#     labels = np.array(f['params'])
#     # seed = np.array(f['random_seed'])
# print(images.shape)
# print(labels.shape)
# # print(seed)