In [1]:
import ztfimg
from ztfimg import catalog as catalog
import ztfin2p3
from ztfin2p3 import catalog
import pandas
import numpy as np
import jax
import optax
import jax.numpy as jnp
from jax.scipy import stats as jstats
from jax import vmap
import matplotlib.pyplot as plt
from astropy.modeling.models import Moffat2D

In [2]:
import pkg_resources
print(pkg_resources.get_distribution('jax').version)

0.4.26


In [3]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

gpu


In [4]:
! nvidia-smi

  pid, fd = os.forkpty()


Tue Jul 16 14:42:54 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.161.07             Driver Version: 535.161.07   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla V100-PCIE-32GB           On  | 00000000:86:00.0 Off |                    0 |
| N/A   31C    P0              44W / 250W |    312MiB / 32768MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

# Image selection

In [5]:
# importing image

liste_file = ["/sps/ztf/data/sci/2020/0924/431759/ztf_20200924431759_000655_zr_c13_o_q3_sciimg.fits", "/sps/ztf/data/sci/2020/0924/431759/ztf_20200924431759_000655_zr_c01_o_q1_sciimg.fits",
             "/sps/ztf/data/sci/2020/0924//278681/ztf_20200924278681_000682_zg_c01_o_q1_sciimg.fits", "/sps/ztf/data/sci/2020/0924/352269/ztf_20200924352269_000650_zr_c06_o_q2_sciimg.fits",
             "/sps/ztf/data/sci/2020/0924/509537/ztf_20200924509537_000700_zg_c03_o_q2_sciimg.fits", "/sps/ztf/data/sci/2020/0924/431759/ztf_20200924431759_000655_zr_c09_o_q1_sciimg.fits"]

img1 = ztfimg.ScienceQuadrant.from_filename(liste_file[0])

q1 = img1.get_ccd().get_quadrant(1) #selects quadrant 1 
qimg1 = q1.get_data() #converted to numpy array

# importing data into a pandas.dataframe
qimg1_catalog = ztfin2p3.catalog.get_img_refcatalog(q1, which="gaia_dr2") # selects the data corresponding to the quadrant in the gaia_dr2 catalog
qimg1_catalog['isolated'] = ztfimg.catalog.get_isolated(qimg1_catalog, seplimit=20) #select stars that are 15 arcsec apart and add a Boolean column
qimg_catalog_isolated = qimg1_catalog.loc[qimg1_catalog['isolated'] == True] # we keep only isolated stars (whose Boolean is True)
qimg_catalog_isolated= qimg_catalog_isolated.drop('isolated', axis=1) #supression of isolated column

# magnitude selection
mag_inf = qimg_catalog_isolated.phot_g_mean_mag > 14 #selects magnitudes above 14
mag_sup = qimg_catalog_isolated.phot_g_mean_mag < 18 #selects magnitudes below 18
qimg_catalog_isolated_mag = qimg_catalog_isolated.loc[mag_inf & mag_sup] #application of the mask on magnitudes

# location selection
mag_bord_left = qimg_catalog_isolated_mag.x > 15 # removes stars on the left edge of 15 pixels
mag_bord_right = qimg_catalog_isolated_mag.x < (q1.shape[0]-15) # removes stars on the right edge of 15 pixels
mag_bord_top = qimg_catalog_isolated_mag.y > 15 # removes the stars on the top edge by 15 pixels
mag_bord_bottom = qimg_catalog_isolated_mag.y < (q1.shape[1]-15) # removes stars on the bottom edge of 15 pixels
mag_bord_combined = np.logical_and.reduce((mag_bord_left, mag_bord_right, mag_bord_top, mag_bord_bottom)) #edge selection
qimg_catalog_isolated_mag_bord = qimg_catalog_isolated_mag[mag_bord_combined] #application of edge mask
#qimg_catalog_isolated_mag_bord.to_csv('data_ztf.csv', index=False) #to save the dataframe in csv format
qimg_catalog_isolated_mag_bord

Unnamed: 0,id,coord_ra,coord_dec,phot_g_mean_flux,phot_bp_mean_flux,phot_rp_mean_flux,phot_g_mean_fluxErr,phot_bp_mean_fluxErr,phot_rp_mean_fluxErr,coord_raErr,...,ra,dec,phot_g_mean_mag,phot_bp_mean_mag,phot_rp_mean_mag,phot_g_mean_magErr,phot_bp_mean_magErr,phot_rp_mean_magErr,x,y
3292,220210914664501376,1.026161,0.637330,3.411760e+05,2.496884e+05,4.863944e+05,333.605022,4307.633356,1871.193552,0.002145,...,58.794687,36.516333,17.567620,17.906570,17.182594,0.001062,0.018731,0.004177,665.035175,1267.744563
3314,220210949024241920,1.025702,0.637414,2.846405e+05,1.552300e+05,4.988820e+05,263.706599,2006.851979,3996.709121,0.001880,...,58.768417,36.521120,17.764324,18.422627,17.155071,0.001006,0.014037,0.008698,589.417449,1282.097111
3343,220214655577103232,1.022990,0.637796,6.637216e+05,4.199027e+05,1.086072e+06,553.114736,1744.992099,2613.982357,0.001542,...,58.612995,36.543040,16.845101,17.342194,16.310420,0.000905,0.004512,0.002613,142.980478,1344.672348
3494,220216407923781760,1.023645,0.639737,3.310731e+05,2.083656e+05,5.224247e+05,297.967950,1487.596690,3983.500840,0.001617,...,58.650511,36.654212,17.600256,18.103001,17.105006,0.000977,0.007751,0.008279,236.507139,1743.317737
3524,220213216767321088,1.026577,0.638661,7.455134e+05,4.821637e+05,1.142259e+06,524.353589,3058.071864,5102.794374,0.001253,...,58.818547,36.592557,16.718927,17.192079,16.255654,0.000764,0.006886,0.004850,723.407453,1541.004010
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
45776,220167793192842368,1.033080,0.634488,4.161678e+06,3.407272e+06,5.055197e+06,1242.641362,5762.443728,7624.821103,0.000454,...,59.191138,36.353509,14.851894,15.069049,14.640720,0.000324,0.001836,0.001638,1820.108284,731.855584
45799,220175077457571712,1.036926,0.636100,1.675470e+06,1.183528e+06,2.416099e+06,681.649828,2936.846731,3462.371170,0.000738,...,59.411462,36.445829,15.839724,16.217119,15.442279,0.000442,0.002694,0.001556,2436.769333,1085.493345
45891,220225586272711168,1.030969,0.639192,6.834436e+05,4.806962e+05,9.924018e+05,407.310306,1946.126256,4175.423717,0.001170,...,59.070190,36.622983,16.813309,17.195389,16.408347,0.000647,0.004396,0.004568,1437.085775,1675.756830
45903,220222596975471872,1.031456,0.639021,6.083751e+05,5.244016e+05,7.197826e+05,425.536839,2348.414058,4584.703300,0.001287,...,59.098059,36.613195,16.939637,17.100906,16.757062,0.000759,0.004862,0.006916,1517.911040,1644.046424


# Functions

The Moffat distribution is:
$$
f(x, y, \alpha, \beta) =  A \left[1 + \left(\frac{(x-x_0)^2 + (y-y_0)^2}{\beta^2}\right)\right]^{-\alpha}
$$
where
$$
A = \frac{\alpha - 1}{\pi \gamma^2}
$$

In [6]:
def get_stamps(dataframe, size=17):
    """
    Parameters:
    ----------
    dataframe: pandas.dataframe
        dataframe of stars's data from Gaia dr2
    size: int
        stamp size (17,17)

    Returns:
    --------
    stamps: np.asarray
        data of the star image
    """
    stamps = []
    for index, df in qimg_catalog_isolated_mag_bord.iterrows():
        x0 = int(round(df["x"]))
        y0 = int(round(df["y"]))
        left = x0 - (size // 2)
        top = y0 - (size // 2)
        right = left + size
        bottom = top + size
        stamps.append(qimg1[top:bottom, left:right])
    return np.asarray(stamps)


def moffat(x, y, x0, y0, A, alpha, gamma):
    r_squared = (x - x0)**2 + (y - y0)**2
    return A * (1 + (r_squared / gamma**2))**(-alpha)


@jax.jit
def get_model(params):
    """
    Returns the Moffat function

    Parameters:
    -----------
    params: pytree
        [x0,y0: (2,N) centroids
         a: (N,) amplitudes
         b: (N,) backgrounds
         gamma: (1,) float
         alpha: (1,) float]
    
    Returns:
    -------
    model : arraylike
        the Moffat function
    """
    mu, A, b, alpha, gamma = params
    x0, y0 = mu[:, 0], mu[:, 1]
    
    vectorized_moffat = vmap(moffat, in_axes=(None, None, 0, 0, 0, None, None)) # to vectorize on centroids
    norm = vectorized_moffat(pos[:, 0], pos[:, 1], x0, y0, A, alpha, gamma)
    norm_model = norm + b[:, None]
    
    return norm_model
    
@jax.jit
def get_likelihood(params):
    """
    Computes the Chi squared from the selected model 

    Parameters:
    ----------
    params: list
        selected model parameters
    data: arraylike
        images of stars (flattened)
    pos: arraylike
        The positions (meshgrid) where the model is evaluated.

    Returns:
    --------
    summ: float
        chi squared sum for all stars
    """
    model = get_model(params)
    summ = jnp.sum((model - data)**2)
    return summ


@jax.jit
def get_logprior(params):
    """
    Returns the probability to have gamma and alpha
    
    Parameters:
    -----------
    params: list
        selected model parameters

    X: arraylike
        The X coordinates.

    Y: arraylike
        The Y coordinates.
    
    Returns:
    -------
    logprior: float
        sum of the two sigma probabilities
    """
    mu, A, b, alpha, gamma = params
    logprior_alpha = jstats.norm.logpdf(alpha, loc=1.0, scale=0.5) 
    logprior_gamma = jstats.norm.logpdf(gamma, loc=1.0, scale=0.5) 
    logprior = logprior_alpha + logprior_gamma
    return logprior

@jax.jit
def get_logprob(params):
    """ 
    Computes the sum of the gamma and alpha probabilities and the chi squared

    Parameters:
    -----------
    params: list
        selected model parameters
    data: arraylike
        images of stars (flattened)
    pos: arraylike
        The positions (meshgrid) where the model is evaluated.

    Returns:
    --------
    logprob: float
        sum of the sigma probabilities and the chi squared
    """
    logprior = -1 * get_logprior(params)  # to minimize
    likelihood = get_likelihood(params)
    logprob = logprior + likelihood
    return logprob

In [7]:
def fit_tncg(func, init_param, 
             niter=10, tol=5e-3, 
             lmbda=1e2, 
             **kwargs):
    """ Hessian-free second order optimization algorithm

    The following implementation of TN-CG is largely based on
    recommendations given in Martens, James (2010, Deep learning via
    Hessian-free optimization, Proc. International  Conference on
    Machine Learning).

    Parameters
    ----------
    func: function
        function to minimize. Should return a float.

    init_param: 
        entry parameter of the input func

    niter: int
        maximum number of iterations

    tol: float
        targeted func variations below which the iteration will stop

    lmbda: float
        lambda parameter of the tncg algorithm. (optstate)

    **kwargs other func entries 

    Returns
    -------
    list
        - best parameters
        - loss (array)

    Example
    -------
    ```python
    import jax
    from edris import simulation, minimize
    key = jax.random.PRNGKey(1234)
    truth, simu = simulation.get_simple_simulation(key, size=1_000)

    def get_total_chi2(param, data):
        # model for a line with error on both axes but no intrinsic scatter.
        x_model = param["x_model"]
        y_model = x_model * param["a"] + param["b"]
    
        chi2_y = jnp.sum( ((data["x_obs"] - x_model)/data["x_err"])**2 )
        chi2_x = jnp.sum( ((data["y_obs"] - y_model)/data["y_err"])**2 )
    
        return chi2_y + chi2_x

    init_param = {"a": 8., "b":0., "x_model": simu["x_obs"]} # careful, must be float
    best_params, loss = minimize.fit_tncg(get_total_chi2, init_param, data=simu)
    ```
    
    """
    # handle kwargs more easily
    func_ = lambda x: func(x, **kwargs)
    fg = jax.value_and_grad(func_)
    
    # - internal function --- #
    def hessian_vector_product(g, x, v):
        return jax.jvp(g, (x,), (v,))[1]

    def step_tncg(x, optstate):
        loss, grads = fg(x)
        lmbda = optstate['lmbda']
        fvp = lambda v: jax.tree_util.tree_map(lambda d1, d2: d1 + lmbda*d2, hessian_vector_product(jax.grad(func_), x, v), v)
        updates, _ = jax.scipy.sparse.linalg.cg(fvp, grads, maxiter=50)
        coco = jax.tree_util.tree_reduce(lambda x, y: x+y, jax.tree_util.tree_map(lambda x, y: (-x*y).sum(), grads, updates))
        return updates, loss, optstate, coco

    step_tncg = jax.jit( step_tncg )
    # ----------------------- #
    
    x = init_param
    optstate = {'lmbda': lmbda}
    losses = []

    for i in range(niter):
        updates, loss, optstate, coco = step_tncg(x, optstate)
        x1 = jax.tree_util.tree_map(lambda x, y: x - y, x, updates)
        dloss = func_(x1) - loss
        losses.append(loss)
        rho = dloss / coco
        
        if rho < 0.25:
            optstate['lmbda'] = optstate['lmbda'] * 1.5
        elif rho > 0.75:
            optstate['lmbda'] = optstate['lmbda'] * 0.3
            
        if dloss < 0: # accept the step
            x = x1
            
        if tol is not None and dloss > -tol:
            break
        
    return x, losses


def fit_adam(func, init_params,
             learning_rate=5e-3, niter=200, 
             tol=1e-3,
             **kwargs):
    """ simple Adam gradient descent using optax.adam

    Parameters
    ----------
    func: function
        function to minimize. Should return a float.

    learning_rate: float
        learning rate of the gradient descent.
        (careful, results can be sensitive to this parameter)
        
    init_param: 
        entry parameter of the input func

    niter: int
        maximum number of iterations

    tol: float
        targeted func variations below which the iteration will stop

    **kwargs other func entries 

    Returns
    -------
    list
        - best parameters
        - loss (array)        
    """
    # handle kwargs more easily
    func_ = lambda x: func(x, **kwargs)
    
    # Initialize the adam optimizer
    params = init_params
    optimizer = optax.adam(learning_rate)
    # Obtain the `opt_state` that contains statistics for the optimizer.
    opt_state = optimizer.init(params)
    
    grad_func = jax.jit(jax.grad( func_ )) # get the derivative
    
    # and do the gradient descent
    losses = []
    for i in range(niter):
        current_grads = grad_func(params)
        updates, opt_state = optimizer.update(current_grads, opt_state)
        params = optax.apply_updates(params, updates)
        losses.append( func_(params) ) # store the loss function
        if tol is not None and (i>2 and ((losses[-2] - losses[-1]) < tol)):
            break
            
    return params, losses

In [8]:
stamps = get_stamps(qimg_catalog_isolated_mag_bord)
coefs = np.sum(stamps, axis=(1,2))
stamps/=coefs[:,None, None]

In [9]:
nstars = len(stamps)
size= 17
X = jnp.linspace(-size/2, size/2, size)
Y = jnp.linspace(-size/2, size/2, size)
X, Y = jnp.meshgrid(X, Y)
pos = jnp.vstack((X.ravel(), Y.ravel())).T

# guess
x0 = jnp.zeros((nstars,), dtype="float32")
y0 = jnp.zeros((nstars,), dtype="float32")
mu = jnp.vstack([x0,y0]).T
A = jnp.ones((nstars,), dtype="float32")  
b = jnp.zeros((nstars,), dtype="float32")
alpha = jnp.array(1., dtype="float32")  
gamma = jnp.array(1., dtype="float32") 

data = stamps.reshape(len(stamps), -1)
grad_func = jax.jit(jax.grad(get_logprob)) # get the derivative

2024-07-16 14:42:57.172529: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [None]:
%%time

guess = [mu, A, b, alpha, gamma]
adam_params, adam_loss = fit_adam(get_logprob, guess, learning_rate=1e-4, tol=1e-5, niter=30000)

In [None]:
%%time

guess = [mu, A, b, alpha, gamma]
tncg_params, tncg_loss = fit_tncg(get_logprob, guess, tol=1e-5, niter=50, lmbda=10000)

In [None]:
#print(adam_params[0])

In [None]:
fig, ax = plt.subplots()
ax.plot(adam_loss, label='adam_loss')
ax.plot(tncg_loss, label='tncg_loss')
ax.set_yscale("log")
ax.set_xscale("log")
ax.set_xlabel("Iterations")
ax.set_ylabel("loss")
ax.set_title(f"Loss function (Moffat) for image : {liste_file[0]} \n Number of stars: {len(qimg_catalog_isolated_mag_bord)}", fontsize=8)
ax.legend()

In [None]:
def plots(models, stamps):

    num_cols = 4
    num_rows = len(end)
    fig_height = num_rows * 6
    fig_width = num_cols * 6

    fig, axes = plt.subplots(num_rows, num_cols, figsize = (fig_width, fig_height))

    for i, (index, row) in enumerate(end.iterrows()):
        ax1 = axes[i,0]
        im1 = ax1.imshow(stamps[i].reshape(size,size))
        ax1.set_title('Real data')
        plt.colorbar(im1, ax=ax1, location = 'left', orientation = 'vertical', shrink=0.75, pad=0.07)
        ax2 = axes[i,1]
        im2 = ax2.imshow(models[i].reshape(size,size))
        ax2.set_title('Moffat model')
        plt.colorbar(im2, ax=ax2, location = 'left', orientation = 'vertical', shrink=0.75, pad=0.07)
        ax3 = axes[i,2]
        im3 = ax3.imshow(stamps[i].reshape(size,size)- models[i].reshape(size,size))
        ax3.set_title('Residuals: real data - model')
        plt.colorbar(im3, ax=ax3, location = 'left', orientation = 'vertical', shrink=0.75, pad=0.07)
        ax4 = axes[i, 3]
        fig = plt.gcf()
        spec = ax4.get_subplotspec()
        fig.delaxes(ax4)
        ax4_3d = fig.add_subplot(spec, projection='3d')
        ax4_3d.plot_surface(X, Y, stamps[i].reshape(size,size), cmap='Reds', alpha=0.6, label='real data')
        ax4_3d.contour3D(X, Y, models[i].reshape(size,size), levels=50, cmap='Purples', alpha=0.6)
        ax4_3d.set_xlabel('x', fontdict=dict(weight='bold'))
        ax4_3d.set_ylabel('y', fontdict=dict(weight='bold'))
        ax4_3d.xaxis.set_tick_params(rotation=45)
        ax4_3d.yaxis.set_tick_params(rotation=-45)
        c1 = ax4_3d.contour(X, Y, models[i].reshape(size, size), 8, zdir='y', offset=np.max(Y), colors='blue', alpha=0.5)
        c2 = ax4_3d.contour(X, Y, stamps[i].reshape(size, size), 8, zdir='y', offset=np.max(Y), colors='red', alpha=0.5)
        ax4_3d.contour(X, Y, models[i].reshape(size, size), 8, zdir='x', offset=np.min(X), colors='blue', alpha=0.5)
        ax4_3d.contour(X, Y, stamps[i].reshape(size, size), 8, zdir='x', offset=np.min(X), colors='red', alpha=0.5)
        h1, _ = c1.legend_elements()
        h2, _ = c2.legend_elements()
        ax4_3d.legend([h1[0], h2[0]], ['Model', 'real data'], fontsize=13)
    
    plt.tight_layout()
    plt.show()

In [None]:
models = get_model(adam_params)
total_images = 5
end = qimg_catalog_isolated_mag_bord.head(total_images)
plots(models, stamps)
print("chi squared:", get_likelihood(adam_params))

In [None]:
models = get_model(tncg_params)
total_images = 10
end = qimg_catalog_isolated_mag_bord.head(total_images)
plots(models, stamps)
print("chi squared:", get_likelihood(tncg_params))