In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [2]:
import os
import shutil

%matplotlib notebook
from __future__ import division, print_function
import numpy as np
import matplotlib.pyplot as plt

import bilby
from bilby.core.prior import Uniform
from bilby.gw.conversion import convert_to_lal_binary_black_hole_parameters, generate_all_bbh_parameters

from gwpy.timeseries import TimeSeries

import lal
import lalsimulation as lalsim

from bilby.core import utils
from bilby.core.utils import logger
from bilby.gw.conversion import bilby_to_lalsimulation_spins
from bilby.gw.utils import (lalsim_GetApproximantFromString,
                    lalsim_SimInspiralFD,
                    lalsim_SimInspiralChooseFDWaveform,
                    lalsim_SimInspiralWaveformParamsInsertTidalLambda1,
                    lalsim_SimInspiralWaveformParamsInsertTidalLambda2,
                    lalsim_SimInspiralChooseFDWaveformSequence,
                    convert_args_list_to_float, 
                    _get_lalsim_approximant)

import copy

In [3]:
import numpy as np
import sklearn
from sklearn.preprocessing import PolynomialFeatures

import scipy
from scipy import signal
from scipy.interpolate import interp1d
from scipy.signal import butter, filtfilt

In [4]:
import collections
import numbers
from itertools import chain, combinations
from itertools import combinations_with_replacement as combinations_w_r

import numpy as np
from scipy import sparse
from scipy.interpolate import BSpline
from scipy.special import comb

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils import check_array
from sklearn.utils.deprecation import deprecated
from sklearn.utils.validation import check_is_fitted, FLOAT_DTYPES, _check_sample_weight
from sklearn.utils.validation import _check_feature_names_in
from sklearn.utils.stats import _weighted_percentile

# Scipy Polynomial Transform Function

In [5]:
def transform(self, X):
        """Transform data to polynomial features.
        Parameters
        ----------
        X : {array-like, sparse matrix} of shape (n_samples, n_features)
            The data to transform, row by row.
            Prefer CSR over CSC for sparse input (for speed), but CSC is
            required if the degree is 4 or higher. If the degree is less than
            4 and the input format is CSC, it will be converted to CSR, have
            its polynomial features generated, then converted back to CSC.
            If the degree is 2 or 3, the method described in "Leveraging
            Sparsity to Speed Up Polynomial Feature Expansions of CSR Matrices
            Using K-Simplex Numbers" by Andrew Nystrom and John Hughes is
            used, which is much faster than the method used on CSC input. For
            this reason, a CSC input will be converted to CSR, and the output
            will be converted back to CSC prior to being returned, hence the
            preference of CSR.
        Returns
        -------
        XP : {ndarray, sparse matrix} of shape (n_samples, NP)
            The matrix of features, where `NP` is the number of polynomial
            features generated from the combination of inputs. If a sparse
            matrix is provided, it will be converted into a sparse
            `csr_matrix`.
        """
        check_is_fitted(self)

        X = self._validate_data(
            X, order="F", dtype=FLOAT_DTYPES, reset=False, accept_sparse=("csr", "csc")
        )

        n_samples, n_features = X.shape

        if sparse.isspmatrix_csr(X):
            if self._max_degree > 3:
                return self.transform(X.tocsc()).tocsr()
            to_stack = []
            if self.include_bias:
                to_stack.append(
                    sparse.csc_matrix(np.ones(shape=(n_samples, 1), dtype=X.dtype))
                )
            if self._min_degree <= 1 and self._max_degree > 0:
                to_stack.append(X)
            for deg in range(max(2, self._min_degree), self._max_degree + 1):
                Xp_next = _csr_polynomial_expansion(
                    X.data, X.indices, X.indptr, X.shape[1], self.interaction_only, deg
                )
                if Xp_next is None:
                    break
                to_stack.append(Xp_next)
            if len(to_stack) == 0:
                # edge case: deal with empty matrix
                XP = sparse.csr_matrix((n_samples, 0), dtype=X.dtype)
            else:
                XP = sparse.hstack(to_stack, format="csr")
        elif sparse.isspmatrix_csc(X) and self._max_degree < 4:
            return self.transform(X.tocsr()).tocsc()
        elif sparse.isspmatrix(X):
            combinations = self._combinations(
                n_features=n_features,
                min_degree=self._min_degree,
                max_degree=self._max_degree,
                interaction_only=self.interaction_only,
                include_bias=self.include_bias,
            )
            columns = []
            for combi in combinations:
                if combi:
                    out_col = 1
                    for col_idx in combi:
                        out_col = X[:, col_idx].multiply(out_col)
                    columns.append(out_col)
                else:
                    bias = sparse.csc_matrix(np.ones((X.shape[0], 1)))
                    columns.append(bias)
            XP = sparse.hstack(columns, dtype=X.dtype).tocsc()
        else:
            # Do as if _min_degree = 0 and cut down array after the
            # computation, i.e. use _n_out_full instead of n_output_features_.
            XP = np.empty(
                shape=(n_samples, self._n_out_full), dtype=X.dtype, order=self.order
            )

            # What follows is a faster implementation of:
            # for i, comb in enumerate(combinations):
            #     XP[:, i] = X[:, comb].prod(1)
            # This implementation uses two optimisations.
            # First one is broadcasting,
            # multiply ([X1, ..., Xn], X1) -> [X1 X1, ..., Xn X1]
            # multiply ([X2, ..., Xn], X2) -> [X2 X2, ..., Xn X2]
            # ...
            # multiply ([X[:, start:end], X[:, start]) -> ...
            # Second optimisation happens for degrees >= 3.
            # Xi^3 is computed reusing previous computation:
            # Xi^3 = Xi^2 * Xi.

            # degree 0 term
            if self.include_bias:
                XP[:, 0] = 1
                current_col = 1
            else:
                current_col = 0

            if self._max_degree == 0:
                return XP

            # degree 1 term
            XP[:, current_col : current_col + n_features] = X
            index = list(range(current_col, current_col + n_features))
            current_col += n_features
            index.append(current_col)

            # loop over degree >= 2 terms
            for _ in range(2, self._max_degree + 1):
                new_index = []
                end = index[-1]
                for feature_idx in range(n_features):
                    start = index[feature_idx]
                    new_index.append(current_col)
                    if self.interaction_only:
                        start += index[feature_idx + 1] - index[feature_idx]
                    next_col = current_col + end - start
                    if next_col <= current_col:
                        break
                    # XP[:, start:end] are terms of degree d - 1
                    # that exclude feature #feature_idx.
                    np.multiply(
                        XP[:, start:end],
                        X[:, feature_idx : feature_idx + 1],
                        out=XP[:, current_col:next_col],
                        casting="no",
                    )
                    current_col = next_col

                new_index.append(current_col)
                index = new_index

            if self._min_degree > 1:
                n_XP, n_Xout = self._n_out_full, self.n_output_features_
                if self.include_bias:
                    Xout = jnp.empty(
                        shape=(n_samples, n_Xout), dtype=XP.dtype, order=self.order
                    )
                    Xout[:, 0] = 1
                    Xout[:, 1:] = XP[:, n_XP - n_Xout + 1 :]
                else:
                    Xout = XP[:, n_XP - n_Xout :].copy()
                XP = Xout
        return XP

In [7]:
XD = np.arange(6).reshape(3, 2)
poly = PolynomialFeatures(2)
poly.fit_transform(XD)
poly = PolynomialFeatures(interaction_only=True)
poly.fit_transform(XD)

array([[ 1.,  0.,  1.,  0.],
       [ 1.,  2.,  3.,  6.],
       [ 1.,  4.,  5., 20.]])

In [8]:
XP1 = transform(poly,XD)

# Scipy Result

In [9]:
XP1

array([[ 1.,  0.,  1.,  0.],
       [ 1.,  2.,  3.,  6.],
       [ 1.,  4.,  5., 20.]])

In [10]:
X = np.arange(6).reshape(3, 2)
poly = PolynomialFeatures(2)
poly.fit_transform(X)
poly = PolynomialFeatures(interaction_only=True)
poly.fit_transform(X)

array([[ 1.,  0.,  1.,  0.],
       [ 1.,  2.,  3.,  6.],
       [ 1.,  4.,  5., 20.]])

In [11]:
poly.order

'C'

In [12]:
poly._n_out_full

4

In [13]:
poly._min_degree

0

In [14]:
poly.interaction_only

True

# JAX transform function

In [15]:
def transform3(self, X):
    n_samples, n_features = X.shape
    XP = jnp.empty(
        shape=(n_samples, self._n_out_full), dtype=X.dtype
    )
    if self.include_bias:
        XP = XP.at[:,0].set(1)
        current_col = 1
    else:
        current_col = 0

    if self._max_degree == 0:
        return XP

    XP = XP.at[:, current_col : current_col + n_features].set(X)
    index = list(range(current_col, current_col + n_features))
    current_col += n_features
    index.append(current_col)

    for _ in range(2, self._max_degree + 1):
        new_index = []
        end = index[-1]
        for feature_idx in range(n_features):
            start = index[feature_idx]
            new_index.append(current_col)
            if self.interaction_only:
                start += index[feature_idx + 1] - index[feature_idx]
            next_col = current_col + end - start
            if next_col <= current_col:
                break
            XP = XP.at[:, current_col:next_col].set(jnp.multiply(
                XP[:, start:end],
                X[:, feature_idx : feature_idx + 1],
            ))
            current_col = next_col

        new_index.append(current_col)
        index = new_index
    print(XP)
    if self._min_degree > 1:
        n_XP, n_Xout = self._n_out_full, self.n_output_features_
        if self.include_bias:
            Xout = jnp.empty(
                shape=(n_samples, n_Xout), dtype=XP.dtype
            )
            Xout = Xout.at[:,0].set(1)
            #Xout[:, 0] = 1
            Xout = Xout.at[:, 1:].set(XP[:, n_XP - n_Xout + 1 :])
            #Xout[:, 1:] = XP[:, n_XP - n_Xout + 1 :]
        else:
            Xout = XP[:, n_XP - n_Xout :].copy()
        print(Xout)
        XP = Xout.copy()
    return XP

In [16]:
XP = transform3(poly, X)

  lax_internal._check_user_dtype_supported(dtype, "zeros")


[[ 1  0  1  0]
 [ 1  2  3  6]
 [ 1  4  5 20]]


# JAX Result

In [17]:
XP

DeviceArray([[ 1,  0,  1,  0],
             [ 1,  2,  3,  6],
             [ 1,  4,  5, 20]], dtype=int32)