In [1]:
import torch
from brt.common import log
import brt
import brt.nn as nn
from brt.router import RandomScatterRouter, RandomGatherRouter
from brt.frontend import build_graph

log.set_level("frontend", "DEBUG")
# log.set_level("backend", "DEBUG")
# log.set_level("ir", "DEBUG")


@brt.netlet
class MoE(nn.Module):
    def __init__(self):
        super().__init__()
        self.scatter_router = RandomScatterRouter(route_num=2)
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 10)
        self.gather_router = RandomGatherRouter(route_num=2)

    def forward(self, x):
        route_results, reverse_indice, reverse_shape = self.scatter_router(x)
        x_0 = self.expert1(route_results[0])
        x_1 = self.expert2(route_results[1])
        x = self.gather_router([x_0, x_1], reverse_indice, reverse_shape)
        return x


@brt.netlet
class Normal(nn.Module):
    def __init__(self):
        super().__init__()
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 10)

    def forward(self, x):
        x = self.expert1(x)
        x = self.expert2(x)
        return x


@brt.domain
class MoEModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.moe = MoE()

    def forward(self, x):
        return self.moe(x)


@brt.domain
class NormalModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.normal = Normal()

    def forward(self, x):
        return self.normal(x)


normal_model = NormalModel()
moe_model = MoEModel()
script_moe_model = torch.jit.script(moe_model)
sm_graph = script_moe_model.moe.graph
# print(sm_graph)
# for node in sm_graph.nodes():
#     if node.kind() == "prim::PythonOp":
#         print(node, node.inputsAt(0).debugName())
    # print(node.s("name"))
    # print(node.inputsAt(0).debugName())
    # _val = getattr(moe_model, node.s("name"))
    # print(_val)
    # print(type(_val))
# normal_model_ir = build_graph(normal_model)
ir_moe_model = build_graph(moe_model)


setting logger for brainstorm.frontend to DEBUG level
[2022-04-26 23:02:40] DEBUG (brainstorm.frontend/MainThread) building brt.router RandomScatterRouter, m_attrs: {'route_num': 2}
[2022-04-26 23:02:40] DEBUG (brainstorm.frontend/MainThread) building brt.router RandomGatherRouter, m_attrs: {'route_num': 2}


In [2]:
from brt.backend.pytorch import model_to_script

print("---------------------------------------------")
model_script = model_to_script(ir_moe_model)
print(model_script)

---------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import brt.nn

import brt
import torch


class _model__moe(nn.Module):
    def __init__(self):
        super().__init__()
        self._scatter_router = brt.router.scatter_router.RandomScatterRouter(route_num=2)
        self._expert2 = torch.nn.modules.linear.Linear(in_features=10, out_features=10)
        self._expert1 = torch.nn.modules.linear.Linear(in_features=10, out_features=10)
        self._gather_router = brt.router.gather_router.RandomGatherRouter(route_num=2)
        self._mapping_ = {'_scatter_router': 'moe.scatter_router', '_expert2': 'moe.expert2', '_expert1': 'moe.expert1', '_gather_router': 'moe.gather_router'}

    def forward(self, x__1):
        _Constant2 = 0
        _Constant3 = 1
        _scatter_router = self._scatter_router(x__1)
        _TupleUnpack6 = _scatter_router
        _aten____getitem__10 = _TupleUnpack6[0][_Const