Skip to content

Commit

Permalink
Add arm support (#7500)
Browse files Browse the repository at this point in the history
Fixes # .

### Description

Add arm support

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
  • Loading branch information
KumoLiu and ericspod committed Mar 3, 2024
1 parent 02c7f53 commit e9e2738
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 14 deletions.
4 changes: 4 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ FROM ${PYTORCH_IMAGE}

LABEL maintainer="monai.contact@gmail.com"

# TODO: remark for issue [revise the dockerfile](https://github.com/zarr-developers/numcodecs/issues/431)
WORKDIR /opt
RUN git clone --recursive https://github.com/zarr-developers/numcodecs.git && pip wheel numcodecs

WORKDIR /opt/monai

# install full deps
Expand Down
4 changes: 2 additions & 2 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ mypy>=1.5.0
ninja
torchvision
psutil
cucim>=23.2.0; platform_system == "Linux"
cucim-cu12; platform_system == "Linux" and python_version >= "3.9" and python_version <= "3.10"
openslide-python
imagecodecs; platform_system == "Linux" or platform_system == "Darwin"
tifffile; platform_system == "Linux" or platform_system == "Darwin"
Expand All @@ -46,7 +46,7 @@ pynrrd
pre-commit
pydicom
h5py
nni; platform_system == "Linux"
nni; platform_system == "Linux" and "arm" not in platform_machine and "aarch" not in platform_machine
optuna
git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded
onnx>=1.13.0
Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ all =
tqdm>=4.47.0
lmdb
psutil
cucim>=23.2.0
cucim-cu12; python_version >= '3.9' and python_version <= '3.10'
openslide-python
tifffile
imagecodecs
Expand Down Expand Up @@ -111,7 +111,7 @@ lmdb =
psutil =
psutil
cucim =
cucim>=23.2.0
cucim-cu12
openslide =
openslide-python
tifffile =
Expand Down
19 changes: 13 additions & 6 deletions tests/test_convert_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from __future__ import annotations

import itertools
import platform
import unittest

import torch
Expand All @@ -29,6 +30,12 @@
TESTS = list(itertools.product(TORCH_DEVICE_OPTIONS, [True, False], [True, False]))
TESTS_ORT = list(itertools.product(TORCH_DEVICE_OPTIONS, [True]))

ON_AARCH64 = platform.machine() == "aarch64"
if ON_AARCH64:
rtol, atol = 1e-1, 1e-2
else:
rtol, atol = 1e-3, 1e-4

onnx, _ = optional_import("onnx")


Expand Down Expand Up @@ -56,8 +63,8 @@ def test_unet(self, device, use_trace, use_ort):
device=device,
use_ort=use_ort,
use_trace=use_trace,
rtol=1e-3,
atol=1e-4,
rtol=rtol,
atol=atol,
)
else:
# https://github.com/pytorch/pytorch/blob/release/1.9/torch/onnx/__init__.py#L182
Expand All @@ -72,8 +79,8 @@ def test_unet(self, device, use_trace, use_ort):
device=device,
use_ort=use_ort,
use_trace=use_trace,
rtol=1e-3,
atol=1e-4,
rtol=rtol,
atol=atol,
)
self.assertTrue(isinstance(onnx_model, onnx.ModelProto))

Expand Down Expand Up @@ -107,8 +114,8 @@ def test_seg_res_net(self, device, use_ort):
device=device,
use_ort=use_ort,
use_trace=True,
rtol=1e-3,
atol=1e-4,
rtol=rtol,
atol=atol,
)
self.assertTrue(isinstance(onnx_model, onnx.ModelProto))

Expand Down
9 changes: 8 additions & 1 deletion tests/test_dynunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from __future__ import annotations

import platform
import unittest
from typing import Any, Sequence

Expand All @@ -24,6 +25,12 @@

InstanceNorm3dNVFuser, _ = optional_import("apex.normalization", name="InstanceNorm3dNVFuser")

ON_AARCH64 = platform.machine() == "aarch64"
if ON_AARCH64:
rtol, atol = 1e-2, 1e-2
else:
rtol, atol = 1e-4, 1e-4

device = "cuda" if torch.cuda.is_available() else "cpu"

strides: Sequence[Sequence[int] | int]
Expand Down Expand Up @@ -159,7 +166,7 @@ def test_consistency(self, input_param, input_shape, _):
with eval_mode(net_fuser):
result_fuser = net_fuser(input_tensor)

assert_allclose(result, result_fuser, rtol=1e-4, atol=1e-4)
assert_allclose(result, result_fuser, rtol=rtol, atol=atol)


class TestDynUNetDeepSupervision(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_rand_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test_rand_affine(self, input_param, input_data, expected_val):
g.set_random_state(123)
result = g(**input_data)
g.rand_affine_grid.affine = torch.eye(4, dtype=torch.float64) # reset affine
test_resampler_lazy(g, result, input_param, input_data, seed=123)
test_resampler_lazy(g, result, input_param, input_data, seed=123, rtol=_rtol)
if input_param.get("cache_grid", False):
self.assertTrue(g._cached_grid is not None)
assert_allclose(result, expected_val, rtol=_rtol, atol=1e-4, type_test="tensor")
Expand Down
4 changes: 3 additions & 1 deletion tests/test_rand_affined.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,9 @@ def test_rand_affined(self, input_param, input_data, expected_val, track_meta):
lazy_init_param["keys"], lazy_init_param["mode"] = key, mode
resampler = RandAffined(**lazy_init_param).set_random_state(123)
expected_output = resampler(**call_param)
test_resampler_lazy(resampler, expected_output, lazy_init_param, call_param, seed=123, output_key=key)
test_resampler_lazy(
resampler, expected_output, lazy_init_param, call_param, seed=123, output_key=key, rtol=_rtol
)
resampler.lazy = False

if input_param.get("cache_grid", False):
Expand Down
9 changes: 8 additions & 1 deletion tests/test_spatial_resampled.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from __future__ import annotations

import platform
import unittest

import numpy as np
Expand All @@ -23,6 +24,12 @@
from tests.lazy_transforms_utils import test_resampler_lazy
from tests.utils import TEST_DEVICES, assert_allclose

ON_AARCH64 = platform.machine() == "aarch64"
if ON_AARCH64:
rtol, atol = 1e-1, 1e-2
else:
rtol, atol = 1e-3, 1e-4

TESTS = []

destinations_3d = [
Expand Down Expand Up @@ -104,7 +111,7 @@ def test_flips_inverse(self, img, device, dst_affine, kwargs, expected_output):

# check lazy
lazy_xform = SpatialResampled(**init_param)
test_resampler_lazy(lazy_xform, output_data, init_param, call_param, output_key="img")
test_resampler_lazy(lazy_xform, output_data, init_param, call_param, output_key="img", rtol=rtol, atol=atol)

# check inverse
inverted = xform.inverse(output_data)["img"]
Expand Down

0 comments on commit e9e2738

Please sign in to comment.