In [4]:
import brt
import brt.nn as nn
import torch


@brt.top_graph
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 10)
        self.linear2 = nn.Linear(10, 10)

    def forward(self, x, y: int):
        for i in range(y):
            x = self.linear1(x)
            x = self.linear2(x)
        return x


simple_net = SimpleNet()

x = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
x = simple_net(x, 10)
print(x)
x = simple_net(None, 10)
print(x)
script_simple_net = torch.jit.script(simple_net)
simple_net_inlined_graph = script_simple_net.inlined_graph
print(simple_net_inlined_graph)



tensor([-0.2486,  0.3125,  0.1019, -0.0980,  0.1306,  0.2705,  0.0085,  0.0107,
         0.0762, -0.2261], grad_fn=<AddBackward0>)
None
graph(%self : __torch__.brt.graph.___torch_mangle_1.wrapper,
      %x.1 : Tensor,
      %y.1 : int):
  %3 : bool = prim::Constant[value=1]() # /tmp/ipykernel_1007005/2018613437.py:14:8
  %x : Tensor = prim::Loop(%y.1, %3, %x.1) # /tmp/ipykernel_1007005/2018613437.py:14:8
    block0(%i : int, %x.17 : Tensor):
      %linear1 : __torch__.brt.graph.wrapper = prim::GetAttr[name="linear1"](%self)
      %11 : Function = prim::Constant[name="linear"]()
      %weight.1 : Tensor = prim::GetAttr[name="weight"](%linear1)
      %bias.1 : Tensor = prim::GetAttr[name="bias"](%linear1)
      %x.5 : Tensor = aten::linear(%x.17, %weight.1, %bias.1) # /state/partition/whcui/tools/pyenv/versions/miniconda3-4.7.12/lib/python3.7/site-packages/torch/nn/functional.py:1848:11
      %linear2 : __torch__.brt.graph.wrapper = prim::GetAttr[name="linear2"](%self)
      %15 : Functi

In [5]:
script_simple_net.save("simple_net.pt")

In [6]:
import torch
import brt
import brt.nn as nn
from brt.router import RandomScatterRouter, RandomGatherRouter


class MoE(nn.Module):
    def __init__(self):
        super().__init__()
        self.scatter_router = RandomScatterRouter(route_num=2)
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 10)
        self.gather_router = RandomGatherRouter(route_num=2)

    def forward(self, x, y: int):
        for i in range(y):
            route_results, reverse_indice, origin_shape = self.scatter_router(x)
            x_0 = self.expert1(route_results[0])
            x_1 = self.expert2(route_results[1])
            x = self.gather_router(
                reverse_indice,
                origin_shape,
                x_0,
                x_1,
            )
        return x


@brt.top_graph
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.moe = MoE()

    def forward(self, x, y: int):
        return self.moe(x, y)


model = Model()

x = torch.Tensor(
    [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]
).cuda()
model.cuda()
y=10
z = model(x, y)
print(z)


# model.brt_script(True)
script_model = torch.jit.script(model)
print(script_model.inlined_graph)


tensor([[ 0.2073, -0.2659, -0.1750,  0.2358,  0.2630,  0.3643,  0.0676,  0.2210,
         -0.2627,  0.1155],
        [-0.0605, -0.3800,  0.0009, -0.0936, -0.1959,  0.1131, -0.3850, -0.0184,
         -0.4683, -0.0590]], device='cuda:0', grad_fn=<ViewBackward0>)
graph(%self : __torch__.brt.graph.___torch_mangle_3.wrapper,
      %x.1 : Tensor,
      %y.1 : int):
  %moe : __torch__.___torch_mangle_2.MoE = prim::GetAttr[name="moe"](%self)
  %5 : Function = prim::Constant[name="linear"]()
  %6 : int = prim::Constant[value=1]() # /tmp/ipykernel_1007005/1820598631.py:19:45
  %7 : int = prim::Constant[value=0]() # /tmp/ipykernel_1007005/1820598631.py:18:45
  %8 : bool = prim::Constant[value=1]() # /tmp/ipykernel_1007005/1820598631.py:16:8
  %x : Tensor = prim::Loop(%y.1, %8, %x.1) # /tmp/ipykernel_1007005/1820598631.py:16:8
    block0(%i : int, %x.11 : Tensor):
      %scatter_router : __torch__.brt.router.scatter_router.RandomScatterRouter = prim::GetAttr[name="scatter_router"](%moe)
      %13 

In [8]:
script_model.save("model.pt")

RuntimeError: 
Could not export Python function call 'forward'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to __constants__:
  File "/tmp/ipykernel_1007005/1820598631.py", line 17
    def forward(self, x, y: int):
        for i in range(y):
            route_results, reverse_indice, origin_shape = self.scatter_router(x)
                                                          ~~~~~~~~~~~~~~~~~~~ <--- HERE
            x_0 = self.expert1(route_results[0])
            x_1 = self.expert2(route_results[1])


In [None]:
import torch
import torch.nn as nn


class MyModule(nn.Module):
    def __init__(self, use_memory_efficient):
        super(MyModule, self).__init__()
        self.use_memory_efficient = use_memory_efficient

    @torch.jit.ignore(drop=True)
    def memory_efficient(self, x):
        import pdb

        pdb.set_trace()
        return x + 10

    def forward(self, x):
        # Use not-yet-scriptable memory efficient mode
        if self.use_memory_efficient:
            return self.memory_efficient(x)
        else:
            return x + 10


m = torch.jit.script(MyModule(use_memory_efficient=False))
# m.save("m.pt")

m = torch.jit.script(MyModule(use_memory_efficient=True))
# exception raised
print(m.inlined_graph)
# m.save("m.pt")
# m(torch.rand(100))
