In [None]:
import wandb
import tensorflow as tf
import wandb
from wandb.keras import WandbCallback

class WandbManager:
    def __init__(self, project_name, entity=None, config=None):
        """
        Initialize wandb setup and connect to a specific project.

        :param project_name: Name of the wandb project to connect to.
        :param entity: Optional. The team or user who owns the project.
        :param config: Optional. Configuration parameters for the experiment.
        """
        self.project_name = project_name
        self.entity = entity
        self.config = config or {}

    def start_experiment(self, experiment_name, config_updates=None):
        """
        Start a new experiment/run in wandb.

        :param experiment_name: Unique name for the experiment/run.
        :param config_updates: Optional. Dictionary of configuration parameters to update.
        """
        # Update the experiment configuration if provided
        if config_updates:
            self.config.update(config_updates)

        # Initialize the wandb run
        wandb.init(project=self.project_name, entity=self.entity, config=self.config, name=experiment_name, reinit=True)

    def log_metrics(self, metrics):
        """
        Log metrics for the current experiment.

        :param metrics: Dictionary of metric names and their values.
        """
        wandb.log(metrics)

    def log_model(self, model, model_name="model"):
        """
        Log model architecture and parameters.

        :param model: The model to log.
        :param model_name: Optional. A name for the logged model.
        """
        wandb.watch(model, log="all")

        # Optionally, save and log the model explicitly if needed
        # model.save(model_name)
        # wandb.save(model_name)

    def finish_experiment(self):
        """
        Finish the current experiment/run.
        """
        wandb.finish()

In [None]:
import tensorflow as tf
import wandb
from wandb.keras import WandbCallback

class WandbKerasModel(WandbManager):
    def __init__(self, model, project_name, entity=None, config=None):
        """
        Initialize the WandbManager and set up the Keras model.

        :param model: The Keras model to be used.
        :param project_name: Name of the wandb project to connect to.
        :param entity: Optional. The team or user who owns the project.
        :param config: Optional. Configuration parameters for the experiment.
        """
        # Initialize the WandbManager
        super(WandbKerasModel, self).__init__(project_name, entity, config)

        # Set the Keras model
        self.model = model

    def compile_and_fit(self, *args, **kwargs):
        """
        Compile and fit the Keras model, automatically logging to wandb using the WandbCallback.

        All arguments for the `compile` and `fit` methods can be passed here.
        """
        # Extract the 'callbacks' keyword argument if provided
        callbacks = kwargs.pop('callbacks', [])
        callbacks.append(WandbCallback())

        # Compile the model
        self.model.compile(*args, **kwargs)

        # Start the wandb experiment
        self.start_experiment(experiment_name=kwargs.get('experiment_name', 'Unnamed Experiment'))

        # Fit the model, including the WandbCallback in callbacks
        return self.model.fit(callbacks=callbacks, **kwargs)

    def __getattr__(self, name):
        """
        Delegate attribute access to the underlying Keras model if not found in this class.
        """
        return getattr(self.model, name)

# Define your Keras model
model = tf.keras.models.Sequential([
    tf.keras.layers.InputLayer(input_shape=(28, 28)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10)
])

# Wrap your model with WandbKerasModel
wandb_model = WandbKerasModel(model=model, project_name='your_project_name', entity='your_wandb_entity', config={'learning_rate': 0.001})

# Compile and fit your model, automatically logging to wandb
wandb_model.compile_and_fit(
    optimizer='adam',
    loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'],
    experiment_name='my_keras_experiment',
    epochs=5,
    validation_data=(x_val, y_val)
)
