In [9]:
import numpy as np
import pylab as plt
from scipy.signal import resample,correlate2d

def fft(A):
    return np.fft.fftshift(np.fft.fft2(A))

def ifft(A):
    return np.fft.ifft2(np.fft.ifftshift(A))

def conv2d(A,B):
    if np.size(A) != np.size(B):
        print("wrong sizes")
        return
    return np.fft.ifft2(np.fft.fft2(A)*np.fft.fft2(B))

def autocorr(A):
    F = fft(A)
    return ifft(F*np.conjugate(F))

def regrid2(A,shape,ax0=None,ax1=None):
    '''Uses fft to regrid ...'''
    if ax0 is not None:
        B,tb = resample(A,shape[0],t=ax0,axis=0)
    else:
        B = resample(A,shape[0],axis=0)
    if ax1 is not None:
        C,tc = resample(B,shape[1],t=ax1,axis=1)
    else:
        C = resample(B,shape[1],axis=1)
    return C
        

'''Deltas'''
M = 1
x0 = np.random.uniform(low=-0.5,high=0.5,size=M)
y0 = np.random.uniform(low=-0.5,high=0.5,size=M)
p0 = np.random.uniform(size=M)

wavelength = 0.01
N=100
print('number:{0}'.format(N))
z = 100.
x = np.linspace(-1,1,N)
dx = np.abs(x[1]-x[0])
X,Y = np.meshgrid(x,x)
Usky = np.zeros([N,N])
for xi,yi,pi in zip(x0,y0,p0):
    Usky += pi*np.exp(-((X-xi)**2 + (Y-yi)**2)/(2*dx)**2)

Usky2 = regrid2(Usky,[2*N,2*N])

Uprop = np.zeros([N,N])*1j
for xi,yi,pi in zip(x0,y0,p0):
    r2 = ((X - xi)**2 + (Y - yi)**2 + z**2)/wavelength**2
    Uprop += (z/wavelength)/r2*np.exp(1j*2*np.pi*np.sqrt(r2) - 1j*np.pi/2)

print(np.sum(autocorr(Usky) - autocorr(Uprop)))
f,(ax1,ax2,ax3) = plt.subplots(3)
ax1.imshow(Usky2,extent=(x[0],x[-1],x[0],x[-1]),origin='lower')
ax2.imshow(np.real(Uprop),extent=(x[0],x[-1],x[0],x[-1]),origin='lower')
ax3.imshow(np.angle(autocorr(Uprop)),extent=(x[0],x[-1],x[0],x[-1]),origin='lower',alpha=0.5)
ax3.imshow(np.angle(autocorr(Usky)),extent=(x[0],x[-1],x[0],x[-1]),origin='lower',alpha=0.5)
plt.show()


w = 1.
kx = np.fft.fftshift(np.fft.fftfreq(N,d=dx))
dk = np.abs(kx[1] - kx[0])#want dk * k * z = 1/4
Kx,Ky = np.meshgrid(kx,kx)


f,(ax1,ax2,ax3) = plt.subplots(3)
ax1.imshow(Usky,extent=(x[0],x[-1],x[0],x[-1]),origin='lower',alpha=0.5)
Asky = fft(Usky)
ax2.imshow(np.abs(Asky),extent=(kx[0],kx[-1],kx[0],kx[-1]),origin='lower')
ax3.imshow(np.angle(Asky),extent=(kx[0],kx[-1],kx[0],kx[-1]),origin='lower')
plt.show()

number:100
(103.453981466+0j)


In [37]:
from scipy.interpolate import griddata
f,(ax4,ax5,ax6) = plt.subplots(3)
w = 10
#want dk * k * z = 1/4
dk = 1./(4.*2.*np.pi*w)
N = int(np.ceil((kx[-1] - kx[0])/dk) + 1)
print('number:{0}'.format(N))
x = np.linspace(-1,1,N)
dx = np.abs(x[1]-x[0])
X,Y = np.meshgrid(x,x)
kx_ = np.fft.fftshift(np.fft.fftfreq(N,d=dx))
Kx_,Ky_ = np.meshgrid(kx_,kx_)
Prop = np.exp(1j*2*np.pi*w*np.sqrt(1j - Kx**2 - Ky**2))
Askyregrid = griddata((Kx.flatten(),Ky.flatten()),Asky.flatten(),(Kx_.flatten(),Ky_.flatten())).reshape(Kx_.shape)
ax4.imshow(np.abs(Prop),extent=(kx_[0],kx_[-1],kx_[0],kx_[-1]),origin='lower')
ax5.imshow(np.angle(Prop),extent=(kx_[0],kx_[-1],kx_[0],kx_[-1]),origin='lower')

Uskyp = ifft(Asky*Prop)
ax6.imshow(np.real(Uskyp),extent=(x[0],x[-1],x[0],x[-1]),origin='lower',alpha=0.5)
plt.show()

number:12318


In [51]:
from scipy.signal import fftconvolve
help(fftconvolve)

Help on function fftconvolve in module scipy.signal.signaltools:

fftconvolve(in1, in2, mode='full')
    Convolve two N-dimensional arrays using FFT.
    
    Convolve `in1` and `in2` using the fast Fourier transform method, with
    the output size determined by the `mode` argument.
    
    This is generally much faster than `convolve` for large arrays (n > ~500),
    but can be slower when only a few output values are needed, and can only
    output float arrays (int or object array inputs will be cast to float).
    
    Parameters
    ----------
    in1 : array_like
        First input.
    in2 : array_like
        Second input. Should have the same number of dimensions as `in1`;
        if sizes of `in1` and `in2` are not equal then `in1` has to be the
        larger array.
    mode : str {'full', 'valid', 'same'}, optional
        A string indicating the size of the output:
    
        ``full``
           The output is the full discrete linear convolution
           of the inpu