In [1]:
import torch
from brt.common import log
import brt
import brt.nn as nn
from brt.router import RandomScatterRouter, RandomGatherRouter
from brt.frontend import build_graph

# log.set_level("frontend", "DEBUG")
# log.set_level("backend", "DEBUG")


@brt.netlet
class MoE(nn.Module):
    def __init__(self):
        super().__init__()
        self.scatter_router = RandomScatterRouter(route_num=2)
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 10)
        self.gather_router = RandomGatherRouter(route_num=2)

    def forward(self, x):
        route_results, reverse_indice, reverse_shape = self.scatter_router(x)
        x_0 = self.expert1(route_results[0])
        x_1 = self.expert2(route_results[1])
        x = self.gather_router([x_0, x_1], reverse_indice, reverse_shape)
        return x


@brt.netlet
class Normal(nn.Module):
    def __init__(self):
        super().__init__()
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 10)

    def forward(self, x):
        x = self.expert1(x)
        x = self.expert2(x)
        return x


@brt.domain
class MoEModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.moe = MoE()

    def forward(self, x):
        return self.moe(x)


@brt.domain
class NormalModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.normal = Normal()

    def forward(self, x):
        return self.normal(x)


normal_model = NormalModel()
moe_model = MoEModel()
# normal_model_ir = build_graph(normal_model)
moe_model_ir = build_graph(moe_model)
for graph_name, graph in moe_model_ir.graphs.items():
    for node in graph.topo_sort():
        print(node)


graph(%self : __torch__.MoEModel,
      %x.1 : Tensor):
  %moe : __torch__.MoE = prim::GetAttr[name="moe"](%self)
  %4 : Function = prim::Constant[name="linear"]()
  %5 : int = prim::Constant[value=1]() # /tmp/ipykernel_434546/3540491993.py:24:41
  %6 : int = prim::Constant[value=0]() # /tmp/ipykernel_434546/3540491993.py:23:41
  %scatter_router : __torch__.brt.router.scatter_router.RandomScatterRouter = prim::GetAttr[name="scatter_router"](%moe)
  %8 : (Tensor[], Tensor[], Tensor) = ^forward()(%scatter_router, %x.1) # /tmp/ipykernel_434546/3540491993.py:22:55
  %route_results.1 : Tensor[], %reverse_indice.1 : Tensor[], %reverse_shape.1 : Tensor = prim::TupleUnpack(%8)
  %expert1 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="expert1"](%moe)
  %13 : Tensor = aten::__getitem__(%route_results.1, %6) # /tmp/ipykernel_434546/3540491993.py:23:27
  %weight.1 : Tensor = prim::GetAttr[name="weight"](%expert1)
  %bias.1 : Tensor = prim::GetAttr[name="bias"](%expert1)
  %x_0.1 

In [11]:
from brt.backend.pytorch import model_to_script

model_script = model_to_script(normal_model_ir)
print("---------------------------------------------")
model_script = model_to_script(moe_model_ir)

[2022-04-25 16:11:13] DEBUG (brainstorm.backend/MainThread) sorted_incoming_edges: [Edge(head=(Node(id=-1, name=_inputs, python_name=None, label=None, operation=_IOPseudoOperation(type="_inputs")), 0), tail=(Node(id=57, name=_model__normal__expert1, python_name=normal.expert1, label=None, operation=ModuleOperator(type="__torch__.torch.nn.modules.linear.Linear", in_features=10, out_features=10)), None))]
[2022-04-25 16:11:13] DEBUG (brainstorm.backend/MainThread) all tail_slots are None: [None]
[2022-04-25 16:11:13] DEBUG (brainstorm.backend/MainThread) submodule_name: _expert1, node_name: _expert1, inputs: ['x__1'], inputs_value: [None]
[2022-04-25 16:11:13] DEBUG (brainstorm.backend/MainThread) sorted_incoming_edges: [Edge(head=(Node(id=57, name=_model__normal__expert1, python_name=normal.expert1, label=None, operation=ModuleOperator(type="__torch__.torch.nn.modules.linear.Linear", in_features=10, out_features=10)), None), tail=(Node(id=59, name=_model__normal__expert2, python_name=no

RuntimeError: unsupported operation type: prim::PythonOp ? None