In [1]:
import torch
from brt.common import log
import brt
import brt.nn as nn
from brt.router import RandomScatterRouter, RandomGatherRouter
from brt.frontend import build_graph, flatten_model_graph
log.set_level("frontend", "DEBUG")
# log.set_level("backend", "DEBUG")
# log.set_level("ir", "DEBUG")


@brt.netlet
class MoE(nn.Module):
    def __init__(self):
        super().__init__()
        self.scatter_router = RandomScatterRouter(route_num=2)
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 10)
        self.gather_router = RandomGatherRouter(route_num=2)

    def forward(self, x):
        route_results, reverse_indice, reverse_shape = self.scatter_router(x)
        x_0 = self.expert1(route_results[0])
        x_1 = self.expert2(route_results[1])
        x = self.gather_router([x_0, x_1], reverse_indice, reverse_shape)
        return x


@brt.netlet
class Normal(nn.Module):
    def __init__(self):
        super().__init__()
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 10)

    def forward(self, x):
        x = self.expert1(x)
        x = self.expert2(x)
        return x


@brt.domain
class MoEModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.moe = MoE()

    def forward(self, x):
        return self.moe(x)


@brt.domain
class NormalModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.normal = Normal()

    def forward(self, x):
        return self.normal(x)


normal_model = NormalModel()
moe_model = MoEModel()
script_moe_model = torch.jit.script(moe_model)
sm_graph = script_moe_model.moe.graph
# print(sm_graph)
# for node in sm_graph.nodes():
#     if node.kind() == "prim::PythonOp":
#         print(node, node.inputsAt(0).debugName())
    # print(node.s("name"))
    # print(node.inputsAt(0).debugName())
    # _val = getattr(moe_model, node.s("name"))
    # print(_val)
    # print(type(_val))
# normal_model_ir = build_graph(normal_model)
ir_moe_model = build_graph(moe_model)
flattened_ir_moe_model = flatten_model_graph(ir_moe_model)


setting logger for brainstorm.frontend to DEBUG level
[2022-04-27 12:47:38] DEBUG (brainstorm.frontend/MainThread) building brt.router RandomScatterRouter, m_attrs: {'route_num': 2}
[2022-04-27 12:47:38] DEBUG (brainstorm.frontend/MainThread) building brt.router RandomGatherRouter, m_attrs: {'route_num': 2}


In [5]:
from brt.backend.pytorch import model_to_script

model_script = model_to_script(flattened_ir_moe_model)
print(model_script)

---------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import brt.nn

import torch
import brt


class _model(nn.Module):
    def __init__(self):
        super().__init__()
        self._moe__scatter_router = brt.router.scatter_router.RandomScatterRouter(route_num=2)
        self._moe__expert2 = torch.nn.modules.linear.Linear(in_features=10, out_features=10)
        self._moe__expert1 = torch.nn.modules.linear.Linear(in_features=10, out_features=10)
        self._moe__gather_router = brt.router.gather_router.RandomGatherRouter(route_num=2)
        self._mapping_ = {'_moe__scatter_router': None, '_moe__expert2': None, '_moe__expert1': None, '_moe__gather_router': None}

    def forward(self, x__1):
        _moe__Constant2 = 0
        _moe__Constant3 = 1
        _moe__scatter_router = self._moe__scatter_router()
        _moe__TupleUnpack6 = _moe__scatter_router
        _moe__aten____getitem__10 = _moe__

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import brt.nn

import torch
import brt


class m_model(nn.Module):
    def __init__(self):
        super().__init__()
        self._moe__scatter_router = brt.router.scatter_router.RandomScatterRouter(route_num=2)
        self._moe__expert2 = torch.nn.modules.linear.Linear(in_features=10, out_features=10)
        self._moe__expert1 = torch.nn.modules.linear.Linear(in_features=10, out_features=10)
        self._moe__gather_router = brt.router.gather_router.RandomGatherRouter(route_num=2)
        self._mapping_ = {'_moe__scatter_router': None, '_moe__expert2': None, '_moe__expert1': None, '_moe__gather_router': None}

    def forward(self, x__1):
        _moe__Constant2 = 0
        _moe__Constant3 = 1
        _moe__scatter_router = self._moe__scatter_router(x__1)
        _moe__TupleUnpack6 = _moe__scatter_router
        _moe__aten____getitem__10 = _moe__TupleUnpack6[0][_moe__Constant3]
        _moe__aten____getitem__8 = _moe__TupleUnpack6[0][_moe__Constant2]
        _moe__expert2 = self._moe__expert2(_moe__aten____getitem__10)
        _moe__expert1 = self._moe__expert1(_moe__aten____getitem__8)
        _moe__ListConstruct12 = [_moe__expert1, _moe__expert2]
        _moe__gather_router = self._moe__gather_router(_moe__ListConstruct12, _moe__TupleUnpack6[1], _moe__TupleUnpack6[2])
        return _moe__gather_router
    
moe = m_model()
x = torch.Tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
x = moe(x)

[tensor([[-1.1668,  4.0991, -2.0578,  4.3479,  5.3001, -0.1832, -1.3333,  9.3435,
         -6.1432, -0.5831]], grad_fn=<AddmmBackward0>), tensor([[ 3.9222,  1.4431, -0.2387,  1.0776,  0.0810,  1.5783,  1.2412,  0.8928,
          1.4979, -1.1245]], grad_fn=<AddmmBackward0>)]
