In [None]:
import numpy as np
import pandas as pd
from scipy.fft import fft,fft2, fftfreq, fftshift, ifft
from scipy.signal import oaconvolve

from PIL import Image, ImageFilter

from IPython.display import clear_output
from scipy.ndimage import rotate
import matplotlib.pyplot as plt



In [None]:
def assert_ndarr(container):
"""
    Ensures that the input is cast to a numpy array.

    Args:
        container: an array-like object containing some values.
    Returns:
        The contents of container cast to a numpy array.
    
"""
    
    if type(container) == np.ndarray and container.ndim > 0:
        return container
    elif type(container) == list:
        return np.squeeze(container)
    elif container is None:
        raise TypeError("Input is None")
    return np.array([container])

def get_chi_sq(data, modelled, uncert):
"""
    Ensures that the input is cast to a numpy array.

    Args:
        container: an array-like object containing some values.
    Returns:
        The contents of container cast to a numpy array.
    
"""
    return np.sum((data - modelled)**2 / (uncert**2))

def print_shapes(*arrs):
    for arr in arrs:
        print(arr.shape, end = " ")
    print("\n")

def normalize_profile(profile, axis = -1,bounds = [-1,1]):
    profile, profile_min, profile_max = force_broadcast(profile, np.amin(profile,axis=axis), np.amax(profile,axis=axis))
    return (bounds[1] - bounds[0]) * (profile - profile_min)/(profile_max - profile_min) + bounds[0]

def get_param_axis(ndarr, a):
    if type(a) == np.ndarray:
        a = a.squeeze()
        if a.ndim > 0:
            return ndarr.shape.index(a.shape[-1])
        else:
            return None
    return ndarr.shape.index(a)

def flatten_after_axis(*arrs, axis = -2):
    return [np.reshape(arr, (*arr.shape[:axis], np.prod(arr.shape[axis:]))) for arr in arrs]

def flatten_after_axis_2d(arr, axis = -2):
    return np.reshape(arr, (*arr.shape[:2*axis], np.prod(arr.shape[2*axis:axis]), np.prod(arr.shape[axis:])))

def retain_shape_after_index(arr,locs, flat_index = False):
    arr = assert_ndarr(arr)
    if not flat_index:
        return arr[...,locs].reshape(*(arr.shape[:-1]),-1)
    return arr[locs].reshape(*(arr.shape[:-1]),-1)


def generate_stepfunc(width,xbin = None, n_pts = None, shift = 0):
    # stepfunc_x = np.linspace(-shift,10*width-shift,int(10*width/xbin)) if n_pts is None else np.linspace(-shift,10*width-shift,n_pts)
    # return np.heaviside(stepfunc_x - 4.5*width,1) - np.heaviside(stepfunc_x - 5.5*width,1)
#     behaves differently for source gratings, needs further work
    stepfunc_x = np.linspace(-shift,width-shift + xbin,int(width/xbin) + 1) if n_pts is None else np.linspace(-shift,width-shift + xbin,n_pts)
    return np.heaviside(stepfunc_x- xbin,1) - np.heaviside(stepfunc_x - width,1)

def generate_source_grating(period, duty_cycle, binsize, n_periods = 50):
    single_source_x = np.linspace(0,period, int(period/binsize)).squeeze()
    single_source = np.heaviside(single_source_x,1) - np.heaviside(single_source_x - duty_cycle*period,1)
    source = assert_ndarr([*np.tile(single_source, n_periods)])
    return source

def generate_single_source(period, duty_cycle, binsize, n_periods, current_period):
    single_source_x = np.linspace(0,(n_periods*period), int((n_periods*period)/binsize)+1)
    single_source = np.heaviside(single_source_x - (current_period*period),1)\
                    - np.heaviside(single_source_x - (current_period*period + duty_cycle*period),1)
    return single_source

def sinc(x):
    x_is_0 = x == 0
    out = np.sin(x)
    x[x_is_0] = 1
    out /= x
    out[x_is_0] = 1

    return out

def squareFT(width,k,shift = 0):
    return -np.exp(-1j*shift)*sinc(width*k/2)

def source_grating_FT(sg,target_freqs, xbin):

    padded = np.pad(sg, 10*len(sg))
    ft = fft(padded, axis = -1, norm="forward")
    freqs = get_freqs(ft,get_param_axis(ft,padded),ft.shape[-1]*xbin)
    # plt.plot(freqs, np.abs(ft))
    # plt.xlim(-2*pi / 300e-6, 2*pi / 300e-6)
    # plt.show()

    target_freqs, freqs = force_broadcast(target_freqs, freqs)

    return ft[np.argmin(np.abs(freqs - target_freqs), axis = -1)]
    
    

def cosine_func(x,A,B,p,phi):
    return A + B*np.cos(2*np.pi*x/p + phi)

def sphere_p_s(r_sphere, p_g, const):
    return const * (p_g**2.5) / (r_sphere**1.5)

def sphere_autocorrelation_func(x,r_sphere):
    zeta = x / r_sphere
    return np.real(np.sqrt(1 - (zeta/2)**2)*(1+ zeta**2 / 8) + zeta**2/2 *\
            (1 - (zeta/4)**2) * np.log(zeta / (2 + np.sqrt(4 - zeta**2))))
    # return np.real(np.exp(-9/8 * (zeta)**2))

def max_index(ndarr):
    return np.unravel_index(np.argmax(ndarr),ndarr.shape)

def min_index(ndarr):
    return np.unravel_index(np.argmin(ndarr),ndarr.shape)

def lowpass(data, minperiod, pxlsize, order):

    normal_cutoff = pxlsize/minperiod
    # Get the filter coefficients 
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
 
    return filtfilt(b, a, data)

def apply_over_bins(func, arr_to_bin, bin_width, axis = None):
    n_bins = arr_to_bin.shape[0] // bin_width
    ret = []
    for i in range(n_bins+1):
        ret.append(func(arr_to_bin[i*bin_width:(i+1)*bin_width], axis = axis))
    return ret

def process_tfl(imgarr, rot_deg, medfilt_window = None):
    imgarr[~np.isfinite(imgarr)] = 0
    ycut,xcut = (np.array(imgarr.shape) * np.abs(np.tan(np.radians(rot_deg)))).astype(int)
    if medfilt_window is not None:
        # return medfilt2d(np.array(Image.fromarray(imgarr).rotate(rot_deg))[ycut:-ycut, xcut: -xcut],medfilt_window)
        return medfilt2d(np.array(Image.fromarray(imgarr).rotate(rot_deg)),medfilt_window)

    else:
        # return np.array(Image.fromarray(imgarr).rotate(rot_deg))[ycut: -ycut, xcut: -xcut]
        return np.array(Image.fromarray(imgarr).rotate(rot_deg))

def lin_func(d,d0,L,L0,p):
    return (d - d0)/((L-L0)*p)

def grating_equation(xarr,p, phase_offset):
    return np.sign(np.sin(2*np.pi*xarr/p + phase_offset))

def grating_equation_2d(xarr,yarr,px,py):
    
    return np.sign(np.cos(2*pi*xarr/px) + np.cos(2*pi*yarr/py) - 1)

def maxmincont(arr):
    return (np.amax(arr,axis = -1) - np.amin(arr,axis = -1))/(np.amax(arr,axis = -1) + np.amin(arr,axis = -1))

def get_freqs(ft_like, axis, real_signal_length):
    real_signal_length = assert_ndarr(real_signal_length)
    return 2*np.pi*np.squeeze([fftfreq(ft_like.shape[axis],size/ft_like.shape[axis]) for size in real_signal_length]) 

def lookup_inds(vals_to_find, available_vals):
    lookup_vals = np.intersect1d(vals_to_find,available_vals, assume_unique = True)

    return np.squeeze(np.nonzero(np.isin(available_vals,lookup_vals)))



def get_complement_inds(arr,arr_complement,target):
    
    start = True

    for i in range(arr.shape[-1]):
#        if i % 10 == 0:
#            clear_output(wait = True)
#            print("%.2f %% done rolling" % (100*i/arr.shape[-1]))
        
        rolled_sum = arr + np.roll(arr_complement,-i, axis = -1)
        bool_where_target = np.isclose(rolled_sum, target)
        
        if np.any(bool_where_target):
            
            nonzero = assert_ndarr(np.squeeze(np.nonzero(bool_where_target)))
            if start:
                arr_inds, arr_complement_inds = nonzero, (nonzero + i) % arr.shape[-1]
                start = False
            else:
                arr_inds = np.concatenate((arr_inds,nonzero), axis = -1)
                arr_complement_inds = np.concatenate((arr_complement_inds,(nonzero + i)),\
                                    axis = -1) % arr.shape[-1]
                
            # print(arr[...,arr_inds] + arr_complement[...,arr_complement_inds])
            
    return [np.squeeze(arr_inds).astype(int),np.squeeze(arr_complement_inds).astype(int)]


def get_a_final(a,k,target):

    a_conj, k_conj = np.conj(a), -k

    indices, conj_indices = get_complement_inds(k,k_conj,target) 
    
    a = np.sum(retain_shape_after_index(a,indices) * retain_shape_after_index(a_conj,\
        conj_indices), axis = -1)

    return a


def get_a_final_2d(a,kx,ky,targetx,targety):

    a_conj, kx_conj, ky_conj = np.conj(a), -kx, -ky
    
    
        
    indicesx, conj_indicesx = get_complement_inds(kx,kx_conj,targetx)
    indicesy, conj_indicesy = get_complement_inds(ky,ky_conj,targety)
    
    if (np.squeeze(targetx) == 0) & (np.squeeze(targety) == 0):
        
        a_zero = np.sum(a[...,indicesx[:,None], indicesy[None,:]] * a_conj[...,conj_indicesx[:,None], conj_indicesy[None,:]], axis = (-2,-1))

        return a_zero

    else:
        x_zero_indices, conj_x_zero_indices = get_complement_inds(kx,kx_conj,0)
        y_zero_indices, conj_y_zero_indices = get_complement_inds(ky,ky_conj,0)
        
    
        a_targetx = np.sum(a[...,indicesx[:,None],y_zero_indices[None,:]] * a_conj[...,conj_indicesx[:,None],conj_y_zero_indices[None,:]], axis = (-2,-1))

        a_targety = np.sum(a[...,x_zero_indices[:,None],indicesy[None,:]] * a_conj[...,conj_x_zero_indices[:,None],conj_indicesy[None,:]], axis = (-2,-1))


        return [a_targetx,a_targety]



def fork_get_regions(nd_intens, xcam, ycam, pg, L, di):
    # so stupid and rushed 
    imgshp = nd_intens.shape[-2:]
    pm_in_pts = int((L*pg/di)/xcam * imgshp[0])
    starty = int(imgshp[1]/20)
    
    left_startx = np.argmax(nd_intens[..., :pm_in_pts,starty], axis = -1)
    right_startx = np.argmax(nd_intens[..., imgshp[0] - pm_in_pts:,starty], axis = -1)
    
    startx_dist = np.abs(right_startx - left_startx)
    
    xl, yl = left_startx, starty
    xr, yr = right_startx, starty
    
    for step in range(imgshp[1] - 2*starty):
        gradxl = nd_intens[...,xl + 1,yl] - nd_intens[...,xl - 1, yl]
        gradyl = nd_intens[...,xl,yl + 1] - nd_intens[...,xl,yl - 1]
        gradl_mag = (gradxl**2 + gradyl**2)**0.5
        
        gradxr = nd_intens[...,xr + 1,yr] - nd_intens[...,xr - 1, yr]
        gradyr = nd_intens[...,xr,yr + 1] - nd_intens[...,xr,yr - 1]
        gradr_mag = (gradxr**2 + gradyr**2)**0.5
        
        if np.all(gradl_mag == 0) or np.all((np.rint(gradxl/gradl_mag) == 0) & (np.rint(gradyl/gradl_mag) == 0)):
            stepxl,stepyl = 1, 1
        else:
            stepxl, stepyl = np.rint(gradxl/gradl_mag), np.rint(gradyl/gradl_mag)
            
        if np.all(gradr_mag == 0) or np.all((np.rint(gradxr/gradr_mag) == 0) & (np.rint(gradyr/gradr_mag) == 0)):
            stepxr,stepyr = -1, 1
        else:
            stepxr, stepyr = np.rint(gradxr/gradr_mag), np.rint(gradyr/gradr_mag)
        
        xl += stepxl
        yl += stepyl
        xr += stepxr
        yr += stepyr
    
    endx_dist = np.abs(xr - xl)
    
    if startx_dist > endx_dist:
        fork_region = [left_startx,right_startx,starty, starty*5]
        nofork_region = [xl,xr,yl,imgshp[1]]
    else:
        fork_region = [xl,xr,yl,imgshp[1]]
        nofork_region = [left_startx,right_startx,starty, starty*5]
    
    return [fork_region, nofork_region]

def best_fit_moire_period(func,xdata,intens,pg,L,d):
    
    def pfit(func,x,intens, p0):
        params, cov = curve_fit(func,x,intens, p0, maxfev = 20000) 
        return np.array([params,np.sqrt(np.diag(cov))])
        
    pmoire_dist = np.linspace(0.875*L * pg/d, 1.125*L*pg/d, 51).squeeze()

    A,B,phi = np.mean(intens), np.amax(intens) - np.mean(intens), 0

    all_params_with_error = np.swapaxes(jb.Parallel(n_jobs = -1)(jb.delayed(pfit)(func,xdata,intens, p0 = [A,B,pmoire,phi])\
                            for pmoire in pmoire_dist), -2,0)
    allparams, allerror = all_params_with_error

    return np.array(allparams[np.argmin(allerror[...,2])])

def best_fit_moire_period_with_error(func,xdata,intens,pg,L,d, use_abs_sigma = True):
    
    def pfit(func,x,intens, p0):
        if use_abs_sigma:
            params, cov = curve_fit(func,x,intens, p0, maxfev = 20000, sigma = np.sqrt(intens), absolute_sigma = True) 
        else:
            params, cov = curve_fit(func,x,intens, p0, maxfev = 20000)
            
        return np.array([params,np.sqrt(np.diag(cov))])
        
    pmoire_dist = np.linspace(0.875*L * pg/d, 1.125*L*pg/d, 51).squeeze()

    A,B,phi = np.mean(intens), np.amax(intens) - np.mean(intens), 0

    all_params_with_error = np.swapaxes(jb.Parallel(n_jobs = -1)(jb.delayed(pfit)(func,xdata,intens, p0 = [A,B,pmoire,phi])\
                            for pmoire in pmoire_dist), -2,0)
    allparams, allerror = all_params_with_error

    return [np.array(allparams[np.argmin(allerror[...,2])]), allerror[np.argmin(allerror[...,2])]]

def force_broadcast(*arrs, nonunique_lengths = [], nonunique_occurrences = [], desired_nonunique_arr_ax = [[]]):
    
    
    arrs = [assert_ndarr(arr).squeeze() for arr in arrs]

    output_lengths = []

    for i in range(len(arrs)):
        for j in range(arrs[i].ndim):
            length = arrs[i].shape[j]
            if length in nonunique_lengths and output_lengths.count(length) <\
                nonunique_occurrences[nonunique_lengths.index(length)]:
                for _ in range(nonunique_occurrences[nonunique_lengths.index(length)]):
                    output_lengths.append(length) 

            elif length not in output_lengths:
                output_lengths.append(length) 
    
    output_lengths = [i for i in output_lengths if i not in nonunique_lengths] +\
                    [i for i in output_lengths if i in nonunique_lengths]

    common_ndim = len(output_lengths)
    axes = np.arange(common_ndim)
    # lengths_copy = output_lengths.copy()
    
    newshapes = np.ones((len(arrs),len(output_lengths)),dtype = int)
    
    if nonunique_lengths != []:
        for i, length in enumerate(nonunique_lengths):
            for (ind, ax) in desired_nonunique_arr_ax[i]:
                newshapes[ind,ax] = length
    
    
    for i, arr in enumerate(arrs):
        shp = arrs[i].shape
        for j, length in enumerate(output_lengths):
            if length in shp and length not in nonunique_lengths:
                newshapes[i,j] = length
        
        arrs[i] = np.reshape(arrs[i], newshapes[i].astype(int))
            
    new_arrs = [assert_ndarr(arr) for arr in arrs]
    return new_arrs 

def force_equal_dims(top, btm):
    final_ndim = max(top.ndim,btm.ndim)
    for _ in range(final_ndim - top.ndim):
        top = np.expand_dims(top,-1)
    for _ in range(final_ndim - btm.ndim):
        btm = np.expand_dims(btm,0)
    return [top,btm]




# sd stands for setup dictionary
class PGMI:
#     !to-do eventually: replace integer references to axes with string variable names (using numpy record arrays or Pandas Index ?) i.e. map input
#     strings to array axes and define class methods that get values/do something along input axes (np.apply_along_axis, np.meshgrid might be useful)
#     might also be worthwhile to predefine the entire dimensionality of the simulation if possible

    def __init__(self, apts_dict, sd):

        self.apts_dict = apts_dict
        self.N_objects = len(apts_dict)
        
        for n, Apparatus in enumerate(self.apts_dict.values()):
            setattr(Apparatus,"apts_num", n+1)
        
        for key, value in sd.items():
            setattr(self, key, assert_ndarr(value))
        
        self.slitxflag = False
        self.slityflag = False

        if sd.get("slitx") is not None:
            self.slitxflag = True


        if sd.get("slity") is not None:
            self.slityflag = True
            
        

        self.stepresx = generate_stepfunc(self.resx,self.xbin).squeeze()
        self.stepresy = generate_stepfunc(self.resy,self.ybin).squeeze()
                
        
        self.k0 = 2*np.pi/self.lam
        self.neu_spec = None
        
        for Apparatus in self.apts_dict.values():
            if type(Apparatus) == RectGrating2D:
                self.px,self.py = Apparatus.px, Apparatus.py
        

        if sd.get("source_period") is not None:
            self.source_grating_flag = True
            self.source_grating = generate_source_grating(self.source_period,self.duty_cycle,self.xbin, n_periods= 10)
        else:
            self.source_grating_flag = False
        
        if sd.get("gravity") is None:
            self.gravity = False
        
        elif self.gravity:
            # self.g = 9.80665
            g_nominal = 9.80665
            g_spread = 0.25
            g_res = 0.01
            self.g = np.arange(g_nominal - g_spread, g_nominal + g_spread + g_res, g_res)
            self.planck_h = 6.62607015e-34
            self.neu_mass = 1.6749275e-27
            # self.grav_yshift = self.g/2 * (self.L*self.neu_mass*self.lam/self.planck_h)**2
            self.C_0_x = 0
            self.C_0_y = 0 # could be setup parameter in future but can be 0 w.l.g I think
            
            self.d, self.lam, self.g, L2 = force_broadcast(self.d, self.lam, self.g, list(self.apts_dict.values())[-1].L2)
            
            self.grav_xshift = np.pi * self.g * np.sin(np.radians(self.cam_rotdeg)) * self.d / (self.L * self.px) * \
                        (L2 * self.neu_mass * self.lam / self.planck_h)**2 + self.C_0_x
            
            self.grav_yshift = np.pi * self.g * np.cos(np.radians(self.cam_rotdeg)) * self.d / (self.L * self.py) * \
                        (L2 * self.neu_mass * self.lam / self.planck_h)**2 + self.C_0_y
            
            self.d, self.lam = self.d.squeeze(), self.lam.squeeze()
            
        if sd.get("bin_lam") is None:
            self.bin_lam = False
            
        if sd.get("Ioffset") is None:
            self.Ioffset = None 
            
        
    def get_apts(self):
        return self.apts_dict
    
    def get_values(self):
        return vars(self)
    
    def get_value(self,key):
        return vars(self)[key]

    def propagate_to(self, key):
        Apparatus = self.get_apts()[key]
        
        self.neu_spec = Apparatus.propagate(self.k0, self.neu_spec)


    def generate_after(self, key):
        Apparatus = self.get_apts()[key]
        
        if Apparatus.is2d:
            kx_moire_vals = (2*np.pi/(self.L*self.px/self.d)).squeeze()
            ky_moire_vals = (2*np.pi/(self.L*self.py/self.d)).squeeze()
            
            a, kx, ky = self.neu_spec
            kx, ky = np.real_if_close(kx.squeeze()), np.real_if_close(ky.squeeze())
            
            a, kx, ky = assert_ndarr(flatten_after_axis_2d(a)), assert_ndarr(flatten_after_axis(kx)), assert_ndarr(flatten_after_axis(kx))
            
            
            print_shapes(a, kx, ky)
            if self.plam.squeeze().ndim > 0:
                kx = np.moveaxis(kx,get_param_axis(kx,self.plam),0)[0]
                ky = np.moveaxis(ky,get_param_axis(ky,self.plam),0)[0]
                # moire frequency is independent of wavelength, makes indexing much simpler
            
            if self.d.squeeze().ndim > 0:
                a, kx, ky = np.moveaxis(a,get_param_axis(a,self.d), 0), np.moveaxis(kx,get_param_axis(kx,kx_moire_vals), 0),\
                            np.moveaxis(ky,get_param_axis(ky,ky_moire_vals), 0)
                
                print_shapes(a, kx, ky)

                a_moirex, a_moirey = get_a_final_2d(a,kx[tuple(np.zeros(kx.ndim - 1, dtype = int))],\
                                    ky[tuple(np.zeros(ky.ndim - 1, dtype = int))],kx_moire_vals[0], ky_moire_vals[0])
                # indices are the same for all d, so just pick k corresponding to the first one for speed

                a_zero = get_a_final_2d(a,kx[tuple(np.zeros(kx.ndim - 1, dtype = int))],ky[tuple(np.zeros(ky.ndim - 1, dtype = int))],0,0)
            else:
                a_moirex, a_moirey = get_a_final_2d(a,kx,ky,kx_moire_vals, ky_moire_vals)
                
                a_zero = get_a_final_2d(a,kx,ky,0,0)
                

            self.all_a, self.all_kx, self.all_ky = a, kx, ky

            k_zero = np.zeros_like(kx_moire_vals)

            # print("past get_a_k_final")
            
            
            kx = np.array([k_zero,kx_moire_vals]) 
            ky = np.array([k_zero,ky_moire_vals]) 

            
            if self.gravity:
                kx, self.slitx, self.grav_xshift = force_broadcast(kx,self.slitx, self.grav_xshift)
                ky, self.slity, self.grav_yshift = force_broadcast(ky,self.slity, self.grav_yshift)
                slitxft = squareFT(self.slitx, kx, shift = self.grav_xshift)
                pxlxft = squareFT(self.resx, kx, shift = self.grav_xshift)
                slityft = squareFT(self.slity, ky, shift = self.grav_yshift)
                pxlyft = squareFT(self.resy, ky, shift = self.grav_yshift)
                
            else:
                
                slitxft = squareFT(self.slitx, kx)
                pxlxft = squareFT(self.resx, kx)
                slityft = squareFT(self.slity, ky)
                pxlyft = squareFT(self.resy, ky)
            
                    
            a_zero, slitxft_zero, pxlxft_zero, slityft_zero, pxlyft_zero = force_broadcast(a_zero, slitxft[0], pxlxft[0],\
                                                                            slityft[0],pxlyft[0])
            a_zero = a_zero * slitxft_zero * pxlxft_zero * slityft_zero * pxlyft_zero
            
            
            a_moirex, slitxft_moire, pxlxft_moire = force_broadcast(a_moirex, slitxft[1], pxlxft[1])
            a_moirex = a_moirex * slitxft_moire * pxlxft_moire
            
            
            a_moirey, slityft_moire, pxlyft_moire = force_broadcast(a_moirey, slityft[1],pxlyft[1])

            
            a_moirey = a_moirey * slityft_moire * pxlyft_moire
            
            
            if self.bin_lam:
                a_zero, a_moirex, a_moirey, self.plam = force_broadcast(a_zero, a_moirex, a_moirey, self.plam)
                a_zero, a_moirex, a_moirey = a_zero*self.plam,a_moirex*self.plam,a_moirey*self.plam
                
                bins = np.arange(round(np.min(self.lam/self.lam_binsize)), \
                        round(np.max(self.lam/self.lam_binsize))+1,1) * self.lam_binsize
                
                print("Wavelength bins:", bins)
                
                bin_ind = np.argmin(np.abs(self.lam[:,None] - bins[None,:]), axis = 0)
                
                lam_axis = get_param_axis(a_zero, self.lam)
                
                
                a_zero, a_moirex, a_moirey = np.moveaxis(a_zero, lam_axis, 0), np.moveaxis(a_moirex, lam_axis, 0),\
                                            np.moveaxis(a_moirey, lam_axis, 0)
                
                
                a_zero   = np.moveaxis([np.sum(a_zero[bin_ind[i]:bin_ind[i+1]],   axis = 0) for i in range(len(bins)-1)],0,lam_axis)
                a_moirex = np.moveaxis([np.sum(a_moirex[bin_ind[i]:bin_ind[i+1]], axis = 0) for i in range(len(bins)-1)],0,lam_axis)
                a_moirey = np.moveaxis([np.sum(a_moirey[bin_ind[i]:bin_ind[i+1]], axis = 0) for i in range(len(bins)-1)],0,lam_axis)
                

            elif self.sum_lam:
                a_zero, a_moirex, a_moirey, self.plam = force_broadcast(a_zero, a_moirex, a_moirey, self.plam)
                a_zero, a_moirex, a_moirey = a_zero*self.plam,a_moirex*self.plam,a_moirey*self.plam
                    
                a_zero, a_moirex, a_moirey = np.sum(a_zero*self.plam, axis = get_param_axis(a_zero,self.plam)),\
                                            np.sum(a_moirex*self.plam, axis = get_param_axis(a_moirex,self.plam)),\
                                            np.sum(a_moirey*self.plam, axis = get_param_axis(a_moirey,self.plam))
            
            
            self.contrast_x = np.abs(2*a_moirex/a_zero)
            self.contrast_y = np.abs(2*a_moirey/a_zero)
            self.phase_x = np.arctan2(np.imag(a_moirex), np.real(a_moirex))
            self.phase_y = np.arctan2(np.imag(a_moirey), np.real(a_moirey))
            
        
        else:
            k_moire_vals = (2*np.pi/(self.L*self.p/self.d)).squeeze()

            a, k = flatten_after_axis(*self.neu_spec,axis = -Apparatus.apts_num)

            k = np.real_if_close(k)

            # print_shapes(k, k_moire_vals)

            if self.plam.squeeze().ndim > 0:
                k = np.moveaxis(k,get_param_axis(k,self.plam),0)[0]
                # moire frequency is independent of wavelength, makes indexing much simpler

            if self.d.squeeze().ndim > 0:
                a, k = np.moveaxis(a,get_param_axis(a,self.d), 0), np.moveaxis(k,get_param_axis(k,k_moire_vals), 0)
                # indices are the same for all d, so just pick k corresponding to the first one for speed

                a_moire = get_a_final(a,k[tuple(np.zeros(k.ndim - 1, dtype = int))],k_moire_vals[tuple(np.zeros(k_moire_vals.ndim, dtype = int))])
                a_zero = get_a_final(a,k[tuple(np.zeros(k.ndim - 1, dtype = int))],0)

            else:
                a_moire = get_a_final(a,k,k_moire_vals)
                a_zero = get_a_final(a,k,0)




            k_zero = np.zeros_like(k_moire_vals)

            # print("past get_a_k_final")
            a, k = np.array([a_zero,a_moire]), np.array([k_zero,k_moire_vals])
            # print_shapes(a,k)

            if self.sum_lam:
                a, self.plam = force_broadcast(a, self.plam)

                a = np.sum(a*self.plam, axis = get_param_axis(a,self.plam))

            if self.source_grating_flag:
                slitxft = source_grating_FT(self.source_grating,k, self.xbin)
            else:
                slitxft = squareFT(self.slitx, k)
            pxlft = squareFT(self.resx, k)

            a, slitxft, pxlft = force_broadcast(a, slitxft, pxlft)
            a *= slitxft * pxlft
            # technically don't need to multiply a_zero by the slit/pxl sinc functions because sinc(k_zero) = 1

            self.contrast = np.abs(2*a[1]/a[0])
    
    def intensity_fit(self, key):
        
        Apparatus = self.get_apts()[key]
        
        self.batches = int(self.batches)

        
        self.x = self.x[:len(self.x) - len(self.x) % self.batches]

        xbatches = np.array([self.x[i*len(self.x)//self.batches:(i+1)*len(self.x)//self.batches] for i in range(self.batches)])
        
        if Apparatus.is2d:
            
            self.y = self.y[:len(self.y) - len(self.y) % self.batches]

            ybatches = np.array([self.y[i*len(self.y)//self.batches:(i+1)*len(self.y)//self.batches] for i in range(self.batches)])

            # self.raw = np.array([[Apparatus.generate_raw_intensity(self.neu_spec,xi,\
            #             yi, i, self.batches) for yi in ybatches] for i, xi in enumerate(xbatches)])
            self.raw = np.array(jb.Parallel(n_jobs = -1, prefer = "threads")(jb.delayed(Apparatus.generate_raw_intensity)(self.neu_spec,xi,\
                        yi, i, self.batches) for i, xi in enumerate(xbatches) for yi in ybatches))
            
            print(self.raw.shape, "after generating raw")
            
            self.raw = np.reshape(self.raw, ((len(xbatches),len(ybatches),*self.raw.shape[1:])))

            if self.plam.squeeze().ndim > 0:
                self.raw = np.moveaxis(self.raw, get_param_axis(self.raw,self.plam.squeeze()), 0)

            
            if self.d.squeeze().ndim > 0:
                self.raw = np.moveaxis(self.raw, get_param_axis(self.raw,self.d.squeeze()), 0)

            self.raw = np.swapaxes(self.raw, -2,-3)
            
            print(self.raw.shape, "before reshape")
            
            self.raw = np.reshape(self.raw,(*(self.raw.shape[:-4]),len(self.x), len(self.y)))
            
            print(self.raw.shape, "after reshape")
            
            if self.sum_lam:
                self.raw, self.plam = force_broadcast(self.raw,self.plam)
                self.raw = np.sum(self.raw*self.plam, axis = get_param_axis(self.raw, self.plam.squeeze()))
                
            
            for _ in range(self.raw.ndim - self.stepresx.ndim-1):
                self.stepresx = np.expand_dims(self.stepresx, 0)
                
            self.stepresx = np.expand_dims(self.stepresx, -1)
            # print(self.stepresx.shape)
            
            self.raw, self.stepresy = force_equal_dims(self.raw,self.stepresy)
            # print(self.stepresy.shape)


            if self.slityflag:
                self.stepslity = generate_stepfunc(self.slity,self.ybin).squeeze()

                self.raw, self.stepslity = force_equal_dims(self.raw,self.stepslity)
                # print(self.stepslity.shape)   

                
                self.intensity = oaconvolve(self.raw,self.stepslity, mode = self.convmode, axes = -1)

            if self.slitxflag:
                self.stepslitx = generate_stepfunc(self.slitx,self.xbin).squeeze()

                for _ in range(self.raw.ndim - self.stepslitx.ndim - 1):
                    self.stepslitx = np.expand_dims(self.stepslitx, 0)
                self.stepslitx = np.expand_dims(self.stepslitx, -1)
                # print(self.stepslitx.shape)   

                self.intensity = oaconvolve(self.intensity,self.stepslitx, mode = self.convmode, axes = -2)
            # print(self.intensity.shape)   

            
            
            if not (self.slitxflag and self.slityflag):
                print("No slits")
                self.intensity = self.raw
                                    
            self.intensity = oaconvolve(oaconvolve(self.intensity, self.stepresy, mode = self.convmode, axes = -1),\
                                self.stepresx, mode = self.convmode, axes = -2)
            # print(self.intensity.shape)
            
            self.intensity_x = np.sum(self.intensity, axis = -1)

            # self.x = self.x[:self.intensity_x.shape[-1]]

            
            if type(Apparatus) == ForkGrating:
               
                self.fitparams_x = np.empty((*(self.intensity_x.shape[:-1]),4))

                
                if self.intensity_x.ndim < 3:
                    for i in range(self.intensity_x.shape[0]):
                        # clear_output(wait = True)
                        # print("Fitting")
                        # print("%.2f %% done" % (i/self.intensity.shape[0] * 1e2))
                        
                        
                        
                        
                        self.fitparams_x[i] = best_fit_moire_period(cosine_func, self.x, self.intensity_x[i],\
                                                self.p,self.L,self.d[i]) if len(self.d) > 1 \
                                                else best_fit_moire_period(cosine_func, self.x, self.intensity_x[i],\
                                                self.p,self.L,self.d)
                        
                
                else:
                    for i in range(self.intensity_x.shape[0]):
                        clear_output(wait = True)
                        print("Fitting")
                        print("%.2f %% done" % (i/self.intensity_x.shape[0] * 1e2))

                        for j in range(self.intensity_x.shape[1]):
                            if self.d.squeeze().ndim > 1:
                            
                                params = best_fit_moire_period(cosine_func,self.x,self.intensity_x[i,j],self.p,self.L,self.d[i,j])
                            else:

                                params = best_fit_moire_period(cosine_func,self.x,self.intensity_x[i,j],self.p,self.L,self.d[i]) 

                            self.fitparams_x[i,j] = params

                self.contrast = np.abs(self.fitparams_x[...,1]/self.fitparams_x[...,0])
                return
            
            else:
                
                self.intensity_y = np.sum(self.intensity, axis = -2)
                
                self.y = self.y[:self.intensity_y.shape[-1]]
                
                
                self.fitparams_x = np.empty((*(self.intensity_x.shape[:-1]),4))
                self.fitparams_y = np.empty((*(self.intensity_y.shape[:-1]),4))
                
                for i, d_i in enumerate(self.d.squeeze()):
                    clear_output(wait = True)
                    print("Fitting")
                    print("%.2f %% done" % (i/self.d.shape[0] * 1e2))
                    
                    self.fitparams_x[i] = best_fit_moire_period(cosine_func, self.x, self.intensity_x[i],self.px,self.L, d_i)
                    self.fitparams_y[i] = best_fit_moire_period(cosine_func, self.y, self.intensity_y[i],self.py,self.L, d_i)

                
                self.contrast_x = np.abs(self.fitparams_x[...,1]/self.fitparams_x[...,0])
                self.contrast_y = np.abs(self.fitparams_y[...,1]/self.fitparams_y[...,0])

                return
        else:
            self.raw = np.array(jb.Parallel(n_jobs = -1, prefer = "threads")(jb.delayed(Apparatus.generate_raw_intensity)(self.neu_spec,\
                                                    xi,i,self.batches) for i, xi in enumerate(xbatches)))

            print(self.raw.shape, "after parallel")
            if self.plam.squeeze().ndim > 0:
                self.raw = np.moveaxis(self.raw, get_param_axis(self.raw,self.plam.squeeze()), 0)

            
            if self.d.squeeze().ndim > 0:
                self.raw = np.moveaxis(self.raw, get_param_axis(self.raw,self.d.squeeze()), 0)
                
                
            self.raw = np.reshape(self.raw,(*self.raw.shape[:-2], self.raw.shape[-2] * self.raw.shape[-1]))
            print(self.raw.shape, "after reshape")
            
            if self.plam.squeeze().ndim > 0:
                self.raw, self.plam = force_broadcast(self.raw,self.plam)
                self.raw = np.sum(self.raw * self.plam, axis = get_param_axis(self.raw, self.plam.squeeze()))
            
#             plt.plot(self.raw[10])
#             plt.show()
            
            if self.slitxflag:
                self.stepslitx = generate_stepfunc(self.slitx,self.xbin).squeeze()

                self.raw, self.stepslitx = force_equal_dims(self.raw,self.stepslitx)
                
            self.raw, self.stepresx = force_equal_dims(self.raw,self.stepresx)
            print(self.raw.shape, "before convolution")
            
            if self.slitxflag: 
   
                self.intensity = oaconvolve(oaconvolve(self.raw,self.stepslitx, mode = self.convmode, axes = -1),\
                                self.stepresx, mode = self.convmode, axes = -1)
    
            elif self.source_grating_flag:
                self.raw, self.source_grating = force_equal_dims(self.raw, self.source_grating)
                self.intensity = oaconvolve(oaconvolve(self.raw,self.source_grating, mode = self.convmode, axes = -1),\
                                self.stepresx, mode = self.convmode, axes = -1)
            else:
                self.intensity = oaconvolve(self.raw,self.stepresx, mode = self.convmode, axes = -1)
            
            

            if self.intensity.shape[-1] > camsize/xbin:
                self.intensity = self.intensity[...,:int(camsize/xbin)]

            self.x = self.x[:self.intensity.shape[-1]]
            self.fitparams = np.empty((*(self.intensity.shape[:-1]),4))

            for i, d_i in enumerate(self.d.squeeze()):
                clear_output(wait = True)
                print("Fitting")
                print("%.2f %% done" % (i/self.d.shape[0] * 1e2))
                    
                self.fitparams[i] = best_fit_moire_period(cosine_func, self.x, self.intensity[i],self.p,self.L, d_i)

            
            if self.Ioffset is not None:  
                self.Ioffset, fitparams_offset = force_broadcast(self.Ioffset,self.fitparams[...,0])
                # print(self.Ioffset.shape, fitparams_offset.shape)
                fitparams_offset = fitparams_offset + self.Ioffset
                fitparams_offset, unmodified_B = force_broadcast(fitparams_offset, self.fitparams[...,1])
                self.contrast = np.abs(unmodified_B/fitparams_offset)

            else:
                self.contrast = np.abs(self.fitparams[...,1]/self.fitparams[...,0])
            
            return

        
class Grating:

    def __init__(self,sd):
        for key, value in sd.items():
            setattr(self, key, assert_ndarr(value))
        
        self.L2, self.L1 = force_broadcast(self.L2, self.L1)
        self.M = 1/(1 + self.L2/self.L1).squeeze()

        if sd.get("x0") is None:
            self.x0 = 0


class RectGrating(Grating):
    
    def __init__(self,sd):
        super().__init__(sd)

        self.is2d = False
        
        if sd.get("image_profile") is not None:
            self.profile = sd.get("image_profile")

            self.xpts = 10 * self.profile.shape[-1]
            
            
            self.x0, self.x = force_broadcast(self.x0, np.linspace(-5*self.real_length,5*self.real_length, self.xpts))
            self.x = np.squeeze(self.x - self.x0)
            self.profile = normalize_profile(self.profile)

            self.profile = np.interp(np.linspace(self.x[0],self.x[-1],self.xpts),self.x,np.tile(self.profile,10))

            self.phi, self.profile = force_broadcast(self.phi, self.profile)
            self.profile = self.phi / 2 * self.profile
            self.real_length *= 10
                        
        else:
            if sd.get("n_p_g") is None:
                self.n_p_g = 50
            else:
                self.n_p_g = self.n_p_g.squeeze()
            
            self.real_length = self.n_p_g * self.p
            self.xpts = int(100*np.round(self.n_p_g) + 1)

            # arbitrary number of points
            self.x0, self.x = force_broadcast(self.x0, np.linspace(-self.real_length/2,self.real_length/2, self.xpts).T)
            self.x = self.x - self.x0
            self.p, self.phi, self.x = force_broadcast(self.p, self.phi, self.x)
            self.profile = self.phi / 2 * grating_equation(self.x,self.p, 0 if sd.get("phase_offset") is None else sd.get("phase_offset"))
            self.x = self.x.squeeze()
            
        
        self.FT = fft(np.exp(1j*self.profile), axis = get_param_axis(self.profile,self.xpts),norm="forward")
        self.absFT = np.abs(self.FT)
        self.freqs = get_freqs(self.FT, get_param_axis(self.FT, self.xpts), self.real_length)
        ordrange = np.arange(-self.mt,self.mt+self.spectrum_spacing, self.spectrum_spacing)


        self.p, ordrange = force_broadcast(self.p, ordrange)

        self.ords = 2*np.pi/self.p * ordrange

        self.ords, self.freqs = force_broadcast(self.ords, self.freqs)

        self.locs = np.argmin(np.abs(self.freqs - self.ords),axis = get_param_axis(self.freqs, self.xpts))

        self.FT, self.freqs = self.FT.squeeze(), self.freqs.squeeze()
        dimdiff = self.FT.ndim - self.freqs.ndim
        
        # print(self.FT.shape,self.freqs.shape)
        self.spectrum = np.squeeze([self.FT[...,np.unique(self.locs, axis = 0)],\
                        (np.tile(self.freqs,(*self.FT.shape[:dimdiff],*((1,)*(self.freqs.ndim)))))[...,np.unique(self.locs, axis = 0)]])
        self.num_ords = len(ordrange)
        # print(self.num_ords)
    
    def propagate(self,k0, neu_spec = None):
        def B(k,k0):
            return -k**2/(2*k0)
        
        ag,kg = self.spectrum.copy()


        if neu_spec is None:
            
            self.M, self.L2, k0, a, k = force_broadcast(self.M, self.L2, k0, ag, kg)
            # equivalent to convolving grating spectrum with identity

        else:
            a, k = neu_spec
            
            if np.any(np.array(a.shape) == (2*self.mt + 1)):
                occurences = a.shape.count(2*self.mt + 1)

                if self.spectrum_spacing != 1:
                    a, k, ag, kg =  force_broadcast(a, k, ag, kg, nonunique_lengths=[2*self.mt + 1],\
                                    nonunique_occurrences=[occurences],desired_nonunique_arr_ax=\
                                    [np.vstack((np.tile([0,1],occurences),\
                                    np.repeat(np.arange(-occurences,0,1),2))).T])
                elif occurences != self.apts_num:
                    a, k, ag, kg =  force_broadcast(a, k, ag, kg, nonunique_lengths=[self.num_ords],\
                                    nonunique_occurrences=[occurences + 1],desired_nonunique_arr_ax=\
                                    [np.concatenate((np.vstack((np.tile([0,1],occurences),\
                                    np.repeat(np.arange(-occurences-1,-1,1),2))).T, [[-2,-1],[-1,-1]]))])
                
                else:
                    a, k, ag, kg =  force_broadcast(a, k, ag, kg, nonunique_lengths=[self.num_ords],\
                                    nonunique_occurrences=[occurences+1],desired_nonunique_arr_ax=\
                                    [np.concatenate((np.vstack((np.tile([0,1],occurences-1),\
                                    np.repeat(np.arange(-occurences,-1,1),2))).T, [[-2,-1],[-1,-1]]))])

            else:
                a, k, ag, kg  = force_broadcast(a, k, ag, kg)
            
            
            a = a * ag
            k = k + kg
                
            occurences = a.shape.count(2*self.mt + 1)
            # print_shapes(a,k,self.M, self.L2, k0)
            self.M, self.L2, k0, a, k = force_broadcast(self.M, self.L2, k0, a, k, nonunique_lengths=[2*self.mt +1],\
                                        nonunique_occurrences=[occurences],desired_nonunique_arr_ax= \
                                        [np.vstack((np.tile([-2,-1],occurences), np.repeat(np.arange(-occurences,0,1),2))).T])
            
            # print(a.shape,k.shape,self.M.shape, self.L2.shape, k0.shape)
            
        self.a_prime = a*np.exp(-1j*B(k,k0)*self.M*self.L2)
        self.k_prime = k*self.M
        
        return [self.a_prime, self.k_prime]
    
    def generate_raw_intensity(self,neu_spec,xarr, iteration, batches):
        
        
        
        a, k = neu_spec
        
        a, k = flatten_after_axis(a.squeeze(),k.squeeze(), axis = -self.apts_num)

        
        a, k, xarr = force_broadcast(a, k, xarr, nonunique_lengths= [a.shape[-1]],\
                                nonunique_occurrences= [1], desired_nonunique_arr_ax=\
                                [np.vstack([[0,1], [-1,-1]]).T])
        
        # # print(a.shape, k.shape,  xarr.shape)

        psi = a * np.exp(-1j * (k*xarr))
        print("Generating raw intensity, %.2f %% done" % (100 * iteration/batches))    
        clear_output(wait = True)

        # print(psi.shape)
        psi = np.sum(psi, axis = -1)

        return np.abs(psi) * np.abs(psi)  
            
      
            
class RectGrating2D(Grating):
    
        
    def __init__(self, sd):
        super().__init__(sd)
        
        self.is2d = True
        
        if sd.get("n_p_g") is None:
                self.n_p_g = 50
        else:
            self.n_p_g = self.n_p_g.squeeze()

        self.real_x_length = self.n_p_g * self.px
        self.real_y_length = self.n_p_g * self.py

        self.xpts = self.ypts = int(100*np.round(self.n_p_g) + 1)

        self.x, self.y = (np.linspace([-self.real_x_length/2,-self.real_y_length/2],\
                        [self.real_x_length/2,self.real_y_length/2],self.xpts)).squeeze().T
        

        self.phi, self.x, self.y = force_broadcast(self.phi, self.x, self.y, nonunique_lengths= [self.xpts],\
                                    nonunique_occurrences= [2], desired_nonunique_arr_ax= [np.vstack([[-2,-1],[-2,-1]]).T])

        self.profile = self.phi/2 * grating_equation_2d(self.x,self.y,self.px,self.py)
        
        # c = plt.imshow(self.profile, clim = (0,1e-3))
        # plt.colorbar(c)
        # plt.show()
        
        self.FT = fft2(np.exp(1j*self.profile), axes = (-2,-1), norm = "forward")
        self.absFT = np.abs(self.FT)
        
        
        self.xfreqs = get_freqs(self.FT, get_param_axis(self.FT,self.xpts), self.real_x_length)
        self.yfreqs = get_freqs(self.FT, get_param_axis(self.FT,self.ypts), self.real_y_length)
        
        
        # c = plt.imshow(self.absFT, clim = (0,1e-3), extent = (np.min(self.xfreqs), np.max(self.xfreqs), np.min(self.yfreqs), np.max(self.yfreqs)))
        # plt.colorbar(c)
        # plt.vlines((2*pi/self.px,-2*pi/self.px) , ymin = -1e8, ymax = 1e8, lw = 1, color = "r")
        # plt.hlines((2*pi/self.py,-2*pi/self.py) , xmin = -1e8, xmax = 1e8, lw = 1, color = "r")
        # plt.xlim(-8*pi/self.px,8*pi/self.px)
        # plt.ylim(-8*pi/self.py,8*pi/self.py)
        # plt.show()
        
        self.xords = 2*np.pi/self.px * np.arange(-self.mt,self.mt+self.spectrum_spacing,self.spectrum_spacing)
        self.yords = 2*np.pi/self.py * np.arange(-self.mt,self.mt+self.spectrum_spacing,self.spectrum_spacing)

        self.xlocs = np.argmin(np.abs(self.xfreqs[:,None] - self.xords[None,:]),axis = get_param_axis(self.xfreqs,self.xpts))
        self.ylocs = np.argmin(np.abs(self.yfreqs[:,None] - self.yords[None,:]),axis = get_param_axis(self.yfreqs,self.ypts))
        

        self.xylocs = tuple(np.split(np.array(np.meshgrid(self.xlocs,self.ylocs)).ravel(),2))
        # self.spectrum = [self.FT[...,self.xylocs[0],self.xylocs[1]],self.xfreqs[self.xlocs], self.yfreqs[self.ylocs]]
        
        # print(self.spectrum[0].shape)
        dimdiff = self.FT.ndim - self.xfreqs.ndim
        
        self.spectrum = [self.FT[...,self.xylocs[0],self.xylocs[1]],\
                        np.tile(self.xfreqs[self.xlocs],(*self.FT.shape[:dimdiff - 1],*((1,)*(self.xfreqs.ndim)))),\
                        np.tile(self.yfreqs[self.ylocs],(*self.FT.shape[:dimdiff - 1],*((1,)*(self.yfreqs.ndim))))]
        # print(self.spectrum[0].shape, self.spectrum[1].shape, self.spectrum[2].shape)

        self.num_ords = len(self.xords)

            
    def propagate(self,k0,neu_spec = None):
        def B(kx,ky,k0):
            return -(kx**2 + ky**2)/(2*k0)
        
        ag, kgx, kgy = self.spectrum.copy()
        
        ag = ag.reshape((*ag.shape[:-1],kgx.shape[-1], kgy.shape[-1]))
        

        if neu_spec is None:
            
            self.M, self.L2, k0, a, kx, ky = force_broadcast(self.M, self.L2, k0, ag, kgx, kgy,\
                                            nonunique_lengths= [self.num_ords], nonunique_occurrences=[2*self.apts_num],\
                                            desired_nonunique_arr_ax=[np.vstack(([-3,-3,-2,-1], np.tile([-2,-1],2))).T])
            # equivalent to convolving grating spectrum with identity
        else:
            
            a, kx, ky = neu_spec
            
            
            # print(a.shape, kx.shape, ky.shape, ag.shape,kgx.shape,kgy.shape)
            
            a, kx, ky, ag, kgx, kgy =  force_broadcast(a, kx, ky, ag, kgx, kgy, nonunique_lengths=[self.num_ords],\
                                        nonunique_occurrences=[2*self.apts_num], desired_nonunique_arr_ax = \
                                        [np.vstack([np.concatenate([np.tile([*([0]*(2*(self.apts_num -1))),1,2],self.apts_num-1),\
                                        [*([-3]*(2*(self.apts_num -1))),-2,-1]]), np.concatenate([np.delete(np.tile(np.arange(-2*self.apts_num,0),2),\
                                        np.arange(4)*-self.apts_num -1), np.tile([-self.apts_num -1,-1],2)])]).T])
            # print(a.shape, kx.shape, ky.shape, ag.shape,kgx.shape,kgy.shape)
            # ensures that x and y amplitudes (as opposed to n,m,l,... orders) remain together for later flattening

            # [-4,-2,-4,-2] [-3,-1,-3,-1]
            # [-6,-5,-3,-2,-6,-5,-3,-2] [-4,-1,-4,-1]
            a = a * ag
            kx = kx + kgx
            ky = ky + kgy
            
            # print(self.M.shape, self.L2.shape, k0.shape, a.shape,kx.shape,ky.shape)
            
            self.M, self.L2, k0, a, kx, ky = force_broadcast(self.M, self.L2, k0, a, kx, ky, nonunique_lengths= [self.num_ords],\
                                            nonunique_occurrences=[2*self.apts_num],desired_nonunique_arr_ax=\
                                            [np.vstack([np.append(np.repeat(-3,2*self.apts_num),np.repeat([-2,-1],self.apts_num)),\
                                            np.tile(np.arange(-2*self.apts_num,0,1),2)]).T])
            # print(self.M.shape, self.L2.shape, k0.shape, a.shape,kx.shape,ky.shape)

            
        self.a_prime = a*np.exp(-1j*B(kx,ky,k0)*self.M*self.L2)
        self.kx_prime = kx*self.M
        self.ky_prime = ky*self.M

        return [self.a_prime, self.kx_prime, self.ky_prime]
    
    def generate_raw_intensity(self,neu_spec,xarr,yarr, iteration, batches):
        
        
        
        a, kx, ky = neu_spec
        # print(a.shape, kx.shape, ky.shape)
        
        kx, ky = flatten_after_axis(kx.squeeze(),ky.squeeze(), axis = -self.apts_num)
        a = flatten_after_axis_2d(a,axis = -self.apts_num)
        # print(a.shape, kx.shape, ky.shape)

        
        a, kx, ky, xarr, yarr = force_broadcast(a, kx, ky, xarr, yarr, nonunique_lengths= [a.shape[-1]],\
                                nonunique_occurrences= [2], desired_nonunique_arr_ax=\
                                [np.vstack([[0,0,1,2], [-2,-1,-2,-1]]).T])
        
        # print(a.shape, kx.shape, ky.shape, xarr.shape, yarr.shape)

        psi = a * np.exp(-1j * (kx*xarr + ky*yarr))
        print("Generating raw intensity, %.2f %% done" % (100 * iteration/batches))    
        clear_output(wait = True)

        # print(psi.shape)
        psi = np.sum(psi, axis = (-2,-1))

        return np.abs(psi) * np.abs(psi)
    
    

class SphericalSample(Grating):
    
    
    def __init__(self,sd):
        super().__init__(sd)
        
        self.is2d = False
        
        def sphere_profile(x,r_sphere):
            return np.real(2*np.sqrt(r_sphere**2 - x**2)) / r_sphere
        
        testind = -1
        
        if sd.get("n_p_g") is None:
            self.n_p_g = 50
        else:
            self.n_p_g = self.n_p_g.squeeze()
            
        # self.real_length = self.p_s
        self.real_length = self.n_p_g * self.p_g
        self.xpts = int(100*np.round(self.n_p_g) + 1 )
        # print(self.real_length, self.xpts)
        
        self.x = np.linspace(-self.real_length/2,self.real_length/2, self.xpts).T.astype(complex)

        # self.x = fftshift(np.roll(self.x,1,axis = get_param_axis(self.x,self.xpts)), axes = get_param_axis(self.x,self.xpts))

            
        self.x0, self.x = force_broadcast(self.x0, self.x)
        self.x = self.x - self.x0
        self.phi, self.r_sphere, self.x = force_broadcast(self.phi,self.r_sphere, self.x)
        
        self.profile = sphere_profile(self.x,self.r_sphere)
        self.x = self.x.squeeze()
        
        
#         plt.plot(self.profile[0], "o")
#         plt.show()
        
#         plt.plot(self.profile[-1], "o")
#         plt.show()
        
        
        
        self.FT = fft(np.exp(1j * self.phi * self.profile), axis = get_param_axis(self.profile,self.xpts),norm="forward")


        self.freqs = np.real_if_close(get_freqs(self.FT, get_param_axis(self.FT, self.xpts), self.real_length))

        
        # print(self.freqs.shape)
        # plt.plot(self.freqs.squeeze(), np.abs(self.FT), "o")
        # plt.show()

        crop = (np.abs(self.FT) > 1e-4).squeeze()

        if crop.ndim > 1:
            crop = np.any(crop, axis = get_param_axis(crop,self.r_sphere))

            
        if np.any(np.array(retain_shape_after_index(self.freqs, crop).shape) == 1):
            crop = (np.abs(self.freqs) < np.pi/self.r_sphere / self.n_p_g).squeeze()
            
    
        self.freqs, self.FT = retain_shape_after_index(self.freqs, crop), retain_shape_after_index(self.FT,crop)
        # print(self.freqs.shape)

        
        # plt.plot(self.freqs.squeeze(), np.abs(self.FT), "o")
        # plt.show()
        
#         plt.plot(self.freqs.squeeze(), np.abs(self.FT[-1]), "o")
#         plt.show()
        
    
        self.FT, self.freqs = self.FT.squeeze(), self.freqs.squeeze()
        dimdiff = self.FT.ndim - self.freqs.ndim
        self.freqs = np.tile(self.freqs, (*self.FT.shape[:dimdiff],*((1,)*(self.freqs.ndim))))
        

        self.absFT = np.abs(self.FT)
        
        # plt.plot(self.freqs[testind],self.absFT[testind], "o")
        # plt.show()
        
        self.spectrum = np.squeeze([self.FT,self.freqs])
        
        self.num_ords = self.FT.shape[-1]
    
    
    
    def propagate(self,k0, neu_spec = None):
        def B(k,k0):
            return -k**2/(2*k0)
        
        a_sample,k_sample = self.spectrum.copy()
        
        if neu_spec is None:
            
            self.M, self.L2, k0, a, k = force_broadcast(self.M, self.L2, k0, a_sample, k_sample)
           

        else:
            a, k = neu_spec
            
            unique_shape = np.unique(a.shape)
            dup_flag = False
            if len(a.shape) != len(unique_shape):

                dup_occurences = len(a.shape) - len(unique_shape) + 1
                occurences = np.array([a.shape.count(x) for x in unique_shape])
                dup_val = np.array(unique_shape)[occurences > 1].squeeze()

                
                if self.num_ords in a.shape:
                    dup_flag = True
                    a, k, a_sample, k_sample =  force_broadcast(a, k, a_sample, k_sample, nonunique_lengths=[dup_val],\
                                                nonunique_occurrences=[dup_occurences + 1],desired_nonunique_arr_ax=\
                                                [np.vstack((np.concatenate((np.tile([0,1],self.apts_num-1),[-2,-1])),\
                                                np.concatenate((np.repeat(np.arange(-self.apts_num+1,0,1),2), [-self.apts_num,-self.apts_num])))).T])
                else:
                    a, k, a_sample, k_sample =  force_broadcast(a, k, a_sample, k_sample, nonunique_lengths=[dup_val],\
                                                nonunique_occurrences=[dup_occurences],desired_nonunique_arr_ax=\
                                                [np.vstack((np.tile([0,1],self.apts_num-1),\
                                                np.repeat(np.arange(-self.apts_num+1,0,1),2))).T])

            else:
                a, k, a_sample, k_sample  = force_broadcast(a, k, a_sample, k_sample)
            
            a = a * a_sample
            k = k + k_sample
            
            print(a.shape,k.shape,self.M.shape, self.L2.shape, k0.shape)
            if dup_flag:
                self.M, self.L2, k0, a, k = force_broadcast(self.M, self.L2, k0, a, k, nonunique_lengths=[dup_val],\
                                            nonunique_occurrences=[dup_occurences+1],desired_nonunique_arr_ax=\
                                            [np.vstack((np.tile([-2,-1],self.apts_num),\
                                            np.repeat(np.arange(-self.apts_num,0,1),2))).T])
            else:
                self.M, self.L2, k0, a, k = force_broadcast(self.M, self.L2, k0, a, k, nonunique_lengths=[dup_val],\
                                            nonunique_occurrences=[dup_occurences],desired_nonunique_arr_ax=\
                                            [np.vstack((np.tile([-2,-1],self.apts_num-1),\
                                            np.repeat(np.arange(-self.apts_num+1,0,1),2))).T])
            
            # print(a.shape,k.shape,self.M.shape, self.L2.shape, k0.shape)
            
        self.a_prime = a*np.exp(-1j*B(k,k0)*self.M*self.L2)
        self.k_prime = k*self.M
        return [self.a_prime, self.k_prime]
    


    
class ForkGrating(RectGrating):
   
    def __init__(self,sd):
        super().__init__(sd)
        self.is2d = True
        
        
        
    def generate_raw_intensity(self,neu_spec,xarr,yarr, iteration, batches):
        a, k = neu_spec
        
        print("Generating raw intensity, %.2f %% done" % (100 * iteration/batches))
        clear_output(wait = True)
        
        OAMorders = np.roll(fftshift(np.arange(-self.mt,self.mt+self.spectrum_spacing,self.spectrum_spacing) * self.OAM), 1)

        phase = np.arctan2(yarr[None,...],xarr[...,None])

        
        # fork_phase = OAMorders[None,None,None,None,:] * phase[None,:,:,None,None]  
        # a, k = a[:,None,None,...], k[:,None,None,...]
        # xarr = xarr[None,:,None,None,None]
        
        fork_phase = OAMorders[...,:] * phase[...,None]
        a, k, xarr, fork_phase = a[None,None, :,:], k[None,None,:,:], xarr[:,None,None,None], fork_phase[:,:,None,:]
        psi = np.sum(a * np.exp(-1j * k * xarr) * np.exp(-1j*fork_phase), axis = (-2,-1))
        
        return np.abs(psi) * np.abs(psi)
        



class RotatedProfile:
    
    def __init__(self,deg, profile, profile_size,profile_height, profile_period):
        self.deg = assert_ndarr(deg)

        self.extended_profile = np.tile(profile,3)
        extended_shape = self.extended_profile.shape[0]
        height_in_pts = round(profile.shape[0] * profile_height // (profile_size))
        self.background = np.zeros((extended_shape,height_in_pts), dtype = np.int8)
        self.scaled_profile = (self.extended_profile - np.min(self.extended_profile) - 1/height_in_pts) / (np.max(self.extended_profile) - np.min(self.extended_profile)) 

        for col_ind, col in enumerate(self.background):
            self.background[col_ind,:int(self.scaled_profile[col_ind]*height_in_pts)+2] = 1

        self.rotated = [rotate(self.background,-d) for d in self.deg]
        self.summed = [np.sum(rot,axis = -1) for rot in self.rotated]

        self.scaled_summed = [2*(s - np.min(s))/(np.max(s) - np.min(s)) - 1 for s in self.summed]

        self.rotated_profiles = [s[extended_shape//3:2*extended_shape//3] for s in self.scaled_summed]
