In [None]:
import math
import torch
import torch.nn as nn
from transformers import BertConfig
from transformers.models.bert.modeling_bert import BertSelfAttention

class MoEAttentionExpert(nn.Module):
    """单个Attention专家"""
    def __init__(self, config):
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )
            
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states, attention_mask=None, **kwargs):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        
        if attention_mask is not None:
            attention_scores = attention_scores + attention_mask

        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        attention_probs = self.dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        
        return (context_layer,)

class MoEAttention(BertSelfAttention):
    """混合专家Attention层"""
    def __init__(self, config, num_experts=4, top_k=2):
        super().__init__(config)
        self.num_experts = num_experts
        self.top_k = top_k
        
        # 创建专家池
        self.experts = nn.ModuleList([MoEAttentionExpert(config) for _ in range(num_experts)])
        
        # 门控网络
        self.gate = nn.Linear(config.hidden_size, num_experts)
        self.softmax = nn.Softmax(dim=-1)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
    ):
        # MoE Attention目前不支持encoder_hidden_states(即不支持cross-attention)
        if encoder_hidden_states is not None:
            raise ValueError("MoEAttention does not support cross-attention")
            
        # 计算门控权重 - 使用[CLS] token或平均池化
        gate_input = hidden_states[:, 0, :]  # 使用[CLS] token
        # 或者: gate_input = hidden_states.mean(dim=1)  # 使用平均池化
        
        gate_logits = self.gate(gate_input)  # [batch_size, num_experts]
        gate_probs = self.softmax(gate_logits)
        
        # 选择top-k专家
        top_k_gate_probs, top_k_indices = torch.topk(gate_probs, self.top_k, dim=-1)
        top_k_gate_probs = top_k_gate_probs / top_k_gate_probs.sum(dim=-1, keepdim=True)
        
        # 初始化输出
        batch_size, seq_length, _ = hidden_states.shape
        context_layer = torch.zeros(
            (batch_size, seq_length, self.all_head_size),
            dtype=hidden_states.dtype,
            device=hidden_states.device
        )
        
        # 如果需要输出attention权重
        all_attentions = () if output_attentions else None
        
        # 计算各专家的输出并加权组合
        for i, expert in enumerate(self.experts):
            # 创建当前专家的mask
            expert_mask = (top_k_indices == i).any(dim=1).float()  # [batch_size]
            
            if expert_mask.sum() == 0:
                continue
                
            # 计算当前专家输出
            expert_output = expert(
                hidden_states=hidden_states,
                attention_mask=attention_mask
            )[0]  # 取第一个输出(忽略可能的attention probs)
            
            # 计算权重 (batch_size, 1, 1)
            weights = (top_k_gate_probs * (top_k_indices == i).float()).sum(dim=1)
            weights = weights.view(-1, 1, 1)
            
            # 只对选中的batch应用该专家的输出
            context_layer += expert_output * weights * expert_mask.view(-1, 1, 1)
        
        # 处理head mask(如果需要)
        if head_mask is not None:
            context_layer = context_layer * head_mask
            
        # 返回格式与原始BERT一致
        outputs = (context_layer,)
        if output_attentions:
            outputs += (all_attentions,)
            
        return outputs

In [None]:
from transformers import BertPreTrainedModel, BertModel

class BertWithMoEAttention(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        
        # 原始BERT模型
        self.bert = BertModel(config)
        
        # 替换所有attention层为MoE Attention
        for layer in self.bert.encoder.layer:
            layer.attention.self = MoEAttention(config, num_experts=4, top_k=2)
    
    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, **kwargs):
        return self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            **kwargs
        )

In [None]:
from transformers import BertTokenizer

# 初始化模型
config = BertConfig.from_pretrained("bert-base-uncased")
model = BertWithMoEAttention(config)

# 使用示例输入
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
inputs = tokenizer("Hello world!", return_tensors="pt")

# 前向传播
outputs = model(**inputs)

In [None]:
from transformers import BertForPreTraining

# 初始化自定义模型
config = BertConfig.from_pretrained("bert-base-uncased")
model = BertWithMoEAttention(config)

# 加载原始BERT预训练权重
pretrained_model = BertForPreTraining.from_pretrained("bert-base-uncased")

# 获取原始BERT的状态字典
pretrained_state_dict = pretrained_model.state_dict()

# 获取自定义模型的状态字典
model_state_dict = model.state_dict()

# 筛选可加载的权重（排除自定义Attention部分的权重）
loadable_weights = {k: v for k, v in pretrained_state_dict.items() 
                   if k in model_state_dict and "attention.self.experts" not in k}

# 加载可用的预训练权重
model_state_dict.update(loadable_weights)
model.load_state_dict(model_state_dict)

print("部分预训练权重加载完成（跳过自定义Attention部分）")

In [None]:
outputs = model(**inputs)

In [None]:
model.to("cuda")

In [None]:
from datasets import load_dataset

# 加载SST-2数据集
dataset = load_dataset("glue", "sst2")

# 查看数据集结构
print(dataset)
print("\n样例:")
print(dataset["train"][0])
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

def preprocess_function(examples):
    return tokenizer(
        examples["sentence"], 
        truncation=True, 
        padding="max_length", 
        max_length=128
    )

# 预处理数据集
encoded_dataset = dataset.map(preprocess_function, batched=True)
encoded_dataset = encoded_dataset.rename_column("label", "labels")
encoded_dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

In [None]:
import torch
from torch.utils.data import DataLoader
from transformers import AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm

def train_model(model, train_dataset, val_dataset, model_name="model"):
    # 准备DataLoader
    train_dataloader = DataLoader(
        train_dataset.select(range(5000)),  # 使用前5000个样本作为小数据集
        batch_size=16,
        shuffle=True
    )
    val_dataloader = DataLoader(
        val_dataset.select(range(872)),  # 使用前1000个样本验证
        batch_size=16
    )
    
    # 优化器和学习率调度
    optimizer = AdamW(model.parameters(), lr=2e-5)
    num_epochs = 3
    num_training_steps = num_epochs * len(train_dataloader)
    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer, 
        num_warmup_steps=0, 
        num_training_steps=num_training_steps
    )
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    # 训练循环
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}")
        
        for batch in progress_bar:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
            loss.backward()
            
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            
            total_loss += loss.item()
            progress_bar.set_postfix({"loss": loss.item()})
        
        avg_train_loss = total_loss / len(train_dataloader)
        
        # 验证
        model.eval()
        val_loss = 0
        correct = 0
        total = 0
        
        for batch in val_dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            with torch.no_grad():
                outputs = model(**batch)
            
            val_loss += outputs.loss.item()
            predictions = torch.argmax(outputs.logits, dim=-1)
            correct += (predictions == batch["labels"]).sum().item()
            total += len(batch["labels"])
        
        val_accuracy = correct / total
        avg_val_loss = val_loss / len(val_dataloader)
        
        print(f"{model_name} - Epoch {epoch+1}:")
        print(f"  Train Loss: {avg_train_loss:.4f}")
        print(f"  Val Loss: {avg_val_loss:.4f}")
        print(f"  Val Acc: {val_accuracy:.4f}\n")
    
    return model

In [None]:
from transformers import BertForSequenceClassification

# 训练原始BERT模型
print("=== 训练原始BERT模型 ===")
original_bert = BertForSequenceClassification.from_pretrained(
    "bert-base-uncased", 
    num_labels=2
)
train_model(original_bert, encoded_dataset["train"], encoded_dataset["validation"], "Original BERT")


In [None]:
from transformers import BertPreTrainedModel, BertModel, BertConfig
from transformers.modeling_outputs import SequenceClassifierOutput

class BertWithMoEAttentionForSequenceClassification(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.config = config
        
        # 原始BERT模型（不包含分类头）
        self.bert = BertModel(config)
        
        # 替换所有attention层为MoE Attention
        for layer in self.bert.encoder.layer:
            layer.attention.self = MoEAttention(config, num_experts=4, top_k=2)
        
        # 分类器
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        
        # 初始化权重
        self.post_init()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        # 通过BERT模型获取输出
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        
        # 获取[CLS] token的隐藏状态
        pooled_output = outputs[1]
        
        # 分类头
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        
        # 计算损失
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
        
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output
        
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
# 训练自定义MoE Attention BERT模型
print("\n=== 训练自定义MoE Attention BERT模型 ===")
moe_bert = BertWithMoEAttentionForSequenceClassification.from_pretrained(
    "bert-base-uncased", 
    num_labels=2
)
train_model(moe_bert, encoded_dataset["train"], encoded_dataset["validation"], "MoE BERT")

In [None]:
import matplotlib.pyplot as plt

def plot_results(original_results, moe_results):
    epochs = range(1, len(original_results["train_loss"]) + 1)
    
    plt.figure(figsize=(12, 4))
    
    # 训练损失对比
    plt.subplot(1, 2, 1)
    plt.plot(epochs, original_results["train_loss"], label="Original BERT")
    plt.plot(epochs, moe_results["train_loss"], label="MoE BERT")
    plt.title("Training Loss Comparison")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    
    # 验证准确率对比
    plt.subplot(1, 2, 2)
    plt.plot(epochs, original_results["val_acc"], label="Original BERT")
    plt.plot(epochs, moe_results["val_acc"], label="MoE BERT")
    plt.title("Validation Accuracy Comparison")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.legend()
    
    plt.tight_layout()
    plt.show()

# 假设我们记录了训练过程中的指标
example_results = {
    "original": {
        "train_loss": [0.45, 0.32, 0.28],
        "val_acc": [0.85, 0.87, 0.88]
    },
    "moe": {
        "train_loss": [0.43, 0.30, 0.25],
        "val_acc": [0.86, 0.88, 0.89]
    }
}

plot_results(example_results["original"], example_results["moe"])