-
Notifications
You must be signed in to change notification settings - Fork 47
/
Copy pathtest_fmmd.py
63 lines (46 loc) · 2.01 KB
/
test_fmmd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import pytest
import numpy as np
import tensorflow as tf
from adapt.feature_based import fMMD
from adapt.feature_based._fmmd import _get_optim_function
np.random.seed(0)
n = 50
m = 50
p = 6
Xs = np.random.randn(m, p)*0.1 + np.array([0.]*(p-2) + [2., 2.])
Xt = np.random.randn(n, p)*0.1
def test_fmmd():
fmmd = fMMD()
fmmd.fit_transform(Xs, Xt);
assert fmmd.features_scores_[-2:].sum() > 10 * fmmd.features_scores_[:-2].sum()
assert np.all(fmmd.selected_features_ == [True]*4 + [False]*2)
assert np.abs(fmmd.transform(Xs) - Xs[:, :4]).sum() == 0.
fmmd.set_params(kernel="rbf")
fmmd.fit_transform(Xs, Xt);
assert fmmd.features_scores_[-2:].sum() > 10 * fmmd.features_scores_[:-2].sum()
fmmd.set_params(kernel="poly", degree=2, gamma=0.1)
fmmd.fit_transform(Xs, Xt);
assert fmmd.features_scores_[-2:].sum() > 10 * fmmd.features_scores_[:-2].sum()
def test_fmmd_diff_size():
fmmd = fMMD()
fmmd.fit_transform(Xs, Xt[:40]);
assert fmmd.features_scores_[-2:].sum() > 10 * fmmd.features_scores_[:-2].sum()
assert np.all(fmmd.selected_features_ == [True]*4 + [False]*2)
assert np.abs(fmmd.transform(Xs) - Xs[:, :4]).sum() == 0.
fmmd.set_params(kernel="rbf")
fmmd.fit_transform(Xs, Xt[:40]);
assert fmmd.features_scores_[-2:].sum() > 10 * fmmd.features_scores_[:-2].sum()
fmmd.set_params(kernel="poly", degree=2, gamma=0.1)
fmmd.fit_transform(Xs, Xt[:40]);
assert fmmd.features_scores_[-2:].sum() > 10 * fmmd.features_scores_[:-2].sum()
def test_kernel_fct():
tf.config.experimental_run_functions_eagerly(True)
fct = _get_optim_function(Xs, Xt, kernel="linear")
with pytest.raises(Exception) as excinfo:
fct(tf.identity(np.ones(6)))
fct = _get_optim_function(Xs, Xt, kernel="rbf")
with pytest.raises(Exception) as excinfo:
fct(tf.identity(np.ones(6)))
fct = _get_optim_function(Xs, Xt, kernel="poly")
with pytest.raises(Exception) as excinfo:
fct(tf.identity(np.ones(6)))