# Max-value example kernel

<span style="color: red;">TODO: Integrate with better bcolz API once https://github.com/aimalz/justice/pull/75 is merged.</span>

This example extracts windowed features, and then computes a soft value whether each point (in each band) is a local maxima.

In [None]:
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf

from justice import lightcurve
from justice import visualize
from justice.align_model import max_model_kernel
from justice.datasets import plasticc_bcolz
from justice.datasets import plasticc_data
from justice.features import per_point_dataset, raw_value_features, band_settings_params

bcolz_table = plasticc_bcolz.BcolzDataset(plasticc_bcolz._root_dir / "training_set"
                                          ).read_table()
bcolz_map = bcolz_table.where("object_id == 1598")
df = pd.DataFrame.from_records(bcolz_map, columns=bcolz_table.names)
groupby = df.groupby(['object_id', 'passband'])
all_raw_bands = {}
for group in df.groupby(['object_id', 'passband']):
    (object_id, passband), df_chunk = group
    if object_id not in all_raw_bands:
        all_raw_bands[object_id
                      ] = [None] * len(plasticc_data.PlasticcDatasetLC.expected_bands)
    all_raw_bands[object_id][passband] = lightcurve.BandData(
        np.array(df_chunk['mjd']),
        np.array(df_chunk['flux']),
        np.array(df_chunk['flux_err']),
        np.array(df_chunk['detected']),
    )

lcs = {}
for object_id, bands in all_raw_bands.items():
    assert None not in bands, "If raw data is missing whole bands, then we have to rethink things"
    lcs[object_id] = plasticc_data.PlasticcDatasetLC(
        **dict(zip(plasticc_data.PlasticcDatasetLC.expected_bands, bands))
    )
lc = lcs[1598]

In [None]:
def model_fn(features, labels, mode, params):
    predictions = max_model_kernel.feature_model_fn(features, params)
    predictions['time'] = features['time']
    return tf.estimator.EstimatorSpec(
        mode=mode, predictions=predictions, loss=tf.constant(0.0), train_op=tf.no_op()
    )


window_size = 10
rve = raw_value_features.RawValueExtractor(
    window_size=window_size,
    band_settings=band_settings_params.BandSettings(lc.expected_bands)
)
data_gen = per_point_dataset.PerPointDatasetGenerator(
    extract_fcn=rve.extract,
    batch_size=5,
)

estimator = tf.estimator.Estimator(
    model_fn=model_fn,
    params={
        'batch_size': 5,
        'window_size': window_size,
        'flux_scale_epsilon': 0.5,
        'lc_bands': lc.expected_bands,
    }
)
predictions = list(data_gen.predict_single_lc(estimator, lc))

In [None]:
visualize.plot_lcs(lcs=[lc])

for i, band in enumerate(lc.expected_bands):
    times = [x['time'] for x in predictions]
    is_max_fv = [x['is_max_soft'][i] for x in predictions]
    plt.figure(figsize=(9.5, 2))
    plt.scatter(times, is_max_fv)