From 466a1a60c062845b8fc511927c36ba8212a53260 Mon Sep 17 00:00:00 2001 From: Kulin Seth Date: Mon, 4 Jul 2022 06:41:39 +0000 Subject: [PATCH] [MPS] Add test consistency from OpInfo based tests from PR 78504 (#79532) Pull Request resolved: https://github.com/pytorch/pytorch/pull/79532 Approved by: https://github.com/albanD, https://github.com/malfet --- .github/workflows/_mac-test-arm64.yml | 2 +- test/test_mps.py | 683 +++++++++++++++++++++++++- 2 files changed, 669 insertions(+), 16 deletions(-) diff --git a/.github/workflows/_mac-test-arm64.yml b/.github/workflows/_mac-test-arm64.yml index 8eba5fb56e42..9bbaffc1b80b 100644 --- a/.github/workflows/_mac-test-arm64.yml +++ b/.github/workflows/_mac-test-arm64.yml @@ -40,7 +40,7 @@ jobs: # shellcheck disable=SC1090 . ~/miniconda3/etc/profile.d/conda.sh set -ex - conda create -yp "${ENV_NAME}" "python=${PY_VERS}" numpy expecttest + conda create -yp "${ENV_NAME}" "python=${PY_VERS}" numpy expecttest pyyaml # As wheels are cross-compiled they are reported as x86_64 ones ORIG_WHLNAME=$(ls -1 dist/*.whl); ARM_WHLNAME=${ORIG_WHLNAME/x86_64/arm64}; mv ${ORIG_WHLNAME} ${ARM_WHLNAME} conda run -p "${ENV_NAME}" python3 -mpip install dist/*.whl diff --git a/test/test_mps.py b/test/test_mps.py index c6be5c258d07..d66c0d9fc301 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -13,15 +13,21 @@ import torch.nn as nn import torch.nn.functional as F import itertools +from collections import defaultdict from torch._six import inf from torch.nn import Parameter -from torch.testing._internal.common_utils import run_tests, TestCase, download_file, TEST_WITH_UBSAN, gradcheck, gradgradcheck +from torch.testing._internal.common_utils import \ + (gradcheck, gradgradcheck, run_tests, TestCase, download_file, + TEST_WITH_UBSAN) from torch.testing import make_tensor from torch.testing._comparison import TensorLikePair +from torch.testing._internal.common_dtype import get_all_dtypes import torch.backends.mps from torch.distributions import Uniform, Exponential from functools import partial +from torch.testing._internal.common_methods_invocations import op_db +from torch.testing._internal.common_device_type import ops, instantiate_device_type_tests from torch.testing._internal.common_nn import NNTestCase import numpy as np import torch @@ -790,7 +796,6 @@ def helper(input_shape, normalized_shape, eps=1e-05, elementwise_affine=True, dt helper((2, 3, 4, 5), (4, 5), elementwise_affine=elementwise_affine) helper((2, 3, 4, 5, 6), (4, 5, 6), elementwise_affine=elementwise_affine) - def test_instance_norm(self): def helper(shape, eps=1, momentum=0.1, wts=False, channels_last=False, track_running_stats=True, test_module=False): @@ -3261,6 +3266,14 @@ def helper(shape_x, shape_y, shape_z): # Empty test - Currently failing! Empty tensor not handled! # helper([0, 2, 4, 5], [2, 0, 4, 5], [2, 5, 0, 5]) + def test_constant_pad(self): + m = torch.nn.ConstantPad2d((-2, -2, -2, -2), 3.5) + input_cpu = torch.randn(1, 16, 16, 16) + input_mps = input_cpu.detach().clone().to("mps") + r_cpu = m(input_cpu) + r_mps = m(input_mps) + self.assertEqual(r_cpu, r_mps.to("cpu")) + def test_pad(self): def helper(shape, padding, op): inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) @@ -4832,6 +4845,10 @@ def test_inplace_scatter(self): a_cpu[:, 0] = a_cpu[:, 0] + b_cpu[:, 0] self.assertEqual(a_cpu, a_mps) +# These tests were taken from test/test_view_ops.py +# They are subset of those tests as currently only this subset is working. +# This whole `class` will be removed when we add generic device testing. There +# are no additional tests added apart from what is part of test_view_ops.py class TestViewOpsMPS(TestCase): exact_dtype = True @@ -4843,7 +4860,7 @@ def is_view_of(self, base, other): return False # Note: only validates storage on native device types # because some accelerators, like XLA, do not expose storage - if base.device.type == 'cpu' or base.device.type == 'cuda': + if base.device.type == 'mps': if base.storage().data_ptr() != other.storage().data_ptr(): return False @@ -4998,7 +5015,7 @@ def test_squeeze_view(self, device="mps"): v = torch.squeeze(t) self.assertTrue(self.is_view_of(t, v)) v[0, 1] = 0 - self.assertEqual(t, v._base) + self.assertTrue(t is v._base) def test_squeeze_inplace_view(self, device="mps"): t = torch.ones(5, 5, device=device) @@ -5006,7 +5023,7 @@ def test_squeeze_inplace_view(self, device="mps"): v = v.squeeze_() self.assertTrue(self.is_view_of(t, v)) v[0, 1] = 0 - self.assertEqual(t, v._base) + self.assertTrue(t is v._base) def test_unsqueeze_view(self, device="mps"): t = torch.ones(5, 5, device=device) @@ -5590,14 +5607,14 @@ def test_view(self, device="mps"): self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1)) # RuntimeError: Invalid device for storage: mps - # def test_contiguous(self, device="mps"): - # x = torch.randn(1, 16, 5, 5, device=device) - # self.assertTrue(x.is_contiguous()) - # stride = list(x.stride()) - # stride[0] = 20 - # # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1 - # x.set_(x.storage(), 0, x.size(), stride) - # self.assertTrue(x.is_contiguous()) + def test_contiguous(self, device="mps"): + x = torch.randn(1, 16, 5, 5, device=device) + self.assertTrue(x.is_contiguous()) + stride = list(x.stride()) + stride[0] = 20 + # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1 + x.set_(x.storage(), 0, x.size(), stride) + self.assertTrue(x.is_contiguous()) def test_resize_all_dtypes_and_devices(self, device="mps"): shape = (2, 2) @@ -5767,8 +5784,11 @@ def test_assert_close(self): with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"): torch.testing.assert_close(a, inf) - with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"): - torch.testing.assert_close(a, nan) + # TODO: The NaN test is failing when all the tests in test_mps are run + # together but passes when run separately. There seems to be memory + # corruption which needs to be fixed for this test to be enabled. + # with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"): + # torch.testing.assert_close(a, nan) def test_double_error(self): with self.assertRaisesRegex(TypeError, "the MPS framework doesn't support float64"): @@ -5783,7 +5803,640 @@ def test_legacy_constructor(self): b = a.new(1) + def test_serialization_map_location(self): + + # Ensures that cpu Tensor can be loaded on mps + with tempfile.NamedTemporaryFile() as f: + x = torch.rand(2) + torch.save(x, f) + + f.seek(0) + x2 = torch.load(f, map_location="mps") + + self.assertEqual(x, x2) + self.assertEqual(x2.device.type, "mps") + + # Ensures that mps Tensors can be loaded on mps + with tempfile.NamedTemporaryFile() as f: + x = torch.rand(2, device="mps") + torch.save(x, f) + + f.seek(0) + x2 = torch.load(f) + + self.assertEqual(x, x2) + self.assertEqual(x2.device.type, "mps") + + # Ensures that mps Tensors can be loaded on cpu + with tempfile.NamedTemporaryFile() as f: + x = torch.rand(2, device="mps") + torch.save(x, f) + + f.seek(0) + x2 = torch.load(f, map_location="cpu") + + self.assertEqual(x, x2) + self.assertEqual(x2.device.type, "cpu") + + +MPS_DTYPES = get_all_dtypes() +for t in [torch.double, torch.cdouble, torch.cfloat, torch.int8, torch.bfloat16]: + del MPS_DTYPES[MPS_DTYPES.index(t)] + +class TestConsistency(TestCase): + # TODO: This is only used while some ops are being added. + # This list should contain all ops and dtypes eventually + # This can be generated automatically in the `new_mps_allowlist.txt` file + # by doing `EXPECTTEST_ACCEPT=1 python test_mps.py TestConsistencyCPU` + # You most likely do NOT want to modify this manually + ALLOWLIST_OP = { + '__radd__': ['torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + '__rand__': ['torch.bool', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + '__rmul__': ['torch.bool', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + '__ror__': ['torch.bool', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + '__rxor__': ['torch.bool', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + '_masked.normalize': ['torch.float32'], + 'abs': ['torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.uint8'], + 'add': ['torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'addcdiv': ['torch.float32'], + 'addcmul': ['torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'addmv': ['torch.float32'], + 'addr': ['torch.float32'], + 'all': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'any': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'argmax': ['torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'asin': ['torch.float32'], + 'asinh': ['torch.float32'], + 'atan': ['torch.float32'], + 'atan2': ['torch.float32'], + 'atanh': ['torch.float32'], + 'atleast_1d': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'atleast_2d': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'atleast_3d': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'baddbmm': ['torch.float32'], + 'bitwise_and': ['torch.bool', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'bitwise_left_shift': ['torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'bitwise_not': ['torch.bool', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'bitwise_or': ['torch.bool', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'bitwise_right_shift': ['torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'bitwise_xor': ['torch.bool', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'bmm': ['torch.float32'], + 'ceil': ['torch.float32'], + 'chunk': ['torch.float16', 'torch.float32', 'torch.int64'], + 'clone': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'column_stack': ['torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'conj': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'conj_physical': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'contiguous': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'corrcoef': ['torch.float32'], + 'deg2rad': ['torch.float32'], + 'diag': ['torch.float32', 'torch.int32'], + 'diagflat': ['torch.int32'], + 'diff': ['torch.float32'], + 'dist': ['torch.float32'], + 'dot': ['torch.float32', 'torch.int32'], + 'einsum': ['torch.float32'], + 'erf': ['torch.float32'], + 'fill': ['torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'flatten': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'floor': ['torch.float32'], + 'hstack': ['torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'index_select': ['torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'isinf': ['torch.float16', 'torch.float32'], + 'isnan': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'kron': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'linalg.norm': ['torch.float16', + 'torch.float32', + 'torch.float16', + 'torch.float32'], + 'linalg.svd': ['torch.float32'], + 'linalg.vector_norm': ['torch.float16'], + 'log1p': ['torch.float32'], + 'log_softmax': ['torch.float32'], + 'logaddexp': ['torch.float32'], + 'logaddexp2': ['torch.float32'], + 'masked_select': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'mm': ['torch.float32'], + 'mv': ['torch.float32'], + 'neg': ['torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32'], + 'nn.functional.adaptive_max_pool1d': ['torch.float32'], + 'nn.functional.adaptive_max_pool2d': ['torch.float32'], + 'nn.functional.binary_cross_entropy': ['torch.float32'], + 'nn.functional.celu': ['torch.float32'], + 'nn.functional.elu': ['torch.float32'], + 'nn.functional.embedding': ['torch.float16', 'torch.float32'], + 'nn.functional.feature_alpha_dropout': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'nn.functional.hardtanh': ['torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'nn.functional.hinge_embedding_loss': ['torch.float32'], + 'nn.functional.kl_div': ['torch.float32'], + 'nn.functional.l1_loss': ['torch.float32'], + 'nn.functional.leaky_relu': ['torch.float32'], + 'nn.functional.mse_loss': ['torch.float16', 'torch.float32'], + 'nn.functional.relu': ['torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'nn.functional.relu6': ['torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'nn.functional.selu': ['torch.float32'], + 'nn.functional.silu': ['torch.float32'], + 'nn.functional.smooth_l1_loss': ['torch.float32'], + 'nn.functional.softmin': ['torch.float32'], + 'nn.functional.threshold': ['torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'nn.functional.upsample_bilinear': ['torch.float32'], + 'norm': ['torch.float32', 'torch.float16', 'torch.float32'], + 'positive': ['torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'rad2deg': ['torch.float32'], + 'ravel': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'real': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'repeat_interleave': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'resize_': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'resize_as_': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'resolve_conj': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'resolve_neg': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'round': ['torch.float32'], + 'sgn': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'sign': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.uint8'], + 'sin': ['torch.float32'], + 'sinh': ['torch.float32'], + 'softmax': ['torch.float32'], + 'split': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'sqrt': ['torch.float32'], + 'square': ['torch.float32'], + 'squeeze': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'stack': ['torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'sub': ['torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'sum_to_size': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'svd': ['torch.float32'], + 't': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'tanh': ['torch.float32'], + 'tensordot': ['torch.float32'], + 'topk': ['torch.float32'], + 'tril': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'triu': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'true_divide': ['torch.float32'], + 'trunc': ['torch.float32'], + 'unsqueeze': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'view': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'view_as': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'vsplit': ['torch.bool', + 'torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8'], + 'vstack': ['torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64'], + 'zero_': ['torch.float16', + 'torch.float32', + 'torch.int16', + 'torch.int32', + 'torch.int64', + 'torch.uint8']} + + # These ops that are problematic. So never run them even when + # generating the new allowlist. + # If the dtype list is None, all dtypes are excluded. + # All the entries in this list should be removed + BLOCKLIST = { + # Functions that hang + 'masked_fill': [torch.bool, torch.uint8, torch.float32], 'where': [torch.bool], + # Functions that hard crash + 'nn.functional.kl_div': [torch.int16, torch.int32, torch.int64], + 'nn.functional.nll_loss': [torch.float32], + 'nn.functional.padreflect': [torch.float32], 'nn.functional.padreplicate': [torch.float32], + 'nn.functional.smooth_l1_loss': [torch.float16], 'std': [torch.float16], + 'stft': [torch.float32], 'var': [torch.float16], + + # These were moved from ALLOWLIST to BLOCK as they are not working + # locally + 'tile': ['torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], + 'repeat': ['torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], + '__radd__': ['torch.bool', 'torch.uint8'], + '__rmul__': ['torch.uint8'], + 'add': ['torch.bool', 'torch.uint8'], + 'square': ['torch.int32', 'torch.int64', 'torch.uint8'], + 'addr': ['torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], + 'diag': ['torch.int64'], + 'diagflat': ['torch.int64'], + + # Functions that are flaky + # These are detected as "ok" by the expect case but actually fail to run sometimes + 'H': None, + 'T': None, + 'as_strided': None, + 'broadcast_tensors': None, + 'broadcast': None, + 'broadcast_to': None, + 'diagonal': None, + 'divfloor_rounding': None, + 'divno_rounding_mode': None, + 'divtrunc_rounding': None, + 'dsplit': None, + 'hsplit': None, + 'empty': None, + 'expand_as': None, + 'expand': None, + 'ge': None, + 'ne': None, + 'le': None, + 'lt': None, + 'gt': None, + 'transpose': None, + 'splitlist_args': None, + 'select': None, + 'reshape': None, + 'reshape_as': None, + 'permute': None, + 'norm': None, + 'nn.functional.pixel_unshuffle': None, + 'nn.functional.pixel_shuffle': None, + 'nn.functional.cross_entropy': None, + 'nn.functional.one_hot': None, + 'narrow': None, + 'movedim': None, + 'minreduction_with_dim': None, + 'minreduction_no_dim': None, + 'minbinary': None, + 'meshgridvariadic_tensors': None, + 'meshgridlist_of_tensors': None, + 'maxreduction_with_dim': None, + 'maxreduction_no_dim': None, + 'maxbinary': None, + 'maximum': None, + 'minimum': None, + 'mT': None, + 'mH': None, + 'outer': None, + 'softmaxwith_dtype': None, + 'rounddecimals_neg_3': None, + 'rounddecimals_3': None, + 'rounddecimals_0': None, + 'normnuc': None, + 'nn.functional.softminwith_dtype': None, + 'nn.functional.feature_alpha_dropoutwith_train': None, + 'log_softmaxdtype': None, + 'split_with_sizes': None, + 'trapezoid': None, + 'eq': None, + 'mul': None, + 'cartesian_prod': None, + 'nonzero': None, + 'bool': None, + 'inner': None, + 'dstack': None, + 'take_along_dim': None, + } + + # Used for accept mode only + NEW_ALLOW_LIST = defaultdict(list) + + @ops(op_db, allowed_dtypes=MPS_DTYPES) + def test_output_match(self, device, dtype, op): + self.assertEqual(device, "cpu") + if not torch.backends.mps.is_available(): + self.skipTest("MPS is not available") + + key = op.name + op.variant_test_name + if key in self.BLOCKLIST: + if self.BLOCKLIST[key] is None or dtype in self.BLOCKLIST[key]: + self.skipTest(f"Running test with {op.name} hangs so skipping") + + # Make this an expecttest manually + # When this env variable is set, generate a new ALLOWLIST_OP + # that reflects the current state of what passes or not + if os.environ.get("EXPECTTEST_ACCEPT", None) == "1": + generate_new_truth = True + else: + generate_new_truth = False + + if not generate_new_truth: + if op.name not in self.ALLOWLIST_OP: + self.skipTest(f"{op.name} is not in the allow list for test on MPS") + else: + if str(dtype) not in self.ALLOWLIST_OP[op.name]: + self.skipTest(f"{op.name} is in the allow list for MPS but {dtype} is excluded") + try: + cpu_samples = op.sample_inputs(device, dtype) + + for cpu_sample in cpu_samples: + mps_sample = cpu_sample.transform(lambda x: x.to("mps") if isinstance(x, torch.Tensor) else x) + + # TODO: This checks only the function variant. We should also check the method and inplace version + # when they exist + cpu_args = [cpu_sample.input] + list(cpu_sample.args) + cpu_kwargs = cpu_sample.kwargs + mps_args = [mps_sample.input] + list(mps_sample.args) + mps_kwargs = mps_sample.kwargs + + cpu_out = op(*cpu_args, **cpu_kwargs) + mps_out = op(*mps_args, **mps_kwargs) + self.assertEqual(cpu_out, mps_out) + except Exception as e: + if not generate_new_truth: + raise e + else: + if generate_new_truth: + self.NEW_ALLOW_LIST[op.name].append(str(dtype)) + + # We could write it only once. But I don't know how to detect that the current test is the last one + # So each test append to the dict and write it. + with open("new_mps_allowlist.txt", "w") as f: + pprint.pprint(self.NEW_ALLOW_LIST, stream=f) + +# TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing. +# This requires mps to be properly registered in the device generic test framework which is not the +# case right now. +instantiate_device_type_tests(TestConsistency, globals(), only_for="cpu") if __name__ == "__main__": run_tests()