In [None]:
import brt
import brt.nn as nn
import torch


@brt.netlet
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 10)
        self.linear2 = nn.Linear(10, 10)

    def forward(self, x, y: int):
        for i in range(y):
            x = self.linear1(x)
            x = self.linear2(x)
        return x


simple_net = SimpleNet()

x = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
x = simple_net(x, 10)
print(x)

simple_net.brt(True)
script_simple_net = torch.jit.script(simple_net)
simple_net_inlined_graph = script_simple_net.inlined_graph
print(simple_net_inlined_graph)


In [9]:
import torch
import brt
import brt.nn as nn
from brt.router import RandomScatterRouter, RandomGatherRouter


@brt.netlet
class MoE(nn.Module):
    def __init__(self):
        super().__init__()
        self.scatter_router = RandomScatterRouter(route_num=2)
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 10)
        self.gather_router = RandomGatherRouter(route_num=2)

    def forward(self, x, y: int):
        for i in range(y):
            route_results, reverse_indice, origin_shape = self.scatter_router(x)
            x_0 = self.expert1(route_results[0])
            x_1 = self.expert2(route_results[1])
            x = self.gather_router(reverse_indice, origin_shape, x_0, x_1)
        return x


moe = MoE()

x = torch.Tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
y = 10
z = moe(x, y)

print(z)


[tensor([[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.]]), tensor([[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.]])]
[tensor([[ 3.3947e+00,  3.9147e+00, -3.7942e+00, -3.2952e+00,  8.1701e-01,
          2.0672e+00,  6.8571e+00,  5.6515e+00, -6.8027e-01, -1.8900e+00],
        [ 1.6006e+00,  4.6057e+00, -3.5487e+00, -9.4125e+00, -5.0309e-03,
         -5.7639e-01, -6.2666e-01,  2.1591e-01, -1.9878e+00, -2.3471e-02]],
       grad_fn=<StackBackward0>), None]
[tensor([[-3.1959, -0.9272, -1.1972,  1.8145,  3.0189, -0.2120, -0.2599, -2.8027,
         -2.3451,  1.8039]], grad_fn=<StackBackward0>), tensor([[-3.8645,  2.4000,  0.0152,  1.3356,  2.0964,  2.0406,  2.7486,  0.4004,
          0.8766, -3.3796]], grad_fn=<StackBackward0>)]
[tensor([[ 0.1495, -1.3255, -1.5483,  0.7197,  0.8409, -0.1732, -0.1724,  2.8066,
          0.4115,  0.2538]], grad_fn=<StackBackward0>), tensor([[ 1.3291,  1.2626,  0.3755, -1.3392, -1.1763, -2.0587, -0.0837,  1.5242,
          2.1446, -2.1055]], grad_fn=<Sta

In [10]:
moe.brt(True)
script_moe = torch.jit.script(moe)

print(script_moe.inlined_graph)

graph(%self : __torch__.brt.netlet.___torch_mangle_1.wrapper,
      %x.1 : Tensor,
      %y.1 : int):
  %3 : bool = prim::Constant[value=1]() # /tmp/ipykernel_2208925/1408058768.py:17:8
  %4 : int = prim::Constant[value=0]() # /tmp/ipykernel_2208925/1408058768.py:19:45
  %5 : int = prim::Constant[value=1]() # /tmp/ipykernel_2208925/1408058768.py:20:45
  %x : Tensor = prim::Loop(%y.1, %3, %x.1) # /tmp/ipykernel_2208925/1408058768.py:17:8
    block0(%i : int, %x.11 : Tensor):
      %scatter_router : __torch__.brt.router.scatter_router.RandomScatterRouter = prim::GetAttr[name="scatter_router"](%self)
      %10 : (Tensor, Tensor, Tensor) = ^forward()(%scatter_router, %x.11) # /tmp/ipykernel_2208925/1408058768.py:18:58
      %route_results.1 : Tensor, %reverse_indice.1 : Tensor, %origin_shape.1 : Tensor = prim::TupleUnpack(%10)
      %expert1 : __torch__.brt.netlet.wrapper = prim::GetAttr[name="expert1"](%self)
      %15 : Tensor = aten::select(%route_results.1, %4, %4) # /tmp/ipykernel_220

In [11]:
moe.brt(True)
script_moe = torch.jit.script(moe)

print(script_moe.inlined_graph)

graph(%self : __torch__.brt.netlet.___torch_mangle_1.wrapper,
      %x.1 : Tensor,
      %y.1 : int):
  %3 : bool = prim::Constant[value=1]() # /tmp/ipykernel_2208925/1408058768.py:17:8
  %4 : int = prim::Constant[value=0]() # /tmp/ipykernel_2208925/1408058768.py:19:45
  %5 : int = prim::Constant[value=1]() # /tmp/ipykernel_2208925/1408058768.py:20:45
  %x : Tensor = prim::Loop(%y.1, %3, %x.1) # /tmp/ipykernel_2208925/1408058768.py:17:8
    block0(%i : int, %x.11 : Tensor):
      %scatter_router : __torch__.brt.router.scatter_router.RandomScatterRouter = prim::GetAttr[name="scatter_router"](%self)
      %10 : (Tensor, Tensor, Tensor) = ^forward()(%scatter_router, %x.11) # /tmp/ipykernel_2208925/1408058768.py:18:58
      %route_results.1 : Tensor, %reverse_indice.1 : Tensor, %origin_shape.1 : Tensor = prim::TupleUnpack(%10)
      %expert1 : __torch__.brt.netlet.wrapper = prim::GetAttr[name="expert1"](%self)
      %15 : Tensor = aten::select(%route_results.1, %4, %4) # /tmp/ipykernel_220