In [1]:
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer, BertConfig
from torch.utils.data import Dataset, DataLoader

In [2]:
BertModel(BertConfig(
                vocab_size=1000,
                hidden_size=32,
                num_hidden_layers=4,
                num_attention_heads=2,
                intermediate_size=64,
                max_position_embeddings=5,
                num_labels=type_of_class
            ))

NameError: name 'type_of_class' is not defined

In [None]:


# 定义联合训练模型
class JointModel(nn.Module):
    def __init__(self, bert_model_name, num_labels, mlm_vocab_size=30522):
        super().__init__()
        # 共享的BERT编码器
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.config = self.bert.config
        
        # 分类任务头
        self.classifier = nn.Linear(self.config.hidden_size, num_labels)
        
        # MLM任务头
        self.mlm_head = nn.Linear(self.config.hidden_size, mlm_vocab_size)
        
        # SOP任务头
        self.sop_head = nn.Linear(self.config.hidden_size, 2)  # 二分类：顺序是否正确

    def forward(self, input_ids, attention_mask, token_type_ids=None, 
                mlm_labels=None, sop_labels=None):
        # 共享编码器输出
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            return_dict=True
        )
        sequence_output = outputs.last_hidden_state  # [batch, seq_len, hidden]
        pooled_output = outputs.pooler_output        # [batch, hidden]

        # 分类任务
        cls_logits = self.classifier(pooled_output)  # [batch, num_labels]

        # MLM任务
        mlm_logits = self.mlm_head(sequence_output)  # [batch, seq_len, vocab]

        # SOP任务
        sop_logits = self.sop_head(pooled_output)    # [batch, 2]

        # 计算各任务损失
        losses = {}
        if mlm_labels is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            mlm_loss = loss_fct(
                mlm_logits.view(-1, self.config.vocab_size),
                mlm_labels.view(-1)
            )
            losses["mlm"] = mlm_loss

        if sop_labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            sop_loss = loss_fct(sop_logits.view(-1, 2), sop_labels.view(-1))
            losses["sop"] = sop_loss

        return cls_logits, losses

# 自定义数据集
class JointDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=128, mlm_prob=0.15):
        """
        texts: 原始文本列表
        labels: 分类标签
        tokenizer: BERT tokenizer
        mlm_prob: 随机mask概率
        """
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.mlm_prob = mlm_prob

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]

        # 1. 原始文本编码（用于分类）
        encoding = self.tokenizer(
            text,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # 2. 生成MLM数据
        input_ids = encoding["input_ids"].squeeze(0)
        mlm_input_ids = input_ids.clone()
        mlm_labels = torch.full_like(input_ids, -100)  # 默认忽略非mask位置
        
        # 随机选择15%的token进行mask
        rand = torch.rand(input_ids.shape)
        mask_indices = (rand < self.mlm_prob) & (input_ids != self.tokenizer.cls_token_id) & (input_ids != self.tokenizer.sep_token_id)
        
        # 80%替换为[MASK], 10%随机词, 10%保持原词
        replace_mask = mask_indices & (torch.rand(mask_indices.shape) < 0.8)
        random_mask = mask_indices & (torch.rand(mask_indices.shape) < 0.5) & ~replace_mask
        
        mlm_input_ids[replace_mask] = self.tokenizer.mask_token_id
        mlm_input_ids[random_mask] = torch.randint(0, self.tokenizer.vocab_size, (sum(random_mask),))
        mlm_labels[mask_indices] = input_ids[mask_indices]

        # 3. 生成SOP数据（句子顺序预测）
        sentences = text.split('.')  # 简单按句号分割
        if len(sentences) >= 2:
            # 50%概率交换前两句
            if torch.rand(1) < 0.5:
                sent1, sent2 = sentences[0], sentences[1]
                sop_label = 1  # 顺序正确
            else:
                sent1, sent2 = sentences[1], sentences[0]
                sop_label = 0  # 顺序错误
            
            sop_text = f"{sent1} [SEP] {sent2}"
            sop_encoding = self.tokenizer(
                sop_text,
                max_length=self.max_len,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            sop_input_ids = sop_encoding["input_ids"].squeeze(0)
            sop_attention_mask = sop_encoding["attention_mask"].squeeze(0)
        else:
            sop_input_ids = input_ids
            sop_attention_mask = encoding["attention_mask"].squeeze(0)
            sop_label = -100  # 忽略此样本

        return {
            "input_ids": input_ids,
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "labels": torch.tensor(label, dtype=torch.long),
            "mlm_input_ids": mlm_input_ids,
            "mlm_labels": mlm_labels,
            "sop_input_ids": sop_input_ids,
            "sop_attention_mask": sop_attention_mask,
            "sop_labels": torch.tensor(sop_label, dtype=torch.long)
        }

# 训练循环示例
def train_epoch(model, dataloader, optimizer, device, loss_weights={'cls': 1.0, 'mlm': 0.5, 'sop': 0.5}):
    model.train()
    total_loss = 0
    
    for batch in dataloader:
        optimizer.zero_grad()
        
        # 分类任务
        cls_output, _ = model(
            input_ids=batch['input_ids'].to(device),
            attention_mask=batch['attention_mask'].to(device)
        )
        cls_loss = nn.CrossEntropyLoss()(cls_output, batch['labels'].to(device))
        
        # MLM任务
        _, mlm_losses = model(
            input_ids=batch['mlm_input_ids'].to(device),
            attention_mask=batch['attention_mask'].to(device),
            mlm_labels=batch['mlm_labels'].to(device)
        )
        
        # SOP任务
        _, sop_losses = model(
            input_ids=batch['sop_input_ids'].to(device),
            attention_mask=batch['sop_attention_mask'].to(device),
            sop_labels=batch['sop_labels'].to(device)
        )
        
        # 加权总损失
        total_loss = (
            loss_weights['cls'] * cls_loss +
            loss_weights['mlm'] * mlm_losses['mlm'] +
            loss_weights['sop'] * sop_losses['sop']
        )
        
        total_loss.backward()
        optimizer.step()
        
    return total_loss.item()

# 使用示例
if __name__ == "__main__":
    # 配置参数
    BERT_MODEL = "bert-base-chinese"
    NUM_CLASSES = 14
    BATCH_SIZE = 16
    MAX_LEN = 128
    
    # 初始化组件
    tokenizer = BertTokenizer.from_pretrained(BERT_MODEL)
    model = JointModel(BERT_MODEL, NUM_CLASSES)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    # 示例数据（需替换为真实数据）
    train_texts = ["这是一个正样本。包含领域相关术语。", "这是负样本。数据增强很重要。"]
    train_labels = [1, 0]
    
    dataset = JointDataset(train_texts, train_labels, tokenizer, MAX_LEN)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    
    # 优化器
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    
    # 训练1个epoch
    loss = train_epoch(model, dataloader, optimizer, device)
    print(f"Training loss: {loss:.4f}")