<a href="https://colab.research.google.com/github/ashaduzzaman-sarker/Text-classification-Sentiment-Analysis/blob/main/Text_Sentiment_classification_with_Switch_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Implement a Switch Transformer for text classification.

## Introduction
The [Swiss Transformer](https://doi.org/10.48550/arXiv.2101.03961) is a variant of the transformer architecture designed for improved efficiency and performance. Here’s a brief explanation:

- **Switch Transformer Overview:** A variant of the Transformer model designed for text classification.
  
- **Key Modification:** Replaces the standard feedforward network (FFN) layer in Transformers with a Mixture of Expert (MoE) routing layer.
  
- **MoE Layer:** Involves multiple experts (sub-networks) that independently process tokens, enabling an increase in model size without proportional increases in computation for each example.
  
- **Parallelism Requirement:** Efficient training requires data and model parallelism, with experts running on separate accelerators simultaneously.

- **Distributed Training:** The full implementation, as described in the original paper, uses TensorFlow Mesh for distributed training, though the provided example is a simpler, non-distributed version for demonstration purposes.

[Mixture of Experts Explained...](https://huggingface.co/blog/moe)

![](https://production-media.paperswithcode.com/methods/a316fa39-5d0e-4058-88a0-31007cbbb44a.png)

## Imports


In [1]:
!pip install --upgrade keras tensorflow



In [2]:
import keras
from keras import ops
from keras import layers

## Download and prepare Dataset

In [3]:
vocab_size = 20000 # Only first 20k words considered
num_tokens_per_example = 200 # Only first 200 words of each movie review

(x_train, y_train), (x_val, y_val) = keras.datasets.imdb.load_data(num_words=vocab_size)

print(len(x_train), "Training sequences")
print(len(x_val), "Validation sequences")

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz
[1m17464789/17464789[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
25000 Training sequences
25000 Validation sequences


In [4]:
x_train = keras.utils.pad_sequences(x_train, maxlen=num_tokens_per_example)
x_val = keras.utils.pad_sequences(x_val, maxlen=num_tokens_per_example)

## Define Hyperparameter

In [5]:
embed_dim = 32  # Embedding size for each token
num_heads = 2  # Number of attention heads
ff_dim = 32  # Hidden layer size in feed forward network inside transformer
num_experts = 10  # Number of experts used in MoE
batch_size = 50  # Batch size
learning_rate = 0.001  # Learning rate
num_epochs = 5  # Number of epochs
dropout_rate = 0.25  # Dropout rate
num_tokens_per_batch = (
    batch_size * num_tokens_per_example
)  # Total number of tokens per batch

print(f'Number of tokens per batch: {num_tokens_per_batch}')

Number of tokens per batch: 10000


## Implement Token & Position embedding layer

In [6]:
class TokenAndPositionEmbedding(layers.Layer):
  def __init__(self, maxlen, vocab_size, embed_dim):
    super().__init__()
    self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
    self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)

  def call(self, x):
    maxlen = ops.shape(x)[-1]
    positions = ops.arange(start=0, stop=maxlen, step=1)
    positions = self.pos_emb(positions)
    x = self.token_emb(x)
    return x + positions

### Implement the feed forward network

In Switch Transformer this is use as Mixture of Experts (MoE)

![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/moe/01_moe_layer.png)

In [7]:
def create_feedforward_network(ff_dim, embed_dim, name=None):
  return keras.Sequential(
      [layers.Dense(ff_dim, activation="relu"), layers.Dense(embed_dim),], name=name
  )

### Implement the Load-balances loss

**Load-balanced loss** is a technique used in Mixture of Experts (MoE) models, such as the Switch Transformer, to ensure that the computational load is evenly distributed across different experts (sub-networks) during training

In [8]:
def load_balanced_loss(router_probs, expert_mask):
    # Get the number of experts from the last dimension of the expert mask
    num_experts = ops.shape(expert_mask)[-1]

    # Calculate the density, which is the mean of the expert mask across the batch dimension
    # This represents the actual load distribution across the experts
    density = ops.mean(expert_mask, axis=0)

    # Calculate the density proxy, which is the mean of the router probabilities across the batch
    # This represents the intended or predicted load distribution across the experts
    density_proxy = ops.mean(router_probs, axis=0)

    # Compute the load-balanced loss
    # The loss is calculated by taking the element-wise product of density and density proxy,
    # followed by averaging and scaling by the square of the number of experts
    loss = ops.mean(density_proxy * density) * ops.cast((num_experts ** 2), 'float32')

    # Return the calculated load-balanced loss
    return loss

### Implement the router as a layer

In [9]:
class Router(layers.Layer):
    def __init__(self, num_experts, expert_capacity):
        self.num_experts = num_experts  # Number of experts available
        self.route = layers.Dense(units=num_experts)  # Dense layer to produce routing logits
        self.expert_capacity = expert_capacity  # Maximum capacity for each expert
        super().__init__()

    def call(self, inputs, training=False):
        # inputs shape: [tokens_per_batch, embed_dim]
        # router_logits shape: [tokens_per_batch, num_experts]
        router_logits = self.route(inputs)  # Compute routing logits for each token

        if training:
            # Add random noise during training to encourage exploration of different experts
            router_logits += keras.random.uniform(
                shape=router_logits.shape,
                minval=0.9,
                maxval=1.1,
            )

        # Convert logits to probabilities using softmax
        router_probs = keras.activations.softmax(router_logits, axis=-1)

        # Select the top expert for each token
        expert_gate, expert_index = ops.top_k(router_probs, k=1)

        # Create a binary mask indicating the selected expert for each token
        expert_mask = ops.one_hot(expert_index, self.num_experts)

        # Calculate the auxiliary load-balancing loss to distribute the load evenly across experts
        aux_loss = load_balanced_loss(router_probs, expert_mask)
        self.add_loss(aux_loss)  # Add the auxiliary loss to the layer's loss

        # Calculate the position of each token within its selected expert
        position_in_expert = ops.cast(ops.cumsum(expert_mask, axis=0) * expert_mask, 'int32')

        # Mask out tokens that exceed the expert's capacity
        expert_mask *= ops.cast(ops.less(ops.cast(position_in_expert, 'int32'), self.expert_capacity), 'float32')

        # Flatten the expert mask to determine if a token was assigned to an expert
        expert_mask_flat = ops.sum(expert_mask, axis=-1)

        # Adjust the gating values by the flattened expert mask
        expert_gate *= expert_mask_flat

        # Combine the inputs, gating values, and position information into a single tensor
        combined_tensor = ops.expand_dims(
            expert_gate
            * expert_mask_flat
            * ops.squeeze(ops.one_hot(expert_index, self.num_experts), 1),
            -1,
        ) * ops.squeeze(ops.one_hot(position_in_expert, self.expert_capacity), 1)

        # Convert the combined tensor to a float32 tensor
        dispatch_tensor = ops.cast(combined_tensor, 'float32')

        return dispatch_tensor, combined_tensor  # Return the dispatch and combined tensors


### Implement a Switch layer

In [10]:
class Switch(layers.Layer):
  def __init__(
      self,
      num_experts,
      embed_dim,
      ff_dim,
      num_tokens_per_batch,
      capacity_factor=1,
  ):
      self.num_experts = num_experts
      self.embed_dim = embed_dim
      self.experts = [
          create_feedforward_network(ff_dim, embed_dim) for _ in range(num_experts)
      ]
      self.expert_capacity = num_tokens_per_batch // self.num_experts
      self.router = Router(self.num_experts, self.expert_capacity)
      super().__init__()

  def call(self, inputs):
      batch_size = ops.shape(inputs)[0]
      num_tokens_per_example = ops.shape(inputs)[1]

      # inputs shape: [num_tokens_per_batch, embed_dim]
      inputs = ops.reshape(inputs, [num_tokens_per_batch, self.embed_dim])

      # dispatch_tensor shape: [expert_capacity, num_experts, tokens_per_batch]
      # compute_tensor shape: [tokens_per_batch, num_experts, expert_capacity]
      dispatch_tensor, combine_tensor = self.router(inputs)

      # expert_inputs shape: [num_experts, expert_capacity, embed_dim]
      expert_inputs = ops.einsum('ab,acd->cdb', inputs, dispatch_tensor)
      expert_inputs = ops.reshape(expert_inputs, [self.num_experts, self.expert_capacity, self.embed_dim])

      # Dispatch to experts
      expert_input_list = ops.unstack(expert_inputs, axis=0)
      expert_output_list = [
          self.experts[idx](expert_input)
          for idx, expert_input in enumerate(expert_input_list)
      ]

      # expert_outputs shape: [expert_capacity, num_experts, embed_dim]
      expert_outputs = ops.stack(expert_output_list, axis=1)

      # expert_outputs_combined shape: [tokens_per_batch, embed_dim]
      expert_outputs_combined = ops.einsum(
          'abc,xba->xc', expert_outputs, combine_tensor
      )

      # output shape: [batch_size, num_tokens_per_example, embed_dim]
      outputs = ops.reshape(
          expert_outputs_combined,
          [batch_size, num_tokens_per_example, self.embed_dim],
      )
      return outputs

### Implement a Transformer block layer

In [11]:
class TransformerBlock(layers.Layer):
  def __init__(self, embed_dim, num_heads, ffn, dropout_rate=0.1):
    super().__init__()
    self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
    # ffn can be either a standard feedforward network or a switch layer with a Mixture of Experts
    self.ffn = ffn
    self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
    self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
    self.dropout1 = layers.Dropout(dropout_rate)
    self.dropout2 = layers.Dropout(dropout_rate)

  def call(self, inputs, training=False):
    attn_output = self.att(inputs, inputs)
    attn_output = self.dropout1(attn_output, training=training)
    out1 = self.layernorm1(inputs + attn_output)
    ffn_output = self.ffn(out1)
    ffn_output = self.dropout2(ffn_output, training=training)
    return self.layernorm2(out1 + ffn_output)

## Implement the classifier

- **TransformerBlock Layer**: Outputs one vector for each time step in the input sequence.
- **Mean Pooling**: Takes the average of all vectors across time steps.
- **Feedforward Network**: Applied on the mean-pooled vector to classify the text.

In [12]:
def create_classifier():
  switch = Switch(num_experts, embed_dim, ff_dim, num_tokens_per_batch)
  transformer_block = TransformerBlock(embed_dim // num_heads, num_heads, switch)

  inputs = layers.Input(shape=(num_tokens_per_example,))
  embedding_layer = TokenAndPositionEmbedding(
      num_tokens_per_example, vocab_size, embed_dim
  )
  x = embedding_layer(inputs)
  x = transformer_block(x)
  x = layers.GlobalAveragePooling1D()(x)
  x = layers.Dropout(dropout_rate)(x)
  x = layers.Dense(ff_dim, activation="relu")(x)
  x = layers.Dropout(dropout_rate)(x)
  outputs = layers.Dense(2, activation="softmax")(x)

  classifier = keras.Model(inputs=inputs, outputs=outputs)
  return classifier

## Train and evaluate the model

In [13]:
def run_experiment(classifier):
  classifier.compile(
      optimizer=keras.optimizers.Adam(learning_rate),
      loss="sparse_categorical_crossentropy",
      metrics=["accuracy"],
  )
  history = classifier.fit(
      x_train,
      y_train,
      batch_size=batch_size,
      epochs=num_epochs,
      validation_data=(x_val, y_val),
  )
  return history

classifier = create_classifier()
run_experiment(classifier)

Epoch 1/5
[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2534s[0m 5s/step - accuracy: 0.7058 - loss: 1.5540 - val_accuracy: 0.8753 - val_loss: 1.2915
Epoch 2/5
[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2534s[0m 5s/step - accuracy: 0.9215 - loss: 1.2154 - val_accuracy: 0.8691 - val_loss: 1.3053
Epoch 3/5
[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2560s[0m 5s/step - accuracy: 0.9584 - loss: 1.1268 - val_accuracy: 0.8654 - val_loss: 1.3515
Epoch 4/5
[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2602s[0m 5s/step - accuracy: 0.9738 - loss: 1.0825 - val_accuracy: 0.8545 - val_loss: 1.5022
Epoch 5/5
[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2561s[0m 5s/step - accuracy: 0.9833 - loss: 1.0536 - val_accuracy: 0.8486 - val_loss: 1.6242


<keras.src.callbacks.history.History at 0x7d3553321de0>