## Reranking an existing model

Can we improve on the existing LR model?

Ideally we'd first now how it's poor, then impose constraints to correct for that.

In [1]:
import numpy as np
import scipy.stats
import pandas as pd
import matplotlib.pyplot as plt
import plotly.express as px
from sklearn.linear_model import LogisticRegression
import maxentropy
import maxentropy.utils as utils

import plotly.io as pio
pio.renderers.default = 'plotly_mimetype'

import plotly.express as px

from sklearn.datasets import load_iris, load_breast_cancer

iris = load_iris()
cancer = load_breast_cancer(as_frame=True)

## First example: find the model with minimum relative entropy to some prior model subject to a non-negativity constraint

In [2]:
cancer['feature_names']

array(['mean radius', 'mean texture', 'mean perimeter', 'mean area',
       'mean smoothness', 'mean compactness', 'mean concavity',
       'mean concave points', 'mean symmetry', 'mean fractal dimension',
       'radius error', 'texture error', 'perimeter error', 'area error',
       'smoothness error', 'compactness error', 'concavity error',
       'concave points error', 'symmetry error',
       'fractal dimension error', 'worst radius', 'worst texture',
       'worst perimeter', 'worst area', 'worst smoothness',
       'worst compactness', 'worst concavity', 'worst concave points',
       'worst symmetry', 'worst fractal dimension'], dtype='<U23')

In [3]:
df_cancer = cancer['data']
X_cancer = cancer['data'].values
y_cancer = cancer['target']

### Question: Can we fit a neural network for classification, remove the final softmax layer, and then apply this?

In [4]:
from sklearn.neural_network import MLPClassifier

In [5]:
net = MLPClassifier(hidden_layer_sizes=(100,))

net.fit(X_cancer, y_cancer)

net.score(X_cancer, y_cancer)

0.9560632688927944

In [6]:
net._predict??

[0;31mSignature:[0m [0mnet[0m[0;34m.[0m[0m_predict[0m[0;34m([0m[0mX[0m[0;34m,[0m [0mcheck_input[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m   
    [0;32mdef[0m [0m_predict[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mX[0m[0;34m,[0m [0mcheck_input[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0;34m"""Private predict method with optional input validation"""[0m[0;34m[0m
[0;34m[0m        [0my_pred[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_forward_pass_fast[0m[0;34m([0m[0mX[0m[0;34m,[0m [0mcheck_input[0m[0;34m=[0m[0mcheck_input[0m[0;34m)[0m[0;34m[0m
[0;34m[0m[0;34m[0m
[0;34m[0m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mn_outputs_[0m [0;34m==[0m [0;36m1[0m[0;34m:[0m[0;34m[0m
[0;34m[0m            [0my_pred[0m [0;34m=[0m [0my_pred[0m[0;34m.[0m[0mravel[0m[0;34m([0m[0;34m)[0m[0;34m[0m
[0;34m[0m[0;34m[0m
[0;34m[0m       

In [7]:
# outputs = net._forward_pass_fast(X_cancer, check_input=True)

In [8]:
import toolz as tz

In [9]:
@tz.curry
# def forward_pass_without_output_layer(net, target=slice(None)):
def forward_pass(net, X, target=slice(None)):
    from sklearn.neural_network._base import ACTIVATIONS
    from sklearn.utils.extmath import safe_sparse_dot
    
    # Initialize first layer
    activation = X

    # Forward propagate
    hidden_activation = ACTIVATIONS[net.activation]
    for i in range(net.n_layers_ - 1):
        activation = safe_sparse_dot(activation, net.coefs_[i])
        activation += net.intercepts_[i]
        if i != net.n_layers_ - 2:
            hidden_activation(activation)
    # Should we really apply the logistic function in the output layer?
    # print(net.out_activation_)
    output_activation = ACTIVATIONS[net.out_activation_]
    output_activation(activation)
    return activation[:, target]

In [10]:
forward_pass(net, X_cancer, 0)[:10]

array([3.75287115e-22, 6.57484442e-10, 2.00950344e-07, 4.25536423e-01,
       3.05404238e-03, 1.49599863e-02, 1.89810044e-08, 1.39019140e-03,
       1.43466827e-01, 5.61311556e-03])

## Now: apply it to a multi-class classification problem (n_classes > 2)

In [11]:
from sklearn.datasets import load_wine

In [12]:
wine = load_wine(as_frame=True)

In [13]:
df_wine = wine['data']
X_wine = wine['data'].values
y_wine = wine['target']

In [14]:
# wine['feature_names']

In [15]:
df_wine[:3]

Unnamed: 0,alcohol,malic_acid,ash,alcalinity_of_ash,magnesium,total_phenols,flavanoids,nonflavanoid_phenols,proanthocyanins,color_intensity,hue,od280/od315_of_diluted_wines,proline
0,14.23,1.71,2.43,15.6,127.0,2.8,3.06,0.28,2.29,5.64,1.04,3.92,1065.0
1,13.2,1.78,2.14,11.2,100.0,2.65,2.76,0.26,1.28,4.38,1.05,3.4,1050.0
2,13.16,2.36,2.67,18.6,101.0,2.8,3.24,0.3,2.81,5.68,1.03,3.17,1185.0


In [16]:
X_wine[:3]

array([[1.423e+01, 1.710e+00, 2.430e+00, 1.560e+01, 1.270e+02, 2.800e+00,
        3.060e+00, 2.800e-01, 2.290e+00, 5.640e+00, 1.040e+00, 3.920e+00,
        1.065e+03],
       [1.320e+01, 1.780e+00, 2.140e+00, 1.120e+01, 1.000e+02, 2.650e+00,
        2.760e+00, 2.600e-01, 1.280e+00, 4.380e+00, 1.050e+00, 3.400e+00,
        1.050e+03],
       [1.316e+01, 2.360e+00, 2.670e+00, 1.860e+01, 1.010e+02, 2.800e+00,
        3.240e+00, 3.000e-01, 2.810e+00, 5.680e+00, 1.030e+00, 3.170e+00,
        1.185e+03]])

In [17]:
model_lr = LogisticRegression(max_iter=5_000)
model_lr.fit(X_wine, y_wine)
model_lr.score(X_wine, y_wine)

0.9943820224719101

In [18]:
net = MLPClassifier(hidden_layer_sizes=(100,), learning_rate_init=0.01, max_iter=1000, random_state=7)
net.fit(X_wine, y_wine)
net.score(X_wine, y_wine)

0.8820224719101124

In [19]:
forward_pass(net, X_wine)[:10]

array([[9.99997359e-01, 1.74069242e-06, 9.00339952e-07],
       [9.99999998e-01, 1.14744691e-09, 4.92073562e-10],
       [1.00000000e+00, 1.53418384e-11, 2.32317785e-11],
       [1.00000000e+00, 3.20726835e-19, 3.52267199e-17],
       [9.35703608e-01, 5.87869811e-02, 5.50941106e-03],
       [1.00000000e+00, 8.85365192e-19, 4.49256502e-17],
       [1.00000000e+00, 1.93530313e-16, 8.48028138e-16],
       [1.00000000e+00, 4.77580015e-12, 8.62673306e-11],
       [9.99999990e-01, 7.44782829e-09, 2.49191611e-09],
       [9.99999958e-01, 9.93821203e-09, 3.16033288e-08]])

In [20]:
net.predict_proba(X_wine)[:10]

array([[9.99997359e-01, 1.74069242e-06, 9.00339952e-07],
       [9.99999998e-01, 1.14744691e-09, 4.92073562e-10],
       [1.00000000e+00, 1.53418384e-11, 2.32317785e-11],
       [1.00000000e+00, 3.20726835e-19, 3.52267199e-17],
       [9.35703608e-01, 5.87869811e-02, 5.50941106e-03],
       [1.00000000e+00, 8.85365192e-19, 4.49256502e-17],
       [1.00000000e+00, 1.93530313e-16, 8.48028138e-16],
       [1.00000000e+00, 4.77580015e-12, 8.62673306e-11],
       [9.99999990e-01, 7.44782829e-09, 2.49191611e-09],
       [9.99999958e-01, 9.93821203e-09, 3.16033288e-08]])

In [21]:
# net.predict_log_proba(X_wine[:10])

In [22]:
# net.predict_proba(X_wine[:10])

#### Now define a sampler

In [23]:
# auxiliary = scipy.stats.uniform(-0.2, 1.2)   # i.e. from -0.2 to 1.0

# sampler = maxentropy.utils.auxiliary_sampler_scipy(auxiliary, n_samples=10_000)

In [86]:
stretched_minima, stretched_maxima = utils.bounds_stretched(X_wine, 10.0)
uniform_dist = scipy.stats.uniform(
    stretched_minima, stretched_maxima - stretched_minima
)
sampler = utils.auxiliary_sampler_scipy(
    uniform_dist, n_dims=len(wine["feature_names"]), n_samples=100_000
)

In [87]:
np.mean(next(sampler)[0] < 0)

0.4060492307692308

In [88]:
@tz.curry
def non_neg(column, x):
    return x[:, column] >= 0

In [89]:
# def non_neg(x):
#     return x >= 0

In [90]:
def scalar(x):
    return 1.0

In [91]:
feature_functions = [non_neg(i) for i in range(len(wine['feature_names']))]

In [92]:
from maxentropy.utils import feature_sampler

In [93]:
sampleFgen = feature_sampler(
            feature_functions,
            sampler,
            vectorized=True,
            matrix_format='ndarray',
        )

In [94]:
next(sampleFgen)[0].shape

(100000, 13)

In [95]:
next(sampleFgen)[0].mean()

0.5951792307692307

In [96]:
next(sampleFgen)[1].shape

(100000,)

In [97]:
next(sampleFgen)[2].shape

(100000, 13)

#### The neural network has fit K different models for the K different target classes.

Here we just twiddle the density for the first target class:

In [98]:
outputs = forward_pass(net, X_wine)
outputs[:3]

array([[9.99997359e-01, 1.74069242e-06, 9.00339952e-07],
       [9.99999998e-01, 1.14744691e-09, 4.92073562e-10],
       [1.00000000e+00, 1.53418384e-11, 2.32317785e-11]])

In [162]:
outputs = net.predict_proba(X_wine)
outputs[:3]

array([[9.99997359e-01, 1.74069242e-06, 9.00339952e-07],
       [9.99999998e-01, 1.14744691e-09, 4.92073562e-10],
       [1.00000000e+00, 1.53418384e-11, 2.32317785e-11]])

In [163]:
outputs.mean(axis=0)

array([0.42208937, 0.28603808, 0.29187255])

In [171]:
np.unique(y_wine, return_counts=True)[1]

array([59, 71, 48])

In [173]:
np.bincount(y_wine)

array([59, 71, 48])

In [164]:
centered_outputs = outputs - outputs.mean(axis=0)

In [165]:
np.round(outputs.mean(axis=0), 2)

array([0.42, 0.29, 0.29])

In [166]:
centered_outputs[:3]

array([[ 0.57790799, -0.28603634, -0.29187165],
       [ 0.57791063, -0.28603808, -0.29187255],
       [ 0.57791063, -0.28603808, -0.29187255]])

In [167]:
@tz.curry
def forward_pass_centered(net, target_class, xs):
    # return forward_pass_without_output_layer(net, xs)[:, target_class] - outputs[:, target_class].mean()
    output = net.predict_proba(xs)[:, target_class] - outputs[:, target_class].mean()
    return output

The above seems to work, but the logic is wrong ...

#### Now do the sensible thing.

In [128]:
@tz.curry
def log_p_x_given_k(net, class_probabilities, target_class, xs):
    """
    This calculates the log of p(x | k = target_class) up to an additive constant (independent of k).

    Since:
    
        p(x | k) = p(k | x) / p(k) * p(x)

    we have:

        log p(x | k) = log p(k | x) - log p(k) + additive_constant_indep_of_k

    """
    import pdb
    pdb.set_trace()
    # return forward_pass_without_output_layer(net, xs)[:, target_class] - outputs[:, target_class].mean()
    output = net.predict_log_proba(xs)[:, target_class] - np.log(class_probabilities) outputs[:, target_class].mean()
    return output

In [131]:
forward_pass_centered(net, 0, X_wine)[:3]

> [0;32m/var/folders/yv/wb9c1bwx4r91d458q1c8ybhm0000gn/T/ipykernel_79125/3802623020.py[0m(6)[0;36mforward_pass_centered[0;34m()[0m
[0;32m      4 [0;31m    [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      5 [0;31m    [0;31m# return forward_pass_without_output_layer(net, xs)[:, target_class] - outputs[:, target_class].mean()[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 6 [0;31m    [0moutput1[0m [0;34m=[0m [0mnet[0m[0;34m.[0m[0mpredict_proba[0m[0;34m([0m[0mxs[0m[0;34m)[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mtarget_class[0m[0;34m][0m [0;34m-[0m [0moutputs[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mtarget_class[0m[0;34m][0m[0;34m.[0m[0mmean[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      7 [0;31m    [0moutput2[0m [0;34m=[0m [0mnet[0m[0;34m.[0m[0mpredict_proba[0m[0;34m([0m[0mxs[0m[0;34m)[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mtarget_class[0m[0;

ipdb>  c


array([0.99999736, 1.        , 1.        ])

In [132]:
target_class = 0

model0 = maxentropy.SamplingMinKLDensity(
    feature_functions,
    sampler,
    prior_log_pdf = forward_pass_centered(net, target_class),
    matrix_format='ndarray',
    vectorized=True
)

> [0;32m/var/folders/yv/wb9c1bwx4r91d458q1c8ybhm0000gn/T/ipykernel_79125/3802623020.py[0m(6)[0;36mforward_pass_centered[0;34m()[0m
[0;32m      4 [0;31m    [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      5 [0;31m    [0;31m# return forward_pass_without_output_layer(net, xs)[:, target_class] - outputs[:, target_class].mean()[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 6 [0;31m    [0moutput1[0m [0;34m=[0m [0mnet[0m[0;34m.[0m[0mpredict_proba[0m[0;34m([0m[0mxs[0m[0;34m)[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mtarget_class[0m[0;34m][0m [0;34m-[0m [0moutputs[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mtarget_class[0m[0;34m][0m[0;34m.[0m[0mmean[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      7 [0;31m    [0moutput2[0m [0;34m=[0m [0mnet[0m[0;34m.[0m[0mpredict_proba[0m[0;34m([0m[0mxs[0m[0;34m)[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mtarget_class[0m[0;

ipdb>  c


In [133]:
np.array([X_wine.mean()])

array([69.13366292])

In [134]:
# X_wine[y_wine==target_class]

In [135]:
X_wine_subset = X_wine[y_wine == target_class]
X_wine_subset.shape

(59, 13)

In [136]:
k = model0.features(X_wine_subset).mean(axis=0)

In [137]:
model0.fit(k)

In [138]:
model0.feature_expectations()

array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

In [139]:
model0.params

array([18.04244082, 18.88670615, 18.81037928, 19.03515928, 18.82852139,
       18.86219035, 18.84327364, 19.02970859, 19.12122566, 18.89553858,
       18.93644025, 18.67136627, 19.1651229 ])

In [140]:
model0.predict_log_proba(X_wine)[:5]

> [0;32m/var/folders/yv/wb9c1bwx4r91d458q1c8ybhm0000gn/T/ipykernel_79125/3802623020.py[0m(6)[0;36mforward_pass_centered[0;34m()[0m
[0;32m      4 [0;31m    [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      5 [0;31m    [0;31m# return forward_pass_without_output_layer(net, xs)[:, target_class] - outputs[:, target_class].mean()[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 6 [0;31m    [0moutput1[0m [0;34m=[0m [0mnet[0m[0;34m.[0m[0mpredict_proba[0m[0;34m([0m[0mxs[0m[0;34m)[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mtarget_class[0m[0;34m][0m [0;34m-[0m [0moutputs[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mtarget_class[0m[0;34m][0m[0;34m.[0m[0mmean[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      7 [0;31m    [0moutput2[0m [0;34m=[0m [0mnet[0m[0;34m.[0m[0mpredict_proba[0m[0;34m([0m[0mxs[0m[0;34m)[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mtarget_class[0m[0;

ipdb>  c


array([-49.50167847, -49.50167583, -49.50167583, -49.50167583,
       -49.56597222])

#### Very low values! Let's proceed anyway. These will be compared against other very low values (for the other classes).

In [143]:
target_class = 1

model1 = maxentropy.SamplingMinKLDensity(
    feature_functions,
    sampler,
    prior_log_pdf = forward_pass_centered(net, target_class),
    matrix_format='ndarray',
    vectorized=True
)
X_wine_subset = X_wine[y_wine == target_class]
k1 = model0.features(X_wine_subset).mean(axis=0)
model1.fit(k1)

target_class = 2

model2 = maxentropy.SamplingMinKLDensity(
    feature_functions,
    sampler,
    prior_log_pdf = forward_pass_centered(net, target_class),
    matrix_format='ndarray',
    vectorized=True
)
X_wine_subset = X_wine[y_wine == target_class]
k2 = model0.features(X_wine_subset).mean(axis=0)
model2.fit(k2)

> [0;32m/var/folders/yv/wb9c1bwx4r91d458q1c8ybhm0000gn/T/ipykernel_79125/3802623020.py[0m(6)[0;36mforward_pass_centered[0;34m()[0m
[0;32m      4 [0;31m    [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      5 [0;31m    [0;31m# return forward_pass_without_output_layer(net, xs)[:, target_class] - outputs[:, target_class].mean()[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 6 [0;31m    [0moutput1[0m [0;34m=[0m [0mnet[0m[0;34m.[0m[0mpredict_proba[0m[0;34m([0m[0mxs[0m[0;34m)[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mtarget_class[0m[0;34m][0m [0;34m-[0m [0moutputs[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mtarget_class[0m[0;34m][0m[0;34m.[0m[0mmean[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      7 [0;31m    [0moutput2[0m [0;34m=[0m [0mnet[0m[0;34m.[0m[0mpredict_proba[0m[0;34m([0m[0mxs[0m[0;34m)[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mtarget_class[0m[0;

ipdb>  c


> [0;32m/var/folders/yv/wb9c1bwx4r91d458q1c8ybhm0000gn/T/ipykernel_79125/3802623020.py[0m(6)[0;36mforward_pass_centered[0;34m()[0m
[0;32m      4 [0;31m    [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      5 [0;31m    [0;31m# return forward_pass_without_output_layer(net, xs)[:, target_class] - outputs[:, target_class].mean()[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 6 [0;31m    [0moutput1[0m [0;34m=[0m [0mnet[0m[0;34m.[0m[0mpredict_proba[0m[0;34m([0m[0mxs[0m[0;34m)[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mtarget_class[0m[0;34m][0m [0;34m-[0m [0moutputs[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mtarget_class[0m[0;34m][0m[0;34m.[0m[0mmean[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      7 [0;31m    [0moutput2[0m [0;34m=[0m [0mnet[0m[0;34m.[0m[0mpredict_proba[0m[0;34m([0m[0mxs[0m[0;34m)[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mtarget_class[0m[0;

ipdb>  c


In [145]:
log_scores = np.array([
    model0.predict_log_proba(X_wine),
    model1.predict_log_proba(X_wine),
    model2.predict_log_proba(X_wine)
]).T
log_scores.shape

> [0;32m/var/folders/yv/wb9c1bwx4r91d458q1c8ybhm0000gn/T/ipykernel_79125/3802623020.py[0m(6)[0;36mforward_pass_centered[0;34m()[0m
[0;32m      4 [0;31m    [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      5 [0;31m    [0;31m# return forward_pass_without_output_layer(net, xs)[:, target_class] - outputs[:, target_class].mean()[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 6 [0;31m    [0moutput1[0m [0;34m=[0m [0mnet[0m[0;34m.[0m[0mpredict_proba[0m[0;34m([0m[0mxs[0m[0;34m)[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mtarget_class[0m[0;34m][0m [0;34m-[0m [0moutputs[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mtarget_class[0m[0;34m][0m[0;34m.[0m[0mmean[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      7 [0;31m    [0moutput2[0m [0;34m=[0m [0mnet[0m[0;34m.[0m[0mpredict_proba[0m[0;34m([0m[0mxs[0m[0;34m)[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mtarget_class[0m[0;

ipdb>  c


> [0;32m/var/folders/yv/wb9c1bwx4r91d458q1c8ybhm0000gn/T/ipykernel_79125/3802623020.py[0m(6)[0;36mforward_pass_centered[0;34m()[0m
[0;32m      4 [0;31m    [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      5 [0;31m    [0;31m# return forward_pass_without_output_layer(net, xs)[:, target_class] - outputs[:, target_class].mean()[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 6 [0;31m    [0moutput1[0m [0;34m=[0m [0mnet[0m[0;34m.[0m[0mpredict_proba[0m[0;34m([0m[0mxs[0m[0;34m)[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mtarget_class[0m[0;34m][0m [0;34m-[0m [0moutputs[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mtarget_class[0m[0;34m][0m[0;34m.[0m[0mmean[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      7 [0;31m    [0moutput2[0m [0;34m=[0m [0mnet[0m[0;34m.[0m[0mpredict_proba[0m[0;34m([0m[0mxs[0m[0;34m)[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mtarget_class[0m[0;

ipdb>  c


> [0;32m/var/folders/yv/wb9c1bwx4r91d458q1c8ybhm0000gn/T/ipykernel_79125/3802623020.py[0m(6)[0;36mforward_pass_centered[0;34m()[0m
[0;32m      4 [0;31m    [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      5 [0;31m    [0;31m# return forward_pass_without_output_layer(net, xs)[:, target_class] - outputs[:, target_class].mean()[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 6 [0;31m    [0moutput1[0m [0;34m=[0m [0mnet[0m[0;34m.[0m[0mpredict_proba[0m[0;34m([0m[0mxs[0m[0;34m)[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mtarget_class[0m[0;34m][0m [0;34m-[0m [0moutputs[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mtarget_class[0m[0;34m][0m[0;34m.[0m[0mmean[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      7 [0;31m    [0moutput2[0m [0;34m=[0m [0mnet[0m[0;34m.[0m[0mpredict_proba[0m[0;34m([0m[0mxs[0m[0;34m)[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mtarget_class[0m[0;

ipdb>  c


(178, 3)

In [146]:
from scipy.special import softmax

In [147]:
log_proba = softmax(log_scores, axis=1)

In [148]:
log_proba[:5]

array([[0.31708033, 0.39567387, 0.28724581],
       [0.3170812 , 0.39567322, 0.28724558],
       [0.3170812 , 0.39567322, 0.28724558],
       [0.3170812 , 0.39567322, 0.28724558],
       [0.29562128, 0.41721152, 0.2871672 ]])

In [149]:
net.n_outputs_

3

In [150]:
pred = net._label_binarizer.inverse_transform(log_proba)
pred

array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1,
       1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2])

In [151]:
from sklearn.metrics import accuracy_score

In [152]:
accuracy_score(y_wine, pred)

0.651685393258427

### Can we do it just using the neural network's `predict_proba` outputs?

In [153]:
np.sort([4, 1, 2, 3])

array([1, 2, 3, 4])

In [154]:
def thing1(xs):
    return forward_pass_centered(net, 0)(xs)

In [155]:
thing1(X_wine)[:5]

> [0;32m/var/folders/yv/wb9c1bwx4r91d458q1c8ybhm0000gn/T/ipykernel_79125/3802623020.py[0m(6)[0;36mforward_pass_centered[0;34m()[0m
[0;32m      4 [0;31m    [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      5 [0;31m    [0;31m# return forward_pass_without_output_layer(net, xs)[:, target_class] - outputs[:, target_class].mean()[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 6 [0;31m    [0moutput1[0m [0;34m=[0m [0mnet[0m[0;34m.[0m[0mpredict_proba[0m[0;34m([0m[0mxs[0m[0;34m)[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mtarget_class[0m[0;34m][0m [0;34m-[0m [0moutputs[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mtarget_class[0m[0;34m][0m[0;34m.[0m[0mmean[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      7 [0;31m    [0moutput2[0m [0;34m=[0m [0mnet[0m[0;34m.[0m[0mpredict_proba[0m[0;34m([0m[0mxs[0m[0;34m)[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mtarget_class[0m[0;

ipdb>  n


> [0;32m/var/folders/yv/wb9c1bwx4r91d458q1c8ybhm0000gn/T/ipykernel_79125/3802623020.py[0m(7)[0;36mforward_pass_centered[0;34m()[0m
[0;32m      4 [0;31m    [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      5 [0;31m    [0;31m# return forward_pass_without_output_layer(net, xs)[:, target_class] - outputs[:, target_class].mean()[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      6 [0;31m    [0moutput1[0m [0;34m=[0m [0mnet[0m[0;34m.[0m[0mpredict_proba[0m[0;34m([0m[0mxs[0m[0;34m)[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mtarget_class[0m[0;34m][0m [0;34m-[0m [0moutputs[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mtarget_class[0m[0;34m][0m[0;34m.[0m[0mmean[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 7 [0;31m    [0moutput2[0m [0;34m=[0m [0mnet[0m[0;34m.[0m[0mpredict_proba[0m[0;34m([0m[0mxs[0m[0;34m)[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mtarget_class[0m[0;

ipdb>  n


> [0;32m/var/folders/yv/wb9c1bwx4r91d458q1c8ybhm0000gn/T/ipykernel_79125/3802623020.py[0m(8)[0;36mforward_pass_centered[0;34m()[0m
[0;32m      4 [0;31m    [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      5 [0;31m    [0;31m# return forward_pass_without_output_layer(net, xs)[:, target_class] - outputs[:, target_class].mean()[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      6 [0;31m    [0moutput1[0m [0;34m=[0m [0mnet[0m[0;34m.[0m[0mpredict_proba[0m[0;34m([0m[0mxs[0m[0;34m)[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mtarget_class[0m[0;34m][0m [0;34m-[0m [0moutputs[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mtarget_class[0m[0;34m][0m[0;34m.[0m[0mmean[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      7 [0;31m    [0moutput2[0m [0;34m=[0m [0mnet[0m[0;34m.[0m[0mpredict_proba[0m[0;34m([0m[0mxs[0m[0;34m)[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mtarget_class[0m[0;

ipdb>  output1


array([-1.76012888e-06,  8.79263968e-07,  8.80864915e-07,  8.80903489e-07,
       -6.42955113e-02,  8.80903489e-07,  8.80903488e-07,  8.80812446e-07,
        8.70963744e-07,  8.39361948e-07,  8.80903489e-07,  8.80903480e-07,
        8.80903489e-07,  8.80903235e-07,  8.80903489e-07,  8.80898807e-07,
        8.79823998e-07, -8.89184092e-07,  8.80903489e-07, -3.15174551e-05,
       -2.25735595e-03, -8.13690342e-04,  5.40839582e-07,  6.79175969e-07,
       -5.79020039e-03, -2.35278079e-01,  8.80903310e-07,  8.80903486e-07,
       -6.02173393e-04,  8.38470526e-07,  8.80902590e-07,  8.80903489e-07,
       -3.13636539e-05, -2.33867324e-07,  1.65968950e-07, -6.99041607e-04,
       -1.84969197e-04,  8.78467750e-07,  7.67935198e-07, -2.62857346e-05,
       -7.15534825e-04,  8.68110801e-07,  8.79861368e-07, -2.61525700e-03,
       -1.87494263e-04, -6.38816349e-06,  8.55978661e-07, -2.79227015e-06,
        4.35627563e-08,  8.80873568e-07,  8.80902910e-07,  8.80903480e-07,
        8.80826624e-07,  

ipdb>  output2


array([9.99997359e-01, 9.99999998e-01, 1.00000000e+00, 1.00000000e+00,
       9.35703608e-01, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00,
       9.99999990e-01, 9.99999958e-01, 1.00000000e+00, 1.00000000e+00,
       1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00,
       9.99999999e-01, 9.99998230e-01, 1.00000000e+00, 9.99967602e-01,
       9.97741763e-01, 9.99185429e-01, 9.99999660e-01, 9.99999798e-01,
       9.94208919e-01, 7.64721040e-01, 1.00000000e+00, 1.00000000e+00,
       9.99396946e-01, 9.99999958e-01, 1.00000000e+00, 1.00000000e+00,
       9.99967755e-01, 9.99998885e-01, 9.99999285e-01, 9.99300077e-01,
       9.99814150e-01, 9.99999998e-01, 9.99999887e-01, 9.99972833e-01,
       9.99283584e-01, 9.99999987e-01, 9.99999999e-01, 9.97383862e-01,
       9.99811625e-01, 9.99992731e-01, 9.99999975e-01, 9.99996327e-01,
       9.99999163e-01, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00,
       1.00000000e+00, 1.00000000e+00, 9.99980215e-01, 9.99992664e-01,
      

ipdb>  output1


array([-1.76012888e-06,  8.79263968e-07,  8.80864915e-07,  8.80903489e-07,
       -6.42955113e-02,  8.80903489e-07,  8.80903488e-07,  8.80812446e-07,
        8.70963744e-07,  8.39361948e-07,  8.80903489e-07,  8.80903480e-07,
        8.80903489e-07,  8.80903235e-07,  8.80903489e-07,  8.80898807e-07,
        8.79823998e-07, -8.89184092e-07,  8.80903489e-07, -3.15174551e-05,
       -2.25735595e-03, -8.13690342e-04,  5.40839582e-07,  6.79175969e-07,
       -5.79020039e-03, -2.35278079e-01,  8.80903310e-07,  8.80903486e-07,
       -6.02173393e-04,  8.38470526e-07,  8.80902590e-07,  8.80903489e-07,
       -3.13636539e-05, -2.33867324e-07,  1.65968950e-07, -6.99041607e-04,
       -1.84969197e-04,  8.78467750e-07,  7.67935198e-07, -2.62857346e-05,
       -7.15534825e-04,  8.68110801e-07,  8.79861368e-07, -2.61525700e-03,
       -1.87494263e-04, -6.38816349e-06,  8.55978661e-07, -2.79227015e-06,
        4.35627563e-08,  8.80873568e-07,  8.80902910e-07,  8.80903480e-07,
        8.80826624e-07,  

ipdb>  q


In [158]:
def thing2(xs):
    return net.predict_proba(xs)[:, 0] - outputs[:, 0].mean()

In [159]:
thing2(X_wine)[:5]

array([-1.76012888e-06,  8.79263968e-07,  8.80864915e-07,  8.80903489e-07,
       -6.42955113e-02])

#### By hand ...

In [66]:
models = {}
for target_class in np.sort(np.unique(y_wine)):
    print(f'Target class {target_class}')
    model = maxentropy.SamplingMinKLDensity(
        feature_functions,
        sampler,
        # Doesn't work:
        # prior_log_pdf = lambda xs: net.predict_log_proba(xs)[:, target_class] - outputs[:, target_class].mean(),
        prior_log_pdf = forward_pass_centered(net, target_class),
        # prior_log_pdf = lambda xs: net.predict_proba(xs)[:, target_class],
        matrix_format='ndarray',
        vectorized=True
    )
    X_wine_subset = X_wine[y_wine == target_class]
    k = model.features(X_wine_subset).mean(axis=0)
    print(k)
    model.fit(k)
    print(model.params)
    models[target_class] = model

Target class 0
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
[7.43029761e-07 1.77536620e+01 7.43029761e-07 7.43029761e-07
 7.43029761e-07 1.69891836e+01 1.79195615e+01 1.74790192e+01
 1.78988756e+01 1.79444659e+01 1.64701302e+01 1.53573105e+01
 1.77197138e+01]
Target class 1
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
[1.56779767e-06 1.79119799e+01 1.56779767e-06 1.56779767e-06
 1.56779767e-06 1.69126670e+01 1.80730683e+01 1.73732919e+01
 1.79330970e+01 1.81208369e+01 1.65728848e+01 1.53448655e+01
 1.75967791e+01]
Target class 2
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
[3.75686978e-06 1.78918647e+01 3.75686978e-06 3.75686978e-06
 3.75686978e-06 1.69863851e+01 1.83348892e+01 1.74767552e+01
 1.79699826e+01 1.76866877e+01 1.67143691e+01 1.54995871e+01
 1.74733545e+01]


In [67]:
log_scores = np.array([
    model.predict_log_proba(X_wine)
    for model in models.values()
]).T
log_scores.shape

(178, 3)

In [68]:
log_scores[:5]

array([[-32.67274158, -32.89903275, -33.03076447],
       [-32.67273895, -32.89903449, -33.03076537],
       [-32.67273894, -32.89903449, -33.03076537],
       [-32.67273894, -32.89903449, -33.03076537],
       [-32.73703534, -32.84024751, -33.02525596]])

In [69]:
from scipy.special import logsumexp

In [70]:
log_proba = (log_scores.T - logsumexp(log_scores, axis=1)).T
log_proba[:5]

array([[-0.91490695, -1.14119812, -1.27292984],
       [-0.91490456, -1.14120011, -1.27293099],
       [-0.91490456, -1.14120011, -1.27293099],
       [-0.91490456, -1.14120011, -1.27293099],
       [-0.9751375 , -1.07834968, -1.26335813]])

In [71]:
np.exp(log_proba)[:5]

array([[0.4005539 , 0.31943607, 0.28001003],
       [0.40055485, 0.31943543, 0.28000971],
       [0.40055485, 0.31943543, 0.28000971],
       [0.40055485, 0.31943543, 0.28000971],
       [0.37714049, 0.34015643, 0.28270308]])

In [72]:
proba = softmax(log_scores, axis=1)
proba[:5]

array([[0.4005539 , 0.31943607, 0.28001003],
       [0.40055485, 0.31943543, 0.28000971],
       [0.40055485, 0.31943543, 0.28000971],
       [0.40055485, 0.31943543, 0.28000971],
       [0.37714049, 0.34015643, 0.28270308]])

In [73]:
np.argmax(log_proba, axis=1)

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 1, 1, 1, 1,
       1, 1, 1, 1, 2, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 2, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2])

Choose the value (25.3 below) that causes this to be roughly 1:

In [74]:
# np.exp(log_scores + 25.3).sum(axis=1)

In [75]:
np.exp(log_scores + 25.3)[:5]

array([[0.00062814, 0.00050094, 0.00043911],
       [0.00062815, 0.00050093, 0.00043911],
       [0.00062815, 0.00050093, 0.00043911],
       [0.00062815, 0.00050093, 0.00043911],
       [0.00058903, 0.00053127, 0.00044153]])

In [76]:
np.argmax(log_scores, axis=1)

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 1, 1, 1, 1,
       1, 1, 1, 1, 2, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 2, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2])

In [77]:
# np.max(log_scores, axis=1) - np.min(log_scores, axis=1)

In [78]:
# softmax(np.exp(log_scores + 25.3), axis=1)[:5]

In [79]:
# log_scores

In [80]:
# log_proba = softmax(log_scores, axis=1)
# log_proba

In [81]:
pred = net._label_binarizer.inverse_transform(log_proba)

In [82]:
pred

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 1, 1, 1, 1,
       1, 1, 1, 1, 2, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 2, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2])

In [83]:
accuracy_score(y_wine, pred)

0.9438202247191011

In [84]:
net.predict(X_wine)

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 1, 1, 1, 0,
       0, 1, 0, 1, 2, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 2, 1, 1, 1, 1,
       0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0,
       0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2])

### Using MinKLClassifier

In [85]:
clf = maxentropy.MinKLClassifier(
    feature_functions,
    sampler,
    prior_log_proba_fn=net.predict_log_proba,
    # prior_log_proba_fn=lambda xs: forward_pass_centered(net, slice(None), xs),
    matrix_format='ndarray',
    vectorized=True
)
clf.fit(X_wine, y_wine)

AssertionError: 

In [None]:
# clf.predict_proba(X_wine)

In [None]:
clf.predict(X_wine)

In [None]:
models = {}
for target_class in np.sort(np.unique(y_wine)):
    print(f'Target class {target_class}')
    model = maxentropy.SamplingMinKLDensity(
        feature_functions,
        sampler,
        # Doesn't work:
        # prior_log_pdf = lambda xs: net.predict_log_proba(xs)[:, target_class] - outputs[:, target_class].mean(),
        prior_log_pdf = forward_pass_centered(net, target_class),
        # prior_log_pdf = lambda xs: net.predict_proba(xs)[:, target_class],
        matrix_format='ndarray',
        vectorized=True
    )
    X_wine_subset = X_wine[y_wine == target_class]
    k = model.features(X_wine_subset).mean(axis=0)
    model.fit(k)
    models[target_class] = model

### Ideas for improving the usability

##### Current API

In [None]:
BREAK

In [None]:
def non_neg(x):
    return x >= 0

prior_model_params = scipy.stats.norm.fit(df_cancer['mean concavity'])

auxiliary = scipy.stats.uniform(-0.2, 1.2)   # i.e. from -0.2 to 1.0

sampler = maxentropy.utils.auxiliary_sampler_scipy(auxiliary, n_samples=10_000)

model = maxentropy.SamplingMinKLDensity(
    [non_neg], sampler, prior_log_pdf = prior_model.logpdf, matrix_format='ndarray', 
)

k = model.features(np.array([X_cancer['mean concavity'].mean()]))

model.fit(k)

##### Desired API

In [None]:
model = maxentropy.SamplingMinKLDensity(sampler='uniform', matrix_format='ndarray', sampling_stretch_factor=0.1, n_samples=10_000)

In [None]:
feature_functions = [non_neg] * X_cancer.shape[1]

model.fit(X_cancer, feature_functions=feature_functions)

In [None]:
def non_neg(x):
    return x >= 0