In [3]:
import torch

import brt
import brt.nn as nn
import brt.backend.internal.symbolic
from brt.router import RandomScatterRouter, RandomGatherRouter


@brt.netlet
class MoE(nn.Module):
    def __init__(self):
        super().__init__()
        self.scatter_router = RandomScatterRouter(route_num=3)
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 10)
        self.expert3 = nn.Linear(10, 10)
        self.gather_router = RandomGatherRouter(route_num=3)

    def forward(self, x):
        route_results, reverse_indice, reverse_shape = self.scatter_router(x)
        x_0 = self.expert1(route_results[0])
        x_1 = self.expert2(route_results[1])
        x_2 = self.expert2(route_results[2])
        x = self.gather_router([x_0, x_1, x_2], reverse_indice, reverse_shape)
        return x


@brt.netlet
class MoEModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.moe = MoE()

    def forward(self, x):
        return self.moe(x)


@brt.domain
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.moe_model = MoEModel()

    def forward(self, x):
        return self.moe_model(x)


moe_model = Model()

x = torch.Tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
# model.cuda()
# y = 10
# z = moe_model(x)
# print(z)
moe_model.optimize()

script_moe_model = torch.jit.script(moe_model)
print(script_moe_model.inlined_graph)

torch.onnx.export(script_moe_model, x, "moe.onnx",  custom_opsets={"brt": 2}, opset_version=14)


graph(%self : __torch__.___torch_mangle_40.Model,
      %x.1 : Tensor):
  %moe_model : __torch__.___torch_mangle_39.MoEModel = prim::GetAttr[name="moe_model"](%self)
  %4 : int = prim::Constant[value=0]() # /tmp/ipykernel_2363175/1213982270.py:21:41
  %5 : int = prim::Constant[value=1]() # /tmp/ipykernel_2363175/1213982270.py:22:41
  %6 : int = prim::Constant[value=2]() # /tmp/ipykernel_2363175/1213982270.py:23:41
  %7 : Function = prim::Constant[name="linear"]()
  %moe : __torch__.___torch_mangle_38.MoE = prim::GetAttr[name="moe"](%moe_model)
  %scatter_router : __torch__.brt.router.scatter.RandomScatterRouter = prim::GetAttr[name="scatter_router"](%moe)
  %route_num.1 : int = prim::GetAttr[name="route_num"](%scatter_router)
  %11 : Tensor[], %12 : Tensor[], %13 : Tensor = brt::symbolic_scatter_route(%x.1, %4, %route_num.1) # /home/whcui/brainstorm_project/brainstorm/python/brt/router/scatter.py:97:15
  %14 : NamedTuple(_0 : Tensor[], _1 : Tensor[], _2 : Tensor) = prim::TupleConstruct



In [3]:
from brt.common import find_lib_path
from typing import List, Tuple
import torch
import torch.nn as nn
import brt.backend.internal.symbolic

libpath = find_lib_path("libbrt_torchscript.so")
torch.ops.load_library(find_lib_path("libbrt_torchscript.so")[0])


class Model(nn.Module):
    def __init__(self, y, z):
        super().__init__()
        self.y = y
        self.z = z

    def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        (
            route_results,
            route_indices,
            reverse_shape,
        ) = torch.ops.brt.symbolic_scatter_route(x, self.y, self.z)
        results = torch.stack(route_results, dim=0)
        indices = torch.stack(route_indices, dim=0)
        return results, indices, reverse_shape


model = Model(0, 2)
script_model = torch.jit.script(model)

x = torch.Tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])

torch.onnx.export(
    script_model, x, "scatter_route.onnx", custom_opsets={"brt": 2}, opset_version=14
)


[tensor([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]), tensor([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]), tensor([1, 1])]


