In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [59]:

class Experts(nn.Module):
    def __init__(self, dim_in, dim_hidden, num_experts):
        super(Experts, self).__init__()
        self.dim_in = dim_in
        self.num_experts = num_experts
        W1 = torch.empty(num_experts, dim_in, dim_hidden)
        b1 = torch.empty(num_experts, dim_hidden)
        W2 = torch.empty(num_experts, dim_hidden, dim_in)
        b2 = torch.empty(num_experts, dim_in)

        std = 1 / math.sqrt(self.dim_in)
        W1.uniform_(-std, std)
        b1.uniform_(-std, std)
        W2.uniform_(-std, std)
        b2.uniform_(-std, std)

        self.W1 = nn.Parameter(W1)
        self.b1 = nn.Parameter(b1)
        self.W2 = nn.Parameter(W2)
        self.b2 = nn.Parameter(b2)

    def forward(self, x):
        # x, weights, experts_indices = input_and_weights
        # batch, context_length, _ = x.shape
        # experts_mask = torch.zeros( (batch, context_length, self.num_experts), device = x.device, dtype = int) # x.shape[:-1] = batch, context_length

        # experts_mask.scatter_(-1, experts_indices, torch.ones_like(experts_indices, device = x.device))
        a = torch.einsum('bcd,ndh->bcnh', x, self.W1) + self.b1  # pass x to every expert
        z = F.relu(a)
        y = torch.einsum('bcnh,nhd->bcnd', z, self.W2) + self.b2
        return y


class GatingNetwork(nn.Module):
    def __init__(self, dim_in, num_experts, top_k, utilization_factor=1e-2):
        super(GatingNetwork, self).__init__()
        self.dim_in = dim_in
        self.num_experts = num_experts
        self.top_k = top_k
        self.Wg = nn.Linear(dim_in, num_experts, bias=False)
        self.Wnoise = nn.Linear(dim_in, num_experts, bias=False)
        self.utilization_factor = utilization_factor
        self.utilization_loss = 0 # REMEMBER TO CLEAR THIS AFTER EACH WEIGHT UPDATE!!!

    def forward(self, x):
        noise = F.softplus(self.Wnoise(x))
        noise *= torch.randn_like(noise).to(x.device)
        logits = self.Wg(x)
        logits += noise
        mask = torch.full_like(logits, -float('inf')).to(x.device)
        selected_logits, selected_indices = torch.topk(logits, self.top_k, dim=-1)
        mask.scatter_(-1, selected_indices, selected_logits)
        weights = F.softmax(mask, dim=-1)
        self.utilization_loss += self.compute_utilization_loss(weights)
        return weights

    def compute_utilization_loss(self, weights):
        importance = weights.reshape(-1, self.num_experts).sum(dim=0)
        square_cv = importance.var(correction=0) / importance.mean().pow(2)
        return self.utilization_factor * square_cv

def utilization_loss_pre_hook(module, grad_out):
    print("gradients before utilization loss:")

    module.utilization_loss.backward()
    print("gradients after:")
    for p in module.parameters():
        print("parameter",p.shape)
        print("grad",p.grad)
    # print("utilization loss backward", util_loss_back)
    # new_grad_out = grad_out + util_loss_back
    # print("grad with utilization loss", new_grad_out)
    # return new_grad_out
    pass

class MoE(nn.Module):
    def __init__(self, dim_in, dim_hidden, num_experts, top_k):
        super(MoE, self).__init__()
        # no need for dropout because it's already sparse?
        self.dim_in = dim_in
        self.dim_hidden = dim_hidden
        self.num_experts = num_experts
        self.top_k = top_k
        self.gating = GatingNetwork(dim_in, num_experts, top_k)
        self.gating.register_full_backward_pre_hook(utilization_loss_pre_hook)
        self.experts = Experts(dim_in, dim_hidden, num_experts)
        
    def forward(self, x):
        weights = self.gating(x)
        expert_results = self.experts(x)
        return torch.einsum('bcn,bcnd->bcd', weights, expert_results)
        # this implementation probably activates all the parameters, so no computational speed up. But that's not important for this RQ


In [72]:
def dummy_loss(x):
    return (x-x).sum()

moe = MoE(10,20, 4, 3)
data = torch.rand(2,4,10)
out = moe(data)
l = dummy_loss(out)
print("utilization loss", moe.gating.utilization_loss)
print(l.backward())


utilization loss tensor(7.1993e-05, grad_fn=<AddBackward0>)
gradients before utilization loss:
gradients after:
parameter torch.Size([4, 10])
grad tensor([[ 9.1050e-06, -1.6015e-05,  5.2508e-06, -4.5756e-06,  2.2626e-05,
          3.9798e-06, -8.7628e-07,  9.6896e-06,  1.3756e-06,  1.2081e-05],
        [-1.2657e-04, -1.7113e-04, -1.2502e-04, -1.8104e-04, -1.8223e-04,
         -2.0755e-04, -1.4976e-04, -1.9422e-04, -1.6006e-04, -1.6159e-04],
        [ 2.7330e-04,  3.2446e-04,  2.2848e-04,  3.4808e-04,  3.6828e-04,
          4.3990e-04,  3.1163e-04,  3.4347e-04,  3.2618e-04,  4.2787e-04],
        [-1.5583e-04, -1.3731e-04, -1.0871e-04, -1.6246e-04, -2.0867e-04,
         -2.3633e-04, -1.6099e-04, -1.5894e-04, -1.6749e-04, -2.7836e-04]])
parameter torch.Size([4, 10])
grad tensor([[ 1.8870e-06, -3.1763e-06,  3.6462e-06, -2.8652e-06,  8.8879e-06,
          2.5956e-06,  1.3048e-06,  6.7388e-06,  4.8423e-06,  5.9528e-06],
        [-6.6041e-05, -7.1722e-05, -2.5824e-05, -5.8050e-05, -8.6739e-05

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

In [69]:
from torchviz import make_dot

In [71]:
W = nn.Linear(5,2, bias = False)
x = torch.rand(10,5)
l2_reg = 0
for p in W.parameters():
    l2_reg += p.pow(2).sum()
make_dot(l2_reg, params = dict(W.parameters()))

ExecutableNotFound: failed to execute WindowsPath('dot'), make sure the Graphviz executables are on your systems' PATH

<graphviz.graphs.Digraph at 0x2360d00cc90>