Dynamic Routing with Default Dispatcher and 2-D tensor

In [4]:
import torch
import brt.nn as nn
from brt.router import ScatterRouter, GatherRouter

route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU())


class DynamicRouting(nn.Module):
    def __init__(self, dst_num):
        super().__init__()
        self.scatter_router_0 = ScatterRouter(
            dst_num=dst_num,
            route_func=route_func,
            route_method="threshold",
            threshold=0,
        )
        self.scatter_router_1 = ScatterRouter(
            dst_num=dst_num,
            route_func=route_func,
            route_method="threshold",
            threshold=0,
        )
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 20)
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(dst_num=dst_num)
        self.gather_router_1 = GatherRouter(dst_num=dst_num)

    def forward(self, x, y):
        route_results_x = self.scatter_router_0(x)
        route_results_y = self.scatter_router_1(y)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting(2)
dy_model.cuda()
for i in range(10):
    x = torch.randn((4, 10)).cuda()
    y = torch.randn((4, 10)).cuda()

    x, y = dy_model(x, y)

    print(x.data.shape)
    print(y.data.shape)
    
dy_model.cpu()
for i in range(10):
    x = torch.randn((2, 10)).cpu()
    y = torch.randn((2, 10)).cpu()

    x, y = dy_model(x, y)

    print(x.data.shape)
    print(y.data.shape)


torch.Size([3, 10])
torch.Size([4, 20])
torch.Size([3, 10])
torch.Size([3, 20])
torch.Size([3, 10])
torch.Size([4, 20])
torch.Size([4, 10])
torch.Size([4, 20])
torch.Size([4, 10])
torch.Size([3, 20])
torch.Size([4, 10])
torch.Size([3, 20])
torch.Size([4, 10])
torch.Size([4, 20])
torch.Size([3, 10])
torch.Size([4, 20])
torch.Size([4, 10])
torch.Size([4, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([0, 10])
torch.Size([2, 20])
torch.Size([1, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([0, 10])
torch.Size([1, 20])
torch.Size([1, 10])
torch.Size([2, 20])
torch.Size([1, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])


Dynamic Routing with Residual Router and 2-D tensor

In [3]:
import torch
import brt.nn as nn
from brt.router import ScatterRouter, GatherRouter


route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU())


class DynamicRouting(nn.Module):
    def __init__(self, dst_num):
        super().__init__()
        self.scatter_router_0 = ScatterRouter(
            dst_num=dst_num,
            route_func=route_func,
            route_method="threshold",
            threshold=0,
            residual_dst = 0
        )
        self.scatter_router_1 = ScatterRouter(
            dst_num=dst_num,
            route_func=route_func,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
        )
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 20)
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(dst_num=dst_num)
        self.gather_router_1 = GatherRouter(dst_num=dst_num)

    def forward(self, x, y):
        route_results_x, route_tags_x, loads = self.scatter_router_0(x)
        route_results_y, route_tags_y, loads = self.scatter_router_1(y)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0], [route_tags_x[0], route_tags_y[0]], loads)
        y = self.gather_router_1([x_1, y_1], [route_tags_x[1], route_tags_y[1]], loads)
        return x, y


dy_model = DynamicRouting(2)
dy_model.cuda()
for i in range(10):
    x = torch.randn((2, 10)).cuda()
    y = torch.randn((2, 10)).cuda()

    x, y = dy_model(x, y)

    print(x.shape)
    print(y.shape)


torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])


Dynamic Routing with Residual Router and 4-D tensor

In [5]:
import torch
import brt.nn as nn
from brt.router import ScatterRouter, GatherRouter


route_func = nn.Sequential(
    nn.Conv2d(4, 2, 1),
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Conv2d(2, 2, 1),
)


class DynamicRouting(nn.Module):
    def __init__(self, dst_num):
        super().__init__()
        self.scatter_router_0 = ScatterRouter(
            dst_num=dst_num,
            route_func=route_func,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
        )
        self.scatter_router_1 = ScatterRouter(
            dst_num=dst_num,
            route_func=route_func,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
        )
        self.expert1 = nn.Conv2d(4, 4, 1)
        self.expert2 = nn.Conv2d(4, 8, 1)
        self.expert3 = nn.Conv2d(4, 4, 1)
        self.expert4 = nn.Conv2d(4, 8, 1)
        self.gather_router_0 = GatherRouter(dst_num=dst_num)
        self.gather_router_1 = GatherRouter(dst_num=dst_num)

    def forward(self, x, y):
        route_results_x, route_tags_x, loads = self.scatter_router_0(x)
        route_results_y, route_tags_y, loads = self.scatter_router_1(y)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0], [route_tags_x[0], route_tags_y[0]], loads)
        y = self.gather_router_1([x_1, y_1], [route_tags_x[1], route_tags_y[1]], loads)
        return x, y


dy_model = DynamicRouting(2)
dy_model.cuda()
for i in range(10):
    x = torch.randn((2, 4, 2, 2)).cuda()
    y = torch.randn((2, 4, 2, 2)).cuda()
    x, y = dy_model(x, y)
    print(x.shape)
    print(y.shape)
    


torch.Size([0, 4, 2, 2])
torch.Size([2, 8, 2, 2])
torch.Size([0, 4, 2, 2])
torch.Size([2, 8, 2, 2])
torch.Size([2, 4, 2, 2])
torch.Size([2, 8, 2, 2])
torch.Size([2, 4, 2, 2])
torch.Size([2, 8, 2, 2])
torch.Size([0, 4, 2, 2])
torch.Size([2, 8, 2, 2])
torch.Size([0, 4, 2, 2])
torch.Size([2, 8, 2, 2])
torch.Size([0, 4, 2, 2])
torch.Size([2, 8, 2, 2])
torch.Size([2, 4, 2, 2])
torch.Size([2, 8, 2, 2])
torch.Size([0, 4, 2, 2])
torch.Size([2, 8, 2, 2])
torch.Size([0, 4, 2, 2])
torch.Size([2, 8, 2, 2])


Sparse routing

In [6]:
from brt.router import TagRouter, SparseGatherRouter, SparseScatterRouter
import brt.nn as nn
import torch
class SparseRouterModel(nn.Module):
    def __init__(self, route_func, route_method):
        super().__init__()
        self.tag_router = TagRouter()
        self.scatter_router = SparseScatterRouter(
            dst_num=2, route_func=route_func, route_method=route_method
        )
        self.expert1 = nn.Identity()
        self.expert2 = nn.Identity()
        self.gather_router = SparseGatherRouter(dst_num=2)

    def forward(self, x):
        x, tags = self.tag_router(x)
        route_results, route_tags = self.scatter_router(x, tags)
        x_0 = self.expert1(route_results[0])
        x_1 = self.expert2(route_results[1])
        x, tags = self.gather_router([x_0, x_1], route_tags)
        return x_0, x_1, x

def route_func(inputs):
        gates = torch.zeros(
            inputs.shape[0], 2, dtype=torch.int64, device=inputs.device
        )
        gates[:, 0] = 1
        return gates

sparse_routing_model = SparseRouterModel(route_func, "topk")

x = torch.arange(0, 30, dtype=torch.float32).view(3, 10)

y = sparse_routing_model(x)
print(y)
print(torch.allclose(x, y[2]))

(tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],
        [20., 21., 22., 23., 24., 25., 26., 27., 28., 29.]]), tensor([], size=(0, 10)), tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],
        [20., 21., 22., 23., 24., 25., 26., 27., 28., 29.]]))
True
