In [1]:
import brt
import brt.nn as nn
import torch


@brt.netlet
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 10)
        self.linear2 = nn.Linear(10, 10)
    
    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        return x


@brt.domain
class SimpleNet2(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = SimpleNet()
    
    def forward(self, x):
        x = self.net(x)
        return x

simple_net = SimpleNet()
simple_net2 = SimpleNet2()

# script_simple_net = torch.jit.script(simple_net)

script_simple_net2 = torch.jit.script(simple_net2)

# x = torch.Tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
# print(script_simple_net2(x))
# print("-----------------------------------------------------")
# print(script_simple_net2.inlined_graph)
# print("-----------------------------------------------------")
# print(script_simple_net2.net.inlined_graph)
# print("-----------------------------------------------------")


In [2]:
from brt.frontend import build_graph
from brt.backend.pytorch import model_to_script
model_ir = build_graph(simple_net2)

for node in model_ir.get_nodes():
    print(node)

print("-----------------")

for cell_node in model_ir.get_cell_nodes():
    print(cell_node)

model_code = model_to_script(model_ir)
print(model_code)

Node(id=-1, name=_inputs, python_name=None, label=None, operation=_IOPseudoOperation(type="_inputs"))
Node(id=-2, name=_outputs, python_name=None, label=None, operation=_IOPseudoOperation(type="_outputs"))
Node(id=3, name=_model__net__linear1, python_name=net.linear1, label=None, operation=ModuleOperator(type="__torch__.torch.nn.modules.linear.Linear", in_features=10, out_features=10))
Node(id=5, name=_model__net__linear2, python_name=net.linear2, label=None, operation=ModuleOperator(type="__torch__.torch.nn.modules.linear.Linear", in_features=10, out_features=10))
Node(id=-1, name=_inputs, python_name=None, label=None, operation=_IOPseudoOperation(type="_inputs"))
Node(id=-2, name=_outputs, python_name=None, label=None, operation=_IOPseudoOperation(type="_outputs"))
Node(id=6, name=_model__net, python_name=net, label=None, operation=Cell(type="_cell"))
-----------------
Node(id=6, name=_model__net, python_name=net, label=None, operation=Cell(type="_cell"))
import torch
import torch.nn

In [1]:
import torch
import brt
import brt.nn as nn
from brt.router import RandomScatterRouter, RandomGatherRouter


@brt.netlet
class MoE(nn.Module):
    def __init__(self):
        super().__init__()
        self.scatter_router = RandomScatterRouter(route_num=2)
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 10)
        self.gather_router = RandomGatherRouter(route_num=2)

    def forward(self, x):
        route_results, reverse_indice, reverse_shape = self.scatter_router(x)
        x_0 = self.expert1(route_results[0])
        x_1 = self.expert2(route_results[1])
        x = self.gather_router([x_0, x_1], reverse_indice, reverse_shape)
        return x_0, x_1


@brt.domain
class MoEModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.moe = MoE()

    def forward(self, x):
        return self.moe(x)

moe_model = MoEModel()

x = torch.Tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
y = 10
z = moe_model(x)
print(z)

moe_model.optimize()
print(moe_model.moe.scatter_router.symbolic_route)

script_moe_model = torch.jit.script(moe_model)
print(script_moe_model.inlined_graph)


tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.])
[tensor([[-0.2873, -0.5483, -3.2447,  0.9373,  5.4349,  4.4751, -3.5992, -2.3106,
         -7.5724,  6.2787]], grad_fn=<AddmmBackward0>), tensor([[ 6.4272,  1.7441, -3.1521,  0.5942,  1.8151,  3.6928, -3.8299,  4.1981,
         -0.0100,  3.8856]], grad_fn=<AddmmBackward0>)]
(tensor([[-0.2873, -0.5483, -3.2447,  0.9373,  5.4349,  4.4751, -3.5992, -2.3106,
         -7.5724,  6.2787]], grad_fn=<AddmmBackward0>), tensor([[ 6.4272,  1.7441, -3.1521,  0.5942,  1.8151,  3.6928, -3.8299,  4.1981,
         -0.0100,  3.8856]], grad_fn=<AddmmBackward0>))
<bound method RandomScatterRouter.symbolic_route of RandomScatterRouter()>
graph(%self : __torch__.MoEModel,
      %x.1 : Tensor):
  %moe : __torch__.MoE = prim::GetAttr[name="moe"](%self)
  %4 : int = prim::Constant[value=1]() # /tmp/ipykernel_16103/2413130693.py:19:41
  %5 : int = prim::Constant[value=0]() # /tmp/ipykernel_16103/2413130693.py:18:41
  %scatter_router : __torch__.brt.rout

In [2]:
from brt.frontend.builder import build_graph
from brt.backend.pytorch import model_to_script
model_ir = build_graph(moe_model)

for node in model_ir.get_nodes():
    print(node)

model_code = model_to_script(model_ir)
print(model_code)


Node(id=-1, name=_inputs, python_name=None, label=None, operation=_IOPseudoOperation(type="_inputs"))
Node(id=-2, name=_outputs, python_name=None, label=None, operation=_IOPseudoOperation(type="_outputs"))
Node(id=2, name=_model__moe__Constant2, python_name=None, label=None, operation=PrimConstant(type="prim::Constant", type='int', value=0))
Node(id=3, name=_model__moe__Constant3, python_name=None, label=None, operation=PrimConstant(type="prim::Constant", type='int', value=1))
Node(id=5, name=_model__moe__scatter_router, python_name=moe.scatter_router, label=None, operation=ModuleOperator(type="__torch__.brt.router.scatter_router.RandomScatterRouter", route_num=2))
Node(id=6, name=_model__moe__TupleUnpack5, python_name=None, label=None, operation=PrimTupleUnpack(type="prim::TupleUnpack"))
Node(id=8, name=_model__moe__aten____getitem__7, python_name=moe.__getitem__, label=None, operation=AtenGetitem(type="aten::__getitem__"))
Node(id=9, name=_model__moe__expert1, python_name=moe.expert1