In [None]:
if not os.path.exists('results'):
    os.makedirs('results')
if not os.path.exists('results/plots'):
    os.makedirs('results/plots')

In [None]:
interesting = np.loadtxt('golden_original.txt',dtype=str)

In [None]:
def get_data(objname):
    df=qc.query(sql="""SELECT meas.* 
                     FROM nsc_dr2.meas
                     WHERE objectid='{:s}'""".format(objname),
              fmt='pandas',
              profile='db01')
    order = ['u','g','r','i','z']
    best_periods = []
    crv=[]
    fltrs=[]
    for f in order:
        selfltr = (df['filter'] == f)
        selfwhm = (df['fwhm'] <= 4.0)
        sel = selfltr & selfwhm
        t = df['mjd'][sel].values
        y = df['mag_auto'][sel].values
        dy = df['magerr_auto'][sel].values
        if len(t) < 25:
            continue

        pout = get_ls_period(t,y,objname=objname+'_'+f,outdir=outdir)
        best_periods.append(pout)
        crvi = np.vstack((t,y,dy)).T
        crv.append(crvi[np.argsort(crvi[:,0])])
        fltrs.append(f)
    period = 0
    for p in best_periods:
        period += p/len(best_periods)
    return crv, period, fltrs
        
def get_ls_period(t,y,min_freq=1./1.,max_freq=1./0.1,objname='_',outdir='results'):
    """Use Lomb-Scargle periodogram to get an estimate on period"""
    
    ls = stats.LombScargle(t, y)
    frequency, power = ls.autopower(minimum_frequency=min_freq,maximum_frequency=max_freq)
    period = 1./frequency # period is the inverse of frequency
    
    best_period = period[np.argmax(power)]
    
    plot_periodogram(period,power,best_period,objname=objname,outdir=outdir)
    return best_period

def plot_periodogram(period,power,best_period=None,objname='',ax=None,outdir='results'):
   
    if ax is None:
        fig, ax = plt.subplots(figsize=(10,7))
        
    ax.plot(period,power,lw=0.1)
    ax.set_xlabel('period (days)')
    ax.set_ylabel('relative power')
    ax.set_title(objname)
    
    if best_period is not None:
        ax.axvline(best_period,color='r');
        ax.text(0.03,0.93,'period = {:.3f} days'.format(best_period),transform=ax.transAxes,color='r')
    fig.savefig(outdir+'/{}_periodogram.png'.format(objname))
    plt.close(fig)
    return

def get_tmps(fltrs):
    tmps=[]
    typs =[]
    names=[]
    templatedir = tempdir()
    for fltr in fltrs:
        typ = []
        templets = glob(templatedir+'/*{}.dat'.format(fltr))
        tmp = np.zeros((len(templets),501,2))
        for i in range(len(templets)):
            tmp[i] = np.concatenate((np.array([[0,0]]),
                                     np.array(pd.read_csv(templets[i],sep=' ')),
                                     np.array([[1,0]])))
            #adjust if filepath to templets changes
            if len(os.path.basename(templets[i]))==8:
                typ.append('RRab')
            elif len(os.path.basename(templets[i]))==6:
                typ.append('RRc')
        typs.append(typ)
        names.append(templets)
        tmps.append(tmp)
    return tmps, names, typs

def double_tmps(tmps):
    tmps2=[]
    for f in range(len(tmps)):
        tmps2.append(np.tile(tmps[f],(2,1)))
        tmps2[f][:,int(len(tmps2[f][0])/2):,0] += 1
    return tmps2

def get_pinit(crv,period):
    pinit = ()
    for ltcrv in crv:
        pinit += ((0.0,max(ltcrv[:,1])-min(ltcrv[:,1]),0.0),)
    pinit += (period,)
    return pinit

def update_pinit(pars,period):
    pinit = ()
    for i in range(len(pars)):
        pinit += (tuple(pars[i,:-1]),)
    pinit += (period,)
    return pinit

def RemoveOutliers(crv,tmps,pars,period):
    n = pars[:,-1].astype(int)
    crv_in = []
    for i in range(len(crv)):
        f = interp1d(tmps[i][n[i],:,0],tmps[i][n[i],:,1]*pars[i,1]+pars[i,2])
        phase = (crv[i][:,0]/period-pars[i,0]) %1
        dif = abs(crv[i][:,1]-f(phase))
        crv_in.append(crv[i][dif<utils.mad(dif)*5])
    return crv_in

def double_period(crv,pars,period):
    crv2 = []
    for i in range(len(crv)):
        crv2.append(crv[i].copy())
        crv2[i][:,1] -= pars[i,2]
        
        crv2[i][:,0] = (crv2[i][:,0]/period-pars[i,0])%1
        crv2[i] = np.tile(crv2[i].T,2).T
        crv2[i][int(len(crv2[i])/2):,0] += 1
        crv2[i] = crv2[i][crv2[i][:,0].argsort()]
        
    return crv2

In [None]:
class tmpfitter:
    def __init__ (self, tmps):
        self.fltr=0
        self.n=0
        self.tmps=tmps

    def model(self, t, t0, amplitude, yoffset):
        # modify the template using peak-to-peak amplitude, yoffset
        # fold input times t by period, phase shift to match template
        xtemp = self.tmps[self.fltr][self.n,:,0]
        ytemp = self.tmps[self.fltr][self.n,:,1]*amplitude + yoffset
        ph = (t - t0) %1
        #print((ph[0],period,t0%1))
        #print((period,t0,amplitude,yoffset))
        # interpolate the modified template to the phase we want
        return interp1d(xtemp,ytemp)(ph)


def tmpfit(crv,tmps,pinit,w=.1,steps=21,n=1):
    fitter = tmpfitter(tmps)
    
    lsteps = int(steps/2+.5)
    rsteps = steps - lsteps
    pl = np.linspace(pinit[-1]-w,pinit[-1],lsteps)
    pr = np.linspace(pinit[-1]+w,pinit[-1],rsteps,endpoint=False)
    plist = np.zeros(pl.size+pr.size)
    plist[0::2] = np.flip(pl)
    plist[1::2] = np.flip(pr)
    plist = plist[plist>0]
    
    pars = np.zeros((len(tmps),4))
    minsumx2 = 10**50
    minp = 0
    for p in plist:
        sumx2=0
        ppars=np.zeros((len(tmps),4))
        for f in range(len(tmps)):
            fitter.fltr = f
            phase = crv[f][:,0]/p%n #1 for one period, 2 for two periods
            minx2 = 10**50
            for i in range(len(tmps[f])):
                fitter.n = i
                try:
                    tpars, cov = curve_fit(fitter.model, phase, crv[f][:,1], 
                                          bounds = ((-.5,0,-50),(.5,10,50)),
                                          sigma=crv[f][:,2], p0=pinit[f], maxfev=500)
                except RuntimeError:
                    #print('Error: Curve_fit failed on templet={}-{}, p={:.4}'.format(f,i,p))
                    continue
                
                x2 = sum((fitter.model(phase,tpars[0],tpars[1],tpars[2])-crv[f][:,1])**2/crv[f][:,2]**2)
                if x2 < minx2:
                    ppars[f,:-1] = tpars
                    ppars[f,-1] = i
                    minx2 = x2
            
            sumx2 += minx2
            if sumx2 > minsumx2:
                break
        if sumx2 < minsumx2:
            minsumx2 = sumx2
            minp = p
            pars = ppars
    npoints=0
    for i in range(len(crv)):
        npoints += len(crv[i])
    return pars, minp, minsumx2/npoints