In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:99% !important; }</style>"))

In [None]:
%matplotlib inline
#%config InlineBackend.figure_format = 'retina'
%load_ext autoreload

%autoreload 2

import numpy as np

import matplotlib
matplotlib.rcParams.update({'font.size': 19})

import matplotlib.pyplot as plt
import os, glob, json
import scipy.linalg as sl
import enterprise
from enterprise.pulsar import Pulsar
import enterprise.signals.parameter as parameter
from enterprise.signals import utils
from enterprise.signals import signal_base
from enterprise.signals import selections
from enterprise.signals.selections import Selection
from enterprise.signals import white_signals
from enterprise.signals import gp_signals
from enterprise.signals import deterministic_signals
import enterprise.constants as const

import enterprise_extensions
from enterprise_extensions import blocks

from QuickBurst import QuickBurst_MCMC as QB_MCMC

import healpy as hp

import libstempo as T2
import libstempo.toasim as LT
import libstempo.plot as LP
import re

import pickle

from QuickBurst import tau_scans_pta

# Setup

In [None]:
parDir = '/home/user/.../Pars/'
timDir =  "/home/user/.../Tims/"

parfiles = sorted(glob.glob(parDir + '*.par'))
timfiles = sorted(glob.glob(timDir + "*.tim"))

psrs = []
for p, t in zip(parfiles, timfiles):
    psr = Pulsar(p, t, ephem=None, clk=None)
    psrs.append(psr)

In [None]:
'''can save generated pulsars as a pickle for future use'''
# pkl_path = "/home/user/filepath.../"
# os.makedirs(pkl_path,exist_ok=True)
# with open(pkl_path + "Psrs_pkl.pkl", 'wb') as f: #_old_pars
#     pickle.dump(psrs, f)

In [None]:
'''Load in pulsar pickle files'''
pkl_path = "/home/reyna/15yr-v1.1/Data/Simple_test_data/20_pulsars/WN_only/Psrs.pkl"
with open(pkl_path , 'rb') as f:
    psrs = pickle.load(f)

In [None]:
### Visualize data
min_toa = np.min([p.toas.min() for p in psrs])
for i, psr in enumerate(psrs):
    plt.figure(i)      
    plt.errorbar((psr.toas-min_toa)/86400/365,
            psr.residuals,
            yerr=0.5*10**(-6),#psr.toaerrs[mk],
            markersize=8, ls='', marker='x', alpha=0.5)
    plt.xlabel('Time [MJD]')
    plt.ylabel(r'Residuals [$\mu$s]')
    plt.title('Pulsar {}'.format(psrs[i].name))
    plt.show()

In [None]:
#Quantifying data gaps in pulsars
data_gaps = []
for psr_idx, psr in enumerate(psrs):
    psr_data_gaps = []
    for toa_idx, toa in enumerate(psr.toas):
        #Check if gaps are larger than 30 days
        if (toa-psr.toas[toa_idx-1])/3600/24 > 110:
            psr_data_gaps.append([psr.toas[toa_idx-1], toa])
    data_gaps.append(psr_data_gaps)
        

In [None]:
np.shape(data_gaps[20])

In [None]:
print(0.2*365, 0.38*365)

In [None]:
for i in range(len(data_gaps)):
    for gaps in data_gaps[i]:
        print(gaps[0]/3600/24, gaps[1]/3600/24)
        # print(gaps/3600/24)

## generates shorted list of pulsars

In [None]:
#Truncate PTA based on Tspan > 10 years
keep_list = []
for p in psrs:
    if (p.toas.max()-p.toas.min())/86400/365 > 10:
        print(p.name)
        print((p.toas.max()-p.toas.min())/86400/365)
        keep_list.append(p)
print(len(keep_list))

In [None]:
# os.makedirs(pkl_path, exist_ok=True)
# with open(pkl_path + "Psrs_pkl_10yr.pkl", 'wb') as f:
#     pickle.dump(keep_list, f)

In [None]:
# with open("home/user/filepath.../Psrs_pkl_10yrs.pkl", 'rb') as f:
#     psrs = pickle.load(f)

In [None]:
#plotting for more complicated dataset pulsars
ng_frontends=['327', '430', 'Rcvr_800', 'Rcvr1_2', 'L-wide', 'S-wide', '1.5GHz', '3GHz']
for i, psr in enumerate(psrs):
    fe_masks = {}
    fe_resids = {}
    psr_fe = np.unique(psr.flags['fe'])
    resids = psr.residuals
    #plt.plot(psr.toas(),resids)
#     print(psr.name)
    plt.figure(i)
    for fe in ng_frontends:
        if fe in psr_fe:
            fe_masks[fe] = np.array(psr.flags['fe']==fe)
            mk = fe_masks[fe]
            plt.errorbar(psr.toas[mk],
                    psr.residuals[mk]*1e6,
                    yerr=0.0,#psr.toaerrs[mk],
                    markersize=8, ls='', marker='x', label=fe, alpha=0.5)
            plt.xlim(tref, maximum)
            plt.title(psr.name)
            #plt.axvline(psr.toas[0]+1423*(86400))
    plt.legend()
    plt.show()

In [None]:
#load in noise file for dataset
noise_file = "/home/reyna/15yr-v1.1/Data/Simple_test_data/20_pulsars/WN_only/noise_dict.json"
with open(noise_file, 'r') as h:
    noise_params = json.load(h)
    
### RUN IF INCLUDING CURN
# gw_gamma = 13/3 
# gw_amp = -14.6
# noise_params['gw_crn_gamma'] = gw_gamma
# noise_params['gw_crn_log10_A'] = gw_amp

In [None]:
noise_params

In [None]:
#Edit equad to have keys including 't2equad'
#only needed for converting 12.5yr data to current enterprise structure
noise_dict = {}
for k, v in noise_params.items():
    if 'equad' in k:
        noise_dict[k[:-5] + 't2equad'] = v
    else:
        noise_dict[k] = v
noise_params = noise_dict

In [None]:
#check noise parameters are what they should be
noise_params

# Individual glitch tau scans

In [None]:
#Finding reference time for dataset
maximum = 0
minimum = np.inf
for psr in psrs:
    if psr.toas.max() > maximum:
        maximum = psr.toas.max()
    if psr.toas.min() < minimum:
        minimum = psr.toas.min()


#Sets reference time
tref = minimum

t0_max = (maximum - minimum)/365/24/3600
print(t0_max)

In [None]:
#frequency range (Hz)
f_max = 1e-7
f_min = 3.5e-9

#centrtal time range (years)
t_max = t0_max
t_min = 0.0

#tau range (envelope size, years)
tau_min = 0.05
tau_max = 5.0

In [None]:
#frequency of wave with a period equal to envelope width
1/(tau_max*365*24*3600)

In [None]:
TS_all = {} #will hold individual pulsar Tau Scans
tau_scan = 0
for i in range(len(psrs)):
    # generate a pta for each pulsar individualy
    x = QB_MCMC.get_pta([psrs[i],], vary_white_noise=False, include_equad=True,
                            include_ecorr = False, include_efac = True, 
                            wn_backend_selection=False, noisedict=noise_params, include_rn=False,
                            vary_rn=False, include_per_psr_rn=False, vary_per_psr_rn=False,
                            max_n_wavelet=5, efac_start = None, rn_amp_prior='log-uniform',
                            rn_log_amp_range=[-18,-11], per_psr_rn_amp_prior='log-uniform',
                            per_psr_rn_log_amp_range=[-18,-11], wavelet_amp_prior='uniform',
                            wavelet_log_amp_range=[-10.0,-5], prior_recovery=False,
                            max_n_glitch=1, glitch_amp_prior='uniform', glitch_log_amp_range=[-10.0,-5],
                            t0_min=0.0, t0_max=t0_max, f0_min=f_min, f0_max=f_max,
                            TF_prior=None, tref=tref)

    #Performs tau scans for each pulsar individually (noise transients)

    TauScan = tau_scans_pta.TauScan([psrs[i],], params=noise_params, pta=x[0])
    TS_dict = tau_scans_pta.make_tau_scan_map(TauScan, f_min=f_min, f_max=f_max, t_min=t_min, t_max=t_max,
                                tau_min=tau_min, tau_max=tau_max)
    TS_all['tau_edges'] = TS_dict['tau_edges'] #why saved every time?
    TS_all['t0_edges'] = TS_dict['t0_edges']
    TS_all['f0_edges'] = TS_dict['f0_edges']
    TS_all['tau_scan'+str(i)] = TS_dict['tau_scan']

## Saving individual glitch tau scans

In [None]:
with open("/home/reyna/15yr-v1.1/Data/Simple_test_data/20_pulsars/WN_only/Tau_scans/noise_transient.pkl", 'wb') as f:
    pickle.dump(TS_all, f)

In [None]:
with open("/home/reyna/15yr-v1.1/Script/QuickBurst/data/glitch_tau_scan_SNR99p.pkl", 'rb') as f:
    TS_all = pickle.load(f)

## Plotting glitch tau scans

In [None]:
data_gaps[0]

In [None]:
data_gaps[0][0][0]/24/3600

# Plot noise transient tau scans with data gaps

In [None]:
import random
tau_edges = TS_all['tau_edges']
T0_list = TS_all['t0_edges']
F0_list = TS_all['f0_edges']

for i in range(len(psrs)):
    print(i)
    tau_scan = TS_all['tau_scan'+str(i)]
    scan_max = max([np.nanmax(x) for x in tau_scan])
    print(scan_max)
    # os.makedirs("/home/reyna/15yr-v1.1/Script/QuickBurst/data/Tau_scan_plots/99p_SNR/CURN_included/glitch_tau_scans/pngs/{0}/".format(psrs[i].name), exist_ok=True)
    for l in range(tau_edges.size-1):
        fig = plt.figure(l+i*(tau_edges.size-1))
        c = plt.gca().pcolormesh(T0_list[l]/24/3600/365, F0_list[l]/1e-9 ,np.sqrt(tau_scan[l]), vmax=np.sqrt(scan_max), vmin=0.0)
        if len(data_gaps[i]) > 0:
            
            for j in range(len(data_gaps[i])):
                val_1 = random.uniform(0.5,1)
                val_2 = random.uniform(0.5,1)
                val_3 = random.uniform(0.5,1)
                plt.gca().axvline((data_gaps[i][j][0]/3600/24/365 - minimum/3600/24/365), label = 'Data gap {}'.format(j), color= (val_1, val_2, val_3))
                plt.gca().axvline((data_gaps[i][j][1]/3600/24/365 - minimum/3600/24/365), color = (val_1, val_2, val_3))
        plt.colorbar(c)
        plt.title("PSR{2:s} -- tau = {0:.2f} - {1:.2f} years".format(tau_edges[l], tau_edges[l+1], psrs[i].name), size=15)
        plt.xlabel("t [days]")
        plt.ylabel("f [nHz]")
        plt.tight_layout()
        legend = plt.legend()
        legend.get_frame().set_alpha(0.5)
        
        # plt.savefig("/home/reyna/15yr-v1.1/Script/QuickBurst/data/Tau_scan_plots/99p_SNR/CURN_included/glitch_tau_scans/pngs/{0}/tau_bin_{1:.2f}-{2:.2f}_years.png".format(psrs[i].name, tau_edges[l], tau_edges[l+1]), dpi = 600)

# Plot all other noise transient tau scans

In [None]:
tau_edges = TS_all['tau_edges']
T0_list = TS_all['t0_edges']
F0_list = TS_all['f0_edges']

for i in range(len(psrs)):
    print(i)
    tau_scan = TS_all['tau_scan'+str(i)]
    scan_max = max([np.nanmax(x) for x in tau_scan])
    print(scan_max)
    for l in range(tau_edges.size-1):
        fig = plt.figure(l+i*(tau_edges.size-1))
        c = plt.gca().pcolormesh(T0_list[l]/24/3600, F0_list[l]/1e-9 ,np.sqrt(tau_scan[l]), vmax=np.sqrt(scan_max), vmin=0.0)
        plt.colorbar(c)
        plt.title("PSR{2:d} -- tau = {0:.2f} - {1:.2f} years".format(tau_edges[l], tau_edges[l+1], i))
        plt.xlabel("t [days]")
        plt.ylabel("f [nHz]")

In [None]:
fig

# Stitching together individual glitch tau scans and saving combined tau scan

In [None]:
TS_dict = {} #holds the wavelet tau scans (combination of all pulsars)

for i in range(len(psrs)):
    glitch_tau_scan_data = np.copy(TS_all['tau_scan{}'.format(i)])
    
    #Only populate tau edges, f0, t0 once. Same for all pulsars
    print('Pulsar {}'.format(i))
    if i==0:
        TS_dict['tau_edges'] = TS_all['tau_edges']
        TS_dict['f0_edges'] = TS_all['f0_edges']
        TS_dict['t0_edges'] = TS_all['t0_edges']
        TS_dict['tau_scan'] = glitch_tau_scan_data
    
    #Otherwise, stitch together tau scans for each pulsar to get wavelet tau scans
    else:
        for j in range(len(TS_dict['tau_scan'])):
            ts = glitch_tau_scan_data
            TS_dict['tau_scan'][j] += np.where(np.isnan(ts[j]), 0.0, ts[j])

## Saving stitched together tau scans (i.e. Wavelet tau scans)

In [None]:
# with open("/home/reyna/15yr-v1.1/Data/Simple_test_data/20_pulsars/WN_only/Tau_scans/GW_wavelet.pkl", 'wb') as f:
#     pickle.dump(TS_dict, f)

In [None]:
with open("/home/reyna/15yr-v1.1/Script/QuickBurst/data/wavelet_tau_scan_SNR99p.pkl", 'rb') as f:
    TS_dict = pickle.load(f)

## Plotting wavelet tau scan

In [None]:
wavelet_tau_edges = TS_dict['tau_edges']
wavelet_T0_list = TS_dict['t0_edges']
wavelet_F0_list = TS_dict['f0_edges']
wavelet_tau_scan = TS_dict['tau_scan']

wavelet_scan_max = max([np.nanmax(x) for x in wavelet_tau_scan])
print(wavelet_scan_max)
i = 29
for l in range(wavelet_tau_edges.size-1):
    fig = plt.figure(l+i*(wavelet_tau_edges.size-1))
    c = plt.gca().pcolormesh(wavelet_T0_list[l]/24/3600/365, wavelet_F0_list[l]/1e-9 ,np.sqrt(wavelet_tau_scan[l]), vmax=np.sqrt(wavelet_scan_max), vmin=0.0)
    plt.colorbar(c)
    plt.title("{2} -- tau = {0:.2f} - {1:.2f} years".format(wavelet_tau_edges[l], wavelet_tau_edges[l+1], "PTA tau scan"), size = 15)
    plt.xlabel("t [years]")
    plt.ylabel("f [nHz]")
    # plt.savefig("/home/reyna/15yr-v1.1/Script/QuickBurst/data/Tau_scan_plots/99p_SNR/CURN_included/wavelet_tau_scans/pngs/tau_bin_{0:.2f}-{1:.2f}_years.png".format(wavelet_tau_edges[l], wavelet_tau_edges[l+1]))