Skip to content

Commit

Permalink
Merge pull request #231 from arogozhnikov/remove-mxnet
Browse files Browse the repository at this point in the history
Remove mxnet
  • Loading branch information
arogozhnikov committed Dec 27, 2022
2 parents 2f96272 + 76e05c3 commit 05a8643
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 75 deletions.
66 changes: 1 addition & 65 deletions einops/_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,70 +248,6 @@ def layers(self):
return gluon


class MXNetBackend(AbstractBackend):
framework_name = 'mxnet.symbol'

def __init__(self):
import mxnet
self.mx = mxnet

def is_appropriate_type(self, tensor):
return isinstance(tensor, self.mx.symbol.Symbol)

def create_symbol(self, shape, dtype='float32'):
# mxnet accepts zeros as undefined dimensions
shape = tuple(0 if d is None else d for d in shape)
var = self.mx.symbol.Variable('input', shape=shape, dtype=dtype)
return var

def eval_symbol(self, symbol, input_dict):
args = {var.name: self.mx.nd.array(val) for var, val in input_dict}
ex = symbol.bind(ctx=self.mx.cpu(), args=args)
ex.forward()
return ex.outputs[0].asnumpy()

def shape(self, x):
# mxnet has problems with shape inference - it does not provide shape symbols
# shape_array seems to be impossible to use in shape inference
# infer_shape_partial returns empty tuple if was not able to infer shape
# reductions such as sum can't return scalars, but return 1-element vectors
shape = x.infer_shape_partial()[1][0]
if len(shape) == 0:
warnings.warn('mxnet inferred shape to be (), which probably means it could not be inferred')
shape = tuple(UnknownSize() if d == 0 else d for d in shape)
return shape

def reshape(self, x, shape):
if len(shape) == 0:
return x # poor support of scalars in mxnet
if any(isinstance(dimension, UnknownSize) for dimension in shape):
from einops import EinopsError
raise EinopsError("Mxnet couldn't infer all dimensions statically, please provide those with axes_lengths")
return x.reshape(shape)

def arange(self, start, stop):
return self.mx.symbol.arange(start, stop)

def stack_on_zeroth_dimension(self, tensors: list):
return self.mx.symbol.stack(*tensors)

def tile(self, x, repeats):
return self.mx.symbol.tile(x, repeats)

def concat(self, tensors, axis: int):
return self.mx.symbol.concat(tensors, dim=axis)

def add_axis(self, x, new_position):
return self.mx.symbol.expand_dims(x, new_position)

def is_float_type(self, x):
return 'float' in str(x.infer_type()[1][0])

def layers(self):
from .layers import gluon
return gluon


class TorchBackend(AbstractBackend):
framework_name = 'torch'

Expand Down Expand Up @@ -677,4 +613,4 @@ def layers(self):
return oneflow

def einsum(self, pattern, *x):
return self.flow.einsum(pattern, *x)
return self.flow.einsum(pattern, *x)
6 changes: 3 additions & 3 deletions einops/einops.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def reduce(tensor: Tensor, pattern: str, reduction: Reduction, **axes_lengths: i
```
Parameters:
tensor: tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch, mxnet.ndarray).
tensor: tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch).
list of tensors is also accepted, those should be of the same type and shape
pattern: string, reduction pattern
reduction: one of available reductions ('min', 'max', 'sum', 'mean', 'prod'), case-sensitive
Expand Down Expand Up @@ -467,7 +467,7 @@ def rearrange(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths)
Find more examples in einops tutorial.
Parameters:
tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch, mxnet.ndarray).
tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch).
list of tensors is also accepted, those should be of the same type and shape
pattern: string, rearrangement pattern
axes_lengths: any additional specifications for dimensions
Expand Down Expand Up @@ -521,7 +521,7 @@ def repeat(tensor: Tensor, pattern: str, **axes_lengths) -> Tensor:
Find more examples in einops tutorial.
Parameters:
tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch, mxnet.ndarray).
tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch).
list of tensors is also accepted, those should be of the same type and shape
pattern: string, rearrangement pattern
axes_lengths: any additional specifications for dimensions
Expand Down
4 changes: 2 additions & 2 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Usage: python test.py
1. Installs part of dependencies (make sure `which pip` points to correct location)
2. Installs current version of einops in editable mode
3. Runs tests
3. Runs the tests
"""

import os
Expand Down Expand Up @@ -38,7 +38,6 @@ def run(cmd, **env):
# install dependencies
dependencies = [
'numpy',
'mxnet==1.*',
'torch',
'tensorflow',
'chainer',
Expand All @@ -58,6 +57,7 @@ def run(cmd, **env):
# oneflow provides wheels for linux, but not mac, so it is tested only on linux
skip_oneflow = 'linux' not in sys.platform
skip_oneflow = True

if not skip_oneflow:
# oneflow installation: https://github.com/Oneflow-Inc/oneflow#install-with-pip-package
assert 0 == run('pip install -f https://release.oneflow.info oneflow==0.7.0+cpu --user')
Expand Down
5 changes: 1 addition & 4 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,9 @@ def collect_test_backends(symbolic=False, layers=False):
]
else:
if not layers:
backend_types = [
_backends.MXNetBackend,
]
backend_types = []
else:
backend_types = [
_backends.MXNetBackend,
_backends.KerasBackend,
]

Expand Down
2 changes: 1 addition & 1 deletion tests/test_notebooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_all_notebooks():

def test_dl_notebook_with_all_backends():
notebook, = Path(__file__).parent.with_name('docs').glob('2-*.ipynb')
backends = ['chainer', 'gluon', 'pytorch']
backends = ['chainer', 'pytorch']
if 'tensorflow' in collect_test_backends(symbolic=False, layers=False):
backends += ['tensorflow']
for backend in backends:
Expand Down
3 changes: 3 additions & 0 deletions tests/test_other.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def test_backends_installed():
# instantiate
backend_type()
except Exception as e:
if 'mxnet' in backend_type.framework_name:
# test mxnet only if it is installed. Not complain if not installed
continue
errors.append(e)
assert len(errors) == 0, errors

Expand Down

0 comments on commit 05a8643

Please sign in to comment.