In [13]:
import os
import desispec.io
import fitsio

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from astropy.io import fits
from astropy import modeling

from scipy.signal import medfilt

import emcee
import corner

tile='68000'
date='20200314'

In [14]:
%set_env DESI_SPECTRO_REDUX=/global/cfs/cdirs/desi/spectro/redux
%set_env SPECPROD=andes
reduxdir = desispec.io.specprod_root()

#read-in redrock templates
import redrock.templates
from desispec.interpolation import resample_flux
from desispec.resolution import Resolution

templates = dict()
for filename in redrock.templates.find_templates():
    t = redrock.templates.Template(filename)
    templates[(t.template_type, t.sub_type)] = t

env: DESI_SPECTRO_REDUX=/global/cfs/cdirs/desi/spectro/redux
env: SPECPROD=andes
DEBUG: Read templates from /global/common/software/desi/cori/desiconda/20190804-1.3.0-spec/code/redrock-templates/master
DEBUG: Using default redshift range -0.0050-1.6997 for rrtemplate-galaxy.fits
DEBUG: Using default redshift range 0.0500-5.9934 for rrtemplate-qso.fits
DEBUG: Using default redshift range -0.0020-0.0020 for rrtemplate-star-A.fits
DEBUG: Using default redshift range -0.0020-0.0020 for rrtemplate-star-B.fits
DEBUG: Using default redshift range -0.0020-0.0020 for rrtemplate-star-CV.fits
DEBUG: Using default redshift range -0.0020-0.0020 for rrtemplate-star-F.fits
DEBUG: Using default redshift range -0.0020-0.0020 for rrtemplate-star-G.fits
DEBUG: Using default redshift range -0.0020-0.0020 for rrtemplate-star-K.fits
DEBUG: Using default redshift range -0.0020-0.0020 for rrtemplate-star-M.fits
DEBUG: Using default redshift range -0.0020-0.0020 for rrtemplate-star-WD.fits


In [15]:
#get spectrum from targetid
def get_spec(targetid,tile,date,to_print=False):
    dirn = os.path.join(reduxdir, "tiles", tile, date)
    for spectrograph in range(10):
        fn = "coadd-{}-{}-{}.fits".format(spectrograph, tile, date)
        file = os.path.join(dirn, fn)
        fmap = fitsio.read(file, "FIBERMAP")
        for i in range(500):
            if(str(fmap['TARGETID'][i]) == targetid):
                specnum = i
                fnstore=fn
                spectrographstore = spectrograph
                if(to_print):
                    print(fn,spectrograph,i)
    specfn = os.path.join(dirn, fnstore)
    specobj = desispec.io.read_spectra(specfn)
    
    if("brz" in specobj.wave):
        x_spc = specobj.wave["brz"]
        y_flx = specobj.flux["brz"][specnum]
        y_err=1/np.sqrt(specobj.ivar["brz"][specnum])
    #else combine into "brz" using helper fnc.
    else:
        x_spc,y_flx,y_err=quick_brz(specobj,specnum)
    
    return(x_spc,y_flx,y_err)

In [16]:
#Turn 'b', 'r', 'z' into 'brz'
def quick_brz(specobj,spectrum):
    #pull wavelength sections
    bw=np.round(specobj.wave['b'],3)
    rw=np.round(specobj.wave['r'],3)
    zw=np.round(specobj.wave['z'],3)
    
    #find overlapping arrays in wave
    br_overlap=np.intersect1d(bw,rw)
    rz_overlap=np.intersect1d(rw,zw)

    #find indices of overlapping regions
    br_start=int(np.where(bw==br_overlap[0])[0])
    br_end=int(len(bw))

    rz_start=int(np.where(rw==rz_overlap[0])[0])
    rz_end=int(len(rw))

    #pull flux
    bf=specobj.flux['b'][spectrum]
    rf=specobj.flux['r'][spectrum]
    zf=specobj.flux['z'][spectrum]
    #pull error
    be=1/np.sqrt(specobj.ivar['b'][spectrum])
    re=1/np.sqrt(specobj.ivar['r'][spectrum])
    ze=1/np.sqrt(specobj.ivar['z'][spectrum])
    #turn into 'brz'
    x_spc=np.concatenate((bw[:br_start],(bw[br_start:br_end]+rw[:br_end-br_start])/2,rw[br_end-br_start:rz_start],(rw[rz_start:rz_end]+zw[:rz_end-rz_start])/2,zw[rz_end-rz_start:]))
    y_flx=np.concatenate((bf[:br_start],(bf[br_start:br_end]+rf[:br_end-br_start])/2,rf[br_end-br_start:rz_start],(rf[rz_start:rz_end]+zf[:rz_end-rz_start])/2,zf[rz_end-rz_start:]))
    y_err=np.concatenate((be[:br_start],(be[br_start:br_end]+re[:br_end-br_start])/2,re[br_end-br_start:rz_start],(re[rz_start:rz_end]+ze[:rz_end-rz_start])/2,ze[rz_end-rz_start:]))
    
    return(x_spc,y_flx,y_err)

In [17]:
def MgII_Model(theta,x):
    z,a1,a2,s1,s2,m,b=theta
    #determine peak centers
    m1=(z+1)*2795.5301
    m2=(z+1)*2802.7056
    
    #Generate Model
    model = m*(x-x[0]) +b + a1*np.exp((-(x-m1)**2)/(2*s1**2))+a2*np.exp((-(x-m2)**2)/(2*s2**2))
    return model

In [18]:
#likelihood fnc
def log_likelihood(theta, x, y, yerr):
    #generative model
    model = MgII_Model(theta,x)
    #error into variance
    sigma2 = yerr ** 2
    #Actual Likelihood fnc
    return -0.5 * np.sum((y - model) ** 2 / sigma2 + np.log(sigma2))

In [19]:
#prior fnc, could contain more info on reasonable redshifts, heights and widths
def log_prior(theta,z_low,z_high):
    z,a1,a2,s1,s2,m,b=theta
    #if -100 < a1 < 100 and  -100 < a2 < 100 and 0 < s1  and  0 < s2  and z_low < z < z_high:
    if 0 < s1 and 0 < s2 and z_low < z < z_high:
        return 0.0
    return -np.inf

In [20]:
#probability fnc
def log_probability(theta, x, y, yerr, z_low, z_high):
    lp = log_prior(theta,z_low,z_high)
    if not np.isfinite(lp):
        return -np.inf
    return lp + log_likelihood(theta, x, y, yerr)

In [21]:
results=[]
max_extra_runs=50

in_str='MgII_Doublets_'+tile+'.csv'
feature_table=np.genfromtxt(in_str,delimiter=',',dtype=str)

#intial fitting setup
fitter = modeling.fitting.LevMarLSQFitter()
#200 is fairly arbitrary but it allows z to potenitally vary by a couple hundreths (is this sufficent?)
sub_region_size=50
rest_frame_sep=7.1755
    
#MCMC setup
ndim=7
nwalkers = 32
    
for feature in feature_table:
    #will need a better way to record date and tile, probably in Doublet.csv
    x_spc,y_flx,y_err=get_spec(feature[0],tile,date)

    #determine redshift and appropriate line_sep
    z=float(feature[1])
    line_sep=rest_frame_sep*(1+z)
    peak=int(feature[2])

    #set sub region values, bounded by [0,x_spc-1]
    srh=min(len(x_spc)-1,peak+sub_region_size)
    srl=max(0,peak-sub_region_size)

    #determine max and min z in window (or lowest/highest values possible if at edge of wavelength space)
    z_low=x_spc[srl]/2795.5301-1
    z_high=x_spc[srh]/2795.5301-1

    #define subregion in x and y
    reg_wave=x_spc[srl:srh]
    reg_flx=y_flx[srl:srh]
    reg_err=y_err[srl:srh]
    
    #initial line guesses
    init_m=(reg_flx[0]-reg_flx[-1])/(reg_wave[0]-reg_wave[-1])
    init_b=reg_flx[0]
    
    init_Amp1= -float(feature[5])
    init_Amp2= -float(feature[7])
    
    init_StdDev1= float(feature[6])
    init_StdDev2= float(feature[8])

    #define initial theta (TODO: Reconsider guesses for m,b)
    initial=[z,init_Amp1,init_Amp2,init_StdDev1,init_StdDev2,init_m,init_b]
    #could widen this inital guess range, don't think it matters though
    
    p0 = initial + 1e-4 * np.random.randn(nwalkers, ndim)
    #run sampler (region)
    sampler = emcee.EnsembleSampler(nwalkers, ndim, log_probability, args=[reg_wave,reg_flx,reg_err,z_low,z_high])
    #burn-in
    state = sampler.run_mcmc(p0, 100)
    sampler.reset()
    #initial production
    state=sampler.run_mcmc(state, 5000)
    
    #extra production runs if we haven't reached autocorrelation
    extra_runs=0
    while(extra_runs<max_extra_runs):
        #should probably make this not a try/except but no idea how. MCMC docs
        try:
            sampler.get_autocorr_time(discard=1000)
            break
        except:
            extra_runs+=1
            try:
                state=sampler.run_mcmc(state,1000)
            except:
                extra_runs=max_extra_runs
                print('Large Condition Number',feature[0])
    if(extra_runs==max_extra_runs):
        print('Autocorrelation Failure:',feature[0])
        '''fig, axes = plt.subplots(7, figsize=(10, 7), sharex=True)
        samples = sampler.get_chain()
        labels = ["z", "Amp1", "Amp2","StdDev1","StdDev2","m","b"]
        for i in range(ndim):
            ax = axes[i]
            ax.plot(samples[:, :, i], "k", alpha=0.3)
            ax.set_xlim(0, len(samples))
            ax.set_ylabel(labels[i])
            ax.yaxis.set_label_coords(-0.1, 0.5)

        axes[-1].set_xlabel("step number")''';
        continue

    #extract MCMC info, discard 500 is fairly random
    flat_samples = sampler.get_chain(flat=True,discard=1000)
    mean_accept_frac=np.mean(sampler.acceptance_fraction)
    
    out_str='MgII_Candidate_Chains/'+feature[0]+'_'+str(np.percentile(flat_samples[:, 0],50))+'.csv'
    #Save Chain to output csv for plotting
    if(mean_accept_frac>0.0):
        np.savetxt(out_str, flat_samples, delimiter=",")
    else:
        print('Outside initial priors',feature[0])

Autocorrelation Failure: 35191251972131971
Large Condition Number 35191251972131971
Autocorrelation Failure: 35191251972131971


  ze=1/np.sqrt(specobj.ivar['z'][spectrum])


Large Condition Number 35191259383465992
Autocorrelation Failure: 35191259383465992
Autocorrelation Failure: 35191255675700437
Autocorrelation Failure: 35191259379272137
Large Condition Number 35191259383466412
Autocorrelation Failure: 35191259383466412
Large Condition Number 35191266723497299
Autocorrelation Failure: 35191266723497299


  re=1/np.sqrt(specobj.ivar['r'][spectrum])


Large Condition Number 35191259370884513
Autocorrelation Failure: 35191259370884513
Outside initial priors 35191273979642609
Large Condition Number 35191284691895977
Autocorrelation Failure: 35191284691895977


  be=1/np.sqrt(specobj.ivar['b'][spectrum])


Large Condition Number 35191291734132813
Autocorrelation Failure: 35191291734132813
Large Condition Number 35191281160291793
Autocorrelation Failure: 35191281160291793
Autocorrelation Failure: 35191291746714758
Large Condition Number 35191291746714911
Autocorrelation Failure: 35191291746714911
Large Condition Number 35191277586744103
Autocorrelation Failure: 35191277586744103
Large Condition Number 35191281168681049
Autocorrelation Failure: 35191281168681049
Autocorrelation Failure: 35191281168681049
Large Condition Number 35191281168681049
Autocorrelation Failure: 35191281168681049
Autocorrelation Failure: 35191281168681049
Outside initial priors 35191281168681049
Large Condition Number 35191274004808475
Autocorrelation Failure: 35191274004808475
Large Condition Number 35191274004808475
Autocorrelation Failure: 35191274004808475
Autocorrelation Failure: 35191277603523775
Large Condition Number 35191281177068220
Autocorrelation Failure: 35191281177068220
Large Condition Number 35191270