In [13]:
import tensorflow as tf
import numpy as np

from tensorflow.python.keras.initializers import (Zeros, glorot_normal,
                                                  glorot_uniform)
from tensorflow.python.keras.regularizers import l2
from tensorflow.python.keras import backend as K

In [28]:
class CrossNet(tf.keras.layers.Layer):
    def __init__(self,layer_num=1, l2_reg=0, seed=1024, **kwargs):
        self.layer_num = layer_num
        self.l2_reg = l2_reg
        self.seed = seed
        super(CrossNet,self).__init__(**kwargs)
    
    def build(self,input_shape):
        print("build input_shape",input_shape)
        
        if len(input_shape) != 2:
            raise ValueError(
                "Unexpected inputs dimensions %d, expect to be 2 dimensions" % (len(input_shape),))
        
        dim = int(input_shape[-1])
        self.kernels = [self.add_weight(name='kernel' + str(i),
                                        shape=(dim, 1),
                                        initializer=glorot_normal(
                                            seed=self.seed),
                                        regularizer=l2(self.l2_reg),
                                        trainable=True) for i in range(self.layer_num)]
        self.bias = [self.add_weight(name='bias' + str(i),
                                     shape=(dim, 1),
                                     initializer=Zeros(),
                                     trainable=True) for i in range(self.layer_num)]
        # Be sure to call this somewhere!
        super(CrossNet, self).build(input_shape)
    
    def call(self, inputs, **kwargs):
        if K.ndim(inputs) != 2:
            raise ValueError(
                "Unexpected inputs dimensions %d, expect to be 2 dimensions" % (K.ndim(inputs)))
        
        print("call inputs.shape",inputs.shape)
        x_0 = tf.expand_dims(inputs, axis=2)
        x_l = x_0
        print("call x_0.shape:",x_0.shape)
        for i in range(self.layer_num):
            xl_w = tf.tensordot(x_l, self.kernels[i], axes=(1, 0))# x_l的第一维和self.kernels[i]的第0维度进行点成(相乘在相加)
            print("self.kernel shape:",self.kernels[i].shape)
            print("xl_w shape:",xl_w.shape)
            dot_ = tf.matmul(x_0, xl_w)
            print("dot_ shape:",dot_.shape)
            x_l = dot_ + self.bias[i] + x_l
        x_l = tf.squeeze(x_l, axis=2)
        return x_l

    def get_config(self, ):

        config = {'layer_num': self.layer_num,
                  'l2_reg': self.l2_reg, 'seed': self.seed}
        base_config = super(CrossNet, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return input_shape

In [59]:
input1 = tf.keras.layers.Input(shape=(3,))
input2 = tf.keras.layers.Input(shape=(3,))
concatenate_layer = tf.keras.layers.concatenate([input1,input2])

output = CrossNet()(concatenate_layer)

build input_shape (None, 6)
call inputs.shape (None, 6)
call x_0.shape: (None, 6, 1)
self.kernel shape: (6, 1)
xl_w shape: (None, 1, 1)
dot_ shape: (None, 6, 1)
