Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ModelCheckPoint problem #127

Open
redaelhail opened this issue May 14, 2024 · 2 comments
Open

ModelCheckPoint problem #127

redaelhail opened this issue May 14, 2024 · 2 comments

Comments

@redaelhail
Copy link

Hello,

Thank you for your work.

I am doing domain adaptation with DANN. I would like to save the best model using model checkpoint based on the loss value of the task network:

    chk = ModelCheckpoint(os.path.join(model_directory_name,'Model'),
                          monitor="loss",
                          verbose=1,
                          save_best_only=True,
                          save_weights_only=False,
                          mode='min',
                          save_freq=1)

During the trainning, i keep receiving this warning:

WARNING:tensorflow:Model's `__init__()` arguments contain non-serializable objects. Please implement a `get_config()` method in the subclassed Model for proper saving and loading. Defaulting to empty config.

This is the training code:

# define callbacks
    callbacks_list = [chk]

    # Build model
    model = DANN(encoder=encoder(), task=task(), discriminator=discriminitor(),
                 Xt=Xt,lambda_= 0.1, metrics=["acc"],random_state=0)
    # start training
    model_log = model.fit(Xs, ys,epochs = 2, callbacks=callbacks_list, verbose=1, class_weight=class_weights)
@antoinedemathelin
Copy link
Collaborator

Hi @redaelhail,
Thank you for reporting the issue.

Yes, saving tensorflow objects is not easy... A suggestion is to only save the weights of the netwoks.

Here is a little example that should work:

import numpy as np
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint
from adapt.feature_based import DANN

np.random.seed(0)
Xt = np.random.randn(100, 2)
Xs = np.random.randn(100, 2)
ys = np.random.choice(2, 100)

def encoder():
  mod = tf.keras.Sequential()
  mod.add(tf.keras.layers.Dense(10, activation="relu", input_shape=(2,)))
  return mod

def task():
  mod = tf.keras.Sequential()
  mod.add(tf.keras.layers.Dense(1, activation="sigmoid"))
  return mod

def discriminitor():
  mod = tf.keras.Sequential()
  mod.add(tf.keras.layers.Dense(1, activation="sigmoid"))
  return mod


chk = ModelCheckpoint('Model.hdf5',
                        monitor="loss",
                        verbose=1,
                        save_best_only=True,
                        save_weights_only=True,
                        mode='min',
                        save_freq=1)

# define callbacks
callbacks_list = [chk]

# Build model
model = DANN(encoder=encoder(), task=task(), discriminator=discriminitor(),
             Xt=Xt, lambda_= 0.1, loss="bce", metrics=["acc"], random_state=0)
# start training
model_log = model.fit(Xs, ys,epochs = 2, callbacks=callbacks_list, verbose=1)

# Load saved weights
model.load_weights('Model.hdf5')

If you want to load the weights in a new model, you still have to call the fit function to instantiate the variables:

new_model = DANN(encoder=encoder(), task=task(), discriminator=discriminitor(),
                 Xt=Xt, lambda_= 0.1, metrics=["acc"], random_state=0)
new_model.fit(Xs, ys,epochs = 0)
new_model.load_weights('Model.hdf5')

You can set epochs=0 to avoid the training process.

PS: I see that, in your code, you use the accuracy metric. I guess you are solving a classification problem? Be careful with the default parameters of DANN. The loss by default is "mse", but I think you need to use the binary cross-entropy ("bce") or the cross-entropy for multiclass? You can also change the default optimizer in the DANN arguments.

@redaelhail
Copy link
Author

yes, that worked, i can now now save the best run. Thank you @antoinedemathelin.

For your remark, I am using the cross-entropy since i am dealing with multiclass classification; it was clear in the documentation that is should be written within the DANN class. Thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants