In [1]:
import torch
scores= torch.randn((8, 4))
print(scores)

tensor([[ 2.6653, -0.7239,  1.1037, -0.0944],
        [-0.3736,  0.4243, -1.3772,  1.2093],
        [-2.3181,  1.5101,  0.3743,  0.3359],
        [-0.1201, -0.3560, -0.4822,  1.0486],
        [ 0.1518,  0.4519,  1.0166,  0.4574],
        [-1.1799, -1.2427,  0.7379, -0.2425],
        [ 0.6018,  0.1740,  1.2797,  1.8234],
        [ 0.8516, -0.2564, -0.3058, -0.5935]])


In [2]:
topk_logits, topk_indices = torch.topk(scores, 2, dim=1)
print(topk_logits )
print(topk_indices)

tensor([[ 2.6653,  1.1037],
        [ 1.2093,  0.4243],
        [ 1.5101,  0.3743],
        [ 1.0486, -0.1201],
        [ 1.0166,  0.4574],
        [ 0.7379, -0.2425],
        [ 1.8234,  1.2797],
        [ 0.8516, -0.2564]])
tensor([[0, 2],
        [3, 1],
        [1, 2],
        [3, 0],
        [2, 3],
        [2, 3],
        [3, 2],
        [0, 1]])


In [3]:
def one_hot_with_dtype(data, num_classes, dtype):
    result = torch.zeros([data.size(0), num_classes], device=data.device, dtype=dtype)
    result.scatter_(1, data.unsqueeze(-1), 1)
    return result


indices_s = [x.view(-1) for x in topk_indices.chunk(2, dim=1)]
print(indices_s)
mask_se = [one_hot_with_dtype(x, num_classes=4, dtype=x.dtype) for x in indices_s]
for mask in mask_se:
    print(mask)


hot_mask = torch.zeros_like(scores, dtype=torch.int64, device=scores.device).scatter_(
    1, topk_indices, 1
)
print(hot_mask)

[tensor([0, 3, 1, 3, 2, 2, 3, 0]), tensor([2, 1, 2, 0, 3, 3, 2, 1])]
tensor([[1, 0, 0, 0],
        [0, 0, 0, 1],
        [0, 1, 0, 0],
        [0, 0, 0, 1],
        [0, 0, 1, 0],
        [0, 0, 1, 0],
        [0, 0, 0, 1],
        [1, 0, 0, 0]])
tensor([[0, 0, 1, 0],
        [0, 1, 0, 0],
        [0, 0, 1, 0],
        [1, 0, 0, 0],
        [0, 0, 0, 1],
        [0, 0, 0, 1],
        [0, 0, 1, 0],
        [0, 1, 0, 0]])
tensor([[1, 0, 1, 0],
        [0, 1, 0, 1],
        [0, 1, 1, 0],
        [1, 0, 0, 1],
        [0, 0, 1, 1],
        [0, 0, 1, 1],
        [0, 0, 1, 1],
        [1, 1, 0, 0]])


In [4]:
importance_scores = -1 * scores.max(dim=1)[0]
print(importance_scores)
print(importance_scores.argsort(dim=0))

tensor([-2.6653, -1.2093, -1.5101, -1.0486, -1.0166, -0.7379, -1.8234, -0.8516])
tensor([0, 6, 2, 1, 3, 4, 7, 5])


In [6]:
from tutel_ea.jit_kernels.gating import fast_cumsum_sub_one
from brt.router.utils import generate_dst_indices


def tutel_compute_location(scores, mask):
    sorted_mask = mask[scores.argsort(dim=0)]
    sorted_cumsum = (fast_cumsum_sub_one(sorted_mask) + 1) * sorted_mask
    return sorted_cumsum[scores.argsort(dim=0).argsort(dim=0)]


def brt_compute_location(scores, mask):
    sorted_mask = mask[scores.argsort(dim=0)]
    sorted_cumsum, loads = generate_dst_indices(sorted_mask)
    return sorted_cumsum[scores.argsort(dim=0).argsort(dim=0)], loads


locations = [tutel_compute_location(importance_scores, mask) for mask in mask_se]
print(locations[0])
print(locations[1])

acc_base = torch.sum(mask_se[0], dim=0, keepdim=True)

locations[1] = (acc_base + locations[1]) * mask_se[1]
print(acc_base)
print(locations[0])
print(locations[1])
print(locations[0] + locations[1])

locations = []
location_base = None

for mask in mask_se:
    location, loads = brt_compute_location(importance_scores, mask)
    location = location if location_base is None else (location_base + location) * mask
    location_base = (
        loads.unsqueeze(0)
        if location_base is None
        else location_base + loads.unsqueeze(0)
    )
    locations.append(location)

print(locations)

tensor([[1, 0, 0, 0],
        [0, 0, 0, 2],
        [0, 1, 0, 0],
        [0, 0, 0, 3],
        [0, 0, 1, 0],
        [0, 0, 2, 0],
        [0, 0, 0, 1],
        [2, 0, 0, 0]])
tensor([[0, 0, 1, 0],
        [0, 1, 0, 0],
        [0, 0, 3, 0],
        [1, 0, 0, 0],
        [0, 0, 0, 1],
        [0, 0, 0, 2],
        [0, 0, 2, 0],
        [0, 2, 0, 0]])
tensor([[2, 1, 2, 3]])
tensor([[1, 0, 0, 0],
        [0, 0, 0, 2],
        [0, 1, 0, 0],
        [0, 0, 0, 3],
        [0, 0, 1, 0],
        [0, 0, 2, 0],
        [0, 0, 0, 1],
        [2, 0, 0, 0]])
tensor([[0, 0, 3, 0],
        [0, 2, 0, 0],
        [0, 0, 5, 0],
        [3, 0, 0, 0],
        [0, 0, 0, 4],
        [0, 0, 0, 5],
        [0, 0, 4, 0],
        [0, 3, 0, 0]])
tensor([[1, 0, 3, 0],
        [0, 2, 0, 2],
        [0, 1, 5, 0],
        [3, 0, 0, 3],
        [0, 0, 1, 4],
        [0, 0, 2, 5],
        [0, 0, 4, 1],
        [2, 3, 0, 0]])
[tensor([[1, 0, 0, 0],
        [0, 0, 0, 2],
        [0, 1, 0, 0],
        [0, 0, 0, 3],
   

In [15]:
gates_s = [(scores * x).sum(dim=1) for x in mask_se]
print(gates_s)
gates = torch.zeros_like(scores, dtype=scores.dtype, device=scores.device)
print(indices_s[0])
print(indices_s[1])
gates.scatter_(1, indices_s[0].unsqueeze(-1), gates_s[0].unsqueeze(-1))
gates.scatter(1, indices_s[1].unsqueeze(-1), gates_s[1].unsqueeze(-1))

[tensor([2.6653, 1.2093, 1.5101, 1.0486, 1.0166, 0.7379, 1.8234, 0.8516]), tensor([ 1.1037,  0.4243,  0.3743, -0.1201,  0.4574, -0.2425,  1.2797, -0.2564])]
tensor([0, 3, 1, 3, 2, 2, 3, 0])
tensor([2, 1, 2, 0, 3, 3, 2, 1])


tensor([[ 2.6653,  0.0000,  1.1037,  0.0000],
        [ 0.0000,  0.4243,  0.0000,  1.2093],
        [ 0.0000,  1.5101,  0.3743,  0.0000],
        [-0.1201,  0.0000,  0.0000,  1.0486],
        [ 0.0000,  0.0000,  1.0166,  0.4574],
        [ 0.0000,  0.0000,  0.7379, -0.2425],
        [ 0.0000,  0.0000,  1.2797,  1.8234],
        [ 0.8516, -0.2564,  0.0000,  0.0000]])