In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install -U datasets



In [None]:

from datasets import load_dataset
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
import torch
from transformers import DataCollatorWithPadding
from sklearn.metrics import accuracy_score

# 1. Load Dataset
dataset = load_dataset("glue", "mnli")

# 2. Tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

def preprocess(example):
    return tokenizer(example["premise"], example["hypothesis"], truncation=True)

encoded_dataset = dataset.map(preprocess, batched=True)
encoded_dataset = encoded_dataset.rename_column("label", "labels")
encoded_dataset.set_format("torch", columns=["input_ids", "token_type_ids", "attention_mask", "labels"])


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel
from sklearn.metrics import accuracy_score

class SharedMoE(nn.Module):
    def __init__(self, model_dim, num_experts):
        super().__init__()
        self.experts = nn.ModuleList([Expert(model_dim) for _ in range(num_experts)])
        self.router = nn.Linear(model_dim, num_experts).to(torch.device('cuda'))

    def forward(self, x):
        B, T, D = x.shape
        x_flat = x.view(-1, D)  # (B*T, D)
        logits = self.router(x_flat)  # (B*T, E)
        indices = logits.argmax(dim=-1)  # (B*T,)
        output = torch.zeros_like(x_flat)

        for i in range(len(self.experts)):
            mask = (indices == i)
            if mask.any():
                selected = x_flat[mask]  # only tokens routed to expert i
                output[mask] = self.experts[i](selected)

        return output.view(B, T, D)


def patch_bert_with_shared_pool_moe(model, shared_moe):
    for i, layer in enumerate(model.encoder.layer):
        # shared_moe = moe_layers[i]
        def new_forward(self,
                        hidden_states,
                        attention_mask=None,
                        head_mask=None,
                        encoder_hidden_states=None,
                        encoder_attention_mask=None,
                        past_key_value=None,
                        output_attentions=False,
                        output_hidden_states=False,
                        return_dict=False, **kwargs):
            self_attention_outputs = layer.attention(hidden_states, attention_mask, head_mask, **kwargs)
            attention_output = self_attention_outputs[0]
            moe_out = shared_moe(attention_output)
            output = layer.output.LayerNorm(moe_out + attention_output)
            return (output,) + self_attention_outputs[1:]
        layer.forward = new_forward.__get__(layer, nn.Module)


# FFN Expert
class Expert(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.ffn = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.ReLU(),
            nn.Linear(dim * 4, dim)
        ).to(torch.device('cuda'))

    def forward(self, x):
        return self.ffn(x)

# Shared Depth-Aware MoE
class SharedDepthAwareMoE(nn.Module):
    def __init__(self, model_dim, num_experts, num_layers, top_k=2):
        super().__init__()
        self.experts = nn.ModuleList([Expert(model_dim) for _ in range(num_experts)])
        self.layer_emb = nn.Embedding(num_layers, model_dim)
        self.router = nn.Linear(model_dim * 2, num_experts)
        self.top_k = top_k

    def forward(self, x, layer_id):
        B, T, D = x.shape
        depth_vec = self.layer_emb(torch.tensor(layer_id, device=x.device)).unsqueeze(0).unsqueeze(1).expand(B, T, -1)
        router_input = torch.cat([x, depth_vec], dim=-1)
        gate_logits = self.router(router_input)
        topk_vals, topk_idx = torch.topk(F.softmax(gate_logits, dim=-1), self.top_k, dim=-1)

        out = torch.zeros_like(x)
        for i in range(self.top_k):
            indices = topk_idx[:, :, i]
            for j, expert in enumerate(self.experts):
                mask = (indices == j).float().unsqueeze(-1)
                out += expert(x * mask) * topk_vals[:, :, i].unsqueeze(-1) * mask
        return out

# Local top-1 routing (Switch-style)
class SwitchMoE(nn.Module):
    def __init__(self, model_dim, num_experts):
        super().__init__()
        self.experts = nn.ModuleList([Expert(model_dim) for _ in range(num_experts)])
        self.router = nn.Linear(model_dim, num_experts)

    def forward(self, x):
        gate_logits = self.router(x)
        indices = gate_logits.argmax(dim=-1)
        out = torch.zeros_like(x)
        for j, expert in enumerate(self.experts):
            mask = (indices == j).float().unsqueeze(-1)
            out += expert(x * mask) * mask
        return out

# Local top-k routing
class TopKMoE(nn.Module):
    def __init__(self, model_dim, num_experts, top_k=2):
        super().__init__()
        self.experts = nn.ModuleList([Expert(model_dim) for _ in range(num_experts)])
        self.router = nn.Linear(model_dim, num_experts)
        self.top_k = top_k

    def forward(self, x):
        gate_logits = self.router(x)
        topk_vals, topk_idx = torch.topk(F.softmax(gate_logits, dim=-1), self.top_k, dim=-1)
        out = torch.zeros_like(x)
        for i in range(self.top_k):
            indices = topk_idx[:, :, i]
            for j, expert in enumerate(self.experts):
                mask = (indices == j).float().unsqueeze(-1)
                out += expert(x * mask) * topk_vals[:, :, i].unsqueeze(-1) * mask
        return out

def patch_bert_with_shared_moe(model, moe_layers):
    for i, layer in enumerate(model.encoder.layer):
        moe = moe_layers[i]
        def new_forward(self,
                        hidden_states,
                        attention_mask=None,
                        head_mask=None,
                        encoder_hidden_states=None,
                        encoder_attention_mask=None,
                        past_key_value=None,
                        output_attentions=False,
                        output_hidden_states=False,
                        return_dict=False, **kwargs):
            self_attention_outputs = layer.attention(hidden_states, attention_mask, head_mask, **kwargs)
            attention_output = self_attention_outputs[0]
            moe_out = moe(attention_output, layer_id=i)
            output = layer.output.LayerNorm(moe_out + attention_output)
            return (output,) + self_attention_outputs[1:]
        layer.forward = new_forward.__get__(layer, nn.Module)


def patch_bert_with_local_moe(model, moe_layers):
    for i, layer in enumerate(model.encoder.layer):
        moe = moe_layers[i]
        def new_forward(self,
                        hidden_states,
                        attention_mask=None,
                        head_mask=None,
                        encoder_hidden_states=None,
                        encoder_attention_mask=None,
                        past_key_value=None,
                        output_attentions=False,
                        output_hidden_states=False,
                        return_dict=False, **kwargs):
            self_attention_outputs = layer.attention(hidden_states, attention_mask, head_mask, **kwargs)
            attention_output = self_attention_outputs[0]
            moe_output = moe(attention_output)
            output = layer.output.LayerNorm(moe_output + attention_output)
            return (output,) + self_attention_outputs[1:]
        layer.forward = new_forward.__get__(layer, nn.Module)


class BertWithMoEClassifier(nn.Module):
    def __init__(self, moe_type='shared', num_experts=4, top_k=1, num_labels=3):
        super().__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        self.config = self.bert.config
        self.moe_type = moe_type
        self.num_layers = self.config.num_hidden_layers
        self.moe_layers = nn.ModuleList()
        self.num_labels = num_labels

        if moe_type == 'shared':
            shared_moe = SharedDepthAwareMoE(
                model_dim=self.config.hidden_size,
                num_experts=num_experts,
                num_layers=self.num_layers,
                top_k=top_k
            )
            for _ in range(self.num_layers):
                self.moe_layers.append(shared_moe)
            patch_bert_with_shared_moe(self.bert, self.moe_layers)
        elif moe_type == 'top1':
            for _ in range(self.num_layers):
                self.moe_layers.append(SwitchMoE(self.config.hidden_size, num_experts))
            patch_bert_with_local_moe(self.bert, self.moe_layers)
        elif moe_type == 'topk':
            for _ in range(self.num_layers):
                self.moe_layers.append(TopKMoE(self.config.hidden_size, num_experts, top_k))
            patch_bert_with_local_moe(self.bert, self.moe_layers)

        elif moe_type == 'shared_pool':
          shared_moe = SharedMoE(
              model_dim=self.config.hidden_size,
              num_experts=num_experts,
          )
          # for _ in range(self.num_layers):
              # self.moe_layers.append(shared_moe)
          patch_bert_with_shared_pool_moe(self.bert, shared_moe)

        else:
            raise ValueError("Invalid moe_type. Choose from: 'shared', 'top1', 'topk'")

        self.classifier = nn.Linear(self.config.hidden_size, self.num_labels)
        self.loss_fct = nn.CrossEntropyLoss()


    def forward(self, input_ids, attention_mask, labels=None):
        # outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        # cls = outputs.last_hidden_state[:, 0, :]
        # return self.classifier(cls)
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls = outputs.last_hidden_state[:, 0, :]
        logits = self.classifier(cls)

        loss = None

        if labels is not None:
            # Calculate loss using CrossEntropyLoss
            loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        # Return logits and loss (loss will be None during evaluation)
        return (loss, logits) if loss is not None else logits


# # Final wrapper
# class BertWithMoEClassifier(nn.Module):
#     def __init__(self, moe_type='shared', num_experts=4, top_k=2):
#         super().__init__()
#         self.bert = BertModel.from_pretrained("bert-base-uncased")
#         self.config = self.bert.config
#         if moe_type == 'shared':
#             moe = SharedDepthAwareMoE(self.config.hidden_size, num_experts, self.config.num_hidden_layers, top_k)
#             patch_bert_with_shared_moe(self.bert, moe)
#         elif moe_type == 'top1':
#             patch_bert_with_local_moe(self.bert, SwitchMoE, num_experts)
#         elif moe_type == 'topk':
#             patch_bert_with_local_moe(self.bert, TopKMoE, num_experts, top_k)
#         else:
#             raise ValueError("Invalid moe_type. Choose from: 'shared', 'top1', 'topk'")

#         self.classifier = nn.Linear(self.config.hidden_size, 2)

#     def forward(self, input_ids, attention_mask):
#         outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
#         cls = outputs.last_hidden_state[:, 0, :]
#         return self.classifier(cls)


In [None]:
model = BertWithMoEClassifier(moe_type='shared_pool', num_experts=48, top_k=1)

In [None]:
import os
os.makedirs("/content/drive/MyDrive/bert-base-moe-mnli-shared", exist_ok=True)

# 4. Metrics
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = torch.argmax(torch.tensor(logits), dim=-1)
    return {"accuracy": accuracy_score(labels, preds)}

# 5. Data Collator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# 6. Training Arguments
training_args = TrainingArguments(
    output_dir="/content/drive/MyDrive/bert-base-moe-mnli-shared",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy"
)

# 7. Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["validation_matched"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)



  trainer = Trainer(


In [None]:
trainer.train()

Epoch,Training Loss,Validation Loss


Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 

In [None]:
import os
os.makedirs("/content/drive/MyDrive/bert-base-mnli-moe-model", exist_ok=True)

In [None]:
# torch.save(model.state_dict(), "/content/drive/MyDrive/bert-base-mnli-moe-model/model.pt")
torch.save(model.state_dict(), "/content/drive/MyDrive/bert-base-mnli-moe-model/sharedpool.pt")



In [None]:
state = torch.load("/content/drive/MyDrive/bert-base-mnli-moe-model/model.pt")

In [None]:
model.load_state_dict(state)

<All keys matched successfully>

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["validation_matched"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

  trainer = Trainer(


In [None]:
trainer.evaluate()

{'eval_loss': 0.8428729176521301,
 'eval_model_preparation_time': 0.0065,
 'eval_accuracy': 0.6197656647987774,
 'eval_runtime': 34.6588,
 'eval_samples_per_second': 283.189,
 'eval_steps_per_second': 4.443}

In [None]:
model.save_pretrained("/content/drive/MyDrive/bert-base-mnli-moe-model")
tokenizer.save_pretrained("/content/drive/MyDrive/bert-base-mnli-moe-tokenizer")

AttributeError: 'BertWithMoEClassifier' object has no attribute 'save_pretrained'

# Top1 MoE:

Epoch	Training Loss	Validation Loss	Accuracy
1	0.926600	0.905462	0.564748
2	0.859100	0.856526	0.609883
3	0.800300	0.842873	0.619766
TrainOutput(global_step=36816, training_loss=0.8897658285499914, metrics={'train_runtime': 11351.5313, 'train_samples_per_second': 103.784, 'train_steps_per_second': 3.243, 'total_flos': 0.0, 'train_loss': 0.8897658285499914, 'epoch': 3.0})