In [1]:
import torch
from brt.common import logging
import brt
import brt.nn as nn
from brt.router import RandomScatterRouter, RandomGatherRouter

logging.set_level_to_debug()

@brt.netlet
class MoE(nn.Module):
    def __init__(self):
        super().__init__()
        self.scatter_router = RandomScatterRouter(route_num=2)
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 10)
        self.gather_router = RandomGatherRouter(route_num=2)

    def forward(self, x):
        route_results, reverse_indice, reverse_shape = self.scatter_router(x)
        x_0 = self.expert1(route_results[0])
        x_1 = self.expert2(route_results[1])
        x = self.gather_router([x_0, x_1], reverse_indice, reverse_shape)
        return x_0, x_1


@brt.netlet
class MoEModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.moe = MoE()

    def forward(self, x):
        return self.moe(x)

@brt.domain
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.moe_model = MoEModel()

    def forward(self, x):
        return self.moe_model(x)


moe_model = Model()

x = torch.Tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
# model.cuda()
y = 10
z = moe_model(x)
print(z)

script_moe_model = torch.jit.script(moe_model)
print(script_moe_model.inlined_graph)

[tensor([[ 6.6156,  3.2507, -1.6666,  3.4580, -0.5150,  0.5712,  1.2547,  7.4214,
          5.4138, -0.9039]], grad_fn=<AddmmBackward0>), tensor([[ 2.3959,  1.9059,  3.2039,  1.8616, -0.2561, -1.4001,  1.8435, -0.1045,
          1.6839, -2.8091]], grad_fn=<AddmmBackward0>)]
(tensor([[ 6.6156,  3.2507, -1.6666,  3.4580, -0.5150,  0.5712,  1.2547,  7.4214,
          5.4138, -0.9039]], grad_fn=<AddmmBackward0>), tensor([[ 2.3959,  1.9059,  3.2039,  1.8616, -0.2561, -1.4001,  1.8435, -0.1045,
          1.6839, -2.8091]], grad_fn=<AddmmBackward0>))
graph(%self : __torch__.Model,
      %x.1 : Tensor):
  %moe_model : __torch__.MoEModel = prim::GetAttr[name="moe_model"](%self)
  %4 : int = prim::Constant[value=0]() # /tmp/ipykernel_3578182/4025828174.py:20:41
  %5 : int = prim::Constant[value=1]() # /tmp/ipykernel_3578182/4025828174.py:21:41
  %6 : Function = prim::Constant[name="linear"]()
  %moe : __torch__.MoE = prim::GetAttr[name="moe"](%moe_model)
  %scatter_router : __torch__.brt.router.