# Max-value example kernel

This example extracts windowed features, and then computes a soft value whether each point (in each band) is a local maxima.

In [None]:
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf

from justice import lightcurve
from justice import visualize
from justice.align_model import max_model_kernel
from justice.datasets import plasticc_data
from justice.features import per_point_dataset, raw_value_features, band_settings_params

source = plasticc_data.PlasticcBcolzSource.get_default()
lc, = plasticc_data.PlasticcDatasetLC.bcolz_get_lcs_by_obj_ids(
    bcolz_source=source,
    dataset="training_set",
    obj_ids=[1598]
)

In [None]:
def model_fn(features, labels, mode, params):
    predictions = max_model_kernel.feature_model_fn(features, params)
    predictions['time'] = features['time']
    return tf.estimator.EstimatorSpec(
        mode=mode, predictions=predictions, loss=tf.constant(0.0), train_op=tf.no_op()
    )


window_size = 10
rve = raw_value_features.RawValueExtractor(
    window_size=window_size,
    band_settings=band_settings_params.BandSettings(lc.expected_bands)
)
data_gen = per_point_dataset.PerPointDatasetGenerator(
    extract_fcn=rve.extract,
    batch_size=5,
)

estimator = tf.estimator.Estimator(
    model_fn=model_fn,
    params={
        'batch_size': 5,
        'window_size': window_size,
        'flux_scale_epsilon': 0.5,
        'lc_bands': lc.expected_bands,
    }
)
predictions = list(data_gen.predict_single_lc(estimator, lc))

In [None]:
visualize.plot_lcs(lcs=[lc])

for i, band in enumerate(lc.expected_bands):
    times = [x['time'] for x in predictions]
    is_max_fv = [x['is_max_soft'][i] for x in predictions]
    plt.figure(figsize=(9.5, 2))
    plt.scatter(times, is_max_fv)