In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2
import numpy as np
import galsim
import batsim
import time
import matplotlib.pyplot as plt

In [None]:
# lensing shear and kappa convergence
gamma1 = 0.00
gamma2 = -0.00
kappa = 0.0
# reduced shear and lensing magnification
g1 = gamma1 / (1 - kappa)
g2 = gamma2 / (1 - kappa)
mu = 1 / ((1 - kappa) ** 2 - gamma1**2 - gamma2**2)
galsim_fname = "/work/xiangchong.li/work/COSMOS/galsim_train/COSMOS_25.2_training_sample/real_galaxy_catalog_25.2.fits"
gsparams = galsim.GSParams()
cosmos = galsim.COSMOSCatalog(galsim_fname)
if False:
    sersic = cosmos.makeGalaxy(
        10, gal_type="parametric", noise_pad_size=0, gsparams=gsparams
    )
else:
    sersic = galsim.Sersic(n=1.0, half_light_radius=1.4, flux=20)
sersic2 = sersic.lens(g1=g1, g2=g2, mu=mu)
psf = galsim.Moffat(beta=3.5, fwhm=0.8, flux=1.0)

# Get nyquist scale and image size
scale = min(sersic.nyquist_scale / 8, psf.nyquist_scale / 4.0)


def next_power_of_2(v):
    return int(2 ** np.ceil(np.log2(v)))


nn = int((sersic.calculateMomentRadius() / scale)) * 20

image = (
    sersic2.shift(0.5 * scale, 0.5 * scale)
    .drawImage(nx=nn, ny=nn, scale=scale, method="no_pixel")
    .array
)

Lens = batsim.LensTransform(gamma1=gamma1, gamma2=gamma2, kappa=kappa)

t0 = time.time()
galfluxes = batsim.simulate_galaxy(
    nn=nn, scale=scale, gal_obj=sersic, transform_obj=Lens
)
t1 = time.time()
print("Time taken:", t1 - t0)
plt.close()
plt.imshow(galfluxes - image)
plt.colorbar()
print(scale)

In [None]:
Lens = batsim.LensTransform(gamma1=gamma1, gamma2=gamma2, kappa=kappa)
t0 = time.time()
image_conv = batsim.simulate_galaxy(
    nn=nn, scale=scale, gal_obj=sersic, transform_obj=Lens, psf_obj=psf
)
t1 = time.time()
print("Time taken:", t1 - t0)
image_conv_galsim = (
    galsim.Convolve([sersic2, psf])
    .shift(0.5 * scale, 0.5 * scale)
    .drawImage(nx=nn, ny=nn, scale=scale, method="no_pixel")
    .array
)

plt.close()
plt.imshow(image_conv - image_conv_galsim)
plt.colorbar()

print(np.max(image_conv), np.max(image_conv))
print(np.sum(image_conv_galsim), np.sum(image_conv_galsim))
print(np.max(np.abs(image_conv - image_conv_galsim) / image_conv_galsim.max()))

In [None]:
plt.close()
plt.imshow(image_conv_galsim)
plt.colorbar()

In [None]:
plt.close()
plt.imshow(image_conv)
plt.colorbar()

In [None]:
if False:
    sersic = galsim.Sersic(n=1.0, half_light_radius=0.1, flux=20).shear(g1=0.2, g2=-0.5)
    image = (
        sersic.shift(0.5 * scale, 0.5 * scale)
        .drawImage(nx=nn, ny=nn, scale=scale, method="no_pixel")
        .array
    )
x = np.fft.rfftfreq(nn, scale / np.pi / 2.0)
y = np.fft.fftfreq(nn, scale / np.pi / 2.0)
inds = np.meshgrid(y, x, indexing="ij")
coords = np.vstack([np.ravel(_) for _ in inds[::-1]])
kygrid, kxgrid = np.meshgrid(y, x, indexing="ij")

image_rfft = np.fft.rfft2(np.fft.ifftshift(image))

ix = 103
iy = 2000
print(image_rfft[iy, ix])
kx = kxgrid[iy, ix]
ky = kygrid[iy, ix]
print(kx, ky)
print(sersic.kValue(galsim.PositionD(kx, ky)))