In [13]:
import torch
from brt._C.router import (
    generate_indices_and_loads,
    dispatch_with_indices_and_loads,
    split_fused_cells_to_paths,
    fuse_split_cells_from_paths,
    combine_with_indices_and_loads,
)

path_num = 4
gates = torch.randn((2048, path_num)).cuda()
topk_indices = torch.topk(gates, k=2, dim=1).indices

hot_mask = (
    torch.zeros(gates.size(0), path_num, dtype=torch.int32, device=gates.device,)
    .scatter_(1, topk_indices, 1)
    .cuda()
)

supported_capacities = torch.Tensor([64, 128, 256, 512, 1024, 2048]).int().cuda()
# supported_capacities = None

route_indices, seat_loads = generate_indices_and_loads(
    hot_mask, supported_capacities, capacity_padding=True, is_tag_index=False
)
print(route_indices)
print(seat_loads)


tensor([[   0,    1,    1,    0],
        [   1,    0,    2,    0],
        [   0,    0,    3,    1],
        ...,
        [1036,  979,    0,    0],
        [1037,    0,    0, 1018],
        [1038,    0,    0, 1019]], device='cuda:0', dtype=torch.int32)
tensor([2048, 1024, 2048, 1024], device='cuda:0', dtype=torch.int32)


In [14]:
in_data = torch.randn((gates.size(0), 4, 3, 2, 3)).cuda()
# print(in_data)
print(route_indices)
print(seat_loads)
# (out_data_1,) = dispatch_with_indices_and_loads(in_data, route_indices, dst_loads)
# print(out_data_1)
# split_data = split_fused_cells_to_paths(out_data_1, dst_loads)
# print(split_data)

out_data_1, tags = dispatch_with_indices_and_loads(
    in_data, route_indices, seat_loads, tag_generating=True
)
print(tags)

split_data, split_loads, split_tags = split_fused_cells_to_paths(
    out_data_1, seat_loads, is_load_split=True, is_tag_split=True, tags=tags
)
print(split_tags)


tensor([[   0,    1,    1,    0],
        [   1,    0,    2,    0],
        [   0,    0,    3,    1],
        ...,
        [1036,  979,    0,    0],
        [1037,    0,    0, 1018],
        [1038,    0,    0, 1019]], device='cuda:0', dtype=torch.int32)
tensor([2048, 1024, 2048, 1024], device='cuda:0', dtype=torch.int32)
tensor([2, 5, 7,  ..., 0, 0, 0], device='cuda:0', dtype=torch.int32)
[tensor([2, 5, 7,  ..., 0, 0, 0], device='cuda:0', dtype=torch.int32), tensor([1, 5, 6,  ..., 0, 0, 0], device='cuda:0', dtype=torch.int32), tensor([1, 2, 3,  ..., 0, 0, 0], device='cuda:0', dtype=torch.int32), tensor([3, 4, 6,  ..., 0, 0, 0], device='cuda:0', dtype=torch.int32)]


In [15]:
fused_data, new_tags, global_seat_indices = fuse_split_cells_from_paths(
    split_data, is_tag_fuse=True, tags=split_tags
)

# print(fused_data)
print(new_tags)
print(global_seat_indices)

final_data = combine_with_indices_and_loads(
    fused_data, global_seat_indices, dst_loads, tags=new_tags, is_tag_index=True
)
# print(final_data)
print(torch.allclose(final_data, in_data*2))





tensor([   0,    1,    2,  ..., 2046, 2047, 2048], device='cuda:0',
       dtype=torch.int32)
tensor([2, 5, 7,  ..., 0, 0, 0], device='cuda:0')
True


In [16]:
fused_data, fused_loads = fuse_split_cells_from_paths(
    split_data, is_load_fuse=True, loads=split_loads
)

# print(fused_data)
print(new_tags)
print(global_seat_indices)

final_data = combine_with_indices_and_loads(fused_data, route_indices, fused_loads)
# print(final_data)
print(torch.allclose(final_data, in_data * 2))


tensor([   0,    1,    2,  ..., 2046, 2047, 2048], device='cuda:0',
       dtype=torch.int32)
tensor([2, 5, 7,  ..., 0, 0, 0], device='cuda:0')
True
