In [1]:
import torch
from brt.common import log
import brt.nn as nn
from brt.router.rand import RandomScatterRouter, RandomGatherRouter


log.set_level("frontend", "INFO")
log.set_level("backend", "INFO")
log.set_level("ir", "INFO")


class MoE(nn.Module):
    def __init__(self):
        super().__init__()
        self.scatter_router = RandomScatterRouter(route_num=2)
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 10)
        self.gather_router = RandomGatherRouter(route_num=2)

    def forward(self, x):
        route_results, route_tags, loads = self.scatter_router(x)
        x_0 = self.expert1(route_results[0])
        x_1 = self.expert2(route_results[1])
        x = self.gather_router([x_0, x_1], route_tags, loads)
        return x


class MoEModel(nn.Module):
    def __init__(
        self,
    ):
        super().__init__()
        self.moe = MoE()

    def forward(self, x):
        return self.moe(x)


moe_model = MoEModel()
x = torch.arange(0, 30, dtype=torch.float32).view(3, 10)
x = moe_model(x)
print(x)


tensor([[  1.4818,  -0.8407,  -0.7055,   0.4145,   0.1442,   0.1853,   7.6691,
           9.0105,   0.6117,  -2.2429],
        [  0.6658,  -4.9182,   0.5323,  -5.2771,  -0.5259,  -0.6493,  16.3698,
          29.6599,   4.1720,  -3.9005],
        [ -9.9166,  18.6627,  -4.6567, -12.5889,  -9.7236, -10.6290,   2.4189,
           7.1485,  15.3497, -12.4852]], grad_fn=<ScatterAddBackward0>)


In [14]:
from brt.frontend import build_graph
from brt.backend.pytorch import model_to_script
from brt.primitive.helper import symbolize

script_moe_model = torch.jit.script(symbolize(moe_model))
sm_graph = script_moe_model.graph
ir_moe_model = build_graph(moe_model)
model_script = model_to_script(ir_moe_model)
print(model_script)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import brt.nn

import brt
import torch


class MoEModel_model__moe(nn.Module):
    def __init__(self):
        super().__init__()
        self._scatter_router = brt.router.scatter.RandomScatterRouter(route_num=2)
        self._expert1 = torch.nn.modules.linear.Linear(in_features=10, out_features=10)
        self._expert2 = torch.nn.modules.linear.Linear(in_features=10, out_features=10)
        self._gather_router = brt.router.gather.GatherRouter(route_num=2)
        self._mapping_ = {'_scatter_router': 'moe.scatter_router', '_expert1': 'moe.expert1', '_expert2': 'moe.expert2', '_gather_router': 'moe.gather_router'}

    def forward(self, x__1):
        _Constant2 = 0
        _Constant3 = 1
        _scatter_router = self._scatter_router(x__1)
        _TupleUnpack5 = _scatter_router
        _aten____getitem__7 = _TupleUnpack5[0][_Constant2]
        _aten____getitem__9 = _TupleUnpack5[0][_Const

In [11]:
import torch
from brt.common import log
import brt.nn as nn
from brt.router import ScatterRouter, GatherRouter
from brt.frontend import build_graph
from brt.primitive.helper import symbolize
from brt.backend.pytorch import model_to_script

log.set_level("frontend", "INFO")
log.set_level("backend", "INFO")
log.set_level("ir", "INFO")


route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU())


class DynamicRouting(nn.Module):
    def __init__(self, route_num):
        super().__init__()
        self.scatter_router_0 = ScatterRouter(
            route_num=route_num,
            route_func=route_func,
            route_method="threshold",
            route_method_args=0,
        )
        self.scatter_router_1 = ScatterRouter(
            route_num=route_num,
            route_func=route_func,
            route_method="threshold",
            route_method_args=0,
        )
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 20)
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(route_num=route_num)
        self.gather_router_1 = GatherRouter(route_num=route_num)

    def forward(self, x, y):
        route_results_x, route_tags_x, loads = self.scatter_router_0(x)
        route_results_y, route_tags_y, loads = self.scatter_router_1(y)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0], [route_tags_x[0], route_tags_y[0]], loads)
        y = self.gather_router_1([x_1, y_1], [route_tags_x[1], route_tags_y[1]], loads)
        return x, y


dy_model = DynamicRouting(2)

x = torch.randn((2, 10))
y = torch.randn((2, 10))

x, y = dy_model(x, y)

print(x.shape)
print(y.shape)


torch.Size([2, 10])
torch.Size([2, 20])


In [12]:

script_dy_model = torch.jit.script(symbolize(dy_model))
sm_graph = script_dy_model.graph
print(sm_graph)
ir_dy_model = build_graph(dy_model)

model_script = model_to_script(ir_dy_model)
print(model_script)

graph(%self : __torch__.___torch_mangle_35.DynamicRouting,
      %x.1 : Tensor,
      %y.1 : Tensor):
  %17 : int = prim::Constant[value=0]() # /tmp/ipykernel_999342/3213744601.py:48:43
  %22 : int = prim::Constant[value=1]() # /tmp/ipykernel_999342/3213744601.py:49:43
  %scatter_router_0 : __torch__.brt.router.scatter.ScatterRouter = prim::GetAttr[name="scatter_router_0"](%self)
  %5 : (Tensor[], Tensor[], int) = prim::CallMethod[name="forward"](%scatter_router_0, %x.1) # /tmp/ipykernel_999342/3213744601.py:46:47
  %route_results_x.1 : Tensor[], %route_tags_x.1 : Tensor[], %loads.1 : int = prim::TupleUnpack(%5)
  %scatter_router_1 : __torch__.brt.router.scatter.ScatterRouter = prim::GetAttr[name="scatter_router_1"](%self)
  %11 : (Tensor[], Tensor[], int) = prim::CallMethod[name="forward"](%scatter_router_1, %y.1) # /tmp/ipykernel_999342/3213744601.py:47:47
  %route_results_y.1 : Tensor[], %route_tags_y.1 : Tensor[], %loads.3 : int = prim::TupleUnpack(%11)
  %expert1 : __torch__.torch

In [29]:
import torch
from brt.common import log
import brt.nn as nn
from brt.router.dispatcher import ResidualDispatcher
from brt.router import ScatterRouter, GatherRouter
from brt.frontend import build_graph
from brt.primitive.helper import symbolize
from brt.backend.pytorch import model_to_script

log.set_level("frontend", "INFO")
log.set_level("backend", "INFO")
log.set_level("ir", "INFO")


route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU())


class DynamicRouting(nn.Module):
    def __init__(self, route_num):
        super().__init__()
        self.scatter_router_0 = ScatterRouter(
            route_num=route_num,
            route_func=route_func,
            route_method="threshold",
            route_method_args=0,
            dispatcher_cls=ResidualDispatcher,
            residual_route=0,
        )
        self.scatter_router_1 = ScatterRouter(
            route_num=route_num,
            route_func=route_func,
            route_method="threshold",
            route_method_args=0,
            dispatcher_cls=ResidualDispatcher,
            residual_route=0,
        )
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 20)
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(route_num=route_num)
        self.gather_router_1 = GatherRouter(route_num=route_num)

    def forward(self, x, y):
        route_results_x, route_tags_x, loads = self.scatter_router_0(x)
        print(route_tags_x)
        route_results_y, route_tags_y, loads = self.scatter_router_1(y)
        print(route_tags_y)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0], [route_tags_x[0], route_tags_y[0]], loads)
        y = self.gather_router_1([x_1, y_1], [route_tags_x[1], route_tags_y[1]], loads)
        return x, y


dy_model = DynamicRouting(2)

x = torch.randn((2, 10))
y = torch.randn((2, 10))

x, y = dy_model(x, y)

print(x)
print(y)


tags tensor([[1]])
tags tensor([[0],
        [1]])
[tensor([[1]]), tensor([[0],
        [1]])]
tags tensor([[0]])
tags tensor([[1]])
[tensor([[0]]), tensor([[1]])]
tensor([[-0.0315, -0.1203, -0.5995, -1.0298,  0.4242, -0.3756,  0.7069, -0.6183,
         -0.4038, -0.4125],
        [-0.2291, -0.2067,  0.0180, -0.0154,  0.3695,  0.0626, -0.1208, -0.3797,
         -0.1901, -0.1971]], grad_fn=<ScatterAddBackward0>)
tensor([[ 0.0745, -0.1664,  0.2018,  0.1628,  0.5579, -0.4704, -0.0179,  0.3484,
          0.1139, -0.3057, -0.0162, -0.0616,  0.2247, -0.4605,  0.0990, -0.0172,
          0.0906,  0.0738,  0.0926,  0.0758],
        [-0.4129,  0.0098, -0.6666,  0.0378,  0.4107, -0.6506,  0.0775, -0.5123,
          0.0186, -0.3639, -0.5886, -0.1097,  0.0104, -0.5220, -0.3602, -0.6615,
         -0.2163, -0.2218,  0.5290,  0.0580]], grad_fn=<ScatterAddBackward0>)
