In [1]:
import brt
import brt.nn as nn
import torch


@brt.netlet
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 10)
        self.linear2 = nn.Linear(10, 10)
    
    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        return x


@brt.domain
class SimpleNet2(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = SimpleNet()
    
    def forward(self, x):
        x = self.net(x)
        return x

simple_net = SimpleNet()
simple_net2 = SimpleNet2()

# script_simple_net = torch.jit.script(simple_net)

script_simple_net2 = torch.jit.script(simple_net2)

# x = torch.Tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
# print(script_simple_net2(x))
# print("-----------------------------------------------------")
# print(script_simple_net2.inlined_graph)
# print("-----------------------------------------------------")
# print(script_simple_net2.net.inlined_graph)
# print("-----------------------------------------------------")


In [2]:
from brt.frontend import build_graph
from brt.backend.pytorch import model_to_script
model_ir = build_graph(simple_net2)

for node in model_ir.get_nodes():
    print(node)

print("-----------------")

for cell_node in model_ir.get_cell_nodes():
    print(cell_node)

model_code = model_to_script(model_ir)
print(model_code)

Node(id=-1, name=_inputs, python_name=None, label=None, operation=_IOPseudoOperation(type="_inputs"))
Node(id=-2, name=_outputs, python_name=None, label=None, operation=_IOPseudoOperation(type="_outputs"))
Node(id=3, name=_model__net__linear1, python_name=net.linear1, label=None, operation=ModuleOperator(type="__torch__.torch.nn.modules.linear.Linear", in_features=10, out_features=10))
Node(id=5, name=_model__net__linear2, python_name=net.linear2, label=None, operation=ModuleOperator(type="__torch__.torch.nn.modules.linear.Linear", in_features=10, out_features=10))
Node(id=-1, name=_inputs, python_name=None, label=None, operation=_IOPseudoOperation(type="_inputs"))
Node(id=-2, name=_outputs, python_name=None, label=None, operation=_IOPseudoOperation(type="_outputs"))
Node(id=6, name=_model__net, python_name=net, label=None, operation=Cell(type="_cell"))
-----------------
Node(id=6, name=_model__net, python_name=net, label=None, operation=Cell(type="_cell"))
import torch
import torch.nn

In [1]:
import torch
import brt
import brt.nn as nn
from brt.router import RandomScatterRouter, RandomGatherRouter


@brt.netlet
class MoE(nn.Module):
    def __init__(self):
        super().__init__()
        self.scatter_router = RandomScatterRouter(route_num=2)
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 10)
        self.gather_router = RandomGatherRouter(route_num=2)

    def forward(self, x):
        route_results, reverse_indice, reverse_shape = self.scatter_router(x)
        x_0 = self.expert1(route_results[0])
        x_1 = self.expert2(route_results[1])
        x = self.gather_router([x_0, x_1], reverse_indice, reverse_shape)
        return x_0, x_1


@brt.netlet
class MoEModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.moe = MoE()

    def forward(self, x):
        return self.moe(x)

@brt.domain
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.moe_model = MoEModel()

    def forward(self, x):
        return self.moe_model(x)


moe_model = Model()

x = torch.Tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
# model.cuda()
y = 10
z = moe_model(x)
print(z)

script_moe_model = torch.jit.script(moe_model)
print(script_moe_model.graph)


tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.])
[tensor([[-0.5677,  1.4194,  4.3959, -2.1718, -2.7670,  6.5326,  1.7647, -1.5553,
         -5.6792,  3.9044],
        [-0.5677,  1.4194,  4.3959, -2.1718, -2.7670,  6.5326,  1.7647, -1.5553,
         -5.6792,  3.9044]], grad_fn=<AddmmBackward0>), tensor([], size=(0, 0))]
(tensor([[-0.5677,  1.4194,  4.3959, -2.1718, -2.7670,  6.5326,  1.7647, -1.5553,
         -5.6792,  3.9044],
        [-0.5677,  1.4194,  4.3959, -2.1718, -2.7670,  6.5326,  1.7647, -1.5553,
         -5.6792,  3.9044]], grad_fn=<AddmmBackward0>), tensor([], size=(0, 0)))
graph(%self : __torch__.Model,
      %x.1 : Tensor):
  %moe_model : __torch__.MoEModel = prim::GetAttr[name="moe_model"](%self)
  %4 : (Tensor, Tensor) = prim::CallMethod[name="forward"](%moe_model, %x.1) # /tmp/ipykernel_4014/1649510055.py:40:15
  return (%4)



In [2]:
from brt.frontend.builder import build_graph
from brt.backend.pytorch import model_to_script
model_ir = build_graph(moe_model)

for node in model_ir.get_nodes():
    print(node)

model_code = model_to_script(model_ir)


Node(id=-1, name=_inputs, python_name=None, label=None, operation=_IOPseudoOperation(type="_inputs"))
Node(id=-2, name=_outputs, python_name=None, label=None, operation=_IOPseudoOperation(type="_outputs"))
Node(id=3, name=_model__moe_model__moe__Constant3, python_name=None, label=None, operation=PrimConstant(type="prim::Constant", type='int', value=0))
Node(id=4, name=_model__moe_model__moe__Constant4, python_name=None, label=None, operation=PrimConstant(type="prim::Constant", type='int', value=1))
Node(id=6, name=_model__moe_model__moe__scatter_router, python_name=moe_model.moe.scatter_router, label=None, operation=ModuleOperator(type="__torch__.brt.router.scatter_router.RandomScatterRouter", route_num=2))
Node(id=7, name=_model__moe_model__moe__TupleUnpack7, python_name=None, label=None, operation=PrimTupleUnpack(type="prim::TupleUnpack"))
Node(id=9, name=_model__moe_model__moe__aten____getitem__9, python_name=moe_model.moe.__getitem__, label=None, operation=AtenGetitem(type="aten::_