In [None]:
import os

import tensorflow as tf

from astronet.ops import dataset_ops
from astronet.util import config_util

data_files = '/mnt/tess/astronet/tfrecords-new\+old/test-0000[6-7]*'

parent = '/mnt/tess/astronet/checkpoints/local_global_new_old_2/10'
all_dirs = os.listdir(parent)
d, = all_dirs
model_dir = os.path.join(parent, d)

model = tf.saved_model.load(model_dir)
config = config_util.load_config(model_dir)

ds = dataset_ops.build_dataset(
    file_pattern=data_files,
    input_config=config['inputs'],
    batch_size=1,
    include_labels=False,
    reverse_time_series_prob=0,
    shuffle_filenames=False,
    repeat=1,
    use_tpu=False,
    one_hot_labels=(config['hparams']['output_dim'] > 1),
    include_identifiers=True)

In [None]:
def normalize(input_config, feature, val):
    mean = config.inputs.features[feature]['mean']
    std = config.inputs.features[feature]['std']
    return (val - mean) / std


def denormalize(input_config, feature, val):
    mean = config.inputs.features[feature]['mean']
    std = config.inputs.features[feature]['std']
    return val * std + mean


def sweep_inputs(tic_id, ds, config, feature, min_val, max_val, n, n2):
    def override(i, data):
        i = tf.cast(i // (n2 + 1), tf.float32)
        data[0][feature] = normalize(
            config.inputs, feature, 
            data[0][feature] * 0 + (
                min_val + i * (max_val - min_val) / n))
        return data
    
    ds = ds.filter(
        lambda _, ids: tf.squeeze(ids['tic_id'] == tic_id)
    )
    ds = ds.repeat(n + 1)
    ds = ds.enumerate()
    ds = ds.map(override)
    
    return ds

In [None]:
import pandas as pd
from matplotlib import pyplot as plt
from matplotlib import ticker

def plot_sweep(sweep_predictions, label, n, features, fig, subplot):
    assert len(features) == 2
    pred_e = sweep_predictions[label].values.reshape(n + 1, n + 1)
    x = sweep_predictions[features[0]].values.reshape(n + 1, n + 1)[0, :]
    y = sweep_predictions[features[1]].values.reshape(n + 1, n + 1)[:, 0]

    ax = fig.add_subplot(subplot)
    im = ax.imshow(pred_e, vmin=0, vmax=1, cmap=plt.get_cmap('RdYlGn'))
    ax.set_xticklabels([''] + ['%3.3f' % v for v in x])
    ax.set_yticklabels([''] + ['%3.3f' % v for v in y])
    ax.set_xlabel(features[0])
    ax.set_ylabel(features[1])
    plt.title(label)
    return im
    
def sweep(tic_id, ds, config, n, featurs_and_ranges):
    fr_1, fr_2 = featurs_and_ranges
    
    name_1, mn, mx = fr_1
    ds = sweep_inputs(tic_id, ds, config, name_1, mn, mx, n, 0)
    name_2, mn, mx = fr_2
    ds = sweep_inputs(tic_id, ds, config, name_2, mn, mx, n, n)
    
    labels = ["disp_E", "disp_N", "disp_J", "disp_S", "disp_B"]

    data = []
    for features, identifiers in ds:
        preds = model(features)
        row = dict(zip(labels, preds.numpy()[0]))
        row[name_1] = denormalize(
            config, name_1, features[name_1].numpy().item())
        row[name_2] = denormalize(
            config, name_2, features[name_2].numpy().item())
        data.append(row)

    sweep_predictions = pd.DataFrame(data)
    
    fig = plt.figure(figsize=(20, 14))
    subplot = 230
    for l in labels:
        subplot += 1
        im = plot_sweep(sweep_predictions, l, n, (name_1, name_2), fig, subplot)
    fig.add_subplot(subplot + 1)
    fig.colorbar(im)
    
    return sweep_predictions

In [None]:
sweep_predictions = sweep(
    290460897, ds, config, 4,
    (('star_mass', 0.14, 4), ('star_rad', 0.17, 12)))

In [None]:
sweep_predictions = sweep(
    160991022, ds, config, 4,
    (('star_rad', 0.17, 50), ('Period', 0.02, 50)))

In [None]:
filtered_ds = ds.filter(
        lambda _, ids: tf.squeeze(ids['tic_id'] == 407146723)
    )

from matplotlib import pyplot as plt
from scipy import ndimage

for features, _ in filtered_ds:
    fig = plt.figure(figsize=(12, 3))
    fig.add_subplot(131)
    gv = features['global_view'][0].numpy()
    plt.plot(gv)
    gv = ndimage.gaussian_filter(gv, 0.1)
    plt.plot(gv)
    
    fig.add_subplot(132)
    lv = features['local_view'][0].numpy()
    plt.plot(lv)
    lv = ndimage.gaussian_filter1d(lv, 2)
    plt.plot(lv)
    
    fig.add_subplot(133)
    sv = features['secondary_view'][0].numpy()
    plt.plot(sv)
    sv = ndimage.gaussian_filter(sv, 1.0)
    plt.plot(sv)
    
    preds = model(features)
    print('No filter:     ', preds[0][0].numpy().item())
    
    features['local_view'] = tf.expand_dims(lv, 0)

    preds = model(features)
    print('Local:         ', preds[0][0].numpy().item())
    
    features['global_view'] = tf.expand_dims(gv, 0)
    
    preds = model(features)
    print('Global + Local:', preds[0][0].numpy().item())
    
    features['secondary_view'] = tf.expand_dims(sv, 0)
    
    preds = model(features)
    print('All:           ', preds[0][0].numpy().item())
    print(list(zip(["disp_E", "disp_N", "disp_J", "disp_S", "disp_B"], preds[0])))