# A short Tutorial to process sample NIRISS AMI simulations

* fit fringes for a simulated target and calibrator sequence (no WFE evolution between them)
* calibrate target closure phases with the calibrator
* fit for a binary

In [None]:
import glob
import os, sys, time
from astropy.io import fits
import numpy as np

from nrm_analysis import nrm_core, InstrumentData


import matplotlib.pyplot as plt
%matplotlib inline

### Where the data lives:

In [None]:
datadir = "../example_data/example_niriss/"
test_tar = datadir + "t_binary_s198.3_p143.9_cr0.01__nispsf_jit7.0_F480M_15x_ov__00_lgpp_short.fits"
test_cal = datadir + "c_binary_s198.3_p143.9_cr0.01__nispsf_jit7.0_F480M_15x_ov__00_lgpp_short.fits"


In [None]:
data=fits.getdata(datadir + "t_binary_s198.3_p143.9_cr0.01__nispsf_jit7.0_F480M_15x_ov__00_lgpp_short.fits")
print(data.shape)

### First we specify the instrument & filter (F480M in this case):

(defaults: Spectral type set to A0V)

In [None]:
niriss = InstrumentData.NIRISS("F480M")

In [None]:
#print(niriss.nwav)


### Next: get fringe observables via image plane fringe-fitting
* Need to pass the InstrumentData object, some keywords.
* Files will be saved into specified directory + new directory named by filename


In [None]:
ff_t = nrm_core.FringeFitter(niriss, datadir=datadir, savedir="targ/", oversample=3, interactive=False) 
ff_c = nrm_core.FringeFitter(niriss, datadir=datadir, savedir="cal/", oversample=3, interactive=False) 
#in general set interactive to False unless you really don't know what you are doing
# originally oversample=7  reduce for debug speed
                                                        

In [None]:
# This can take a little while -- there is a parallelization option, set threads=n_threads
# output of this is long -- may also want to do this scripted instead of in notebook,
# leaving off the output in this example.

ff_t.fit_fringes(test_tar)
ff_c.fit_fringes(test_cal)

You'll find some new files. Text files save the observables you are trying to measure, but there are also some diagnostic fits files written: centered_X are the cropped/centered data, modelsolution_XX are the best fit model to the data, and residual_XX is the difference between the two. 

Coming soon: propogating errors from fringe-fitting to observables.

In [None]:
data =   fits.getdata("targ/t_binary_s198.3_p143.9_cr0.01__nispsf_jit7.0_F480M_15x_ov__00_lgpp_short/centered_0.fits")
fmodel = fits.getdata("targ/t_binary_s198.3_p143.9_cr0.01__nispsf_jit7.0_F480M_15x_ov__00_lgpp_short/modelsolution_01.fits")
res =    fits.getdata("targ/t_binary_s198.3_p143.9_cr0.01__nispsf_jit7.0_F480M_15x_ov__00_lgpp_short/residual_01.fits")

plt.figure(figsize=(12,4))
plt.subplot(131)
plt.title("Input data")
im = plt.imshow(pow(data/data.max(), 0.5))
plt.axis("off")
plt.colorbar(fraction=0.046, pad=0.04)
plt.subplot(132)
plt.title("best model")
im = plt.imshow(pow(fmodel/data.max(), 0.5))
plt.axis("off")
plt.colorbar(fraction=0.046, pad=0.04)
plt.subplot(133)
plt.title("residual")
im = plt.imshow(res/data.max())
plt.axis("off")
plt.colorbar(fraction=0.046, pad=0.04)


If you don't want to clog up your hardrive with fits files you can initialize FringeFitter with keyword "save_txt_only=True" -- but you may want to save out everything the first time you reduce the data to check it. Above we can see a pretty good fit the magnification of the model is a bit off. This shows up as a radial patter in the residual. Finely fitting the exact magnification and rotation should be done before fringe fitting. 

### Calibration is simple: point to the data

The most important thing is to pass the right InstrumentData object with correct parameters so wavelength, pixelscale, etc. can be interpreted into on-sky spatial frequency. This can write out an oifits file.

In [None]:
niriss = InstrumentData.NIRISS("F480M") # temp fix to reset nwav appropriately to 1
tdir = "targ/t_binary_s198.3_p143.9_cr0.01__nispsf_jit7.0_F480M_15x_ov__00_lgpp_short/"
cdir = "cal/c_binary_s198.3_p143.9_cr0.01__nispsf_jit7.0_F480M_15x_ov__00_lgpp_short/"
calib = nrm_core.Calibrate([tdir, cdir], niriss, savedir = "calibrated_example/", interactive=False)

In [None]:
calib.save_to_oifits("exampleoifitsfiles.oifits") # will save into specified "savedir"
# Saved to oifits files are average fringe phasses, closure phases, visibility amplitudes,  
# closure amplitudes & errors over the sequence of observations, wavelength and baseline info

### Now what to do with an oifits file? Example: fit a binary

In [None]:
# BinaryAnalyze module in nrm_core is a convenient way to load in your oifits file
ba = nrm_core.BinaryAnalyze("calibrated_example/exampleoifitsfiles.oifits", savedir="calibrated_example/")

In [None]:
# Want to see what your closure phases look like?
plt.plot(ba.cp, '.')
# Bispectrum amplitudes:
plt.plot(ba.t3amp, '.')

In [None]:
# Can do a coarse fit for binary parameters. 
# The binary in this case has a contrast of 0.01 at about 200mas, 
# so it should be pretty easy to recover.
coarse_params = ba.chi2map(nstep=45, maxsep=400, clims=[1e-4, 0.9])

In [None]:
plt.figure()
plt.plot(45/2.0 -0.5,45/2.0 - 0.5, marker="*", color='w', markersize=20)
plt.imshow(np.sqrt(ba.significance), cmap="cubehelix", interpolation="nearest")
plt.xlabel("RA (mas)")
plt.ylabel("DEC (mas)")
plt.xticks(np.linspace(0, 45, 5), np.linspace(ba.ras.min(), ba.ras.max(), 4+1))
plt.yticks(np.linspace(0, 45, 5), np.linspace(ba.decs.min(), ba.decs.max(), 4+1))
plt.gca().invert_yaxis()
plt.colorbar()

## We can now do a finer fit with this coarse search as a starting point

For a finer fit you can use the run_emcee method, or if you want more control you can define your own likelihood function

In [None]:
con, sep, pa = np.array(coarse_params)
if pa > 180.0:
    pa = 360-pa
guess = {"con":con, "sep":sep, "pa":pa}
priors = [(1e-5, 0.99), (20.0, 400.0), (-180.0,180.0)]
ba.run_emcee(guess, nwalkers=100, niter=1000, priors=priors, \
                     threads=4, burnin = 100, \
                     # scale is to account for rendundancy in closure phase baselines
                     scale = np.sqrt(7/3.0))
# set so it will show in the notebook in addition to being written out
ba.plot="on"
ba.corner_plot("test_mcmc.pdf")

In [None]:
from nrm_analysis.modeling.binarymodel import model_cp_uv, model_t3amp_uv

def logl(data, err, model):
    """
    Likelihood given data, errors, and the model values
    These are all shape (nobservable, nwav)
    """
    #err*=np.sqrt(10.0/3.0)
    chi2 = np.sum(((data-model)/err)**2)
    loglike = -chi2/2.0
    return loglike

def cp_binary_model(params, binset, priors):
    """
    len(params) must be len(nrank) +2 (for sep, pa)
    """
    for i in range(len(params)):
        if (params[i] < priors[i][0] or params[i] > priors[i][1]):  
            return -np.inf

    contrasts = params[2:]
    seps = params[0]
    pas = params[1]

    model_cp = model_cp_uv(binset.uvcoords, contrasts, seps, pas, 1.0/binset.wavls)
    ll = logl(binset.cp, binset.cperr , model_cp)
    return ll

def all_binary_model(params, binset, priors):
    
    for i in range(len(params)):
        if (params[i] < priors[i][0] or params[i] > priors[i][1]):  
            return -np.inf

    contrasts = params[2:]
    seps = params[0]
    pas = params[1]

    model_cp = model_cp_uv(binset.uvcoords, contrasts, seps, pas, 1.0/binset.wavls)
    model_t3 = model_t3amp_uv(binset.uvcoords, contrasts, seps, pas, 1.0/binset.wavls)
    model_all = np.concatenate((model_cp, model_t3))
    obs_all = np.concatenate((binset.cp, binset.t3amp))
    obserr_all = np.concatenate((binset.cperr, binset.t3amperr))
    ll = logl(obs_all, obserr_all , model_all)
    return ll


In [None]:
import emcee
con, sep, pa = np.array(coarse_params)
if pa>180.0:
    pa = 360 - pa
params = np.array([sep, pa, con])
nwalkers = 100
niter = 1000
p0 = [params + 0.1*params*np.random.rand(len(params)) for i in range(nwalkers)]
priors = [ (0.0, 300.0), (-180.0, 180.0), (1e-4, .99)]
sampler = emcee.EnsembleSampler(nwalkers, len(params), cp_binary_model,
                                threads=1, args=[ba, priors])
t0 = time.time()
pos, prob, state = sampler.run_mcmc(p0, 100)
sampler.reset()
t1 = time.time()
print("burn in complete, took ", t1-t0, "s")
pos, prob, state = sampler.run_mcmc(pos, niter)
t2 = time.time()
print("Mean acceptance fraction: {0:.3f}".format(np.mean(sampler.acceptance_fraction)))
print("This number should be between ~ 0.25 and 0.5 if everything went as planned.")
print("ran mcmc, took", t2 - t1, "s")
chain = sampler.flatchain
fullchain = sampler.chain


In [None]:
import corner

samples = sampler.chain[:, 50:, :].reshape((-1, len(params)))
pq = list(map(lambda v: (v[1], v[2]-v[1], v[1]-v[0]),
         zip(*np.percentile(samples, [1, 50, 99],axis=0))))
print("================")
print("Recovered:")
print("---------")
print(pq[2][0],"+/-",pq[2][1],pq[2][2])
print(pq[0][0],"+/-",pq[0][1],pq[0][2])
print(pq[1][0],"+/-",pq[1][1],pq[1][2])
print("================")

fig = corner.corner(chain, bins = 50)