## GShard Implementation

### Group-level top-2 gating with auxiliary loss

In [1]:
import jax
import jax.numpy as jnp
import numpy as np

# set up basic params

# S = N/G, where N is the number of tokens in a training batch,
# G is the number of groups, S is the number of tokens per group
num_groups = 4 # G
num_experts = 4 # E
num_features = 128 # M
capacity_factor = 1.2
batch_size = 128
num_devices = 2 # D

key = jax.random.PRNGKey(0)
batched_inputs = jax.random.normal(key, shape=(batch_size, num_features), dtype=jnp.float32)

## sharding inputs
sharded_inputs = jnp.reshape(batched_inputs,
                             newshape=(num_groups, jnp.floor_divide(batch_size, num_groups), num_features))
print("sharded input shape: {}".format(sharded_inputs.shape))

## init wg
# gates per token per expert
wg = jax.random.normal(key, shape=(num_features, num_experts), dtype=jnp.float32)
print("wg shape: {}".format(wg.shape))

sharded input shape: (4, 32, 128)
wg shape: (128, 4)


## gates

In [2]:
## gates = softmax ( einsum ( "GSM , ME - >GSE " , inputs , wg ))

local_inputs = sharded_inputs[0]

# make group-local by removing G
output = jnp.einsum("SM,ME->SE", local_inputs, wg)
# pass through softmax
gates = jax.nn.softmax(output, axis=1)
print(gates.shape)
# test the sum is correct
# jnp.sum(output, axis=2)

(32, 4)


## Top2Gating

In [6]:
## combine_weights , dispatch_mask = Top2Gating ( gates )

def _capacity(gates, capacity_factor, min_capacity=0):
    # gates with shape SE as above
    num_tokens = gates.shape[0]
    num_experts = gates.shape[1]
    capacity = jnp.ceil((num_tokens / num_experts) * capacity_factor)
#     if capacity < min_capacity:
#         capacity = min_capacity
    return capacity

def Top2Gating(gates):
    indices1_s = jnp.argmax(gates, axis=1)
#     print("top 1 gate id: {}".format(indices1_s))
    weights1 = jnp.array([row[indices1_s[i]] for i, row in enumerate(gates)])
#     print("top 1 gate prob: {}".format(weights1))
    num_experts = gates.shape[1]
    mask1 = jax.nn.one_hot(indices1_s, num_experts)
#     print("one hot encoding mask for top 1 gates: {}".format(mask1))
    exp_counts1 = jnp.sum(mask1, axis=0)
#     print("expert activations count: {}".format(exp_counts1))
    
    # replace top experts with mean value
    masked_gates2 = gates * (-mask1 + 1)
#     print("masking out the largest prob for top 1: {}".format(masked_gates2))
    
    indices2_s = jnp.argmax(masked_gates2, axis=1)
    mask2 = jax.nn.one_hot(indices2_s, num_experts)
    
    
    
    # Compute locations in capacity buffer
    locations1 = jnp.cumsum(mask1, axis=0) - 1
    locations2 = jnp.cumsum(mask2, axis=0) - 1
#     print("locations2 in capacity buffer:{}".format(locations2))
    # Update 2nd's location by accounting for locations of 1st
    locations2 += exp_counts1
#     print("locations2 in capacity buffer:{}".format(locations2))
    
    # Remove locations outside capacity from mask
    capacity = _capacity(gates, capacity_factor=1.2)
    
#     print("one hot encoding mask for top 1 gates: {}".format(mask1))
    mask1 *= (locations1 < capacity)
#     print("one hot encoding mask for top 1 gates after masking: {}".format(mask1))
    mask2 *= (locations2 < capacity)
#     print("one hot encoding mask for top 2 gates after masking: {}".format(mask2))

    # Store the capacity location for each token
    locations1_s = jnp.sum(locations1 * mask1, axis=1)
#     print("capacity location1 for each token: {}".format(locations1_s))
    
    locations2_s = jnp.sum(locations2 * mask2, axis=1)
#     print("capacity location2 for each token: {}".format(locations2_s))
    
    # normalize gate probs
    gates1_s = jnp.einsum("se,se->s", gates, mask1)
    gates2_s = jnp.einsum("se,se->s", gates, mask2)
    denom_s = gates1_s + gates2_s
    # avoid divide by 0
    gates1_s /= (denom_s + 0.0001)
    gates2_s /= (denom_s + 0.0001)
    
    # get combine_weights and dispatch_mask
    gates1 = jnp.einsum("s,se->se", gates1_s, mask1)
    gates2 = jnp.einsum("s,se->se", gates2_s, mask2)
#     print(gates1)
    locations1_sc = jax.nn.one_hot(locations1_s, capacity)
    locations2_sc = jax.nn.one_hot(locations2_s, capacity)
    combine1_sec = jnp.einsum("se,sc->sec", gates1, locations1_sc)
    combine2_sec = jnp.einsum("se,sc->sec", gates2, locations2_sc)
    combine_weights = combine1_sec + combine2_sec
    print(combine_weights)
    dispatch_mask = jnp.where(combine_weights > 0, 1, 0)
    print(dispatch_mask.shape)
    return combine_weights, dispatch_mask
    
Top2Gating_jit = jax.jit(Top2Gating)
combine_weights, dispatch_mask = Top2Gating_jit(gates)

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The error arose in jax.nn.one_hot argument `num_classes`.
While tracing the function Top2Gating at /tmp/ipykernel_2310/2983666350.py:12 for jit, this value became a tracer due to JAX operations on these lines:

  operation a:f32[] = ceil b
    from line /tmp/ipykernel_2310/2983666350.py:7 (_capacity)

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

In [4]:
##  i sp a t ch e d _e x p er t _ in p u ts = einsum (
## "GSEC ,GSM - >EGCM " , dispatch_mask , reshaped_inputs )
## convert to local: "SEC ,SM - >ECM "
print(local_inputs.shape)
dispatched_input = jnp.einsum("sec,sm->ecm", dispatch_mask, local_inputs)

(32, 128)
