In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

In [None]:
vocab_size=20000
num_tokens_per_example=200
(x_train,y_train),(x_val,y_val)=tf.keras.datasets.imdb.load_data(num_words=vocab_size)
print(len(x_train),'Training sentences')
print(len(x_val),'Validation sentences')

x_train=tf.keras.preprocessing.sequence.pad_sequences(x_train,
                               maxlen=num_tokens_per_example,padding='post',)
x_val=tf.keras.preprocessing.sequence.pad_sequences(x_val,
                             maxlen=num_tokens_per_example,padding='post')


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz


  x_train, y_train = np.array(xs[:idx]), np.array(labels[:idx])
  x_test, y_test = np.array(xs[idx:]), np.array(labels[idx:])


25000 Training sentences
25000 Validation sentences


In [None]:
embed_dim=32
num_heads=2
ff_dim=32
num_experts=10
batch_size=50
learning_rate=0.001
dropout_rate=0.25
num_epochs=3
num_tokens_per_batch=(batch_size*num_tokens_per_example)
print(f'Number of tokens per batch:{num_tokens_per_batch}')

Number of tokens per batch:10000


In [None]:
class TokenAndPositionEmbedding(layers.Layer):
  def __init__(self,maxlen,vocab_size,embed_dim):
    super(TokenAndPositionEmbedding,self).__init__()
    self.token_emb=layers.Embedding(vocab_size,embed_dim)
    self.pos_emb=layers.Embedding(maxlen,embed_dim)

  def call(self,x):
    maxlen=tf.shape(x)[-1]
    positions=tf.range(start=0,limit=maxlen,delta=1)
    positions=self.pos_emb(positions)
    x=self.token_emb(x)
    return x+positions  

In [None]:
#Mixture of Experts
def create_feedforward_network(ff_dim,name=None):
  return tf.keras.Sequential([
         layers.Dense(ff_dim,activation='relu'),
         layers.Dense(ff_dim)],name=name)

In [None]:
#This is an auxiliary loss to encourage a balanced load across experts
def load_balanced_loss(router_probs,expert_mask):
  num_experts=tf.shape(expert_mask)[-1]
  density=tf.reduce_mean(expert_mask,axis=0)
  density_proxy=tf.reduce_mean(router_probs,axis=0)
  loss=tf.reduce_mean(density_proxy*density)*tf.cast((num_experts**2),tf.float32)
  return loss

In [None]:
class Router(layers.Layer):
  def __init__(self,num_experts,expert_capacity):
    self.num_experts=num_experts
    self.route=layers.Dense(num_experts)
    self.expert_capacity=expert_capacity
    super(Router,self).__init__()

  def call(self,inputs,training=False):
    #inputs shape:[tokens_per_batch,embed_dim]=[10000,32]
    #router_logits shape:[10000,10]
    router_logits=self.route(inputs)

    if training:
      #Add noise for exploration across experts
      router_logits+=tf.random.uniform(shape=router_logits.shape,minval=0.9,maxval=1.1)

    router_probs=keras.activations.softmax(router_logits,axis=-1)
    expert_gate,expert_index=tf.math.top_k(router_probs,k=1)
    #expert_mask shape:[10000,10]
    expert_mask=tf.one_hot(expert_index,depth=self.num_experts)
    #Compute load balancing loss
    aux_loss=load_balanced_loss(router_probs,expert_mask)
    self.add_loss(aux_loss)

    position_in_expert=tf.cast(tf.math.cumsum(expert_mask,axis=0)*expert_mask,tf.dtypes.int32)
    #Keep only tokens that fit within expert capacity
    expert_mask*=tf.cast(tf.math.less(tf.cast(position_in_expert,tf.dtypes.int32),self.expert_capacity),
                         tf.dtypes.float32,)
    expert_mask_flat=tf.reduce_sum(expert_mask,axis=-1)
    #Mask out the experts that have overflowed the expert capacity
    expert_gate*=expert_mask_flat
    #Combine expert outputs and scaling with router probability
    #combined_tensor shape:[tokens_per_batch,num_experts,expert_capacity]=[10000,10,10000//10]
    combined_tensor=tf.expand_dims(expert_gate
                                   *expert_mask_flat
                                   *tf.squeeze(tf.one_hot(expert_index,depth=self.num_experts),1),
                                   -1)*tf.squeeze(tf.one_hot(position_in_expert,depth=self.expert_capacity),1)
    dispatch_tensor=tf.cast(combined_tensor,tf.dtypes.float32)

    return dispatch_tensor,combined_tensor
                               

In [None]:
expert_mask=tf.one_hot([1,2,1,3],depth=4)  #(10000,)
print(expert_mask)

tf.Tensor(
[[0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [0. 1. 0. 0.]
 [0. 0. 0. 1.]], shape=(4, 4), dtype=float32)


In [None]:
position_in_expert=tf.math.cumsum(expert_mask,axis=0)*expert_mask
print(position_in_expert)

tf.Tensor(
[[0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [0. 2. 0. 0.]
 [0. 0. 0. 1.]], shape=(4, 4), dtype=float32)


In [None]:
expert_mask*=tf.cast(tf.math.less(tf.cast(position_in_expert,tf.dtypes.int32),2),
                         tf.dtypes.float32,)
print(expert_mask)

tf.Tensor(
[[0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 1.]], shape=(4, 4), dtype=float32)


In [None]:
expert_mask_flat=tf.reduce_sum(expert_mask,axis=-1)
print(expert_mask_flat)

tf.Tensor([1. 1. 0. 1.], shape=(4,), dtype=float32)


In [None]:
class Switch(layers.Layer):
  def __init__(self,num_experts,embed_dim,num_tokens_per_batch,capacity_factor=1):
    self.num_experts=num_experts
    self.embed_dim=embed_dim
    self.experts=[
         create_feedforward_network(embed_dim) for _ in range(num_experts)         
    ]
    self.expert_capacity=num_tokens_per_batch//self.num_experts
    self.num_tokens_per_batch=num_tokens_per_batch
    self.router=Router(self.num_experts,self.expert_capacity)
    super(Switch,self).__init__()

  def call(self,inputs):
    batch_size=tf.shape(inputs)[0]
    num_tokens_per_example=tf.shape(inputs)[1]

    #inputs shape:[num_tokens_per_batch,embed_dim]
    inputs=tf.reshape(inputs,[self.num_tokens_per_batch,self.embed_dim])
    #dispatch_tensor shape:[expert_capacity,num_experts,tokens_per_batch]
    #combine_tensor shape:[tokens_per_batch,num_experts,expert_capacity]
    dispatch_tensor,combine_tensor=self.router(inputs)

    #expert_inputs shape:[num_experts,expert_capacity,embed_dim]=[10,1000,16]
    expert_inputs=tf.einsum("ab,acd->cdb",inputs,dispatch_tensor)
    expert_inputs=tf.reshape(expert_inputs,
                     [self.num_experts,self.expert_capacity,self.embed_dim])
     
    
    #Dispatch to experts
    expert_input_list=tf.unstack(expert_inputs,axis=0)
    expert_output_list=[self.experts[idx](expert_input)
                       for idx,expert_input in enumerate(expert_input_list)]

    #expert_outputs shape:[expert_capacity,num_experts,embed_dim]
    expert_outputs=tf.stack(expert_output_list,axis=1)
    #expert_outputs_combined shape:[tokens_per_batch,embed_dim]
    expert_outputs_combined=tf.einsum("abc,xba->xc",expert_outputs,combine_tensor)

    #outputs_shape:[batch_size,num_tokens_per_example,embed_dim]
    outputs=tf.reshape(expert_outputs_combined,
                       [batch_size,num_tokens_per_example,self.embed_dim])
    return outputs


In [None]:
x = tf.reshape(tf.range(12), (3,4))
print('x:',x)
print(tf.unstack(x,axis=0))
print(tf.unstack(x,axis=1))

x: tf.Tensor(
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]], shape=(3, 4), dtype=int32)
[<tf.Tensor: shape=(4,), dtype=int32, numpy=array([0, 1, 2, 3], dtype=int32)>, <tf.Tensor: shape=(4,), dtype=int32, numpy=array([4, 5, 6, 7], dtype=int32)>, <tf.Tensor: shape=(4,), dtype=int32, numpy=array([ 8,  9, 10, 11], dtype=int32)>]
[<tf.Tensor: shape=(3,), dtype=int32, numpy=array([0, 4, 8], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 5, 9], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([ 2,  6, 10], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([ 3,  7, 11], dtype=int32)>]


In [None]:
class TransformerBlock(layers.Layer):
  def __init__(self,embed_dim,num_heads,ffn,dropout_rate=0.1):
    super(TransformerBlock,self).__init__()
    self.att=layers.MultiHeadAttention(num_heads=num_heads,key_dim=embed_dim)
    #The ffn can be either a standard feedforward network or a switch
    #ayer with a Mixture of Experts.
    self.ffn=ffn
    self.layernorm1=layers.LayerNormalization(epsilon=1e-6)
    self.layernorm2=layers.LayerNormalization(epsilon=1e-6)
    self.dropout1=layers.Dropout(dropout_rate)
    self.dropout2=layers.Dropout(dropout_rate)

  def call(self,inputs,training):
    attn_output=self.att(inputs,inputs)
    attn_output=self.dropout1(attn_output,training=training)
    out1=self.layernorm1(inputs+attn_output)
    ffn_output=self.ffn(out1)
    ffn_output=self.dropout2(ffn_output,training=training)
    return self.layernorm2(out1+ffn_output)     


In [None]:
def create_classifier():
  switch=Switch(num_experts,embed_dim,num_tokens_per_batch)
  transformer_block=TransformerBlock(ff_dim,num_heads,switch)

  inputs=layers.Input(shape=(num_tokens_per_example,))
  embedding_layer=TokenAndPositionEmbedding(num_tokens_per_example,vocab_size,embed_dim)

  x=embedding_layer(inputs)
  x=transformer_block(x)
  x=layers.GlobalAveragePooling1D()(x)
  x=layers.Dropout(dropout_rate)(x)
  x=layers.Dense(ff_dim,activation='relu')(x)
  x=layers.Dropout(dropout_rate)(x)
  outputs=layers.Dense(2,activation='softmax')(x)

  classifier=tf.keras.Model(inputs=inputs,outputs=outputs)
  return classifier

In [None]:
def run_experiment(classifier):
  classifier.compile(optimizer=tf.keras.optimizers.Adam(learning_rate),
                     loss='sparse_categorical_crossentropy',
                     metrics=['accuracy'],)

  history=classifier.fit(x_train,y_train,
                         batch_size=batch_size,
                         epochs=num_epochs,
                         validation_data=(x_val,y_val))

  return history 

In [None]:
classifier=create_classifier()
run_experiment(classifier)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<tensorflow.python.keras.callbacks.History at 0x7f20d38adad0>