In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

x = torch.randn (5, 100, 4)
B, T, C = x.shape
expert_cap = 1.0
num_experts = 5

gate = nn.Linear (4, 5)

# compute tokens per expert 
tok_per_expert = int((T / num_experts) * expert_cap)

# compute forwarding scores then probabilities for each token across 5 different experts
scores = gate(x)
probs = F.softmax (scores, dim=-1) # (B, T, E)
probs.shape

torch.Size([5, 100, 5])

In [2]:
expert_looking_at_all_tokens = probs.permute(0, 2, 1) # (B, E, T)
expert_specific_token_probs, expert_specific_tokens = torch.topk (expert_looking_at_all_tokens, tok_per_expert, dim=-1)
expert_looking_at_all_tokens.shape

torch.Size([5, 5, 100])

In [3]:
expert_specific_token_probs.shape, expert_specific_tokens.shape

(torch.Size([5, 5, 20]), torch.Size([5, 5, 20]))

In [4]:
extract_from_BTC_one_hot = F.one_hot (expert_specific_tokens, num_classes=T)
extract_from_BTC_one_hot = extract_from_BTC_one_hot.float()
extract_from_BTC_one_hot.shape # B, E, l, T

# extract from B, T, C using B, E, l, T
# want ->B, E, l, C from B, T, C using B, E, l, T
# 1 T, C - E, l, T
# 1 C T - E T L
# E C L
# E L C

torch.Size([5, 5, 20, 100])

- Mixture of experts simplified
- answer = answer2 is xin for mixture of experts

In [5]:
x_og = x
x = x.unsqueeze(1) # (B, 1, T, C)
extract_from_BTC_one_hot = extract_from_BTC_one_hot # (B, E, l, T) x (B, 1, T, C) -> B E L C

extract_for_einsum = F.one_hot (expert_specific_tokens, T).float() # (B, E, l, T) extract from B, T, C

answer2 = torch.einsum ('BElT, BTC -> BElC', extract_from_BTC_one_hot, x_og)
answer = extract_from_BTC_one_hot @ x

answer2.shape == answer.shape
torch.allclose(answer,answer2)


True

In [6]:
# each expert goes from n_embd to n_hidden
w1 = nn.Parameter(torch.ones(num_experts, 4, 16)) # (8, 512, 1024)
# each expert projects back to n_embd
w2 = nn.Parameter (torch.ones(num_experts, 16, 4)) # (8, 1024, 512)
gelu = nn.GELU()

activation = torch.einsum("BElC, ECH -> BElH", answer2, w1)
activation = gelu(activation)
activation = torch.einsum("BElH, EHC -> BElC", activation, w2)
activation.shape


torch.Size([5, 5, 20, 4])

In [7]:
extract_from_BTC_one_hot.shape

torch.Size([5, 5, 20, 100])

In [8]:


activation = activation * expert_specific_token_probs.unsqueeze(dim=-1)

# to allclose against einsum
activation2 = activation.clone() # (B, E, l, C)
extract_from_BTC_one_hot2 = extract_from_BTC_one_hot.clone() # (B, E, l, T)

B1, E1, l1, C1 = activation.shape # (B, E, l, C)
# one hot shape (B, E, l, T)
B2, E2, l2, T2 = extract_from_BTC_one_hot.shape

extract_from_BTC_one_hot = extract_from_BTC_one_hot.view (B2, E2*l2, T2).permute(0, 2, 1) # (B, T, El)
activation = activation.contiguous().view (B1, E1*l1, C1) # (B, El, C) # from this only extract the tokens where the one hot tensor marks one

final_out = extract_from_BTC_one_hot @ activation
final_out2 = torch.einsum ('BElC, BElT -> BTC', activation2, extract_from_BTC_one_hot2)
torch.allclose(final_out, final_out2)




True

In [13]:
B1, E1, l1, C1, 
activation = activation.view(5, 20*5, 4)

In [None]:
activation.size()

In [None]:
activation = activation.view(5, 100, 4)

In [12]:
extract_from_BTC_one_hot.contiguous().view(B, E1*l1, T).permute(0, 2, 1)[0, 0, :] # (B, T, El)

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [21]:
final_out[0, :, :]

tensor([[ 2.4027e+01,  2.4027e+01,  2.4027e+01,  2.4027e+01],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [-1.6611e-01, -1.6611e-01, -1.6611e-01, -1.6611e-01],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 2.5229e+01,  2.5229e+01,  2.5229e+01,  2.5229e+01],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 1.9628e+01,  1.9628e+01,  1.9628e+01,  1.9628e+01],
        [ 2.6508e+00,  2.6508e+00,  2.6508e+00,  2.6508e+00],
        [ 1.3475e+01,  1.3475e+01,  1.3475e+01,  1.3475e+01],
        [-4.0679e-01, -4.0679e-01, -4.0679e-01, -4.0679e-01],
        [-1.0665e+00, -1.0665e+00, -1.0665e+00, -1.0665e+00],
        [ 1.3877e-01,  1.3877e-01,  1.3877e-01,  1.3877e-01],
        [ 1.5181e+00,  1.5181e+00,  1.5181e+00,  1.5181e+00],
        

In [None]:
activation2 # (B, E,l ,C)


tensor([[[[ 7.6331e+00,  7.6331e+00,  7.6331e+00,  7.6331e+00],
          [ 4.0356e-01,  4.0356e-01,  4.0356e-01,  4.0356e-01],
          [ 8.1093e+00,  8.1093e+00,  8.1093e+00,  8.1093e+00],
          ...,
          [-9.7446e-02, -9.7446e-02, -9.7446e-02, -9.7446e-02],
          [ 9.6621e-01,  9.6621e-01,  9.6621e-01,  9.6621e-01],
          [-2.1677e-01, -2.1677e-01, -2.1677e-01, -2.1677e-01]],

         [[ 2.3569e+01,  2.3569e+01,  2.3569e+01,  2.3569e+01],
          [ 1.5714e+01,  1.5714e+01,  1.5714e+01,  1.5714e+01],
          [ 3.0672e+01,  3.0672e+01,  3.0672e+01,  3.0672e+01],
          ...,
          [ 2.1371e+01,  2.1371e+01,  2.1371e+01,  2.1371e+01],
          [ 1.4707e+00,  1.4707e+00,  1.4707e+00,  1.4707e+00],
          [ 7.8431e+00,  7.8431e+00,  7.8431e+00,  7.8431e+00]],

         [[-1.8080e-02, -1.8080e-02, -1.8080e-02, -1.8080e-02],
          [-3.0998e-02, -3.0998e-02, -3.0998e-02, -3.0998e-02],
          [-2.1409e-01, -2.1409e-01, -2.1409e-01, -2.1409e-01],
      

In [37]:
# B E L T
extract_from_BTC_one_hot2[0,0,0,:]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])