In [4]:
import numpy as np
from scipy.optimize import curve_fit
from scipy.fft import fft,fft2, fftfreq
from scipy.signal import oaconvolve
from scipy.stats import binned_statistic_dd
from PIL import Image, ImageFilter
import joblib as jb
from IPython.display import clear_output
# import numba
from scipy.ndimage import rotate
import matplotlib.pyplot as plt


In [2]:
def generate_stepfunc(width,padding):
    stepfunc_x = np.linspace(0,width+padding,int(width/padding)+2)
    return np.heaviside(stepfunc_x-padding,1) - np.heaviside(stepfunc_x-width-padding,1)

def generate_source_grating(period, duty_cycle, binsize, n_periods = 50):
    single_source_x = np.linspace(0,period, int(period/binsize)+1)
    single_source = np.heaviside(single_source_x,1) - np.heaviside(single_source_x - duty_cycle*period,1)
    source = np.array([*np.tile(single_source, n_periods)])
    # plt.plot(np.arange(len(source))/1e3,source)
    # plt.show()
    
    return source
    

def generate_single_source(period, duty_cycle, binsize, n_periods, current_period):
    single_source_x = np.linspace(0,(n_periods*period), int((n_periods*period)/binsize)+1)
    single_source = np.heaviside(single_source_x - (current_period*period),1)\
                    - np.heaviside(single_source_x - (current_period*period + duty_cycle*period),1)
    # plt.plot(np.arange(len(single_source))/1e3, single_source)
    # plt.show()
    return single_source

    

def bin_intensity(current_x, I, bin_width):
    n_bins = current_x[-1]//bin_width
    binned_mean, bin_edges, _ = binned_statistic_dd(current_x,I,bins = n_bins)
        
    return [np.squeeze(bin_edges),binned_mean]

def invert(pxl,fitparams):
    a,b,c = fitparams[0],fitparams[1],(fitparams[2]-pxl)
    return (-b + np.sqrt(b**2 - 4*a*c))/(2*a)

def dbcosine_func(x,A,B,p1,phi,p2):
    return A + B*np.cos(2*np.pi*x/p1 + phi) * (np.cos(2*np.pi*x/p2))

def cosine_func(x,A,B,p,phi):
    return A + B*np.cos(2*np.pi*x/p + phi)

def cosine_func_jac(x,A,B,p,phi):
    return np.array([np.ones_like(x), np.cos(2*np.pi*x/p + phi), (B*2*np.pi*x/(p**2))*np.sin(2*np.pi*x/p + phi),-B*np.sin(2*np.pi*x/p)]).T

def max_index(ndarr):
    return np.unravel_index(np.argmax(ndarr),ndarr.shape)

def lowpass(data, minperiod, pxlsize, order):

    normal_cutoff = pxlsize/minperiod
    # Get the filter coefficients 
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
 
    return filtfilt(b, a, data)


def grating_equation(xarr,p, phase_offset):
    return np.sign(np.sin(2*np.pi*xarr/p + phase_offset))

def grating_equation_2d(xarr,yarr,px,py):
    return np.sign(np.cos(2*pi*xarr/px) + np.cos(2*pi*yarr/py) - 1)




def maxmincont(arr):
    return (np.amax(arr,axis = -1) - np.amin(arr,axis = -1))/(np.amax(arr,axis = -1) + np.amin(arr,axis = -1))

def get_freqs(ft_like, real_signal_length):
    
    return 2*np.pi*np.array([fftfreq(ft_like.shape[-1],size/ft_like.shape[-1]) for size in real_signal_length]) if type(real_signal_length) == np.ndarray\
            else 2*np.pi*fftfreq(ft_like.shape[-1],real_signal_length/ft_like.shape[-1])

def rect_spectrum(x,xsize,p,phi,mt, image_profile = None, phase_offset = None):
#     !should probably be a class
    if image_profile is not None:
        realspace = phi/2 * image_profile
    else:
        realspace = phi/2 * grating_equation(x,p, 0 if phase_offset is None else phase_offset)
    
    allfreqs = get_freqs(x.squeeze(),xsize) # always returns 1D array with shape x.shape[-1]

    ft = fft(np.exp(1j*realspace), axis = -1, norm = "forward")
    
    
#     fig,ax = plt.subplots(figsize = (12,8))
# #     ax.plot(allfreqs,abs(ft[19]))
#     ax.plot(allfreqs, abs(ft))
#     [ax.axvline(2*np.pi/p * m, color = "r" if m%2==0 else "k", ls = "--") for m in range(-mt,mt+1)]
#     plt.xlim(-3e7,3e7)
#     plt.show()
#     plt.close(fig)
    

    ords = 2*np.pi/p * np.arange(-mt,mt+1,1)

    locs = np.argmin(np.abs(allfreqs[:,None] - ords[None,:]),axis = 0)
#     print(abs(ft[19,locs]))
    # return [allfreqs, ft, locs]
    return np.array([np.tile(allfreqs[locs], (*ft.shape[:-1],1)),ft[...,locs]]).squeeze()
    
    
    
def rect_spectrum2d(x,xsize,y,ysize,px,py,phi,mt):
    realspace = phi/2 * grating_equation_2d(x,y,px,py)
    
    
    ft2d = fft2(np.exp(1j*realspace), norm = "forward")

    xfreqs, yfreqs = 2*np.pi*fftfreq(ft2d.shape[0], xsize/ft2d.shape[0]), 2*np.pi*fftfreq(ft2d.shape[1], ysize/ft2d.shape[1])


    ordsx, ordsy = 2*np.pi/px * np.arange(-mt,mt+1,1), 2*np.pi/py * np.arange(-mt,mt+1,1)

    locsx, locsy = np.argmin(np.abs(xfreqs[:,None] - ordsx[None,:]),axis = 0), np.argmin(np.abs(yfreqs[:,None] - ordsy[None,:]),axis = 0)
    
    
    locs2d = tuple(np.split(np.array(np.meshgrid(locsx,locsy)).ravel(),2))


#     plt.plot(np.abs(ft2d[locs2d]))
#     plt.show()
#     plt.colorbar(plt.imshow(np.abs(ft2d[locs2d].reshape(3,3))))
#     plt.show()
#     plt.colorbar(plt.imshow(np.abs(ft2d), extent = (min(xfreqs), max(xfreqs), min(yfreqs), max(yfreqs))))
#     # [[plt.plot(xfreq,yfreq, "r.") for yfreq in yfreqs[locsy]] for xfreq in xfreqs[locsx]]
    
#     ext_mult = 2
    
#     plt.xlim(min(xfreqs[locsx])*ext_mult, max(xfreqs[locsx])*ext_mult)
#     plt.ylim(min(yfreqs[locsy])*ext_mult, max(yfreqs[locsy])*ext_mult)

#     plt.show()
#     print(ft2d[locs2d])
    
    return [xfreqs[locsx], yfreqs[locsy], ft2d[locs2d]]
    
    
    
    
    
def sample_spectrum(xsize,p,mt, spectrum_spacing,  sample_profile, phi):
    
    realspace = phi/2 * sample_profile
    
    allfreqs = get_freqs(sample_profile.squeeze(),xsize) # always returns 1D array with shape x.shape[-1]

    ft = fft(np.exp(1j*realspace), axis = -1, norm = "forward")

    if type(p) == np.ndarray:
        ords = 2*np.pi/p[:,None] * np.arange(-mt,mt+spectrum_spacing,spectrum_spacing)[None,:]

        allfreqs, ords = force_broadcast(allfreqs,ords)

        locs = np.argmin(np.abs(allfreqs - ords), axis = 1)[0]

        allfreqs = allfreqs.squeeze()
        # print(allfreqs.shape,ft.shape)
        
    
    else:
        ords = 2*np.pi/p * np.arange(-mt,mt+spectrum_spacing,spectrum_spacing)
        locs = np.argmin(np.abs(allfreqs[...,None] - ords[None,...]),axis = 0)

    
    dimdiff = ft.ndim - allfreqs.ndim

    return np.array([(np.tile(allfreqs,(*ft.shape[:dimdiff],*((1,)*(allfreqs.ndim)))))[...,locs], ft[...,locs]]).squeeze()  


def fork_get_regions(nd_intens, xcam, ycam, pg, L, di):
    imgshp = nd_intens.shape[-2:]
    pm_in_pts = int((L*pg/di)/xcam * imgshp[0])
    starty = int(imgshp[1]/20)
    
    left_startx = np.argmax(nd_intens[..., :pm_in_pts,starty], axis = -1)
    right_startx = np.argmax(nd_intens[..., imgshp[0] - pm_in_pts:,starty], axis = -1)
    
    startx_dist = np.abs(right_startx - left_startx)
    
    xl, yl = left_startx, starty
    xr, yr = right_startx, starty
    
    for step in range(imgshp[1] - 2*starty):
        gradxl = nd_intens[...,xl + 1,yl] - nd_intens[...,xl - 1, yl]
        gradyl = nd_intens[...,xl,yl + 1] - nd_intens[...,xl,yl - 1]
        gradl_mag = (gradxl**2 + gradyl**2)**0.5
        
        gradxr = nd_intens[...,xr + 1,yr] - nd_intens[...,xr - 1, yr]
        gradyr = nd_intens[...,xr,yr + 1] - nd_intens[...,xr,yr - 1]
        gradr_mag = (gradxr**2 + gradyr**2)**0.5
        
        if np.all(gradl_mag == 0):
            stepxl,stepyl = 1, 1
        elif np.all((np.rint(gradxl/gradl_mag) == 0) & (np.rint(gradyl/gradl_mag) == 0)):
            stepxl,stepyl = 1, 1
        else:
            stepxl, stepyl = np.rint(gradxl/gradl_mag), np.rint(gradyl/gradl_mag)
            
        if np.all(gradr_mag == 0):
            stepxr,stepyr = -1, 1
        elif np.all((np.rint(gradxr/gradr_mag) == 0) & (np.rint(gradyr/gradr_mag) == 0)):
            stepxr,stepyr = -1, 1
        else:
            stepxr, stepyr = np.rint(gradxr/gradr_mag), np.rint(gradyr/gradr_mag)
        
        xl += stepxl
        yl += stepyl
        xr += stepxr
        yr += stepyr
    
    endx_dist = np.abs(xr - xl)
    
    if startx_dist > endx_dist:
        fork_region = [left_startx,right_startx,starty, starty*5]
        nofork_region = [xl,xr,yl,imgshp[1]]
    else:
        fork_region = [xl,xr,yl,imgshp[1]]
        nofork_region = [left_startx,right_startx,starty, starty*5]
    
    return [fork_region, nofork_region]


def best_fit_moire_period(func,xdata,intens,pg,L,d):
    
    def pfit(func,x,intens, p0):
        params, cov = curve_fit(func,x,intens, p0, maxfev = 20000) 
        return np.array([params,np.sqrt(np.diag(cov))])
        
    pmoire_dist = np.linspace(0.875*L * pg/d, 1.125*L*pg/d, 51).squeeze()
    # pmoire_dist = np.linspace(0.75*L * pg/d, 1.25*L*pg/d, 51).squeeze()

    allparams, allerror = [],[]
    A,B,phi = np.mean(intens), np.amax(intens) - np.mean(intens), 0

    all_params_with_error = np.swapaxes(jb.Parallel(n_jobs = -1)(jb.delayed(pfit)(func,xdata,intens, p0 = [A,B,pmoire,phi])\
                            for pmoire in pmoire_dist), -2,0)
    allparams, allerror = all_params_with_error

    return np.array(allparams[np.argmin(allerror[...,2])])


# @numba.njit(fastmath = True, parallel = True)
def get_modsquared_psi(k,a,xarr):
    psi = np.sum(np.exp(-1j*k*xarr)*a, axis = -1)

    return np.abs(psi)*np.abs(psi)    

# def fit_recurs(fit_routine, func, xdata, ydata, params):
#     if ydata.ndim == 1:
#         return fit_routine(func,xdata,ydata, params)
#     else:
#         return fit_recurs(fit_routine, func, xdata, ydata[], params[])
    
def force_broadcast(*arrs, nonunique_lengths = [], nonunique_occurences = []):

    
    arrs = [arr.squeeze() for arr in arrs]
    
    output_lengths = []
    
    for i in range(len(arrs)):
        for j in range(arrs[i].ndim):
            if arrs[i].shape[j] in nonunique_lengths:
                if output_lengths.count(arrs[i].shape[j]) < nonunique_occurences[nonunique_lengths.index(arrs[i].shape[j])]:
                        output_lengths.append(arrs[i].shape[j]) 

            elif arrs[i].shape[j] not in output_lengths:
                output_lengths.append(arrs[i].shape[j]) 
                
#     print(output_lengths)
    common_ndim = len(output_lengths) 
    axes = np.arange(common_ndim)
    lengths_copy = output_lengths.copy()

    for i in range(len(arrs)):
        arr_ndim = arrs[i].ndim

        if arr_ndim > 0:
                    
            boolkey = np.array([np.array(lengths_copy) == j if j not in nonunique_lengths\
                                else np.full(axes.shape,False) for j in arrs[i].shape]).astype(bool)
            
            if boolkey.ndim > 1:
                boolcopy = boolkey.copy()
                boolkey = boolkey[0]
                for b in boolcopy[1:]:
                    boolkey = boolkey | b
            
            for length in arrs[i].shape:
                if length in nonunique_lengths:

                    boolkey[lengths_copy.index(length)] = True
                    lengths_copy[lengths_copy.index(length)] = -1
                    
            axes_to_expand = axes[~boolkey]
                    
            for axis in axes_to_expand:
                arrs[i] = np.expand_dims(arrs[i],int(axis))

    return arrs 

def force_equal_dims(top, btm):
    final_ndim = max(top.ndim,btm.ndim)
    for _ in range(final_ndim - top.ndim):
        top = np.expand_dims(top,-1)
    for _ in range(final_ndim - btm.ndim):
        btm = np.expand_dims(btm,0)
    return [top,btm]

        
    
class RotatedSpectrum:
    
    def __init__(self,deg, profile, profile_size,profile_height, profile_period):
        self.deg = deg if type(deg) == np.ndarray else np.array([deg])

        self.extended_profile = np.tile(profile,3)
        extended_shape = self.extended_profile.shape[0]
        height_in_pts = round(profile.shape[0] * profile_height // (profile_size))
        self.background = np.zeros((extended_shape,height_in_pts), dtype = np.int8)
        self.scaled_profile = (self.extended_profile - np.min(self.extended_profile) - 1/height_in_pts) / (np.max(self.extended_profile) - np.min(self.extended_profile)) 

        for col_ind, col in enumerate(self.background):
            self.background[col_ind,:int(self.scaled_profile[col_ind]*height_in_pts)+2] = 1

        self.rotated = [rotate(self.background,-d) for d in self.deg]
        self.summed = [np.sum(rot,axis = -1) for rot in self.rotated]

        self.scaled_summed = [2*(s - np.min(s))/(np.max(s) - np.min(s)) - 1 for s in self.summed]

        self.rotated_profiles = [s[extended_shape//3:2*extended_shape//3] for s in self.scaled_summed]

    
    
    
class PGMI:
#     !to-do eventually: replace integer references to axes with string variable names, i.e. create dict that maps input strings to array axes
#     and define class methods that get values/do something along input axes (np.apply_along_axis might be useful)

    def __init__(self, apts_dict,init_values):
#         !can accomplish same idea more efficiently with __slots__ or dataclass I think

        self.apts_dict = apts_dict

        self.pos = init_values["pos"]
        
        self.slitxflag = False
        self.slityflag = False

        if init_values.get("slitx") is not None:
            self.slitxflag = True
            self.stepslitx = generate_stepfunc(init_values["slitx"],init_values["xbin"])
#             print(init_values["slitx"])

        if init_values.get("slity") is not None:
            self.slityflag = True
            self.stepslity = generate_stepfunc(init_values["slity"],init_values["ybin"])
            
        

        self.stepresx = generate_stepfunc(init_values["res"],init_values["xbin"])
        self.stepresy = generate_stepfunc(init_values["res"],init_values["ybin"])
        self.resx = init_values["res"]
        self.resy = init_values["res"]
        self.mt = init_values["mt"]
        self.lam = init_values["lam"]
        self.k_l = 2*np.pi/init_values["lam"]
        self.plam = init_values["plam"]
        self.L = init_values["L"]
        self.d = init_values["d"]
        self.convmode = init_values["convmode"]
        self.camsize = init_values["camsize"]
        self.x = init_values["x"]
        
        for Apparatus in self.apts_dict.values():
            if type(Apparatus) == RectGrating2D:
                self.px,self.py = Apparatus.px, Apparatus.py
            elif type(Apparatus) != Sample:
                self.p = Apparatus.p
                break
        

        if init_values.get("source_period") is not None:
            self.slitxflag = True
            self.stepslitx = generate_source_grating(init_values["source_period"],init_values["duty_cycle"],\
                            init_values["xbin"], n_periods= int(3*self.camsize/init_values["source_period"]) + 1)
        
        self.x0 = init_values["x0"] if type(init_values["x0"]) == np.ndarray else np.array([init_values["x0"]]) # !should probably be grating attribute to allow for mutiple phase stepping
        self.batches = init_values["batches"]
        if init_values.get("y") is not None:
            self.y = init_values["y"]
        self.Ioffset = None    
        if init_values.get("Ioffset") is not None:
            self.Ioffset = init_values["Ioffset"]
        
    def get_apts(self):
        return self.apts_dict
    
    def get_values(self):
        return vars(self)
    
    def get_value(self,key):
#         !not sure if this will work when using __slots__
        return vars(self)[key]

    def propagate_to(self, key):
        Apparatus = self.get_apts()[key]
        
        pos,mt,k_l = self.pos, self.mt, self.k_l
        
        
        if Apparatus.is2d and not Apparatus.isFork:
            px,py,phi,L1,L2 = Apparatus.px, Apparatus.py, Apparatus.phi, Apparatus.L1,Apparatus.L2
            self.pos = Apparatus.propagate(pos,px,py,phi,mt,k_l, L1,L2)
        else:
            p,phi,L1,L2 = Apparatus.p,Apparatus.phi, Apparatus.L1,Apparatus.L2
            self.pos = Apparatus.propagate(pos,p,phi,mt,k_l, L1,L2, self.x0)

    def generate_after(self, key):
        Apparatus = self.get_apts()[key]
        L,d,pos,camsize = self.L, self.d, self.pos, self.camsize
        
        batches = self.batches
        self.x = np.linspace(self.x[0],self.x[-1], len(self.x) - len(self.x) % batches)
        
        xbatches = np.array([self.x[i*len(self.x)//batches:(i+1)*len(self.x)//batches] for i in range(batches)])

        if Apparatus.is2d:
            
            
            self.y = np.linspace(self.y[0],self.y[-1], len(self.y) - len(self.y) % batches )
            ybatches = np.array([self.y[i*len(self.y)//batches:(i+1)*len(self.y)//batches] for i in range(batches)])
            
#             xyspace = np.array(np.meshgrid(self.x,self.y))
#             xybatches = np.moveaxis(xyspace.reshape(2, int(len(self.x)/np.sqrt(batches)), int(len(self.y)/np.sqrt(batches)), batches), -1, 0)
            
            self.raw = np.array([[Apparatus.get_rawintensity(pos,xi,\
                        yi, self.camsize) for yi in ybatches] for xi in xbatches])
            print(self.raw.shape, "after parallel")
            
            if len(self.plam.flatten()) > 1:
                self.raw = np.moveaxis(self.raw, self.raw.shape.index(len(self.plam.flatten())), 0)
                
            
            self.raw = np.swapaxes(np.moveaxis(self.raw, self.raw.shape.index(len(d.squeeze())), 0), -2,-3)
            print(self.raw.shape, "before reshape")

            self.raw = np.reshape(self.raw,(*self.raw.shape[:-4],len(self.x), len(self.y)))
            print(self.raw.shape, "after reshape")


            for _ in range(self.raw.ndim - self.stepresx.ndim-1):
                self.stepresx = np.expand_dims(self.stepresx, 0)
                
            self.stepresx = np.expand_dims(self.stepresx, -1)

            self.raw, self.stepresy = force_equal_dims(self.raw,self.stepresy)
            
            if self.slityflag:
                self.raw, self.stepslity = force_equal_dims(self.raw,self.stepslity)
                self.intensity = oaconvolve(self.raw,self.stepslity, mode = self.convmode, axes = -1)

            if self.slitxflag:
                for _ in range(self.raw.ndim - self.stepslitx.ndim - 1):
                    self.stepslitx = np.expand_dims(self.stepslitx, 0)
                self.stepslitx = np.expand_dims(self.stepslitx, -1)

                self.intensity = oaconvolve(self.intensity,self.stepslitx, mode = self.convmode, axes = -2)
            
            if not (self.slityflag and self.slityflag):
                self.intensity = self.raw

            self.intensity = oaconvolve(oaconvolve(self.intensity, self.stepresy, mode = self.convmode, axes = -1), self.stepresx, mode = self.convmode, axes = -2)
            self.x, self.y = np.linspace(0,camsize, self.intensity.shape[-2]), np.linspace(0,camsize, self.intensity.shape[-1])
            return
            
        else:
            p = self.p 
            self.raw = np.array(jb.Parallel(n_jobs = -1, prefer = "threads")(jb.delayed(Apparatus.get_rawintensity)(pos,xi,camsize) for xi in xbatches))

            print(self.raw.shape, "after parallel")
            self.raw = np.moveaxis(self.raw, 0,-2)
            
            self.raw = np.reshape(self.raw,(*self.raw.shape[:-2], self.raw.shape[-2] * self.raw.shape[-1]  ))
            print(self.raw.shape, "after reshape")
            
            
            
            if len(self.plam.flatten()) > 1:
                self.raw, self.plam = force_broadcast(self.raw,self.plam)
                self.raw = np.sum(self.raw * self.plam, axis = self.raw.shape.index(len(self.plam.squeeze())))
            
#             plt.plot(self.raw[10])
#             plt.show()
            
            if self.slitxflag:
                self.raw, self.stepslitx = force_equal_dims(self.raw,self.stepslitx)
                
            self.raw, self.stepresx = force_equal_dims(self.raw,self.stepresx)
            print(self.raw.shape, "before convolution")
            
            if self.slitxflag: 
#                 plt.plot(oaconvolve(self.raw,self.stepslitx, mode = self.convmode, axes = -1)[10])
#                 plt.show()
                self.intensity = oaconvolve(oaconvolve(self.raw,self.stepslitx, mode = self.convmode, axes = -1),\
                                self.stepresx, mode = self.convmode, axes = -1)
                # self.slitconv = oaconvolve(self.raw,self.stepslitx, mode = self.convmode, axes = -1)
                
            else:
                self.intensity = oaconvolve(self.raw,self.stepresx, mode = self.convmode, axes = -1)
            
            if self.intensity.shape[-1] > camsize/xbin:
                self.intensity = self.intensity[...,int(camsize/xbin)]

            
            self.x = np.linspace(0,camsize,self.intensity.shape[-1])
        
        
        
        self.FT = fft(self.intensity, axis = -1)
        self.absFT = abs(self.FT)

        print("Finding contrast")

        if self.intensity.ndim < 3 :
            fitparams = [best_fit_moire_period(cosine_func,self.x,self.intensity[i],p, L, di) for i, di in enumerate(d)]\
                        if self.intensity.ndim > 1 else best_fit_moire_period(cosine_func,self.x,self.intensity,p, L, d)


            self.fitparams = np.array(fitparams)

        else:
#             surely there must be a better way than nested for loops. recursion?
            cont = np.empty(self.intensity.shape[:-1])
            fitparams = np.empty((*cont.shape,4))
            
            for i in range(self.intensity.shape[0]):
                clear_output(wait = True)
                print("Fitting")
                print("%.2f %% done" % (i/self.intensity.shape[0] * 1e2))
                for j in range(self.intensity.shape[1]):
                    
                    params = best_fit_moire_period(cosine_func,self.x,self.intensity[i,j],p,L,d[i,j]) if d.squeeze().ndim > 1\
                            else best_fit_moire_period(cosine_func,self.x,self.intensity[i,j],p,L,d[i])
                    
                    fitparams[i,j] = params
            
            self.fitparams = fitparams
        
        if self.Ioffset is not None:  
            self.Ioffset, fitparams_offset = force_broadcast(self.Ioffset,self.fitparams[...,0])
            # print(self.Ioffset.shape, fitparams_offset.shape)
            fitparams_offset = fitparams_offset + self.Ioffset
            fitparams_offset, unmodified_B = force_broadcast(fitparams_offset, self.fitparams[...,1])
            self.contrast = np.abs(unmodified_B/fitparams_offset)
            
        else:
            self.contrast = np.abs(self.fitparams[...,1]/self.fitparams[...,0])
            
                    
                    
# #     def cont_density(self):
# #         pos,mt,k_l = self.pos, self.mt, self.k_l
# #            need to replicate 2PGMI simulations density plots
        

    
class RectGrating(PGMI):
   
    def __init__(self,init_values):
        self.is2d = False
        self.p = init_values["p"]
        self.phi = init_values["phi"]
         
        self.L2, self.L1 = force_broadcast(init_values["L2"], init_values["L1"])
        self.M = 1/(1 + self.L2/self.L1).squeeze()
        if init_values.get("phase_offset") is not None:
            self.phase = init_values["phase_offset"]
        
    
    
    def propagate(self,pos,p,phi,mt,k_l,L1,L2, x0):
        def B(k_l,k):
            return -k**2/(2*k_l)
        self.x0 = x0
#         !change this when x0 becomes grating attribute and not PGMI attribute; should probably also specify spectrum during __init__

        xsize = 10*p
        xft = np.linspace(0,xsize,1001)
        phi,x0, xft = force_broadcast(phi,x0,xft)
        kg, ag = rect_spectrum(xft - x0,xsize,p, phi, mt).squeeze()

#         specifies the spectrum of the grating
        
        k = pos[0]
        a = pos[1]

#         get current position in Fourier space


        if np.any(np.array(k.shape) % kg.shape[-1] == 0):
            mt_ind = -2
            k, kg = force_broadcast(k,kg, nonunique_lengths = [kg.shape[-1]], nonunique_occurences = [2])    
            a, ag = force_broadcast(a,ag, nonunique_lengths = [ag.shape[-1]], nonunique_occurences = [2])    
            
        else:
            mt_ind = -1
            k, kg, a, ag = force_broadcast(k,kg,a,ag)    
            

        k = k + kg
        a = a * ag
#         array assignment is averse to += and *=
    
        k = np.reshape(k,(*k.shape[:mt_ind],np.prod(k.shape[mt_ind:])))
        a = np.reshape(a,(*a.shape[:mt_ind],np.prod(a.shape[mt_ind:])))
#         ensure that last dimension gets larger in length without adding more dimensions


        # print(self.M.shape,L2.shape,k.shape,a.shape, k_l.shape)
        self.M, L2, k_l, k, a = force_broadcast(self.M, L2, k_l, k, a )
        # print(self.M.shape,L2.shape,k.shape,a.shape, k_l.shape)
        
        
        nextk = k*self.M
        nexta = a*np.exp(-1j*B(k_l,k)*self.M*L2)
#         scale by magnification factor

#         print(nextk.shape,nexta.shape)
#         for i in range(nextk.shape[0]):
#             plt.plot(nextk[i],np.abs(nexta[i]), "o",label = i)

#         plt.show()
        # h0 = 
        # cont = 2 * 
        return [nextk,nexta]
    
    def get_rawintensity(self,pos,xarr,camsize):
        k = pos[0]
        a = pos[1]
        
        
        print("%.2f %% done" % (xarr[...,0]/camsize*100))
        clear_output(wait = True)
        
        a,k,xarr = [np.swapaxes(arr,-1,-2) for arr in force_broadcast(a,k,xarr)]
#         !is this necessary?
        
        return get_modsquared_psi(k,a,xarr)    
    

class RectGrating2D(PGMI):

    def __init__(self,init_values):
        self.px, self.py = init_values["px"],init_values["py"]
        self.is2d = True
        self.isFork = False
        self.phi = init_values["phi"]
        self.L1 = init_values["L1"]
        self.L2 = init_values["L2"]
        self.L2, self.L1 = force_broadcast(self.L2,self.L1)
        
        self.M = 1/(1 + self.L2/self.L1)

    
    def propagate(self,pos,px, py, phi,mt,k_l,L1,L2):
#         will currently only work for two gratings. incredibly scuffed and hardcoded due to time constraints
        def B(k_l,kx,ky):
            return -(kx**2 + ky**2)/(2*k_l)
        
        xlen, ylen = 10*px, 10*py
        x, y =  np.linspace(0,xlen,1001), np.linspace(0,ylen,1001)
        phi, x, y = force_broadcast(phi,x,y, nonunique_lengths= [x.shape[-1]], nonunique_occurences= [2])
        
        kgx, kgy, ag = rect_spectrum2d(x,xlen,y,ylen,px,py,phi,mt)
        
        
        ag = ag.reshape((kgx.shape[-1], kgy.shape[-1]))
        
        
        
#         specifies the spectrum of the grating

        kx, ky, a = pos

#         get current position in Fourier space

        if np.any(np.array(kx.shape) % kgx.shape[-1] == 0):
            kx, kgx = force_broadcast(kx, kgx, nonunique_lengths= [kgx.shape[-1]], nonunique_occurences= [2])
            ky, kgy = force_broadcast(ky, kgy, nonunique_lengths= [kgy.shape[-1]], nonunique_occurences= [2])
            
            # print(ag)
            # for m in range(2*mt + 1):
            #     if m != mt:
            #         ag[...,m] = 0
            # print(ag)
            
            a, ag = force_broadcast(a, ag, nonunique_lengths= [ag.shape[-1]], nonunique_occurences= [4])
            kx = kx + kgx
            ky = ky + kgy
            a = a * ag
            a = np.swapaxes(a, -2,-3)

            # plt.colorbar(plt.imshow(np.abs(np.reshape(a[0],(9,9)))))
            # plt.show()
            
#             print(kx.shape, ky.shape, a.shape)
            self.M, L2, kx, ky, k_l = force_broadcast(self.M, L2, kx, ky, k_l, nonunique_lengths= [kx.shape[-1],ky.shape[-1]],\
                                        nonunique_occurences = [4,4])
            
            nextkx = kx*self.M
            nextky = ky*self.M
#             print(nextkx.shape, nextky.shape, a.shape)

            self.M, L2, a, k_l = force_broadcast(self.M, L2, a, k_l, nonunique_lengths= [a.shape[-1]],\
                                                    nonunique_occurences = [4])

            nexta = a*np.exp(-1j*B(k_l,kx,ky)*self.M*L2) 
            # plt.colorbar(plt.imshow(np.abs(np.reshape(nexta[0],(9,9)))))
            # plt.show()
#             print(kx.shape, ky.shape, a.shape)
            
                        
            
        else:
            
            kx, kgx = force_broadcast(kx, kgx)
            ky, kgy = force_broadcast(ky, kgy)
            # print(ag)
            # for m in range(2*mt + 1):
            #     if m != mt:
            #         ag[...,m] = 0
            # print(ag)       
                    
            a, ag, = force_broadcast(a, ag, nonunique_lengths= [ag.shape[-1]], nonunique_occurences= [2])
            kx = kx + kgx
            ky = ky + kgy
            a = a * ag
            # plt.colorbar(plt.imshow(np.abs(a)))
            # plt.show()
            kx,ky = force_broadcast(kx,ky, nonunique_lengths= [kgx.shape[-1]], nonunique_occurences = [2])
#             print(kx.shape, ky.shape, a.shape)

            
            self.M, L2, kx,ky, k_l = force_broadcast(self.M, L2, kx,ky, k_l, nonunique_lengths= [kx.shape[-1], ky.shape[-1]],\
                                        nonunique_occurences = [2,2])
#             print(kx.shape, ky.shape, a.shape, "after M broad")
            nextkx = kx*self.M
            nextky = ky*self.M
#             print(nextkx.shape, nextky.shape, a.shape)

            self.M, L2, a, k_l = force_broadcast(self.M, L2, a, k_l, nonunique_lengths= [a.shape[-1]],\
                                                    nonunique_occurences = [2])

            nexta = a*np.exp(-1j*B(k_l,kx,ky)*self.M*L2) 
            # plt.colorbar(plt.imshow(np.abs(nexta[0])))
            # plt.show()
        

        return [nextkx,nextky, nexta]
    
    def get_rawintensity(self,pos,xarr,yarr,camsize):
        kx = pos[0]
        ky = pos[1]
        a = pos[2]
        
#         a = np.swapaxes(a,-2,-3)
#         xarr, yarr = xypair
        
        
        print("%.2f %% done" % (xarr[...,0]/camsize*100))        
        clear_output(wait = True)
        # print(kx.shape,ky.shape,a.shape)
        xarr,yarr,kx,ky = force_broadcast(xarr,yarr,kx,ky,\
                        nonunique_lengths= [xarr.shape[-1], kx.shape[-1], ky.shape[-1]], nonunique_occurences=[2, 4, 4])
        
        kxarr, kyarr = kx*xarr, ky*yarr
        # print(kxarr.shape, kyarr.shape, a.shape)
        # print((a*np.exp(-1j*(kxarr + kyarr))).shape)
        
        psi = np.sum(a*np.exp(-1j*(kxarr + kyarr)), axis = (-1,-2,-3,-4))
        # print(psi.shape)
        return np.abs(psi)*np.abs(psi) 
    
    

    
class GratingFromProfile(PGMI):
    
    def __init__(self,init_values, profile_size,  profile):
        
       
        self.is2d = False
        self.p = init_values["p"]
        self.phi = init_values["phi"]
         
        self.L2, self.L1 = force_broadcast(init_values["L2"], init_values["L1"])
#         print(self.L2.shape,self.L1.shape)
        self.M = 1/(1 + self.L2/self.L1).squeeze()

        
        self.profile = profile
        self.profile_size = profile_size
        
    
    
    def propagate(self,pos,p,phi,mt,k_l,L1,L2,x0):
        def B(k_l,k):
            return -k**2/(2*k_l)
        

        k = pos[0]
        a = pos[1]
#         get current position in Fourier space

        xft = np.linspace(0,self.profile_size,len(self.profile))
        phi,x0, xft = force_broadcast(phi,x0,xft)

        kg, ag = rect_spectrum(xft - x0,self.profile_size, p, phi, mt, image_profile=self.profile).squeeze()

        phi,kg,ag = force_broadcast(phi,kg,ag)
        

#         print(k.shape,a.shape)

        if np.any((np.array(k.shape) % kg.shape[-1] == 0) | (np.array(k.shape) % (2*mt + 1) == 0)):
            mt_ind = -2
            k, kg = force_broadcast(k,kg, nonunique_lengths = [kg.shape[-1]], nonunique_occurences = [2])    
            a, ag = force_broadcast(a,ag, nonunique_lengths = [ag.shape[-1]], nonunique_occurences = [2])    
            
        else:
            mt_ind = -1
            k, a, kg, ag = force_broadcast(k,a,kg,ag)  
            
#         print(k.shape,a.shape)

        k = k + kg
        a = a * ag
#         array assignment is averse to += and *=
    
        k = np.reshape(k,(*k.shape[:mt_ind],np.prod(k.shape[mt_ind:])))
        a = np.reshape(a,(*a.shape[:mt_ind],np.prod(a.shape[mt_ind:])))
#         ensure that last dimension gets larger in length without adding more dimensions


        print(self.M.shape,L2.shape,k.shape,a.shape, k_l.shape)
        self.M, L2, k, a, k_l = force_broadcast(self.M, L2, k, a, k_l)
        print(self.M.shape,L2.shape,k.shape,a.shape, k_l.shape)
        
            
        nextk = k*self.M
        nexta = a*np.exp(-1j*B(k_l,k)*self.M*L2)
#         scale by magnification factor

#         print(nextk.shape,nexta.shape)
        
        return [nextk,nexta]


    def get_rawintensity(self,pos,xarr,camsize):
        k = pos[0]
        a = pos[1]
        print("%.2f %% done" % (xarr[...,0]/camsize*100))
        clear_output(wait = True)
        
        k,a,xarr = [np.swapaxes(arr,-1,-2) if arr.ndim > 1 else arr for arr in force_broadcast(k,a,xarr)]
#         !is this necessary?
        
       
        return get_modsquared_psi(k,a,xarr)   
    
    
class ForkGrating(PGMI):
   
    def __init__(self,init_values):
        self.is2d = True
        self.isFork = True
        self.p = init_values["p"]
        self.phi = init_values["phi"]
         
        self.L2, self.L1 = force_broadcast(init_values["L2"], init_values["L1"])
        self.M = 1/(1 + self.L2/self.L1).squeeze()
        self.OAM = 0
        
        if init_values.get("OAM") is not None:
            self.OAM = init_values["OAM"]
            
        if init_values.get("mt") is not None:
            self.mt = init_values["mt"]
        
        
    
    
    def propagate(self,pos,p,phi,mt,k_l,L1,L2, x0):
        def B(k_l,k):
            return -k**2/(2*k_l)
        self.x0 = x0
#         !change this when x0 becomes grating attribute and not PGMI attribute; should probably also specify spectrum during __init__

        xsize = 10*p
        xft = np.linspace(0,xsize,1001)
        phi,x0, xft = force_broadcast(phi,x0,xft)
        kg, ag = rect_spectrum(xft - x0,xsize,p, phi, mt).squeeze()

#         specifies the spectrum of the grating
        
        k = pos[0]
        a = pos[1]

#         get current position in Fourier space


        if np.any(np.array(k.shape) % kg.shape[-1] == 0):
            mt_ind = -2
            k, kg = force_broadcast(k,kg, nonunique_lengths = [kg.shape[-1]], nonunique_occurences = [2])    
            a, ag = force_broadcast(a,ag, nonunique_lengths = [ag.shape[-1]], nonunique_occurences = [2])    
            
        else:
            mt_ind = -1
            k, kg, a, ag = force_broadcast(k,kg,a,ag)    
            

        k = k + kg
        a = a * ag
#         array assignment is averse to += and *=
    
        k = np.reshape(k,(*k.shape[:mt_ind],np.prod(k.shape[mt_ind:])))
        a = np.reshape(a,(*a.shape[:mt_ind],np.prod(a.shape[mt_ind:])))
        
#         ensure that last dimension gets larger in length without adding more dimensions


        # print(self.M.shape,L2.shape,k.shape,a.shape, k_l.shape)
        self.M, L2, k, a, k_l = force_broadcast(self.M, L2, k, a, k_l)
        # print(self.M.shape,L2.shape,k.shape,a.shape, k_l.shape)
        
            
        nextk = k*self.M
        nexta = a*np.exp(-1j*B(k_l,k)*self.M*L2)
#         scale by magnification factor

#         print(nextk.shape,nexta.shape)
        
        return [nextk,nexta]
    
    def get_rawintensity(self,pos,xarr,yarr,camsize):
        k = pos[0]
        a = pos[1]
        
        print("%.2f %% done" % (xarr[...,0]/camsize*100  + 50))
        clear_output(wait = True)
        
        
        
        # print(k.shape, "before n-reshape")
        k = np.reshape(k,(*k.shape[:-1], k.shape[-1] // (2*self.mt + 1), 2*self.mt + 1))
        a = np.reshape(a,(*a.shape[:-1], a.shape[-1] // (2*self.mt + 1), 2*self.mt + 1))
        
        OAMorders = np.arange(-self.mt,self.mt+1,1) * self.OAM
        phase = np.arctan2(yarr[None,...],xarr[...,None])
        
        # print(k.shape, "after n-reshape", "OAMorders:", OAMorders.shape, "phase:", phase.shape)
        
        if len(xarr) != len(yarr):
            phase, OAMorders = force_broadcast(phase, OAMorders)
            fork_phase = phase * OAMorders
            k, fork_phase = force_broadcast(k,fork_phase)
            a, fork_phase = force_broadcast(a,fork_phase)

        else:
            phase, OAMorders = force_broadcast(phase, OAMorders, nonunique_lengths = [len(xarr)], nonunique_occurences=[2]) 
            fork_phase = phase * OAMorders
            k, fork_phase = force_broadcast(k,fork_phase, nonunique_lengths = [len(xarr)], nonunique_occurences=[2])
            a, fork_phase = force_broadcast(a,fork_phase, nonunique_lengths = [len(xarr)], nonunique_occurences=[2])

        # print("k:", k.shape,"fork_phase:", fork_phase.shape, "a:", a.shape, "xarr:", xarr.shape)
        k, a = np.moveaxis(k, k.shape.index(2*self.mt+1),-2), np.moveaxis(a, a.shape.index(2*self.mt+1),-2)
        fork_phase = np.expand_dims(fork_phase, -2)
        xarr = np.moveaxis(force_equal_dims(xarr,k)[0], 0,k.shape.index(1))
        
        # print("k:", k.shape,"fork_phase:", fork_phase.shape, "a:", a.shape, "xarr:", xarr.shape)
      
        psi = np.sum(a * np.exp(-1j * k * xarr) * np.exp(-1j*fork_phase) , axis = (-1,-2))
        return np.abs(psi) * np.abs(psi)
        
        
        
    
class Sample(PGMI):
    
    def __init__(self,init_values, profile_size, profile):
        self.is2d = False
        self.p = init_values["p"]
        self.phi = init_values["phi"]
        self.spectrum_spacing = init_values["spectrum_spacing"]
        self.mt = init_values["mt"]
        self.L2, self.L1 = force_broadcast(init_values["L2"], init_values["L1"])

        self.M = 1/(1 + self.L2/self.L1).squeeze()

        self.profile = profile
        self.profile_size = profile_size
        
    def propagate(self,pos,p,phi,mt,k_l,L1,L2,x0):
        def B(k_l,k):
            return -k**2/(2*k_l)
        

        k = pos[0]
        a = pos[1]

#         get current position in Fourier space
        
        phi,x0, self.profile = force_broadcast(phi,x0,self.profile)

        k_s, a_s = sample_spectrum(self.profile_size, p, self.mt, self.spectrum_spacing,self.profile, phi).squeeze()
        # print(k_s.shape, a_s.shape)
        
        mt_ind = -2
        
        k, a, k_s, a_s = force_broadcast(k,a,k_s, a_s)
        # print(k.shape,k_s.shape)
        
        if k_s.squeeze().ndim > 1:
            k_s, a_s = np.moveaxis(k_s, -2,-3), np.moveaxis(a_s, -2,-3)
            k, a = np.moveaxis(k, -2,-3), np.moveaxis(a, -2,-3)
            
        # print(k.shape,a.shape)

        k = k + k_s
        a = a * a_s
        # print(k.shape,a.shape)

#         array assignment is averse to += and *=
        if type(self.p) == np.ndarray:
            k = np.moveaxis(k, k.shape.index(len(self.p)), 1)
            a = np.moveaxis(a, a.shape.index(len(self.p)), 1)

        if type(self.phi) == np.ndarray and len(k_l) > 1:
            k = np.swapaxes(k, 1, -2)
            # a = np.moveaxis(a, a.shape.index(len(self.phi)), 1)

    
    
        k = np.reshape(k,(*k.shape[:mt_ind],np.prod(k.shape[mt_ind:])))
        a = np.reshape(a,(*a.shape[:mt_ind],np.prod(a.shape[mt_ind:])))
#         ensure that last dimension gets larger in length without adding more dimensions
    
       

        # print(self.M.shape,L2.shape,k.shape,a.shape, k_l.shape)
        self.M, L2, k_l, k, a = force_broadcast(self.M, L2, k_l, k, a)
        # print(self.M.shape,L2.shape,k.shape,a.shape, k_l.shape)
        
            
        nextk = k*self.M
        nexta = a*np.exp(-1j*B(k_l,k)*self.M*L2)

        #         scale by magnification factor

        
        return [nextk,nexta]


    def get_rawintensity(self,pos,xarr,camsize):
        k = pos[0]
        a = pos[1]
        print("%.2f %% done" % (xarr[...,0]/camsize*100))
        clear_output(wait = True)
        
        a,k,xarr = [np.swapaxes(arr,-1,-2) if arr.ndim > 1 else arr for arr in force_broadcast(a,k,xarr)]
        #         !is this necessary?
        
       
        return get_modsquared_psi(k,a,xarr)   
    