In [None]:
USE_GPU = False

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

########################################
#Do not touch what's follow

import jax
jax.config.update("jax_enable_x64", True)

if USE_GPU:
	jax.config.update('jax_platform_name', 'gpu')
else:	
	jax.config.update('jax_platform_name', 'cpu')
	
import jax.numpy as jnp

print(jax.__version__)
print(jax.lib.xla_bridge.get_backend().platform)

from jax.experimental import mesh_utils
from jax.sharding import Mesh
from jax.sharding import NamedSharding
from jax.sharding import PositionalSharding
from jax.sharding import PartitionSpec as P

# A handy utility to get the details related to sharding
def get_sharding_details(sharded_data):
    print("\nSharding Layout:")
    
    # JAX provides a handy utility to visualize the sharding
    jax.debug.visualize_array_sharding(sharded_data)

    print("\nSharding Layout details:")

    # We can get detailed information for each shard
    for i, shard in enumerate(sharded_data.global_shards):
        print(f"\nShard no: {i:>5}")
        print(f"Device: {str(shard.device):>32}")
        print(f"Data shape: {str(shard.data.shape):>8}")
        print(f"Data slices: {str(shard.index):>22}\n")
        print("="*75)
        print("")
        
# Get the device array
num_devices = len(jax.local_devices())
devices     = mesh_utils.create_device_mesh((num_devices,))

# Create a mesh from the device array
mesh = Mesh(devices, axis_names=("ax"))

# Define sharding with a partiton spec
sharding = NamedSharding(mesh, P("ax"))

print(f"Number of logical devices: {len(devices)}")
print(f"Shape of device array    : {devices.shape}")
print(f"\nMesh     : {mesh}")
print(f"Sharding : {sharding}\n\n")

In [None]:
import os, sys
sys.path.append(os.getcwd()+'/../')
from CHIMERA.data import load_data_gw, load_data_injection

dir_data = "/home/mt/softwares/CHIMERA_private/data/"
file_inj = dir_data+"O5_v3/injections_20M_sources_PLP_v9s2_H1-L1-Virgo-KAGRA-LIGOI_IMRPhenomHM_snr_th-20_dutyfac-1_fmin-10_noiseless.h5"

file_ev = "/home/mt/university/projects/gw_projects/full_spectral_sirens/catalog_generation/testSNR20/O5_snrthr20_MVN+MaskMassDLAngles+Resampling2bis_inv_detJ_m1m2_to_Mceta.h5"
#file_ev  = dir_data+"O5_v3/O5_samples_from_fisher_allpars_snrth-25_ieth-0.05_DelOmTh-inf.h5"

inj_data, inj_prior = load_data_injection(file_inj, )
events_pe = load_data_gw(file_ev)

pe_priors = events_pe['dL']**2


In [None]:
# Dummy data for demonstration purpose
data = events_pe['dL']

# Shard the data
sharded_data = jax.device_put(data, sharding)

print(f"Data  shape: {data.shape}")
print(f"Shard shape: {sharding.shard_shape(data.shape)}")        
get_sharding_details(sharded_data)

In [None]:
from CHIMERA.cosmo import flrw

cosmo_model = flrw()
cosmo_model = cosmo_model.update_params({'H0':jnp.linspace(40,100,25)})
z1 = cosmo_model.z_from_dGW(sharded_data)

In [None]:
m1d = events_pe['m1det']/(1+z1)
m2d = events_pe['m2det']/(1+z1)