In [None]:
import numpy, healpy, requests
import tensorflow
from ligo.skymap import io
from astropy.io import fits
from reproject import reproject_from_healpix

In [None]:
##########################################
# Load model and pre-processing functions
##########################################

In [None]:
# Load GWSkyNet
model_name = 'GWSkyNet_v1'
with open('{}.json'.format(model_name), 'r') as json_file:
    json_model = json_file.read()
model = tensorflow.keras.models.model_from_json(json_model)
model.load_weights('{}.h5'.format(model_name))

In [None]:
# Target header for reproject_from_healpix
target_header = fits.Header.fromstring("""
NAXIS   =                    2
NAXIS1  =                  360
NAXIS2  =                  180
CTYPE1  = 'RA---CAR'
CRPIX1  =                180.5
CRVAL1  =                180.0
CDELT1  =                   -1
CUNIT1  = 'deg     '
CTYPE2  = 'DEC--CAR'
CRPIX2  =                 90.5
CRVAL2  =                  0.0
CDELT2  =                    1
CUNIT2  = 'deg     '
COORDSYS= 'icrs    '
""", sep='\n')

# Normalization factors from the training set. These values might change in future versions.
training_norms = {'distance': 10320, 'skymap': 0.005, 
                  'distmu': 11553, 'distnorm': 863800, 'distsigma': 12065}

def nan_invalid(data, invalid_value):
    """Turn invalid values into numpy.nan"""
    invalid_indices = numpy.where(data==invalid_value)
    for idx in invalid_indices:
        data[idx] = numpy.nan
    return data

def prepare_data(fits_file):
    """Pre-processing data from FITS file for GWSkyNet"""
    skymap, metadata = io.read_sky_map(fits_file, distances=True, nest=None)
    
    # Distance must be normalized by maximum in the training set
    distance = metadata['distmean'] / training_norms['distance']
    
    network = metadata['instruments']
    # Convert detector network to multi-hot format
    dets = []
    for ifo in ['H1', 'L1', 'V1']:
        dets.append(1) if ifo in network else dets.append(0)
        
    # Read data columns from FITS file
    # invalid_values = {'Distmu':numpy.inf, 'Distsigma':1., 'Distnorm':0.}
    # (convention described in Table 1 of https://arxiv.org/pdf/1605.04242.pdf)
    fits_cols = {'skymap':skymap[0],
                 'distmu':nan_invalid(skymap[1], numpy.inf),
                 'distsigma':nan_invalid(skymap[2], 1.),
                 'distnorm':nan_invalid(skymap[3], 0.)}
    
    # Reproject and downsample each column
    img_data, norms = dict(), dict()
    for column in fits_cols:
        with numpy.errstate(invalid='ignore'):
            img, mask = reproject_from_healpix((fits_cols[column], 'ICRS'),
                                       target_header, nested=metadata['nest'], hdu_in=None,
                                       order='bilinear', field=0)
        
        # Replace NaN with zero and normalize img data
        img = numpy.nan_to_num(img)
        norms[column] = numpy.max(img)
        img = img / norms[column]
        # Normalize norms by maximum in the training set
        norms[column] /= training_norms[column]
        
        # Downsample img data using maxpooling
        x = numpy.reshape(img, (1, len(img), len(img[0]), 1))
        x = tensorflow.cast(x, tensorflow.float32)
        maxpool = tensorflow.keras.layers.MaxPooling2D(pool_size=(2, 2))
        img_data[column] = maxpool(x)
    
    # Stack volume images
    dist_columns = ['distmu', 'distsigma', 'distnorm']
    stacked_volume = numpy.stack([numpy.reshape(img_data[column], (1, 90, 180)) for column in dist_columns], axis=-1)

    return [stacked_volume, img_data['skymap'], numpy.reshape(dets, (1,3)), numpy.reshape(distance, (1,1)),
            numpy.reshape(norms['skymap'], (1,1)), numpy.reshape(norms['distmu'], (1,1)),
            numpy.reshape(norms['distsigma'], (1,1)), numpy.reshape(norms['distnorm'], (1,1))]

def predict(loaded_model, data, threshold):
    """Use loaded model to predict result
    
    Keyword arguments:
    loaded_model: machine-learning model to use for prediction
    data: pre-processed data from FITS file
    threshold: real-noise threshold to predict real events (typically 0.5)
    """
    prediction = tensorflow.squeeze(loaded_model(data), [-1]).numpy()
    print('Predicted probability: {:.2f}%'.format(prediction[0]*100))
    if prediction >= threshold:
        print('This candidate is likely an astrophysical signal.')
    else:
        print('This candidate is NOT astrophysical.')

In [None]:
##########################################
# Make predictions
##########################################

In [None]:
# Choose candidate from GraceDB and download corresponding FITS file
# See https://gracedb.ligo.org/superevents/public/O3/ for a list of candidates
event_name = 'S200316bj'
event_url = 'https://gracedb.ligo.org/apiweb/superevents/{}/files/'.format(event_name)
r = requests.head(event_url + 'bayestar.multiorder.fits')
try:
    r.headers['Content-Disposition']
    fits_url = event_url + 'bayestar.multiorder.fits'
except KeyError:
    # Older events do not have bayestar.multiorder.fits file
    fits_url = event_url + 'bayestar.fits'
fits_name = '{}.fits'.format(event_name)
!curl --output $fits_name $fits_url

In [None]:
data = prepare_data(fits_name)
# Real-noise threshold
RN_threshold = 0.5
predict(model, data, RN_threshold)