In [1]:
import numpy as np

import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability import bijectors as tfb
from tensorflow_probability import distributions as tfd

from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import tensor_util
from tensorflow_probability.python.internal import tensorshape_util

In [2]:
import matplotlib.pylab as plt

In [3]:
from tensorflow_probability.python.internal import tensorshape_util

In [4]:
import pandas as pd

# Helper Functions

In [5]:
import functools

def debug(func):
    """Print the function signature and return value"""
    @functools.wraps(func)
    def wrapper_debug(*args, **kwargs):
        args_repr = [repr(a) for a in args]                      # 1
        kwargs_repr = [f"{k}={v!r}" for k, v in kwargs.items()]  # 2
        signature = ", ".join(args_repr + kwargs_repr)           # 3
        print(f"Calling {func.__name__}({signature})")
        value = func(*args, **kwargs)
        print(f"{func.__name__!r} returned {value!r}")           # 4
        return value
    return wrapper_debug

In [6]:
def for_all_methods(decorator):
    def decorate(cls):
        for attr in cls.__dict__: # there's propably a better way to do this
            if callable(getattr(cls, attr)):
                setattr(cls, attr, decorator(getattr(cls, attr)))
        return cls
    return decorate

# Quantile Distribution Wrapper

In [11]:
import numpy as np

import scipy.interpolate as I

import tensorflow as tf

from tensorflow_probability import distributions as tfd

from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import tensor_util
from tensorflow_probability.python.internal import tensorshape_util
from tensorflow_probability.python.internal import prefer_static
from tensorflow_probability.python.internal import reparameterization

#from ..losses import PinballLoss

class QuantileRegressionDistributionWrapper(tfd.Distribution):

    def __init__(self,
                 quantiles,
                 constrain_quantiles=tf.identity,#PinballLoss.constrain_quantiles,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='QuantileDistributionWrapper'):

        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype(
                [quantiles], dtype_hint=tf.float32)

            self.quantiles = tensor_util.convert_nonref_to_tensor(
                quantiles, dtype=dtype, name='quantiles')

            assert self.quantiles.shape[-1] == 100, '100 Qunatiles reqired'

            self.quantiles = constrain_quantiles(self.quantiles)

            self._pdf_sp, self._cdf_sp, self._quantile_sp = self.make_interp_spline()

            super().__init__(
                dtype=dtype,
                reparameterization_type=reparameterization.FULLY_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                name=name)

    def make_interp_spline(self):
        """
        Generates the Spline Interpolation.
        """
        percentiles = np.linspace(0., 1., 100, dtype=np.float32)
        quantiles = self.quantiles.numpy().copy()

        float_min = np.finfo(np.float32).min * np.ones_like(quantiles[..., :1])
        float_max = np.finfo(np.float32).max * \
            np.ones_like(quantiles[..., -1:])

        min_q = quantiles[..., :1] - 15 * np.diff(quantiles)[..., :1]
        max_q = quantiles[..., -1:] + 15 * np.diff(quantiles)[..., -1:]

        x = np.concatenate([float_min, quantiles, float_max], axis=-1)
        y = np.concatenate(
            [percentiles[..., :1], percentiles, percentiles[..., -1:]], axis=-1)

        x = quantiles
        y = percentiles

        x = x.reshape(-1, x.shape[-1])

        x_min = np.min(x, axis=-1)  # [shape]
        x_max = np.max(x, axis=-1)  # [shape]

        quantile_sp = [I.make_interp_spline(
            y=np.squeeze(x[i]),
            x=np.squeeze(y),
            k=3,
            bc_type='clamped',
            # assume_sorted=True
        ) for i in range(x.shape[0])]

        cdf_sp = [I.make_interp_spline(
            y=np.squeeze(y),
            x=np.squeeze(x[i]),
            k=3,
            bc_type='clamped',
            # assume_sorted=True
        ) for i in range(x.shape[0])]
        pdf_sp = [s.derivative(1) for s in cdf_sp]

        def pdf_sp_fn(x):
            y = []
            z_clip = np.clip(x, x_min, x_max)
            for i, ip in enumerate(pdf_sp):
                y.append(ip(z_clip[..., i]).astype(np.float32))
            y = np.stack(y, axis=-1)
            return y

        def cdf_sp_fn(x):
            y = []
            z_clip = np.clip(x, x_min, x_max)
            for i, ip in enumerate(cdf_sp):
                y.append(ip(z_clip[..., i]).astype(np.float32))
            y = np.stack(y, axis=-1)
            return y

        def quantile_sp_fn(p):
            q = []
            p_clip = np.clip(p, np.zeros_like(x_min), np.ones_like(x_max))
            for i, ip in enumerate(quantile_sp):
                q.append(ip(p_clip[..., i]).astype(np.float32))
            q = np.stack(q, axis=-1)
            return q

        return pdf_sp_fn, cdf_sp_fn, quantile_sp_fn

    def reshape_out(self, sample_shape, y):
        output_shape = prefer_static.broadcast_shape(
            sample_shape, self.batch_shape)
        return tf.reshape(y, output_shape)

    def _eval_spline(self, x, attr):
        x = np.asarray(x, dtype=np.float32)
        batch_rank = tensorshape_util.rank(self.batch_shape)
        sample_shape = x.shape

        if x.shape[-batch_rank:] == self.batch_shape:
            shape = list(x.shape[:-batch_rank]) + [-1]
            x = tf.reshape(x, shape)
        else:
            x = x[..., None]

        return self.reshape_out(sample_shape, getattr(self, attr)(x))

    def _batch_shape(self):
        shape = tf.TensorShape(prefer_static.shape(self.quantiles)[:-1])
        return tf.broadcast_static_shape(shape, tf.TensorShape([1]))

    def _event_shape(self):
        return tf.TensorShape([])

    def _log_prob(self, x):
        return np.log(self.prob(x))

    def _prob(self, x):
        return self._eval_spline(x, '_pdf_sp')

    def _log_cdf(self, x):
        return np.log(self.cdf(x))

    def _cdf(self, x):
        return self._eval_spline(x, '_cdf_sp')

    def _mean(self):
        return self._quantile(0.5)

    def _quantile(self, p):
        # input_shape = p.shape
        # q = self.quantiles
        # perm = tf.concat([[q.ndim - 1], tf.range(0, q.ndim - 1)], 0)
        # q = tfp.math.interp_regular_1d_grid(
        #     p,
        #     x_ref_min=0.,
        #     x_ref_max=1.,
        #     y_ref=tf.transpose(q, perm),
        #     axis=0)
        # return self.reshape_out(input_shape, q)
        return self._eval_spline(p, '_quantile_sp')


# Normal vs Wrapper

In [12]:
def gen_dist (batch_shape):
    order=5
    if batch_shape != []:
        n=tfd.Normal(loc=tf.zeros((batch_shape)), scale=tf.ones((batch_shape)))
        bs = QuantileRegressionDistributionWrapper(tf.broadcast_to(tf.linspace(0.1,2,100),batch_shape + [100]))
    else:
        n=tfd.Normal(loc=tf.zeros((1)), scale=tf.ones((1)))
        bs = QuantileRegressionDistributionWrapper(tf.linspace(0.1,2,100))
    return n, bs

In [13]:
result={}
batch_sizes=[[], [1], [10], [10,2]]
inputs=np.concatenate([[b,[1]+b,b+[1]] for b in batch_sizes if b != []]).tolist()

for batch_size in batch_sizes:
    n, bs = gen_dist(batch_size)

    print(f'n: {n}')
    print(f'bs: {bs}')

    tmp_res={}

    for i in inputs:
        try:
            tmp_res[f'n.prob({i})'] = str(n.prob(tf.ones(i)).shape)
        except Exception as e:
            tmp_res[f'n.prob({i})'] = 'E'#str(e)
        try:
            tmp_res[f'n.quantile({i})'] = str(n.quantile(0.5).shape)
        except Exception as e:
            tmp_res[f'n.quantile({i})'] = 'E'#str(e)
        try:
            tmp_res[f'bs.prob({i})'] = str(bs.prob(tf.ones(i)).shape)
        except Exception as e:
            tmp_res[f'bs.prob({i})'] = 'E'#str(e)
        try:
            tmp_res[f'bs.quantile({i})'] = str(bs.quantile(0.5).shape)
        except Exception as e:
            tmp_res[f'bs.quantile({i})'] = 'E'#str(e)
    result[str(batch_size)]=tmp_res

n: tfp.distributions.Normal("Normal", batch_shape=[1], event_shape=[], dtype=float32)
bs: tfp.distributions.QuantileRegressionDistributionWrapper("QuantileDistributionWrapper", batch_shape=[1], event_shape=[], dtype=float32)
n: tfp.distributions.Normal("Normal", batch_shape=[1], event_shape=[], dtype=float32)
bs: tfp.distributions.QuantileRegressionDistributionWrapper("QuantileDistributionWrapper", batch_shape=[1], event_shape=[], dtype=float32)
n: tfp.distributions.Normal("Normal", batch_shape=[10], event_shape=[], dtype=float32)
bs: tfp.distributions.QuantileRegressionDistributionWrapper("QuantileDistributionWrapper", batch_shape=[10], event_shape=[], dtype=float32)
n: tfp.distributions.Normal("Normal", batch_shape=[10, 2], event_shape=[], dtype=float32)
bs: tfp.distributions.QuantileRegressionDistributionWrapper("QuantileDistributionWrapper", batch_shape=[10, 2], event_shape=[], dtype=float32)


In [14]:
df=pd.DataFrame(result)

In [15]:
df.index=pd.MultiIndex.from_frame(df.index.str.extract('(n|bs).*(prob|quantile)\((.*)\)'),
                                  names=['dist','op','shape'])
df

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,[],[1],[10],"[10, 2]"
dist,op,shape,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
n,prob,[1],"(1,)","(1,)","(10,)","(10, 2)"
n,quantile,[1],"(1,)","(1,)","(10,)","(10, 2)"
bs,prob,[1],"(1,)","(1,)","(10,)","(10, 2)"
bs,quantile,[1],"(1,)","(1,)","(10,)","(10, 2)"
n,prob,"[1, 1]","(1, 1)","(1, 1)","(1, 10)","(10, 2)"
n,quantile,"[1, 1]","(1,)","(1,)","(10,)","(10, 2)"
bs,prob,"[1, 1]","(1, 1)","(1, 1)","(1, 10)","(10, 2)"
bs,quantile,"[1, 1]","(1,)","(1,)","(10,)","(10, 2)"
n,prob,[10],"(10,)","(10,)","(10,)",E
n,quantile,[10],"(1,)","(1,)","(10,)","(10, 2)"


In [16]:
df.loc[('n','prob')] == df.loc[('bs','prob')]

  df.loc[('n','prob')] == df.loc[('bs','prob')]


Unnamed: 0_level_0,[],[1],[10],"[10, 2]"
shape,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
[1],True,True,True,True
"[1, 1]",True,True,True,True
[10],True,True,True,True
"[1, 10]",True,True,True,True
"[10, 1]",True,True,True,False
"[10, 2]",True,True,True,True
"[1, 10, 2]",True,True,True,True
"[10, 2, 1]",True,True,True,True


In [17]:
df.loc[('n','quantile')] == df.loc[('bs','quantile')]

  df.loc[('n','quantile')] == df.loc[('bs','quantile')]


Unnamed: 0_level_0,[],[1],[10],"[10, 2]"
shape,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
[1],True,True,True,True
"[1, 1]",True,True,True,True
[10],True,True,True,True
"[1, 10]",True,True,True,True
"[10, 1]",True,True,True,True
"[10, 2]",True,True,True,True
"[1, 10, 2]",True,True,True,True
"[10, 2, 1]",True,True,True,True


In [18]:
n, tq = gen_dist ([32,48])
n, tq

(<tfp.distributions.Normal 'Normal' batch_shape=[32, 48] event_shape=[] dtype=float32>,
 <tfp.distributions.QuantileRegressionDistributionWrapper 'QuantileDistributionWrapper' batch_shape=[32, 48] event_shape=[] dtype=float32>)

In [19]:
x=tf.ones((100,32,48))
tq.prob(x).shape,n.prob(x).shape

(TensorShape([100, 32, 48]), TensorShape([100, 32, 48]))