Dynamic Routing with Default Dispatcher and 2-D tensor

In [3]:
import torch
import brt.nn as nn
from brt.router import ScatterRouter, GatherRouter
from brt.router.flow_tensor import FlowTensor

route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU())


class DynamicRouting(nn.Module):
    def __init__(self, dst_num):
        super().__init__()
        self.scatter_router_0 = ScatterRouter(
            dst_num=dst_num,
            route_func=route_func,
            route_method="threshold",
            threshold=0,
        )
        self.scatter_router_1 = ScatterRouter(
            dst_num=dst_num,
            route_func=route_func,
            route_method="threshold",
            threshold=0,
        )
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 20)
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(dst_num=dst_num        ,    sparse=False)
        self.gather_router_1 = GatherRouter(dst_num=dst_num,            sparse=False)

    def forward(self, x, y):
        route_results_x = self.scatter_router_0(x)
        route_results_y = self.scatter_router_1(y)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting(2)
# dy_model.cuda()
# for i in range(1):
#     x = torch.randn((4, 10)).cuda()
#     y = torch.randn((4, 10)).cuda()

#     x, y = dy_model(x, y)

#     print(x.shape)
#     print(y)
#     print(type(y))
    
dy_model.cpu()
for i in range(10):
    x = torch.randn((4, 10)).cpu()
    y = torch.randn((4, 10)).cpu()

    x, y = dy_model(x, y)
    print(x)
    
    # print(x.shape)
    # print(y.shape)


tensor([[-1.6271, -0.6383, -1.2221,  0.3257, -0.0352, -1.4077, -1.3284,  0.6198,
         -0.1626,  0.9552],
        [-1.2683, -1.2079, -0.6410,  0.5701,  0.2645, -0.8459, -0.6069, -1.1710,
          0.3517, -0.0354],
        [-0.2773,  0.1909, -0.2333, -0.6264, -0.5244,  0.0751,  0.0542, -0.0203,
         -0.3374,  0.4982],
        [-0.6350,  0.3245, -0.0794,  0.1198, -0.3305, -1.0191, -0.2120, -0.2445,
         -1.1816, -0.0537]], grad_fn=<AliasBackward0>)
tensor([[-0.0831,  0.1028, -0.1646, -0.1229, -0.3900, -0.0480,  0.0405,  0.2906,
         -0.2759,  0.3182],
        [-0.4368,  0.1810, -0.2997, -0.4503, -0.5143, -0.8797,  1.0962, -0.3353,
         -0.7917,  1.0673],
        [ 0.0320, -0.6266, -0.0114, -0.3400, -0.2685, -0.4233, -0.4727, -0.1379,
         -0.8922,  0.1393],
        [-0.7869,  0.0192,  0.6330,  2.6323,  0.8123, -4.0420, -0.2710, -2.2398,
         -1.1509, -0.0267]], grad_fn=<AliasBackward0>)
tensor([[-0.4396, -1.0582, -0.4558,  1.3950,  0.5716, -0.4968, -0.5918, -0

Dynamic Routing with Residual Router and 2-D tensor

In [12]:
import torch
import brt.nn as nn
from brt.router import ScatterRouter, GatherRouter

route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU())


class DynamicRouting(nn.Module):
    def __init__(self, dst_num):
        super().__init__()
        self.scatter_router_0 = ScatterRouter(
            dst_num=dst_num,
            route_func=route_func,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
        )
        self.scatter_router_1 = ScatterRouter(
            dst_num=dst_num,
            route_func=route_func,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
        )
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 20)
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(dst_num=dst_num)
        self.gather_router_1 = GatherRouter(dst_num=dst_num)

    def forward(self, x, y):
        route_results_x = self.scatter_router_0(x)
        route_results_y = self.scatter_router_1(y)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting(2)
dy_model.cuda()
for i in range(10):
    x = torch.randn((2, 10)).cuda()
    y = torch.randn((2, 10)).cuda()

    x, y = dy_model(x, y)

    print(x.shape)
    print(y.shape)


torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([1, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([1, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([1, 20])


Dynamic Routing with Residual Router and 4-D tensor

In [13]:
import torch
import brt.nn as nn
from brt.router import ScatterRouter, GatherRouter


route_func = nn.Sequential(
    nn.Conv2d(4, 2, 1),
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Conv2d(2, 2, 1),
)


class DynamicRouting(nn.Module):
    def __init__(self, dst_num):
        super().__init__()
        self.scatter_router_0 = ScatterRouter(
            dst_num=dst_num,
            route_func=route_func,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
        )
        self.scatter_router_1 = ScatterRouter(
            dst_num=dst_num,
            route_func=route_func,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
        )
        self.expert1 = nn.Conv2d(4, 4, 1)
        self.expert2 = nn.Conv2d(4, 8, 1)
        self.expert3 = nn.Conv2d(4, 4, 1)
        self.expert4 = nn.Conv2d(4, 8, 1)
        self.gather_router_0 = GatherRouter(dst_num=dst_num)
        self.gather_router_1 = GatherRouter(dst_num=dst_num,sparse=True)

    def forward(self, x, y):
        route_results_x = self.scatter_router_0(x)
        route_results_y = self.scatter_router_1(y)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting(2)
dy_model.cuda()
for i in range(10):
    x = torch.randn((2, 4, 2, 2)).cuda()
    y = torch.randn((2, 4, 2, 2)).cuda()
    x, y = dy_model(x, y)
    print(x.shape)
    print(y.shape)
    


torch.Size([2, 4, 2, 2])
torch.Size([0, 8, 2, 2])
torch.Size([2, 4, 2, 2])
torch.Size([0, 8, 2, 2])
torch.Size([2, 4, 2, 2])
torch.Size([0, 8, 2, 2])
torch.Size([2, 4, 2, 2])
torch.Size([0, 8, 2, 2])
torch.Size([2, 4, 2, 2])
torch.Size([0, 8, 2, 2])
torch.Size([2, 4, 2, 2])
torch.Size([0, 8, 2, 2])
torch.Size([2, 4, 2, 2])
torch.Size([0, 8, 2, 2])
torch.Size([2, 4, 2, 2])
torch.Size([0, 8, 2, 2])
torch.Size([2, 4, 2, 2])
torch.Size([0, 8, 2, 2])
torch.Size([2, 4, 2, 2])
torch.Size([0, 8, 2, 2])
