In [1]:
#Imports
import pandas as pd
import os
import tensorflow as tf

from utils.modelgenerator import *
from utils.modelhandler import *
from utils.datahandler import *

#Get data 
cwd = os.path.normpath(os.getcwd())
df = pd.read_csv(cwd+'/data/df_with_final_features.csv', index_col='Date') #df = pd.read_csv('user5.csv')
df.index = pd.to_datetime(df.index)
#df = df[['User5', 'temp', 'rhum']]
df.fillna(0, inplace=True)

df_array = []
for idx in range(2):
    df_array.append(df[[f'User{idx+1}', 'temp', 'rhum', 'wspd', 'PC1', 'hour sin', 'hour cos', f'User{idx+1}_lag_24hrs']])

#df_array[3].head(3)

#Train, Validation and Test datasets
sequence_length = 25
batch_size = 16
num_features = df_array[0].shape[1]

dh = Datahandler()

X_train, y_train, X_val, y_val, X_test, y_test = {}, {}, {}, {}, {}, {}

#Create Train, Validation and Test datasets
for idx, df in enumerate(df_array):
    n = len(df)
    train_df = df[0:int(n*0.7)]
    val_df = df[int(n*0.7):int(n*0.9)]
    test_df = df[int(n*0.9):]

    # Min max sclaing
    train_df = dh.min_max_scaling(train_df)
    val_df = dh.min_max_scaling(val_df)
    test_df = dh.min_max_scaling(test_df)

    # Sequencing
    train_sequences = dh.create_sequences(train_df, sequence_length)
    val_sequences = dh.create_sequences(val_df, sequence_length)
    test_sequences = dh.create_sequences(test_df, sequence_length)

    #Split into feature and label
    X_train[f'user{idx+1}'], y_train[f'user{idx+1}'] = dh.prepare_data(train_sequences, batch_size)
    X_val[f'user{idx+1}'], y_val[f'user{idx+1}'] = dh.prepare_data(val_sequences, batch_size)
    X_test[f'user{idx+1}'], y_test[f'user{idx+1}'] = dh.prepare_data(test_sequences, batch_size)

#General Hyperparameters
# #All models
horizon = 1
max_epochs = 100
m1 = ModelGenerator()
mh = Modelhandler()

loss = tf.keras.losses.MeanSquaredError()
metrics=[
    tf.keras.metrics.RootMeanSquaredError(), 
    tf.keras.metrics.MeanAbsolutePercentageError(),
    tf.keras.metrics.MeanAbsoluteError(),
]

early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss',patience=10,mode='min')
timing_callback = TimingCallback()
custom_callback = CustomCallback()
#model_checkpoint = ModelCheckpoint('models/best_model.h5', save_best_only=True, monitor='val_loss', mode='min')
callbacks=[early_stopping, timing_callback, custom_callback] #model_checkpoint


# Dense Model

In [2]:
dense_results = pd.DataFrame(columns=['architecture', 'train_time', 'avg_time_epoch', 'mse','mse_std', 'rmse','rmse_std','mape','mape_std','mae','mae_std'])

In [3]:
def build_dense_model(X_train, horizon, num_layers, units, batch_size):

    input_data = layers.Input(shape=(X_train.shape[1], X_train.shape[2]), batch_size=batch_size) 
    x =  layers.Dense(units, activation='relu')(input_data)
    for _ in range(num_layers-1):
      x = layers.Dense(units, activation='relu')(x)
    x = layers.Dropout(0.2)(x)
    x = layers.Flatten()(x)
    output = layers.Dense(horizon)(x) 

    dense_model = tf.keras.Model(inputs=input_data, outputs=output, name="Dense_model")

    return dense_model

In [4]:
#Dense Model

#Dense Hyperparameter
dense_architecture = "L3_U16"
dense_layers = 3
dense_units = 16
dense_all_results = pd.DataFrame(columns=["user", "architecture", "train_time", "avg_time_epoch", "mse", "rmse", "mape", "mae"])

#For each of the 3 user
for idx in range(len(df_array)):
    print("User: ", idx+1)
    for round in range(3):
        #print("Round: ", round)
        dense_model = build_dense_model(X_train[f'user{idx+1}'], horizon, num_layers=dense_layers, units=dense_units, batch_size=batch_size)
        dense_histroy, dense_user_results = mh.compile_fit_evaluate_model(
            model=dense_model, 
            loss=loss, 
            metrics=metrics, 
            X_train=X_train[f'user{idx+1}'],
            y_train = y_train[f'user{idx+1}'], 
            max_epochs = max_epochs, 
            batch_size=batch_size, 
            X_val=X_val[f'user{idx+1}'], 
            y_val=y_val[f'user{idx+1}'], 
            X_test=X_test[f'user{idx+1}'], 
            y_test=y_test[f'user{idx+1}'], 
            callbacks=callbacks, 
            user=f'user{idx+1}', 
            hyper=dense_architecture,
            optimizer=tf.keras.optimizers.Adam(learning_rate=0.001)
        )
        # Add the 'architecture' column from dense_user_results to dense_results
        dense_all_results = pd.merge(dense_all_results, dense_user_results, how='outer')   

for idx in range(len(df_array)):
    new_row = {
        'architecture': dense_all_results["architecture"][0],
        'train_time': dense_all_results[dense_all_results["user"]==f"user{idx+1}"]["train_time"].mean(), 
        'avg_time_epoch' : dense_all_results[dense_all_results["user"]==f"user{idx+1}"]["avg_time_epoch"].mean(),
        'mse': dense_all_results[dense_all_results["user"]==f"user{idx+1}"]["mse"].mean(),
        'mse_std' : dense_all_results[dense_all_results["user"]==f"user{idx+1}"]["mse"].std(),
        'rmse': dense_all_results[dense_all_results["user"]==f"user{idx+1}"]["rmse"].mean(),
        'rmse_std' : dense_all_results[dense_all_results["user"]==f"user{idx+1}"]["rmse"].std(),
        'mape': dense_all_results[dense_all_results["user"]==f"user{idx+1}"]["mape"].mean(),
        'mape_std' : dense_all_results[dense_all_results["user"]==f"user{idx+1}"]["mape"].std(),
        'mae': dense_all_results[dense_all_results["user"]==f"user{idx+1}"]["mae"].mean(),
        'mae_std' : dense_all_results[dense_all_results["user"]==f"user{idx+1}"]["mae"].std(),
    }
    dense_results.loc[len(dense_results)] = new_row

User:  1
User:  2


In [5]:
print("Mean total train time: ", dense_results["train_time"].mean())
print("Mean train time epoch: ", dense_results["avg_time_epoch"].mean())
print("Mean mse: ", dense_results["mse"].mean())
print("Mean rmse: ", dense_results["rmse"].mean())
print("Mean mape: ", dense_results["mape"].mean())
print("Mean mae: ", dense_results["mae"].mean())

Mean total train time:  8.16236412525177
Mean train time epoch:  0.34206747128921855
Mean mse:  0.02608363050967455
Mean rmse:  0.15948817382256192
Mean mape:  135305.29296875
Mean mae:  0.10177501415212949


# MoE top k

In [2]:
dense_moe_results = pd.DataFrame(columns=['architecture', 'train_time', 'avg_time_epoch', 'mse','mse_std', 'rmse','rmse_std','mape','mape_std','mae','mae_std'])

In [11]:
#Builds the expert models for the MoE Layer
def build_expert_network(expert_units):
    expert = keras.Sequential([
            layers.Dense(expert_units, activation="relu"), 
            ])
    return expert


#Builds a MoE model with top_k gating
def build_topk_dense_moe_model(X_train, batch_size, horizon, dense_units, num_experts, top_k, expert_units, m1):
    #Input of shape (batch_size, sequence_length, features)
    inputs = layers.Input(shape=(X_train.shape[1], X_train.shape[2]), batch_size=batch_size, name='input_layer') 
    x = layers.Dense(dense_units, activation="relu")(inputs)
    

    #EMBEDDED MOE LAYER
    # ROUTER
    router_inputs = inputs 
    print("router_inputs: ", router_inputs) #(16, 24, 8)
    router_probs = layers.Dense(num_experts, activation='softmax')(router_inputs)
    print("router_probs: ", router_probs) #(16, 24, 5)
    expert_gate, expert_index = tf.math.top_k(router_probs, k=top_k)
    print("expert_gate: ", expert_gate) #(16, 24, 2)
    print("expert_index: ", expert_index) #(16, 24, 2)
    expert_idx_mask = tf.one_hot(expert_index, depth=num_experts)
    print("expert_idx_mask: ", expert_idx_mask) #(16, 24, 2, 5)
    combined_tensor = tf.einsum('abc,abcd->abd', expert_gate, expert_idx_mask)
    print("combined_tensor: ", combined_tensor) #(16, 24, 5)

    #expert_inputs = tf.einsum("abc,abd->dabc", router_inputs, combined_tensor) # Instead of (3,4) -> (3, 16, 24, 4)
    #print("expert_inputs: ", expert_inputs) (5, 16, 24, 8)


    #expert_input_list = tf.unstack(expert_inputs, axis=0)
    #print("expert_input_list: ", expert_input_list) #[(16, 24, 8), (16, 24, 8), (16, 24, 8), (16, 24, 8), (16, 24, 8)]

    expert_inputs = tf.einsum("abc,abd->dac", router_inputs, combined_tensor) # Instead of (3,4) -> (3, 16, 24, 4)
    print("expert_inputs: ", expert_inputs) #(5, 8)
    
    expert_input_list = tf.unstack(expert_inputs, axis=0)
    print("expert_input_list: ", expert_input_list) #[(16, 8)]



    expert_output_list = [
            [build_expert_network(expert_units=expert_units) for _ in range(num_experts)][idx](expert_input)
            for idx, expert_input in enumerate(expert_input_list)
        ]
    
    expert_outputs = tf.stack(expert_output_list, axis=1)
    expert_outputs_combined = tf.einsum(
            "abcd,ace->acd", expert_outputs, combined_tensor #(16, 2, 24, 4) and (16, 24, 3)
        )    
    moe_output = expert_outputs_combined
    #END MOE LAYER

    #BOTTOM Model
    x = layers.Dropout(0.2)(moe_output)
    x = layers.Flatten()(x)
    outputs = layers.Dense(horizon)(x)
    topk_moe_model = models.Model(inputs=inputs, outputs=outputs, name="topk_moe")

    return topk_moe_model

In [12]:
mh = Modelhandler()
dense_moe_architecture = "top2_exp5_d8" #top2_exp5_d8
dense_units = 16

num_experts = 5
expert_units = 8
top_k = 1

dense_all_results = pd.DataFrame(columns=["user", "architecture", "train_time", "avg_time_epoch", "mse", "rmse", "mape", "mae"])
#For each of the 3 user
for idx in range(len(df_array)):
    print("User: ", idx+1)
    for round in range(1):
        #print("Round: ", round)
        dense_model = build_topk_dense_moe_model(X_train[f'user{idx+1}'], batch_size, horizon, dense_units, num_experts, top_k, expert_units, m1)
        dense_histroy, dense_user_results = mh.compile_fit_evaluate_model(
            model=dense_model, 
            loss=loss, 
            metrics=metrics, 
            X_train=X_train[f'user{idx+1}'],
            y_train = y_train[f'user{idx+1}'], 
            max_epochs = max_epochs, 
            batch_size=batch_size, 
            X_val=X_val[f'user{idx+1}'], 
            y_val=y_val[f'user{idx+1}'], 
            X_test=X_test[f'user{idx+1}'], 
            y_test=y_test[f'user{idx+1}'], 
            callbacks=callbacks, 
            user=f'user{idx+1}', 
            hyper=dense_moe_architecture,
            optimizer=tf.keras.optimizers.Adam(learning_rate=0.001)
        )
        # Add the 'architecture' column from dense_user_results to dense_results
        dense_all_results = pd.merge(dense_all_results, dense_user_results, how='outer')   
        mh.plot_model_predictions(dense_model, dense_histroy, y_test[f'user{idx+1}'], X_test[f'user{idx+1}'], batch_size)

for idx in range(len(df_array)):
    new_row = {
        'architecture': dense_all_results["architecture"][0],
        'train_time': dense_all_results[dense_all_results["user"]==f"user{idx+1}"]["train_time"].mean(), 
        'avg_time_epoch' : dense_all_results[dense_all_results["user"]==f"user{idx+1}"]["avg_time_epoch"].mean(),
        'mse': dense_all_results[dense_all_results["user"]==f"user{idx+1}"]["mse"].mean(),
        'mse_std' : dense_all_results[dense_all_results["user"]==f"user{idx+1}"]["mse"].std(),
        'rmse': dense_all_results[dense_all_results["user"]==f"user{idx+1}"]["rmse"].mean(),
        'rmse_std' : dense_all_results[dense_all_results["user"]==f"user{idx+1}"]["rmse"].std(),
        'mape': dense_all_results[dense_all_results["user"]==f"user{idx+1}"]["mape"].mean(),
        'mape_std' : dense_all_results[dense_all_results["user"]==f"user{idx+1}"]["mape"].std(),
        'mae': dense_all_results[dense_all_results["user"]==f"user{idx+1}"]["mae"].mean(),
        'mae_std' : dense_all_results[dense_all_results["user"]==f"user{idx+1}"]["mae"].std(),
    }
    dense_moe_results.loc[len(dense_moe_results)] = new_row


User:  1
router_inputs:  KerasTensor(type_spec=TensorSpec(shape=(16, 24, 8), dtype=tf.float32, name='input_layer'), name='input_layer', description="created by layer 'input_layer'")
router_probs:  KerasTensor(type_spec=TensorSpec(shape=(16, 24, 5), dtype=tf.float32, name=None), name='dense_44/Softmax:0', description="created by layer 'dense_44'")
expert_gate:  KerasTensor(type_spec=TensorSpec(shape=(16, 24, 2), dtype=tf.float32, name=None), name='tf.math.top_k_4/TopKV2:0', description="created by layer 'tf.math.top_k_4'")
expert_index:  KerasTensor(type_spec=TensorSpec(shape=(16, 24, 2), dtype=tf.int32, name=None), name='tf.math.top_k_4/TopKV2:1', description="created by layer 'tf.math.top_k_4'")
expert_idx_mask:  KerasTensor(type_spec=TensorSpec(shape=(16, 24, 2, 5), dtype=tf.float32, name=None), name='tf.one_hot_4/one_hot:0', description="created by layer 'tf.one_hot_4'")
combined_tensor:  KerasTensor(type_spec=TensorSpec(shape=(16, 24, 5), dtype=tf.float32, name=None), name='tf.eins

ValueError: Exception encountered when calling layer "tf.einsum_11" (type TFOpLambda).

Shape must be rank 4 but is rank 3
	 for 0th input and equation: abcd,ace->acd for '{{node tf.einsum_11/einsum/Einsum}} = Einsum[N=2, T=DT_FLOAT, equation="abcd,ace->acd"](Placeholder, Placeholder_1)' with input shapes: [16,5,8], [16,24,5].

Call arguments received by layer "tf.einsum_11" (type TFOpLambda):
  • equation='abcd,ace->acd'
  • inputs=('tf.Tensor(shape=(16, 5, 8), dtype=float32)', 'tf.Tensor(shape=(16, 24, 5), dtype=float32)')
  • kwargs=<class 'inspect._empty'>

In [18]:
#Test 1: 1 expert (16), top 1
print("Dense Model -----------------------------------")
print("Mean total train time: ", dense_results["train_time"].mean())
print("Mean train time epoch: ", dense_results["avg_time_epoch"].mean())
print("Mean mse: ", dense_results["mse"].mean())
print("Mean rmse: ", dense_results["rmse"].mean())
print("Mean mape: ", dense_results["mape"].mean())
print("Mean mae: ", dense_results["mae"].mean())

print("Mixture of Experts Dense Model -----------------------------------")
print("Mean total train time: ", dense_moe_results["train_time"].mean())
print("Mean train time epoch: ", dense_moe_results["avg_time_epoch"].mean())
print("Mean mse: ", dense_moe_results["mse"].mean())
print("Mean rmse: ", dense_moe_results["rmse"].mean())
print("Mean mape: ", dense_moe_results["mape"].mean())
print("Mean mae: ", dense_moe_results["mae"].mean())

Dense Model -----------------------------------
Mean total train time:  8.16236412525177
Mean train time epoch:  0.34206747128921855
Mean mse:  0.02608363050967455
Mean rmse:  0.15948817382256192
Mean mape:  135305.29296875
Mean mae:  0.10177501415212949
Mixture of Experts Dense Model -----------------------------------
Mean total train time:  26.797215541203816
Mean train time epoch:  0.6380768816364883
Mean mse:  0.02796006730447213
Mean rmse:  0.1648175356288751
Mean mape:  150801.8330078125
Mean mae:  0.1062221818914016


In [None]:
#[batch, group, experts, expert_capacity]
combined_tensor

  Dimensions cheat sheet:
  a, b: batch size
  l: original sequence length
  m: input depth
  n: output depth
  g, h: number of groups
  s, t: group size
  x, y: number of experts
  c, d: expert capacity

   # Now create expert_inputs based on the assignments.
  # put num_experts dimension first to make split easier in alltoall
  expert_inputs_y = mtf.einsum([inputs_y, dispatch_tensor_y], [y, x1, h0, d, m])

  # Second level, all to all. Here we change the split dimension from h0 to y0.
  expert_inputs_y = mtf.reshape(expert_inputs_y, mtf.Shape(
      [y0, x1, h, d, m]))

  hidden_output = mtf.layers.dense(
      expert_inputs_y, hidden_dim, expert_dims=[y0, x1],
      activation=mtf.relu, use_bias=False, master_dtype=master_dtype,
      slice_dtype=slice_dtype, name="expert0")
  expert_output = mtf.layers.dense(
      hidden_output, output_dim, expert_dims=[y0, x1],
      use_bias=False, master_dtype=master_dtype, slice_dtype=slice_dtype,
      name="expert1")

  # NOW COMBINE EXPERT OUTPUTS (reversing everything we have done)
  # expert_output has shape [y0, x1, h, d, n]

  # alltoall
  expert_output = mtf.reshape(expert_output, mtf.Shape(
      [y, x1, h0, d, n]))

  # combine results from inner level
  output_y = mtf.einsum([expert_output, combine_tensor_y], [x1, h0, t, n])

  # Reshape the combined tensor from inner level to now contain outer_batch_dim
  # a0 and group_dim g
  output = mtf.reshape(output_y, [x1, a0, g, c, n])

  # alltoall from expert_dim x to group_dim g1
  expert_output_x = mtf.reshape(output, mtf.Shape([x, a0, g1, c, n]))

  # combine results from outer level
  output_x = mtf.einsum([expert_output_x, combine_tensor_x], [a0, g1, s, n])

  # Reshape the combined tensor to now contain inner_batch_dim
  # b1 and the original sequence length
  output = mtf.reshape(output_x, [a0, b1, l, n])
  if insert_outer_batch_dim:
    output = mtf.reshape(output, [b1, l, n])
  return output, (loss_outer + loss_inner) * hparams.moe_loss_coef

In [None]:
def _top_2_gating(
    inputs, outer_expert_dims, experts_dim, expert_capacity_dim,
    hparams, train, importance=None):
  
  # Inputs: [<batch_dims>, group_size_dim, input_dim]
  # expert_dim: a Dimension (the number of experts)
  # expert_capacity_dim: a Dimension (number of examples per group per expert)
  # hparams:
  # train:  boolean
  
  group_size_dim, unused_input_dim = inputs.shape.dims[-2:] #Last two dimensions -> sequence_length, dense_units

  raw_gates = mtf.softmax(mtf.layers.dense(
      inputs, experts_dim, use_bias=False,
      expert_dims=outer_expert_dims), experts_dim)

  # The internals of this function run in float32.
  #   bfloat16 seems to reduce quality.
  raw_gates = mtf.to_float(raw_gates)

  expert_capacity_f = float(expert_capacity_dim.size)

  # FIND TOP 2 EXPERTS PER POSITON
  # Find the top expert for each position. shape=[batch, group]
  index_1, gate_1 = mtf.top_1(raw_gates, experts_dim)
  # [batch, group, experts]
  mask_1 = mtf.one_hot(index_1, experts_dim, dtype=raw_gates.dtype)
  density_1_proxy = raw_gates
  if importance is not None:
    mask_1 *= mtf.to_float(mtf.equal(importance, 1.0))
    gate_1 *= mtf.to_float(mtf.equal(importance, 1.0))
    density_1_proxy *= mtf.to_float(mtf.equal(importance, 1.0))
  gates_without_top_1 = raw_gates * (1.0 - mask_1)
  # [batch, group]
  index_2, gate_2 = mtf.top_1(gates_without_top_1, experts_dim)
  # [batch, group, experts]
  mask_2 = mtf.one_hot(index_2, experts_dim, dtype=raw_gates.dtype)
  if importance is not None:
    mask_2 *= mtf.to_float(mtf.greater(importance, 0.0))

  denom = gate_1 + gate_2 + 1e-9
  gate_1 /= denom
  gate_2 /= denom

  # BALANCING LOSSES
  # shape = [batch, experts]
  # We want to equalize the fraction of the batch assigned to each expert
  density_1 = mtf.reduce_mean(mask_1, reduced_dim=group_size_dim)
  # Something continuous that is correlated with what we want to equalize.
  density_1_proxy = mtf.reduce_mean(density_1_proxy, reduced_dim=group_size_dim)
  density_1 = mtf.Print(
      density_1, [mtf.reduce_mean(density_1, output_shape=[experts_dim])],
      "density_1", summarize=1000)
  loss = (mtf.reduce_mean(density_1_proxy * density_1)
          * float(experts_dim.size * experts_dim.size))

  if hparams.moe_use_second_place_loss:
    # Also add a loss to encourage all experts to be used equally also as the
    # second-place expert.  Experimentally, this seems to be a wash.
    # We want to equalize the fraction of the batch assigned to each expert:
    density_2 = mtf.reduce_mean(mask_2, reduced_dim=group_size_dim)
    # As a proxy for density_2, we renormalize the raw gates after the top one
    # has been removed.
    normalized = gates_without_top_1 / (
        mtf.reduce_sum(gates_without_top_1, reduced_dim=experts_dim) + 1e-9)
    density_2_proxy = mtf.reduce_mean(normalized, reduced_dim=group_size_dim)
    loss_2 = (mtf.reduce_mean(density_2_proxy * density_2)
              * float(experts_dim.size * experts_dim.size))
    loss += loss_2 * 0.5

  # Depending on the policy in the hparams, we may drop out some of the
  # second-place experts.
  policy = (
      hparams.moe_second_policy_train if train else
      hparams.moe_second_policy_eval)
  threshold = (
      hparams.moe_second_threshold_train if train else
      hparams.moe_second_threshold_eval)
  if policy == "all":
    # Use second-place experts for all examples.
    pass
  elif policy == "none":
    # Never use second-place experts for all examples.
    mask_2 = mtf.zeros_like(mask_2)
  elif policy == "threshold":
    # Use second-place experts if gate_2 > threshold.
    mask_2 *= mtf.to_float(mtf.greater(gate_2, threshold))
  elif policy == "random":
    # Use second-place experts with probablity min(1.0, gate_2 / threshold).
    mask_2 *= mtf.to_float(
        mtf.less(mtf.random_uniform(gate_2.mesh, gate_2.shape),
                 gate_2 / max(threshold, 1e-9)))
  else:
    raise ValueError("Unknown policy %s" % policy)
  mask_2 = mtf.Print(
      mask_2, [mtf.reduce_mean(mask_2, output_shape=[experts_dim])],
      "density_2", summarize=1000)

  # COMPUTE ASSIGNMENT TO EXPERTS
  # [batch, group, experts]
  # This is the position within the expert's mini-batch for this sequence
  position_in_expert_1 = mtf.cumsum(
      mask_1, group_size_dim, exclusive=True) * mask_1
  # Remove the elements that don't fit. [batch, group, experts]
  mask_1 *= mtf.to_float(mtf.less(position_in_expert_1, expert_capacity_f))
  # [batch, experts]
  # How many examples in this sequence go to this expert
  mask_1_count = mtf.reduce_sum(mask_1, reduced_dim=group_size_dim)
  # [batch, group] - mostly ones, but zeros where something didn't fit
  mask_1_flat = mtf.reduce_sum(mask_1, reduced_dim=experts_dim)
  # [batch, group]
  position_in_expert_1 = mtf.reduce_sum(
      position_in_expert_1, reduced_dim=experts_dim)
  # Weight assigned to first expert.  [batch, group]
  gate_1 *= mask_1_flat

  # [batch, group, experts]
  position_in_expert_2 = (
      mtf.cumsum(mask_2, group_size_dim, exclusive=True) + mask_1_count)
  position_in_expert_2 *= mask_2
  mask_2 *= mtf.to_float(mtf.less(position_in_expert_2, expert_capacity_f))
  # mask_2_count = mtf.reduce_sum(mask_2, reduced_dim=experts_dim)
  mask_2_flat = mtf.reduce_sum(mask_2, reduced_dim=experts_dim)
  gate_2 *= mask_2_flat
  position_in_expert_2 = mtf.reduce_sum(
      position_in_expert_2, reduced_dim=experts_dim)

  # [batch, group, experts, expert_capacity]
  combine_tensor = (
      gate_1 * mask_1_flat
      * mtf.one_hot(index_1, experts_dim)
      * mtf.one_hot(mtf.to_int32(position_in_expert_1), expert_capacity_dim) +
      gate_2 * mask_2_flat
      * mtf.one_hot(index_2, experts_dim)
      * mtf.one_hot(mtf.to_int32(position_in_expert_2), expert_capacity_dim))

  combine_tensor = mtf.cast(combine_tensor, inputs.dtype)
  loss = mtf.cast(loss, inputs.dtype)

  dispatch_tensor = mtf.cast(
      mtf.cast(combine_tensor, tf.bool), combine_tensor.dtype)

  return dispatch_tensor, combine_tensor, loss
