# Generating percentiles for TensorFlow model input features

The current TensorFlow model uses histogram-like percentile features, which are kind of a continuous version of one-hot features.

For example, if key cutoff points are `[-3, 1, 0, 2, 10]`, we might encode a value `x` as `sigma((x - cutoff) / scale)`. If `sigma` is the sigmoid function, `x = 0.1`, and `scale = 0.1`, then we'd get `[1, 1, 0.73, 0, 0]`, in other words `x` is definitely above the first 2 points, mostly above the third, and below the fourth and fifth. If we increase `scale` to `2.0`, then values are less discrete: `[0.82, 0.63, 0.51, 0.28, 0.01]`.

This notebook generates appropriate cutoff points for these, to reflect most data encountered.

In [None]:
# Different options for soft-onehot function.
%matplotlib inline
import numpy as np
from matplotlib import pyplot as plt
x = np.linspace(-10, 10, 100)
cutoff = 1.0
sigmoid = lambda x: 1/(1+np.exp(-x))
scale = 2.0
logit = (x - cutoff) / scale
plt.plot(x, sigmoid(logit))
plt.plot(x, np.exp(- logit * logit))

In [None]:
NUM_LCS = 10_000  # key parameter, turn it down if you want this notebook to finish faster.

# Settings determining type of features extracted.
window_size = 10
band_time_diff = 4.0

In [None]:
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf

from justice.datasets import plasticc_data

source = plasticc_data.PlasticcBcolzSource.get_default()
bcolz_source = plasticc_data.PlasticcBcolzSource.get_default()
meta_table = bcolz_source.get_table('test_set_metadata')
%time all_ids = meta_table['object_id'][:]

In [None]:
%%time
import random
sample_ids = random.Random(828372).sample(list(all_ids), NUM_LCS)

lcs = []
_chunk_sz = 100
for start in range(0, len(sample_ids), _chunk_sz):
    lcs.extend(plasticc_data.PlasticcDatasetLC.bcolz_get_lcs_by_obj_ids(
        bcolz_source=source,
        dataset="test_set",
        obj_ids=sample_ids[start:start + _chunk_sz]
    ))

In [None]:
%%time

from justice.features import band_settings_params
from justice.features import dense_extracted_features
from justice.features import feature_combinators
from justice.features import metadata_features
from justice.features import per_point_dataset
from justice.features import raw_value_features

batch_size = 32
rve = raw_value_features.RawValueExtractor(
    window_size=window_size,
    band_settings=band_settings_params.BandSettings(lcs[0].expected_bands)
)
mve = metadata_features.MetadataValueExtractor()
data_gen = per_point_dataset.PerPointDatasetGenerator(
    extract_fcn=feature_combinators.combine([rve.extract, mve.extract]),
    batch_size=batch_size,
)

def input_fn():
    return data_gen.make_dataset_lcs(lcs)

def per_band_model_fn(band_features, params):
    batch_size = params["batch_size"]
    window_size = params["window_size"]
    wf = dense_extracted_features.WindowFeatures(
        band_features, batch_size=batch_size, window_size=window_size, band_time_diff=band_time_diff)
    dflux_dt = wf.dflux_dt(clip_magnitude=None)
    init_layer = dense_extracted_features.initial_layer(wf, include_flux_and_time=True)
    init_layer_masked = wf.masked(init_layer, value_if_masked=0, expected_extra_dims=[3])
    return {
        "initial_layer": init_layer_masked,
        "in_window": wf.in_window,
    }

def model_fn(features, labels, mode, params):
    band_settings = band_settings_params.BandSettings.from_params(params)
    per_band_data = band_settings.per_band_sub_model_fn(
        per_band_model_fn, features, params=params
    )
    predictions = {
        'band_{}.{}'.format(band, name): tensor
        for band, tensor_dict in zip(band_settings.bands, per_band_data)
        for name, tensor in tensor_dict.items()
    }
    predictions['time'] = features['time']
    predictions['object_id'] = features['object_id']
    return tf.estimator.EstimatorSpec(
        mode=mode, predictions=predictions, loss=tf.constant(0.0), train_op=tf.no_op()
    )

params = {
    'batch_size': batch_size,
    'window_size': window_size,
    'flux_scale_epsilon': 0.5,
    'lc_bands': lcs[0].expected_bands,
}
estimator = tf.estimator.Estimator(
    model_fn=model_fn,
    params=params
)
predictions = list(estimator.predict(input_fn=input_fn, yield_single_examples=True))
print(f"Got {len(predictions)} predictions.")

In [None]:
predictions[4]

In [None]:
def get_values_df(band):
    arrays = [x[f"band_{band}.initial_layer"] for x in predictions if x[f"band_{band}.in_window"]]
    return pd.DataFrame(np.concatenate(arrays, axis=0), columns=["dflux_dt", "dflux", "dtime"])
df = get_values_df(lcs[0].expected_bands[0])
df.hist('dflux_dt', bins=32)
df.hist('dflux', bins=32)
df.hist('dtime', bins=32)

## Really messy code to get a histogram with mostly-unique bins.

Because we want fixed-size arrays for TensorFlow code, we want a set of e.g. 32 unique cutoff points that reflect a good distribution of cutoffs. However its is really messy, because there tend to be strong peaks in the histogram which are repeated frequently.

In [None]:
import collections
import scipy.optimize

def _some_duplicates(non_unique, unique, num_desired):
    to_duplicate_candidates = non_unique.tolist()
    for x in unique:
        to_duplicate_candidates.remove(x)
    unique = unique.tolist()
    while len(unique) < num_desired:
        assert len(unique) <= num_desired
        to_duplicate = random.choice(to_duplicate_candidates)
        unique.insert(unique.index(to_duplicate), to_duplicate)
    return unique

def unique_percentiles(array, num_desired):
    partition_size = 100.0 / num_desired
    epsilon = 0.05 * partition_size
    
    solution = None
    optimal_solution = None

    def _actual_unique(vals):
        nonlocal solution, optimal_solution
        if optimal_solution is not None:
            return 0  # stop optimization, or at least return quickly
        num_points_base, perturb = vals
        num_points = int(round(num_desired * num_points_base))
        perturb = abs(perturb)
        q = np.linspace(0, 100, int(num_points))
        rng = np.random.RandomState(int(1e6 * perturb))
        noise = rng.normal(loc=0, scale=min(1.0, 10 * perturb) * epsilon, size=q.shape)
        noise[0] = 0
        noise[-1] = 0
        q += noise
        non_unique = np.percentile(array, q=q, interpolation='linear')
        unique = np.unique(non_unique)
        result = abs(num_desired - len(unique))
        if num_desired == len(unique):
            optimal_solution = unique
        elif len(unique) <= num_desired <= len(unique) + 1:
            solution = _some_duplicates(non_unique, unique, num_desired)
        return (4 if len(unique) > num_desired else 1) * result + perturb
    
    res = scipy.optimize.minimize(
        _actual_unique,
        x0=[1.0, 0.1],
        options={'maxiter': 1000, 'rhobeg': 0.3},
        tol=1e-6,
        method='COBYLA')
    if optimal_solution is None and solution is None:
        raise ValueError(f"Could not find deduplicated percentiles!")
    return optimal_solution if optimal_solution is not None else solution

desired_num_cutoffs = 32
all_solutions = []
for band in lcs[0].expected_bands:
    df = get_values_df(band)
    for i, column in enumerate(df.columns):
        print(band, column)
        percentiles = np.array(unique_percentiles(df[column], desired_num_cutoffs), dtype=np.float32)
        median_scale = np.median(percentiles[1:] - percentiles[:-1])
        all_solutions.append({
            'band': band,
            'column_index': i,
            'column': column,
            'median_scale': float(median_scale),
            'cutoffs': percentiles,
        })

with_settings = {
    'window_size': window_size,
    'band_time_diff': band_time_diff,
    'desired_num_cutoffs': desired_num_cutoffs,
    'solutions': all_solutions
}

## Save to nicely-formatted JSON

Writes numpy arrays as strings, then rewrites those strings.

In [None]:
import datetime
import json

from justice import path_util

class ArrayPreEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return "<<<<{}>>>>".format(", ".join(f"{x:.8f}" for x in obj.tolist()))
        else:
            print(obj)
        return json.JSONEncoder.default(self, obj)

def _encode(x):
    result = json.dumps(x, indent=2, cls=ArrayPreEncoder).replace('"<<<<', '[').replace('>>>>"', ']')
    json.loads(result)  # error if not decodable
    return result

now = datetime.datetime.now()
path = path_util.data_dir / 'tf_align_model' / 'feature_extraction' / (
    f"cutoffs__window_sz-{window_size}__{now.year:04d}-{now.month:02d}-{now.day:02d}.json")
path.parent.mkdir(parents=True, exist_ok=True)
with open(str(path), 'w') as f:
    f.write(_encode(with_settings))