# Mix of expert Sparse gate in Pytorch

In [2]:
import torch 
from torch import nn 
from torch.distributions.normal import Normal
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


## 1. Define Experts

In [2]:
class MLP(nn.Module):
    def __init__(self,input_size,output_size,hidden_size) -> None:
        super(MLP,self).__init__()
        self.fc1 = nn.Linear(input_size,hidden_size)
        self.fc2 = nn.Linear(hidden_size,output_size)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax()
    def forward(self,x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        return x


## SparseDispatcher

In [None]:
class SparseDispatcher(object):
    """
    A class to create input minibatches for the experts and to combine the results of experts to form a unified output
    Two functions:
        1. dispatch- take an input tensor and create input tensors for each expert
        2.combine- take output tensors from each expert and form a combined output tensor. 
                    Outputs from different experts for the same batch element are summed together, weighted by the provided gates
    
    The class is initialized with a "gates" tensor, which specifies which batch elements go to which experts and the weights to use when combing the outputs
    Batch element is sent to expert if gates[index_b,index_e] != 0 
    The inputs and outputs are 2D , shape =(batch,depth)
    
    Caller is responsible for collapsing addtional dimensions prior to calling this class and reshaping the output to the original shpae
    See common_layers.reshape_like()

    An Example:
    gates: a float 32 Tensor with shape (batch_size,num_experts)
    inputs: a float 32 Tensor with shape (batch_size,input_size)
    experts: a list of length of num_experts containing expert network

    dispatcher = SpareDispatcher(num_experts,gates)
    
    expert_inputs = dispatcher.dispatch(input)
    expert_outpts = [  experts[i](experts_inputs[i]) for i in range(num_experts)   ]
    
    outputs = dispatcher.combine(expert_outputs)

    The preceding code sets the output for a particular example b to:
    output[b] = Sum_i(gates[b,i]*experts[i](inputs[b]))

    This class takes advantage of sparsity in the gate matrix by including in the Tensors for expert i only the batch elements for which gates [b,i]>0 
    """
    def __init__(self,num_experts,gates) -> None:
        self._gates = gates
        self._num_experts = num_experts
        # Sort experts which has non-zero weight
        sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0)
        #drop indices
        _,self._expert_index = sorted_experts.split(split_size=1,dim=1) 
        # get accroding batch index for each expert 
        self._batch_index =torch.nonzero(gates)[index_sorted_experts[:,1],0]
        # calculate number of samples that each expert get 
        self._part_sizes = (gates>0).sum(0).tolist()
        # expand gates to match with self._batch_index
        gates_exp = gates[self._batch_index.flatten()]
        self._non_zero_gates = torch.gather(gates_exp,1,self._expert_index)
    def dispatch(self,inp):
        inp_exp = inp[self._batch_index].squeeze(1)
        return torch.split(inp_exp,self._part_sizes,dim=0)
    
    def combine(self,expert_out,multiply_by_gates= True):

        # apply exp to expert outputs so that we are not longer in log space
        stitched = torch.cat(expert_out,0).exp()
        if multiply_by_gates:
            stitched = stitched.mul(self._non_zero_gates)
        zeros = torch.zeros(self._gates.size(0),expert_out[-1].size(1),requires_grad=True,device=stitched.device)
        combined = zeros.index_add(0,self._batch_index,stitched.float())
        # add eps to all zero values in order to avoid nan when going back
        combined[combined==0] = np.finfo(float).eps
        return combined.log()
    
    def expert_to_gates(self):
        # Split nonzero gates for each expert
        return torch.split(self._non_zero_gates,self._part_sizes,dim=0)


## 2 MoE

In [None]:
class MoE(nn.Module):
    def __init__(self,input_size,output_size,num_experts,hidden_size,noisy_gating=True,k = 4) -> None:
        """
        The Module of MoE
        Args:
        input_size: size of input
        output_size: size of output
        num_experts: number of experts to be trained
        hidden_size : the size of hidden layer
        noisy_gate: boolean, if add noisy towards the input 
        k: int number of expert to be use for decision , should be <= num_experts
        """
        super(MoE,self).__init__()
        self.noisy_gating = noisy_gating
        self.num_experts = num_experts
        self.output_size = output_size
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.k = k 

        # Create experts with number of experts
        self.experts = nn.ModuleList([MLP(self.input_size,self.output_size,self.hidden_size) for i in range(self.num_experts)])

        # Gate , zero initialize the weight of gate 
        self.w_gate = nn.Parameter(torch.zeros(self.input_size,self.num_experts),requires_grad=True)
        self.w_noise = nn.Parameter(torch.zeros(self.input_size,self.num_experts),requires_grad=True)

        self.softplus =nn.Softplus()
        self.softmax = nn.Softmax(dim = 1)
        self.register_buffer("mean",torch.tensor([0.0]))
        self.register_buffer("std",torch.tensor([1.0]))

        assert(self.k <= self.num_experts)

    def cv_squared(self,x):
        """
        compute squared coeff of variation of a sample
        Useful as a loss to encourage a positive distribution to be more uniform
        """
        eps = 1e-10
        # if only 1 expert
        if x.shape[0] == 1:
            return torch.tensor([0],device=x.device,dtype=x.dtype)
        return x.float().var() / (x.float().mean()**2 + eps)
    

    def _gates_to_load(self,gates):
        """
        Compute the true load per expert, given the gates
        """
        return (gates > 0).sum(0)


    def _prob_in_top_k(self,clean_values,noisy_values,
                        noise_stddev,noisy_top_values):
        """
        Computes the probability that value is in topk, given different random noise.
        It is a way to backprop from a loss that balance the number 
        In the case of no noise, pass None ==> noise_stddev
        Args:
        clean_values: Tensor, size = (batch,n)
        noisy_values: Tensor, size = (batch,n)
                        Equal to clean_values + normaliy distributed noise with std = noise_stddev
        noise_stddev: Tensor, size = (batch,n)
        noisy_top_values: Tensor, size = (batch,m), 
                            values ouput of top_k(noisy_top_values,m), where m>= k+1

        Returns:
        A Tensor of shape (batch,n) 
        """
        batch = clean_values.size(0)
        m = noisy_top_values.size(1)
        top_values_flat = noisy_top_values.flatten()

        threshold_positions_if_in = torch.arange(batch,device=clean_values.device)*m +self.k
        threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat,0,threshold_positions_if_in),1)
        is_in = torch.gt(noisy_values,threshold_if_in) # A boolean tensor to check if the noisy value meet threshold
        
        threshold_positions_if_out = threshold_positions_if_in -1 
        threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat,0,threshold_positions_if_out),1)

        normal  = Normal(self.mean,self.std)
        prob_if_in = normal.cdf((clean_values- threshold_if_in)/noise_stddev)
        prob_if_out = normal.cdf((clean_values- threshold_if_out)/noise_stddev)
        prob = torch.where(is_in,prob_if_in,prob_if_out)
        return prob


    def noisy_top_k_gating(self,x,train,noise_epsilon = 1e-2):
        """
        Noise top-k gating 
        Args:
            x: input tensor with shape [batch_size, input_size]
            train: A boolean, add noise only for training
            noise_epsilon: float
        Returns:
            gates: a Tensor with shape [batch_size,num_experts]
            load: a Tensor with shape [num_experts]
        """
        clean_logits =x @ self.w_gate
        if self.noisy_gating and train:
            raw_noise_stddev = x @ self.w_noise
            noise_stddev = (self.softplus(raw_noise_stddev) + noise_epsilon)
            noisy_logits = clean_logits + (torch.rand_like(clean_logits)*noise_stddev)
            logits = noisy_logits
        else:
            logits = clean_logits
        
        # calculate topk +1 that will be need for the noisy gates
        top_logits , top_indices = logits.topk( min(self.k+1, self.num_experts),dim=1 )
        top_k_logits = top_logits[:,:self.k]
        top_k_indices = top_indices[:,:self.k]
        top_k_gates = self.softmax(top_k_logits)

        zeros = torch.zeros_like(logits,requires_grad=True)
        gates = zeros.scatter(1, top_k_indices,top_k_gates) # dim, index, src

        if self.noisy_gating and self.k < self.num_experts and train:
            load = self._prob_in_top_k(clean_logits,noisy_logits,noise_stddev).sum(0)
        else:
            load = self._gates_to_load(gates)

        return gates,load
    
    def forward(self,x,loss_coef = 1e-2):
        """
        Args:
        x: Tensor, shape = (batch_size, input_size)
        train: boolean
        loss_coef: a scalar- multiplier on load-balancing losses
        
        Returns:
        y: a tensor with shape (batch_size, output_size)
        extra_training_loss: scalar, extra loss encourage all experts to be approximately equally used across a batch
    
        """
        gates, load = self.noisy_top_k_gating(x,self.training)
        
        importance = gates.sum(0)

        loss = self.cv_squared(importance) + self.cv_squared(load)
        loss *= loss_coef

        dispatcher= SparseDispatcher(self.num_experts,gates)
        expert_inputs= dispatcher.dispatch(x)
        gates = dispatcher.expert_to_gates()
        expert_outputs = [ self.experts[i](expert_inputs[i])  for i in range(self.num_experts)   ]

        y = dispatcher.combine(expert_outputs)
        return y,loss

