In [16]:
import torch
import brt.frontend.symbolic
import brt.frontend.nn as nn
from brt.routers.app import RandScatterRouter, RandGatherRouter
from brt.frontend import symbolize


class MoE(nn.Module):
    def __init__(self):
        super().__init__()
        self.scatter_router = RandScatterRouter(route_num=2)
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 10)
        self.gather_router = RandGatherRouter(route_num=2)

    def forward(self, x):
        route_results, route_tags, loads = self.scatter_router(x)
        x_0 = self.expert1(route_results[0])
        x_1 = self.expert2(route_results[1])
        x = self.gather_router([x_0, x_1], route_tags, loads)
        return x


class ThorMoE(nn.Module):
    def __init__(
        self,
    ):
        super().__init__()
        self.moe = MoE()

    def forward(self, x):
        return self.moe(x)


thor_moe = ThorMoE()
thor_moe.eval()
x = torch.arange(0, 20, dtype=torch.float32).view(2, 10)
print(x)
y = thor_moe(x)

script_moe_model = torch.jit.script(symbolize(thor_moe))
print(script_moe_model.moe.graph)

torch.onnx.export(
    script_moe_model,
    x,
    "moe.onnx",
    verbose=True,
    custom_opsets={"brt": 2},
    opset_version=11,
)


tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14., 15., 16., 17., 18., 19.]])
graph(%self : __torch__.___torch_mangle_256.MoE,
      %x.1 : Tensor):
  %10 : int = prim::Constant[value=0]() # /tmp/ipykernel_1526473/1513214202.py:18:41
  %15 : int = prim::Constant[value=1]() # /tmp/ipykernel_1526473/1513214202.py:19:41
  %scatter_router : __torch__.brt.router.scatter.RandomScatterRouter = prim::GetAttr[name="scatter_router"](%self)
  %4 : (Tensor[], Tensor[], int) = prim::CallMethod[name="forward"](%scatter_router, %x.1) # /tmp/ipykernel_1526473/1513214202.py:17:43
  %route_results.1 : Tensor[], %route_tags.1 : Tensor[], %loads.1 : int = prim::TupleUnpack(%4)
  %expert1 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="expert1"](%self)
  %11 : Tensor = aten::__getitem__(%route_results.1, %10) # /tmp/ipykernel_1526473/1513214202.py:18:27
  %x_0.1 : Tensor = prim::CallMethod[name="forward"](%expert1, %11) # /tmp/ipykernel_1526473/15



In [3]:
from brt.runtime import find_lib_path
from typing import List, Tuple
import torch
import torch.nn as nn

libpath = find_lib_path("libbrt_torchscript.so")
torch.ops.load_library(find_lib_path("libbrt_torchscript.so")[0])


class Model(nn.Module):
    def __init__(self, y, z):
        super().__init__()
        self.y = y
        self.z = z

    def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        (
            route_results,
            route_indices,
            reverse_shape,
        ) = torch.ops.brt.symbolic_scatter_route(x, self.y, self.z)
        results = torch.stack(route_results, dim=0)
        indices = torch.stack(route_indices, dim=0)
        return results, indices, reverse_shape


model = Model(0, 2)
script_model = torch.jit.script(model)

x = torch.Tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])

torch.onnx.export(
    script_model, x, "scatter_route.onnx", custom_opsets={"brt": 2}, opset_version=14
)


[tensor([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]), tensor([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]), tensor([1, 1])]


