In [1]:
import galsim
import piff
import numpy as np
import matplotlib.pyplot as plt
import pickle
import copy
import time
import cProfile
import pstats
import jax
from jax import config
# if want to use float64 for jax, default will be float32 with jax
# faster with float32...
# config.update("jax_enable_x64", True)

In [2]:
# load data from a single ccd from subaru. Use the exact same input
# of what is coming in Piff in LSST stack (in meas_extensions_piff)

dic = pickle.load(open('piff_input_2024-06-24_13-43-22.pkl', 'rb'))

stars = dic['stars']
wcs = dic['wcs']
pointing = dic['pointing']
piffConfig = dic['piffConfig']

In [3]:
# run piff with current config in LSST stack

print(piffConfig)
piffResult = piff.PSF.process(piffConfig)
logger = piff.config.setup_logger(verbose=2)
piffResult = piff.PSF.process(piffConfig)

def my_function():
    piffResult.fit(stars, wcs, pointing, logger=logger)

start = time.time()
cProfile.run('my_function()', 'profile_data')
end = time.time()

print('Total run time: ' + str(end - start))

p = pstats.Stats('profile_data')
p.strip_dirs().sort_stats('cumulative').print_stats(20)

Iteration 1: Fitting 118 stars


{'type': 'Simple', 'model': {'type': 'PixelGrid', 'scale': 0.16839226517851116, 'size': 25, 'interp': 'Lanczos(11)'}, 'interp': {'type': 'BasisPolynomial', 'order': 2}, 'outliers': {'type': 'Chisq', 'nsigma': 4.0, 'max_remove': 0.05}}


nq = 3750
PF time to compute ATb and ATA: 3.077662 | use jax: False
Beginning solution of matrix size (3750, 3750)
PF: Not using JAX to solve the linear system | Time: 0.357351
             Total chisq = 63222.52 / 72029 dof
Iteration 2: Fitting 118 stars
nq = 3750
PF time to compute ATb and ATA: 3.060766 | use jax: False
Beginning solution of matrix size (3750, 3750)
PF: Not using JAX to solve the linear system | Time: 0.352146
Found 9 stars with chisq > thresh
             Removed 6 outliers
             Total chisq = 44289.31 / 68419 dof
Iteration 3: Fitting 112 stars
nq = 3750
PF time to compute ATb and ATA: 3.011620 | use jax: False
Beginning solution of matrix size (3750, 3750)
PF: Not using JAX to solve the linear system | Time: 0.363533
Found 3 stars with chisq > thresh
             Removed 3 outliers
             Total chisq = 41842.70 / 66603 dof
Iteration 4: Fitting 109 stars
nq = 3750
PF time to compute ATb and ATA: 2.818668 | use jax: False
Beginning solution of matrix siz

Total run time: 15.556748867034912
Thu Jun 27 01:33:24 2024    profile_data

         1893509 function calls (1881043 primitive calls) in 15.552 seconds

   Ordered by: cumulative time
   List reduced from 490 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000   15.552   15.552 {built-in method builtins.exec}
        1    0.000    0.000   15.552   15.552 <string>:1(<module>)
        1    0.021    0.021   15.552   15.552 2749334822.py:8(my_function)
        1    0.072    0.072   15.531   15.531 simplepsf.py:103(fit)
        4    0.017    0.004   13.454    3.363 basis_interp.py:130(solve)
        4   10.832    2.708   13.436    3.359 basis_interp.py:280(_solve_direct)
        4    1.425    0.356    1.425    0.356 _basic.py:52(solve)
      457    0.770    0.002    1.317    0.003 pixelgrid.py:195(chisq)
     2285    1.107    0.000    1.107    0.000 {method 'dot' of 'numpy.ndarray' objects}
      457    0.001   

<pstats.Stats at 0x35590c640>

In [4]:
# enable jax just in some part of basis interp (mostly in _solve_direct)

jaxPiffConfig = copy.deepcopy(piffConfig)
jaxPiffConfig['interp'].update({'use_jax': True})
print(jaxPiffConfig)

logger = piff.config.setup_logger(verbose=2)
piffResult = piff.PSF.process(jaxPiffConfig)

def my_function():
    piffResult.fit(stars, wcs, pointing, logger=logger)

start = time.time()
cProfile.run('my_function()', 'profile_data')
end = time.time()

print('Total run time: ' + str(end - start))

p = pstats.Stats('profile_data')
p.strip_dirs().sort_stats('cumulative').print_stats(20)

Iteration 1: Fitting 118 stars


{'type': 'Simple', 'model': {'type': 'PixelGrid', 'scale': 0.16839226517851116, 'size': 25, 'interp': 'Lanczos(11)'}, 'interp': {'type': 'BasisPolynomial', 'order': 2, 'use_jax': True}, 'outliers': {'type': 'Chisq', 'nsigma': 4.0, 'max_remove': 0.05}}


nq = 3750
PF time to compute ATb and ATA: 1.119895 | use jax: True
Beginning solution of matrix size (3750, 3750)
PF: Use JAX to solve the linear system | Time: 0.124662
             Total chisq = 63222.52 / 72029 dof
Iteration 2: Fitting 118 stars
nq = 3750
PF time to compute ATb and ATA: 1.036825 | use jax: True
Beginning solution of matrix size (3750, 3750)
PF: Use JAX to solve the linear system | Time: 0.093753
Found 9 stars with chisq > thresh
             Removed 6 outliers
             Total chisq = 44289.31 / 68419 dof
Iteration 3: Fitting 112 stars
nq = 3750
PF time to compute ATb and ATA: 1.024430 | use jax: True
Beginning solution of matrix size (3750, 3750)
PF: Use JAX to solve the linear system | Time: 0.098193
Found 3 stars with chisq > thresh
             Removed 3 outliers
             Total chisq = 41842.70 / 66603 dof
Iteration 4: Fitting 109 stars
nq = 3750
PF time to compute ATb and ATA: 1.137896 | use jax: True
Beginning solution of matrix size (3750, 3750)
PF: Use

Total run time: 7.0860679149627686
Thu Jun 27 01:33:34 2024    profile_data

         2016768 function calls (2003093 primitive calls) in 7.078 seconds

   Ordered by: cumulative time
   List reduced from 2146 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      9/1    0.000    0.000    7.079    7.079 {built-in method builtins.exec}
        1    0.000    0.000    7.079    7.079 <string>:1(<module>)
        1    0.022    0.022    7.079    7.079 3036657732.py:10(my_function)
        1    0.153    0.153    7.056    7.056 simplepsf.py:103(fit)
        4    0.041    0.010    4.820    1.205 basis_interp.py:130(solve)
        4    1.102    0.276    4.779    1.195 basis_interp.py:280(_solve_direct)
   144/76    0.000    0.000    2.386    0.031 core.py:389(bind_with_trace)
     34/8    0.000    0.000    2.385    0.298 traceback_util.py:175(reraise_with_filtered_traceback)
     30/8    0.000    0.000    2.385    0.298 pjit.py:302(cache_miss

<pstats.Stats at 0x107a7df60>