### 1. 提取可训练参数
### model.trainable_variables 返回模型中可训练的参数
### 2. 设置print输出格式 (因为直接print会有很多数据被省略号替换掉)
### np.set_printoptions(threshold=超过多少省略显示)

In [None]:
np.set_printoptions(threshold=np.inf)  # np.inf表示无限大

In [None]:
#在程序末端加上print打印所有可训练参数
print(model.trainable_variables)
#把所有可训练参数存入文本
file = open('./weights.txt', 'w')
for v in model.trainable_variables:
    file.write(str(v.name) + '\n')
    file.write(str(v.shape) + '\n')
    file.write(str(v.numpy()) + '\n')
file.close()

### 例

In [1]:
import tensorflow as tf
import os
import numpy as np

In [2]:
np.set_printoptions(threshold=np.inf)

In [3]:
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

In [4]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

In [5]:
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

In [6]:
checkpoint_save_path = "./checkpoint/mnist.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
    print('-------------load the model-----------------')
    model.load_weights(checkpoint_save_path)

-------------load the model-----------------


In [7]:
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 save_weights_only=True,
                                                 save_best_only=True)
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
                    callbacks=[cp_callback])

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


In [8]:
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten (Flatten)            multiple                  0         
_________________________________________________________________
dense (Dense)                multiple                  100480    
_________________________________________________________________
dense_1 (Dense)              multiple                  1290      
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________


In [9]:
print(model.trainable_variables)
file = open('./weights.txt', 'w')
for v in model.trainable_variables:
    file.write(str(v.name) + '\n')
    file.write(str(v.shape) + '\n')
    file.write(str(v.numpy()) + '\n')
file.close()

[<tf.Variable 'sequential/dense/kernel:0' shape=(784, 128) dtype=float32, numpy=
array([[ 4.30281013e-02,  3.85793075e-02, -1.44901127e-02,
         8.06135908e-02,  6.97688758e-03, -6.76954612e-02,
         7.73446485e-02,  7.70648047e-02,  5.28961420e-04,
        -7.58391842e-02, -1.26056373e-03, -2.84259468e-02,
        -6.72744811e-02, -1.51694864e-02,  5.65186515e-02,
        -5.97575530e-02,  4.81315330e-02, -3.55114229e-02,
         7.75742531e-03, -2.82817222e-02,  7.44178370e-02,
        -1.94587372e-02,  5.19953445e-02, -1.40086561e-03,
         5.04470691e-02, -2.86271237e-02,  7.55075365e-03,
        -2.40655243e-03, -1.96086876e-02, -7.32530132e-02,
         5.86279407e-02, -1.25264823e-02, -2.85029151e-02,
         8.02404657e-02,  7.13820085e-02,  4.67967913e-02,
        -7.49285072e-02, -7.36024380e-02, -7.90561885e-02,
         5.71924224e-02, -2.81935185e-03,  6.12368435e-03,
        -7.46383741e-02, -7.90733993e-02,  6.11763820e-02,
         3.68576199e-02,  2.612956