Skip to content

Commit 7441db8

Browse files
Fix RegularTransfer for Tensorflow 2.3.1
1 parent e18e4d2 commit 7441db8

File tree

7 files changed

+33
-15
lines changed

7 files changed

+33
-15
lines changed

adapt/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1351,7 +1351,8 @@ def train_step(self, data):
13511351
loss = tf.reduce_mean(loss)
13521352

13531353
# Run backwards pass.
1354-
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
1354+
gradients = tape.gradient(loss, self.trainable_variables)
1355+
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
13551356
self.compiled_metrics.update_state(ys, y_pred)
13561357
# Collect metrics to return
13571358
return_metrics = {}

adapt/feature_based/_coral.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,7 @@ class CORAL(BaseAdaptEstimator):
8686
See also
8787
--------
8888
DeepCORAL
89-
FE
90-
mSDA
89+
FA
9190
9291
References
9392
----------

adapt/feature_based/_dann.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
DANN
33
"""
44

5+
import warnings
56
import numpy as np
67
import tensorflow as tf
78

@@ -104,12 +105,17 @@ def __init__(self,
104105
Xt=None,
105106
yt=None,
106107
lambda_=0.1,
107-
gamma=10.,
108108
verbose=1,
109109
copy=True,
110110
random_state=None,
111111
**params):
112112

113+
if "gamma" in params:
114+
warnings.warn("the `gamma` argument has been removed from DANN."
115+
" If you want to use the lambda update process, please"
116+
" use the `UpdateLambda` callback from adapt.utils")
117+
params.pop("gamma")
118+
113119
names = self._get_param_names()
114120
kwargs = {k: v for k, v in locals().items() if k in names}
115121
kwargs.update(params)

adapt/feature_based/_fa.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ class FA(BaseAdaptEstimator):
6969
See also
7070
--------
7171
CORAL
72-
mSDA
7372
7473
Examples
7574
--------

adapt/feature_based/_fmmd.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def func(W):
3737
Kxy = tf.matmul(tf.matmul(Xs, tf.linalg.diag(W**1)), tf.transpose(Xt))
3838

3939
K = tf.concat((Kxx, Kxy), axis=1)
40-
K = tf.concat((K, tf.concat((Kyy, tf.transpose(Kxy)), axis=1)), axis=0)
40+
K = tf.concat((K, tf.concat((tf.transpose(Kxy), Kyy), axis=1)), axis=0)
4141

4242
f = -tf.linalg.trace(tf.matmul(K, L))
4343
Df = tf.gradients(f, W)
@@ -53,7 +53,7 @@ def func(W):
5353
Kxy = pairwise_X(tf.matmul(Xs, tf.linalg.diag(W**1)), Xt)
5454

5555
K = tf.concat((Kxx, Kxy), axis=1)
56-
K = tf.concat((K, tf.concat((Kyy, tf.transpose(Kxy)), axis=1)), axis=0)
56+
K = tf.concat((K, tf.concat((tf.transpose(Kxy), Kyy), axis=1)), axis=0)
5757
K = tf.exp(-gamma * K)
5858

5959
f = -tf.linalg.trace(tf.matmul(K, L))
@@ -70,7 +70,7 @@ def func(W):
7070
Kxy = tf.matmul(tf.matmul(Xs, tf.linalg.diag(W**1)), tf.transpose(Xt))
7171

7272
K = tf.concat((Kxx, Kxy), axis=1)
73-
K = tf.concat((K, tf.concat((Kyy, tf.transpose(Kxy)), axis=1)), axis=0)
73+
K = tf.concat((K, tf.concat((tf.transpose(Kxy), Kyy), axis=1)), axis=0)
7474
K = (gamma * K + coef)**degree
7575

7676
f = -tf.linalg.trace(tf.matmul(K, L))
@@ -144,7 +144,7 @@ class fMMD(BaseAdaptEstimator):
144144
See also
145145
--------
146146
CORAL
147-
FE
147+
FA
148148
149149
Examples
150150
--------
@@ -155,7 +155,7 @@ class fMMD(BaseAdaptEstimator):
155155
>>> model = fMMD(RidgeClassifier(), Xt=Xt, kernel="linear", random_state=0, verbose=0)
156156
>>> model.fit(Xs, ys)
157157
>>> model.score(Xt, yt)
158-
0.45
158+
0.92
159159
160160
References
161161
----------

adapt/instance_based/_kliep.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ def _fit(self, Xs, Xt, kernel_params):
349349
b = b.reshape(-1, 1)
350350

351351
alpha = np.ones((len(centers), 1)) / len(centers)
352+
alpha = self._projection(alpha, b)
352353
previous_objective = -np.inf
353354
objective = np.mean(np.log(np.dot(A, alpha) + EPS))
354355
if self.verbose > 1:
@@ -360,10 +361,7 @@ def _fit(self, Xs, Xt, kernel_params):
360361
alpha += self.lr * np.dot(
361362
np.transpose(A), 1./(np.dot(A, alpha) + EPS)
362363
)
363-
alpha += b * ((((1-np.dot(np.transpose(b), alpha)) /
364-
(np.dot(np.transpose(b), b) + EPS))))
365-
alpha = np.maximum(0, alpha)
366-
alpha /= (np.dot(np.transpose(b), alpha) + EPS)
364+
alpha = self._projection(alpha, b)
367365
objective = np.mean(np.log(np.dot(A, alpha) + EPS))
368366
k += 1
369367

@@ -374,6 +372,14 @@ def _fit(self, Xs, Xt, kernel_params):
374372
return alpha, centers
375373

376374

375+
def _projection(self, alpha, b):
376+
alpha += b * ((((1-np.dot(np.transpose(b), alpha)) /
377+
(np.dot(np.transpose(b), b) + EPS))))
378+
alpha = np.maximum(0, alpha)
379+
alpha /= (np.dot(np.transpose(b), alpha) + EPS)
380+
return alpha
381+
382+
377383
def _cross_val_jscore(self, Xs, Xt, kernel_params, cv):
378384
split = int(len(Xt) / cv)
379385
cv_scores = []

tests/test_dann.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Test functions for dann module.
33
"""
44

5+
import pytest
56
import numpy as np
67
import tensorflow as tf
78
from tensorflow.keras import Sequential, Model
@@ -111,4 +112,10 @@ def test_optimizer_enc_disc():
111112
epochs=10, batch_size=32, verbose=0)
112113
assert np.all(model.encoder_.get_weights()[0] == encoder.get_weights()[0])
113114
assert np.any(model.task_.get_weights()[0] != task.get_weights()[0])
114-
assert np.any(model.discriminator_.get_weights()[0] != disc.get_weights()[0])
115+
assert np.any(model.discriminator_.get_weights()[0] != disc.get_weights()[0])
116+
117+
118+
def test_warnings():
119+
with pytest.warns() as record:
120+
model = DANN(gamma=10.)
121+
assert len(record) == 1

0 commit comments

Comments
 (0)