-
Notifications
You must be signed in to change notification settings - Fork 47
/
Copy pathtest_wdgrl.py
84 lines (71 loc) · 2.7 KB
/
test_wdgrl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""
Test functions for wdgrl module.
"""
import numpy as np
import tensorflow as tf
from tensorflow.keras import Sequential, Model
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import GlorotUniform
from adapt.feature_based import WDGRL
Xs = np.concatenate((
np.linspace(0, 1, 100).reshape(-1, 1),
np.zeros((100, 1))
), axis=1)
Xt = np.concatenate((
np.linspace(0, 1, 100).reshape(-1, 1),
np.ones((100, 1))
), axis=1)
ys = 0.2 * Xs[:, 0].ravel()
yt = 0.2 * Xt[:, 0].ravel()
def _get_encoder(input_shape=Xs.shape[1:]):
model = Sequential()
model.add(Dense(1, input_shape=input_shape,
kernel_initializer="ones",
use_bias=False))
model.compile(loss="mse", optimizer="adam")
return model
def _get_discriminator(input_shape=(1,)):
model = Sequential()
model.add(Dense(10,
input_shape=input_shape,
kernel_initializer=GlorotUniform(seed=0),
activation="elu"))
model.add(Dense(1,
kernel_initializer=GlorotUniform(seed=0),
activation=None))
model.compile(loss="mse", optimizer="adam")
return model
def _get_task(input_shape=(1,), output_shape=(1,)):
model = Sequential()
model.add(Dense(np.prod(output_shape),
kernel_initializer=GlorotUniform(seed=0),
use_bias=False,
input_shape=input_shape))
model.compile(loss="mse", optimizer=Adam(0.1))
return model
def test_fit_lambda_zero():
tf.random.set_seed(1)
np.random.seed(1)
model = WDGRL(_get_encoder(), _get_task(), _get_discriminator(),
lambda_=0, loss="mse", optimizer=Adam(0.01), metrics=["mse"],
random_state=0)
model.fit(Xs, ys, Xt, yt,
epochs=300, verbose=0)
assert isinstance(model, Model)
assert model.encoder_.get_weights()[0][1][0] == 1.0
assert np.sum(np.abs(model.predict(Xs).ravel() - ys)) < 0.01
assert np.sum(np.abs(model.predict(Xt).ravel() - yt)) > 10
def test_fit_lambda_one():
tf.random.set_seed(1)
np.random.seed(1)
model = WDGRL(_get_encoder(), _get_task(), _get_discriminator(),
lambda_=1, gamma=0, loss="mse", optimizer=Adam(0.01),
metrics=["mse"], random_state=0)
model.fit(Xs, ys, Xt, yt,
epochs=300, verbose=0)
assert isinstance(model, Model)
assert np.abs(model.encoder_.get_weights()[0][1][0] /
model.encoder_.get_weights()[0][0][0]) < 0.3
assert np.sum(np.abs(model.predict(Xs).ravel() - ys)) < 2
assert np.sum(np.abs(model.predict(Xt).ravel() - yt)) < 5