In [1]:
import galsim
import piff
import numpy as np
import matplotlib.pyplot as plt
import pickle
import copy
import time
import cProfile
import pstats
import jax
from jax import config
# if want to use float64 for jax, default will be float32 with jax
# faster with float32...
# config.update("jax_enable_x64", True)

In [2]:
# load data from a single ccd from subaru. Use the exact same input
# of what is coming in Piff in LSST stack (in meas_extensions_piff)

dic = pickle.load(open('piff_input_2024-06-24_13-43-22.pkl', 'rb'))

stars = dic['stars']
wcs = dic['wcs']
pointing = dic['pointing']
piffConfig = dic['piffConfig']

In [3]:
# run piff with current config in LSST stack

print(piffConfig)
piffResult = piff.PSF.process(piffConfig)
logger = piff.config.setup_logger(verbose=2)
piffResult = piff.PSF.process(piffConfig)

def my_function():
    piffResult.fit(stars, wcs, pointing, logger=logger)

start = time.time()
cProfile.run('my_function()', 'profile_data')
end = time.time()

print('Total run time: ' + str(end - start))

p = pstats.Stats('profile_data')
p.strip_dirs().sort_stats('cumulative').print_stats(20)

Iteration 1: Fitting 118 stars


{'type': 'Simple', 'model': {'type': 'PixelGrid', 'scale': 0.16839226517851116, 'size': 25, 'interp': 'Lanczos(11)'}, 'interp': {'type': 'BasisPolynomial', 'order': 2}, 'outliers': {'type': 'Chisq', 'nsigma': 4.0, 'max_remove': 0.05}}


nq = 3750
PF time to compute ATb and ATA: 3.094912 | use jax: False
Beginning solution of matrix size (3750, 3750)
PF: Not using JAX to solve the linear system | Time: 0.354791
             Total chisq = 63222.52 / 72029 dof
Iteration 2: Fitting 118 stars
nq = 3750
PF time to compute ATb and ATA: 3.048380 | use jax: False
Beginning solution of matrix size (3750, 3750)
PF: Not using JAX to solve the linear system | Time: 0.373044
Found 9 stars with chisq > thresh
             Removed 6 outliers
             Total chisq = 44289.31 / 68419 dof
Iteration 3: Fitting 112 stars
nq = 3750
PF time to compute ATb and ATA: 3.140364 | use jax: False
Beginning solution of matrix size (3750, 3750)
PF: Not using JAX to solve the linear system | Time: 0.351064
Found 3 stars with chisq > thresh
             Removed 3 outliers
             Total chisq = 41842.70 / 66603 dof
Iteration 4: Fitting 109 stars
nq = 3750
PF time to compute ATb and ATA: 2.970835 | use jax: False
Beginning solution of matrix siz

Total run time: 15.903011798858643
Thu Jun 27 01:01:37 2024    profile_data

         1893498 function calls (1881032 primitive calls) in 15.900 seconds

   Ordered by: cumulative time
   List reduced from 490 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000   15.900   15.900 {built-in method builtins.exec}
        1    0.000    0.000   15.900   15.900 <string>:1(<module>)
        1    0.018    0.018   15.900   15.900 2749334822.py:8(my_function)
        1    0.058    0.058   15.882   15.882 simplepsf.py:103(fit)
        4    0.017    0.004   13.745    3.436 basis_interp.py:130(solve)
        4   10.933    2.733   13.728    3.432 basis_interp.py:280(_solve_direct)
        4    1.420    0.355    1.420    0.355 _basic.py:52(solve)
      457    0.803    0.002    1.363    0.003 pixelgrid.py:195(chisq)
     2285    1.290    0.001    1.290    0.001 {method 'dot' of 'numpy.ndarray' objects}
      457    0.001   

<pstats.Stats at 0x31dc8f640>

In [4]:
# enable jax just in some part of basis interp (mostly in _solve_direct)

jaxPiffConfig = copy.deepcopy(piffConfig)
jaxPiffConfig['interp'].update({'use_jax': True})
print(jaxPiffConfig)

logger = piff.config.setup_logger(verbose=2)
piffResult = piff.PSF.process(jaxPiffConfig)

def my_function():
    piffResult.fit(stars, wcs, pointing, logger=logger)

start = time.time()
cProfile.run('my_function()', 'profile_data')
end = time.time()

print('Total run time: ' + str(end - start))

p = pstats.Stats('profile_data')
p.strip_dirs().sort_stats('cumulative').print_stats(20)

Iteration 1: Fitting 118 stars


{'type': 'Simple', 'model': {'type': 'PixelGrid', 'scale': 0.16839226517851116, 'size': 25, 'interp': 'Lanczos(11)'}, 'interp': {'type': 'BasisPolynomial', 'order': 2, 'use_jax': True}, 'outliers': {'type': 'Chisq', 'nsigma': 4.0, 'max_remove': 0.05}}


nq = 3750
PF time to compute ATb and ATA: 1.143394 | use jax: True
Beginning solution of matrix size (3750, 3750)
PF: Use JAX to solve the linear system | Time: 0.118327
             Total chisq = 63222.52 / 72029 dof
Iteration 2: Fitting 118 stars
nq = 3750
PF time to compute ATb and ATA: 1.043012 | use jax: True
Beginning solution of matrix size (3750, 3750)
PF: Use JAX to solve the linear system | Time: 0.085137
Found 9 stars with chisq > thresh
             Removed 6 outliers
             Total chisq = 44289.31 / 68419 dof
Iteration 3: Fitting 112 stars
nq = 3750
PF time to compute ATb and ATA: 1.029485 | use jax: True
Beginning solution of matrix size (3750, 3750)
PF: Use JAX to solve the linear system | Time: 0.087024
Found 3 stars with chisq > thresh
             Removed 3 outliers
             Total chisq = 41842.70 / 66603 dof
Iteration 4: Fitting 109 stars
nq = 3750
PF time to compute ATb and ATA: 1.017584 | use jax: True
Beginning solution of matrix size (3750, 3750)
PF: Use

Total run time: 6.954502105712891
Thu Jun 27 01:01:54 2024    profile_data

         2016753 function calls (2003080 primitive calls) in 6.942 seconds

   Ordered by: cumulative time
   List reduced from 2146 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      9/1    0.000    0.000    6.943    6.943 {built-in method builtins.exec}
        1    0.000    0.000    6.943    6.943 <string>:1(<module>)
        1    0.022    0.022    6.943    6.943 3036657732.py:10(my_function)
        1    0.096    0.096    6.922    6.922 simplepsf.py:103(fit)
        4    0.039    0.010    4.740    1.185 basis_interp.py:130(solve)
        4    1.069    0.267    4.701    1.175 basis_interp.py:280(_solve_direct)
     34/8    0.000    0.000    2.406    0.301 traceback_util.py:175(reraise_with_filtered_traceback)
     30/8    0.000    0.000    2.406    0.301 pjit.py:302(cache_miss)
     30/8    0.000    0.000    2.405    0.301 pjit.py:169(_python_pjit_hel

<pstats.Stats at 0x123755d80>