### `gotorb` example notebook
* Reproduce some of the plots from the paper using a small (~10%) subset of the test set used in the paper.


In [None]:
import os
import tensorflow as tf
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib import gridspec

from multiprocessing import cpu_count
from scipy.stats import gaussian_kde
from astropy.visualization import ZScaleInterval

from sklearn.metrics import confusion_matrix, accuracy_score
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import hickle

from gotorb.active_learning import binary_entropy
from gotorb.classifier import dropout_pred, latentvec_pred

# `disable' hyperthreading.
tf.config.threading.set_intra_op_parallelism_threads(int(cpu_count()/2))

In [None]:
# unpack pre-made dataset
stamps, meta = hickle.load("../data/datapack.hkl")

print("### Dataset composition ###")
for name in np.unique(meta.metalabel.values):
    print(f"{name}: {(meta.metalabel.values == name).sum()} of {len(meta)}")

# load model from hdf5 file.
model = tf.keras.models.load_model("../data/gotorb_valmodel_BALDflip_20201030-170220.h5")

In [None]:
# Deterministic classifier predictions - using `keras` predict call
classifier_preds = model.predict(stamps, verbose=1)[:,0]
# Bayesian predictions - enable dropout and predict a new version each time.
posterior_preds = dropout_pred(model, stamps, verbose=1, batchsize=4096)
posterior_scores = posterior_preds.mean(axis=1) # get flattened posteriors

### FPR/FNR curves

In [None]:
fig, ax = plt.subplots(dpi=120)

mp_mask = (meta.metalabel.values != "syntransient")  & (meta.metalabel != 'marshall')  & (meta.metalabel != "glxresid")
trn_mask = (meta.metalabel.values != "mp")   & (meta.metalabel != 'marshall') & (meta.metalabel != "randjunk")

plot_fpr = 1
for name, msk in zip(["MP only", "Transient only"], [mp_mask, trn_mask, np.ones(len(mp_mask))]):

    preds = posterior_preds.mean(axis=1)[msk]
    labels = meta.label[msk]
    rbbins = np.arange(-0.0001, 1.0001, 0.0001)

    h_b, e_b = np.histogram(preds[(labels == 0).values], bins=rbbins, density=True)
    rb_thres = np.array(list(range(len(h_b)))) / len(h_b)
    h_b_c = np.cumsum(h_b)
    h_r, e_r = np.histogram(preds[(labels == 1).values], bins=rbbins, density=True)
    h_r_c = np.cumsum(h_r)

    ax.plot(rb_thres, h_r_c / np.max(h_r_c),
            label='{} FNR'.format(name), linewidth=1)
    if plot_fpr:
        ax.plot(rb_thres, 1 - h_b_c / np.max(h_b_c),
                label='All-data FPR', linewidth=1)
        plot_fpr = 0

        mmce = (h_r_c / np.max(h_r_c) + 1 - h_b_c / np.max(h_b_c)) / 2
        ax.plot(rb_thres, mmce, '--',
                label='MMCE', color='gray', linewidth=1)

    
ax.set_xlim([-0.03, 1.03])

ax.set_xticks(np.arange(0, 1.1, 0.1))
ax.set_yticks(np.arange(0, 1.1, 0.1))

ax.set_yscale('log')
ax.set_ylim([5e-4, 1])
vals = ax.get_yticks()

# latex fudge - need to encode percentages differently since % is reserved in TeX.
axylabels = ['{:,.1%}'.format(x) if x < 0.01 else '{:,.0%}'.format(x) for x in vals]
axylabels = [r.replace("%", "$\%$") for r in axylabels]
ax.set_yticklabels(axylabels)
plt.grid()
ax.set_xlabel("Real-bogus threshold")
ax.set_ylabel("Cumulative percentage")
plt.legend(loc="upper right", bbox_to_anchor=(0.85, 1), edgecolor='k')

axins = inset_axes(ax, width="70%", height="70%",
                   bbox_transform=ax.transAxes, loc=3,
                  bbox_to_anchor=(0.2, 0.6, 0.3, 0.55))


cm_norm = confusion_matrix(meta.label[mp_mask | trn_mask], np.rint(posterior_preds.mean(axis=1)[mp_mask | trn_mask]))
cm = confusion_matrix(meta.label[mp_mask | trn_mask], np.rint(posterior_preds.mean(axis=1)[mp_mask | trn_mask]))
cm_norm = 100*cm_norm / cm.sum(axis=1)[:, np.newaxis] # normalise it
axins.imshow(cm_norm, cmap='terrain_r')
axins.set_xlabel("Predicted", fontsize=8)
axins.set_ylabel("True", fontsize=8)
axins.set_xticklabels(["b", "r"], fontsize=8)
axins.set_yticklabels(["b", "r"], fontsize=8)

axins.set(xticks=np.arange(cm.shape[1]),
           yticks=np.arange(cm.shape[0]),
           ylim=(1.5, -0.5),
           )

thresh = cm_norm.max() / 2.
for i in range(cm_norm.shape[0]):
    for j in range(cm_norm.shape[1]):
        axins.text(j, i, '{:.2f}% \n({:d})'.format(cm_norm[i, j], cm[i, j]),
                ha="center", va="center",
                color="white" if cm_norm[i, j] > thresh else "black",
                size=8)
        
plt.show()

### What do the resultant posteriors look like?

In [None]:
def plot_example(stamps, posterior, meta):
    titles = ["SCIENCE", "TEMPLATE", "DIFFERENCE"]
    fig, ax = plt.subplots(1, 4, dpi=120)
    for a, label in enumerate(titles):
        scaler = ZScaleInterval()
        vlo, vhi = scaler.get_limits(stamps[:,:,a])
        ax[a].imshow(stamps[:, :, a], aspect=1, cmap='Greys_r', vmin=vlo, vmax=vhi)
        ax[a].set_title(label)
        ax[a].set_frame_on(False)
        ax[a].tick_params(top=False, bottom=False, left=False, right=False,
                      labeltop=False, labelbottom=False, labelleft=False, labelright=False)
    
    samplerange = np.linspace(0, 1, 1000)
    kde = gaussian_kde(posterior, bw_method="silverman")
    kde_sampled = kde(samplerange)
    ax[3].fill_between(samplerange, 0, kde_sampled)
    plt.ylim(0, )
    ax[3].set_aspect(1./ax[3].get_data_ratio())
    ax[3].spines['right'].set_visible(False)
    ax[3].spines['top'].set_visible(False)
    ax[3].spines['left'].set_visible(False)
    ax[3].tick_params(left=False, labelleft=False, top=False, right=False)
    ax[3].axvline(0.5, c='k', lw=1, zorder=-1, ls='--')
    
    confidence = 1 - np.mean(binary_entropy(posterior))
    ttl = fig.suptitle(f"Score: {posterior.mean():.2f}  Label: {meta.label.astype(int), meta.metalabel}  Conf.: {confidence:.2f}", y=0.80)

In [None]:
confidence = 1 - np.mean(binary_entropy(posterior_preds), axis=1)

# plot the bottom end of the confidence distribution
for idx in np.argsort(confidence)[np.arange(0, int(len(confidence)/5), 50)]:
    plot_example(stamps[idx], posterior_preds[idx], meta.iloc[idx])

### t-SNE embedding
* Interactive figure compared to the paper - can click through the flattened latent space and identify clusters of weird.

In [None]:
fvec = latentvec_pred(model, stamps, batchsize=2048, verbose=True)

# PCA reduce first to 50ish dimensions for convergence help
print("Computing PCA reduction")
PCA_red = PCA(n_components=50).fit_transform(fvec)
PCA_complete = PCA(n_components=2).fit_transform(fvec)

# compute t-SNE embedding
nphyscores = int(cpu_count()/2)
print("Starting t-SNE embedding task -- this might take a while!")
TSNE_red = TSNE(n_components=2, verbose=1, n_jobs=ncores).fit_transform(PCA_red)

In [None]:
def plot_latentspace(fig, x, y, color_by):
    latent_testvec = np.array([x,y])

    # set up subplot grid
    gridspec.GridSpec(3,3)

    plt.subplot2grid((3,3), (0,0), colspan=3, rowspan=2)
    plt.scatter(TSNE_red[:,0], TSNE_red[:,1], cmap='viridis', c=color_by, s=0.1)
    lvmark = plt.scatter(*latent_testvec, marker='*', c='r')
    plt.xlabel("latent vector 1")
    plt.ylabel("latent vector 2")
    plt.title("t-SNE latent space plot")
    plt.colorbar()

    nn_idx = np.argmin(np.linalg.norm(TSNE_red - latent_testvec, axis=1))
    stampchoice = stamps[nn_idx]
    
    labels = ["SCI", "TEMP", "DIFF"]
    
    for i in range(0, 3):
        plt.subplot2grid((3,3), (2,i))
        scaler = ZScaleInterval()
        vlo, vhi = scaler.get_limits(stampchoice[:,:,i])
        plt.imshow(stampchoice[:,:,i], aspect=1, cmap='Greys_r', vmin=vlo, vmax=vhi)
        plt.xticks([])
        plt.yticks([])
        plt.xlabel(labels[i])
    plt.tight_layout()
    
    
# make class-based colormap
transient_flag = meta.metalabel.values == "syntransient"
mp_flag = meta.metalabel.values == "mp"
randjunk_flag = meta.metalabel.values == "randjunk"
marshall_flag = meta.metalabel.values == "marshall"
resid_flag = meta.metalabel.values == "glxresid"

colours = [1, 0.75, 0.5, 0.25, 0]
flags = [transient_flag, mp_flag, randjunk_flag, marshall_flag, resid_flag]
names = ["syntransient", "mp", "randjunk", "marshall", "glxresid"]
class_cmap = np.zeros(len(meta))

for coef, flg in zip(colours, flags):
    class_cmap += coef*flg.astype(float)

### Interactive figure
* Click anywhere to move to that point in latent space - the stamp previews will update.
* Switch colourmap by changing the cmap variable to one of [`confidence`, `posterior_scores`, `class_cmap`]

In [None]:
%matplotlib notebook
fig = plt.figure(dpi=120)

cmap = class_cmap

plot_latentspace(fig, 0, 0, cmap) # init plot at (0,0)


def onclick(event):
    cx, cy = event.xdata, event.ydata
    fig.clf()
    plot_latentspace(fig, cx, cy, cmap)
    

cid = fig.canvas.mpl_connect('button_press_event', onclick)