# Tensorflow自定义层

自定义的层需要继承自keras.layers.Layer
自定义网络继承自 keras.Model

需要实现的方法

- __init__
- call

In [7]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Sequential,layers,metrics,losses,optimizers,datasets

In [8]:
# 导入数据
(x,y),(x_test,y_test) = datasets.fashion_mnist.load_data()
batch_size = 100
def pre_process(x,y):
    x = tf.reshape(x,[28*28])
    x = tf.cast(x,dtype=tf.float32)/255.0
    y = tf.cast(y,dtype=tf.int32)
    return x,y


db = tf.data.Dataset.from_tensor_slices((x,y))
db = db.map(pre_process)
db = db.shuffle(5000).batch(batch_size)

db_test = tf.data.Dataset.from_tensor_slices((x_test,y_test))
db_test = db_test.map(pre_process)
db_test = db_test.shuffle(5000).batch(batch_size)

In [15]:
 # 自定义层
class MyDense(layers.Layer):
    def __init__(self,inp_dim,outp_dim):
        super(MyDense,self).__init__()
        self.kernel = self.add_variable('w',[inp_dim,outp_dim])
        self.bias = self.add_variable('b',[outp_dim])
        
    def call(self,inputs , training = None):
        out = inputs @ self.kernel + self.bias
        return out
    
# 自定义网络
class MyModel(keras.Model):
    def __init__(self):
        super(MyModel,self).__init__()
        self.fc1 = MyDense(28*28,256)
        self.fc2 = MyDense(256,128)
        self.fc3 = MyDense(128,10)
        
    def call(self,inputs,training = None):
        x = self.fc1(inputs)
        x = tf.nn.sigmoid(x)
        x = self.fc2(x)
        x = tf.nn.sigmoid(x)
        x = self.fc3(x)
        x = tf.nn.softmax(x)
        return x
    

        

In [16]:
network = MyModel()

In [17]:
network.compile(optimizer=optimizers.SGD(lr = 0.001),
                loss = 'sparse_categorical_crossentropy',
                metrics=['accuracy'])

In [18]:
network.fit(db,epochs=5,validation_data=db_test,validation_freq=2)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x22082f9b438>

In [19]:
network.summary()

Model: "my_model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
my_dense_6 (MyDense)         multiple                  200960    
_________________________________________________________________
my_dense_7 (MyDense)         multiple                  32896     
_________________________________________________________________
my_dense_8 (MyDense)         multiple                  1290      
Total params: 235,146
Trainable params: 235,146
Non-trainable params: 0
_________________________________________________________________
