In [15]:
import numpy as np
from scipy.optimize import curve_fit
from scipy.fft import fft,fft2, fftfreq
from scipy.signal import oaconvolve
from PIL import Image, ImageFilter
import joblib as jb
from IPython.display import clear_output
import numba

In [22]:
def generate_stepfunc(width,padding):
    stepfunc_x = np.linspace(0,width+2*padding,int(width/padding + 1))
    return np.heaviside(stepfunc_x-padding,1) - np.heaviside(stepfunc_x-width-padding,1)

def invert(pxl,fitparams):
    a,b,c = fitparams[0],fitparams[1],(fitparams[2]-pxl)
    return (-b + np.sqrt(b**2 - 4*a*c))/(2*a)

def dbcosine_func(x,A,B,p1,phi,p2):
    return A + B*np.cos(2*np.pi*x/p1 + phi) * (np.cos(2*np.pi*x/p2))

def cosine_func(x,A,B,p,phi):
    return A + B*np.cos(2*np.pi*x/p + phi) 

def cosine_func_jac(x,A,B,p,phi):
    return np.array([np.ones_like(x), np.cos(2*np.pi*x/p + phi), (B*2*np.pi*x/(p**2))*np.sin(2*np.pi*x/p + phi),-B*np.sin(2*np.pi*x/p)]).T

def max_index(ndarr):
    return np.unravel_index(np.argmax(ndarr),ndarr.shape)

def lowpass(data, minperiod, pxlsize, order):

    normal_cutoff = pxlsize/minperiod
    # Get the filter coefficients 
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
 
    return filtfilt(b, a, data)


def grating_equation(xarr,p, phase_offset):
    return np.sign(np.sin(2*np.pi*xarr/p + phase_offset))

# def avgminregion(profile):
    
#     indmin = np.nonzero(abs(profile - np.amin(profile))/np.amin(profile) < 1e-2)[0]
#     indmin = indmin[len(indmin)//2]
    
#     return np.mean(profile[indmin - len(profile)//20 : indmin + len(profile)//20])


def maxmincont(arr):
    return (np.amax(arr,axis = -1) - np.amin(arr,axis = -1))/(np.amax(arr,axis = -1) + np.amin(arr,axis = -1))

def get_freqs(ft_like, real_signal_length):
    return 2*np.pi*fftfreq(ft_like.shape[-1],real_signal_length/ft_like.shape[-1])

def rect_spectrum(x,xsize,p,phi,mt, image_profile = None, phase_offset = None):
    if image_profile is not None:
        realspace = phi/2 * image_profile
#         realspace = phi/2 * image_profile * len(image_profile)/(xsize)
    else:
        realspace = phi/2 * grating_equation(x,p,0 if phase_offset is None else phase_offset)
    
    allfreqs = get_freqs(x.squeeze(),xsize) # always returns 1D array with shape x.shape[-1]

    ft = fft(np.exp(1j*realspace), axis = -1, norm = "forward")
    
    fig,ax = plt.subplots(figsize = (12,8))
    ax.plot(allfreqs,abs(ft[19]))
#     ax.plot(allfreqs, abs(ft))
#     [ax.plot(2*np.pi/p * m, cos(phi[20]/2) if m == 0 else 0 if m%2 == 0 else abs(np.sin(phi[20]/2) * 2 / (m*np.pi)), "ro" if m%2==0 else "ko") for m in range(-mt,mt+1)]
    [ax.axvline(2*np.pi/p * m, color = "r" if m%2==0 else "k", ls = "--") for m in range(-mt,mt+1)]
    plt.xlim(-3e7,3e7)
    plt.show()
    plt.close(fig)
    
    
    ords = 2*np.pi/p * np.arange(-mt,mt+1,1)

    locs = np.argmin(abs(allfreqs[:,None] - ords[None,:]),axis = 0)
#     print(abs(ft[19,locs]))
#     return [allfreqs, ft]
    return np.array([np.tile(allfreqs[locs], (*ft.shape[:-1],1)),ft[...,locs]]).squeeze()
    
def rect_spectrum2d(x,xsize,y,ysize,px,py,phi,mt):
    realspace = phi/2 * (np.sign(np.cos(2*np.pi*x[:,None]/px).round(10)) + np.sign(np.cos(2*np.pi*y[None,:]/py).round(10)) )
    realspace[realspace < 0 ] = 0
    ft2d = fft2(np.exp(1j*realspace))/x.shape[-1]/y.shape[-1]

    xfreqs, yfreqs = 2*np.pi*fftfreq(ft2d.shape[0], xsize/ft2d.shape[0]), 2*np.pi*fftfreq(ft2d.shape[1], ysize/ft2d.shape[1])
    
    ordsx, ordsy = 2*np.pi/px * np.arange(-mt,mt+1,1), 2*np.pi/py * np.arange(-mt,mt+1,1)

    locsx, locsy = np.sum(abs(xfreqs[:,None] - ordsx[None,:]) < 1e-6, axis = -1).astype(bool), np.sum(abs(yfreqs[:,None] - ordsy[None,:]) < 1e-6, axis = -1).astype(bool)
    
    locs2d = locsx[:,None] & locsy[None,:]
    
    return [xfreqs[locsx], yfreqs[locsy], ft2d[locs2d]]
    
def best_fit_moire_period(func,xdata,ydata,pg,L,d):
    pmoire_dist = np.linspace(0.5*L * pg/d, 1.5*L*pg/d, 21)
    
    allparams, allerror = [],[]
    A,B,phi = np.mean(ydata), np.amax(ydata) - np.mean(ydata), 0


    for pmoire in pmoire_dist.squeeze():
#         print(pmoire)
        params, cov = curve_fit(func,xdata,ydata, p0 = [A, B, pmoire, phi], maxfev = 5000)
        allparams.append(params)
        allerror.append(np.sqrt(np.diag(cov))[2])
        
    return np.array(allparams[np.argmin(allerror)])


# @numba.njit(fastmath = True, parallel = True)
def get_modsquared_psi(k,a,xarr):
    psi = np.sum(np.exp(-1j*k*xarr)*a, axis = -1)
    return np.abs(psi)*np.abs(psi)    

# def fit_recurs(fit_routine, func, xdata, ydata, params):
#     if ydata.ndim == 1:
#         return fit_routine(func,xdata,ydata, params)
#     else:
#         return fit_recurs(fit_routine, func, xdata, ydata[], params[])
    
def force_broadcast(*arrs, nonunique_lengths = [], nonunique_occurences = []):

    
    arrs = [arr.squeeze() for arr in arrs]
    
    output_lengths = []
    
    for i in range(len(arrs)):
        for j in range(arrs[i].ndim):
            if arrs[i].shape[j] in nonunique_lengths:
                if output_lengths.count(arrs[i].shape[j]) < nonunique_occurences[nonunique_lengths.index(arrs[i].shape[j])]:
                        output_lengths.append(arrs[i].shape[j]) 

            elif arrs[i].shape[j] not in output_lengths:
                output_lengths.append(arrs[i].shape[j]) 
    
    common_ndim = len(output_lengths) 
    axes = np.arange(common_ndim)
    lengths_copy = output_lengths.copy()

    for i in range(len(arrs)):
        arr_ndim = arrs[i].ndim

        if arr_ndim > 0:
                    
            boolkey = np.array([np.array(lengths_copy) == j if j not in nonunique_lengths else np.full(axes.shape,False) for j in arrs[i].shape]).astype(bool)
            
            if boolkey.ndim > 1:
                boolcopy = boolkey.copy()
                boolkey = boolkey[0]
                for b in boolcopy[1:]:
                    boolkey = boolkey | b
            
            
            for length in arrs[i].shape:
                if length in nonunique_lengths:
                    boolkey[lengths_copy.index(length)] = True
                    lengths_copy[lengths_copy.index(length)] = -1
                    
            axes_to_expand = axes[~boolkey]
            
#             print(axes_to_expand)            
        
            for axis in axes_to_expand:
                arrs[i] = np.expand_dims(arrs[i],int(axis))

    return arrs 

def force_equal_dims(top, btm):
    final_ndim = max(top.ndim,btm.ndim)
    for _ in range(final_ndim - top.ndim):
        top = np.expand_dims(top,-1)
    for _ in range(final_ndim - btm.ndim):
        btm = np.expand_dims(btm,0)
    return [top,btm]
    
    
    
class PGMI:
#     !to-do eventually: replace integer references to axes with string variable names, i.e. create dict that maps input strings to array axes
#     and define class methods that get values/do something along input axes

    def __init__(self, apts_dict,init_values):
#         !can accomplish same idea more efficiently with __slots__ I think

        self.apts_dict = apts_dict

        self.pos = init_values["pos"]
        
        self.slitxflag = False
        self.slityflag = False

        if init_values.get("slitx") is not None:
            self.slitxflag = True
            self.stepslitx = generate_stepfunc(init_values["slitx"],init_values["xbin"])
            print(init_values["slitx"])
        if init_values.get("slity") is not None:
            self.slityflag = True
            self.stepslity = generate_stepfunc(init_values["slity"],init_values["ybin"])
            
        self.stepresx = generate_stepfunc(init_values["res"],init_values["xbin"])
        self.stepresy = generate_stepfunc(init_values["res"],init_values["ybin"])
        self.mt = init_values["mt"]
        self.lam = init_values["lam"]
        self.k_l = 2*np.pi/init_values["lam"]
        self.plam = init_values["plam"]
        self.L = init_values["L"]
        self.d = init_values["d"]
        self.convmode = init_values["convmode"]
        self.camsize = init_values["camsize"]
        self.x = init_values["x"]
        self.x0 = init_values["x0"] if type(init_values["x0"]) == np.ndarray else np.array([init_values["x0"]]) # !should probably be grating attribute to allow for mutiple phase stepping
        self.batches = init_values["batches"]
        if init_values.get("y") is not None:
            self.y = init_values["y"]
        self.Ioffset = None    
        if init_values.get("Ioffset") is not None:
            self.Ioffset = init_values["Ioffset"]
        
    def get_apts(self):
        return self.apts_dict
    
    def get_values(self):
        return vars(self)
    
    def get_value(self,key):
#         !not sure if this will work when using __slots__
        return vars(self)[key]

    def propagate_to(self, key):
        Apparatus = self.get_apts()[key]
        
        pos,mt,k_l = self.pos, self.mt, self.k_l
        
        
        if Apparatus.is2d:
            px,py,phi,L1,L2 = Apparatus.px, Apparatus.py, Apparatus.phi, Apparatus.L1,Apparatus.L2
            self.pos = Apparatus.propagate(pos,px,py,phi,mt,k_l, L1,L2, self.x0)
        else:
            p,phi,L1,L2 = Apparatus.p,Apparatus.phi, Apparatus.L1,Apparatus.L2
            self.pos = Apparatus.propagate(pos,p,phi,mt,k_l, L1,L2, self.x0)

    def generate_after(self, key):
        Apparatus = self.get_apts()[key]
        L,d,pos,camsize = self.L, self.d, self.pos, self.camsize
        
        batches = self.batches
        self.x = np.linspace(self.x[0],self.x[-1], len(self.x) - len(self.x) % batches)
        
        xbatches = np.array([self.x[i*len(self.x)//batches:(i+1)*len(self.x)//batches] for i in range(batches)])

        if Apparatus.is2d:
            
            px, py = Apparatus.px, Apparatus.py  
            
#                 currently not working without nested loops

#             self.y = np.linspace(self.y[0],self.y[-1], len(self.y) - len(self.y) % batches )
#             ybatches = np.array([self.y[i*len(self.y)//batches:(i+1)*len(self.y)//batches] for i in range(batches)])
            
#             xyspace = np.array(np.meshgrid(self.x,self.y))
#             xybatches = np.moveaxis(xyspace.reshape(2, int(len(self.x)/np.sqrt(batches)), int(len(self.y)/np.sqrt(batches)), batches), -1, 0)
            
#             self.raw = np.array(jb.Parallel(n_jobs = -1, prefer = "threads")(jb.delayed(Apparatus.get_rawintensity)(pos,xybatches[i]) for i in range(batches)))
            
#             self.raw = np.moveaxis(self.raw, 0, -3)
#             self.raw = np.reshape(self.raw,(*self.raw.shape[:-3],self.x.shape[0], self.y.shape[0]))

#                 needs to incorporate force_broadcast

#             for _ in range(self.raw.ndim - self.stepresx.ndim-1):
#                 self.stepresx = np.expand_dims(self.stepresx, 0)
#                 self.stepresy = np.expand_dims(self.stepresy, 0)
                
#             self.stepresx = np.expand_dims(self.stepresx, -1)
#             self.stepresy = np.expand_dims(self.stepresy, 0)
                
#             if self.slityflag:
#                 for _ in range(self.raw.ndim - self.stepslity.ndim):
#                     self.stepslity = np.expand_dims(self.stepslity, 0)
#                 self.intensity = oaconvolve(self.raw,self.stepslity, mode = self.convmode, axes = -1)

#             if self.slitxflag:
#                 for _ in range(self.raw.ndim - self.stepslitx.ndim - 1):
#                     self.stepslitx = np.expand_dims(self.stepslitx, 0)
#                 self.stepslitx = np.expand_dims(self.stepslitx, -1)

#                 self.intensity = oaconvolve(self.intensity,self.stepslitx, mode = self.convmode, axes = -2)
                
#             self.slitconv = self.intensity
#             self.intensity = oaconvolve(oaconvolve(self.intensity, self.stepresy, mode = self.convmode, axes = -1), self.stepresx, mode = self.convmode, axes = -2)
#             self.x, self.y = np.linspace(0,camsize, self.intensity.shape[-2]), np.linspace(0,camsize, self.intensity.shape[-1])
            
            
        else:
            p = Apparatus.p
            self.raw = np.array(jb.Parallel(n_jobs = -1, prefer = "threads")(jb.delayed(Apparatus.get_rawintensity)(pos,xi) for xi in xbatches))
#             self.raw = np.array(jb.Parallel(n_jobs = -1, prefer = "threads")(jb.delayed(get_rawintensity)(pos,xi) for xi in xbatches))
            print(self.raw.shape, "after parallel")
            self.raw = np.moveaxis(self.raw, 0,-2)
            
            self.raw = np.reshape(self.raw,(*self.raw.shape[:-2], self.raw.shape[-2] * self.raw.shape[-1]  ))
            print(self.raw.shape, "after reshape")
    
            
            if self.plam.flatten().shape[0] > 1:
                self.raw, self.plam = force_broadcast(self.raw,self.plam)
                self.raw = np.sum(self.raw * self.plam, axis = -2)
            
            if self.slitxflag:
                self.raw, self.stepslitx = force_equal_dims(self.raw,self.stepslitx)
                
            self.raw, self.stepresx = force_equal_dims(self.raw,self.stepresx)
            print(self.raw.shape, "before convolution")
            
            if self.slitxflag: 
                self.intensity = oaconvolve(oaconvolve(self.raw,self.stepslitx, mode = self.convmode, axes = -1), self.stepresx, mode = self.convmode, axes = -1)
            else:
                self.intensity = oaconvolve(self.raw, self.stepresx, mode = self.convmode, axes = -1)

            self.x = np.linspace(0,camsize, self.intensity.shape[-1])

        self.FT = fft(self.intensity, axis = -1)
        self.absFT = abs(self.FT)

        print("Finding contrast")

        if self.intensity.ndim < 3 :
            fitparams = jb.Parallel(n_jobs = -1)(jb.delayed(best_fit_moire_period)(cosine_func,self.x,self.intensity[i],p, L, di) for i, di in enumerate(d))
#             with_err = jb.Parallel(n_jobs = -1)(jb.delayed(curve_fit)(cosine_func,self.x,self.intensity[i],\
#                     p0 = [np.mean(self.intensity[i]), np.amax(self.intensity[i]) - np.mean(self.intensity[i]), L*p/di, 0], maxfev = 5000) for i, di in enumerate(d))
#             fitparams = np.array([i[0] for i in with_err])

            self.fitparams = np.array(fitparams)

        else:
#             surely there must be a better way than nested for loops. np.nditer ? recursion?
            cont = np.empty(self.intensity.shape[:-1])
            fitparams = np.empty((*cont.shape,4))
            
            for i in range(self.intensity.shape[0]):
                for j in range(self.intensity.shape[1]):
                    
                    params = best_fit_moire_period(cosine_func,self.x,self.intensity[i,j],p,L,d[i])
                    fitparams[i,j] = params
            
            self.fitparams = fitparams
        
        if self.Ioffset is not None:  
            self.Ioffset, fitparams_offset = force_broadcast(self.Ioffset,self.fitparams[...,0])
#             print(self.Ioffset.shape, fitparams_offset.shape)
            fitparams_offset = fitparams_offset + self.Ioffset
            fitparams_offset, unmodified_B = force_broadcast(fitparams_offset, self.fitparams[...,1])
            self.contrast = abs(unmodified_B/fitparams_offset)
            
        else:
            self.contrast = abs(self.fitparams[...,1]/self.fitparams[...,0])
            
                    
                    
# #     def cont_density(self):
# #         pos,mt,k_l = self.pos, self.mt, self.k_l
# #            need to replicate 2PGMI simulations density plots
        

    
class RectGrating(PGMI):
   
    def __init__(self,init_values):
        self.is2d = False
        self.p = init_values["p"]
        self.phi = init_values["phi"]
         
        self.L2, self.L1 = force_broadcast(init_values["L2"], init_values["L1"])
#         print(self.L2.shape,self.L1.shape)
        self.M = 1/(1 + self.L2/self.L1).squeeze()
        if init_values.get("phase_offset") is not None:
            self.phase = init_values["phase_offset"]
        
    
    
    def propagate(self,pos,p,phi,mt,k_l,L1,L2, x0):
        def B(k_l,k):
            return -k**2/(2*k_l)
        self.x0 = x0
#         !change this when x0 becomes grating attribute and not PGMI attribute; should also specify spectrum during __init__

        xsize = 10*p
        xft = np.linspace(0,xsize,1001)
        phi,x0, xft = force_broadcast(phi,x0,xft)
        kg, ag = rect_spectrum(xft - x0,xsize,p, phi, mt).squeeze()

#         specifies the spectrum of the grating
        
        k = pos[0]
        a = pos[1]

#         get current position in Fourier space

#         print(k.shape,a.shape)

        if np.any(np.array(k.shape) % kg.shape[-1] == 0):
            mt_ind = -2
            k, kg = force_broadcast(k,kg, nonunique_lengths = [kg.shape[-1]], nonunique_occurences = [2])    
            a, ag = force_broadcast(a,ag, nonunique_lengths = [ag.shape[-1]], nonunique_occurences = [2])    
            
        else:
            mt_ind = -1
            k, a, kg, ag = force_broadcast(k,a,kg,ag)  
            
#         print(k.shape,a.shape)

        k = k + kg
        a = a * ag
#         array assignment is averse to += and *=
    
        k = np.reshape(k,(*k.shape[:mt_ind],np.prod(k.shape[mt_ind:])))
        a = np.reshape(a,(*a.shape[:mt_ind],np.prod(a.shape[mt_ind:])))
#         ensure that last dimension gets larger in length without adding more dimensions


#         print(self.M.shape,L2.shape,k.shape,a.shape, k_l.shape)
        self.M, L2, k, a, k_l = force_broadcast(self.M, L2, k, a, k_l)
#         print(self.M.shape,L2.shape,k.shape,a.shape, k_l.shape)
        
            
        nextk = k*self.M
        nexta = a*np.exp(-1j*B(k_l,k)*self.M*L2)
#         scale by magnification factor

#         print(nextk.shape,nexta.shape)
        
        return [nextk,nexta]
    
    def get_rawintensity(self,pos,xarr):
        k = pos[0]
        a = pos[1]
        
#         print("%.2f %% done" % (xarr[...,0]/2.5e-2*100))
#         clear_output(wait = True)
        
        k,a,xarr = [np.swapaxes(arr,-1,-2) for arr in force_broadcast(k,a,xarr)]
#         !is this necessary?
        
        return get_modsquared_psi(k,a,xarr)    
    

# class RectGrating2D(PGMI):
#    needs to be updated with force_broadcast
#     def __init__(self,init_values):
#         self.px, self.py = init_values["px"],init_values["py"]
#         self.is2d = True
        
#         self.phi = init_values["phi"]
#         self.L1 = init_values["L1"][:,None]
#         self.L2 = init_values["L2"][:,None]
#         self.M = 1/(1 + self.L2/self.L1)

    
#     def propagate(self,pos,px, py, phi,mt,k_l,L1,L2, M_acc, x0):
#         def B(k_l,k):
#             return -k**2/(2*k_l)
        
#         xlen, ylen = 10*px, 10*py
#         kgx, kgy, ag = rect_spectrum2d(np.linspace(0,xlen,1001),xlen,np.linspace(0,xlen,1001),ylen,px,py,phi,mt)
#         ag = ag.reshape((kgx.shape[-1], kgy.shape[-1]))
        
# #         specifies the spectrum of the grating

#         kx = pos[0]
#         ky = pos[1]
#         a = pos[2]

# #         get current position in Fourier space

    
    
#         if L2.flatten().shape[0] in kx.shape:
#             for _ in range(kx.ndim - kgx.ndim):
#                 kgx = np.expand_dims(kgx,0) 
            
#             for _ in range(ky.ndim - kgy.ndim):
#                 kgy = np.expand_dims(kgy,0) 
                
#             for _ in range(a.ndim - ag.ndim):
#                 ag = np.expand_dims(ag,0)
                
#             kx = kx[...,None] + kgx[None,:]
#             ky = ky[...,None] + kgy[None,:]

#             a = a[...,None,None] * ag
            
#         else:

#             kx = kx[...,None] + kgx[None,:]
#             ky = ky[...,None] + kgy[None,:]
#             a = a[...,None,None] * ag[None,:]


            
#         for _ in range (kx.ndim - self.M.ndim):
#             self.M = np.expand_dims(self.M,-1)
            
#         for _ in range (kx.ndim - L2.ndim):
#             L2 = np.expand_dims(L2,-1)
        
#         nextkx = kx*self.M
#         nextky = ky*self.M
        
#         dimdiff = a.ndim - kx.ndim
        
#         for _ in range(dimdiff):
#             kx = np.expand_dims(kx,-1)
#             ky = np.expand_dims(ky, -1 - dimdiff)
#             self.M = np.expand_dims(self.M,-1)
#             L2 = np.expand_dims(L2,-1)
        
#         nexta = a*np.exp(-1j*B(k_l,kx + ky)*self.M*L2*M_acc) 
# #         scale by magnification factor

#         return [nextkx,nextky, nexta]
    
#     def get_rawintensity(self,pos,xypair):
#         kx = pos[0]
#         ky = pos[1]
#         a = pos[2]
#         a = np.swapaxes(a,-2,-3)
#         xarr, yarr = xypair
        
        
# #         print("done %",round(yarr[0]/1e-2 * 100,3))
        
#         clear_output(wait = True)
        
#         for _ in range(2):
#             kx,ky,a = np.expand_dims(kx,1),np.expand_dims(ky,1),np.expand_dims(a,1)
# #         makes space for x and y
#         dimdiff = a.ndim - kx.ndim
#         for _ in range(dimdiff):
#             kx, ky = np.expand_dims(kx,-1), np.expand_dims(ky,-1 - dimdiff)
        
#         for _ in range(a.ndim - xarr.ndim - 2):
#             xarr,yarr = np.expand_dims(xarr,-1), np.expand_dims(yarr,-1)
        
#         xarr, yarr = np.expand_dims(xarr,0), np.expand_dims(yarr,0)
#         xarr, yarr = np.expand_dims(xarr,-1), np.expand_dims(yarr,0)
        

        
#         psi = np.sum(np.exp(-1j*(kx*xarr + ky*yarr))*a, axis = (-1,-2,-3,-4))
#         return abs(psi)*abs(psi)    
    
    

    
class GratingFromImage(PGMI):
    
    def __init__(self,imgarr, init_values, profile_size,  profile):
        
       
        self.is2d = False
        self.p = init_values["p"]
        self.phi = init_values["phi"]
         
        self.L2, self.L1 = force_broadcast(init_values["L2"], init_values["L1"])
#         print(self.L2.shape,self.L1.shape)
        self.M = 1/(1 + self.L2/self.L1).squeeze()

        self.imgarr = imgarr if type(imgarr) == np.ndarray else np.array(imgarr)
        self.profile = profile
        self.profile_size = profile_size
        
        if self.imgarr.ndim == 3:
            gray = np.empty((self.imgarr.shape[:-1]))
            gray[...,:] = self.imgarr[...,0]
            self.imgarr = gray
#        !add image processing    
   
    
    
    def propagate(self,pos,p,phi,mt,k_l,L1,L2,x0):
        def B(k_l,k):
            return -k**2/(2*k_l)
        

        k = pos[0]
        a = pos[1]
#         get current position in Fourier space

        xft = np.linspace(0,self.profile_size,len(self.profile))
        phi,x0, xft = force_broadcast(phi,x0,xft)

        kg, ag = rect_spectrum(xft - x0,self.profile_size, p, phi, mt, image_profile=self.profile).squeeze()

        phi,kg,ag = force_broadcast(phi,kg,ag)
        

#         print(k.shape,a.shape)

        if np.any((np.array(k.shape) % kg.shape[-1] == 0) | (np.array(k.shape) % (2*mt + 1) == 0)):
            mt_ind = -2
            k, kg = force_broadcast(k,kg, nonunique_lengths = [kg.shape[-1]], nonunique_occurences = [2])    
            a, ag = force_broadcast(a,ag, nonunique_lengths = [ag.shape[-1]], nonunique_occurences = [2])    
            
        else:
            mt_ind = -1
            k, a, kg, ag = force_broadcast(k,a,kg,ag)  
            
#         print(k.shape,a.shape)

        k = k + kg
        a = a * ag
#         array assignment is averse to += and *=
    
        k = np.reshape(k,(*k.shape[:mt_ind],np.prod(k.shape[mt_ind:])))
        a = np.reshape(a,(*a.shape[:mt_ind],np.prod(a.shape[mt_ind:])))
#         ensure that last dimension gets larger in length without adding more dimensions


#         print(self.M.shape,L2.shape,k.shape,a.shape, k_l.shape)
        self.M, L2, k, a, k_l = force_broadcast(self.M, L2, k, a, k_l)
#         print(self.M.shape,L2.shape,k.shape,a.shape, k_l.shape)
        
            
        nextk = k*self.M
        nexta = a*np.exp(-1j*B(k_l,k)*self.M*L2)
#         scale by magnification factor

#         print(nextk.shape,nexta.shape)
        
        return [nextk,nexta]


    def get_rawintensity(self,pos,xarr):
        k = pos[0]
        a = pos[1]
        
#         print("%.2f %% done" % (xarr[...,0]/2.5e-2*100))
#         clear_output(wait = True)
        
        k,a,xarr = [np.swapaxes(arr,-1,-2) for arr in force_broadcast(k,a,xarr)]
#         !is this necessary?
        
       
        return get_modsquared_psi(k,a,xarr)   
    
    
# class ForkGrating(PGMI):
#     def __init__(self):

# class Sample(PGMI):
#     def __init__(self):
    