In [1]:
from brt.cpp.router import generate_local_indices
import torch

dst_num = 4
gates = torch.randn((4, 4)).cuda()
topk_indices = torch.topk(gates, k=1, dim=1).indices

route_indices = (
    torch.zeros(
        gates.size(0),
        dst_num,
        dtype=torch.int64,
        device=gates.device,
    )
    .scatter_(1, topk_indices, 1)
    .cuda()
)

supported_capacities = torch.Tensor([1, 2, 4, 8, 16]).int().cuda()

local_indices, dst_loads = generate_local_indices(
    route_indices, supported_capacities
)
print(local_indices)
print(dst_loads)




tensor([[0, 1, 0, 0],
        [0, 2, 0, 0],
        [1, 2, 0, 0],
        [1, 2, 0, 1]], device='cuda:0', dtype=torch.int32)
tensor([[0, 1, 0, 0],
        [0, 2, 0, 0],
        [1, 0, 0, 0],
        [0, 0, 0, 1]], device='cuda:0', dtype=torch.int32)
tensor([1, 2, 1, 1], device='cuda:0', dtype=torch.int32)


In [5]:
from brt.cpp.router import route_with_local_indices
import torch

in_data = torch.randn(4, 4, 4, dtype=torch.float32, requires_grad=True).cuda()

total_load = torch.sum(dst_loads).item()
out_data = torch.zeros((total_load, 4), dtype=in_data.dtype).cuda()

new_gates = torch.ones_like(gates, dtype=torch.float32).cuda()

out_data = route_with_local_indices(
    in_data,
    local_indices,
    dst_loads,
)

print(in_data)
print(out_data)


tensor([[[-1.3002, -1.0905, -0.2003, -0.2435],
         [-0.8435,  1.3128, -2.1729, -0.9251],
         [-0.4823, -0.3195,  0.8688,  0.6669],
         [-3.0042,  0.8174,  1.1353,  0.3623]],

        [[ 0.3643,  0.1735, -1.7406,  0.6933],
         [-0.4151, -0.6063,  0.3414,  0.2058],
         [-1.5642, -1.0758, -0.3469, -0.0738],
         [-1.1588,  1.6699,  1.4229,  1.0638]],

        [[-0.3819,  0.7073,  0.1184, -1.1172],
         [-0.2838,  0.0434, -0.5198,  0.8079],
         [ 0.7819, -1.8814, -0.1268,  1.5259],
         [ 1.6006, -1.1591,  0.2308,  0.1008]],

        [[-1.1017, -1.2753, -1.3597, -0.5143],
         [-0.5345, -0.9381,  0.4791, -0.4463],
         [ 0.7022, -1.3897, -1.2662,  0.3815],
         [ 0.6348,  0.5748,  0.8092, -0.5442]]], device='cuda:0',
       grad_fn=<ToCopyBackward0>)
tensor([[[-0.3819,  0.7073,  0.1184, -1.1172],
         [-0.2838,  0.0434, -0.5198,  0.8079],
         [ 0.7819, -1.8814, -0.1268,  1.5259],
         [ 1.6006, -1.1591,  0.2308,  0.1008]],


In [6]:
from brt.cpp.router import route_back_with_local_indices

final_data = route_back_with_local_indices(out_data, local_indices, dst_loads)

print(final_data)
print(torch.allclose(final_data, in_data))
print(final_data.shape)

tensor([[[-1.3002, -1.0905, -0.2003, -0.2435],
         [-0.8435,  1.3128, -2.1729, -0.9251],
         [-0.4823, -0.3195,  0.8688,  0.6669],
         [-3.0042,  0.8174,  1.1353,  0.3623]],

        [[ 0.3643,  0.1735, -1.7406,  0.6933],
         [-0.4151, -0.6063,  0.3414,  0.2058],
         [-1.5642, -1.0758, -0.3469, -0.0738],
         [-1.1588,  1.6699,  1.4229,  1.0638]],

        [[-0.3819,  0.7073,  0.1184, -1.1172],
         [-0.2838,  0.0434, -0.5198,  0.8079],
         [ 0.7819, -1.8814, -0.1268,  1.5259],
         [ 1.6006, -1.1591,  0.2308,  0.1008]],

        [[-1.1017, -1.2753, -1.3597, -0.5143],
         [-0.5345, -0.9381,  0.4791, -0.4463],
         [ 0.7022, -1.3897, -1.2662,  0.3815],
         [ 0.6348,  0.5748,  0.8092, -0.5442]]], device='cuda:0')
True
torch.Size([4, 4, 4])
