In [None]:
# Install dependencies
!pip install transformers --quiet


In [None]:
!pip install -U datasets

Collecting datasets
  Downloading datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.6.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.5/491.5 kB[0m [31m37.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2025.3.0-py3-none-any.whl (193 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.6/193.6 kB[0m [31m22.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fsspec, datasets
  Attempting uninstall: fsspec
    Found existing installation: fsspec 2025.3.2
    Uninstalling fsspec-2025.3.2:
      Successfully uninstalled fsspec-2025.3.2
  Attempting uninstall: datasets
    Found existing installation: datasets 2.14.4
    Uninstalling datasets-2.14.4:
      Successfully uninstalled datasets-2.14.4
[31mERROR: pip's dependency r

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizerFast, BertModel, BertConfig
from datasets import load_dataset
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score


In [None]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from transformers import BertTokenizerFast, BertModel, BertConfig
# from datasets import load_dataset
# from torch.utils.data import DataLoader
# from sklearn.metrics import accuracy_score

# # Load and tokenize MNLI
# tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
# dataset = load_dataset("glue", "mnli")
# def tokenize_fn(example):
#     return tokenizer(example['premise'], example['hypothesis'],
#                      truncation=True, padding="max_length", max_length=128)
# encoded_dataset = dataset.map(tokenize_fn, batched=True)
# encoded_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
# train_loader = DataLoader(encoded_dataset["train"].select(range(5120)), batch_size=8, shuffle=True)
# val_loader = DataLoader(encoded_dataset["validation_matched"].select(range(2560)), batch_size=16)


from datasets import load_dataset
from transformers import BertTokenizerFast
from torch.utils.data import DataLoader

# Load SST-2 from GLUE benchmark
dataset = load_dataset("glue", "sst2")

# Initialize tokenizer
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

# Tokenize
def tokenize_fn(example):
    return tokenizer(example["sentence"], truncation=True, padding="max_length", max_length=128)

encoded = dataset.map(tokenize_fn, batched=True)
encoded.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

# Small subset for quick comparison
train_loader = DataLoader(encoded["train"].select(range(10000)), batch_size=8, shuffle=True)
val_loader = DataLoader(encoded["validation"], batch_size=16)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/35.3k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/3.11M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/72.8k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/148k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/67349 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/872 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1821 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

Map:   0%|          | 0/1821 [00:00<?, ? examples/s]

In [None]:

# Expert block
class Expert(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.ffn = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.ReLU(),
            nn.Linear(dim * 4, dim)
        )
    def forward(self, x):
        return self.ffn(x)

# Shared Depth-Aware MoE
class SharedDepthAwareMoE(nn.Module):
    def __init__(self, model_dim, num_experts, num_layers, top_k=2):
        super().__init__()
        self.experts = nn.ModuleList([Expert(model_dim) for _ in range(num_experts)])
        self.layer_emb = nn.Embedding(num_layers, model_dim)
        self.router = nn.Linear(model_dim * 2, num_experts)
        self.top_k = top_k

    def forward(self, x, layer_id):
        B, T, D = x.shape
        depth_vec = self.layer_emb(torch.tensor(layer_id, device=x.device)).unsqueeze(0).unsqueeze(1).expand(B, T, -1)
        router_input = torch.cat([x, depth_vec], dim=-1)
        gate_logits = self.router(router_input)
        topk_vals, topk_idx = torch.topk(F.softmax(gate_logits, dim=-1), self.top_k, dim=-1)

        out = torch.zeros_like(x)
        for i in range(self.top_k):
            indices = topk_idx[:, :, i]
            for j, expert in enumerate(self.experts):
                mask = (indices == j).float().unsqueeze(-1)
                out += expert(x * mask) * topk_vals[:, :, i].unsqueeze(-1) * mask
        return out

# Switch-style MoE
class SwitchMoE(nn.Module):
    def __init__(self, model_dim, num_experts):
        super().__init__()
        self.experts = nn.ModuleList([Expert(model_dim) for _ in range(num_experts)]).to('cuda')
        self.router = nn.Linear(model_dim, num_experts).to('cuda')

    def forward(self, x):
        gate_logits = self.router(x)
        indices = gate_logits.argmax(dim=-1)
        out = torch.zeros_like(x, device=x.device)
        for j, expert in enumerate(self.experts):
            mask = (indices == j).float().unsqueeze(-1)
            out += expert(x * mask) * mask
        return out

# Patch BERT layers
def patch_bert_with_moe(model, moe):
    for i, layer in enumerate(model.encoder.layer):
        # Create closure to capture layer index
        def new_forward(self,
                        hidden_states,
                        attention_mask=None,
                        head_mask=None,
                        encoder_hidden_states=None,
                        encoder_attention_mask=None,
                        past_key_value=None,
                        output_attentions=False,
                        output_hidden_states=False,
                        return_dict=False, **kwargs):
            self_attention_outputs = layer.attention(hidden_states, attention_mask, head_mask, **kwargs)
            attention_output = self_attention_outputs[0]
            # Apply shared depth-aware MoE instead of intermediate+output FFN
            moe_out = moe(attention_output, layer_id=i)
            output = layer.output.LayerNorm(moe_out + attention_output)
            return (output,) + self_attention_outputs[1:]

        # Replace entire layer.forward with custom one
        layer.forward = new_forward.__get__(layer, nn.Module)
    return model


def patch_bert_with_local_moe(model, num_experts=4):
    for i, layer in enumerate(model.encoder.layer):
        moe = SwitchMoE(model.config.hidden_size, num_experts)

        # Bypass both intermediate and output projections
        def new_forward(self,
                        hidden_states,
                        attention_mask=None,
                        head_mask=None,
                        encoder_hidden_states=None,
                        encoder_attention_mask=None,
                        past_key_value=None,
                        output_attentions=False,
                        output_hidden_states=False,
                        return_dict=False, **kwargs):
            self_attention_outputs = layer.attention(hidden_states, attention_mask, head_mask, **kwargs)
            attention_output = self_attention_outputs[0]
            # Apply MoE instead of FFN
            moe_output = moe(attention_output)
            output = layer.output.LayerNorm(moe_output + attention_output)
            return (output,) + self_attention_outputs[1:]

        layer.forward = new_forward.__get__(layer, nn.Module)
    return model


# Model wrapper
class BertWithMoEClassifier(nn.Module):
    def __init__(self, shared=True, num_experts=4, top_k=2):
        super().__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        self.config = self.bert.config
        if shared:
            self.moe = SharedDepthAwareMoE(self.config.hidden_size, num_experts, self.config.num_hidden_layers, top_k)
            patch_bert_with_moe(self.bert, self.moe)
        else:
            patch_bert_with_local_moe(self.bert, num_experts)

        self.classifier = nn.Linear(self.config.hidden_size, 2)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls = outputs.last_hidden_state[:, 0, :]
        return self.classifier(cls)

# Training and evaluation
def train_and_evaluate(model, optimizer, train_loader, val_loader, device):
    model.to(device)
    model.train()
    for epoch in range(3):
        epoch_loss = 0
        for batch in train_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)
            optimizer.zero_grad()
            logits = model(input_ids=input_ids, attention_mask=attention_mask)
            loss = F.cross_entropy(logits, labels)
            epoch_loss += loss.item()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch + 1} Loss: {epoch_loss/len(train_loader):.4f}")

    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)
            logits = model(input_ids=input_ids, attention_mask=attention_mask)
            preds = logits.argmax(dim=-1)
            all_preds.extend(preds.cpu().tolist())
            all_labels.extend(labels.cpu().tolist())

    acc = accuracy_score(all_labels, all_preds)
    print(f"Validation Accuracy: {acc:.4f}")
    return acc



In [None]:
# Run both models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from torch.optim import AdamW
# Switch-style MoE
switch_model = BertWithMoEClassifier(shared=False)
switch_optim = AdamW(switch_model.parameters(), lr=2e-5)
switch_acc = train_and_evaluate(switch_model, switch_optim, train_loader, val_loader, device)

print(f"Switch MoE Accuracy: {switch_acc:.4f}")

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Epoch 1 Loss: 0.4820
Epoch 2 Loss: 0.3048
Epoch 3 Loss: 0.2244
Validation Accuracy: 0.7764
Switch MoE Accuracy: 0.7764


In [None]:
from torch.optim import AdamW


# Depth-Aware MoE
depth_model = BertWithMoEClassifier(shared=True)
depth_optim = AdamW(depth_model.parameters(), lr=2e-5)
depth_acc = train_and_evaluate(depth_model, depth_optim, train_loader, val_loader, device)
print(f"Depth-Aware MoE Accuracy: {depth_acc:.4f}")



Epoch 1 Loss: 0.4847
Epoch 2 Loss: 0.3064
Epoch 3 Loss: 0.2252
Validation Accuracy: 0.7936
Depth-Aware MoE Accuracy: 0.7936
