In [2]:
import brt
import brt.nn as nn
import torch


@brt.netlet
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 10)
        self.linear2 = nn.Linear(10, 10)

    def forward(self, x, y: int):
        for i in range(y):
            x = self.linear1(x)
            x = self.linear2(x)
        return x


simple_net = SimpleNet()

x = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
x = simple_net(x, 10)
print(x)

simple_net.brt(True)
script_simple_net = torch.jit.script(simple_net)
simple_net_inlined_graph = script_simple_net.inlined_graph
print(simple_net_inlined_graph)


tensor([ 0.1035,  0.3010,  0.2266, -0.2274,  0.3631,  0.1613,  0.2359, -0.0746,
        -0.0900,  0.0652], grad_fn=<AddBackward0>)
graph(%self : __torch__.brt.netlet.___torch_mangle_0.wrapper,
      %x.1 : Tensor,
      %y.1 : int):
  %3 : bool = prim::Constant[value=1]() # /tmp/ipykernel_1110553/4091440697.py:14:8
  %x : Tensor = prim::Loop(%y.1, %3, %x.1) # /tmp/ipykernel_1110553/4091440697.py:14:8
    block0(%i : int, %x.17 : Tensor):
      %linear1 : __torch__.brt.netlet.wrapper = prim::GetAttr[name="linear1"](%self)
      %11 : Function = prim::Constant[name="linear"]()
      %weight.1 : Tensor = prim::GetAttr[name="weight"](%linear1)
      %bias.1 : Tensor = prim::GetAttr[name="bias"](%linear1)
      %x.5 : Tensor = aten::linear(%x.17, %weight.1, %bias.1) # /state/partition/whcui/tools/pyenv/versions/miniconda3-4.7.12/lib/python3.7/site-packages/torch/nn/functional.py:1848:11
      %linear2 : __torch__.brt.netlet.wrapper = prim::GetAttr[name="linear2"](%self)
      %15 : Function

In [2]:
import torch
import brt
import brt.nn as nn
from brt.router import RandomScatterRouter, RandomGatherRouter


@brt.netlet
class MoE(nn.Module):
    def __init__(self):
        super().__init__()
        self.scatter_router = RandomScatterRouter(route_num=2)
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 10)
        self.gather_router = RandomGatherRouter(route_num=2)

    def forward(self, x, y: int):
        for i in range(y):
            route_results, reverse_indice, origin_shape = self.scatter_router(x)
            x_0 = self.expert1(route_results[0])
            x_1 = self.expert2(route_results[1])
            x = self.gather_router(reverse_indice, origin_shape, x_0, x_1)
        return x


moe = MoE()

x = torch.Tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
y = 10
z = moe(x, y)

print(z)


[None, tensor([[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.],
        [ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.]])]
[tensor([[ 3.5053e-02, -3.0767e+00, -4.2341e-03, -2.8722e+00, -3.7326e+00,
          3.9108e+00,  1.0917e+00,  5.3074e+00, -1.0823e+00,  2.0326e+00]],
       grad_fn=<StackBackward0>), tensor([[ 3.5053e-02, -3.0767e+00, -4.2341e-03, -2.8722e+00, -3.7326e+00,
          3.9108e+00,  1.0917e+00,  5.3074e+00, -1.0823e+00,  2.0326e+00]],
       grad_fn=<StackBackward0>)]
[None, tensor([[-0.6160,  1.6336,  0.8787, -1.7561,  1.5403, -0.6676,  0.1043, -0.5224,
         -1.5715, -2.3269],
        [-3.3846, -3.6145,  2.3732, -0.7463,  0.9669, -1.4154, -0.4790, -2.5548,
         -3.0576,  2.0615]], grad_fn=<StackBackward0>)]
[tensor([[ 0.8302,  0.3068,  0.3192,  0.2066,  0.2634, -0.2640, -0.9765,  0.4962,
          1.1794,  0.2318]], grad_fn=<StackBackward0>), tensor([[ 2.4163,  0.4960,  2.0154,  0.4898, -2.0935, -0.7519,  0.1408, -2.0699,
          0.5020, -0.4846]], g

In [5]:
moe.brt(True)
script_moe = torch.jit.script(moe)

print(script_moe.inlined_graph)

graph(%self : __torch__.brt.netlet.___torch_mangle_1.wrapper,
      %x.1 : Tensor,
      %y.1 : int):
  %3 : bool = prim::Constant[value=1]() # /tmp/ipykernel_958749/1823930129.py:17:8
  %4 : int = prim::Constant[value=0]() # /tmp/ipykernel_958749/1823930129.py:19:45
  %5 : int = prim::Constant[value=1]() # /tmp/ipykernel_958749/1823930129.py:20:45
  %x : Tensor = prim::Loop(%y.1, %3, %x.1) # /tmp/ipykernel_958749/1823930129.py:17:8
    block0(%i : int, %x.11 : Tensor):
      %scatter_router : __torch__.brt.router.scatter_router.RandomScatterRouter = prim::GetAttr[name="scatter_router"](%self)
      %10 : (Tensor, Tensor, Tensor) = ^forward()(%scatter_router, %x.11) # /tmp/ipykernel_958749/1823930129.py:18:58
      %route_results.1 : Tensor, %reverse_indice.1 : Tensor, %origin_shape.1 : Tensor = prim::TupleUnpack(%10)
      %expert1 : __torch__.brt.netlet.wrapper = prim::GetAttr[name="expert1"](%self)
      %15 : Tensor = aten::select(%route_results.1, %4, %4) # /tmp/ipykernel_958749/1

ModuleNotFoundError: No module named '_reduction'

In [11]:
moe.brt(True)
script_moe = torch.jit.script(moe)

print(script_moe.inlined_graph)

graph(%self : __torch__.brt.netlet.___torch_mangle_1.wrapper,
      %x.1 : Tensor,
      %y.1 : int):
  %3 : bool = prim::Constant[value=1]() # /tmp/ipykernel_2208925/1408058768.py:17:8
  %4 : int = prim::Constant[value=0]() # /tmp/ipykernel_2208925/1408058768.py:19:45
  %5 : int = prim::Constant[value=1]() # /tmp/ipykernel_2208925/1408058768.py:20:45
  %x : Tensor = prim::Loop(%y.1, %3, %x.1) # /tmp/ipykernel_2208925/1408058768.py:17:8
    block0(%i : int, %x.11 : Tensor):
      %scatter_router : __torch__.brt.router.scatter_router.RandomScatterRouter = prim::GetAttr[name="scatter_router"](%self)
      %10 : (Tensor, Tensor, Tensor) = ^forward()(%scatter_router, %x.11) # /tmp/ipykernel_2208925/1408058768.py:18:58
      %route_results.1 : Tensor, %reverse_indice.1 : Tensor, %origin_shape.1 : Tensor = prim::TupleUnpack(%10)
      %expert1 : __torch__.brt.netlet.wrapper = prim::GetAttr[name="expert1"](%self)
      %15 : Tensor = aten::select(%route_results.1, %4, %4) # /tmp/ipykernel_220

In [4]:
import brt
from brt.nn._reduction import *
z = brt.nn._reduction
import torch
z = torch.nn._reduction
