# Keras 系统 Flatten 摊平层

In [4]:
from keras import backend as K
from keras.engine.topology import Layer
 
from keras.engine.base_layer import InputSpec

import numpy as np



class Flatten(Layer): 
    
    def __init__(self, data_format=None, **kwargs):
        super(Flatten, self).__init__(**kwargs)
        self.input_spec = InputSpec(min_ndim=3)
        self.data_format = K.normalize_data_format(data_format)
        
    def compute_output_shape(self, input_shape):
        if not all(input_shape[1:]):
            raise ValueError('The shape of the input to "Flatten" '
                             'is not fully defined '
                             '(got ' + str(input_shape[1:]) + '). '
                             'Make sure to pass a complete "input_shape" '
                             'or "batch_input_shape" argument to the first '
                             'layer in your model.')
        return (input_shape[0], np.prod(input_shape[1:]))
    
    def call(self, inputs):
        if self.data_format == 'channels_first':
            # Ensure works for any dim
            permutation = [0]
            permutation.extend([i for i in
                                range(2, K.ndim(inputs))])
            permutation.append(1)
            inputs = K.permute_dimensions(inputs, permutation)

        return K.batch_flatten(inputs)
    
    def get_config(self):
        config = {'data_format':self.data_format}
        base_config = super(Flatten, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

In [5]:
from keras.models import Sequential
from keras import layers

model = Sequential()
model.add(layers.Conv2D(64, (3, 3), input_shape=(3, 32, 32), padding='same',))
# now: model.output_shape == (None, 64, 32, 32)
model.add(Flatten())
# now: model.output_shape == (None, 65536)
model.summary()
    

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_2 (Conv2D)            (None, 3, 32, 64)         18496     
_________________________________________________________________
flatten_2 (Flatten)          (None, 6144)              0         
Total params: 18,496
Trainable params: 18,496
Non-trainable params: 0
_________________________________________________________________
