Dynamic Routing with Default Dispatcher and 2-D tensor

In [None]:
import torch
import torch.nn as nn
from brt.router import ScatterRouter, GatherRouter


class DynamicRouting(nn.Module):
    def __init__(self):
        super().__init__()
        self.route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU())
        self.scatter_router_0 = ScatterRouter(
            protocol_type="threshold",
            fabric_type="dispatch",
        )
        self.scatter_router_1 = ScatterRouter(
            protocol_type="threshold",
            fabric_type="dispatch",
        )
        self.expert1 = nn.Linear(10, 10)  # keep
        self.expert2 = nn.Linear(10, 20)  # upsample
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(fabric_type="combine")
        self.gather_router_1 = GatherRouter(fabric_type="combine")

    def forward(self, x, y):
        x_gates = self.route_func(x)
        y_gates = self.route_func(y)
        route_results_x = self.scatter_router_0(x, x_gates)
        route_results_y = self.scatter_router_1(y, y_gates)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting()

for i in range(1):
    in_x = torch.randn((4, 10))
    in_y = torch.randn((4, 10))

    dy_model.cpu()
    cpu_x, cpu_y = dy_model(in_x.cpu(), in_y.cpu())
    print(cpu_x)
    print(cpu_y)

    dy_model.cuda()
    cuda_x, cuda_y = dy_model(in_x.cuda(), in_y.cuda())
    print(cuda_x)
    print(cuda_y)


Dynamic Routing with Residual Router and 2-D tensor

In [None]:
import torch
import torch.nn as nn
from brt.router import ScatterRouter, GatherRouter


class DynamicRouting(nn.Module):
    def __init__(self):
        super().__init__()
        self.route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU())
        self.scatter_router_0 = ScatterRouter(
            protocol_type="threshold",
            fabric_type="dispatch",
        )
        self.scatter_router_1 = ScatterRouter(
            protocol_type="threshold",
            fabric_type="dispatch",
        )
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 20)
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(fabric_type="combine")
        self.gather_router_1 = GatherRouter(fabric_type="combine")

    def forward(self, x, y):
        x_gates = self.route_func(x)
        y_gates = self.route_func(y)
        route_results_x = self.scatter_router_0(x, x_gates)
        route_results_y = self.scatter_router_1(y, y_gates)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting()
dy_model.cuda()

for i in range(1):
    x = torch.randn((3, 10)).cuda()
    y = torch.randn((3, 10)).cuda()

    x, y = dy_model(x, y)

    print(x)
    print(y)


Dynamic Routing with Residual Router and 2-D tensor while routing gates

In [20]:
import torch
import torch.nn as nn
from brt.router import ScatterRouter, GatherRouter


class DynamicRouting(nn.Module):
    def __init__(self):
        super().__init__()
        self.route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU())
        self.scatter_router_0 = ScatterRouter(
            dispatch_score=True,
            protocol_type="threshold",
            fabric_type="dispatch",
        )
        self.scatter_router_1 = ScatterRouter(
            protocol_type="threshold",
            fabric_type="dispatch",
        )
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 20)
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(fabric_type="combine")
        self.gather_router_1 = GatherRouter(fabric_type="combine")

    def forward(self, x, y):
        x_gates = self.route_func(x)
        y_gates = self.route_func(y)
        route_results_x, route_gates_x = self.scatter_router_0(x, x_gates)
        route_results_y = self.scatter_router_1(y, y_gates)
        x_0 = self.expert1(route_results_x[0])
        x_0 = route_gates_x[0] * x_0
        x_1 = self.expert2(route_results_x[1])
        x_1 = route_gates_x[1] * x_1
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting()
dy_model.cuda()
for i in range(10):
    x = torch.randn((3, 10)).cuda()
    y = torch.randn((3, 10)).cuda()

    x, y = dy_model(x, y)

    # print(x.shape)
    # print(y.shape)


[ProtoTensor(
data: ProtoTensor([[0.4156]], device='cuda:0', grad_fn=<AliasBackward0>)
tag_stack: [tensor([[1]], device='cuda:0')]
load stack: [3]), tensor([[0.6150],
        [0.1975],
        [0.4762]], device='cuda:0', grad_fn=<AliasBackward0>)]
ProtoTensor(
data: ProtoTensor([[ 0.4268,  0.6731, -0.2888,  0.0502,  0.5056, -0.8375, -0.1827,
               0.6613, -1.1024,  0.9359]], device='cuda:0',
            grad_fn=<AliasBackward0>)
tag_stack: [tensor([[1]], device='cuda:0')]
load stack: [3])
tensor([[ 0.1345, -0.5591, -0.2105,  0.0021,  0.4726,  0.8786,  0.0605, -0.4822,
         -0.5286, -0.4091,  0.7765, -0.7589, -0.0979, -0.1256,  0.6138,  0.3466,
          0.8299,  0.3985, -0.1366,  0.8685],
        [-0.0185,  0.9469, -0.2337, -0.1099,  0.0874,  0.7555,  0.9220,  0.4358,
          0.7234, -0.0283, -0.0552,  0.2293,  0.1887, -0.7704, -0.5916, -1.0441,
          0.7610,  0.2705,  0.5875,  0.8707],
        [-0.3191, -1.0070,  0.0735, -1.3915,  0.1924,  0.4870,  0.9591,  0.2915,


Dynamic Routing with Residual Router and 4-D tensor

In [22]:
import torch
import torch.nn as nn
from brt.router import ScatterRouter, GatherRouter


route_func = (
    lambda x: nn.Sequential(
        nn.Conv2d(4, 2, 1),
        nn.AdaptiveAvgPool2d((1, 1)),
        nn.Conv2d(2, 2, 1),
    )
    .cuda()(x)
    .view(-1, 2)
)  # [bs x dst_num x 1 x 1] keep up down -> keep down


class DynamicRouting(nn.Module):
    def __init__(self, dst_num):
        super().__init__()
        self.route_func = nn.Sequential(
            nn.Conv2d(4, 2, 1),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Conv2d(2, 2, 1),
        )
        self.scatter_router_0 = ScatterRouter(
            dispatch_score=True,
            protocol_type="threshold",
            fabric_type="dispatch",
        )
        self.scatter_router_1 = ScatterRouter(
            protocol_type="threshold",
            fabric_type="dispatch",
        )
        self.expert1 = nn.Conv2d(4, 4, 1)
        self.expert2 = nn.Conv2d(4, 8, 1)
        self.expert3 = nn.Conv2d(4, 4, 1)
        self.expert4 = nn.Conv2d(4, 8, 1)
        self.gather_router_0 = GatherRouter(fabric_type="combine")
        self.gather_router_1 = GatherRouter(fabric_type="combine")

    def forward(self, x, y):
        gates_x = self.route_func(x).view(-1, 2)
        gates_y = self.route_func(y).view(-1, 2)
        route_results_x, _ = self.scatter_router_0(x, gates_x)
        route_results_y = self.scatter_router_1(y, gates_y)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting(2)
dy_model.cuda()
for i in range(10):
    x = torch.randn((3, 4, 2, 2)).cuda()
    y = torch.randn((3, 4, 2, 2)).cuda()
    x, y = dy_model(x, y)
    print(x.shape)
    print(y.shape)


torch.Size([3, 4, 2, 2])
torch.Size([3, 8, 2, 2])
torch.Size([3, 4, 2, 2])
torch.Size([3, 8, 2, 2])
torch.Size([3, 4, 2, 2])
torch.Size([3, 8, 2, 2])
torch.Size([3, 4, 2, 2])
torch.Size([3, 8, 2, 2])
torch.Size([3, 4, 2, 2])
torch.Size([3, 8, 2, 2])
torch.Size([3, 4, 2, 2])
torch.Size([3, 8, 2, 2])
torch.Size([3, 4, 2, 2])
torch.Size([3, 8, 2, 2])
torch.Size([3, 4, 2, 2])
torch.Size([3, 8, 2, 2])
torch.Size([3, 4, 2, 2])
torch.Size([3, 8, 2, 2])
torch.Size([3, 4, 2, 2])
torch.Size([3, 8, 2, 2])
