Skip to content
Permalink
master
Go to file
 
 
Cannot retrieve contributors at this time
54 lines (44 sloc) 1.69 KB
import numpy as np
from keras import backend as K
from keras.engine import Layer, InputSpec
class Slice(Layer):
def __init__(self, indices, axis=1, **kwargs):
self.supports_masking = True
self.axis = axis
if isinstance(indices, slice):
self.indices = (indices.start, indices.stop, indices.step)
else:
self.indices = indices
self.slices = [ slice(None) ] * self.axis
if isinstance(self.indices, int):
self.slices.append(self.indices)
elif isinstance(self.indices, (list, tuple)):
self.slices.append(slice(*self.indices))
else:
raise TypeError("indices must be int or slice")
super(Slice, self).__init__(**kwargs)
def call(self, inputs, mask=None):
return inputs[self.slices]
def compute_output_shape(self, input_shape):
input_shape = list(input_shape)
for i, slice in enumerate(self.slices):
if i == self.axis:
continue
start = slice.start or 0
stop = slice.stop or input_shape[i]
step = slice.step or 1
input_shape[i] = None if stop is None else (stop - start) // step
del input_shape[self.axis]
return tuple(input_shape)
def compute_mask(self, x, mask=None):
if mask is None:
return mask
if self.axis == 1:
return mask[self.slices]
else:
return mask
def get_config(self):
config = {'axis': self.axis,
'indices': self.indices}
base_config = super(Slice, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
You can’t perform that action at this time.