In [3]:
import torch
from brt.common import log
import brt.nn as nn
from brt.routers.app import RandScatterRouter
from brt.routers import GatherRouter


log.set_level("frontend", "INFO")
log.set_level("backend", "INFO")
log.set_level("ir", "INFO")


class MoE(nn.Module):
    def __init__(self):
        super().__init__()
        self.scatter_router = RandScatterRouter(dst_num=2)
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 10)
        self.gather_router = GatherRouter(dst_num=2)

    def forward(self, x):
        route_results = self.scatter_router(x)
        x_0 = self.expert1(route_results[0])
        x_1 = self.expert2(route_results[1])
        x = self.gather_router([x_0, x_1])
        return x


class MoEModel(nn.Module):
    def __init__(
        self,
    ):
        super().__init__()
        self.moe = MoE()

    def forward(self, x):
        return self.moe(x)


moe_model = MoEModel()
x = torch.arange(0, 30, dtype=torch.float32).view(3, 10)
x = moe_model(x)
print(x)


tensor([[ 0.4654, -1.0616,  1.6243,  5.2247, -4.3258,  1.1867,  2.2375, -0.1648,
          2.0015, -6.2896],
        [ 0.3991, 21.2339,  2.9036, 12.8604, -4.4032, -2.1136,  6.1745, -0.9672,
         -0.2255,  1.8724],
        [ 0.5039, 35.7569,  4.3346, 20.2232, -7.7174, -6.6343,  9.9190,  0.0566,
         -2.7642,  3.6952]], grad_fn=<AliasBackward0>)


In [4]:
from brt.frontend import build_graph
from brt.backend.pytorch import model_to_script
from brt.primitive.helper import symbolize

script_moe_model = torch.jit.script(symbolize(moe_model))
sm_graph = script_moe_model.graph
ir_moe_model = build_graph(moe_model)
model_script = model_to_script(ir_moe_model)
print(model_script)


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import brt.nn

import torch
import brt


class MoEModel_model__moe(nn.Module):
    def __init__(self):
        super().__init__()
        self._scatter_router = brt.routers.rand.RandomScatterRouter(dst_num=2)
        self._expert1 = torch.nn.modules.linear.Linear(in_features=10, out_features=10)
        self._expert2 = torch.nn.modules.linear.Linear(in_features=10, out_features=10)
        self._gather_router = brt.routers.rand.RandomGatherRouter(dst_num=2)
        self._mapping_ = {'_scatter_router': 'moe.scatter_router', '_expert1': 'moe.expert1', '_expert2': 'moe.expert2', '_gather_router': 'moe.gather_router'}

    def forward(self, x__1):
        _Constant2 = 0
        _Constant3 = 1
        _scatter_router = self._scatter_router(x__1)
        _aten____getitem__6 = _scatter_router[_Constant2]
        _aten____getitem__8 = _scatter_router[_Constant3]
        _expert1 = self._expert1(_ate

In [7]:
import torch
from brt.common import log
import brt.nn as nn
from brt.routers import ScatterRouter, GatherRouter
from brt.frontend import build_graph
from brt.primitive.helper import symbolize
from brt.backend.pytorch import model_to_script

log.set_level("frontend", "INFO")
log.set_level("backend", "INFO")
log.set_level("ir", "INFO")


route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU())


class DynamicRouting(nn.Module):
    def __init__(self, dst_num):
        super().__init__()
        self.scatter_router_0 = ScatterRouter(
            dst_num=dst_num,
            route_func=route_func,
            route_method="threshold",
        )
        self.scatter_router_1 = ScatterRouter(
            dst_num=dst_num,
            route_func=route_func,
            route_method="threshold",
        )
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 20)
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(dst_num=dst_num)
        self.gather_router_1 = GatherRouter(dst_num=dst_num)

    def forward(self, x, y):
        route_results_x = self.scatter_router_0(x)
        route_results_y = self.scatter_router_1(y)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting(2)

x = torch.randn((2, 10))
y = torch.randn((2, 10))

x, y = dy_model(x, y)

print(x.shape)
print(y.shape)


torch.Size([2, 10])
torch.Size([1, 20])


In [8]:
script_dy_model = torch.jit.script(symbolize(dy_model))
sm_graph = script_dy_model.graph
print(sm_graph)
ir_dy_model = build_graph(dy_model)

model_script = model_to_script(ir_dy_model)
print(model_script)


graph(%self : __torch__.DynamicRouting,
      %x.1 : Tensor,
      %y.1 : Tensor):
  %11 : int = prim::Constant[value=0]() # /tmp/ipykernel_1193600/3920826766.py:40:43
  %16 : int = prim::Constant[value=1]() # /tmp/ipykernel_1193600/3920826766.py:41:43
  %scatter_router_0 : __torch__.brt.routers.base.ScatterRouter = prim::GetAttr[name="scatter_router_0"](%self)
  %route_results_x.1 : Tensor[] = prim::CallMethod[name="forward"](%scatter_router_0, %x.1) # /tmp/ipykernel_1193600/3920826766.py:38:26
  %scatter_router_1 : __torch__.brt.routers.base.ScatterRouter = prim::GetAttr[name="scatter_router_1"](%self)
  %route_results_y.1 : Tensor[] = prim::CallMethod[name="forward"](%scatter_router_1, %y.1) # /tmp/ipykernel_1193600/3920826766.py:39:26
  %expert1 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="expert1"](%self)
  %12 : Tensor = aten::__getitem__(%route_results_x.1, %11) # /tmp/ipykernel_1193600/3920826766.py:40:27
  %x_0.1 : Tensor = prim::CallMethod[name="forward"](

In [10]:
import torch
from brt.common import log
import brt.nn as nn
from brt.routers import ScatterRouter, GatherRouter
from brt.frontend import build_graph
from brt.primitive.helper import symbolize
from brt.backend.pytorch import model_to_script

log.set_level("frontend", "INFO")
log.set_level("backend", "INFO")
log.set_level("ir", "INFO")


route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU())


class DynamicRouting(nn.Module):
    def __init__(self, dst_num):
        super().__init__()
        self.scatter_router_0 = ScatterRouter(
            dst_num=dst_num,
            route_func=route_func,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
        )
        self.scatter_router_1 = ScatterRouter(
            dst_num=dst_num,
            route_func=route_func,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
        )
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 20)
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(dst_num=dst_num)
        self.gather_router_1 = GatherRouter(dst_num=dst_num)

    def forward(self, x, y):
        route_results_x = self.scatter_router_0(x)
        route_results_y = self.scatter_router_1(y)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting(2)

x = torch.randn((2, 10))
y = torch.randn((2, 10))

x, y = dy_model(x, y)

print(x)
print(y)


tensor([[-3.3219e-01, -4.3970e-01,  2.5187e-01, -7.0467e-02, -2.9024e-01,
         -3.7078e-01,  1.5290e-01, -3.6852e-01,  6.7264e-01,  3.7140e-04],
        [-7.1383e-02, -1.7643e-01, -3.1747e-01, -2.7569e-01, -2.8982e-01,
          1.7913e-01, -2.6814e-01, -1.9270e-01,  2.8713e-01, -1.6198e-01]],
       grad_fn=<AliasBackward0>)
tensor([[ 0.4941,  0.0527,  0.2407,  0.3436, -0.9369, -0.2936,  0.3314, -0.8531,
         -0.1485, -0.0379, -1.1913, -0.9412, -0.0833,  0.3690, -0.2636,  0.4481,
         -0.5323, -0.9983, -0.3011,  0.0529],
        [-0.1380,  0.1222, -0.1333,  0.9078, -0.5864,  0.8863,  0.4465,  0.0145,
          0.3410, -0.4837, -0.4201, -1.1782,  0.7786, -0.2429,  1.2011,  0.3658,
          0.0459,  0.5267,  0.3156, -0.8280]], grad_fn=<AliasBackward0>)
