Skip to content

[Unity] nn.Module external modules#15487

Merged
cyx-6 merged 3 commits intoapache:unityfrom
cyx-6:extern-module
Aug 20, 2023
Merged

[Unity] nn.Module external modules#15487
cyx-6 merged 3 commits intoapache:unityfrom
cyx-6:extern-module

Conversation

@cyx-6
Copy link
Contributor

@cyx-6 cyx-6 commented Aug 4, 2023

This PR introduces the feature of importing external *.o modules into our nn.Module frontend.

@tvm-bot
Copy link
Collaborator

tvm-bot commented Aug 4, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@Hzfengsy
Copy link
Member

cc @junrushao. Could you review this PR?

@junrushao
Copy link
Member

I'd love to review later this week. Was meant to review Monday but got distracted for quite a while

@junrushao
Copy link
Member

Otherwise LGTM

cyx-6 and others added 2 commits August 18, 2023 10:51
This PR introduces the feature of importing external `*.o` modules into our `nn.Module` frontend.
@cyx-6 cyx-6 merged commit b959645 into apache:unity Aug 20, 2023
@Cydia2018
Copy link

Thank you very much for your work.

I would like to ask if the function of external modules has not been fully developed yet? I tried to embed a handwritten attention operator and it compiled normally, but an error was reported when running: Cannot find PackedFunc attention in either Relax VM kernel library.

Looking forward to your reply!

@cyx-6
Copy link
Contributor Author

cyx-6 commented Sep 22, 2023

@Cydia2018 Thanks for reaching out! Have you ever wrapped your operator with TVM_DLL_EXPORT_TYPED_FUNC or other TVM registration functor? Feel free to create a new issue with more context. And it would be great if you could provide a reproducible script for us. :)

junrushao added a commit to junrushao/tvm that referenced this pull request Oct 30, 2023
Following apache#15487, this PR introduces `nn.SourceModule` as a subclass of
`nn.ExternModule` to more convenient handling of externally implemented
operators, subgraphs and other components.

**What is `nn.ExternModule` designed for?** It is a generic design that
accepts any object file (`.o` in Linux) and combines the symbols within
into TVM-generated shared/static library. This way, TVM-generated
library will be able to call into the symbols in the object files
provided. For example, calling into `cutlass_fmha` in TVM Relax.

**What is `nn.SourceModule`?** It is a subclass that builds on top of
`nn.ExternModule`, which helps with a specific case where the external
implementation is provided as C++/CUDA source code and `SourceModule`
could conveniently take care of applying a C++/CUDA compiler to convert
it to convert them into object files.

**C++/CUDA Calling Convention.** An exported symbol should be explicitly
declared with macro `TVM_DLL_EXPORT_TYPED_FUNC($SYMBOL, $CPP_FUNC)`,
while it is recommended to always hide the symbols that we don't wish to
export. Multiple files should never define the same symbol, otherwise it
is considered as UB.

**Marks on `nn.Module`.** It is required to define the input/output
shapes and dtypes using `nn.spec.ExternFunctionSpec`. Symbolic/dynamic
shapes are supported, but there are a few limitations to note:

1) Multi-output is currently not supported, meaning the return value has
to be a single `nn.Tensor`. There is no technical challenge we are aware
of, and we could extend the interface to get it supported in the future
if there's any need.
2) Symbolic dtype is not supported. It means one has to export multiple
symbols for multiple dtypes even if the compute is mathematically
identical, e.g. `matmul_f16_f16_f16`, `matmul_f32_f32_f32`. I imagine it
could be alleviated if customization of dtype deduction is introduced.

**Example.** Take the C++ code below as an example:

```C++
\#include <dlpack/dlpack.h>
\#include <tvm/runtime/packed_func.h>
\#include <tvm/runtime/data_type.h>

namespace {

int _scalar_add(DLTensor* a, DLTensor* b, DLTensor* c) {
  using namespace tvm::runtime;
  ICHECK(a->ndim == 0);
  ICHECK(b->ndim == 0);
  ICHECK(c->ndim == 0);
  ICHECK(DataType(a->dtype) == DataType::Float(32));
  ICHECK(DataType(b->dtype) == DataType::Float(32));
  ICHECK(DataType(c->dtype) == DataType::Float(32));
  float* a_data = static_cast<float*>(a->data);
  float* b_data = static_cast<float*>(b->data);
  float* c_data = static_cast<float*>(c->data);
  *c_data = *a_data + *b_data;
  return 0;
}

int _test_sym(DLTensor* a, DLTensor* b, DLTensor* c) {
  using namespace tvm::runtime;
  ICHECK(a->ndim == 3);
  ICHECK(b->ndim == 3);
  ICHECK(c->ndim == 4);
  ICHECK(DataType(a->dtype) == DataType::Float(32));
  ICHECK(DataType(b->dtype) == DataType::Float(32));
  ICHECK(DataType(c->dtype) == DataType::Float(32));
  int x = a->shape[0];
  int y = a->shape[1];
  int z = b->shape[1];
  ICHECK(a->shape[0] == x);
  ICHECK(a->shape[1] == y);
  ICHECK(a->shape[2] == 1);
  ICHECK(b->shape[0] == y);
  ICHECK(b->shape[1] == z);
  ICHECK(b->shape[2] == 5);
  ICHECK(c->shape[0] == x);
  ICHECK(c->shape[1] == y);
  ICHECK(c->shape[2] == z);
  ICHECK(c->shape[3] == 9);
  return 0;
}

}
TVM_DLL_EXPORT_TYPED_FUNC(ext_scalar_add, _scalar_add);
TVM_DLL_EXPORT_TYPED_FUNC(ext_test_sym, _test_sym);
```

It exposes two symbols `ext_scalar_add` and `ext_test_sym`. In
`nn.Module`, their shape/dtype deduction rules could be described as:

```python
dtype = "float32"
functions = {
  "ext_scalar_add": spec.ExternFunctionSpec(
    args=[
      spec.Tensor((), dtype),
      spec.Tensor((), dtype),
    ],
    ret=spec.Tensor((), dtype),
  ),
  "ext_test_sym": spec.ExternFunctionSpec(
    args=[
      spec.Tensor(("x", "y", 1), dtype),
      spec.Tensor(("y", "z", 5), dtype),
    ],
    ret=spec.Tensor(("x", "y", "z", 9), dtype),
  ),
}
```

and thus the external module could be defined as:

```python
class MyExtMod(nn.SourceModule):
    def __init__(self):
        super().__init__(
            source_code=SOURCE_CODE,
            source_format="cpp",
            functions=functions,
        )
    def scalar_add(self, a: nn.Tensor, b: nn.Tensor):  # pylint: disable=invalid-name
        return self.get_extern_func("ext_scalar_add")(a, b)
    def test_sym(self, a: nn.Tensor, b: nn.Tensor):  # pylint: disable=invalid-name
        return self.get_extern_func("ext_test_sym")(a, b)
```

Any `nn.Module` could use this `nn.SourceModule` as part of the computation
and export them into TVM IRModule:

```python
my_ext_mod = MyExtMod()

class TestModule(nn.Module):
    def __init__(self) -> None:
        self.extern_matmul = my_ext_mod

    def scalar_add(self, a: nn.Tensor, b: nn.Tensor):  # pylint: disable=invalid-name
        return self.extern_matmul.scalar_add(a, b)

    def test_sym(self, a: nn.Tensor, b: nn.Tensor):  # pylint: disable=invalid-name
        return self.extern_matmul.test_sym(a, b)

model = TestModule()
ir_module, _ = model.export_tvm(
    spec={
        "scalar_add": {
            "a": spec.Tensor((), dtype),
            "b": spec.Tensor((), dtype),
        },
        "test_sym": {
            "a": spec.Tensor(shape_a, dtype),
            "b": spec.Tensor(shape_b, dtype),
        },
    }
)
```
junrushao added a commit to junrushao/tvm that referenced this pull request Oct 30, 2023
Following apache#15487, this PR introduces `nn.SourceModule` as a subclass of
`nn.ExternModule` to more convenient handling of externally implemented
operators, subgraphs and other components.

**What is `nn.ExternModule` designed for?** It is a generic design that
accepts any object file (`.o` in Linux) and combines the symbols within
into TVM-generated shared/static library. This way, TVM-generated
library will be able to call into the symbols in the object files
provided. For example, calling into `cutlass_fmha` in TVM Relax.

**What is `nn.SourceModule`?** It is a subclass that builds on top of
`nn.ExternModule`, which helps with a specific case where the external
implementation is provided as C++/CUDA source code and `SourceModule`
could conveniently take care of applying a C++/CUDA compiler to convert
it to convert them into object files.

**C++/CUDA Calling Convention.** An exported symbol should be explicitly
declared with macro `TVM_DLL_EXPORT_TYPED_FUNC($SYMBOL, $CPP_FUNC)`,
while it is recommended to always hide the symbols that we don't wish to
export. Multiple files should never define the same symbol, otherwise it
is considered as UB.

**Marks on `nn.Module`.** It is required to define the input/output
shapes and dtypes using `nn.spec.ExternFunctionSpec`. Symbolic/dynamic
shapes are supported, but there are a few limitations to note:

1) Multi-output is currently not supported, meaning the return value has
to be a single `nn.Tensor`. There is no technical challenge we are aware
of, and we could extend the interface to get it supported in the future
if there's any need.
2) Symbolic dtype is not supported. It means one has to export multiple
symbols for multiple dtypes even if the compute is mathematically
identical, e.g. `matmul_f16_f16_f16`, `matmul_f32_f32_f32`. I imagine it
could be alleviated if customization of dtype deduction is introduced.

**Example.** Take the C++ code below as an example:

```C++
\#include <dlpack/dlpack.h>
\#include <tvm/runtime/packed_func.h>
\#include <tvm/runtime/data_type.h>

namespace {

int _scalar_add(DLTensor* a, DLTensor* b, DLTensor* c) {
  using namespace tvm::runtime;
  ICHECK(a->ndim == 0);
  ICHECK(b->ndim == 0);
  ICHECK(c->ndim == 0);
  ICHECK(DataType(a->dtype) == DataType::Float(32));
  ICHECK(DataType(b->dtype) == DataType::Float(32));
  ICHECK(DataType(c->dtype) == DataType::Float(32));
  float* a_data = static_cast<float*>(a->data);
  float* b_data = static_cast<float*>(b->data);
  float* c_data = static_cast<float*>(c->data);
  *c_data = *a_data + *b_data;
  return 0;
}

int _test_sym(DLTensor* a, DLTensor* b, DLTensor* c) {
  using namespace tvm::runtime;
  ICHECK(a->ndim == 3);
  ICHECK(b->ndim == 3);
  ICHECK(c->ndim == 4);
  ICHECK(DataType(a->dtype) == DataType::Float(32));
  ICHECK(DataType(b->dtype) == DataType::Float(32));
  ICHECK(DataType(c->dtype) == DataType::Float(32));
  int x = a->shape[0];
  int y = a->shape[1];
  int z = b->shape[1];
  ICHECK(a->shape[0] == x);
  ICHECK(a->shape[1] == y);
  ICHECK(a->shape[2] == 1);
  ICHECK(b->shape[0] == y);
  ICHECK(b->shape[1] == z);
  ICHECK(b->shape[2] == 5);
  ICHECK(c->shape[0] == x);
  ICHECK(c->shape[1] == y);
  ICHECK(c->shape[2] == z);
  ICHECK(c->shape[3] == 9);
  return 0;
}

}
TVM_DLL_EXPORT_TYPED_FUNC(ext_scalar_add, _scalar_add);
TVM_DLL_EXPORT_TYPED_FUNC(ext_test_sym, _test_sym);
```

It exposes two symbols `ext_scalar_add` and `ext_test_sym`. In
`nn.Module`, their shape/dtype deduction rules could be described as:

```python
dtype = "float32"
functions = {
  "ext_scalar_add": spec.ExternFunctionSpec(
    args=[
      spec.Tensor((), dtype),
      spec.Tensor((), dtype),
    ],
    ret=spec.Tensor((), dtype),
  ),
  "ext_test_sym": spec.ExternFunctionSpec(
    args=[
      spec.Tensor(("x", "y", 1), dtype),
      spec.Tensor(("y", "z", 5), dtype),
    ],
    ret=spec.Tensor(("x", "y", "z", 9), dtype),
  ),
}
```

and thus the external module could be defined as:

```python
class MyExtMod(nn.SourceModule):
    def __init__(self):
        super().__init__(
            source_code=SOURCE_CODE,
            source_format="cpp",
            functions=functions,
        )
    def scalar_add(self, a: nn.Tensor, b: nn.Tensor):  # pylint: disable=invalid-name
        return self.get_extern_func("ext_scalar_add")(a, b)
    def test_sym(self, a: nn.Tensor, b: nn.Tensor):  # pylint: disable=invalid-name
        return self.get_extern_func("ext_test_sym")(a, b)
```

Any `nn.Module` could use this `nn.SourceModule` as part of the computation
and export them into TVM IRModule:

```python
my_ext_mod = MyExtMod()

class TestModule(nn.Module):
    def __init__(self) -> None:
        self.extern_matmul = my_ext_mod

    def scalar_add(self, a: nn.Tensor, b: nn.Tensor):  # pylint: disable=invalid-name
        return self.extern_matmul.scalar_add(a, b)

    def test_sym(self, a: nn.Tensor, b: nn.Tensor):  # pylint: disable=invalid-name
        return self.extern_matmul.test_sym(a, b)

model = TestModule()
ir_module, _ = model.export_tvm(
    spec={
        "scalar_add": {
            "a": spec.Tensor((), dtype),
            "b": spec.Tensor((), dtype),
        },
        "test_sym": {
            "a": spec.Tensor(shape_a, dtype),
            "b": spec.Tensor(shape_b, dtype),
        },
    }
)
```
junrushao added a commit to junrushao/tvm that referenced this pull request Oct 30, 2023
Following apache#15487, this PR introduces `nn.SourceModule` as a subclass of
`nn.ExternModule` to more convenient handling of externally implemented
operators, subgraphs and other components.

**What is `nn.ExternModule` designed for?** It is a generic design that
accepts any object file (`.o` in Linux) and combines the symbols within
into TVM-generated shared/static library. This way, TVM-generated
library will be able to call into the symbols in the object files
provided. For example, calling into `cutlass_fmha` in TVM Relax.

**What is `nn.SourceModule`?** It is a subclass that builds on top of
`nn.ExternModule`, which helps with a specific case where the external
implementation is provided as C++/CUDA source code and `SourceModule`
could conveniently take care of applying a C++/CUDA compiler to convert
it to convert them into object files.

**C++/CUDA Calling Convention.** An exported symbol should be explicitly
declared with macro `TVM_DLL_EXPORT_TYPED_FUNC($SYMBOL, $CPP_FUNC)`,
while it is recommended to always hide the symbols that we don't wish to
export. Multiple files should never define the same symbol, otherwise it
is considered as UB.

**Marks on `nn.Module`.** It is required to define the input/output
shapes and dtypes using `nn.spec.ExternFunctionSpec`. Symbolic/dynamic
shapes are supported, but there are a few limitations to note:

1) Multi-output is currently not supported, meaning the return value has
to be a single `nn.Tensor`. There is no technical challenge we are aware
of, and we could extend the interface to get it supported in the future
if there's any need.
2) Symbolic dtype is not supported. It means one has to export multiple
symbols for multiple dtypes even if the compute is mathematically
identical, e.g. `matmul_f16_f16_f16`, `matmul_f32_f32_f32`. I imagine it
could be alleviated if customization of dtype deduction is introduced.

**Example.** Take the C++ code below as an example:

```C++
\#include <dlpack/dlpack.h>
\#include <tvm/runtime/packed_func.h>
\#include <tvm/runtime/data_type.h>

namespace {

int _scalar_add(DLTensor* a, DLTensor* b, DLTensor* c) {
  using namespace tvm::runtime;
  ICHECK(a->ndim == 0);
  ICHECK(b->ndim == 0);
  ICHECK(c->ndim == 0);
  ICHECK(DataType(a->dtype) == DataType::Float(32));
  ICHECK(DataType(b->dtype) == DataType::Float(32));
  ICHECK(DataType(c->dtype) == DataType::Float(32));
  float* a_data = static_cast<float*>(a->data);
  float* b_data = static_cast<float*>(b->data);
  float* c_data = static_cast<float*>(c->data);
  *c_data = *a_data + *b_data;
  return 0;
}

int _test_sym(DLTensor* a, DLTensor* b, DLTensor* c) {
  using namespace tvm::runtime;
  ICHECK(a->ndim == 3);
  ICHECK(b->ndim == 3);
  ICHECK(c->ndim == 4);
  ICHECK(DataType(a->dtype) == DataType::Float(32));
  ICHECK(DataType(b->dtype) == DataType::Float(32));
  ICHECK(DataType(c->dtype) == DataType::Float(32));
  int x = a->shape[0];
  int y = a->shape[1];
  int z = b->shape[1];
  ICHECK(a->shape[0] == x);
  ICHECK(a->shape[1] == y);
  ICHECK(a->shape[2] == 1);
  ICHECK(b->shape[0] == y);
  ICHECK(b->shape[1] == z);
  ICHECK(b->shape[2] == 5);
  ICHECK(c->shape[0] == x);
  ICHECK(c->shape[1] == y);
  ICHECK(c->shape[2] == z);
  ICHECK(c->shape[3] == 9);
  return 0;
}

}
TVM_DLL_EXPORT_TYPED_FUNC(ext_scalar_add, _scalar_add);
TVM_DLL_EXPORT_TYPED_FUNC(ext_test_sym, _test_sym);
```

It exposes two symbols `ext_scalar_add` and `ext_test_sym`. In
`nn.Module`, their shape/dtype deduction rules could be described as:

```python
dtype = "float32"
functions = {
  "ext_scalar_add": spec.ExternFunctionSpec(
    args=[
      spec.Tensor((), dtype),
      spec.Tensor((), dtype),
    ],
    ret=spec.Tensor((), dtype),
  ),
  "ext_test_sym": spec.ExternFunctionSpec(
    args=[
      spec.Tensor(("x", "y", 1), dtype),
      spec.Tensor(("y", "z", 5), dtype),
    ],
    ret=spec.Tensor(("x", "y", "z", 9), dtype),
  ),
}
```

and thus the external module could be defined as:

```python
class MyExtMod(nn.SourceModule):
    def __init__(self):
        super().__init__(
            source_code=SOURCE_CODE,
            source_format="cpp",
            functions=functions,
        )
    def scalar_add(self, a: nn.Tensor, b: nn.Tensor):  # pylint: disable=invalid-name
        return self.get_extern_func("ext_scalar_add")(a, b)
    def test_sym(self, a: nn.Tensor, b: nn.Tensor):  # pylint: disable=invalid-name
        return self.get_extern_func("ext_test_sym")(a, b)
```

Any `nn.Module` could use this `nn.SourceModule` as part of the computation
and export them into TVM IRModule:

```python
my_ext_mod = MyExtMod()

class TestModule(nn.Module):
    def __init__(self) -> None:
        self.extern_matmul = my_ext_mod

    def scalar_add(self, a: nn.Tensor, b: nn.Tensor):  # pylint: disable=invalid-name
        return self.extern_matmul.scalar_add(a, b)

    def test_sym(self, a: nn.Tensor, b: nn.Tensor):  # pylint: disable=invalid-name
        return self.extern_matmul.test_sym(a, b)

model = TestModule()
ir_module, _ = model.export_tvm(
    spec={
        "scalar_add": {
            "a": spec.Tensor((), dtype),
            "b": spec.Tensor((), dtype),
        },
        "test_sym": {
            "a": spec.Tensor(shape_a, dtype),
            "b": spec.Tensor(shape_b, dtype),
        },
    }
)
```
junrushao added a commit that referenced this pull request Oct 31, 2023
Following #15487, this PR introduces `nn.SourceModule` as a subclass of
`nn.ExternModule` to more convenient handling of externally implemented
operators, subgraphs and other components.

**What is `nn.ExternModule` designed for?** It is a generic design that
accepts any object file (`.o` in Linux) and combines the symbols within
into TVM-generated shared/static library. This way, TVM-generated
library will be able to call into the symbols in the object files
provided. For example, calling into `cutlass_fmha` in TVM Relax.

**What is `nn.SourceModule`?** It is a subclass that builds on top of
`nn.ExternModule`, which helps with a specific case where the external
implementation is provided as C++/CUDA source code and `SourceModule`
could conveniently take care of applying a C++/CUDA compiler to convert
it to convert them into object files.

**C++/CUDA Calling Convention.** An exported symbol should be explicitly
declared with macro `TVM_DLL_EXPORT_TYPED_FUNC($SYMBOL, $CPP_FUNC)`,
while it is recommended to always hide the symbols that we don't wish to
export. Multiple files should never define the same symbol, otherwise it
is considered as UB.

**Marks on `nn.Module`.** It is required to define the input/output
shapes and dtypes using `nn.spec.ExternFunctionSpec`. Symbolic/dynamic
shapes are supported, but there are a few limitations to note:

1) Multi-output is currently not supported, meaning the return value has
to be a single `nn.Tensor`. There is no technical challenge we are aware
of, and we could extend the interface to get it supported in the future
if there's any need.
2) Symbolic dtype is not supported. It means one has to export multiple
symbols for multiple dtypes even if the compute is mathematically
identical, e.g. `matmul_f16_f16_f16`, `matmul_f32_f32_f32`. I imagine it
could be alleviated if customization of dtype deduction is introduced.

**Example.** Take the C++ code below as an example:

```C++
\#include <dlpack/dlpack.h>
\#include <tvm/runtime/packed_func.h>
\#include <tvm/runtime/data_type.h>

namespace {

int _scalar_add(DLTensor* a, DLTensor* b, DLTensor* c) {
  using namespace tvm::runtime;
  ICHECK(a->ndim == 0);
  ICHECK(b->ndim == 0);
  ICHECK(c->ndim == 0);
  ICHECK(DataType(a->dtype) == DataType::Float(32));
  ICHECK(DataType(b->dtype) == DataType::Float(32));
  ICHECK(DataType(c->dtype) == DataType::Float(32));
  float* a_data = static_cast<float*>(a->data);
  float* b_data = static_cast<float*>(b->data);
  float* c_data = static_cast<float*>(c->data);
  *c_data = *a_data + *b_data;
  return 0;
}

int _test_sym(DLTensor* a, DLTensor* b, DLTensor* c) {
  using namespace tvm::runtime;
  ICHECK(a->ndim == 3);
  ICHECK(b->ndim == 3);
  ICHECK(c->ndim == 4);
  ICHECK(DataType(a->dtype) == DataType::Float(32));
  ICHECK(DataType(b->dtype) == DataType::Float(32));
  ICHECK(DataType(c->dtype) == DataType::Float(32));
  int x = a->shape[0];
  int y = a->shape[1];
  int z = b->shape[1];
  ICHECK(a->shape[0] == x);
  ICHECK(a->shape[1] == y);
  ICHECK(a->shape[2] == 1);
  ICHECK(b->shape[0] == y);
  ICHECK(b->shape[1] == z);
  ICHECK(b->shape[2] == 5);
  ICHECK(c->shape[0] == x);
  ICHECK(c->shape[1] == y);
  ICHECK(c->shape[2] == z);
  ICHECK(c->shape[3] == 9);
  return 0;
}

}
TVM_DLL_EXPORT_TYPED_FUNC(ext_scalar_add, _scalar_add);
TVM_DLL_EXPORT_TYPED_FUNC(ext_test_sym, _test_sym);
```

It exposes two symbols `ext_scalar_add` and `ext_test_sym`. In
`nn.Module`, their shape/dtype deduction rules could be described as:

```python
dtype = "float32"
functions = {
  "ext_scalar_add": spec.ExternFunctionSpec(
    args=[
      spec.Tensor((), dtype),
      spec.Tensor((), dtype),
    ],
    ret=spec.Tensor((), dtype),
  ),
  "ext_test_sym": spec.ExternFunctionSpec(
    args=[
      spec.Tensor(("x", "y", 1), dtype),
      spec.Tensor(("y", "z", 5), dtype),
    ],
    ret=spec.Tensor(("x", "y", "z", 9), dtype),
  ),
}
```

and thus the external module could be defined as:

```python
class MyExtMod(nn.SourceModule):
    def __init__(self):
        super().__init__(
            source_code=SOURCE_CODE,
            source_format="cpp",
            functions=functions,
        )
    def scalar_add(self, a: nn.Tensor, b: nn.Tensor):  # pylint: disable=invalid-name
        return self.get_extern_func("ext_scalar_add")(a, b)
    def test_sym(self, a: nn.Tensor, b: nn.Tensor):  # pylint: disable=invalid-name
        return self.get_extern_func("ext_test_sym")(a, b)
```

Any `nn.Module` could use this `nn.SourceModule` as part of the computation
and export them into TVM IRModule:

```python
my_ext_mod = MyExtMod()

class TestModule(nn.Module):
    def __init__(self) -> None:
        self.extern_matmul = my_ext_mod

    def scalar_add(self, a: nn.Tensor, b: nn.Tensor):  # pylint: disable=invalid-name
        return self.extern_matmul.scalar_add(a, b)

    def test_sym(self, a: nn.Tensor, b: nn.Tensor):  # pylint: disable=invalid-name
        return self.extern_matmul.test_sym(a, b)

model = TestModule()
ir_module, _ = model.export_tvm(
    spec={
        "scalar_add": {
            "a": spec.Tensor((), dtype),
            "b": spec.Tensor((), dtype),
        },
        "test_sym": {
            "a": spec.Tensor(shape_a, dtype),
            "b": spec.Tensor(shape_b, dtype),
        },
    }
)
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants