In [1]:
import jax.numpy as np
import numpy as onp
from jax import grad, jit, vmap
from jax.ops import index_update
from jax import random
import matplotlib.pyplot as plt

import morphine
from morphine.matrixDFT import minimal_dft
import poppy

%matplotlib inline

import matplotlib as mpl
mpl.style.use('seaborn-colorblind')
phasemap = mpl.cm.rainbow
phasemap.set_bad(color='k')


#To make sure we have always the same matplotlib settings
#(the ones in comments are the ipython notebook settings)

mpl.rcParams['figure.figsize']=(12.0,9.0)    #(6.0,4.0)
mpl.rcParams['font.size']=20               #10 
mpl.rcParams['savefig.dpi']= 200             #72 
mpl.rcParams['axes.labelsize'] = 18
mpl.rcParams['axes.labelsize'] = 18
mpl.rcParams['xtick.labelsize'] = 14
mpl.rcParams['ytick.labelsize'] = 14
from matplotlib import rc
mpl.rcParams["font.family"] = "Times New Roman"

colours = mpl.rcParams['axes.prop_cycle'].by_key()['color']

from astropy import units as u

shift = np.fft.fftshift
fft   = np.fft.fft2
ifft  = np.fft.ifft2
fftfreq = np.fft.fftfreq

dtor = np.pi/180.0

import warnings
warnings.filterwarnings("ignore")


  phasemap.set_bad(color='k')


In [2]:
D = 2.
wavelen = 1e-6

NPIX = 128
FOV = 8.

def morphine_min(wavel):
    empty = morphine.OpticalSystem(npix=NPIX)
    empty.add_pupil( morphine.CircularAperture(radius=1.),npix=NPIX)
    empty.add_detector( pixelscale=0.025, fov_arcsec=FOV )
    psf, instrument = empty.propagate_mono(wavel,retain_intermediates=False)
    return psf.intensity

jit_morphine = jit(morphine_min)

def poppy_basic(wavel):
    empty = poppy.OpticalSystem(npix=NPIX)
    empty.add_pupil( poppy.CircularAperture(radius=1.),npix=NPIX)
    empty.add_detector( pixelscale=0.025, fov_arcsec=FOV )
    psf, instrument = empty.propagate_mono(wavel)
    return psf[0].data

img_min = morphine_min(wavelen)
img_jit = jit_morphine(wavelen)
img_pop = poppy_basic(wavelen)



IndexError: Array boolean indices must be concrete.

In [None]:
def minimal_dft_prim(plane, nlamD, npix):
    """Perform a matrix discrete Fourier transform with selectable
    output sampling and centering.

    Where parameters can be supplied as either scalars or 2-tuples, the first
    element of the 2-tuple is used for the Y dimension and the second for the
    X dimension. This ordering matches that of numpy.ndarray.shape attributes
    and that of Python indexing.

    To achieve exact correspondence to the FFT set nlamD and npix to the size
    of the input array in pixels and use 'FFTSTYLE' centering. (n.b. When
    using `numpy.fft.fft2` you must `numpy.fft.fftshift` the input pupil both
    before and after applying fft2 or else it will introduce a checkerboard
    pattern in the signs of alternating pixels!)

    Parameters
    ----------
    plane : 2D ndarray
        2D array (either real or complex) representing the input image plane or
        pupil plane to transform.
    nlamD : float or 2-tuple of floats (nlamDY, nlamDX)
        Size of desired output region in lambda / D units, assuming that the
        pupil fills the input array (corresponds to 'm' in
        Soummer et al. 2007 4.2). This is in units of the spatial frequency that
        is just Nyquist sampled by the input array.) If given as a tuple,
        interpreted as (nlamDY, nlamDX).
    npix : int or 2-tuple of ints (npixY, npixX)
        Number of pixels per side side of destination plane array (corresponds
        to 'N_B' in Soummer et al. 2007 4.2). This will be the # of pixels in
        the image plane for a forward transformation, in the pupil plane for an
        inverse. If given as a tuple, interpreted as (npixY, npixX).
    """

    npupY, npupX = plane.shape # 32, be careful

    npixY, npixX = 1.0*npix, 1.0*npix

    nlamDY, nlamDX = 1.0*nlamD, 1.0*nlamD
    
    dU = nlamDX / (npixX)
    dV = nlamDY / (npixY)
    dX = 1.0 / (1.0*npupX)
    dY = 1.0 / (1.0*npupY)


    Xs = (1.0*np.arange(npupX) - (npupX) / 2.0 + 0.5) * dX
    Ys = (1.0*np.arange(npupY) - (npupY) / 2.0 + 0.5) * dY

    Us = (1.0*np.arange(npixX) - (npixX) / 2.0 + 0.5) * dU
    Vs = (1.0*np.arange(npixY) - (npixY) / 2.0 + 0.5) * dV

    XU = np.outer(Xs, Us)
    YV = np.outer(Ys, Vs)

    expXU = np.exp(-2.0 * np.pi * 1j * XU)
    expYV = np.exp(-2.0 * np.pi * 1j * YV).T
    t1 = np.dot(expYV, plane)
    t2 = np.dot(t1, expXU)

    norm_coeff = np.sqrt((nlamDY * nlamDX) / (npupY * npupX * npixY * npixX))
    return norm_coeff * t2

minimal_dft = jit(minimal_dft_prim,static_argnums=2)

In [None]:

@jit
def ex1(x):
  size = onp.prod(onp.array(x.shape))
  return x.reshape((size,))

ex1(onp.ones((3, 4)))


In [None]:

@jit
def ex1(x):
  size = np.prod(np.array(x.shape))
  return x.reshape((size,))

ex1(np.ones((3, 4)))


In [None]:
plane = np.ones((4, 5))
nlamD = 1
npix = 1

test_nojit = minimal_dft_prim(plane, nlamD, npix)
# DeviceArray([[4.472136+0.j]], dtype=complex64)
test_jit = minimal_dft(plane, nlamD, npix)
# DeviceArray([[4.472136+0.j]], dtype=complex64)


In [None]:
%%timeit
img_min = morphine_min(wavelen)

In [None]:
%%timeit
img_jit = jit_morphine(wavelen)

In [None]:
%%timeit
img_pop = poppy_basic(wavelen)

In [None]:
fig, axes = plt.subplots(1,3,figsize=(12.0,4.0))
for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
axes[0].imshow(img_min**0.25)
axes[1].imshow(img_jit**0.25)
axes[2].imshow(img_pop**0.25)
