<a href="https://colab.research.google.com/github/Saoudyahya/switchTransformer/blob/main/switchTransformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, hidden_size, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads

        assert (
            self.head_dim * num_heads == hidden_size
        ), "hidden_size must be divisible by num_heads"

        self.q_linear = nn.Linear(hidden_size, hidden_size)
        self.k_linear = nn.Linear(hidden_size, hidden_size)
        self.v_linear = nn.Linear(hidden_size, hidden_size)
        self.out_linear = nn.Linear(hidden_size, hidden_size)

    def forward(self, x):
        batch_size, seq_length, _ = x.size()

        # Linear projections
        queries = self.q_linear(x)  # (batch_size, seq_length, hidden_size)
        keys = self.k_linear(x)      # (batch_size, seq_length, hidden_size)
        values = self.v_linear(x)    # (batch_size, seq_length, hidden_size)

        # Split into multiple heads
        queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)  # (batch_size, num_heads, seq_length, head_dim)
        keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)          # (batch_size, num_heads, seq_length, head_dim)
        values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)    # (batch_size, num_heads, seq_length, head_dim)

        # Scaled dot-product attention
        scores = torch.matmul(queries, keys.transpose(-2, -1)) / (self.head_dim ** 0.5)  # (batch_size, num_heads, seq_length, seq_length)
        attn_weights = F.softmax(scores, dim=-1)  # (batch_size, num_heads, seq_length, seq_length)
        output = torch.matmul(attn_weights, values)  # (batch_size, num_heads, seq_length, head_dim)

        # Concatenate heads and pass through output linear layer
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_length, hidden_size)  # (batch_size, seq_length, hidden_size)
        return self.out_linear(output)  # (batch_size, seq_length, hidden_size)

class SwitchTransformer(nn.Module):
    def __init__(self, num_experts, hidden_size, num_heads):
        super(SwitchTransformer, self).__init__()
        self.num_experts = num_experts

        # List of expert layers (using attention)
        self.experts = nn.ModuleList([MultiHeadSelfAttention(hidden_size, num_heads) for _ in range(num_experts)])

        # Gate mechanism
        self.gate = nn.Linear(hidden_size, num_experts)

    def forward(self, x):
        # Compute the gate values
        gate_values = self.gate(x)  # (batch_size, seq_length, num_experts)

        # Apply softmax to get the probability distribution over experts
        gate_probs = torch.softmax(gate_values, dim=-1)  # (batch_size, seq_length, num_experts)

        # Compute outputs from each expert
        expert_outputs = [expert(x) for expert in self.experts]  # Each has shape (batch_size, seq_length, hidden_size)

        # Stack expert outputs to shape (num_experts, batch_size, seq_length, hidden_size)
        expert_outputs = torch.stack(expert_outputs, dim=0)  # (num_experts, batch_size, seq_length, hidden_size)

        # Reshape gate_probs to (batch_size, seq_length, num_experts, 1) for broadcasting
        gate_probs = gate_probs.unsqueeze(-1)  # (batch_size, seq_length, num_experts, 1)

        # Use torch.einsum for weighted sum across the experts
        output = torch.einsum('bsne,nesh->bsh', gate_probs, expert_outputs)  # (batch_size, seq_length, hidden_size)
        return output

# Example usage
if __name__ == "__main__":
    hidden_size = 768  # Example hidden size (e.g., BERT base)
    num_experts = 4
    num_heads = 8  # Number of attention heads
    batch_size = 2
    seq_length = 10

    model = SwitchTransformer(num_experts, hidden_size=hidden_size, num_heads=num_heads)

    # Dummy input (batch_size x seq_length x hidden_size)
    input_tensor = torch.rand(batch_size, seq_length, hidden_size)
    output = model(input_tensor)
    print("Output shape:", output.shape)  # Should be (batch_size, seq_length, hidden_size)


Output shape: torch.Size([2, 10, 768])


In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, hidden_size, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads

        assert (
            self.head_dim * num_heads == hidden_size
        ), "hidden_size must be divisible by num_heads"

        self.q_linear = nn.Linear(hidden_size, hidden_size)
        self.k_linear = nn.Linear(hidden_size, hidden_size)
        self.v_linear = nn.Linear(hidden_size, hidden_size)
        self.out_linear = nn.Linear(hidden_size, hidden_size)

    def forward(self, x):
        batch_size, seq_length, _ = x.size()

        # Linear projections
        queries = self.q_linear(x)
        keys = self.k_linear(x)
        values = self.v_linear(x)

        # Split into multiple heads
        queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)

        # Scaled dot-product attention
        scores = torch.matmul(queries, keys.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, values)

        output = output.transpose(1, 2).contiguous().view(batch_size, seq_length, hidden_size)
        return self.out_linear(output)

class SwitchTransformer(nn.Module):
    def __init__(self, num_experts, hidden_size, num_heads):
        super(SwitchTransformer, self).__init__()
        self.num_experts = num_experts

        # List of expert layers (using attention)
        self.experts = nn.ModuleList([MultiHeadSelfAttention(hidden_size, num_heads) for _ in range(num_experts)])

        # Gate mechanism
        self.gate = nn.Linear(hidden_size, num_experts)
        self.output_layer = nn.Linear(hidden_size, tokenizer.vocab_size)  # Output layer to map to vocab size

    def forward(self, x):
        gate_values = self.gate(x)  # (batch_size, seq_length, num_experts)
        gate_probs = torch.softmax(gate_values, dim=-1)  # (batch_size, seq_length, num_experts)

        expert_outputs = [expert(x) for expert in self.experts]
        expert_outputs = torch.stack(expert_outputs, dim=0)  # (num_experts, batch_size, seq_length, hidden_size)

        gate_probs = gate_probs.unsqueeze(-1)  # (batch_size, seq_length, num_experts, 1)
        output = torch.einsum('bsne,nesh->bsh', gate_probs, expert_outputs)  # (batch_size, seq_length, hidden_size)

        # Map output to token IDs
        token_logits = self.output_layer(output)  # (batch_size, seq_length, vocab_size)
        return token_logits

# Example usage
if __name__ == "__main__":
    model_name = "bert-base-uncased"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    transformer_model = AutoModel.from_pretrained(model_name)

    hidden_size = 768  # Example hidden size (e.g., BERT base)
    num_experts = 4
    num_heads = 8  # Number of attention heads
    model = SwitchTransformer(num_experts, hidden_size=hidden_size, num_heads=num_heads)

    input_text = "The quick brown fox jumps over the lazy dog."

    # Tokenize the input text
    tokens = tokenizer(input_text, return_tensors='pt', padding=True, truncation=True)

    with torch.no_grad():
        embeddings = transformer_model(**tokens).last_hidden_state  # (batch_size, seq_length, hidden_size)

    # Pass the embeddings through the SwitchTransformer
    output = model(embeddings)

    # Convert the output logits to token IDs
    decoded_ids = torch.argmax(output, dim=-1)  # Get the most likely token IDs
    decoded_output = tokenizer.decode(decoded_ids[0].tolist(), skip_special_tokens=True)

    print("Output shape:", output.shape)  # Should be (batch_size, seq_length, vocab_size)
    print("Decoded output:", decoded_output)  # Should produce more meaningful text


Output shape: torch.Size([1, 12, 30522])
Decoded output: usa articles usa usa usa usa articles articles usa usa regiment regiment
