# Layer

In [1]:
from keras.engine import Layer, InputSpec
from keras.layers import Flatten
import tensorflow as tf
import keras.backend as K
from keras import Model

class KMaxPooling(Layer):
    """
    K-max pooling layer that extracts the k-highest activations from a sequence (2nd dimension).
    TensorFlow backend.
    """
    def __init__(self, k=204800, **kwargs):
        super().__init__(**kwargs)
        self.input_spec = InputSpec(ndim=4)
        self.k = k

    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.k)

    def call(self, inputs):
        flattened_input = K.batch_flatten(inputs)
        # swap last two dimensions since top_k will be applied along the last dimension
        # shifted_input = tf.transpose(flattened_input, [0, 2, 1])
        
        # extract top_k, returns two tensors [values, indices]
        top_k = tf.nn.top_k(flattened_input, k=self.k, sorted=True, name=None)[0]
        
        
        # return flattened output
        return top_k

Using TensorFlow backend.


# Usage example

In [0]:
from keras.applications import xception
from keras.layers import Flatten, Dense

num_classes=15

base_xception_model = xception.Xception(weights='imagenet', include_top=False, input_shape=(None, None, 3))
x_xception = base_xception_model.output
#x_xception = Flatten()(x_xception)
x_xception = KMaxPooling()(x_xception)
x_xception = Dense(512, activation='relu')(x_xception)
predictions_xception = Dense(num_classes, activation='softmax')(x_xception)

model_xception = Model(inputs=base_xception_model.input, outputs=predictions_xception)
print(model_xception.summary())

## How to use the kmaxpooling layer
The k-maxpooling layer should replace the flatten layer in the implementation.
It flattens the input it gets from the previous convolution network (In the example, XCeption)
And returns the k (default=204800) highest values. 
Because of this it is possible to give multiple input sizes to the network (as long as the images are big enough). Which in the example is shown by giving input shape (None, None, 3).