In [None]:
import numpy as np
import pandas as pd
from scipy.fft import fft,fft2, fftfreq, fftshift, ifft
from scipy.signal import oaconvolve

from PIL import Image, ImageFilter

from IPython.display import clear_output
from scipy.ndimage import rotate
import matplotlib.pyplot as plt



In [1]:
def assert_ndarr(a):
    if type(a) == np.ndarray and a.ndim > 0:
        return a
    elif a is None:
        raise TypeError("Input is None")
    return np.array([a])

def normalize_profile(profile, axis = -1,bounds = [-1,1]):
    profile, profile_min, profile_max = force_broadcast(profile, np.amin(profile,axis=axis), np.amax(profile,axis=axis))
    return (bounds[1] - bounds[0]) * (profile - profile_min)/(profile_max - profile_min) + bounds[0]

def get_param_axis(ndarr, a):
    return ndarr.shape.index(len(a.squeeze())) if type(a) == np.ndarray else ndarr.shape.index(a)

def flatten_after_axis(*arrs, axis = -2):
    return [np.reshape(arr, (*arr.shape[:axis], np.prod(arr.shape[axis:]))) for arr in arrs]

def retain_shape_after_index(arr,locs,only_last_axis = True):
    arr = assert_ndarr(arr)
    if only_last_axis:
        return arr[...,locs].reshape(*(arr.shape[:-1]),-1)
    return arr[locs].reshape(*(arr.shape[:-1]),-1)


def generate_stepfunc(width,xbin = None, n_pts = None):
    stepfunc_x = np.linspace(0,2*width,int(2*width/xbin)) if n_pts is None else np.linspace(0,2*width,n_pts)
    return np.heaviside(stepfunc_x - width/2,1) - np.heaviside(stepfunc_x-3*width/2,1)

def generate_source_grating(period, duty_cycle, binsize, n_periods = 50):
    single_source_x = np.linspace(0,period, int(period/binsize)+1)
    single_source = np.heaviside(single_source_x,1) - np.heaviside(single_source_x - duty_cycle*period,1)
    source = np.array([*np.tile(single_source, n_periods)])
    
    return source

def generate_single_source(period, duty_cycle, binsize, n_periods, current_period):
    single_source_x = np.linspace(0,(n_periods*period), int((n_periods*period)/binsize)+1)
    single_source = np.heaviside(single_source_x - (current_period*period),1)\
                    - np.heaviside(single_source_x - (current_period*period + duty_cycle*period),1)
    return single_source

def sinc(x):
    x_is_0 = x == 0
    out = np.sin(x)
    x[x_is_0] = 1
    out /= x
    out[x_is_0] = 1

    return out

def squareFT(width,k):
    return sinc(width*k/2)

def cosine_func(x,A,B,p,phi):
    return A + B*np.cos(2*np.pi*x/p + phi)


def max_index(ndarr):
    return np.unravel_index(np.argmax(ndarr),ndarr.shape)

def min_index(ndarr):
    return np.unravel_index(np.argmin(ndarr),ndarr.shape)

def lowpass(data, minperiod, pxlsize, order):

    normal_cutoff = pxlsize/minperiod
    # Get the filter coefficients 
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
 
    return filtfilt(b, a, data)


def grating_equation(xarr,p, phase_offset):
    return np.sign(np.sin(2*np.pi*xarr/p + phase_offset))

def grating_equation_2d(xarr,yarr,px,py):
    xarr, yarr = force_broadcast(xarr,yarr) if xarr.shape != yarr.shape else xarr[:,None], yarr[None,:]
    return np.sign(np.cos(2*pi*xarr/px) + np.cos(2*pi*yarr/py) - 1)

def maxmincont(arr):
    return (np.amax(arr,axis = -1) - np.amin(arr,axis = -1))/(np.amax(arr,axis = -1) + np.amin(arr,axis = -1))

def get_freqs(ft_like, param_ind, real_signal_length):
    real_signal_length = assert_ndarr(real_signal_length)
    return 2*np.pi*np.squeeze([fftfreq(ft_like.shape[param_ind],size/ft_like.shape[param_ind]) for size in real_signal_length]) 

def lookup_inds(vals_to_find, available_vals):
    lookup_vals = np.intersect1d(vals_to_find,available_vals, assume_unique = True)

    return np.squeeze(np.nonzero(np.isin(available_vals,lookup_vals)))


# def get_complement_inds(arr,arr_complement,target,decim_cutoff,N_objects,mt_for_roll = None):
#     # perhaps sorting along self.d axis would work
    
    
#     diff = target - arr

#     major_complement_inds = lookup_inds(np.round(diff, decim_cutoff),np.round(arr_complement, decim_cutoff))
    
#     # print("target:", np.round(target, decim_cutoff),"\n")
#     # print("arr:", np.round(arr, decim_cutoff),"\n")
#     # print("arr_complement:", np.round(arr_complement, decim_cutoff),"\n")
#     # print("diff:", np.round(diff, decim_cutoff),"\n")

    
    
#     if mt_for_roll is not None:
#         major_complement_inds = np.roll(major_complement_inds, -(2*mt_for_roll - 1))
#         print(major_complement_inds.shape)
#         # print("major_complement_inds",major_complement_inds,"\n")

#         replace_locs = np.arange(0,(2*mt_for_roll)**2, 2*mt_for_roll)

#         minor_complement_inds = np.roll(major_complement_inds, -2*mt_for_roll)
        
#         major_complement_inds[replace_locs] = minor_complement_inds[replace_locs]
        
#         # print("major_complement_inds",major_complement_inds,"\n")

        
#         major_inds = lookup_inds(np.round(target - arr_complement, decim_cutoff), np.round(arr, decim_cutoff))

#     return [major_inds,major_complement_inds]

def get_complement_inds(arr,arr_complement,target):
    
    arr_inds, arr_complement_inds = np.array([]),np.array([])
    
    for i in range(arr.shape[-1]):
        rolled_sum = arr + np.roll(arr_complement,-i, axis = -1)
        bool_where_target = np.isclose(rolled_sum, target)
        if np.any(bool_where_target):
            nonzero = np.squeeze(np.nonzero(bool_where_target))
            arr_inds = np.concatenate((arr_inds,nonzero))
            arr_complement_inds = np.concatenate((arr_complement_inds,nonzero + i)) % arr.shape[-1]
    
    
    return [np.squeeze(arr_inds).astype(int),np.squeeze(arr_complement_inds).astype(int)]

def get_a_k_final(a,k,target):

    k_conj, a_conj = -k, np.conj(a)
    
    if np.squeeze(target) == 0:
        indices = conj_indices = np.arange(k.shape[-1])
        # obviously, if you take k[i] + -k[i] you get 0
    else:
        indices, conj_indices = get_complement_inds(k,k_conj,target) 

    a = np.sum(retain_shape_after_index(a,indices) * retain_shape_after_index(a_conj,conj_indices), axis = -1)

    k = (retain_shape_after_index(k,indices) + retain_shape_after_index(k_conj,conj_indices))[...,0]
    
    return [a,k]

def fork_get_regions(nd_intens, xcam, ycam, pg, L, di):
    imgshp = nd_intens.shape[-2:]
    pm_in_pts = int((L*pg/di)/xcam * imgshp[0])
    starty = int(imgshp[1]/20)
    
    left_startx = np.argmax(nd_intens[..., :pm_in_pts,starty], axis = -1)
    right_startx = np.argmax(nd_intens[..., imgshp[0] - pm_in_pts:,starty], axis = -1)
    
    startx_dist = np.abs(right_startx - left_startx)
    
    xl, yl = left_startx, starty
    xr, yr = right_startx, starty
    
    for step in range(imgshp[1] - 2*starty):
        gradxl = nd_intens[...,xl + 1,yl] - nd_intens[...,xl - 1, yl]
        gradyl = nd_intens[...,xl,yl + 1] - nd_intens[...,xl,yl - 1]
        gradl_mag = (gradxl**2 + gradyl**2)**0.5
        
        gradxr = nd_intens[...,xr + 1,yr] - nd_intens[...,xr - 1, yr]
        gradyr = nd_intens[...,xr,yr + 1] - nd_intens[...,xr,yr - 1]
        gradr_mag = (gradxr**2 + gradyr**2)**0.5
        
        if np.all(gradl_mag == 0) or np.all((np.rint(gradxl/gradl_mag) == 0) & (np.rint(gradyl/gradl_mag) == 0)):
            stepxl,stepyl = 1, 1
        else:
            stepxl, stepyl = np.rint(gradxl/gradl_mag), np.rint(gradyl/gradl_mag)
            
        if np.all(gradr_mag == 0) or np.all((np.rint(gradxr/gradr_mag) == 0) & (np.rint(gradyr/gradr_mag) == 0)):
            stepxr,stepyr = -1, 1
        else:
            stepxr, stepyr = np.rint(gradxr/gradr_mag), np.rint(gradyr/gradr_mag)
        
        xl += stepxl
        yl += stepyl
        xr += stepxr
        yr += stepyr
    
    endx_dist = np.abs(xr - xl)
    
    if startx_dist > endx_dist:
        fork_region = [left_startx,right_startx,starty, starty*5]
        nofork_region = [xl,xr,yl,imgshp[1]]
    else:
        fork_region = [xl,xr,yl,imgshp[1]]
        nofork_region = [left_startx,right_startx,starty, starty*5]
    
    return [fork_region, nofork_region]

def best_fit_moire_period(func,xdata,intens,pg,L,d):
    
    def pfit(func,x,intens, p0):
        params, cov = curve_fit(func,x,intens, p0, maxfev = 20000) 
        return np.array([params,np.sqrt(np.diag(cov))])
        
    pmoire_dist = np.linspace(0.875*L * pg/d, 1.125*L*pg/d, 51).squeeze()

    A,B,phi = np.mean(intens), np.amax(intens) - np.mean(intens), 0

    all_params_with_error = np.swapaxes(jb.Parallel(n_jobs = -1)(jb.delayed(pfit)(func,xdata,intens, p0 = [A,B,pmoire,phi])\
                            for pmoire in pmoire_dist), -2,0)
    allparams, allerror = all_params_with_error

    return np.array(allparams[np.argmin(allerror[...,2])])

def force_broadcast(*arrs, nonunique_lengths = [], nonunique_occurences = [], desired_nonunique_arr_ax = [[]]):

    arrs = [assert_ndarr(arr).squeeze() for arr in arrs]

    output_lengths = []

    for i in range(len(arrs)):
        for j in range(arrs[i].ndim):
            length = arrs[i].shape[j]
            if length in nonunique_lengths:
                if output_lengths.count(length) < nonunique_occurences[nonunique_lengths.index(length)]:
                    output_lengths.append(length) 

            elif length not in output_lengths:
                output_lengths.append(length) 
    
    output_lengths = [i for i in output_lengths if i not in nonunique_lengths] + [i for i in output_lengths if i in nonunique_lengths]
    common_ndim = len(output_lengths)
    axes = np.arange(common_ndim)
    # lengths_copy = output_lengths.copy()
    
    newshapes = np.ones((len(arrs),len(output_lengths)),dtype = int)
    
    if nonunique_lengths != []:
        for i, length in enumerate(nonunique_lengths):
            for (ind, ax) in desired_nonunique_arr_ax:
                newshapes[ind,ax] = length
    
    
    for i, arr in enumerate(arrs):
        shp = arrs[i].shape
        for j, length in enumerate(output_lengths):
            if length in shp and length not in nonunique_lengths:
                newshapes[i,j] = length
        arrs[i] = np.reshape(arrs[i], newshapes[i].astype(int))
            
    new_arrs = [assert_ndarr(arr) for arr in arrs]
    return new_arrs 

def force_equal_dims(top, btm):
    final_ndim = max(top.ndim,btm.ndim)
    for _ in range(final_ndim - top.ndim):
        top = np.expand_dims(top,-1)
    for _ in range(final_ndim - btm.ndim):
        btm = np.expand_dims(btm,0)
    return [top,btm]


# sd stands for setup dictionary

class Grating:

    def __init__(self,sd):
        for key, value in sd.items():
            setattr(self, key, assert_ndarr(value))
        
        self.L2, self.L1 = force_broadcast(self.L2, self.L1)
        self.M = 1/(1 + self.L2/self.L1).squeeze()

        if sd.get("x0") is not None:
            self.x0 = sd.get("x0")
        else:
            self.x0 = 0


class RectGrating(Grating):
    
    def __init__(self,sd):
        super().__init__(sd)

        if sd.get("image_profile") is not None:
            self.profile = sd.get("image_profile")

            self.xpts = self.profile.shape[-1]
            self.x0, self.x = force_broadcast(self.x0, np.linspace(0,self.real_x_length, self.xpts))
            self.x = np.squeeze(self.x - self.x0)
            self.profile = normalize_profile(self.profile)
            self.phi, self.profile = force_broadcast(self.phi, self.profile)
            self.profile = self.phi / 2 * self.profile
            
        else:
            self.real_x_length = 10*self.p

            self.xpts = 1001
            # arbitrary number of points
            self.x0, self.x = force_broadcast(self.x0, np.linspace(0,self.real_x_length, self.xpts).T)
            self.x = self.x - self.x0
            self.p, self.phi, self.x = force_broadcast(self.p, self.phi, self.x)
            self.profile = self.phi / 2 * grating_equation(self.x,self.p, 0 if sd.get("phase_offset") is None else sd.get("phase_offset"))
            self.x = self.x.squeeze()
            
        
        self.FT = fft(np.exp(1j*self.profile), axis = get_param_axis(self.profile,self.xpts), norm = "forward")
        self.absFT = np.abs(self.FT)
        self.freqs = get_freqs(self.FT, get_param_axis(self.FT, self.x), self.real_x_length)
        ordrange = np.arange(-self.mt,self.mt+self.spectrum_spacing, self.spectrum_spacing)
                
        self.ords = 2*np.pi/self.p[:,None] * ordrange[None,:]
        self.freqs, self.ords = force_broadcast(self.freqs, self.ords)

        self.locs = np.argmin(np.abs(self.freqs - self.ords),axis = get_param_axis(self.freqs, self.xpts))
        
        self.FT, self.freqs = self.FT.squeeze(), self.freqs.squeeze()
        dimdiff = self.FT.ndim - self.freqs.ndim
        
        self.spectrum = np.squeeze([self.FT[...,np.unique(self.locs, axis = 0)],\
                        (np.tile(self.freqs,(*self.FT.shape[:dimdiff],*((1,)*(self.freqs.ndim)))))[...,np.unique(self.locs, axis = 0)]])
        self.num_ords = len(ordrange)
    
    
    
    def propagate(self,k0, neu_spec = None):
        def B(k,k0):
            return -k**2/(2*k0)
        
        ag,kg = self.spectrum.copy()
        
        if neu_spec is None:
            
            self.M, self.L2, k0, a, k = force_broadcast(self.M, self.L2, k0, ag, kg)
            # equivalent to convolving grating spectrum with identity

        else:
            a, k = neu_spec
            
            if np.any(np.array(a.shape) == self.num_ords):

                a, k, ag, kg =  force_broadcast(a, k, ag, kg, nonunique_lengths=[self.num_ords],\
                                nonunique_occurences=[self.apts_num],desired_nonunique_arr_ax=\
                                np.concatenate((np.vstack((np.tile([0,1],self.apts_num-1),\
                                np.repeat(np.arange(-self.apts_num,-1,1),2))).T, [[-2,-1],[-1,-1]])))

            else:
                a, k, ag, kg  = force_broadcast(a, k, ag, kg)
            
            
            a = a * ag
            k = k + kg
            
            # print(a.shape,k.shape,self.M.shape, self.L2.shape, k0.shape)
            self.M, self.L2, k0, a, k = force_broadcast(self.M, self.L2, k0, a, k, nonunique_lengths=[self.num_ords],\
                                        nonunique_occurences=[self.apts_num],desired_nonunique_arr_ax= \
                                        np.vstack((np.tile([-2,-1],self.apts_num), np.repeat(np.arange(-self.apts_num,0,1),2))).T)
            
            # print(a.shape,k.shape,self.M.shape, self.L2.shape, k0.shape)
            
        self.a_prime = a*np.exp(-1j*B(k,k0)*self.M*self.L2)
        self.k_prime = k*self.M
        
        return [self.a_prime, self.k_prime]
            
            
      
            
class RectGrating2D(Grating):
    
        
    def __init__(self, **sd):
        super().__init__(sd)
        
        self.real_x_length = 10*self.px
        self.real_y_length = 10*self.py
        self.xpts = self.ypts = 1001
        self.x, self.y = np.linspace(0,[self.real_x_length,self.real_y_length],self.xpts).T

        self.phi, self.x, self.y = force_broadcast(self.phi, self.x, self.y, nonunique_lengths= [len(self.x)], nonunique_occurences=  [2])
        self.profile = self.phi/2 * grating_equation_2d(self.x,self.y,self.px,self.py)
        
        self.FT = fft2(np.exp(1j*self.profile), axes = (get_param_axis(self.profile,self.x), get_param_axis(self.profile, self.y)), norm = "forward")
        self.absFT = np.abs(self.FT)
        self.xfreqs = get_freqs(self.FT, get_param_axis(self.FT,self.x), self.real_x_length)
        self.yfreqs = get_freqs(self.FT, get_param_axis(self.FT,self.y), self.real_y_length)
        self.xords = 2*np.pi/self.px * np.arange(-self.mt,self.mt+self.spectrum_spacing,self.spectrum_spacing)
        self.yords = 2*np.pi/self.py * np.arange(-self.mt,self.mt+self.spectrum_spacing,self.spectrum_spacing)
        self.xlocs = np.argmin(np.abs(self.xfreqs[:,None] - self.xords[None,:]),axis = 0)
        self.ylocs = np.argmin(np.abs(self.yfreqs[:,None] - self.yords[None,:]),axis = 0)
        self.xylocs = tuple(np.split(np.array(np.meshgrid(self.xlocs,self.ylocs)).ravel(),2))
        self.spectrum = [self.xfreqs[self.xlocs], self.yfreqs[self.ylocs], self.FT[self.xylocs]]
    
    def propagate(self,k0,neu_spec = None):
        def B(kx,ky,k0):
            return -(kx**2 + ky**2)/(2*k0)

        
class Sample(RectGrating):
    pass

class ForkGrating(Grating):
    pass

class PGMI:
#     !to-do eventually: replace integer references to axes with string variable names (using numpy record arrays or Pandas Index ?) i.e. map input
#     strings to array axes and define class methods that get values/do something along input axes (np.apply_along_axis, np.meshgrid might be useful)

    def __init__(self, apts_dict, sd):

        self.apts_dict = apts_dict
        self.N_objects = len(apts_dict)
        for n, Apparatus in enumerate(self.apts_dict.values()):
            setattr(Apparatus,"apts_num", n+1)
        
        for key, value in sd.items():
            setattr(self, key, assert_ndarr(value))
        
        self.slitxflag = False
        self.slityflag = False

        if sd.get("slitx") is not None:
            self.slitx = sd.get("slitx")
            self.slitxflag = True
            self.stepslitx = generate_stepfunc(sd.get("slitx"),sd.get("xbin"))


        if sd.get("slity") is not None:
            self.slity = sd.get("slity")

            self.slityflag = True
            self.stepslity = generate_stepfunc(sd.get("slity"),sd.get("ybin"))
            
        

        self.stepresx = generate_stepfunc(self.resx,self.xbin)
        self.stepresy = generate_stepfunc(self.resy,self.ybin)

        self.k0 = 2*np.pi/self.lam
        self.neu_spec = None
        
        for Apparatus in self.apts_dict.values():
            if type(Apparatus) == RectGrating2D:
                self.px,self.py = Apparatus.px, Apparatus.py
            elif type(Apparatus) != Sample:
                self.p = Apparatus.p
                break
        

        if sd.get("source_period") is not None:
            self.slitxflag = True
            self.stepslitx = generate_source_grating(self.source_period,self.duty_cycle,\
                            self.xbin, n_periods= int(3*self.camsize/self.source_period) + 1)
            
        
    def get_apts(self):
        return self.apts_dict
    
    def get_values(self):
        return vars(self)
    
    def get_value(self,key):
        return vars(self)[key]

    def propagate_to(self, key):
        Apparatus = self.get_apts()[key]
        
        self.neu_spec = Apparatus.propagate(self.k0, self.neu_spec)


    def generate_after(self, key):
        Apparatus = self.get_apts()[key]
        
        
        k_moire_vals = (2*np.pi/(self.L*self.p/self.d)).squeeze()

        a, k = flatten_after_axis(*self.neu_spec,axis = -Apparatus.apts_num)
        
        k = np.real_if_close(k)  
        

        if self.plam.squeeze().ndim > 0:
            k = np.moveaxis(k,get_param_axis(k,self.plam),0)[0]
            # moire frequency is independent of wavelength, makes indexing much simpler
        

        a, k = np.moveaxis(a,get_param_axis(a,k_moire_vals), 0), np.moveaxis(k,get_param_axis(k,k_moire_vals), 0)

        a_moire, k_moire = get_a_k_final(a,k[0],k_moire_vals[0])
        # indices are the same for all d, so just pick k corresponding to the first one for speed
        
        a_zero, k_zero = get_a_k_final(a,k[0],0)

        a, k = np.array([a_zero,a_moire]), np.array([k_zero,k_moire])
        
        if self.sum_lam:
            a, self.plam = force_broadcast(a, self.plam)

            a = np.sum(a*self.plam, axis = get_param_axis(a,self.plam))
       
    
        slitxft = squareFT(self.slitx, k)
        pxlft = squareFT(self.resx, k)
        
        a, slitxft, pxlft = force_broadcast(a,slitxft, pxlft)
        a *= slitxft * pxlft

        self.contrast = np.abs(2*a[1]/a[0])


    

class RectGrating2D(PGMI):

    def __init__(self,init_values):
        self.px, self.py = init_values["px"],init_values["py"]
        self.is2d = True
        self.isFork = False
        self.phi = init_values["phi"]
        self.L1 = init_values["L1"]
        self.L2 = init_values["L2"]
        self.L2, self.L1 = force_broadcast(self.L2,self.L1)
        
        self.M = 1/(1 + self.L2/self.L1)

    
    def propagate(self,pos,px, py, phi,mt,k0,L1,L2):
#         will currently only work for two gratings. incredibly scuffed and hardcoded due to time constraints
        def B(k_l,kx,ky):
            return -(kx**2 + ky**2)/(2*k0)
        
        xlen, ylen = 10*px, 10*py
        x, y =  np.linspace(0,xlen,1001), np.linspace(0,ylen,1001)
        phi, x, y = force_broadcast(phi,x,y, nonunique_lengths= [x.shape[-1]], nonunique_occurences= [2])
        
        kgx, kgy, ag = rect_spectrum2d(x,xlen,y,ylen,px,py,phi,mt)
        
        
        ag = ag.reshape((kgx.shape[-1], kgy.shape[-1]))
        
        
        
#         specifies the spectrum of the grating

        kx, ky, a = pos

#         get current position in Fourier space

        if np.any(np.array(kx.shape) % kgx.shape[-1] == 0):
            kx, kgx = force_broadcast(kx, kgx, nonunique_lengths= [kgx.shape[-1]], nonunique_occurences= [2])
            ky, kgy = force_broadcast(ky, kgy, nonunique_lengths= [kgy.shape[-1]], nonunique_occurences= [2])
            
            # print(ag)
            # for m in range(2*mt + 1):
            #     if m != mt:
            #         ag[...,m] = 0
            # print(ag)
            
            a, ag = force_broadcast(a, ag, nonunique_lengths= [ag.shape[-1]], nonunique_occurences= [4])
            kx = kx + kgx
            ky = ky + kgy
            a = a * ag
            a = np.swapaxes(a, -2,-3)

            # plt.colorbar(plt.imshow(np.abs(np.reshape(a[0],(9,9)))))
            # plt.show()
            
#             print(kx.shape, ky.shape, a.shape)
            self.M, L2, kx, ky, k0 = force_broadcast(self.M, L2, kx, ky, k0, nonunique_lengths= [kx.shape[-1],ky.shape[-1]],\
                                        nonunique_occurences = [4,4])
            
            nextkx = kx*self.M
            nextky = ky*self.M
#             print(nextkx.shape, nextky.shape, a.shape)

            self.M, L2, a, k0 = force_broadcast(self.M, L2, a, k0, nonunique_lengths= [a.shape[-1]],\
                                                    nonunique_occurences = [4])

            nexta = a*np.exp(-1j*B(kx,ky,k0)*self.M*L2) 
            # plt.colorbar(plt.imshow(np.abs(np.reshape(nexta[0],(9,9)))))
            # plt.show()
#             print(kx.shape, ky.shape, a.shape)
            
                        
            
        else:
            
            kx, kgx = force_broadcast(kx, kgx)
            ky, kgy = force_broadcast(ky, kgy)
            # print(ag)
            # for m in range(2*mt + 1):
            #     if m != mt:
            #         ag[...,m] = 0
            # print(ag)       
                    
            a, ag, = force_broadcast(a, ag, nonunique_lengths= [ag.shape[-1]], nonunique_occurences= [2])
            kx = kx + kgx
            ky = ky + kgy
            a = a * ag
            # plt.colorbar(plt.imshow(np.abs(a)))
            # plt.show()
            kx,ky = force_broadcast(kx,ky, nonunique_lengths= [kgx.shape[-1]], nonunique_occurences = [2])
#             print(kx.shape, ky.shape, a.shape)

            
            self.M, L2, kx,ky, k_l = force_broadcast(self.M, L2, kx,ky, k_l, nonunique_lengths= [kx.shape[-1], ky.shape[-1]],\
                                        nonunique_occurences = [2,2])
#             print(kx.shape, ky.shape, a.shape, "after M broad")
            nextkx = kx*self.M
            nextky = ky*self.M
#             print(nextkx.shape, nextky.shape, a.shape)

            self.M, L2, a, k_l = force_broadcast(self.M, L2, a, k_l, nonunique_lengths= [a.shape[-1]],\
                                                    nonunique_occurences = [2])

            nexta = a*np.exp(-1j*B(k_l,kx,ky)*self.M*L2) 
            # plt.colorbar(plt.imshow(np.abs(nexta[0])))
            # plt.show()
        

        return [nextkx,nextky, nexta]
    
    def get_rawintensity(self,pos,xarr,yarr,camsize):
        kx = pos[0]
        ky = pos[1]
        a = pos[2]
        
#         a = np.swapaxes(a,-2,-3)
#         xarr, yarr = xypair
        
        
        print("%.2f %% done" % (xarr[...,0]/camsize*100))        
        clear_output(wait = True)
        # print(kx.shape,ky.shape,a.shape)
        xarr,yarr,kx,ky = force_broadcast(xarr,yarr,kx,ky,\
                        nonunique_lengths= [xarr.shape[-1], kx.shape[-1], ky.shape[-1]], nonunique_occurences=[2, 4, 4])
        
        kxarr, kyarr = kx*xarr, ky*yarr
        # print(kxarr.shape, kyarr.shape, a.shape)
        # print((a*np.exp(-1j*(kxarr + kyarr))).shape)
        
        psi = np.sum(a*np.exp(-1j*(kxarr + kyarr)), axis = (-1,-2,-3,-4))
        # print(psi.shape)
        return np.abs(psi)*np.abs(psi) 
    
    

    

    
class ForkGrating(PGMI):
   
    def __init__(self,init_values):
        self.is2d = True
        self.isFork = True
        self.p = init_values["p"]
        self.phi = init_values["phi"]
         
        self.L2, self.L1 = force_broadcast(init_values["L2"], init_values["L1"])
        self.M = 1/(1 + self.L2/self.L1).squeeze()
        self.OAM = 0
        
        if init_values.get("OAM") is not None:
            self.OAM = init_values["OAM"]
            
        if init_values.get("mt") is not None:
            self.mt = init_values["mt"]
        
        
    
    
    def propagate(self,pos,p,phi,mt,k_l,L1,L2, x0):
        def B(k_l,k):
            return -k**2/(2*k_l)
        self.x0 = x0
#         !change this when x0 becomes grating attribute and not PGMI attribute; should probably also specify spectrum during __init__

        xsize = 10*p
        xft = np.linspace(0,xsize,1001)
        phi,x0, xft = force_broadcast(phi,x0,xft)
        kg, ag = rect_spectrum(xft - x0,xsize,p, phi, mt).squeeze()

#         specifies the spectrum of the grating
        
        k = pos[0]
        a = pos[1]

#         get current position in Fourier space


        if np.any(np.array(k.shape) % kg.shape[-1] == 0):
            mt_ind = -2
            k, kg = force_broadcast(k,kg, nonunique_lengths = [kg.shape[-1]], nonunique_occurences = [2])    
            a, ag = force_broadcast(a,ag, nonunique_lengths = [ag.shape[-1]], nonunique_occurences = [2])    
            
        else:
            mt_ind = -1
            k, kg, a, ag = force_broadcast(k,kg,a,ag)    
            

        k = k + kg
        a = a * ag
#         array assignment is averse to += and *=
    
        k = np.reshape(k,(*k.shape[:mt_ind],np.prod(k.shape[mt_ind:])))
        a = np.reshape(a,(*a.shape[:mt_ind],np.prod(a.shape[mt_ind:])))
        
#         ensure that last dimension gets larger in length without adding more dimensions


        # print(self.M.shape,L2.shape,k.shape,a.shape, k_l.shape)
        self.M, L2, k, a, k_l = force_broadcast(self.M, L2, k, a, k_l)
        # print(self.M.shape,L2.shape,k.shape,a.shape, k_l.shape)
        
            
        nextk = k*self.M
        nexta = a*np.exp(-1j*B(k_l,k)*self.M*L2)
#         scale by magnification factor

#         print(nextk.shape,nexta.shape)
        
        return [nextk,nexta]
    
    def get_rawintensity(self,pos,xarr,yarr,camsize):
        k = pos[0]
        a = pos[1]
        
        print("%.2f %% done" % (xarr[...,0]/camsize*100  + 50))
        clear_output(wait = True)
        
        
        
        # print(k.shape, "before n-reshape")
        k = np.reshape(k,(*k.shape[:-1], k.shape[-1] // (2*self.mt + 1), 2*self.mt + 1))
        a = np.reshape(a,(*a.shape[:-1], a.shape[-1] // (2*self.mt + 1), 2*self.mt + 1))
        
        OAMorders = np.arange(-self.mt,self.mt+1,1) * self.OAM
        phase = np.arctan2(yarr[None,...],xarr[...,None])
        
        # print(k.shape, "after n-reshape", "OAMorders:", OAMorders.shape, "phase:", phase.shape)
        
        if len(xarr) != len(yarr):
            phase, OAMorders = force_broadcast(phase, OAMorders)
            fork_phase = phase * OAMorders
            k, fork_phase = force_broadcast(k,fork_phase)
            a, fork_phase = force_broadcast(a,fork_phase)

        else:
            phase, OAMorders = force_broadcast(phase, OAMorders, nonunique_lengths = [len(xarr)], nonunique_occurences=[2]) 
            fork_phase = phase * OAMorders
            k, fork_phase = force_broadcast(k,fork_phase, nonunique_lengths = [len(xarr)], nonunique_occurences=[2])
            a, fork_phase = force_broadcast(a,fork_phase, nonunique_lengths = [len(xarr)], nonunique_occurences=[2])

        # print("k:", k.shape,"fork_phase:", fork_phase.shape, "a:", a.shape, "xarr:", xarr.shape)
        k, a = np.moveaxis(k, k.shape.index(2*self.mt+1),-2), np.moveaxis(a, a.shape.index(2*self.mt+1),-2)
        fork_phase = np.expand_dims(fork_phase, -2)
        xarr = np.moveaxis(force_equal_dims(xarr,k)[0], 0,k.shape.index(1))
        
        # print("k:", k.shape,"fork_phase:", fork_phase.shape, "a:", a.shape, "xarr:", xarr.shape)
      
        psi = np.sum(a * np.exp(-1j * k * xarr) * np.exp(-1j*fork_phase) , axis = (-1,-2))
        return np.abs(psi) * np.abs(psi)
        
        
        



class RotatedProfile:
    
    def __init__(self,deg, profile, profile_size,profile_height, profile_period):
        self.deg = assert_ndarr(deg)

        self.extended_profile = np.tile(profile,3)
        extended_shape = self.extended_profile.shape[0]
        height_in_pts = round(profile.shape[0] * profile_height // (profile_size))
        self.background = np.zeros((extended_shape,height_in_pts), dtype = np.int8)
        self.scaled_profile = (self.extended_profile - np.min(self.extended_profile) - 1/height_in_pts) / (np.max(self.extended_profile) - np.min(self.extended_profile)) 

        for col_ind, col in enumerate(self.background):
            self.background[col_ind,:int(self.scaled_profile[col_ind]*height_in_pts)+2] = 1

        self.rotated = [rotate(self.background,-d) for d in self.deg]
        self.summed = [np.sum(rot,axis = -1) for rot in self.rotated]

        self.scaled_summed = [2*(s - np.min(s))/(np.max(s) - np.min(s)) - 1 for s in self.summed]

        self.rotated_profiles = [s[extended_shape//3:2*extended_shape//3] for s in self.scaled_summed]
