Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unittest][Metal] Add minimal metal functionality test to CI #15756

Merged
merged 6 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,14 @@ jobs:
shell: bash -l {0}
run: >-
python -m pytest -v tests/python/all-platform-minimal-test
- name: Minimal Metal Compile-Only
shell: bash -l {0}
run: >-
python -m pytest -v -s 'tests/python/unittest/test_allreduce.py::test_allreduce_sum_compile'
- name: Minimal Metal Compile-and-Run
shell: bash -l {0}
run: >-
python -m pytest -v -s 'tests/python/unittest/test_allreduce.py::test_allreduce_sum[dims0-metal]'
- name: Test iOS RPC
shell: bash -l {0}
run: >-
Expand Down
44 changes: 44 additions & 0 deletions tests/python/unittest/test_allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import numpy as np
from tvm.script import tir as T

import pytest


@T.prim_func
def reduce(a: T.handle, b: T.handle, d1: T.int32, d2: T.int32, d3: T.int32) -> None:
Expand Down Expand Up @@ -82,6 +84,48 @@ def test_allreduce_sum(dims, target, dev):
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-6, atol=1e-6)


define_metal_compile_callback = tvm.testing.parameter(True, False)


@pytest.fixture
def optional_metal_compile_callback(define_metal_compile_callback):
name = "tvm_callback_metal_compile"
cached = tvm.get_global_func(name, allow_missing=True)

if define_metal_compile_callback:

@tvm.register_func(name, override=True)
def compile_metal(src, target):
return tvm.contrib.xcode.compile_metal(src, sdk="macosx")

yield

if define_metal_compile_callback:
if cached is None:
tvm._ffi.registry.remove_global_func(name)
else:
tvm.register_func(name, cached, override=True)


@tvm.testing.requires_metal(support_required="compile-only")
def test_allreduce_sum_compile(optional_metal_compile_callback):
# Disable the parametrization over dims, at least for now
dims = (1, 1, 2)
target = "metal"

d1, d2, d3 = dims
_, _, _d1, _d2, _d3 = reduce.params
mod = reduce.specialize({_d1: d1, _d2: d2, _d3: d3})
sch = tvm.tir.Schedule(mod)
blk = sch.get_block("reduce")
i, j, k, l = sch.get_loops(blk)
sch.bind(i, "blockIdx.x")
sch.bind(j, "threadIdx.z")
sch.bind(k, "threadIdx.y")
sch.bind(l, "threadIdx.x")
tvm.build(sch.mod["main"], target=target)


@tvm.testing.parametrize_targets("cuda", "metal")
def test_allreduce_max(dims, target, dev):
d1, d2, d3 = dims
Expand Down