In [10]:
from __future__ import division
import numpy as np
import matplotlib.pyplot as plt
from numpy import pi,sqrt,e,shape,zeros,fft
import scipy.special as sc
import matplotlib.animation as anim

In [11]:
class Diffraction:
    """
    Takes in initial condition, Length, Number of points, wavelength, propagation distance.
    The main functionality is to return the FFT and Sinc based solutions at a plane distance z away.
    """
    def __init__(self,initial_condition,L,N,lam,z):
        self.f = initial_condition
        self.L = L
        self.N = N
        self.z = z
        self.lam = lam
        self.dx = L/N
        self.x = np.linspace(-L/2,L/2-self.dx,N)
        self.y = y = np.linspace(-L/2,L/2-self.dx,N)
        self.k = np.pi*2/self.lam

    def fresnel_weights(self):
        """
        Input: self
        Return: Fresnel weights to be called on later in sinc based solution (Here symmetry between x and y grids is assumed)
        """
        dx,z,x,y,k = self.dx,self.z,self.x,self.y,self.k
        X, Y = np.meshgrid(x, y)
        fres_x = zeros(shape(X),dtype = 'complex')
        W = 1/(2*dx)
        delta = dx

        for m in range(len(x)):
            for j in range(len(x)):
                slide = x[m] - x[j]
                u_1x = -pi*sqrt(2*z/k)*W - sqrt(k/(2*z))*slide
                u_2x = pi*sqrt(2*z/k)*W - sqrt(k/(2*z))*slide
                S_1x,C_1x = sc.fresnel(u_1x*sqrt(2/pi))
                S_2x,C_2x = sc.fresnel(u_2x*sqrt(2/pi))
                phi_x = (delta/pi)*sqrt(k/(2*z))*e**(1j*slide**2*k/(2*z))*sqrt(pi/2)*(C_2x - C_1x - 1j*(S_2x - S_1x))
                fres_x[m,j] = phi_x
        return fres_x,fres_x

    def sinc_solution(self):
        """
        Input: self
        Return: sinc based solution as outlined in the writup and 'Diffraction integral computation using sinc approximation' by Cubillos et al. 
        """
        w_x,w_y = self.fresnel_weights()
        return e**(1j*self.k*self.z)*w_x@self.f@w_y.T

    def fft_solution(self):
        """
        Input: self
        Return: FFT based solution
        """
        L,N,f,k,z = self.L,self.N,self.f,self.k,self.z
        freq_x = np.linspace(-N/2,N/2-1,N)*(2*np.pi)/L
        freq_y = np.linspace(-N/2,N/2-1,N)*(2*np.pi)/L
        freq_x,freq_y = np.meshgrid(freq_x,freq_y)
        fft2 = fft.fftshift(fft.fft2(f))
        fft2_prop = np.e**(1j*k*z)*fft2*np.e**(((freq_y)**2 + (freq_x)**2)*z/(2*1j*k))
        return fft.ifft2(fft2_prop)

    def three_d(self,solution,angle = [90,45],title = '3-D Plot'): 
        """
        Input: Solution (or initial condition if wanted), Optional: Different angle [x,y] , title  
        Returns: 3-D image
        """
        X, Y = np.meshgrid(self.x, self.y)
        fig = plt.figure(figsize = (12,12))
        ax = plt.axes(projection='3d')

        surf = ax.plot_surface(X, Y, solution, cmap = plt.cm.cividis)
        ax.set_xlabel('x', labelpad=20)
        ax.set_ylabel('y', labelpad=20)
        ax.set_zlabel('z', labelpad=10)
        plt.title(title)
        fig.colorbar(surf, shrink=0.3, aspect=20)
        ax.view_init(angle[0],angle[1])
        # ax.view_init(10, 20)
        plt.show()