In [None]:
import os
import sys

try:
    import ipywidgets
except:
    !{sys.executable} -m pip install ipywidgets
    
from ipywidgets import Button, HBox #<----- Add HBox for displaying multiple buttons
import pandas as pd
pd.set_option('display.max_columns', 51)
pd.set_option('display.max_rows', 5)
pd.options.mode.chained_assignment = None  # default='warn'
from glob import glob
import requests
import itertools
import numpy as np
import matplotlib.pyplot as plt
import re
import math
import importlib
from bokeh.io import output_notebook
from bokeh.plotting import show
from io import BytesIO

from astropy.visualization.wcsaxes.frame import EllipticalFrame
from astropy.coordinates import Angle, SkyCoord
from astropy import units as u
from mocpy import World2ScreenMPL, MOC
import healpy as hp
from ligo.skymap.io import read_sky_map
from ligo.skymap.postprocess import find_greedy_credible_levels

from vasttools.pipeline import Pipeline
from vasttools.moc import VASTMOCS
from vasttools.tools import add_credible_levels, find_in_moc, skymap2moc
import annotator

from IPython.display import display
%matplotlib inline
output_notebook()

!jupyter nbextension enable --py widgetsnbextension

In [None]:
pipe = Pipeline()
piperun = pipe.load_run('combined')
query = (
    "n_measurements > 1 "
    "& n_neighbour_dist > 1/60. "
    "& avg_compactness < 1.5 "
    "& n_relations == 0 "
    "& n_siblings == 0"
    "& n_selavy >= 2"
    "& avg_flux_int/avg_flux_peak < 1.5" 
    "& v_peak >= 0"
    "& max_snr > 7")
sources = piperun.sources.query(query)
display(sources)

In [None]:
eta_cutoff, v_cutoff, interest, plot = piperun.run_eta_v_analysis(2, 2, df=sources)
print(eta_cutoff, v_cutoff)

In [None]:
meas = piperun.measurements
meas = meas[meas['forced'] == False]
grouped_meas = meas.groupby('source')
min_time_df = grouped_meas.agg({'time': 'min'})
display(min_time_df)

In [None]:
def name_to_time(name):
    name = name.lstrip('GW').split('_')
    if len(name) == 1:
        return pd.to_datetime(name[0], format='%y%m%d').to_datetime64()
    return pd.to_datetime(name[0] + name[1], format='%y%m%d%H%M%S').to_datetime64()

In [None]:
def show_interesting_sources(interest, event_time=0, save_fig=False):
    for source in interest.index[:]:
        print(source)
        my_source = piperun.get_source(source)
        #display(my_source.ned_search())
        a = my_source.plot_lightcurve(start_date=pd.Timestamp(event_time), save=save_fig)
        ax = a.gca()
        ax.set_xscale('log')
        ax.set_yscale('log')
        ax.set_ylim(bottom=0.1)
        b = my_source.show_all_png_cutouts(size=Angle(2*u.arcmin))
        plt.show()
        plt.close(fig='all')

In [None]:
footprint_file = 'full_VAST_footprint_1024.fits'
coverage_map = hp.read_map(footprint_file, nest=True)
coverage_NPIX= len(coverage_map)
coverage_NSIDE = hp.npix2nside(coverage_NPIX)

cutoff = 0.9
coverages = pd.read_csv('coverages.csv')
files = glob('./GWTC1/*') + glob('./GWTC2/*PublicationSamples.fits') + glob('./GWTC3/*Mixed.fits')    

In [None]:
%matplotlib inline

if not os.path.isfile('gw_event_sources.csv'):
    full_df = pd.DataFrame()
    for idx, f in enumerate(files[:]):
        event_name = re.findall("(GW[0-9]{6}_[0-9]{6}|GW[0-9]{6})", f)[0]
        print(f'{event_name}: {idx+1}/{len(files)}')
        if coverages[coverages['Event'] == event_name]['90_percent_coverage'].values[0] == 0:
            continue

        event_time = name_to_time(event_name)   
        skymap, history = read_sky_map(f, nest=True)
        skymap = hp.ud_grade(skymap, coverage_NSIDE, order_in='NESTED', power=-2)
        nside = hp.get_nside(skymap)
        level = np.log2(nside)
        credible_levels = find_greedy_credible_levels(skymap)
        idx = np.where(credible_levels < cutoff)[0]
        levels = np.ones(len(idx)) * level
        moc = MOC.from_healpix_cells(idx, depth=levels)
        sources_in_moc = sources.iloc[find_in_moc(moc, sources)]

    #     theta = 0.5 * np.pi - np.deg2rad(sources_in_moc['wavg_dec'].values)
    #     phi = np.deg2rad(sources_in_moc['wavg_ra'].values)
    #     ipix = hp.ang2pix(nside, theta, phi, nest=True)
    #     sources_in_moc.loc[:, 'credible_level'] = credible_levels[ipix]
    #     sources_in_moc = sources_in_moc.sort_values('credible_level', ascending=True)

        sources_after_event = min_time_df[min_time_df['time'] >= event_time].source.values
        possible_sources = sources_in_moc[
            (sources_in_moc.index.isin(sources_after_event))
            & (sources_in_moc.index.isin(interest.index.values))]

        if not possible_sources.empty:
            df = possible_sources[['wavg_ra', 'wavg_dec', 'v_peak', 'eta_peak','min_snr','max_snr']]
            df['link'] = df.apply(lambda row: "https://dev.pipeline.vast-survey.org/sources/"+str(row.name), axis = 1)
            df['event'] = event_name

            if full_df.empty:
                full_df = df
            else:
                full_df = full_df.append(df)
    #     show_interesting_sources(possible_sources, event_time)
    full_df.to_csv('gw_event_sources.csv', index_label='id')
full_df = pd.read_csv('gw_event_sources.csv', index_col='id')
full_df

In [None]:
mocs = VASTMOCS()
def plot_sources(sources):
    fig = plt.figure(figsize=(24,12))

    epoch1_moc = mocs.load_pilot_epoch_moc('1')
    # 
    with World2ScreenMPL(
        fig,
        fov=320 * u.deg,
        center=SkyCoord(0, 0, unit='deg', frame='icrs'),
        coordsys="icrs",
        rotation=Angle(0, u.degree),
    ) as wcs:
        ax = fig.add_subplot(111, projection=wcs, frame_class=EllipticalFrame)
        ax.set_title("Event Sources")
        ax.grid(color="black", linestyle="dotted")
        epoch1_moc.fill(ax=ax, wcs=wcs, alpha=0.5, fill=True, linewidth=0, color="#00bb00")
        epoch1_moc.border(ax=ax, wcs=wcs, alpha=0.5, color="black")
        ax.scatter(
            sources['wavg_ra']*u.deg, 
            sources['wavg_dec']*u.deg, 
            transform=ax.get_transform('world'),
            zorder=10,
            s=2
        )
    plt.show()
    plt.close()

# AMON

In [None]:
# %matplotlib inline
# for f in glob(f"AMON/*.fits")[:]:
#     print(f)
#     moc = MOC.from_fits(f)
#     idx = find_in_moc(moc, sources, pipe=True)
#     sources_in_moc = sources.iloc[idx]
#     display(sources_in_moc)
#     show_interesting_sources(sources_in_moc)

In [None]:
candidates = full_df[~full_df.index.duplicated(keep='first')]
candidates

In [None]:
%matplotlib inline

## Create pngs
create_pngs = True

out_dir = 'gw_search'
source_df = candidates

if not os.path.isdir('gw_search'):
    for idx, (i, source) in enumerate(source_df.iterrows()):
        print(f'{idx+1}/{source_df.shape[0]}')
        source_name = annotator.get_source_name(source).replace(' ','_')
        candidate_source = piperun.get_source(i)
        event_time = name_to_time(source.event)
        if create_pngs:
            candidate_source.plot_lightcurve(save=True, start_date=pd.Timestamp(event_time), outfile='{}/{}_lc.png'.format(out_dir,source_name), figsize=(4,2), hide_legend=True, plot_dpi=100)

    #         num_columns = 4
    #         cutout_width = 8
    #         row_height = 2
    #         num_rows = math.ceil(len(candidate_source.measurements)/num_columns)

    #         num_rows = 2
    #         row_height = 2
    #         column_width = 2
    #         cutout_height = num_rows*row_height
    #         num_columns = math.ceil(len(candidate_source.measurements)/num_rows)
    #         cutout_width = num_columns*column_width
    #         cutout_height = row_height*num_rows

    #         candidate_source.show_all_png_cutouts(columns=num_columns, figsize=(cutout_width, cutout_height), hide_epoch_labels=True, force=True)
    #         plt.tight_layout()
    #         plt.savefig('{}/{}_cutouts.png'.format(out_dir,source_name), dpi=100)
    #         plt.close()
    #         plt.clf()

    #         candidate_source.show_all_png_cutouts(columns=num_columns, hide_epoch_labels=True, size=Angle(1*u.arcmin), figsize=(cutout_width, cutout_height), force=True)
    #         plt.tight_layout()
    #         plt.savefig('{}/{}_cutouts_zoom.png'.format(out_dir,source_name), dpi=100)
    #         plt.close()
    #         plt.clf()

In [None]:
def save_annotations(annotations, fname):
    ann_df = pd.DataFrame(annotations).set_index('item')[['reason', 'label']]
    ann_df.to_csv(fname)
    
def load_saved_annotations(fnames):
    dfs = []
    for fname in fnames:
        dfs.append(pd.read_csv(fname, index_col='item'))
    
    return pd.concat(dfs)

def filter_prev(classifications, source_df):
    filtered_df = source_df.drop(index=classifications.index)
    return filtered_df

In [None]:
prev_classifications = load_saved_annotations(['2022-01-24-classifications-0.csv','2022-01-24-classifications-1.csv'])
filtered_df = filter_prev(prev_classifications, candidates)
stage_1_annotations, _annotator = annotator.annotate_sources(filtered_df, ['artefact', 'no', 'unlikely', 'interesting', 'very interesting'], folder=out_dir)

In [None]:
df = load_saved_annotations(['2022-01-24-classifications-0.csv','2022-01-24-classifications-1.csv'])

In [None]:
r = requests.get('https://docs.google.com/spreadsheets/d/e/2PACX-1vTPTtxWq4mVNiM5eKL_98a53O6-gQteS7Ab7kdIUqtwxsThLIR7yh60kPTTiwbw0pE45mXoZUYeBCWA/pub?output=csv')
df = pd.read_csv(BytesIO(r.content), index_col=0)
interesting = df[(df.label.isin(['very interesting', 'interesting']))]
interesting

In [None]:
%matplotlib inline
for idx, s in enumerate(interesting.index.values):
    print(f'{idx+1}/{len(interesting.index.values)}')
    new_df = df[df.index == s]
    for i, row in new_df.iterrows():
        event = row.event
        event_time = name_to_time(event)
        source = piperun.get_source(i)
        a = source.plot_lightcurve(start_date=pd.Timestamp(event_time))
        ax = a.gca()
        ax.set_xscale('log')
        ax.set_yscale('log')
        ax.set_ylim(bottom=0.1)
        plt.savefig(f'gw_lightcurves/{s}_{event}', dpi=100)

In [None]:
glade_columns = ['GLADE no',
 'PGC no',
 'GWGC name',
 'HyperLEDA name',
 '2MASS name',
 'WISExSCOS name',
 'SDSS-DR16Q name',
 'Object type flag',
 'RA',
 'Dec',
 'B',
 'B_err',
 'B flag',
 'B_Abs',
 'J',
 'J_err',
 'H',
 'H_err',
 'K',
 'K_err',
 'W1',
 'W1_err',
 'W2',
 'W2_err',
 'W1 flag',
 'B_J',
 'B_J err',
 'z_helio',
 'z_cmb',
 'z flag',
 'v_err',
 'z_err',
 'd_L',
 'd_L err',
 'dist flag',
 'M*',
 'M*_err',
 'Merger rate',
 'Merger rate error']


In [None]:
coords1_sc = SkyCoord(ra=interesting['wavg_ra'], dec=interesting['wavg_dec'], unit=(u.deg, u.deg), frame='icrs')
def crossmatch(chunk, max_sep, c2ra, c2dec):
    matches = []

    # Convert to astropy coordinates objects
    coords2_sc = SkyCoord(ra=chunk[c2ra], dec=chunk[c2dec], unit=(u.deg, u.deg), frame='icrs')
    
    # Perform crossmatching
    idxc, idxcatalog, d2d, _ = coords2_sc.search_around_sky(coords1_sc, max_sep)
    
    return (idxc, idxcatalog, d2d)

In [None]:
usecols=[
    'GLADE no',
    'PGC no',
    'GWGC name',
    'HyperLEDA name',
    '2MASS name',
    'WISExSCOS name',
    'SDSS-DR16Q name',
    'RA',
    'Dec',
    'z_helio',
    'z_cmb',
    'z flag',
    'z_err',
    'd_L',
]


data = []
with pd.read_csv('GLADE+.txt', sep=' ', chunksize=1000000, header=None, names=glade_columns, usecols=usecols, index_col='GLADE no') as f:
    for idx, chunk in enumerate(f):
        print(idx)
        vast_match_idxs, chunk_match_idxs, d2d = crossmatch(chunk, 20*u.arcsec, 'RA', 'Dec')
        if vast_match_idxs.size > 0:
            x = list(zip(vast_match_idxs, chunk_match_idxs, d2d))
            
            for el in x:
                res = [interesting.iloc[el[0]].name]
                res += [chunk.iloc[el[1]].name]
                res.extend(chunk.iloc[el[1]].values)
                res += [el[2].value]
                data.append(res)
                print(res)

In [None]:
if not os.path.isfile('glade_matches.csv'):
    cols = ['id']+usecols+['sep']
    df = pd.DataFrame(data,columns=cols)
    df.to_csv('glade_matches.csv', index=False)
df = pd.read_csv('glade_matches.csv')

In [None]:
glade_ids = df['GLADE no'].values


In [None]:
c=0
z, o,t,th = 0, 0,0, 0
t=0
with open('GLADE+.txt', 'r') as f:
    for idx, line in enumerate(f):
        x = line.split()
        if x[-6]=='null':
            if x[-5] == '1':
                o+=1
            elif x[-5] == '2':
                t += 1
            elif x[-5] == '3':
                th += 1
            else:
                z += 1
            

In [None]:
plt.hist([0,1,2,3], weights=[z,o,t,th], bins=[-.5,0.5,1.5,2.5,3.5])


In [None]:
z,o,t,th