# Continuous Waves - Why do we do that??

In addition to the stochastic background, we expect a number of closer, louder sources to stand out to us. These are high-mass, inspiraling black hole binaries with orbital frequencies between 1 and 300 nHz which are evolving very slowly over the course of PTA observations. As such, we term these signals "continuous gravitational waves" since PTA experiments don't witness these binaries changing much in orbital frequency. 

Continuous waves searches are a bit different than PTA searches for the GWB, because we can model the entire waveform for a single binary emitting GWs (this is why they're a type of "deterministic" signal!). While in theory this makes the search simpler, in reality, we need to model a lot of parameters to make this work.

In this exercise, we'll start by injecting a continuous wave signal into some simulated pulsar timing data. First, we inspect the residuals by eye and then investigate how each parameter affects the delays induced by a CW. Then, we'll set up an enterprise run to conduct a full MCMC to see how much information we can recover. By doing so with a few different setups, by the end of this tutorial, you should have a good understanding of not only the numerous steps necessary to run a CW search, but also *why* they're necessary! 

In [None]:
# Useful imports
%matplotlib inline

import numpy as np
import libstempo.toasim as LT
import libstempo.plot as LP
import matplotlib.pyplot as plt
from enterprise.pulsar import Pulsar



import astropy.constants as const
from libstempo import eccUtils as eu
from enterprise import constants as econst

import astropy.units as u
from astropy.coordinates import SkyCoord

colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

### 0.0 Make personalized directory 

In [None]:
INIT = 'INIT'
DIR = './cw_test'
!mkdir 'cw_test'

# 1.0 Simulate Pulsar Data
### 1.1 Generate fake `.par` files

In [None]:
# this function creates a parameter file for a simulated pulsar
# it randomizes sky location, name, pulse frequency and parallax

def make_fake_pulsar(DIR):
    '''
    Makes a fake pulsar par file
    '''
    output = "MODE 1\n"
    
    # Sphere Point Picking
    u = np.random.uniform()
    v = np.random.uniform()
    phi = 2*np.pi*u #using standard physics notation
    theta = np.arccos(2*v-1) - np.pi/2

    c = SkyCoord(phi,theta,frame='icrs',unit='rad')
    cstr = c.to_string('hmsdms')
    #print cstr
    RAJ = cstr.split(" ")[0].replace("h",":").replace("m",":")[:-1]
    DECJ = cstr.split(" ")[1].replace("d",":").replace("m",":")[:-1]
    cstr = cstr.replace(" ","")
    name = "J"+RAJ[0:2]+RAJ[3:5]+DECJ[0]+DECJ[1:3]+DECJ[4:6]

    output += "PSR      %s\n"%name

    
    output += "PEPOCH   50000.0\n"    
    output += "POSEPOCH   50000.0\n"

    period = 0.001*np.random.uniform(1,10) #seconds
    output += "F0       %0.10f 1\n"%(1.0/period)

    output += "RAJ      %s 1\n"%RAJ
    output += "DECJ     %s 1\n"%DECJ

    dist = 1.0 #np.random.uniform(0.1,5) #kpc
    output += "PX       %0.5f 1\n"%(1.0/dist) #mas

    filename = "%s/%s.par"%(DIR,name)
    print(filename)
    with open(filename,'w') as FILE:
        FILE.write(output)

    return filename.encode('ascii','ignore')

Let's start by making 3 pulsars. The `make_fake_pulsar` function will create three pulsars randomly placed anywhere on the sky at a distance of 1 kpc.

In [None]:
n_psrs = 3 #try changing this later! 

pars =[]
for p in range(n_psrs):
    pars.append(make_fake_pulsar(DIR))

### 1.2  Generate fake `.tim` files with an injected single GW source


In [None]:
#input GW parameters - useful for later! 

# Let's add a source at the Virgo cluster: 12h27m +12d43', D_L = 15 Mpc
gwtheta_in = np.pi/2-(12+43.0/60)*np.pi/180 #theta = pi/2 - dec
gwphi_in = (12*15+27.0/60)*np.pi/180
DL_in = 15
phase0_in = 0
psi_in = 0
inc_in = 0
fgw_in = 1e-8
mass_in = 5e9

strain = (2*(const.G*mass_in*u.Msun/const.c**3)**(5./3)*(np.pi*fgw_in*u.Hz)**(2./3)*const.c/(DL_in*u.Mpc)).to(u.dimensionless_unscaled)
print(np.log10(strain))

In [None]:
# this function simulates a pulsar's timing file
# and injects the TOAs with a CW source of given chirp and frequency [Hz]

def observe(par,noise=0.5,mass=5e9,fgw=1e-8, psrTerm = True, evolve = True):
    ''' Noise in microseconds, mass in solar masses'''
    # let's set up some fake TOAs
    t = np.arange(53000,56650,30.0) #observing dates for 10 years
    t += np.random.randn(len(t)) #observe every 30+/-1 days
    psr = LT.fakepulsar(parfile=par,
                        obstimes=t,
                        toaerr=noise)
    LT.add_equad(psr,equad=noise*1e-6) #this is mildly correlated white noise 
    
    
    LT.add_cgw(psr, gwtheta=gwtheta_in, gwphi=gwphi_in, mc=mass, dist=DL_in, 
               fgw=fgw, phase0=phase0_in, psi=psi_in, inc=inc_in, pdist=1.0,
               pphase=None, psrTerm=psrTerm, evolve=evolve,
               phase_approx=False, tref=psr.toas()[-1]*86400)
#     psr.fit()
    psr.savetim('%s.tim'%par.split('.')[0])
    print('%s.tim'%par.split('.')[0])
#     psr.savepar(par)

    return psr


In [None]:
tims = []

for par in pars:
    tims.append( observe(par.decode('utf-8'), fgw = fgw_in))

# 1.3 What does this look like?
Let's take a look at the timing residuals with this injected signal with the `LP.plotres` function. 

In [None]:
# some space to work

for tim in tims:
    LP.plotres(tim)

Notice that these residuals are not just a sinusoid! There are a few interesting pieces of these waveforms that are important to disentangle. To do this, let's first turn our `.par` and `.tim` files into `enterprise` `Pulsars`, since they're a bit easier to work with.

In [None]:
psrs = []
for par, tim in zip(pars, tims):
    psrs.append(Pulsar(par, tim))


Since these are fake pulsars, enterprise doesn't have a distance measurement for them. 
Instead, it uses the assumption that they're approximately 1 kpc away; since we created them with `make_fake_pulsar`, which places the pulsars at a distance of 1 kpc, this assumption is correct!

To have a good understanding of how to set up a CW search and why each piece is necessary, let's take a look at a few things that make the CW signals look so distinct. 

First is the antenna pattern response, which affects how strongly a pulsar at any position on the sky will be affected by a CW from a source in a given location.

In [None]:
from enterprise.signals.utils import create_gw_antenna_pattern
import healpy as hp

In [None]:
def plot_apf(psrs, gwtheta, gwphi):
    npix = hp.nside2npix(8)

    theta, phi = hp.pix2ang(8, range(768))
    data_p_pos = np.zeros(npix)
    data_x_pos = np.zeros(npix)
    data_m_pos = np.zeros(npix)
    for pix in range(npix):
            theta, phi = hp.pix2ang(8, pix)
            pos = np.array([np.sin(theta)*np.cos(phi), 
                         np.sin(theta)*np.sin(phi), 
                         np.cos(theta)])
            data_p_pos[pix],data_x_pos[pix], data_m_pos[pix]  = create_gw_antenna_pattern(pos,gwtheta,gwphi)
            
    names = ['Plus', 'Cross', 'Cos Mu']
    fig, axes = plt.subplots(1,3,figsize = (18,3))

    for d, dat in enumerate([data_x_pos, data_p_pos, data_m_pos]):
        plt.axes(axes[d])
        mv = hp.mollview(dat, rot = 180, title = names[d], hold = True)


    
        for i, psr in enumerate(psrs):
            hp.visufunc.projscatter(psr.theta, psr.phi, marker = '*', s = 300, edgecolor = 'w', color = colors[i])

        #healpy isn't great at axis labels, so we have to do it by hand :( 
        for i in range(2,24,2):
            text = hp.projtext( i*180/12+3, 4,  str(i)+'h', lonlat=True, coord='G', fontsize = 'large', fontweight = 100, zorder = 1, color = 'w')

        for i in range(-75,0,15):
            text = hp.projtext( 180,i,   str(i)+'$^\circ$', lonlat=True, coord='G', fontsize = 'large', fontweight = 100, zorder = 10, color = 'w')
        for i in range(15,90,15):
            text = hp.projtext( 180,i,   str(i)+'$^\circ$', lonlat=True, coord='G', fontsize = 'large', fontweight = 100, zorder = 10, color = 'w')
        hp.graticule(15, 30)

In [None]:
plot_apf(psrs, gwtheta_in, gwphi_in)
#try changing the input theta and phi values to see how the antenna pattern changes! 

Next, let's take a closer look at the form of a CW signal.

In [None]:
from enterprise_extensions.deterministic import cw_delay

In [None]:
tmin = [p.toas.min() for p in psrs]
tmax = [p.toas.max() for p in psrs]
Tspan = np.max(tmax) - np.min(tmin)
tref = max(tmax)
# it's imporant to choose a logical reference time for your CW model! the default in most functions is t = 0, 
# which isn't actually useful to compare to binary orbital periods measured in at the current MJD. 
# Therefore, we use the last TOA in the data set.

The strain amplitude is a function of a few of the source parameters,  

$$
\begin{equation}
h_{0} = \dfrac{2 \mathcal{M}_c^{5/3}(\pi f_{\rm GW})^{2/3}}{d_{L}}
\end{equation}
$$

You'd anticipate that more massive systems are "louder," as are closer systems. This is reflected in the strain's dependence on the binary's effective mass, called the chirp mass ($\mathcal{M}_{c}$) and inverse dependence on a particular distance to the binary, called the luminosity distance ($d_L$).

If we have no prior knowledge about our continuous wave source, the only information we can recover from our search is about the strain (i.e., $h_{0}$). And since that is a mix of chirp mass and luminosity distance, we need some other information in order to get estimates of these two values independently. Fortunately, there is a way to untangle these values: through the inclusion of the so-called pulsar term.

The full influence of a gravitational wave passing through our pulsar array is a combination of the measurements of the signal at the Earth and at each pulsar individually. If the wave jiggles Earth, then the times of arrival for all the pulsars in the array will change. We encapsulate this particular portion of the signal effect into the "Earth-term". But, every pulsar experiences a jiggle too, and since the pulsar's pulses take time to reach us, we essentially see the signal in the past, before it reached Earth. This is the "pulsar-term" in our residuals. 

Theoretically, we understand this to provide information on the chirp mass of the system. By including both Earth- and pulsar-terms into our search, we're better able to recover the source properties. 


In [None]:
#check out cw_delay


In [None]:
psr_num = 0
names = ['earth', 'psr', 'total']
fig, axes = plt.subplots(3,1, figsize = (5, 10))
psr = psrs[psr_num]
cw1 = cw_delay(psr.toas, psr.pos, pdist = (1.0, 0.2),
               cos_gwtheta=np.cos(gwtheta_in), 
               gwphi=gwphi_in, log10_mc=np.log10(mass_in), log10_dist=np.log10(DL_in), 
               log10_fgw=np.log10(fgw_in), phase0=phase0_in, psi=psi_in, cos_inc=np.cos(inc_in), 
              psrTerm = True, evolve = True, tref =tref, p_dist = 0)

axes[2].plot((psr.toas*u.second).to(u.day), cw1 , color = colors[psr_num], marker = '', alpha = 1)
cw1_e = cw_delay(psr.toas, psr.pos, pdist = (1.0, 0.2),
               cos_gwtheta=np.cos(gwtheta_in), 
               gwphi=gwphi_in, log10_mc=np.log10(mass_in), log10_dist=np.log10(DL_in), 
               log10_fgw=np.log10(fgw_in), phase0=phase0_in, psi=psi_in, cos_inc=np.cos(inc_in), 
              psrTerm = False, evolve = True, tref =tref, p_dist = 0)

axes[0].plot((psr.toas*u.second).to(u.day), cw1_e , color = colors[psr_num], marker = '', alpha = 1)

cwp = cw1+cw1_e
axes[1].plot((psr.toas*u.second).to(u.day), cwp , color = colors[psr_num], marker = '', alpha = 1)

axes[2].plot((psr.toas*u.second).to(u.day), cwp-cw1_e , color = colors[psr_num], marker = '', alpha = 1)
for i,ax in enumerate(axes):
    ax.set_ylim(-2.6e-5, 2.6e-5)
    ax.set_xlabel('MJD')
    ax.set_ylabel('Residual (s)')
    ax.set_title(names[i], y=1.0, pad=-25)
    


In addition to the pulsar's location on the sky, the distance to the pulsar has a big effect on the residuals when we're modeling the pulsar term. Since the wavelength of our typical GWs are much smaller than the uncertainties on the distances to our pulsars, small changes in the pulsar distance can have a big effect on the shape of the residuals. In the block below, we'll draw a few values from a normal distribution of pulsar distances and plot the output of `cw_delay` to compare to the input signal. 


In [1]:
#calculate wavelength here


In [None]:
N = 5
dists = np.random.randn(N,n_psrs)*0.2+1
plt.figure(figsize=(25, 3.5*n_psrs))
for i, psr in enumerate(psrs):
    ax = plt.subplot(n_psrs, 3, i+1)
    cw1 = cw_delay(psr.toas, psr.pos, pdist = (1.0, 0.2),
                   cos_gwtheta=np.cos(gwtheta_in), 
                   gwphi=gwphi_in, log10_mc=np.log10(mass_in), log10_dist=np.log10(DL_in), 
                   log10_fgw=np.log10(fgw_in), phase0=phase0_in, psi=psi_in, cos_inc=np.cos(inc_in), 
                  psrTerm = True, evolve = True, tref =tref, p_dist = 0)
    for d, dist in enumerate(dists):
        cwd = cw_delay(psr.toas, psr.pos, pdist = (dist[i], 0.2),
                       cos_gwtheta=np.cos(gwtheta_in), 
                       gwphi=gwphi_in, log10_mc=np.log10(mass_in), log10_dist=np.log10(DL_in), 
                       log10_fgw=np.log10(fgw_in), phase0=phase0_in, psi=psi_in, cos_inc=np.cos(inc_in), 
                      psrTerm = True, evolve = True, tref =tref, p_dist = 0)

        ax.plot((psr.toas*u.second).to(u.day), cwd , color = colors[d], marker = '', alpha = 1/N)


    ax.plot((psr.toas*u.second).to(u.day), cw1 , color = colors[i], marker = '', alpha = 1)
    ax.set_ylim(-1.8e-5, 1.8e-5)
    
# Try choosing more distance values (or even some specific values of the pulsar distance) to see how much 
# of an effect this can have! Since everyone's random pulsars are in different locations, compare with friends too.

## 3.0 Construct CW Signal Search

Now that we have a good understanding of why a CW looks the way it does, let's look for one in our simulated PTA with `enterprise`!

In [None]:
# More imports for GW detection codes
import sys
# Enterprise
# sys.path.append("/home/jovyan/work/shared/enterprise/")
import enterprise.signals.parameter as parameter
from enterprise.signals import utils
from enterprise.signals import signal_base
from enterprise.signals import selections
from enterprise.signals.selections import Selection
from enterprise.signals import white_signals
from enterprise.signals import gp_signals
import corner
from PTMCMCSampler.PTMCMCSampler import PTSampler as ptmcmc
import enterprise_extensions.model_utils as utils
import enterprise_extensions.models as models

from enterprise_extensions.sampler import get_parameter_groups, JumpProposal

from enterprise_extensions.deterministic import cw_delay, CWSignal


### 3.1 Add pulsars to model

In [None]:
# We've already created our enterprise pulsars, but to do so, remember:
# psrs = []
# for par, tim in zip(pars, tims):
#     psrs.append(Pulsar(par, tim))


### 3.2 Form timing model

In [None]:
# Create white noise parameter priors
efac = parameter.Constant(1)
equad = parameter.Uniform(-8.5,-5.0)

##### Signals below #####

# white noise parameters
ef = white_signals.MeasurementNoise(efac=efac)
eq = white_signals.EquadNoise(log10_equad=equad)

# timing model
tm = gp_signals.TimingModel()

### 3.3 Initialize CW source parameter priors & add to signal model

In [None]:
# continuous GW parameters
# note that we are pre-initializing them with names here so that they will be shared
# across all pulsars in the PTA

# Our standard CW search looks for a GW with a specific frequency and we hold log10_fgw constant,
# as commented below. However, we will do a search in frequency
#freq = 8e-09
#log10_fgw = parameter.Constant(np.log10(freq))('log10_fgw')


cos_gwtheta = parameter.Uniform(-1, 1)('cos_gwtheta') #position of source
gwphi = parameter.Uniform(0, 2*np.pi)('gwphi') #position of source
log10_mc = parameter.Uniform(7, 10)('log10_mc') #chirp mass of binary
log10_fgw = parameter.Uniform(-9,-7)('log10_fgw') #gw frequency 
phase0 = parameter.Uniform(0, 2*np.pi)('phase0') #gw phase
psi = parameter.Uniform(0, np.pi)('psi') #gw polarization 
cos_inc = parameter.Uniform(-1, 1)('cos_inc') #inclination of binary with respect to Earth 

log10_h = parameter.Uniform(-18, -11)('log10_h') #gw strain (linear exponential for an upper limit calculation)


psrTerm = False
# define CGW waveform and signal
cw_wf = cw_delay(cos_gwtheta=cos_gwtheta, gwphi=gwphi, log10_mc=log10_mc, 
                     log10_h=log10_h, log10_fgw=log10_fgw, phase0=phase0, 
                     psi=psi, cos_inc=cos_inc, 
                 tref = tref, evolve = True, psrTerm = psrTerm, p_dist = 0)
cw = CWSignal(cw_wf, ecc=False, psrTerm=psrTerm)

### 3.4 Define the full model (includes noise + GW signals)

In [None]:
# full signal
s = ef + eq + tm + cw

### 3.5 With model and pulsars, create `enterprise` PTA object

In [None]:
# initialize PTA
model = [s(psr) for psr in psrs]
pta = signal_base.PTA(model)

### 3.6 Prepare the MCMC sampler

In [None]:
# Prepare sampler initial condition
x0 = np.hstack(p.sample() for p in pta.params)
ndim = len(x0)

# initial jump covariance matrix
cov = np.diag(np.ones(ndim) * 0.1**2)

# parameter groupings
groups = get_parameter_groups(pta)

# define where you want to put the chains from the MCMC
chaindir = DIR+'/chains/'

# set up jump groups by red noise groups (need better way of doing this)
sampler = ptmcmc(ndim, pta.get_lnlikelihood, pta.get_lnprior, cov, groups=groups, 
                 outDir=chaindir, resume=False)

# write parameter file for convenience
filename = chaindir + '/params.txt'
np.savetxt(filename,list(map(str, pta.param_names)), fmt='%s')

In [None]:
# add prior draws to proposal cycle -- this helps prevent the sampler's walkers 
# from getting trapped in local minimum in parameter space
jp = JumpProposal(pta)
sampler.addProposalToCycle(jp.draw_from_prior, 50)
sampler.addProposalToCycle(jp.draw_from_cw_log_uniform_distribution, 10)

#these might be enough for our 3 pulsar PTA, but more might be needed for more complex searches. 
#Which ones would you add? Try adjusting these weights to see what happens!

### 3.7 Sample!

In [None]:
N = 50000
sampler.sample(x0, N, SCAMweight=30, AMweight=15, DEweight=5)

## 4.0 Visualize Results

In [None]:
#input values to compare
noises_in = np.log10(np.ones(n_psrs)*0.5*1e-6)

psr_correct = []
for n in range(n_psrs):
    if psrTerm:
        dists_in = np.zeros(n_psrs)
        phases_in = np.zeros(n_psrs)
        # note: used as a placeholder to visualize likelihoods and create truth vector for corner plots.
        # Homework! Can you use what you learned to calculate what this phase should be?
        psr_correct.append(dists_in[n])
        psr_correct.append(phases_in[n])
    psr_correct.append(noises_in[n])

psr_correct = np.array(psr_correct)
correct = np.hstack((psr_correct, np.array([np.cos(gwtheta_in), 
                   np.cos(inc_in),gwphi_in, np.log10(fgw_in), np.log10(strain.to(u.dimensionless_unscaled)), 
                   np.log10(mass_in),phase0_in, psi_in] )))

In [None]:
# Load the MCMC chains and parameter names (if you need them)

chain = np.loadtxt(chaindir + '/chain_1.txt')
params = list(np.loadtxt(chaindir + '/params.txt', dtype='str'))

# Value of burn-in to apply to the chains.
burn = int(0.25*chain.shape[0])

In [None]:
# Convenience function to help plot the marginalized parameter distributions
def plot_param(name, hist = True, correct = correct):
    '''Given one of the CGW function names above, plot the distribution as a histogram or as a trace plot. 
    In traces, the burn in period is shown in red.'''
    
    if hist:
        plt.hist(chain[burn:,params.index(name)], 50, density=True, lw=2, color='C0', histtype='step')
        plt.axvline(correct[params.index(name)])
    else:
        plt.plot(chain[:,params.index(name)], lw=2, color='C0')
        plt.plot(chain[:burn,params.index(name)], lw=2, color='C3')
        plt.axhline(correct[params.index(name)])

    if "log10" in name:
        plt.xlabel(r"$\log_{10} \mathrm{%s}$"%(name.split('_')[-1]))
    else:
        plt.xlabel(name)

In [None]:
# some space to work

In [None]:
plot_param('log10_h', hist = False)

In [None]:
# corner plot of the GW parameters
if psrTerm:
    pars_per_pulsar = 3
else:
    pars_per_pulsar = 1
corner.corner(chain[burn:,n_psrs*pars_per_pulsar:-4], labels=params[n_psrs*pars_per_pulsar:], 
              truths = correct[n_psrs*pars_per_pulsar:]);

In [None]:
#try making a corner plot of the per-pulsar parameters too!

In [None]:
#collect output values
ind_ml = np.argmax(chain[burn:,-4])
cos_gwtheta_out = chain[burn+ind_ml, params.index('cos_gwtheta')]
cos_inc_out = chain[burn+ind_ml, params.index('cos_inc')]
gwphi_out = chain[burn+ind_ml, params.index('gwphi')]
log10_fgw_out = chain[burn+ind_ml, params.index('log10_fgw')]
log10_h_out = chain[burn+ind_ml, params.index('log10_h')]
log10_mc_out = chain[burn+ind_ml, params.index('log10_mc')]
phase0_out = chain[burn+ind_ml, params.index('phase0')]
psi_out = chain[burn+ind_ml, params.index('psi')]


In [None]:
log10_dl_out = np.log10(((const.G*10**log10_mc_out*u.Msun/(const.c**3))**(5./3)*
                         (np.pi*10**log10_fgw_out*u.Hz)**(2./3)*const.c/(10**log10_h_out)).to(u.Mpc).value)
print(10**log10_dl_out)
#is this close to the input value?

In [None]:
#needed if using pulsar term
dist_pars = [p for p in params if "dist" in p]
dists_out = [chain[burn+ind_ml, params.index(p)] for p in dist_pars]

phase_pars = [p for p in params if "p_phase" in p]
phases_out = [chain[burn+ind_ml, params.index(p)] for p in phase_pars]

In addition to our MCMC chain output, it can be helpful to visualize our results by comparing the signal made by the injected parameters with the results.

In [None]:
plt.figure(figsize=(25, 3.5*n_psrs))

for i, psr in enumerate(psrs):
    ax = plt.subplot(n_psrs, 3, i+1)
    cw1 = cw_delay(psr.toas, psr.pos, pdist = (1.0, 0.2),
               cos_gwtheta=np.cos(gwtheta_in), 
               gwphi=gwphi_in, log10_mc=np.log10(mass_in), log10_dist=np.log10(DL_in), 
               log10_fgw=np.log10(fgw_in), phase0=phase0_in, psi=psi_in, cos_inc=np.cos(inc_in), 
              psrTerm = True, evolve = True, tref =tref, p_dist = 0)
    
    cw_out = cw_delay(psr.toas, psr.pos, pdist = (1.0, 0.2),
               cos_gwtheta=cos_gwtheta_out, 
               gwphi=gwphi_out, log10_mc=log10_mc_out, log10_h = log10_h_out, 
               log10_fgw=log10_fgw_out, phase0=phase0_out, psi=psi_out, cos_inc=cos_inc_out, 
              psrTerm = False, evolve = True, tref =tref, p_dist = 0) #dists_out[i], p_phase = phases_out[i])

    ax.plot((psr.toas*u.second).to(u.day), cw1 , color = colors[i], marker = '', alpha = 1)
    ax.plot((psr.toas*u.second).to(u.day), cw_out , color = colors[i], marker = '', alpha = 1, ls = '--')

    ax.set_ylim(-1.8e-5, 1.8e-5)

## 5.0 Likelihood Visualization 

Did your CW search work perfectly? If not, it might be useful to try to understand why the MCMC returned what it did. The PTA likelihood is complex and VERY difficult to move around. Let's take a look!

In [None]:
def like_plotter_1d(xpar, xgrid, pars = pta.param_names, correct = correct, like = pta.get_lnlikelihood):
    #function to visualize the likelihood against one parameter
    #xpar
    ind = pars.index(xpar)
    l= []
    c = correct.copy()
    for x in xgrid:
        c[ind] = x
        l.append(like(c))
    plt.plot(xgrid, l)
    plt.xlabel(xpar)
    plt.ylabel('lnlike')
    plt.axvline(correct[ind], color = 'C3', alpha = 0.5)
    i = np.argmax(l)
    plt.axvline(xgrid[i], color = 'C1', alpha = 0.5)
    return l
                


In [None]:
l = like_plotter_1d('log10_fgw', np.linspace(-9, -7, 1000))
#try a few!

In [None]:
def like_plotter_2d(xpar, xgrid, ypar, ygrid, pars = pta.param_names, correct = correct, like = pta.get_lnlikelihood):
    xind = pars.index(xpar)
    yind = pars.index(ypar)
    l= []
    c = correct.copy()
    for x in xgrid:
        for y in ygrid:
            c[xind] = x
            c[yind] = y
            l.append(like(c))
    l = np.array(l)
    l = l.reshape(len(xgrid), len(ygrid))
    fig = plt.figure()
    plt.pcolormesh(xgrid, ygrid, l.transpose(), shading = 'auto', vmin = np.percentile(l, 15), vmax = np.percentile(l, 99.5))
    plt.xlabel(xpar)
    plt.ylabel(ypar)
    plt.colorbar(label = 'lnlike')
    plt.scatter(correct[xind], correct[yind], marker ='o', facecolor ='none', 
                edgecolors = 'C3', s =200, linewidths = 2, label = 'Input')
    i = np.unravel_index(np.argmax(l.transpose()), l.transpose().shape)
    plt.scatter(xgrid[i[1]], ygrid[i[0]] , marker ='o', facecolor ='none', 
                edgecolors = 'C1', s =150, linewidths = 2, label = 'Grid Max')
    plt.legend(loc = 'lower right', bbox_to_anchor=(1.7, 0.8))
    return fig

In [None]:
fig = like_plotter_2d('gwphi', np.linspace(0,2*np.pi, 100), 'cos_gwtheta', np.linspace(-1, 1, 100))

In [None]:
fig = like_plotter_2d('log10_fgw', np.linspace(-9,-7, 1000), 'log10_h', np.linspace(-18, -12, 50))

In [None]:
#try a few!

## 6+ And beyond:

Now that you are an expert, re-run the code in the following ways and see how your answers change:

* Use the pulsar term in the `enterprise` search (necessary code snippet below)

In [None]:
#define the pulsar distance and pulsar phase in your enterprise model, and add to cw_delay
p_phase = parameter.Uniform(0, 2*np.pi)
p_dist = parameter.Normal(0, 1)


* Reduce the simulated mass of the system in `observe()` to make the observed strain even weaker. How well can you do?


* Try only injecting noise into your simulated pulsars. Then, take an upper limit on the GW strain at a constant frequency (code below). This is similar to how current CW searches operate.


In [None]:
from enterprise_extensions.model_utils import bayes_fac, ul
bf = bayes_fac(chain[burn:, params.index('log10_h')]) #this should be done on a chain run with a LinearExp log10h prior
upper = ul(chain[burn:, params.index('log10_h')])

* Constrain some parameters of the system to create a multi-messenger search (such as `cos_gwtheta`, `gwphi`, and `log10_dL` (which can be used instead of `log10_h` if it's known)


* Add some extra jump proposals to your MCMC to help the parameters move around. Try adjusting the weights of the jump proposals to see how they change your MCMC results. (examples below)


In [None]:
sampler.addProposalToCycle(jp.draw_from_par_prior(['log10_mc']), 5) 
#this jump will take a list of parameters to add a prior draw in


* Add red noise into the `observe()` function using `LT.add_rednoise(psr,A,gamma)`. $A$ will be in GW strain amplitude units and you can vary $\gamma$ in steepness, typically somewhere between 1 and 5. How does this affect the results? What new correlations do you see amongst the parameters? 


Remember: try adding more pulsars into your array if you need some additional signal boost! However, you might need more samples in each of these tasks, especially if you have added pulsars. 