Skip to content

Commit

Permalink
separate layers for different frameworks into individual files
Browse files Browse the repository at this point in the history
  • Loading branch information
arogozhnikov committed Oct 16, 2018
1 parent d1534f4 commit a95d509
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 72 deletions.
90 changes: 18 additions & 72 deletions layers.py → layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,20 @@

import functools

from einops import rearrange, TransformRecipe, _prepare_transformation_recipe, EinopsError
from einops import TransformRecipe, _prepare_transformation_recipe, EinopsError


# TODO tests for serialization / deserialization inside the model
# TODO docstrings
# TODO make imports like from einops.torch import ...

class RearrangeMixin:
"""
Rearrange layer behaves identically to einops.rearrange operation.
:param pattern: str, rearrangement pattern
:param axes_lengths: any additional specification of dimensions
See einops.rearrange for examples.
"""
def __init__(self, pattern, **axes_lengths):
super().__init__()
self.pattern = pattern
Expand Down Expand Up @@ -38,6 +44,15 @@ def _apply_recipe(self, x):


class ReduceMixin:
"""
Reduce layer behaves identically to einops.reduce operation.
:param pattern: str, rearrangement pattern
:param reduction: one of available reductions ('min', 'max', 'sum', 'mean', 'prod'), case-sensitive
:param axes_lengths: any additional specification of dimensions
See einops.reduce for examples.
"""
def __init__(self, pattern, reduction, **axes_lengths):
super().__init__()
self.pattern = pattern
Expand Down Expand Up @@ -66,74 +81,5 @@ def _apply_recipe(self, x):
raise EinopsError(' Error while computing {!r}\n {}'.format(self, e))


import torch


class TorchRearrange(RearrangeMixin, torch.nn.Module):
def forward(self, input):
return self._apply_recipe(input)


class TorchReduce(ReduceMixin, torch.nn.Module):
def forward(self, input):
return self._apply_recipe(input)


import chainer


class ChainerRearrange(RearrangeMixin, chainer.Link):
def __call__(self, x):
return self._apply_recipe(x)


class ChainerReduce(ReduceMixin, chainer.Link):
def __call__(self, x):
return self._apply_recipe(x)


import mxnet


# TODO symbolic is not working right now

class GluonRearrange(RearrangeMixin, mxnet.gluon.HybridBlock):
def hybrid_forward(self, F, x):
return self._apply_recipe(x)


class GluonReduce(ReduceMixin, mxnet.gluon.HybridBlock):
def hybrid_forward(self, F, x):
return self._apply_recipe(x)


from keras.engine.topology import Layer


class KerasRearrange(RearrangeMixin, Layer):
def compute_output_shape(self, input_shape):
input_shape = tuple(None if d is None else int(d) for d in input_shape)
init_shapes, reduced_axes, axes_reordering, final_shapes = self.recipe().reconstruct_from_shape(input_shape)
return final_shapes

def call(self, inputs):
return self._apply_recipe(inputs)

def get_config(self):
return {'pattern': self.pattern, **self.axes_lengths}


class KerasReduce(ReduceMixin, Layer):
def compute_output_shape(self, input_shape):
input_shape = tuple(None if d is None else int(d) for d in input_shape)
init_shapes, reduced_axes, axes_reordering, final_shapes = self.recipe().reconstruct_from_shape(input_shape)
return final_shapes

def call(self, inputs):
return self._apply_recipe(inputs)

def get_config(self):
return {'pattern': self.pattern, 'reduction': self.reduction, **self.axes_lengths}


keras_custom_objects = {'KerasRearrange': KerasRearrange, 'KerasReduce': KerasReduce}
15 changes: 15 additions & 0 deletions layers/chainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import chainer

from layers import RearrangeMixin, ReduceMixin

__author__ = 'Alex Rogozhnikov'


class Rearrange(RearrangeMixin, chainer.Link):
def __call__(self, x):
return self._apply_recipe(x)


class Reduce(ReduceMixin, chainer.Link):
def __call__(self, x):
return self._apply_recipe(x)
16 changes: 16 additions & 0 deletions layers/gluon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import mxnet

from layers import RearrangeMixin, ReduceMixin

__author__ = 'Alex Rogozhnikov'


# TODO symbolic is not working right now
class Rearrange(RearrangeMixin, mxnet.gluon.HybridBlock):
def hybrid_forward(self, F, x):
return self._apply_recipe(x)


class Reduce(ReduceMixin, mxnet.gluon.HybridBlock):
def hybrid_forward(self, F, x):
return self._apply_recipe(x)
34 changes: 34 additions & 0 deletions layers/keras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from keras.engine import Layer

from layers import RearrangeMixin, ReduceMixin

__author__ = 'Alex Rogozhnikov'


class Rearrange(RearrangeMixin, Layer):
def compute_output_shape(self, input_shape):
input_shape = tuple(None if d is None else int(d) for d in input_shape)
init_shapes, reduced_axes, axes_reordering, final_shapes = self.recipe().reconstruct_from_shape(input_shape)
return final_shapes

def call(self, inputs):
return self._apply_recipe(inputs)

def get_config(self):
return {'pattern': self.pattern, **self.axes_lengths}


class Reduce(ReduceMixin, Layer):
def compute_output_shape(self, input_shape):
input_shape = tuple(None if d is None else int(d) for d in input_shape)
init_shapes, reduced_axes, axes_reordering, final_shapes = self.recipe().reconstruct_from_shape(input_shape)
return final_shapes

def call(self, inputs):
return self._apply_recipe(inputs)

def get_config(self):
return {'pattern': self.pattern, 'reduction': self.reduction, **self.axes_lengths}


keras_custom_objects = {Rearrange.__name__: Rearrange, Reduce.__name__: Reduce}
15 changes: 15 additions & 0 deletions layers/torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch

from layers import RearrangeMixin, ReduceMixin

__author__ = 'Alex Rogozhnikov'


class Rearrange(RearrangeMixin, torch.nn.Module):
def forward(self, input):
return self._apply_recipe(input)


class Reduce(ReduceMixin, torch.nn.Module):
def forward(self, input):
return self._apply_recipe(input)

0 comments on commit a95d509

Please sign in to comment.