-
Notifications
You must be signed in to change notification settings - Fork 47
/
Copy pathtest_iwc.py
46 lines (38 loc) · 1.32 KB
/
test_iwc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
"""
Test functions for iwn module.
"""
import numpy as np
from sklearn.linear_model import RidgeClassifier
from adapt.utils import make_classification_da
from adapt.instance_based import IWC
from adapt.utils import get_default_discriminator
from tensorflow.keras.optimizers import Adam
Xs, ys, Xt, yt = make_classification_da()
def test_iwn():
model = IWC(RidgeClassifier(0.), classifier=RidgeClassifier(0.),
Xt=Xt, random_state=0)
model.fit(Xs, ys);
model.predict(Xt)
model.score(Xt, yt)
w1 = model.predict_weights()
w2 = model.predict_weights(Xs)
assert np.abs(w1-w2).sum() < 10**-5
def test_default_classif():
model = IWC(RidgeClassifier(0.), classifier=None,
Xt=Xt, random_state=0)
model.fit(Xs, ys);
model.predict(Xt)
model.score(Xt, yt)
w1 = model.predict_weights()
w2 = model.predict_weights(Xs)
assert np.abs(w1-w2).sum() < 10**-5
def test_nn_classif():
model = IWC(RidgeClassifier(0.), classifier=get_default_discriminator(),
cl_params=dict(epochs=10, optimizer=Adam(), loss="bce", verbose=0),
Xt=Xt, random_state=0)
model.fit(Xs, ys);
model.predict(Xt)
model.score(Xt, yt)
w1 = model.predict_weights()
w2 = model.predict_weights(Xs)
assert np.abs(w1-w2).sum() < 10**-5