In [1]:
import brt
import brt.nn as nn
import torch


@brt.netlet
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 10)
        self.linear2 = nn.Linear(10, 10)
    
    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        return x


@brt.domain
class SimpleNet2(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = SimpleNet()
    
    def forward(self, x):
        x = self.net(x)
        return x

simple_net = SimpleNet()
simple_net2 = SimpleNet2()

# script_simple_net = torch.jit.script(simple_net)

script_simple_net2 = torch.jit.script(simple_net2)

# x = torch.Tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
# print(script_simple_net2(x))
# print("-----------------------------------------------------")
# print(script_simple_net2.inlined_graph)
# print("-----------------------------------------------------")
# print(script_simple_net2.net.inlined_graph)
# print("-----------------------------------------------------")


In [2]:
from brt.frontend import build_graph
from brt.backend.pytorch import model_to_script
model_ir = build_graph(simple_net2)

for node in model_ir.get_nodes():
    print(node)

print("-----------------")

for cell_node in model_ir.get_cell_nodes():
    print(cell_node)

model_code = model_to_script(model_ir)
print(model_code)

Node(id=-1, name=_inputs, python_name=None, label=None, operation=_IOPseudoOperation(type="_inputs"))
Node(id=-2, name=_outputs, python_name=None, label=None, operation=_IOPseudoOperation(type="_outputs"))
Node(id=3, name=_model__net__linear1, python_name=net.linear1, label=None, operation=ModuleOperator(type="__torch__.torch.nn.modules.linear.Linear", in_features=10, out_features=10))
Node(id=5, name=_model__net__linear2, python_name=net.linear2, label=None, operation=ModuleOperator(type="__torch__.torch.nn.modules.linear.Linear", in_features=10, out_features=10))
Node(id=-1, name=_inputs, python_name=None, label=None, operation=_IOPseudoOperation(type="_inputs"))
Node(id=-2, name=_outputs, python_name=None, label=None, operation=_IOPseudoOperation(type="_outputs"))
Node(id=6, name=_model__net, python_name=net, label=None, operation=Cell(type="_cell"))
-----------------
Node(id=6, name=_model__net, python_name=net, label=None, operation=Cell(type="_cell"))
import torch
import torch.nn

In [8]:
import torch
import brt
import brt.nn as nn
from brt.router import RandomScatterRouter, RandomGatherRouter

@brt.netlet
class MoE(nn.Module):
    def __init__(self):
        super().__init__()
        self.scatter_router = RandomScatterRouter(route_num=2)
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 10)
        self.gather_router = RandomGatherRouter(route_num=2)
    
    def forward(self, x):
        route_results, reverse_indice, reverse_shape = self.scatter_router(x)
        x_0 = self.expert1(route_results[0])
        x_1 = self.expert2(route_results[1])
        x = self.gather_router([x_0, x_1], reverse_indice, reverse_shape)
        return x


@brt.domain
class MoEModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.moe = MoE()

    def forward(self, x):
        return self.moe(x)

moe_model = MoEModel()


moe_model.optimize()
# print(moe_model.moe.scatter_router.symbolic_route)

script_moe_model = torch.jit.script(moe_model)
print(script_moe_model.inlined_graph)


graph(%self : __torch__.___torch_mangle_32.MoEModel,
      %x.1 : Tensor):
  %moe : __torch__.___torch_mangle_31.MoE = prim::GetAttr[name="moe"](%self)
  %4 : Function = prim::Constant[name="linear"]()
  %5 : int = prim::Constant[value=1]() # /tmp/ipykernel_1337437/314304295.py:19:41
  %6 : int = prim::Constant[value=0]() # /tmp/ipykernel_1337437/314304295.py:18:41
  %scatter_router : __torch__.brt.router.scatter.RandomScatterRouter = prim::GetAttr[name="scatter_router"](%moe)
  %route_num.1 : int = prim::GetAttr[name="route_num"](%scatter_router)
  %9 : Tensor[], %10 : Tensor[], %11 : Tensor = brt::symbolic_scatter_route(%x.1, %6, %route_num.1) # /home/whcui/brainstorm_project/brainstorm/python/brt/router/scatter.py:97:15
  %12 : NamedTuple(_0 : Tensor[], _1 : Tensor[], _2 : Tensor) = prim::TupleConstruct(%9, %10, %11)
  %route_results.1 : Tensor[], %reverse_indice.1 : Tensor[], %reverse_shape.1 : Tensor = prim::TupleUnpack(%12)
  %expert1 : __torch__.torch.nn.modules.linear.Linear = 

In [9]:
from brt.frontend.builder import build_graph
from brt.backend.pytorch import model_to_script
model_ir = build_graph(moe_model)

# for node in model_ir.get_nodes():
#     print(node)

model_code = model_to_script(model_ir)
print(model_code)


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import brt.nn

import torch
import brt


class _model__moe(nn.Module):
    def __init__(self):
        super().__init__()
        self._scatter_router = brt.router.scatter.RandomScatterRouter(route_num=2)
        self._expert1 = torch.nn.modules.linear.Linear(in_features=10, out_features=10)
        self._expert2 = torch.nn.modules.linear.Linear(in_features=10, out_features=10)
        self._gather_router = brt.router.gather.RandomGatherRouter(route_num=2)
        self._mapping_ = {'_scatter_router': 'moe.scatter_router', '_expert1': 'moe.expert1', '_expert2': 'moe.expert2', '_gather_router': 'moe.gather_router'}

    def forward(self, x__1):
        _Constant2 = 0
        _Constant3 = 1
        _scatter_router = self._scatter_router(x__1)
        _TupleUnpack5 = _scatter_router
        _aten____getitem__7 = _TupleUnpack5[0][_Constant2]
        _aten____getitem__9 = _TupleUnpack5[0][_Constan