In [2]:
import numpy as np
import matplotlib.pyplot as plt
from numpy import pi,sqrt,e,shape,zeros,fft
import scipy.special as sc
import scipy
import matplotlib.animation as anim

In [4]:
from __future__ import division

## For creating GIFS in jupyter
class AnimatedGif:
    def __init__(self, size=(300, 300)):
        self.fig = plt.figure()
        self.fig.set_size_inches(size[0] / 20, size[1] / 20)
        ax = self.fig.add_axes([0, 0, 1, 1], frameon=False, aspect=1)
        ax.set_xticks([])
        ax.set_yticks([])
        self.images = []
 
    def add(self, image, label=''):
        plt_im = plt.imshow(image,animated=True,cmap = 'hsv')
        plt_txt = plt.text(10, 310, label, color='red')
        self.images.append([plt_im, plt_txt])
 
    def save(self, filename):
        animation = anim.ArtistAnimation(self.fig, self.images)
        animation.save(filename, writer='imagemagick', fps=10)

##easy 3-d plots calling
def three_d(X,Y,z,title): #Takes meshgrid and 'title'
    fig = plt.figure(figsize = (12,12))
    ax = plt.axes(projection='3d')

    surf = ax.plot_surface(X, Y, z, cmap = plt.cm.cividis)
    ax.set_xlabel('x', labelpad=20)
    ax.set_ylabel('y', labelpad=20)
    ax.set_zlabel('z', labelpad=10)
    plt.title(title)
    fig.colorbar(surf, shrink=0.3, aspect=20)
#     ax.view_init(90, 45)
    ax.view_init(10, 20)
    plt.show()

##easy interactive 3-d plot
def interact_threed(X,Y,Z):
    from mpl_toolkits.mplot3d import Axes3D
    import matplotlib.pyplot as plt
    from matplotlib import cm
    from matplotlib.ticker import LinearLocator, FormatStrFormatter
    import numpy as np
    %matplotlib qt
    
    fig = plt.figure(dpi=200)
    ax = fig.gca(projection='3d')
    surf = ax.plot_surface(X, Y, Z, cmap = 'plasma',
                   linewidth=1, antialiased=False)
    ax.zaxis.set_major_locator(LinearLocator(10))
    ax.zaxis.set_major_formatter(FormatStrFormatter('%.02f'))
    for angle in range(0, 360):
           ax.view_init(30, 40)
    fig.colorbar(surf, shrink=0.5, aspect=5)
    plt.show()

In [5]:
##fft 1D 
def my_fft(f,L,N,dx):
    fft = np.zeros(N,dtype = 'complex')
    shift = N/2
    for k in range(N):
        for n in range(N):
            fft[k] += dx*f[n]*np.e**(-(1j*2*np.pi/N)*(k-shift)*(n-shift))
    return fft

##ifft 1D
def my_ifft(F,L,N,dx):
    ifft = np.zeros(N,dtype = 'complex')
    shift = N/2
    for n in range(N):
        for k in range(N):
            ifft[n] += (1/L)*F[k]*np.e**((1j*2*np.pi/N)*(k-shift)*(n-shift))
    return ifft

##fft 2D
def my_fft_2d(f,L,N,dx):
    m,n = np.shape(f)
    fft2 = np.zeros((m,n),dtype = 'complex_')
    for i in range(m):
        fft2[i] = my_fft(f[i],L,N,dx)
    for j in range(n):
        fft2[:,j] = my_fft(fft2[:,j],L,N,dx)
    return fft2

##ifft 2D
def my_ifft_2d(F,L,N,dx):
    m,n = np.shape(F)
    ifft2 = np.zeros((m,n),dtype = 'complex_')
    for i in range(m):
        ifft2[i] = my_ifft(F[i],L,N,dx)
    for j in range(n):
        ifft2[:,j] = my_ifft(ifft2[:,j],L,N,dx)
    return ifft2 

##derivative in fourier space
def my_ifft_2d_deriv(F,L,N,dx,F_x,F_y):
    m,n = np.shape(F)
    ifft2 = np.zeros((m,n),dtype = 'complex_')
    freq_x = np.linspace(-N/2,N/2-1,N)*(2*np.pi)/L
    freq_y = np.linspace(-N/2,N/2-1,N)*(2*np.pi)/L
    for i in range(m):
        ifft2[i] = my_ifft((1j*freq_x)**F_x*F[i],L,N,dx)
    for j in range(n):
        ifft2[:,j] = my_ifft((1j*freq_y)**F_y*ifft2[:,j],L,N,dx)      
    return ifft2 

In [1]:
##Fresnel weights as outlined in write-up
def fresnel_weights(L,N,x,y,z,k):
    X, Y = np.meshgrid(x, y)
    dx = L/N
    fres_x = zeros(shape(X),dtype = 'complex')
    fres_y = zeros(shape(X),dtype = 'complex')
    W = 1/(2*dx)
    delta = dx

    for m in range(len(x)):
        for j in range(len(x)):
            slide = x[m] - x[j]
            u_1x = -pi*sqrt(2*z/k)*W - sqrt(k/(2*z))*slide
            u_2x = pi*sqrt(2*z/k)*W - sqrt(k/(2*z))*slide
            S_1x,C_1x = sc.fresnel(u_1x*sqrt(2/pi))
            S_2x,C_2x = sc.fresnel(u_2x*sqrt(2/pi))
            phi_x = (delta/pi)*sqrt(k/(2*z))*e**(1j*slide**2*k/(2*z))*sqrt(pi/2)*(C_2x - C_1x - 1j*(S_2x - S_1x))
            fres_x[m,j] = phi_x

    return fres_x,fres_x