Skip to content

Commit fc8d586

Browse files
update to new tf version
1 parent 3d06483 commit fc8d586

File tree

12 files changed

+283
-293
lines changed

12 files changed

+283
-293
lines changed

adapt/base.py

Lines changed: 134 additions & 104 deletions
Large diffs are not rendered by default.

adapt/feature_based/_adda.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -183,14 +183,11 @@ def pretrain_step(self, data):
183183
gradients_enc = enc_tape.gradient(enc_loss, trainable_vars_enc)
184184

185185
# Update weights
186-
self.optimizer.apply_gradients(zip(gradients_task, trainable_vars_task))
187-
self.optimizer_enc.apply_gradients(zip(gradients_enc, trainable_vars_enc))
186+
self.pretrain_optimizer.apply_gradients(zip(gradients_task, trainable_vars_task))
187+
self.pretrain_optimizer_enc.apply_gradients(zip(gradients_enc, trainable_vars_enc))
188188

189189
# Update metrics
190-
self.compiled_metrics.update_state(ys, ys_pred)
191-
self.compiled_loss(ys, ys_pred)
192-
# Return a dict mapping metric names to current value
193-
logs = {m.name: m.result() for m in self.metrics}
190+
logs = self._update_logs(ys, ys_pred)
194191
return logs
195192

196193

@@ -211,6 +208,8 @@ def train_step(self, data):
211208
else:
212209
# encoder src is not needed if pretrain=False
213210
Xs_enc = Xs
211+
212+
ys_pred = self.task_(Xs_enc, training=False)
214213

215214
ys_disc = self.discriminator_(Xs_enc, training=True)
216215

@@ -245,7 +244,8 @@ def train_step(self, data):
245244
# self.compiled_loss(ys, ys_pred)
246245
# Return a dict mapping metric names to current value
247246
# logs = {m.name: m.result() for m in self.metrics}
248-
logs = self._get_disc_metrics(ys_disc, yt_disc)
247+
logs = self._update_logs(ys, ys_pred)
248+
logs.update(self._get_disc_metrics(ys_disc, yt_disc))
249249
return logs
250250

251251

@@ -262,15 +262,11 @@ def _get_disc_metrics(self, ys_disc, yt_disc):
262262
))
263263
return disc_dict
264264

265-
266-
def _initialize_weights(self, shape_X):
267-
# Init weights encoder
268-
self(np.zeros((1,) + shape_X))
269-
270-
# Set same weights to encoder_src
265+
266+
def _initialize_networks(self):
267+
super()._initialize_networks()
271268
if self.pretrain:
272269
# encoder src is not needed if pretrain=False
273-
self.encoder_(np.zeros((1,) + shape_X))
274270
self.encoder_src_ = check_network(self.encoder_,
275271
copy=True,
276272
name="encoder_src")

adapt/feature_based/_ccsa.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,6 @@ def train_step(self, data):
190190
self.optimizer_enc.apply_gradients(zip(gradients_enc, trainable_vars_enc))
191191

192192
# Update metrics
193-
self.compiled_metrics.update_state(ys, ys_pred)
194-
self.compiled_loss(ys, ys_pred)
195-
# Return a dict mapping metric names to current value
196-
logs = {m.name: m.result() for m in self.metrics}
193+
logs = self._update_logs(ys, ys_pred)
197194
logs.update({"contrast": contrastive_loss})
198195
return logs

adapt/feature_based/_cdan.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -282,10 +282,7 @@ def train_step(self, data):
282282
self.optimizer_disc.apply_gradients(zip(gradients_disc, trainable_vars_disc))
283283

284284
# Update metrics
285-
self.compiled_metrics.update_state(ys, ys_pred)
286-
self.compiled_loss(ys, ys_pred)
287-
# Return a dict mapping metric names to current value
288-
logs = {m.name: m.result() for m in self.metrics}
285+
logs = self._update_logs(ys, ys_pred)
289286
disc_metrics = self._get_disc_metrics(ys_disc, yt_disc)
290287
logs.update({"disc_loss": disc_loss})
291288
logs.update(disc_metrics)
@@ -303,19 +300,19 @@ def _get_disc_metrics(self, ys_disc, yt_disc):
303300

304301

305302
def _initialize_weights(self, shape_X):
306-
self(np.zeros((1,) + shape_X))
307-
Xs_enc = self.encoder_(np.zeros((1,) + shape_X), training=True)
308-
ys_pred = self.task_(Xs_enc, training=True)
309-
if Xs_enc.get_shape()[1] * ys_pred.get_shape()[1] > self.max_features:
303+
self.encoder_.build((None,) + shape_X)
304+
self.task_.build(self.encoder_.output_shape)
305+
if self.encoder_.output_shape[1] * self.task_.output_shape[1] > self.max_features:
310306
self.is_overloaded_ = True
311-
self._random_task = tf.random.normal([ys_pred.get_shape()[1],
312-
self.max_features])
313-
self._random_enc = tf.random.normal([Xs_enc.get_shape()[1],
314-
self.max_features])
315-
self.discriminator_(np.zeros((1, self.max_features)))
307+
self._random_task = tf.random.normal([self.task_.output_shape[1],
308+
self.max_features])
309+
self._random_enc = tf.random.normal([self.encoder_.output_shape[1],
310+
self.max_features])
311+
self.discriminator_.build((None, self.max_features))
316312
else:
317313
self.is_overloaded_ = False
318-
self.discriminator_(np.zeros((1, Xs_enc.get_shape()[1] * ys_pred.get_shape()[1])))
314+
self.discriminator_.build((None, self.encoder_.output_shape[1] * self.task_.output_shape[1]))
315+
self.build((None,) + shape_X)
319316

320317

321318
def _initialize_networks(self):

adapt/feature_based/_dann.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
DANN
33
"""
44

5+
import inspect
56
import warnings
67
import numpy as np
78
import tensorflow as tf
@@ -170,10 +171,13 @@ def train_step(self, data):
170171
self.optimizer_disc.apply_gradients(zip(gradients_disc, trainable_vars_disc))
171172

172173
# Update metrics
173-
self.compiled_metrics.update_state(ys, ys_pred)
174-
self.compiled_loss(ys, ys_pred)
174+
#for metric in self.metrics:
175+
# metric.update_state(ys, ys_pred)
176+
#self.compiled_loss(ys, ys_pred)
175177
# Return a dict mapping metric names to current value
176-
logs = {m.name: m.result() for m in self.metrics}
178+
#logs = {m.name: m.result() for m in self.metrics}
179+
180+
logs = self._update_logs(ys, ys_pred)
177181
disc_metrics = self._get_disc_metrics(ys_disc, yt_disc)
178182
logs.update({"disc_loss": disc_loss})
179183
logs.update(disc_metrics)

adapt/utils.py

Lines changed: 30 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
except:
1717
from scikeras.wrappers import KerasClassifier, KerasRegressor
1818
import tensorflow as tf
19-
import tensorflow.keras.backend as K
2019
from tensorflow.keras import Sequential, Model
2120
from tensorflow.keras.layers import Layer, Dense, Flatten, Input
2221
from tensorflow.keras.models import clone_model
@@ -88,24 +87,25 @@ def accuracy(y_true, y_pred):
8887
Boolean Tensor
8988
"""
9089
# TODO: accuracy can't handle 1D ys.
91-
multi_columns_t = K.cast(K.greater(K.shape(y_true)[1], 1),
92-
"float32")
93-
binary_t = K.reshape(K.sum(K.cast(K.greater(y_true, 0.5),
94-
"float32"), axis=-1), (-1,))
95-
multi_t = K.reshape(K.cast(K.argmax(y_true, axis=-1),
96-
"float32"), (-1,))
90+
dtype = y_pred.dtype
91+
multi_columns_t = tf.cast(tf.greater(tf.shape(y_true)[1], 1),
92+
dtype)
93+
binary_t = tf.reshape(tf.reduce_sum(tf.cast(tf.greater(y_true, 0.5),
94+
dtype), axis=-1), (-1,))
95+
multi_t = tf.reshape(tf.cast(tf.math.argmax(y_true, axis=-1),
96+
dtype), (-1,))
9797
y_true = ((1 - multi_columns_t) * binary_t +
9898
multi_columns_t * multi_t)
9999

100-
multi_columns_p = K.cast(K.greater(K.shape(y_pred)[1], 1),
101-
"float32")
102-
binary_p = K.reshape(K.sum(K.cast(K.greater(y_pred, 0.5),
103-
"float32"), axis=-1), (-1,))
104-
multi_p = K.reshape(K.cast(K.argmax(y_pred, axis=-1),
105-
"float32"), (-1,))
100+
multi_columns_p = tf.cast(tf.greater(tf.shape(y_pred)[1], 1),
101+
dtype)
102+
binary_p = tf.reshape(tf.reduce_sum(tf.cast(tf.greater(y_pred, 0.5),
103+
dtype), axis=-1), (-1,))
104+
multi_p = tf.reshape(tf.cast(tf.math.argmax(y_pred, axis=-1),
105+
dtype), (-1,))
106106
y_pred = ((1 - multi_columns_p) * binary_p +
107-
multi_columns_p * multi_p)
108-
return tf.keras.metrics.get("acc")(y_true, y_pred)
107+
multi_columns_p * multi_p)
108+
return tf.cast(tf.math.equal(y_true, y_pred), dtype)
109109

110110

111111
def predict(self, x, **kwargs):
@@ -259,11 +259,11 @@ def check_network(network, copy=True,
259259
# but no input_shape
260260
if hasattr(network, "input_shape"):
261261
shape = network.input_shape[1:]
262-
new_network = clone_model(network, input_tensors=Input(shape))
262+
new_network = clone_model(network)
263263
new_network.set_weights(network.get_weights())
264264
elif network.built:
265265
shape = network._build_input_shape[1:]
266-
new_network = clone_model(network, input_tensors=Input(shape))
266+
new_network = clone_model(network)
267267
new_network.set_weights(network.get_weights())
268268
else:
269269
new_network = clone_model(network)
@@ -284,7 +284,7 @@ def check_network(network, copy=True,
284284
new_network._name = name
285285

286286
# Override the predict method to speed the prediction for small dataset
287-
new_network.predict = predict.__get__(new_network)
287+
# new_network.predict = predict.__get__(new_network)
288288
return new_network
289289

290290

@@ -366,62 +366,6 @@ def get_default_discriminator(name=None, state=None):
366366
return model
367367

368368

369-
@tf.custom_gradient
370-
def _grad_handler(x, lambda_):
371-
y = tf.identity(x)
372-
def custom_grad(dy):
373-
return (lambda_ * dy, 0. * lambda_)
374-
return y, custom_grad
375-
376-
class GradientHandler(Layer):
377-
"""
378-
Multiply gradients with a scalar during backpropagation.
379-
380-
Act as identity in forward step.
381-
382-
Parameters
383-
----------
384-
lambda_init : float (default=1.)
385-
Scalar multiplier
386-
"""
387-
def __init__(self, lambda_init=1., name="g_handler"):
388-
super().__init__(name=name)
389-
self.lambda_init=lambda_init
390-
self.lambda_ = tf.Variable(lambda_init,
391-
trainable=False,
392-
dtype="float32")
393-
394-
def call(self, x):
395-
"""
396-
Call gradient handler.
397-
398-
Parameters
399-
----------
400-
x: object
401-
Inputs
402-
403-
Returns
404-
-------
405-
x, custom gradient function
406-
"""
407-
return _grad_handler(x, self.lambda_)
408-
409-
410-
def get_config(self):
411-
"""
412-
Return config dictionnary.
413-
414-
Returns
415-
-------
416-
dict
417-
"""
418-
config = super().get_config().copy()
419-
config.update({
420-
'lambda_init': self.lambda_init
421-
})
422-
return config
423-
424-
425369
def make_classification_da(n_samples=100,
426370
n_features=2,
427371
random_state=2):
@@ -638,8 +582,18 @@ def check_fitted_network(estimator):
638582
if isinstance(estimator, Model):
639583
estimator.__deepcopy__ = __deepcopy__.__get__(estimator)
640584
return estimator
641-
642-
585+
586+
587+
def check_if_compiled(network):
588+
"""
589+
Check if the network is compiled.
590+
"""
591+
if hasattr(network, "compiled") and network.compiled:
592+
return True
593+
elif hasattr(network, "_is_compiled") and networtf._is_compiled:
594+
return True
595+
else:
596+
return False
643597

644598
# Try to save the initial estimator if it is a Keras Model
645599
# This is required for cloning the adapt method.

tests/test_adda.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,7 @@
88
from tensorflow.keras import Sequential, Model
99
from tensorflow.keras.layers import Dense
1010
from tensorflow.keras.initializers import GlorotUniform
11-
try:
12-
from tensorflow.keras.optimizers.legacy import Adam
13-
except:
14-
from tensorflow.keras.optimizers import Adam
11+
from tensorflow.keras.optimizers import Adam
1512

1613
from adapt.feature_based import ADDA
1714

0 commit comments

Comments
 (0)