Dynamic Routing with Default Dispatcher and 2-D tensor

In [1]:
import torch
import torch.nn as nn
from brt.router import ScatterRouter, GatherRouter


class DynamicRouting(nn.Module):
    def __init__(self):
        super().__init__()
        self.route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU())
        self.scatter_router_0 = ScatterRouter(
            protocol_type="threshold",
            protocol_kwargs={"threshold": 0.0, "residual_path": 0},
            fabric_type="dispatch",
            fabric_kwargs={"route_logic": "1d", "transform": False},
        )
        self.scatter_router_1 = ScatterRouter(
            protocol_type="threshold",
            protocol_kwargs={"threshold": 0.0, "residual_path": 0},
            fabric_type="dispatch",
            fabric_kwargs={"route_logic": "1d", "transform": False},
        )
        self.expert1 = nn.Linear(10, 10)  # keep
        self.expert2 = nn.Linear(10, 20)  # upsample
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(fabric_type="combine")
        self.gather_router_1 = GatherRouter(fabric_type="combine")

    def forward(self, x, y):
        x_gates = self.route_func(x)
        y_gates = self.route_func(y)
        route_results_x = self.scatter_router_0(x, x_gates)
        route_results_y = self.scatter_router_1(y, y_gates)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting()

dy_model.cpu()
for i in range(1):
    x = torch.randn((4, 10)).cpu()
    y = torch.randn((4, 10)).cpu()

    x, y = dy_model(x, y)
    print(x)
    print(y)
    




tensor([[0, 0],
        [1, 0],
        [2, 0],
        [3, 0]])
tensor([[1, 0],
        [2, 1],
        [3, 3],
        [0, 0]])
tensor([[-0.1337, -0.5725,  0.2809, -0.5756, -0.7206,  1.3379,  0.4429, -1.5247,
         -0.6831, -0.3665],
        [ 0.4355, -1.1128,  0.6187, -0.1854, -0.0917,  1.4852,  0.6469,  0.3147,
         -0.1680,  0.4829],
        [-0.7603,  0.8093, -0.9373, -0.2413, -0.2688,  0.9660,  0.4258,  1.2660,
         -0.1794, -0.5589],
        [-0.0656, -0.0997, -0.1527, -0.3565, -0.6989,  0.2345,  0.0725, -0.8098,
         -0.0324,  0.1228]], grad_fn=<AliasBackward0>)
tensor([[ 0.4401,  0.4553,  0.1578, -1.0557,  0.0467,  0.9038,  0.4522,  0.0498,
         -0.0961, -0.1590,  2.7585, -0.6058, -0.5113, -1.0705,  0.4971, -1.8115,
          0.0257, -1.0393,  0.3191, -0.9247],
        [-0.7178, -0.3076, -0.0678,  0.6714, -0.1254,  0.5345, -0.5017, -1.2882,
          0.6223, -0.1943,  0.4505, -0.0122,  0.7367,  0.4821,  0.6080, -0.1119,
          0.6372,  0.1176,  0.2704, -

In [2]:
dy_model.cuda()
for i in range(1):
    x = torch.randn((4, 10)).cuda()
    y = torch.randn((4, 10)).cuda()

    x, y = dy_model(x, y)
    print(x)
    print(y)
    

print(dy_model.scatter_router_0.load_history)

RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

Dynamic Routing with Residual Router and 2-D tensor

In [None]:
import torch

from brt.router import ScatterRouter, GatherRouter
from brt.router.proto_tensor import ProtoTensor

route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU()).cuda()


class DynamicRouting(nn.Module):
    def __init__(self, dst_num):
        super().__init__()
        self.scatter_router_0 = ScatterRouter(
            dst_num=dst_num,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
        )
        self.scatter_router_1 = ScatterRouter(
            dst_num=dst_num,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
        )
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 20)
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(dst_num=dst_num)
        self.gather_router_1 = GatherRouter(dst_num=dst_num)

    def forward(self, x, y):
        x_gates = route_func(x)
        y_gates = route_func(y)
        route_results_x = self.scatter_router_0(x, x_gates)
        route_results_y = self.scatter_router_1(y, y_gates)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting(2)
dy_model.cuda()
for i in range(10):
    x = torch.randn((3, 10)).cuda()
    y = torch.randn((3, 10)).cuda()

    x, y = dy_model(x, y)

    print(x.shape)
    print(y.shape)


Dynamic Routing with Residual Router and 2-D tensor while routing gates

In [12]:
import torch
from brt.router import ScatterRouter, GatherRouter

route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU()).cuda()


class DynamicRouting(nn.Module):
    def __init__(self, dst_num):
        super().__init__()
        self.scatter_router_0 = ScatterRouter(
            dst_num=dst_num,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
            routing_gates=True,
        )
        self.scatter_router_1 = ScatterRouter(
            dst_num=dst_num,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
        )
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 20)
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(dst_num=dst_num)
        self.gather_router_1 = GatherRouter(dst_num=dst_num)

    def forward(self, x, y):
        x_gates = route_func(x)
        y_gates = route_func(y)
        route_results_x, route_gates_x = self.scatter_router_0(x, x_gates)
        route_results_y = self.scatter_router_1(y, y_gates)
        x_0 = self.expert1(route_results_x[0])
        x_0 = route_gates_x[0] * x_0
        x_1 = self.expert2(route_results_x[1])
        x_1 = route_gates_x[1] * x_1
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting(2)
dy_model.cuda()
for i in range(10):
    x = torch.randn((3, 10)).cuda()
    y = torch.randn((3, 10)).cuda()

    x, y = dy_model(x, y)

    print(x.shape)
    print(y.shape)


torch.Size([3, 10])
torch.Size([1, 20])


KeyboardInterrupt: 

Dynamic Routing with Residual Router and 4-D tensor

In [None]:
import torch
import torch.nn as nn
from brt.router import ScatterRouter, GatherRouter


route_func = lambda x : nn.Sequential(
    nn.Conv2d(4, 2, 1), nn.AdaptiveAvgPool2d((1, 1)), nn.Conv2d(2, 2, 1),
).cuda()(x).view(-1, 2)  # [bs x dst_num x 1 x 1] keep up down -> keep down


class DynamicRouting(nn.Module):
    def __init__(self, dst_num):
        super().__init__()
        self.scatter_router_0 = ScatterRouter(
            dst_num=dst_num,
            route_func=route_func,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
            routing_gates=True
        )
        self.scatter_router_1 = ScatterRouter(
            dst_num=dst_num,
            route_func=route_func,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
        )
        self.expert1 = nn.Conv2d(4, 4, 1)
        self.expert2 = nn.Conv2d(4, 8, 1)
        self.expert3 = nn.Conv2d(4, 4, 1)
        self.expert4 = nn.Conv2d(4, 8, 1)
        self.gather_router_0 = GatherRouter(dst_num=dst_num)
        self.gather_router_1 = GatherRouter(dst_num=dst_num, sparse=True)

    def forward(self, x, y):
        route_results_x = self.scatter_router_0(x, route_func(x))
        route_results_y = self.scatter_router_1(y, route_func(x))
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting(2)
dy_model.cuda()
for i in range(10):
    x = torch.randn((3, 4, 2, 2)).cuda()
    y = torch.randn((3, 4, 2, 2)).cuda()
    x, y = dy_model(x, y)
    print(x.shape)
    print(y.shape)
