In [1]:
import torch
from brt.common import log
import brt
import brt.nn as nn
from brt.router import RandomScatterRouter, RandomGatherRouter
from brt.frontend import build_graph, flatten_model_graph
log.set_level("frontend", "DEBUG")
log.set_level("backend", "DEBUG")
log.set_level("ir", "DEBUG")


@brt.netlet
class MoE(nn.Module):
    def __init__(self):
        super().__init__()
        self.scatter_router = RandomScatterRouter(route_num=2)
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 10)
        self.gather_router = RandomGatherRouter(route_num=2)

    def forward(self, x):
        route_results, reverse_indice, reverse_shape = self.scatter_router(x)
        x_0 = self.expert1(route_results[0])
        x_1 = self.expert2(route_results[1])
        x = self.gather_router([x_0, x_1], reverse_indice, reverse_shape)
        return x

@brt.domain
class MoEModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.moe = MoE()

    def forward(self, x):
        return self.moe(x)


moe_model = MoEModel()
script_moe_model = torch.jit.script(moe_model)
sm_graph = script_moe_model.moe.graph
# print(sm_graph)
# for node in sm_graph.nodes():
#     if node.kind() == "prim::PythonOp":
#         print(node, node.inputsAt(0).debugName())
    # print(node.s("name"))
    # print(node.inputsAt(0).debugName())
    # _val = getattr(moe_model, node.s("name"))
    # print(_val)
    # print(type(_val))
# normal_model_ir = build_graph(normal_model)
ir_moe_model = build_graph(moe_model)
flattened_ir_moe_model = flatten_model_graph(ir_moe_model)


setting logger for brainstorm.frontend to DEBUG level
setting logger for brainstorm.backend to DEBUG level
setting logger for brainstorm.ir to DEBUG level
[2022-05-02 12:21:03] DEBUG (brainstorm.ir/MainThread) find subclass of pytorch operation: prim::GetAttr
[2022-05-02 12:21:03] DEBUG (brainstorm.ir/MainThread) find subclass of pytorch operation: prim::Constant
[2022-05-02 12:21:03] DEBUG (brainstorm.ir/MainThread) find subclass of pytorch operation: prim::Constant
[2022-05-02 12:21:03] DEBUG (brainstorm.ir/MainThread) find subclass of pytorch operation: prim::GetAttr
[2022-05-02 12:21:03] DEBUG (brainstorm.frontend/MainThread) building brt.router RandomScatterRouter, m_attrs: {'route_num': 2}
[2022-05-02 12:21:03] DEBUG (brainstorm.ir/MainThread) find subclass of pytorch operation: __torch__.brt.router.scatter_router.RandomScatterRouter
[2022-05-02 12:21:03] DEBUG (brainstorm.ir/MainThread) find subclass of pytorch operation: prim::TupleUnpack
[2022-05-02 12:21:03] DEBUG (brainstorm

In [2]:
from brt.backend.pytorch import model_to_script

model_script = model_to_script(ir_moe_model)
print(model_script)

[2022-05-02 12:21:48] DEBUG (brainstorm.backend/MainThread) sorted_incoming_edges: []
[2022-05-02 12:21:48] DEBUG (brainstorm.backend/MainThread) submodule_name: _Constant2, node_name: _Constant2, inputs: [], inputs_value: []
[2022-05-02 12:21:48] DEBUG (brainstorm.backend/MainThread) sorted_incoming_edges: []
[2022-05-02 12:21:48] DEBUG (brainstorm.backend/MainThread) submodule_name: _Constant3, node_name: _Constant3, inputs: [], inputs_value: []
[2022-05-02 12:21:48] DEBUG (brainstorm.backend/MainThread) sorted_incoming_edges: [Edge(head=(Node(id=-1, name=_inputs, python_name=None, label=None, operation=_IOPseudoOperation(type="_inputs")), 0), tail=(Node(id=5, name=_model__moe__scatter_router, python_name=moe.scatter_router, label=None, operation=ModuleOperator(type="__torch__.brt.router.scatter_router.RandomScatterRouter", route_num=2)), None))]
[2022-05-02 12:21:48] DEBUG (brainstorm.backend/MainThread) all tail_slots are None: [None]
[2022-05-02 12:21:48] DEBUG (brainstorm.backend

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import brt.nn

import torch
import brt


class m_model(nn.Module):
    def __init__(self):
        super().__init__()
        self._moe__scatter_router = brt.router.scatter_router.RandomScatterRouter(route_num=2)
        self._moe__expert2 = torch.nn.modules.linear.Linear(in_features=10, out_features=10)
        self._moe__expert1 = torch.nn.modules.linear.Linear(in_features=10, out_features=10)
        self._moe__gather_router = brt.router.gather_router.RandomGatherRouter(route_num=2)
        self._mapping_ = {'_moe__scatter_router': None, '_moe__expert2': None, '_moe__expert1': None, '_moe__gather_router': None}

    def forward(self, x__1):
        _moe__Constant2 = 0
        _moe__Constant3 = 1
        _moe__scatter_router = self._moe__scatter_router(x__1)
        _moe__TupleUnpack6 = _moe__scatter_router
        _moe__aten____getitem__10 = _moe__TupleUnpack6[0][_moe__Constant3]
        _moe__aten____getitem__8 = _moe__TupleUnpack6[0][_moe__Constant2]
        _moe__expert2 = self._moe__expert2(_moe__aten____getitem__10)
        _moe__expert1 = self._moe__expert1(_moe__aten____getitem__8)
        _moe__ListConstruct12 = [_moe__expert1, _moe__expert2]
        _moe__gather_router = self._moe__gather_router(_moe__ListConstruct12, _moe__TupleUnpack6[1], _moe__TupleUnpack6[2])
        return _moe__gather_router
    
moe = m_model()
x = torch.Tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
x = moe(x)

[tensor([[-1.1668,  4.0991, -2.0578,  4.3479,  5.3001, -0.1832, -1.3333,  9.3435,
         -6.1432, -0.5831]], grad_fn=<AddmmBackward0>), tensor([[ 3.9222,  1.4431, -0.2387,  1.0776,  0.0810,  1.5783,  1.2412,  0.8928,
          1.4979, -1.1245]], grad_fn=<AddmmBackward0>)]


In [36]:
import torch.nn as nn
import torch
from typing import List, Tuple
import inspect
import re

def fwd(x: int) -> Tuple[int, int, int]:
    return 1, 2, 3


fwd_sig = inspect.signature(nn.MultiheadAttention.forward)
x = str(fwd_sig.return_annotation)
print(x)
if x.startswith("typing.Tuple"):
    x = x[len("typing.Tuple") + 1 : -1]
    print
    s = re.split(r',\s*(?=[^]]*(?:[|$])', x) 
    ',(?=[^}]*(?:{|$))'
    print(s)
    ret_n = len(s)
else:
    ret_n = 1
print(ret_n)


typing.Tuple[torch.Tensor, typing.Union[torch.Tensor, NoneType]]


error: missing ), unterminated subpattern at position 4