Dynamic Routing with Default Dispatcher and 2-D tensor

In [5]:
import torch
import torch.nn as nn
from brt.router import ScatterRouter, GatherRouter
class DynamicRouting(nn.Module):
    def __init__(self):
        super().__init__()
        self.route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU())
        self.scatter_router_0 = ScatterRouter(
            protocol_type="threshold",
            fabric_type="dispatch",
        )
        self.scatter_router_1 = ScatterRouter(
            protocol_type="threshold",
            fabric_type="dispatch",
        )
        self.expert1 = nn.Linear(10, 10)  # keep
        self.expert2 = nn.Linear(10, 20)  # upsample
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(fabric_type="combine")
        self.gather_router_1 = GatherRouter(fabric_type="combine")
    def forward(self, x, y):
        x_gates = self.route_func(x)
        y_gates = self.route_func(y)
        route_results_x = self.scatter_router_0(x, x_gates)
        route_results_y = self.scatter_router_1(y, y_gates)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y
dy_model = DynamicRouting()
for i in range(1):
    in_x = torch.randn((4, 10))
    in_y = torch.randn((4, 10))
    dy_model.cpu()
    cpu_x, cpu_y = dy_model(in_x.cpu(), in_y.cpu())
    print(cpu_x)
    print(cpu_y)
    dy_model.cuda()
    cuda_x, cuda_y = dy_model(in_x.cuda(), in_y.cuda())
    print(cuda_x)
    print(cuda_y)


tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [-0.0282,  0.8077, -0.1290, -0.0281, -0.2477, -1.0221,  0.3454, -0.3232,
          0.2726,  0.3481],
        [ 0.2142, -0.4677, -0.0294, -1.2462, -0.0145,  0.1233, -0.1467, -0.0757,
         -0.1168, -0.4898],
        [-0.5658,  0.2220, -0.7049,  0.0997, -0.4098, -0.1703, -0.1509,  0.0133,
          0.6791, -0.2452]], grad_fn=<AliasBackward0>)
tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000],
        [-1.2465, -1.7525, -0.9535, -0.8834, -0.8782, -0.1326, -1.4295, -0.1853,
         -0.1577,  0.3481, -0.8457, -0.0894,  1.0547,  1.5808, -0.0763, -0.9335,
         -1.6403,  0.9497,  0.8951, -0.1301],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0

Dynamic Routing with Residual Router and 2-D tensor

In [6]:
import torch
import torch.nn as nn
from brt.router import ScatterRouter, GatherRouter


class DynamicRouting(nn.Module):
    def __init__(self):
        super().__init__()
        self.route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU())
        self.scatter_router_0 = ScatterRouter(
            protocol_type="threshold",
            fabric_type="dispatch",
        )
        self.scatter_router_1 = ScatterRouter(
            protocol_type="threshold",
            fabric_type="dispatch",
        )
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 20)
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(fabric_type="combine")
        self.gather_router_1 = GatherRouter(fabric_type="combine")

    def forward(self, x, y):
        x_gates = self.route_func(x)
        y_gates = self.route_func(y)
        route_results_x = self.scatter_router_0(x, x_gates)
        route_results_y = self.scatter_router_1(y, y_gates)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting()
dy_model.cuda()

for i in range(1):
    x = torch.randn((3, 10)).cuda()
    y = torch.randn((3, 10)).cuda()

    x, y = dy_model(x, y)

    print(x)
    print(y)


tensor([[-4.5302e-01,  2.9965e-01,  5.8934e-01,  8.7531e-01, -1.5401e-03,
         -5.6871e-01, -3.6920e-01, -1.4634e+00, -1.8272e+00, -7.1656e-01],
        [-2.0284e-01, -8.0398e-01, -7.8178e-01,  2.0861e-01,  1.2111e-03,
          2.2141e-01,  9.5521e-01, -7.1936e-01, -8.9217e-01, -6.1342e-01],
        [ 3.6026e-01, -5.1783e-01,  6.6238e-02, -1.1183e-01, -1.2101e+00,
         -1.1991e+00,  1.2025e-01, -8.2009e-01,  5.7306e-02, -6.1690e-01]],
       device='cuda:0', grad_fn=<AliasBackward0>)
tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.4210,  0.3431,  0.5573, -0.1178, -0.0333, -0.8537, -0.2588,  0.7704,
     

Dynamic Routing with Residual Router and 2-D tensor while routing gates

In [7]:
import torch
import torch.nn as nn
from brt.router import ScatterRouter, GatherRouter


class DynamicRouting(nn.Module):
    def __init__(self):
        super().__init__()
        self.route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU())
        self.scatter_router_0 = ScatterRouter(
            dispatch_score=True,
            protocol_type="threshold",
            fabric_type="dispatch",
        )
        self.scatter_router_1 = ScatterRouter(
            protocol_type="threshold",
            fabric_type="dispatch",
        )
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 20)
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(fabric_type="combine")
        self.gather_router_1 = GatherRouter(fabric_type="combine")

    def forward(self, x, y):
        x_gates = self.route_func(x)
        y_gates = self.route_func(y)
        route_results_x, route_gates_x = self.scatter_router_0(x, x_gates)
        route_results_y = self.scatter_router_1(y, y_gates)
        x_0 = self.expert1(route_results_x[0])
        x_0 = route_gates_x[0] * x_0
        x_1 = self.expert2(route_results_x[1])
        x_1 = route_gates_x[1] * x_1
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting()
dy_model.cuda()
for i in range(10):
    x = torch.randn((3, 10)).cuda()
    y = torch.randn((3, 10)).cuda()

    x, y = dy_model(x, y)

    # print(x.shape)
    # print(y.shape)


Dynamic Routing with Residual Router and 4-D tensor

In [8]:
import torch
import torch.nn as nn
from brt.router import ScatterRouter, GatherRouter


route_func = (
    lambda x: nn.Sequential(
        nn.Conv2d(4, 2, 1),
        nn.AdaptiveAvgPool2d((1, 1)),
        nn.Conv2d(2, 2, 1),
    )
    .cuda()(x)
    .view(-1, 2)
)  # [bs x dst_num x 1 x 1] keep up down -> keep down


class DynamicRouting(nn.Module):
    def __init__(self, dst_num):
        super().__init__()
        self.route_func = nn.Sequential(
            nn.Conv2d(4, 2, 1),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Conv2d(2, 2, 1),
        )
        self.scatter_router_0 = ScatterRouter(
            dispatch_score=True,
            protocol_type="threshold",
            fabric_type="dispatch",
        )
        self.scatter_router_1 = ScatterRouter(
            protocol_type="threshold",
            fabric_type="dispatch",
        )
        self.expert1 = nn.Conv2d(4, 4, 1)
        self.expert2 = nn.Conv2d(4, 8, 1)
        self.expert3 = nn.Conv2d(4, 4, 1)
        self.expert4 = nn.Conv2d(4, 8, 1)
        self.gather_router_0 = GatherRouter(fabric_type="combine")
        self.gather_router_1 = GatherRouter(fabric_type="combine")

    def forward(self, x, y):
        gates_x = self.route_func(x).view(-1, 2)
        gates_y = self.route_func(y).view(-1, 2)
        route_results_x, _ = self.scatter_router_0(x, gates_x)
        route_results_y = self.scatter_router_1(y, gates_y)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting(2)
dy_model.cuda()
for i in range(10):
    x = torch.randn((3, 4, 2, 2)).cuda()
    y = torch.randn((3, 4, 2, 2)).cuda()
    x, y = dy_model(x, y)
    print(x.shape)
    print(y.shape)


torch.Size([3, 4, 2, 2])
torch.Size([3, 8, 2, 2])
torch.Size([3, 4, 2, 2])
torch.Size([0, 8, 2, 2])
torch.Size([3, 4, 2, 2])
torch.Size([3, 8, 2, 2])
torch.Size([3, 4, 2, 2])
torch.Size([0, 8, 2, 2])
torch.Size([3, 4, 2, 2])
torch.Size([0, 8, 2, 2])
torch.Size([3, 4, 2, 2])
torch.Size([0, 8, 2, 2])
torch.Size([3, 4, 2, 2])
torch.Size([0, 8, 2, 2])
torch.Size([3, 4, 2, 2])
torch.Size([0, 8, 2, 2])
torch.Size([3, 4, 2, 2])
torch.Size([3, 8, 2, 2])
torch.Size([3, 4, 2, 2])
torch.Size([3, 8, 2, 2])
