In [58]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


In [62]:
class Experts(nn.Module):
    def __init__(self, dim_in, dim_hidden, num_experts):
        super(Experts, self).__init__()
        self.dim_in = dim_in
        self.num_experts = num_experts
        W1 = torch.empty(num_experts, dim_in, dim_hidden)
        b1 = torch.empty(num_experts, dim_hidden)
        W2 = torch.empty(num_experts, dim_hidden, dim_in)
        b2 = torch.empty(num_experts, dim_in)

        std = 1 / math.sqrt(self.dim_in)
        W1.uniform_(-std, std)
        b1.uniform_(-std, std)
        W2.uniform_(-std, std)
        b2.uniform_(-std, std)

        self.W1 = nn.Parameter(W1)
        self.b1 = nn.Parameter(b1)
        self.W2 = nn.Parameter(W2)
        self.b2 = nn.Parameter(b2)

    def forward(self, x):
        # x, weights, experts_indices = input_and_weights
        # batch, context_length, _ = x.shape
        # experts_mask = torch.zeros( (batch, context_length, self.num_experts), device = x.device, dtype = int) # x.shape[:-1] = batch, context_length

        # experts_mask.scatter_(-1, experts_indices, torch.ones_like(experts_indices, device = x.device))
        a = torch.einsum('bcd,ndh->bcnh', x, self.W1) + self.b1  # pass x to every expert
        z = F.relu(a)
        y = torch.einsum('bcnh,nhd->bcnd', z, self.W2) + self.b2
        return y


class GatingNetwork(nn.Module):
    def __init__(self, dim_in, num_experts, top_k, utilization_factor=1e-2):
        super(GatingNetwork, self).__init__()
        self.dim_in = dim_in
        self.num_experts = num_experts
        self.top_k = top_k
        self.Wg = nn.Linear(dim_in, num_experts, bias=False)
        self.Wnoise = nn.Linear(dim_in, num_experts, bias=False)
        self.utilization_factor = utilization_factor

    def forward(self, x):
        noise = F.softplus(self.Wnoise(x))
        noise *= torch.randn_like(noise).to(x.device)
        logits = self.Wg(x)
        logits += noise
        mask = torch.full_like(logits, -float('inf')).to(x.device)
        selected_logits, selected_indices = torch.topk(logits, self.top_k, dim=-1)
        mask.scatter_(-1, selected_indices, selected_logits)
        weights = F.softmax(mask, dim=-1)
        return weights, self.utilization_loss(weights)

    def utilization_loss(self, weights):
        importance = weights.reshape(-1, self.num_experts).sum(dim=0)
        square_cv = torch.var(importance) / importance.mean().pow(2)
        return self.utilization_factor * square_cv

class MoE(nn.Module):
    def __init__(self, dim_in, dim_hidden, num_experts, top_k):
        super(MoE, self).__init__()
        # no need for dropout because it's already sparse?
        self.dim_in = dim_in
        self.dim_hidden = dim_hidden
        self.num_experts = num_experts
        self.top_k = top_k
        self.gating = GatingNetwork(dim_in, num_experts, top_k)
        self.experts = Experts(dim_in, dim_hidden, num_experts)
        self.utilization_loss = 0 # REMEMBER TO CLEAR THIS AFTER EACH WEIGHT UPDATE!!!
    def forward(self, x):
        weights, loss = self.gating(x)
        self.utilization_loss += loss
        expert_results = self.experts(x)
        return torch.einsum('bcn,bcnd->bcd', weights, expert_results)
   

In [63]:
def weight_select(weights, indices):
    _, D, H = weights.shape
    B, C, K = indices.shape
    return weights[indices.flatten()].reshape(B,C, K, D, H)

def bias_select(biases, indices):
    _, D = biases.shape
    B, C, K = indices.shape
    return biases[indices.flatten()].reshape(B,C, K, D)
def mask_to_indices(mask, k):
    B, C, D = mask.shape
    indices = torch.arange(D).to(mask.device)
    indices = indices.reshape(1, 1, D).expand(B, C, D)
    selected = torch.masked_select(indices,mask.bool())
    return selected.reshape(B,C,k)
    
def masked_linear(x, weights, biases, mask, k):
    idxs = mask_to_indices(mask, k)
    selected_weights = weight_select(weights, idxs)
    selected_biases = bias_select(biases, idxs)
    xW = torch.einsum('bcd,kdh->bckh', x, selected_weights)
    return xW + selected_biases

    return torch.einsum('bcd,bckdh->bckh', X, selected_weights)

In [67]:
class FastExperts(nn.Module):
    def __init__(self, dim_in, dim_hidden, num_experts, top_k):
        super(FastExperts, self).__init__()
        self.dim_in = dim_in
        self.num_experts = num_experts
        self.top_k = top_k

        W1 = torch.empty(num_experts, dim_in, dim_hidden)
        b1 = torch.empty(num_experts, dim_hidden)
        W2 = torch.empty(num_experts, dim_hidden, dim_in)
        b2 = torch.empty(num_experts, dim_in)

        std = 1 / math.sqrt(self.dim_in)
        W1.uniform_(-std, std)
        b1.uniform_(-std, std)
        W2.uniform_(-std, std)
        b2.uniform_(-std, std)

        self.W1 = nn.Parameter(W1)
        self.b1 = nn.Parameter(b1)
        self.W2 = nn.Parameter(W2)
        self.b2 = nn.Parameter(b2)

    def forward(self, x_mask):
        x, mask = x_mask
        idxs = mask_to_indices(mask, self.top_k)
        selected_W1 = weight_select(self.W1, idxs)
        selected_b1 = bias_select(self.b1, idxs)
        selected_W2 = weight_select(self.W2, idxs)
        selected_b2 = bias_select(self.b2, idxs)
        a = torch.einsum('bcd,bckdh->bckh', x, selected_W1) + selected_b1
        z = F.relu(a) # bckh
        y = torch.einsum('bckh,bckhd->bckd',z, selected_W2) + selected_b2
        return y

class FastMoE(nn.Module):
    def __init__(self, dim_in, dim_hidden, num_experts, top_k):
        super(FastMoE, self).__init__()
        # no need for dropout because it's already sparse?
        self.dim_in = dim_in
        self.dim_hidden = dim_hidden
        self.num_experts = num_experts
        self.top_k = top_k
        self.gating = GatingNetwork(dim_in, num_experts, top_k)
        self.experts = FastExperts(dim_in, dim_hidden, num_experts, top_k)
        self.utilization_loss = 0 # REMEMBER TO CLEAR THIS AFTER EACH WEIGHT UPDATE!!!
    def forward(self, x):
        B, C, D = x.shape
        weights, loss = self.gating(x)
        self.utilization_loss += loss
        expert_results = self.experts((x, weights))
        condensed_weights = torch.masked_select(weights, weights!=0).reshape(B,C, self.top_k)
        return torch.einsum('bck,bckd->bcd', condensed_weights, expert_results)


In [71]:
torch.manual_seed(123)
moe = MoE(16,64, 4, 2)
torch.manual_seed(123)
fmoe = FastMoE(16,64, 4, 2)

x = torch.rand(8,10,16)
y1 = moe(x)
y2 = fmoe(x)
print(y1)
print(y2)
torch.allclose(y1,y2)

tensor([[[ 3.7545e-01, -3.4478e-01, -4.0842e-02,  ..., -1.4285e-01,
          -3.7767e-01, -2.4461e-01],
         [ 3.6083e-01, -1.9923e-01,  3.1582e-01,  ...,  6.3526e-02,
          -3.1739e-01,  1.0903e-01],
         [ 6.4246e-01, -1.9165e-01, -1.5668e-01,  ...,  2.2759e-01,
          -2.1124e-01,  1.4033e-01],
         ...,
         [ 6.4189e-01, -1.9953e-01, -1.5698e-01,  ...,  2.3685e-01,
          -2.6781e-01, -4.0507e-03],
         [-1.5182e-01, -5.7137e-01,  4.6534e-01,  ..., -9.6387e-02,
          -5.3043e-01, -1.7141e-01],
         [ 3.6920e-01, -2.7935e-01, -2.9404e-02,  ..., -2.1070e-02,
          -4.0898e-01, -2.2914e-01]],

        [[-6.4496e-02, -4.4203e-01,  3.2405e-01,  ..., -2.3962e-01,
          -3.5366e-01, -3.4330e-01],
         [ 2.5191e-01, -3.6255e-01,  8.3020e-02,  ..., -2.9544e-01,
          -3.0794e-01, -2.4780e-01],
         [ 3.4759e-01, -2.8837e-01,  2.7238e-01,  ...,  1.2699e-01,
          -2.3180e-01, -3.0344e-02],
         ...,
         [ 1.8931e-01, -4

False

In [72]:
def dummy_loss(x):
    return (x-x).sum()

moe = MoE(10,20, 4, 3)
data = torch.rand(2,4,10)
out = moe(data)
l = dummy_loss(out)
print("utilization loss", moe.gating.utilization_loss)
print(l.backward())


utilization loss tensor(7.1993e-05, grad_fn=<AddBackward0>)
gradients before utilization loss:
gradients after:
parameter torch.Size([4, 10])
grad tensor([[ 9.1050e-06, -1.6015e-05,  5.2508e-06, -4.5756e-06,  2.2626e-05,
          3.9798e-06, -8.7628e-07,  9.6896e-06,  1.3756e-06,  1.2081e-05],
        [-1.2657e-04, -1.7113e-04, -1.2502e-04, -1.8104e-04, -1.8223e-04,
         -2.0755e-04, -1.4976e-04, -1.9422e-04, -1.6006e-04, -1.6159e-04],
        [ 2.7330e-04,  3.2446e-04,  2.2848e-04,  3.4808e-04,  3.6828e-04,
          4.3990e-04,  3.1163e-04,  3.4347e-04,  3.2618e-04,  4.2787e-04],
        [-1.5583e-04, -1.3731e-04, -1.0871e-04, -1.6246e-04, -2.0867e-04,
         -2.3633e-04, -1.6099e-04, -1.5894e-04, -1.6749e-04, -2.7836e-04]])
parameter torch.Size([4, 10])
grad tensor([[ 1.8870e-06, -3.1763e-06,  3.6462e-06, -2.8652e-06,  8.8879e-06,
          2.5956e-06,  1.3048e-06,  6.7388e-06,  4.8423e-06,  5.9528e-06],
        [-6.6041e-05, -7.1722e-05, -2.5824e-05, -5.8050e-05, -8.6739e-05

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

In [69]:
from torchviz import make_dot

In [71]:
W = nn.Linear(5,2, bias = False)
x = torch.rand(10,5)
l2_reg = 0
for p in W.parameters():
    l2_reg += p.pow(2).sum()
make_dot(l2_reg, params = dict(W.parameters()))

ExecutableNotFound: failed to execute WindowsPath('dot'), make sure the Graphviz executables are on your systems' PATH

<graphviz.graphs.Digraph at 0x2360d00cc90>

In [28]:
def weight_select(weights, indices):
    _, D, H = weights.shape
    B, C, K = indices.shape
    return weights[indices.flatten()].reshape(B,C, K, D, H)

def bias_select(biases, indices):
    _, D = biases.shape
    B, C, K = indices.shape
    return biases[indices.flatten()].reshape(B,C, K, D)

In [47]:
def mask_to_indices(mask, k):
    B, C, D = mask.shape
    indices = torch.arange(D).to(mask.device)
    indices = indices.reshape(1, 1, D).expand(B, C, D)
    selected = torch.masked_select(indices,mask.bool())
    return selected.reshape(B,C,k)

In [None]:
def masked_forward(x, weights, biases, mask, k):
    idxs = mask_to_indices(mask, k)
    selected_weights = weight_select(weights, idxs)
    selected_biases = bias_select(biases, idxs)
    xW = torch.einsum('bcd,kdh->bckh', x, selected_weights)
    return xW + selected_biases

    return torch.einsum('bcd,bckdh->bckh', X, selected_weights)


In [49]:
mask_to_indices(
    torch.tensor(
        [
            [
                [0,0.1,1],
                [1,0,1]
            ],
            [
                [1,1,0],
                [1,0,1]
            ]
        ]
    ),
    k=2
)

tensor([[[1, 2],
         [0, 2]],

        [[0, 1],
         [0, 2]]])

In [29]:
W = torch.randn(3,2,2)

print(W)
M = torch.tensor([
    [[0,1],[1,2]],
    [[2,1],[0,1]]
])
print(weight_select(W,M))

B = torch.randn(3,2)
print(B)
print(bias_select(B,M))

tensor([[[-1.3343, -1.5194],
         [ 0.5537, -1.8515]],

        [[ 1.5177, -0.8600],
         [-0.7913,  0.7807]],

        [[-0.9847,  0.2793],
         [ 1.4003,  0.8517]]])
tensor([[[[[-1.3343, -1.5194],
           [ 0.5537, -1.8515]],

          [[ 1.5177, -0.8600],
           [-0.7913,  0.7807]]],


         [[[ 1.5177, -0.8600],
           [-0.7913,  0.7807]],

          [[-0.9847,  0.2793],
           [ 1.4003,  0.8517]]]],



        [[[[-0.9847,  0.2793],
           [ 1.4003,  0.8517]],

          [[ 1.5177, -0.8600],
           [-0.7913,  0.7807]]],


         [[[-1.3343, -1.5194],
           [ 0.5537, -1.8515]],

          [[ 1.5177, -0.8600],
           [-0.7913,  0.7807]]]]])
tensor([[-1.2498, -0.7375],
        [ 0.5627,  0.6696],
        [ 0.4355, -0.4725]])
tensor([[[[-1.2498, -0.7375],
          [ 0.5627,  0.6696]],

         [[ 0.5627,  0.6696],
          [ 0.4355, -0.4725]]],


        [[[ 0.4355, -0.4725],
          [ 0.5627,  0.6696]],

         [[-1.2498, -0.73

In [75]:
x = torch.tensor([0,1])
x.to_sparse()
W = torch.randn(2,10)
print(x)

tensor([0, 1])
