In [1]:
import tensorflow as tf
import json


def load_model_from_file(file_path):
    # Load the model from the .h5 file
    model = tf.keras.models.load_model(file_path)
    return model


def extract_layer_details(model):
    layer_details = []

    for layer in model.layers:
        layer_info = {}
        layer_info["name"] = layer.name
        layer_info["type"] = type(layer).__name__

        # Extract input and output shapes
        try:
            layer_info["input_shape"] = layer.input_shape
            layer_info["output_shape"] = layer.output_shape
        except AttributeError:
            # Convert KerasTensor to shape tuple
            layer_info["input_shape"] = (
                tuple(layer.input.shape) if hasattr(layer, "input") else None
            )
            layer_info["output_shape"] = (
                tuple(layer.output.shape) if hasattr(layer, "output") else None
            )

        # Extract weights and biases if available
        if hasattr(layer, "get_weights"):
            weights = layer.get_weights()
            if weights:
                layer_info["weights"] = weights[0].tolist()  # Convert numpy arrays to lists for JSON serialization
                if len(weights) > 1:
                    layer_info["biases"] = weights[1].tolist()

        # Additional layer-specific attributes
        if isinstance(layer, tf.keras.layers.Conv2D):
            layer_info["kernel_size"] = layer.kernel_size
            layer_info["filters"] = layer.filters
            layer_info["strides"] = layer.strides
            layer_info["padding"] = layer.padding

        layer_details.append(layer_info)

    return layer_details


def save_to_json(layer_details, filename="config.json"):
    with open(filename, "w") as json_file:
        json.dump({"layers": layer_details}, json_file, indent=4)


# Load the model from the .h5 file (update with your actual file path)
model = load_model_from_file("mnist_1.h5")

# Extract layer details
layer_details = extract_layer_details(model)

# Save to JSON
save_to_json(layer_details)


