In [None]:
import jax
print('GPU connected')
#jax.config.update('jax_enable_x64', True)
import numpyro
#numpyro.set_platform('gpu')
#numpyro.set_host_device_count(4)
#numpyro.enable_x64()
from SLCOSMO_imports import *
from SLCOSMO import tool

jam_obj = jam_sph_proj.jam_sph_proj()
mge_epl = MGE(tool.EPL_msunmpc, 'thetaE', n_gauss=10, n_terms=28, sigma_start_mult=1/100, sigma_end_mult=50)
mge_sersic = MGE(tool.sersic_fn, 'effective_radius', n_gauss=10, n_terms=28, sigma_start_mult=1/100, sigma_end_mult=10)

N = 1000
rad = jnp.linspace(0, 0.725, 20)
dr = rad[2] - rad[1]
rad = rad+dr
@jax.jit
def jaxpersion(kwargs_sersic, kwargs_epl, cosmology,zl, beta):
    # Giving light profile, mass profile, coosmology parameter, velocity anisotropy
    # Calculate velocity dispersion
    light_amplitude, light_sigma = mge_sersic.decompose(**kwargs_sersic)
    mass_amplitude, mass_sigma = mge_epl.decompose(**kwargs_epl)
    distance = tool.angular_diameter_distance(zl, cosmology)[0]
    model,chi2, flux  = jam_obj.get_kinematics(surf_lum = jnp.array(light_amplitude), 
                                 distance = distance, 
                                 beta=jnp.array([1, beta, beta, 1]),
                                 sigma_lum=jnp.array(light_sigma), 
                                 surf_pot=jnp.array(mass_amplitude),  
                                 sigma_pot=jnp.array(mass_sigma),
                                 mbh=0,rad=jnp.array(rad), 
                                 logistic=True,
                                 tensor='los',
                                 quiet=True)
    vel = jnp.sum(model*2*np.pi*rad*dr)/jnp.sum(2*np.pi*rad*dr)
    return vel#jnp.mean(model, axis=-1) 

in_axes_sersic = {'mass_to_light_ratio': 0, 'intensity': 0, 'effective_radius': 0, 'sersic_index': 0, 'center_x': 0, 'center_y': 0}
in_axes_cosmology = {'Omegam': 0, 'Omegak': 0, 'w0': 0, 'wa': 0,'h0': 0}
in_axes_epl = {'thetaE': 0, 'gamma': 0, 'zl':0, 'zs': 0, 'cosmology':in_axes_cosmology}
batch_func = jax.vmap(
    jaxpersion,
    in_axes=(in_axes_sersic, in_axes_epl, in_axes_cosmology, 0, 0)
)
batch_func_jit = jax.jit(batch_func)

def make_mock_data(N):
    data = np.loadtxt('Euclid_len.txt')
    zl = data[:,0]
    zs = data[:,1]
    theta_E = data[:,2]
    re = data[:,5]
    theta_E = jnp.array(theta_E[0:N])
    kwargs_sersic_all = {'mass_to_light_ratio': jnp.full((N,), 1.0),
                         'intensity':jnp.full((N,), 1.0), 
                         'sersic_index': jnp.full((N,), 4.0),
                         'center_x': jnp.zeros((N,)),
                         'center_y': jnp.zeros((N,)),
                         'effective_radius': jnp.full((N,), 1.0)}
    zl = jnp.array(zl[0:N])
    zs = jnp.array(zs[0:N])
    
    cosmology_true = {
        'Omegam': jnp.full((N,), 0.3),
        'Omegak': jnp.zeros((N,)),
        'w0':     jnp.full((N,), -1.0),
        'wa':     jnp.zeros((N,)),
        'h0':     jnp.full((N,), 70.0),
    }
    kwargs_epl_all = {
        'thetaE': theta_E,
        'gamma':  np.random.normal(2, 0.1, N),
        'zl': zl,
        'zs': zs,
        'cosmology':cosmology_true
    }
    vel_batch = batch_func_jit(
        kwargs_sersic_all,
        kwargs_epl_all,
        cosmology_true,
        zl,
        np.random.normal(0, 0.1, N))

    return jnp.mean(vel_batch, axis=-1)+np.random.normal(0, 10, N), theta_E, zl, zs, kwargs_sersic_all, kwargs_epl_all,cosmology_true


def lens_cosmology_model(vel_obs, kwargs_sersic_all, kwargs_epl_all, zl_array, vel_err, N):
    #Cosmology
    Omegam = numpyro.sample("Omegam", dist.Uniform(0.10, 0.6))
    w0 = numpyro.sample("w0", dist.Uniform(-2, 0))

    cosmology_test = {
        'Omegam': jnp.full((N,), Omegam),
        'Omegak': jnp.zeros((N,)),
        'w0':     jnp.full((N,), w0),
        'wa':     jnp.zeros((N,)),
        'h0':     jnp.full((N,), 70.0),
    }
    kwargs_epl_all['cosmology'] = cosmology_test

    #Density profile
    gamma_mean = numpyro.sample("gamma", dist.Uniform(1.8, 2.2))
    gamma_sigma = numpyro.sample('gamma_sig', dist.TruncatedNormal(0.16, 1.0, low=0.0, high=0.4))
    anisotropy_mean = numpyro.sample("anisotropy", dist.Uniform(-0.2, 0.2))
    anisotropy_sigma = numpyro.sample('anisotropy_sig', dist.TruncatedNormal(0.13, 1.0, low=0.0, high=0.4))
    #Individual density
    with numpyro.plate("lens_kin_data", N):
        y_i = numpyro.sample("gamma_i", dist.TruncatedNormal(gamma_mean, gamma_sigma, low=1.5, high=2.5))
        anisotropy_i = numpyro.sample("beta_i", dist.TruncatedNormal(anisotropy_mean, anisotropy_sigma, low=-0.4, high=0.4))

    kwargs_epl_all['gamma'] = y_i
    pre_vel= batch_func_jit(kwargs_sersic_all, kwargs_epl_all, cosmology_test, zl_array, anisotropy_i)

    with numpyro.plate("N_data", N):
        numpyro.sample("obs", dist.Normal(pre_vel, vel_err), obs=vel_obs)




# rng_key = random.PRNGKey(0)
# vel_obs, theta_E, zl, zs, kwargs_sersic_all, kwargs_epl_all,cosmology_true = make_mock_data(N)
# nuts_kernel = NUTS(lens_cosmology_model, max_tree_depth = 8, target_accept_prob = 0.8)
# mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1500, num_chains=4, chain_method='vectorized')

# mcmc.run(rng_key, vel_obs   = vel_obs,
#             kwargs_sersic_all = kwargs_sersic_all,
#             kwargs_epl_all    = kwargs_epl_all,
#             zl_array          = zl,
#             vel_err = 10,
#             N=N)

# posterior_samples = mcmc.get_samples()
# import arviz as az
# inf_data = az.from_numpyro(mcmc)
# inf_data.to_netcdf("posterior_samples_notebook.nc")

In [None]:
import jax
print('GPU connected')
jax.config.update('jax_enable_x64', True)
import numpyro
numpyro.set_platform('gpu')
#numpyro.set_platform('cpu')
#numpyro.set_host_device_count(4)
#numpyro.enable_x64()
from SLCOSMO_imports import *
from SLCOSMO import tool
from tqdm.auto import trange  # 或者 from tqdm import trange


jam_obj = jam_sph_proj.jam_sph_proj()
mge_epl = MGE(tool.EPL_msunmpc, 'thetaE', n_gauss=10, n_terms=28, sigma_start_mult=1/100, sigma_end_mult=50)
mge_sersic = MGE(tool.sersic_fn, 'effective_radius', n_gauss=10, n_terms=28, sigma_start_mult=1/100, sigma_end_mult=10)

rad = jnp.linspace(0, 0.725, 20)
dr = rad[2] - rad[1]
rad = rad + dr

@jax.jit
def jaxpersion(kwargs_sersic, kwargs_epl, cosmology, zl, beta):
    light_amplitude, light_sigma = mge_sersic.decompose(**kwargs_sersic)
    mass_amplitude, mass_sigma = mge_epl.decompose(**kwargs_epl)
    distance = tool.angular_diameter_distance(zl, cosmology)[0]

    model, chi2, flux = jam_obj.get_kinematics(
        surf_lum=jnp.array(light_amplitude), 
        distance=distance, 
        beta=jnp.array([1, beta, beta, 1]),
        sigma_lum=jnp.array(light_sigma), 
        surf_pot=jnp.array(mass_amplitude),  
        sigma_pot=jnp.array(mass_sigma),
        mbh=0, rad=jnp.array(rad), 
        logistic=True,
        tensor='los',
        quiet=True
    )
    vel = jnp.sum(model * 2 * jnp.pi * rad * dr) / jnp.sum(2 * jnp.pi * rad * dr)
    return vel

# ----------------------------
# 设置离散网格（可根据需要调整数量）
# ----------------------------
N1 = 50  # thetaE 网格点数
N2 = 50  # gamma 网格点数
N3 = 50  # effective_radius 网格点数
N4 = 50  # beta 网格点数
N1 = N2 = N3 = N4 = 100

thetaE_grid = jnp.arange(0.5, 3.0, 0.01)
gamma_grid = jnp.arange(1.2, 2.8, 0.01)
Re_grid    = jnp.arange(0.15, 3, 0.01)
beta_grid  = jnp.arange(-0.5, 0.8, 0.01)

N1 = len(thetaE_grid)
N2 = len(gamma_grid)
N3 = len(Re_grid)
N4 = len(beta_grid)

N1 = N2 = N3 = N4 = 50

thetaE_grid = jnp.linspace(0.5, 3.0, N1)
gamma_grid = jnp.linspace(1.2, 2.8, N2)
Re_grid    = jnp.linspace(0.15, 3, N3)
beta_grid  = jnp.linspace(-0.5, 0.8, N4)

LUT = jnp.zeros((N1, N2, N3, N4))

# ----------------------------
# 其他固定参数
# ----------------------------
zl = 0.5
zs = 1.0
cosmology = {'Omegam': 0.3, 'Omegak': 0.0, 'w0': -1.0, 'wa': 0.0, 'h0': 70.0}

dl, ds, dls = tool.dldsdls(zl, zs, cosmology, n=20)
ds_over_dls_sqrt = jnp.sqrt(ds / dls)


for i in trange(N1, desc="thetaE dimension"):
    for j in range(N2):
        for k in range(N3):
            N = N4
            kwargs_sersic_base = {'mass_to_light_ratio': jnp.full((N,), 1.0),
                 'intensity':jnp.full((N,), 1.0), 
                 'sersic_index': jnp.full((N,), 4.0),
                 'center_x': jnp.zeros((N,)),
                 'center_y': jnp.zeros((N,)),
                 'effective_radius': jnp.full((N,), 1.0)}

            cosmology_true = {
                'Omegam': jnp.full((N,), 0.3),
                'Omegak': jnp.zeros((N,)),
                'w0':     jnp.full((N,), -1.0),
                'wa':     jnp.zeros((N,)),
                'h0':     jnp.full((N,), 70.0),
            }
            kwargs_epl_base = {
                'thetaE': jnp.full((N,), thetaE_grid[i]),
                'gamma':  np.random.normal(2, 0.1, N),
                'zl': jnp.full((N,), zl),
                'zs': jnp.full((N,), zs),
                'cosmology':cosmology_true
            }
            kwargs_epl = dict(kwargs_epl_base)
            kwargs_epl['gamma']  = jnp.full((N,), gamma_grid[j])
            
            kwargs_sersic = dict(kwargs_sersic_base)
            kwargs_sersic['effective_radius'] = jnp.full((N,), Re_grid[k])
            
            beta = beta_grid
            vel_disp = batch_func_jit(
                        kwargs_sersic,
                        kwargs_epl,
                        cosmology_true,
                        jnp.full((N,), zl),
                        beta)
                            
            #vel_disp = jaxpersion(kwargs_sersic, kwargs_epl, cosmology, zl, beta)
            
            for l in range(N4):
                LUT = LUT.at[i, j, k, l].set((vel_disp[l] / ds_over_dls_sqrt)[0])

# 此时，LUT 作为 4D 数组，已经是你需要的 Lookup Table 
# LUT[i, j, k, l] = velocity_dispersion_fn(thetaE_grid[i], gamma_grid[j], Re_grid[k], beta_grid[l])

# 你可以将 LUT 保存到文件中，以便之后在 numpyro 里加载：
numpy.save('velocity_disp_table.npy', LUT) 
# (在 JAX 下要先转成 .device_buffer 或转成 onp.array 再存)
