In [1]:
import ztfimg
from ztfimg import catalog as catalog
import ztfin2p3
from ztfin2p3 import catalog
import pandas
import numpy as np
import jax
import optax
import jax.numpy as jnp
from jax.scipy import stats as jstats
from jax import vmap
import matplotlib.pyplot as plt
from astropy.modeling.models import Moffat2D
import jax.profiler as jprof
import os
from IPython.display import display, HTML
from IPython import get_ipython
import time

# Définir le répertoire pour les fichiers de trace
trace_dir = '/pbs/home/s/svoisin/stage_m2/tmp/jax_moffat'
os.makedirs(trace_dir, exist_ok=True)

try:
    # Démarrer la trace
    jprof.start_trace(trace_dir, create_perfetto_link=False)
    print(f"Profilage démarré dans {trace_dir}")


    # importing image
    liste_file = ["/sps/ztf/data/sci/2020/0924/431759/ztf_20200924431759_000655_zr_c13_o_q3_sciimg.fits", "/sps/ztf/data/sci/2020/0924/431759/ztf_20200924431759_000655_zr_c01_o_q1_sciimg.fits",
                 "/sps/ztf/data/sci/2020/0924//278681/ztf_20200924278681_000682_zg_c01_o_q1_sciimg.fits", "/sps/ztf/data/sci/2020/0924/352269/ztf_20200924352269_000650_zr_c06_o_q2_sciimg.fits",
                 "/sps/ztf/data/sci/2020/0924/509537/ztf_20200924509537_000700_zg_c03_o_q2_sciimg.fits", "/sps/ztf/data/sci/2020/0924/431759/ztf_20200924431759_000655_zr_c09_o_q1_sciimg.fits"]
    
    img1 = ztfimg.ScienceQuadrant.from_filename(liste_file[0])
    
    q1 = img1.get_ccd().get_quadrant(1) #selects quadrant 1 
    qimg1 = q1.get_data() #converted to numpy array
    
    # importing data into a pandas.dataframe
    qimg1_catalog = ztfin2p3.catalog.get_img_refcatalog(q1, which="gaia_dr2") # selects the data corresponding to the quadrant in the gaia_dr2 catalog
    qimg1_catalog['isolated'] = ztfimg.catalog.get_isolated(qimg1_catalog, seplimit=20) #select stars that are 15 arcsec apart and add a Boolean column
    qimg_catalog_isolated = qimg1_catalog.loc[qimg1_catalog['isolated'] == True] # we keep only isolated stars (whose Boolean is True)
    qimg_catalog_isolated= qimg_catalog_isolated.drop('isolated', axis=1) #supression of isolated column
    
    # magnitude selection
    mag_inf = qimg_catalog_isolated.phot_g_mean_mag > 14 #selects magnitudes above 14
    mag_sup = qimg_catalog_isolated.phot_g_mean_mag < 18 #selects magnitudes below 18
    qimg_catalog_isolated_mag = qimg_catalog_isolated.loc[mag_inf & mag_sup] #application of the mask on magnitudes
    
    # location selection
    mag_bord_left = qimg_catalog_isolated_mag.x > 15 # removes stars on the left edge of 15 pixels
    mag_bord_right = qimg_catalog_isolated_mag.x < (q1.shape[0]-15) # removes stars on the right edge of 15 pixels
    mag_bord_top = qimg_catalog_isolated_mag.y > 15 # removes the stars on the top edge by 15 pixels
    mag_bord_bottom = qimg_catalog_isolated_mag.y < (q1.shape[1]-15) # removes stars on the bottom edge of 15 pixels
    mag_bord_combined = np.logical_and.reduce((mag_bord_left, mag_bord_right, mag_bord_top, mag_bord_bottom)) #edge selection
    qimg_catalog_isolated_mag_bord = qimg_catalog_isolated_mag[mag_bord_combined] #application of edge mask
    
    def get_stamps(dataframe, size=17):
        """
        Parameters:
        ----------
        dataframe: pandas.dataframe
            dataframe of stars's data from Gaia dr2
        size: int
            stamp size (17,17)
    
        Returns:
        --------
        stamps: np.asarray
            data of the star image
        """
        stamps = []
        for index, df in qimg_catalog_isolated_mag_bord.iterrows():
            x0 = int(round(df["x"]))
            y0 = int(round(df["y"]))
            left = x0 - (size // 2)
            top = y0 - (size // 2)
            right = left + size
            bottom = top + size
            stamps.append(qimg1[top:bottom, left:right])
        return np.asarray(stamps)
    
    
    def moffat(x, y, x0, y0, A, alpha, gamma):
        r_squared = (x - x0)**2 + (y - y0)**2
        return A * (1 + (r_squared / gamma**2))**(-alpha)
    
    
    @jax.jit
    def get_model(params):
        """
        Returns the Moffat function
    
        Parameters:
        -----------
        params: pytree
            [x0,y0: (2,N) centroids
             a: (N,) amplitudes
             b: (N,) backgrounds
             gamma: (1,) float
             alpha: (1,) float]
        
        Returns:
        -------
        model : arraylike
            the Moffat function
        """
        mu, A, b, alpha, gamma = params
        x0, y0 = mu[:, 0], mu[:, 1]
        
        vectorized_moffat = vmap(moffat, in_axes=(None, None, 0, 0, 0, None, None)) # to vectorize on centroids
        norm = vectorized_moffat(pos[:, 0], pos[:, 1], x0, y0, A, alpha, gamma)
        norm_model = norm + b[:, None]
        
        return norm_model
        
    @jax.jit
    def get_likelihood(params):
        """
        Computes the Chi squared from the selected model 
    
        Parameters:
        ----------
        params: list
            selected model parameters
        data: arraylike
            images of stars (flattened)
        pos: arraylike
            The positions (meshgrid) where the model is evaluated.
    
        Returns:
        --------
        summ: float
            chi squared sum for all stars
        """
        model = get_model(params)
        summ = jnp.sum((model - data)**2)
        return summ
    
    
    @jax.jit
    def get_logprior(params):
        """
        Returns the probability to have gamma and alpha
        
        Parameters:
        -----------
        params: list
            selected model parameters
    
        X: arraylike
            The X coordinates.
    
        Y: arraylike
            The Y coordinates.
        
        Returns:
        -------
        logprior: float
            sum of the two sigma probabilities
        """
        mu, A, b, alpha, gamma = params
        logprior_alpha = jstats.norm.logpdf(alpha, loc=1.0, scale=0.5) 
        logprior_gamma = jstats.norm.logpdf(gamma, loc=1.0, scale=0.5) 
        logprior = logprior_alpha + logprior_gamma
        return logprior
    
    @jax.jit
    def get_logprob(params):
        """ 
        Computes the sum of the gamma and alpha probabilities and the chi squared
    
        Parameters:
        -----------
        params: list
            selected model parameters
        data: arraylike
            images of stars (flattened)
        pos: arraylike
            The positions (meshgrid) where the model is evaluated.
    
        Returns:
        --------
        logprob: float
            sum of the sigma probabilities and the chi squared
        """
        logprior = -1 * get_logprior(params)  # to minimize
        likelihood = get_likelihood(params)
        logprob = logprior + likelihood
        return logprob
    
    def fit_tncg(func, init_param, 
                 niter=10, tol=5e-3, 
                 lmbda=1e2, 
                 **kwargs):
        """ Hessian-free second order optimization algorithm
    
        The following implementation of TN-CG is largely based on
        recommendations given in Martens, James (2010, Deep learning via
        Hessian-free optimization, Proc. International  Conference on
        Machine Learning).
    
        Parameters
        ----------
        func: function
            function to minimize. Should return a float.
    
        init_param: 
            entry parameter of the input func
    
        niter: int
            maximum number of iterations
    
        tol: float
            targeted func variations below which the iteration will stop
    
        lmbda: float
            lambda parameter of the tncg algorithm. (optstate)
    
        **kwargs other func entries 
    
        Returns
        -------
        list
            - best parameters
            - loss (array)
    
        Example
        -------
        ```python
        import jax
        from edris import simulation, minimize
        key = jax.random.PRNGKey(1234)
        truth, simu = simulation.get_simple_simulation(key, size=1_000)
    
        def get_total_chi2(param, data):
            # model for a line with error on both axes but no intrinsic scatter.
            x_model = param["x_model"]
            y_model = x_model * param["a"] + param["b"]
        
            chi2_y = jnp.sum( ((data["x_obs"] - x_model)/data["x_err"])**2 )
            chi2_x = jnp.sum( ((data["y_obs"] - y_model)/data["y_err"])**2 )
        
            return chi2_y + chi2_x
    
        init_param = {"a": 8., "b":0., "x_model": simu["x_obs"]} # careful, must be float
        best_params, loss = minimize.fit_tncg(get_total_chi2, init_param, data=simu)
        ```
        
        """
        # handle kwargs more easily
        func_ = lambda x: func(x, **kwargs)
        fg = jax.value_and_grad(func_)
        
        # - internal function --- #
        def hessian_vector_product(g, x, v):
            return jax.jvp(g, (x,), (v,))[1]
    
        def step_tncg(x, optstate):
            loss, grads = fg(x)
            lmbda = optstate['lmbda']
            fvp = lambda v: jax.tree_util.tree_map(lambda d1, d2: d1 + lmbda*d2, hessian_vector_product(jax.grad(func_), x, v), v)
            updates, _ = jax.scipy.sparse.linalg.cg(fvp, grads, maxiter=50)
            coco = jax.tree_util.tree_reduce(lambda x, y: x+y, jax.tree_util.tree_map(lambda x, y: (-x*y).sum(), grads, updates))
            return updates, loss, optstate, coco
    
        step_tncg = jax.jit( step_tncg )
        # ----------------------- #
        
        x = init_param
        optstate = {'lmbda': lmbda}
        losses = []
    
        for i in range(niter):
            updates, loss, optstate, coco = step_tncg(x, optstate)
            x1 = jax.tree_util.tree_map(lambda x, y: x - y, x, updates)
            dloss = func_(x1) - loss
            losses.append(loss)
            rho = dloss / coco
            
            if rho < 0.25:
                optstate['lmbda'] = optstate['lmbda'] * 1.5
            elif rho > 0.75:
                optstate['lmbda'] = optstate['lmbda'] * 0.3
                
            if dloss < 0: # accept the step
                x = x1
                
            if tol is not None and dloss > -tol:
                break
            
        return x, losses
    
    
    def fit_adam(func, init_params,
                 learning_rate=5e-3, niter=200, 
                 tol=1e-3,
                 **kwargs):
        """ simple Adam gradient descent using optax.adam
    
        Parameters
        ----------
        func: function
            function to minimize. Should return a float.
    
        learning_rate: float
            learning rate of the gradient descent.
            (careful, results can be sensitive to this parameter)
            
        init_param: 
            entry parameter of the input func
    
        niter: int
            maximum number of iterations
    
        tol: float
            targeted func variations below which the iteration will stop
    
        **kwargs other func entries 
    
        Returns
        -------
        list
            - best parameters
            - loss (array)        
        """
        # handle kwargs more easily
        func_ = lambda x: func(x, **kwargs)
        
        # Initialize the adam optimizer
        params = init_params
        optimizer = optax.adam(learning_rate)
        # Obtain the `opt_state` that contains statistics for the optimizer.
        opt_state = optimizer.init(params)
        
        grad_func = jax.jit(jax.grad( func_ )) # get the derivative
        
        # and do the gradient descent
        losses = []
        for i in range(niter):
            current_grads = grad_func(params)
            updates, opt_state = optimizer.update(current_grads, opt_state)
            params = optax.apply_updates(params, updates)
            losses.append( func_(params) ) # store the loss function
            if tol is not None and (i>2 and ((losses[-2] - losses[-1]) < tol)):
                break
                
        return params, losses
    
    stamps = get_stamps(qimg_catalog_isolated_mag_bord)
    coefs = np.sum(stamps, axis=(1,2))
    stamps/=coefs[:,None, None]
    
    nstars = len(stamps)
    size= 17
    X = jnp.linspace(-size/2, size/2, size)
    Y = jnp.linspace(-size/2, size/2, size)
    X, Y = jnp.meshgrid(X, Y)
    pos = jnp.vstack((X.ravel(), Y.ravel())).T
    
    # guess
    x0 = jnp.zeros((nstars,), dtype="float32")
    y0 = jnp.zeros((nstars,), dtype="float32")
    mu = jnp.vstack([x0,y0]).T
    A = jnp.ones((nstars,), dtype="float32")  
    b = jnp.zeros((nstars,), dtype="float32")
    alpha = jnp.array(1., dtype="float32")  
    gamma = jnp.array(1., dtype="float32") 
    
    data = stamps.reshape(len(stamps), -1)
    grad_func = jax.jit(jax.grad(get_logprob)) # get the derivative
    
    guess = [mu, A, b, alpha, gamma]
    
    adam_params, adam_loss = fit_adam(get_logprob, guess, learning_rate=1e-4, tol=1e-5, niter=30000)
    print("Exécution de fit_adam terminée.")
    tncg_params, tncg_loss = fit_tncg(get_logprob, guess, tol=1e-5, niter=50, lmbda=10000)
    print("Exécution de fit_tncg terminée.")

finally:
    try:
        # Arrêter la trace et obtenir le chemin du fichier de trace
        trace_file = jprof.stop_trace()
        print("Profilage arrêté.")

        # Assurez-vous que les messages sont affichés
        time.sleep(1)

    except Exception as e:
        print(f"Erreur en arrêtant la trace: {e}")




Profilage démarré dans /pbs/home/s/svoisin/stage_m2/tmp/jax_moffat


2024-07-23 10:20:33.997413: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


Exécution de fit_adam terminée.
Exécution de fit_tncg terminée.


[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/message_lite.cc:449] tensorflow.profiler.XSpace exceeded maximum protobuf size of 2GB: 4567007552


Profilage arrêté.
