In [13]:
import torch
from pathlib import Path

model = "flame-moe-290m"
runid, epoch, layer = 31066, 1080, 2
shard = "0-0.pt"

samples = torch.load(Path(f"samples/{model}/{runid}", shard), map_location="cpu")
actives = torch.load(Path(f"actives/{model}/{runid}/{epoch}/{layer}", shard), map_location="cpu")
scores, indices = actives

# samples are 1D tensor of tokens.
# These are tokens fed into the model for producing the routing logs.
print("samples".center(40, "-"))
print(samples.shape) # torch.Size([16384])
print(samples) # tensor([1512, 3206,  342,  ..., 2644,  273,  253])

# scores are 2D tensor of routing scores for each token to each of its active experts.
# It has the shape of (num_tokens, num_topk_experts). In our case, we have 8 experts
# activated, where 2 of them are shared experts, so the last dimension is 6
print("scores".center(40, "-"))
print(scores.shape) # torch.Size([16384, 6])
print(scores)

# indices are 2D tensor of expert indices for each token to each of its active experts.
# It has the shape of (num_tokens, num_topk_experts). In our case, we have 8 experts
# activated, where 2 of them are shared experts, so the last dimension is 6
print("indices".center(40, "-"))
print(indices.shape) # torch.Size([16384, 6])
print(indices)


----------------samples-----------------
torch.Size([16384])
tensor([1512, 3206,  342,  ..., 2644,  273,  253])
-----------------scores-----------------
torch.Size([16384, 6])
tensor([[0.0887, 0.0660, 0.0499, 0.0405, 0.0404, 0.0393],
        [0.0677, 0.0631, 0.0536, 0.0428, 0.0315, 0.0309],
        [0.0855, 0.0776, 0.0564, 0.0433, 0.0392, 0.0388],
        ...,
        [0.0194, 0.0194, 0.0193, 0.0190, 0.0189, 0.0188],
        [0.0357, 0.0355, 0.0246, 0.0225, 0.0221, 0.0220],
        [0.0218, 0.0207, 0.0202, 0.0200, 0.0192, 0.0189]])
----------------indices-----------------
torch.Size([16384, 6])
tensor([[43, 47,  3, 21, 16, 34],
        [ 1, 21,  5, 16, 20, 34],
        [20, 28,  4, 34,  9, 33],
        ...,
        [52, 59, 53, 39, 18, 48],
        [20, 36, 45,  2,  7, 38],
        [52, 26, 20, 10, 59, 29]])
