In [1]:
import io
import pprint as pp
import pandas as pd
from keras.models import Model
import wandb
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
import re
import tempfile
import os


def model_summary_to_df(model) -> pd.DataFrame:
    """
    Convert a Keras model summary to a pandas DataFrame.

    Args:
        model: The Keras model to convert the summary of.

    Returns:
        A pandas DataFrame containing the model summary information.

    """
    # Capture the model summary
    stream = io.StringIO()
    model.summary(print_fn=lambda x: stream.write(x + "\n"))
    summary_string = stream.getvalue()
    stream.close()

    # Split summary into lines
    lines = summary_string.split("\n")

    # Dynamically find the header line and summary line index
    header_line_index = 0
    summary_line_index = 0
    for i, line in enumerate(lines):
        if "Output Shape" in line and "Param #" in line:
            header_line_index = i
        if "Total params:" in line:
            summary_line_index = i

    # Use the found header line to determine column start positions
    header_line = lines[header_line_index]
    col_names = ["Layer", "Output Shape", "Param #"]
    col_starts = [header_line.index(col_name) for col_name in col_names]

    # Parse the model summary string
    model_lines = lines[header_line_index:summary_line_index]
    summary_lines = lines[-summary_line_index:]

    data = []
    # Parse layer information using column start positions
    for i, line in enumerate(model_lines):
        if len(line.strip()) > 0:
            layer = line[: col_starts[1]].strip()
            output_shape = line[col_starts[1] : col_starts[2]].strip()
            params = line[col_starts[2] :].strip()
            data.append([layer, output_shape, params])

    # Parse and add summary info
    for info in summary_lines:
        if info:  # Check if the line is not empty
            parts = info.split(":")
            if len(parts) == 2:
                info_label = parts[0].strip()
                info_value = parts[1].strip()
                data.append([info_label, "", info_value])

    df = pd.DataFrame(data, columns=col_names)

    # Initialize an empty list to store the processed rows
    processed_rows = []

    for i in range(len(df)):
        row = df.iloc[i]
        # Check if the next row needs to be merged with the current row
        if (
            i + 1 < len(df)
            and df.iloc[i + 1]["Layer"].endswith(")")
            and df.iloc[i + 1]["Output Shape"].strip() == ""
            and df.iloc[i + 1]["Param #"].strip() == ""
        ):
            # Merge current and next row
            next_row = df.iloc[i + 1]
            merged_row = [
                row["Layer"] + next_row["Layer"],
                row["Output Shape"],
                row["Param #"],
            ]
            processed_rows.append(merged_row)
        elif (
            row["Layer"].endswith(")")
            and row["Output Shape"] == ""
            and row["Param #"] == ""
        ):
            continue
        else:
            processed_rows.append(row.values.tolist())

    # Convert the processed rows back into a DataFrame
    df = pd.DataFrame(processed_rows, columns=df.columns)

    return df



def wandb_log_model_summary_and_architecture(
    model, log_summary: bool = True, log_architecture_plot: bool = False
) -> None:
    """
    Logs the model summary and architecture to wandb.

    Args:
        model: The model to log the summary and architecture for.
        log_summary: A boolean indicating whether to log the model summary. Default is True.
        log_architecture_plot: A boolean indicating whether to log the model architecture plot. Default is False.

    Returns:
        None

    """
    if log_summary:
        df = model_summary_to_df(model)
        wandb.log({"Model Summary": wandb.Table(dataframe=df)})
    if log_architecture_plot:
        # Use a temporary file to save the model plot
        with tempfile.NamedTemporaryFile(
            delete=False, suffix=".png", dir="."
        ) as tmpfile:
            tf.keras.utils.plot_model(
                model, to_file=tmpfile.name, show_shapes=True, show_layer_names=True
            )
            # Open the temporary file and log it to wandb
            wandb.log({"Model Architecture": wandb.Image(tmpfile.name)})
        # Optionally, delete the temporary file if you don't want it to remain
        os.remove(tmpfile.name)

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd





In [2]:
# start a new run
wandb.init(project="test", reinit=True)

# Define a complex model (example)
model = Sequential(
    [
        Conv2D(
            32,
            kernel_size=(3, 3),
            activation="relu",
            input_shape=(64, 64, 3),
            name="conv2d_1",
        ),
        MaxPooling2D(pool_size=(2, 2), name="max_pooling2d_1"),
        Conv2D(64, (3, 3), activation="relu", name="conv2d_2"),
        MaxPooling2D(pool_size=(2, 2), name="max_pooling2d_2"),
        Flatten(name="flatten"),
        Dense(128, activation="relu", name="dense_1"),
        Dropout(0.5, name="dropout"),
        Dense(
            10, activation="softmax", name="output"
        ),  # Assuming 10 classes for classification
    ]
)

# print summary
print(model.summary())

# Get the model summary DataFrame
df = model_summary_to_df(model)

# log to wandb
# assumes you have already done `wandb.init()`
wandb_log_model_summary_and_architecture(
    model, log_summary=True, log_architecture_plot=True
)

# Display the DataFrame
display(df)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mandrewm4894[0m. Use [1m`wandb login --relogin`[0m to force relogin




Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_1 (Conv2D)           (None, 62, 62, 32)        896       
                                                                 
 max_pooling2d_1 (MaxPoolin  (None, 31, 31, 32)        0         
 g2D)                                                            
                                                                 
 conv2d_2 (Conv2D)           (None, 29, 29, 64)        18496     
                                                                 
 max_pooling2d_2 (MaxPoolin  (None, 14, 14, 64)        0         
 g2D)                                                            
                                                                 
 flatten (Flatten)           (None, 12544)             0         
                                                                 
 dense_1 (Dense)             (None, 128)              

Unnamed: 0,Layer,Output Shape,Param #
0,Layer (type),Output Shape,Param #
1,=============================,==========================,==========
2,conv2d_1 (Conv2D),"(None, 62, 62, 32)",896
3,max_pooling2d_1 (MaxPooling2D),"(None, 31, 31, 32)",0
4,conv2d_2 (Conv2D),"(None, 29, 29, 64)",18496
5,max_pooling2d_2 (MaxPooling2D),"(None, 14, 14, 64)",0
6,flatten (Flatten),"(None, 12544)",0
7,dense_1 (Dense),"(None, 128)",1605760
8,dropout (Dropout),"(None, 128)",0
9,output (Dense),"(None, 10)",1290
