Skip to content

Commit

Permalink
Fix failing test caused by a change in multivariate_normal sequence.
Browse files Browse the repository at this point in the history
  • Loading branch information
vnmabus committed Jun 13, 2023
1 parent a219252 commit 4dbde09
Showing 1 changed file with 29 additions and 22 deletions.
51 changes: 29 additions & 22 deletions skfda/ml/classification/_logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,32 +58,39 @@ class LogisticRegression(
intercept\_: Independent term.
Examples:
>>> from numpy import array
>>> import skfda
>>> from skfda.datasets import make_gaussian_process
>>> from skfda.ml.classification import LogisticRegression
>>> fd1 = make_gaussian_process(
... n_samples=50,
... n_features=100,
... noise=0.7,
... random_state=0,
>>>
>>> n_samples = 10000
>>> n_features = 200
>>>
>>> def mean_1(t):
... return (np.abs(t - 0.25)
... - 2 * np.abs(t - 0.5)
... + np.abs(t - 0.75))
>>>
>>> X_0 = make_gaussian_process(
... n_samples=n_samples // 2,
... n_features=n_features,
... random_state=0,
... )
>>> fd2 = make_gaussian_process(
... n_samples=50,
... n_features = 100,
... mean = array([1]*100),
... noise = 0.7,
... random_state=0
>>> X_1 = make_gaussian_process(
... n_samples=n_samples // 2,
... n_features=n_features,
... mean=mean_1,
... random_state=1,
... )
>>> fd = fd1.concatenate(fd2)
>>> y = 50*[0] + 50*[1]
>>> lr = LogisticRegression()
>>> _ = lr.fit(fd[::2], y[::2])
>>> lr.coef_.round(2)
array([[ 18.91, 19.69, 19.9 , 6.09, 12.49]])
>>> lr.points_.round(2)
array([ 0.11, 0.06, 0.07, 0.02, 0.03])
>>> lr.score(fd[1::2],y[1::2])
0.92
>>> X = skfda.concatenate((X_0, X_1))
>>>
>>> y = np.zeros(n_samples)
>>> y [n_samples // 2:] = 1
>>> lr = LogisticRegression(max_features=3)
>>> _ = lr.fit(X[::2], y[::2])
>>> np.allclose(sorted(lr.points_), [0.25, 0.5, 0.75], rtol=1e-2)
True
>>> lr.score(X[1::2],y[1::2])
0.7498
References:
.. footbibliography::
Expand Down

0 comments on commit 4dbde09

Please sign in to comment.