In [None]:
# Packages 
# import warnings
# warnings.simplefilter('ignore')

import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import qmc
import glob
import h5py
import time
from datetime import timedelta

# We change the default level of the logger so that
# we can see what's happening with caching.
import sys, os
import logging
logger = logging.getLogger('21cmFAST')
logger.setLevel(logging.INFO)

import py21cmfast as p21c

# For interacting with the cache
from py21cmfast import cache_tools

# Cache for intermediate process
cache_direc = "/storage/home/hcoda1/3/bxia34/scratch/_cache"

if not os.path.exists(cache_direc):
    os.mkdir(cache_direc)

p21c.config['direc'] = cache_direc

string_pad_length = 80

In [None]:
# Parallize
import multiprocessing
from multiprocessing import Pool
try:
    from mpi4py import MPI
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    size = comm.Get_size()
except ImportError:
    rank = 0
    size =1

In [None]:
if rank == 0:
    sampler = qmc.LatinHypercube(d=2, strength=2)
    sample = sampler.random(n=9)
    # send_data = np.random.normal(size=(10,2))
    send_data = sample[:int(sample.shape[0]//size * size),:]
    send_data = send_data.reshape(size, int(send_data.shape[0]/size), send_data.shape[1])
    print(f"Process {rank} scatters data (shape = {send_data.shape}) to {size} nodes".center(string_pad_length,'-'))
else:
    send_data = None
recv_data = comm.scatter(send_data, root=0)
print(f"Process {rank}/{size} recvs data (shape = {recv_data.shape})".center(string_pad_length, '-'))

In [None]:
# plt.scatter(recv_data.T[0], recv_data.T[1])

In [None]:
def denormalize(norm, kind):
    if kind == 'ION_Tvir_MIN':
        value = norm*2 + 4
    elif kind == 'HII_EFF_FACTOR':
        value = norm*np.log10(25) + 1
        value = 10**value
    else:
        raise TypeError(f"kind \'{kind}\' is not supported.")
    return value

In [None]:
# What to sample
params_node = dict(
    ION_Tvir_MIN = None,
    HII_EFF_FACTOR = None,
)

for i, kind in enumerate(params_node.keys()):
    params_node[kind] = denormalize(recv_data.T[i], kind)
# params_node.values()

In [None]:
def define_params(params: dict) -> None:
    redshift = 11.93 
    user_params = {
        "HII_DIM":60, 
        "BOX_LEN":150, 
        # "USE_INTERPOLATION_TABLE":True
        }
    cosmo_params = dict(
        SIGMA_8 = 0.810,
        hlittle = 0.677,
        OMm = 0.310,
        OMb = 0.0490,
        POWER_INDEX = 0.967,
    )
    astro_params = dict(
        ION_Tvir_MIN = params['ION_Tvir_MIN'],
        HII_EFF_FACTOR = params['HII_EFF_FACTOR'],
    )

In [None]:
# Save as hdf5
# images = images_node
# params = np.array(list(params_node.values())).T

def save(images, params, save_direc_name="./images_params.h5"):
    # if os.path.exists(save_direc_name):
    #     os.remove(save_direc_name)
    HII_DIM = images.shape[-1]
    with h5py.File(save_direc_name, 'a') as f:
        if 'images' not in f.keys():
            f.create_dataset(
                'images', 
                data=images, 
                maxshape=(None, HII_DIM, HII_DIM, HII_DIM)
            )
            f.create_dataset(
                'params',
                data = params,
                maxshape = (None, params.shape[-1]))
            # f.create_dataset(
            #     'random_seed',
            #     data=random_seed,
            #     maxshape=(None,1)
            # )
        else:
            # print(image.shape)
            new_size = f['images'].shape[0] + images.shape[0]
            f['images'].resize(new_size, axis=0)
            f['images'][-images.shape[0]:] = images
            f['params'].resize(new_size, axis=0)
            f['params'][-images.shape[0]:] = params
            # f['random_seed'].resize(new_size, axis=0)
            # f['random_seed'][-1] = random_seed

In [None]:
def generate_brightness_temp(params_node_value):
    # All parameters
    generate_brightness_temp_start = time.perf_counter()

    pid_cpu = multiprocessing.current_process().pid
    random_seed = np.random.randint(1,2**32) + pid_cpu
    params_cpu = {key: params_node_value[i] for (i, key) in enumerate(params_node.keys())}

    define_params(params_cpu)

    # Simulation
    coeval = p21c.run_coeval(
        redshift = redshift,
        user_params = user_params,
        cosmo_params = p21c.CosmoParams(cosmo_params),
        astro_params = p21c.AstroParams(astro_params),
        random_seed = random_seed
    )

    cache_pattern = os.path.join(cache_direc, f"*{coeval.random_seed}*")
    for filename in glob.glob(cache_pattern):
        # print(filename)
        os.remove(filename)

    generate_brightness_temp_end = time.perf_counter()
    time_elapsed = generate_brightness_temp_end - generate_brightness_temp_start
    print(f'cpu {pid_cpu} in {pid_node}, seed {random_seed}, {params_cpu}, cost {timedelta(seconds=time_elapsed)}')
    
    return coeval.brightness_temp

pid_node = os.getpid()
CPU_num = len(os.sched_getaffinity(pid_node))
print(f"node {pid_node}: {CPU_num} CPUs are working on {np.shape(list(params_node.values()))[-1]} groups of params".center(string_pad_length,'-'))

# run p21c.run_coeval in parallel on multi-CPUs
Pool_start = time.perf_counter()
with Pool(CPU_num) as p:
    images_node = np.array(p.map(generate_brightness_temp, np.array(list(params_node.values())).T))
Pool_end = time.perf_counter()
time_elapsed = Pool_end - Pool_start
print(f"images {images_node.shape} generated by node {pid_node} with {timedelta(seconds=time_elapsed)}".center(string_pad_length,'-'))

# save images, params as .h5 file
save(images_node, np.array(list(params_node.values())).T, 'test_save_func.h5')

In [None]:
# fig, axes = plt.subplots(3,3, figsize=(10,10))
# for i, ax in enumerate(axes.flat):
#     ax.imshow(images_node[i][0])
#     ax.set_xlabel('')
#     ax.set_ylabel('')
#     ax.axis('off')
#     ax.set_title(np.array(list(params_node.values())).T[i])
# # plt.margins(x=0, y=0)
# plt.suptitle(params_node.keys())
# plt.tight_layout(w_pad=0.2)
# # plt.close()

In [None]:
# cache_tools.clear_cache(direc=cache_direc)

In [None]:
# # How to read images, params, and random_seed
# with h5py.File(save_direc_name, 'a') as f:
#     images = np.array(f['images'])
#     labels = np.array(f['params'])
#     # seed = np.array(f['random_seed'])
# print(images.shape)
# print(labels.shape)
# # print(seed)

In [21]:
# Packages 
# import warnings
# warnings.simplefilter('ignore')

import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import qmc
import glob
import h5py
import time
from datetime import timedelta

# We change the default level of the logger so that
# we can see what's happening with caching.
import sys, os
import logging
logger = logging.getLogger('21cmFAST')
logger.setLevel(logging.INFO)

import py21cmfast as p21c

# For interacting with the cache
from py21cmfast import cache_tools

# Cache for intermediate process
cache_direc = "/storage/home/hcoda1/3/bxia34/scratch/_cache"

if not os.path.exists(cache_direc):
    os.mkdir(cache_direc)

p21c.config['direc'] = cache_direc

str_pad_len = 80
str_pad_type = '-'

# Parallize
import multiprocessing
from multiprocessing import Pool
try:
    from mpi4py import MPI
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    size = comm.Get_size()
except ImportError:
    rank = 0
    size =1

class Generate21cmImages():
    def __init__(self, **kwargs):
        """
        Generate dataset by 21cmFAST in parallel.
        Input: kwargs = {'param1': [min, max], 'param2': [min, max], ...}
        Output: hdf5 storing images and params.
        """
        self.kwargs = kwargs
        print(f"kwargs = {self.kwargs}".center(str_pad_len, str_pad_type))

        self.sample_normalized_params(dimension=len(self.kwargs), num_groups=9)
        self.denormalize()
        self.define_default_params()

    def sample_normalized_params(self, dimension=2, num_groups=9):
        """
        sample and scatter to other nodes
        """
        if rank == 0:
            sampler = qmc.LatinHypercube(d=dimension, strength=2)
            sample = sampler.random(n=num_groups)
            send_data = sample[:int(sample.shape[0]//size * size),:]
            send_data = send_data.reshape(size, int(send_data.shape[0]/size), send_data.shape[1])
            print(f"Process {rank} scatters data (shape = {send_data.shape}) to {size} nodes".center(str_pad_len,str_pad_type))
        else:
            send_data = None
        self.recv_data = comm.scatter(send_data, root=0)
        print(f"Process {rank}/{size} recvs data (shape = {self.recv_data.shape})".center(str_pad_len, str_pad_type))

    def denormalize(self):
        """
        denormalize data received, and return self.params_node which stores params for each node.
        """
        self.params_node = {}
        for i, kind in enumerate(self.kwargs):
            x = self.recv_data.T[i]
            k = self.kwargs[kind][1]-self.kwargs[kind][0]
            b = self.kwargs[kind][0]
            self.params_node[kind] = k*x + b

    def define_default_params(self):
        self.redshift = 11.93 
        self.user_params = {
            "HII_DIM":60, 
            "BOX_LEN":150, 
            "USE_INTERPOLATION_TABLE":True
            }
        self.cosmo_params = dict(
            SIGMA_8 = 0.810,
            hlittle = 0.677,
            OMm = 0.310,
            OMb = 0.0490,
            POWER_INDEX = 0.967,
            )
        self.astro_params = dict(
            ION_Tvir_MIN = 5,#params['ION_Tvir_MIN'],
            HII_EFF_FACTOR = 100,#params['HII_EFF_FACTOR'],
            )

    def update_params(self):
        params_list = ["user_params", "cosmo_params", "astro_params"]
        for params in params_list:
            for key in self.params_cpu:
                if key in self.__dict__[params]:
                    self.__dict__[params][key] = self.params_cpu[key]


    def generate_brightness_temp(self, params_node_value):
        # All parameters
        generate_brightness_temp_start = time.perf_counter()

        self.pid_cpu = multiprocessing.current_process().pid
        self.random_seed = np.random.randint(1,2**32) + self.pid_cpu

        self.params_cpu = {key: params_node_value[i] for (i, key) in enumerate(self.params_node.keys())}
        self.update_params()

        # Simulation
        coeval = p21c.run_coeval(
            redshift = self.redshift,
            user_params = self.user_params,
            cosmo_params = p21c.CosmoParams(self.cosmo_params),
            astro_params = p21c.AstroParams(self.astro_params),
            random_seed = self.random_seed
        )

        cache_pattern = os.path.join(cache_direc, f"*{coeval.random_seed}*")
        for filename in glob.glob(cache_pattern):
            # print(filename)
            os.remove(filename)

        generate_brightness_temp_end = time.perf_counter()
        time_elapsed = generate_brightness_temp_end - generate_brightness_temp_start
        print(f'cpu {self.pid_cpu} in {self.pid_node}, seed {self.random_seed}, {self.params_cpu}, cost {timedelta(seconds=time_elapsed)}')
        
        return coeval.brightness_temp

    def run_parallel(self, save_direc_name='images_params.h5'):
        self.pid_node = os.getpid()
        CPU_num = len(os.sched_getaffinity(self.pid_node))
        print(f"node {self.pid_node}: {CPU_num} CPUs are working on {np.shape(list(self.params_node.values()))[-1]} groups of params".center(str_pad_len,str_pad_type))

        # run p21c.run_coeval in parallel on multi-CPUs
        Pool_start = time.perf_counter()

        with Pool(CPU_num) as p:
            images_node = np.array(p.map(self.generate_brightness_temp, np.array(list(self.params_node.values())).T))

        Pool_end = time.perf_counter()
        time_elapsed = Pool_end - Pool_start

        print(f"images {images_node.shape} generated by node {self.pid_node} with {timedelta(seconds=time_elapsed)}".center(str_pad_len,str_pad_type))

        # save images, params as .h5 file
        self.save(images_node, np.array(list(self.params_node.values())).T, save_direc_name=save_direc_name)

        return images_node

    # Save as hdf5
    def save(images, params, save_direc_name):
        # if os.path.exists(save_direc_name):
        #     os.remove(save_direc_name)
        HII_DIM = images.shape[-1]
        with h5py.File(save_direc_name, 'a') as f:
            if 'images' not in f.keys():
                f.create_dataset(
                    'images', 
                    data=images, 
                    maxshape=(None, HII_DIM, HII_DIM, HII_DIM)
                )
                f.create_dataset(
                    'params',
                    data = params,
                    maxshape = (None, params.shape[-1]))
                # f.create_dataset(
                #     'random_seed',
                #     data=random_seed,
                #     maxshape=(None,1)
                # )
            else:
                # print(image.shape)
                new_size = f['images'].shape[0] + images.shape[0]
                f['images'].resize(new_size, axis=0)
                f['images'][-images.shape[0]:] = images
                f['params'].resize(new_size, axis=0)
                f['params'][-images.shape[0]:] = params
                # f['random_seed'].resize(new_size, axis=0)
                # f['random_seed'][-1] = random_seed

In [23]:
if __name__ == '__main__':
    kwargs = dict(
        ION_Tvir_MIN = [4,6],#params['ION_Tvir_MIN'],
        HII_EFF_FACTOR = [np.log10(10), np.log10(250)],#params['HII_EFF_FACTOR']
        )
    generator = Generate21cmImages(**kwargs)
    generator.run_parallel("save_to_here.h5")

-kwargs = {'ION_Tvir_MIN': [4, 6], 'HII_EFF_FACTOR': [1.0, 2.3979400086720375]}-
-------------Process 0 scatters data (shape = (1, 9, 2)) to 1 nodes-------------
--------------------Process 0/1 recvs data (shape = (9, 2))---------------------
-------------node 130965: 4 CPUs are working on 9 groups of params--------------




cpu 137512 in 130965, seed 3432576786, {'ION_Tvir_MIN': 4.6073881089433755, 'HII_EFF_FACTOR': 1.4245460691417975}, cost 0:00:09.120088
cpu 137515 in 130965, seed 3432576789, {'ION_Tvir_MIN': 4.215213179507314, 'HII_EFF_FACTOR': 1.8905260720324768}, cost 0:00:09.342297
cpu 137514 in 130965, seed 3432576788, {'ION_Tvir_MIN': 5.940721442276709, 'HII_EFF_FACTOR': 1.1940208881594168}, cost 0:00:09.814470
cpu 137513 in 130965, seed 3432576787, {'ION_Tvir_MIN': 5.274054775610042, 'HII_EFF_FACTOR': 1.150427557013396}, cost 0:00:09.947502
cpu 137515 in 130965, seed 2544749169, {'ION_Tvir_MIN': 5.548546512840648, 'HII_EFF_FACTOR': 1.660000891050096}, cost 0:00:08.547268
cpu 137513 in 130965, seed 2544749167, {'ION_Tvir_MIN': 4.944247803132924, 'HII_EFF_FACTOR': 2.0823875627947546}, cost 0:00:08.886708
cpu 137512 in 130965, seed 2544749166, {'ION_Tvir_MIN': 4.881879846173981, 'HII_EFF_FACTOR': 1.6164075599040753}, cost 0:00:09.778869
cpu 137514 in 130965, seed 2544749168, {'ION_Tvir_MIN': 4.27758

TypeError: Generate21cmImages.save() got multiple values for argument 'save_direc_name'