Dynamic Routing with Default Dispatcher and 2-D tensor

In [9]:
import torch
import torch.nn as nn
from brt.router import ScatterRouter, GatherRouter


class DynamicRouting(nn.Module):
    def __init__(self):
        super().__init__()
        self.route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU())
        self.scatter_router_0 = ScatterRouter(
            protocol_type="threshold",
            protocol_kwargs={"threshold": 0.0, "residual_path": 0},
            fabric_type="dispatch",
            fabric_kwargs={"route_logic": "1d", "transform": False},
        )
        self.scatter_router_1 = ScatterRouter(
            protocol_type="threshold",
            protocol_kwargs={"threshold": 0.0, "residual_path": 0},
            fabric_type="dispatch",
            fabric_kwargs={"route_logic": "1d", "transform": False},
        )
        self.expert1 = nn.Linear(10, 10)  # keep
        self.expert2 = nn.Linear(10, 20)  # upsample
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(fabric_type="combine")
        self.gather_router_1 = GatherRouter(fabric_type="combine")

    def forward(self, x, y):
        x_gates = self.route_func(x)
        y_gates = self.route_func(y)
        route_results_x = self.scatter_router_0(x, x_gates)
        route_results_y = self.scatter_router_1(y, y_gates)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting()

dy_model.cpu()
for i in range(1):
    x = torch.randn((4, 10)).cpu()
    y = torch.randn((4, 10)).cpu()

    x, y = dy_model(x, y)
    print(x)
    print(y)
    




tensor([[ 0.3166, -0.5643, -1.1399,  0.0916,  0.0229,  0.8864, -1.1528, -1.9537,
         -0.0671, -1.0921],
        [-0.0823,  0.0750,  0.1320, -0.1134, -0.3419,  0.3165, -1.1521, -0.5937,
         -0.3584, -0.6543],
        [ 0.1310, -0.2195,  0.0992,  1.0090,  0.1904, -0.1563, -1.3074,  0.7276,
         -0.1603, -0.1828],
        [ 0.8414,  0.3183, -0.2880, -0.3768, -0.5600, -0.3882,  0.5404,  0.8075,
          0.2602,  0.2580]], grad_fn=<AliasBackward0>)
tensor([[-0.3912, -0.0040,  0.2605,  0.2304, -0.5349,  0.1048, -0.9600,  0.2457,
         -1.3399,  0.2648, -1.1139,  0.6045,  0.8503,  0.8554,  0.6182,  1.2150,
         -0.1765, -0.3273,  0.0260,  0.8498],
        [-0.4919, -0.9002, -0.6617,  0.6329, -1.0451,  0.1458, -0.3138,  0.2437,
         -1.1097,  0.3452,  0.1887,  0.4526,  1.6244, -0.1277,  0.8320,  0.5059,
          0.2183, -0.2257,  0.3238, -0.5783],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0

In [10]:
dy_model.cuda()
for i in range(1):
    x = torch.randn((4, 10)).cuda()
    y = torch.randn((4, 10)).cuda()

    x, y = dy_model(x, y)
    print(x)
    print(y)
    

print(dy_model.scatter_router_0.load_history)

AttributeError: module 'brt._C' has no attribute 'generate_src_indices'

: 

Dynamic Routing with Residual Router and 2-D tensor

In [None]:
import torch

from brt.router import ScatterRouter, GatherRouter
from brt.router.proto_tensor import ProtoTensor

route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU()).cuda()


class DynamicRouting(nn.Module):
    def __init__(self, dst_num):
        super().__init__()
        self.scatter_router_0 = ScatterRouter(
            dst_num=dst_num,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
        )
        self.scatter_router_1 = ScatterRouter(
            dst_num=dst_num,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
        )
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 20)
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(dst_num=dst_num)
        self.gather_router_1 = GatherRouter(dst_num=dst_num)

    def forward(self, x, y):
        x_gates = route_func(x)
        y_gates = route_func(y)
        route_results_x = self.scatter_router_0(x, x_gates)
        route_results_y = self.scatter_router_1(y, y_gates)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting(2)
dy_model.cuda()
for i in range(10):
    x = torch.randn((3, 10)).cuda()
    y = torch.randn((3, 10)).cuda()

    x, y = dy_model(x, y)

    print(x.shape)
    print(y.shape)


Dynamic Routing with Residual Router and 2-D tensor while routing gates

In [12]:
import torch
from brt.router import ScatterRouter, GatherRouter

route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU()).cuda()


class DynamicRouting(nn.Module):
    def __init__(self, dst_num):
        super().__init__()
        self.scatter_router_0 = ScatterRouter(
            dst_num=dst_num,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
            routing_gates=True,
        )
        self.scatter_router_1 = ScatterRouter(
            dst_num=dst_num,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
        )
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 20)
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(dst_num=dst_num)
        self.gather_router_1 = GatherRouter(dst_num=dst_num)

    def forward(self, x, y):
        x_gates = route_func(x)
        y_gates = route_func(y)
        route_results_x, route_gates_x = self.scatter_router_0(x, x_gates)
        route_results_y = self.scatter_router_1(y, y_gates)
        x_0 = self.expert1(route_results_x[0])
        x_0 = route_gates_x[0] * x_0
        x_1 = self.expert2(route_results_x[1])
        x_1 = route_gates_x[1] * x_1
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting(2)
dy_model.cuda()
for i in range(10):
    x = torch.randn((3, 10)).cuda()
    y = torch.randn((3, 10)).cuda()

    x, y = dy_model(x, y)

    print(x.shape)
    print(y.shape)


torch.Size([3, 10])
torch.Size([1, 20])


KeyboardInterrupt: 

Dynamic Routing with Residual Router and 4-D tensor

In [None]:
import torch
import torch.nn as nn
from brt.router import ScatterRouter, GatherRouter


route_func = lambda x : nn.Sequential(
    nn.Conv2d(4, 2, 1), nn.AdaptiveAvgPool2d((1, 1)), nn.Conv2d(2, 2, 1),
).cuda()(x).view(-1, 2)  # [bs x dst_num x 1 x 1] keep up down -> keep down


class DynamicRouting(nn.Module):
    def __init__(self, dst_num):
        super().__init__()
        self.scatter_router_0 = ScatterRouter(
            dst_num=dst_num,
            route_func=route_func,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
            routing_gates=True
        )
        self.scatter_router_1 = ScatterRouter(
            dst_num=dst_num,
            route_func=route_func,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
        )
        self.expert1 = nn.Conv2d(4, 4, 1)
        self.expert2 = nn.Conv2d(4, 8, 1)
        self.expert3 = nn.Conv2d(4, 4, 1)
        self.expert4 = nn.Conv2d(4, 8, 1)
        self.gather_router_0 = GatherRouter(dst_num=dst_num)
        self.gather_router_1 = GatherRouter(dst_num=dst_num, sparse=True)

    def forward(self, x, y):
        route_results_x = self.scatter_router_0(x, route_func(x))
        route_results_y = self.scatter_router_1(y, route_func(x))
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting(2)
dy_model.cuda()
for i in range(10):
    x = torch.randn((3, 4, 2, 2)).cuda()
    y = torch.randn((3, 4, 2, 2)).cuda()
    x, y = dy_model(x, y)
    print(x.shape)
    print(y.shape)
