In [1]:
import io
import pandas as pd
from keras.models import Model
import wandb
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
import re
import tempfile
from PIL import Image


def model_summary_to_df(model) -> pd.DataFrame:
    """
    Convert a Keras model summary to a pandas DataFrame.

    Args:
        model: The Keras model to convert the summary of.

    Returns:
        A pandas DataFrame containing the model summary information.

    """
    # Capture the model summary
    stream = io.StringIO()
    model.summary(print_fn=lambda x: stream.write(x + "\n"))
    summary_string = stream.getvalue()
    stream.close()

    # Split summary into lines
    lines = summary_string.split("\n")

    # Dynamically find the header line
    header_line_index = 0
    for i, line in enumerate(lines[:4]):  # Header should be within the first few lines
        if "Output Shape" in line and "Param #" in line:
            header_line_index = i
            break

    # Use the found header line to determine column start positions
    header_line = lines[header_line_index]
    col_names = ["Layer", "Output Shape", "Param #"]
    col_starts = [header_line.index(col_name) for col_name in col_names]

    # Parse the model summary string
    parsed_lines = lines[header_line_index + 2 : -5]  # Adjusted to skip the header
    summary_info = lines[-5:]  # Final summary rows

    data = []
    # Parse layer information using column start positions
    for line in parsed_lines:
        if line and len(line.strip(" ")) > 0:  # Non-empty line
            layer = line[: col_starts[1]].strip()
            output_shape = line[col_starts[1] : col_starts[2]].strip()
            params = line[col_starts[2] :].strip()
            data.append([layer, output_shape, params])

    # Parse and add summary info
    for info in summary_info:
        if info:  # Check if the line is not empty
            parts = info.split(":")
            if len(parts) == 2:
                info_label = parts[0].strip()
                info_value = parts[1].strip()
                data.append([info_label, "", info_value])  # Add with empty 'Output Shape'

    # Create and return the DataFrame
    return pd.DataFrame(data, columns=col_names)


def wandb_log_model_summary_and_architecture(
    model, log_summary: bool = True, log_architecture_plot: bool = False
) -> None:
    """
    Logs the model summary and architecture to wandb.

    Args:
        model: The model to log the summary and architecture for.
        log_summary: A boolean indicating whether to log the model summary. Default is True.
        log_architecture_plot: A boolean indicating whether to log the model architecture plot. Default is False.

    Returns:
        None

    """
    if log_summary:
        df = model_summary_to_df(model)  # Ensure this function returns a DataFrame
        wandb.log({"Model Summary": wandb.Table(dataframe=df)})
    if log_architecture_plot:
        buffer = io.BytesIO()
        tf.keras.utils.plot_model(
            model,
            to_file=buffer,
            show_shapes=True,
            show_layer_names=True,
        )
        buffer.seek(0)
        wandb.log({"Model Architecture": wandb.Image(data_or_path=buffer)})
        buffer.close()

# start a new run
wandb.init(project="test")

# Define a complex model (example)
model = Sequential([
    Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(64, 64, 3), name='conv2d_1'),
    MaxPooling2D(pool_size=(2, 2), name='max_pooling2d_1'),
    Conv2D(64, (3, 3), activation='relu', name='conv2d_2'),
    MaxPooling2D(pool_size=(2, 2), name='max_pooling2d_2'),
    Flatten(name='flatten'),
    Dense(128, activation='relu', name='dense_1'),
    Dropout(0.5, name='dropout'),
    Dense(10, activation='softmax', name='output')  # Assuming 10 classes for classification
])

# print summary
print(model.summary())

# Get the model summary DataFrame
df = model_summary_to_df(model)

# log to wandb
# assumes you have already done `wandb.init()`
wandb_log_model_summary_and_architecture(model)

# Display the DataFrame
display(df)

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


ModuleNotFoundError: No module named 'keras'