In [8]:
import tensorflow as tf

def print_tflite_model_summary(model_path):
    # 创建 TFLite 解释器
    interpreter = tf.lite.Interpreter(model_path=model_path)

    # 分配张量内存
    interpreter.allocate_tensors()

    # 获取输入和输出张量的详细信息
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    # 打印模型的输入信息
    print("Input Shape:")
    for input_tensor in input_details:
        print(f"  {input_tensor['name']}: {input_tensor['shape']}")

    # 打印模型的输出信息
    print("\nOutput Shape:")
    for output_tensor in output_details:
        print(f"  {output_tensor['name']}: {output_tensor['shape']}")

    # 打印完整的网络结构
    print("\nNetwork Structure:")
    for i, layer in enumerate(interpreter.get_tensor_details()):
        print(f"Layer {i} - {layer['name']} - Shape: {layer['shape']}")

# 用法示例
# 定义 TFLite 模型文件路径
tflite_model_path = '/home/zhangyouan/桌面/zya/NN_net/network/SSD/IMX_681_ssd_mobilenet_git/keras/detection/Quantization/quantized_model_1214.tflite'
print_tflite_model_summary(tflite_model_path)


Input Shape:
  serving_default_input_1:0: [  1 120 160   1]

Output Shape:
  StatefulPartitionedCall:0: [   1 1242    7]

Network Structure:
Layer 0 - serving_default_input_1:0 - Shape: [  1 120 160   1]
Layer 1 - model/mbox_loc_final/strided_slice/stack - Shape: [1]
Layer 2 - model/mbox_loc_final/strided_slice/stack_1 - Shape: [1]
Layer 3 - model/reshape_8/Reshape/shape/1 - Shape: []
Layer 4 - model/reshape/Reshape/shape/2 - Shape: []
Layer 5 - model/reshape_6/Reshape/shape/1 - Shape: []
Layer 6 - model/reshape_4/Reshape/shape/1 - Shape: []
Layer 7 - model/reshape_2/Reshape/shape/1 - Shape: []
Layer 8 - model/reshape/Reshape/shape/1 - Shape: []
Layer 9 - model/mbox_loc_final/Reshape/shape/2 - Shape: []
Layer 10 - model/Conv2D_loc_DD5_2/BiasAdd/ReadVariableOp - Shape: [24]
Layer 11 - model/Conv2D_loc_DD5_2/Conv2D - Shape: [24  1  1 64]
Layer 12 - model/DepthwiseConv2D_loc_DD5_1/BiasAdd/ReadVariableOp - Shape: [64]
Layer 13 - model/DepthwiseConv2D_loc_DD5_1/depthwise - Shape: [ 1  3  3 

In [2]:
from keras.models import load_model

# 加载模型
h5_model_path = '/home/zhangyouan/桌面/zya/NN_net/network/SSD/IMX_681_ssd_mobilenet_git/keras/detection/SSD_ipynb_two_objects/output/20231206/20231213_good_detection_test_callback.h5'

# 注册自定义对象
custom_objects = {'compute_loss': None}

# 加载模型时传递 custom_objects 参数
loaded_model = load_model(h5_model_path, custom_objects=custom_objects)

# 打印模型结构
loaded_model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, 120, 160, 1)]        0         []                            
                                                                                                  
 DepthWiseConv2D_layer1 (De  (None, 120, 160, 8)          80        ['input_1[0][0]']             
 pthwiseConv2D)                                                                                   
                                                                                                  
 re_lu (ReLU)                (None, 120, 160, 8)          0         ['DepthWiseConv2D_layer1[0][0]
                                                                    ']                            
                                                                                              