In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertConfig, BertForPreTraining
from datasets import load_dataset
from tqdm import tqdm
import torch.nn.functional as F
import random

In [2]:
import logging
logging.disable(logging.WARNING)

In [3]:
# !pip freeze > requirements.txt

In [4]:
# 1. 加载数据集（wikitext-2，简化版Wikipedia）
# from datasets import load_dataset
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
train_texts = [text for text in dataset["train"]["text"] if len(text.strip()) > 0]  # 过滤空文本
# print("\n".join(train_texts[:5]))

In [5]:
# 2. 初始化BERT tokenizer（使用预训练的tokenizer，保证词汇表一致性）
# from transformers import BertTokenizer, BertConfig, BertForPreTraining
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
MAX_SEQ_LENGTH = 128  # 句子最大长度（BERT-base默认512，这里简化为128）
MASK_PROB = 0.15  # MLM任务的掩盖概率

In [6]:
# 3. 构建NSP句子对数据集
# from torch.utils.data import Dataset, DataLoader
class BertPretrainDataset(Dataset):
    def __init__(self, texts, tokenizer, max_seq_length, mask_prob):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length
        self.mask_prob = mask_prob
        self.sentences = self._split_into_sentences()  # 拆分所有文本为句子列表

    def _split_into_sentences(self):
        """将文本拆分为句子（简单按句号分割，实际可用spaCy等工具优化）"""
        sentences = []
        for text in self.texts:
            # 分割句子并过滤空句子
            sents = [sent.strip() for sent in text.split(".") if len(sent.strip()) > 5]
            sentences.extend(sents)
        return sentences

    def _create_nsp_example(self, idx):
        """创建NSP任务的句子对：50%正例（连续句子），50%负例（随机句子）"""
        # 正例：取第idx句和第idx+1句（确保不越界）
        if random.random() < 0.5 and idx < len(self.sentences) - 1:
            sentence1 = self.sentences[idx]
            sentence2 = self.sentences[idx + 1]
            is_next = 1  # 下一句标记
        # 负例：取第idx句和随机一句（排除连续句）
        else:
            sentence1 = self.sentences[idx]
            random_idx = random.choice([i for i in range(len(self.sentences)) if i != idx and i != idx + 1])
            sentence2 = self.sentences[random_idx]
            is_next = 0  # 非下一句标记
        return sentence1, sentence2, is_next

    def _apply_mlm(self, input_ids):
        """对输入token应用MLM掩码策略"""
        input_ids = input_ids.clone()  # 避免修改原数据
        vocab_size = self.tokenizer.vocab_size

        # 随机选择要掩盖的位置（排除[CLS]、[SEP]、[PAD]）
        mask_positions = torch.bernoulli(torch.full(input_ids.shape, self.mask_prob)).bool()
        mask_positions = mask_positions & (input_ids != self.tokenizer.cls_token_id) & (input_ids != self.tokenizer.sep_token_id) & (input_ids != self.tokenizer.pad_token_id)

        for pos in torch.where(mask_positions)[0]:
            if random.random() < 0.8:
                # 80%替换为[MASK]
                input_ids[pos] = self.tokenizer.mask_token_id
            elif random.random() < 0.5:  # 剩下的20%中，50%替换为随机token
                input_ids[pos] = random.randint(0, vocab_size - 1)
            # 10%保持原token（不做处理）

        # 生成MLM任务的标签：仅掩盖位置为原token，其他位置为-100（PyTorch忽略-100的损失）
        mlm_labels = input_ids.clone()
        mlm_labels[~mask_positions] = -100
        return input_ids, mlm_labels

    def __len__(self):
        return len(self.sentences) - 1  # 避免idx+1越界

    def __getitem__(self, idx):
        # 1. 生成NSP句子对
        sentence1, sentence2, is_next = self._create_nsp_example(idx)

        # 2. Tokenize句子对（BERT要求格式：[CLS] sent1 [SEP] sent2 [SEP]）
        encoded = self.tokenizer(
            sentence1, sentence2,
            padding="max_length",
            truncation="longest_first",
            # truncation="only_first",
            max_length=self.max_seq_length,
            return_tensors="pt"
        )
        input_ids = encoded["input_ids"].flatten()  # (max_seq_length,)
        attention_mask = encoded["attention_mask"].flatten()  # (max_seq_length,)
        token_type_ids = encoded["token_type_ids"].flatten()  # (max_seq_length,)：0=sent1，1=sent2

        # 3. 应用MLM掩码
        input_ids_mlm, mlm_labels = self._apply_mlm(input_ids)

        return {
            "input_ids": input_ids_mlm,
            "attention_mask": attention_mask,
            "token_type_ids": token_type_ids,
            "mlm_labels": mlm_labels,
            "nsp_label": torch.tensor(is_next, dtype=torch.long)
        }

In [7]:
# 构建数据集和数据加载器
train_dataset = BertPretrainDataset(train_texts, tokenizer, MAX_SEQ_LENGTH, MASK_PROB)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)

In [8]:
# from transformers import BertConfig, BertForPreTraining
# 配置BERT模型参数（简化版BERT-base，实际BERT-base为12层、768隐藏层、12头注意力）
config = BertConfig(
    vocab_size=tokenizer.vocab_size,  # 词汇表大小（bert-base-uncased为30522）
    hidden_size=768,  # 隐藏层维度
    num_hidden_layers=6,  # 编码器层数（简化为6层，原12层）
    num_attention_heads=12,  # 注意力头数
    intermediate_size=3072,  # 前馈网络隐藏层维度（768*4）
    max_position_embeddings=MAX_SEQ_LENGTH,  # 最大序列长度
    type_vocab_size=2,  # token_type_ids的类别数（0和1）
    hidden_dropout_prob=0.1,  # Dropout概率
    attention_probs_dropout_prob=0.1
)

# 初始化预训练模型（BertForPreTraining已包含MLM和NSP头）
model = BertForPreTraining(config)
print(f"模型参数量：{sum(p.numel() for p in model.parameters()):,}")  # 约8000万参数（简化版）

模型参数量：67,284,284


In [10]:
# import torch
device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# 优化器：AdamW（权重衰减防止过拟合）
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=5e-5,  # BERT预训练常用学习率
    weight_decay=0.01  # 权重衰减
)

# 训练参数
EPOCHS = 1  # 预训练轮数（实际需10+轮，这里简化为3轮）
accumulation_steps = 4  # 梯度累积（模拟更大batch_size，如32*4=128）

In [11]:
model.train()
total_step = 0

for epoch in range(EPOCHS):
    epoch_loss = 0.0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")

    for batch_idx, batch in enumerate(progress_bar):
        # 1. 数据移到GPU/CPU
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        token_type_ids = batch["token_type_ids"].to(device)
        mlm_labels = batch["mlm_labels"].to(device)
        nsp_label = batch["nsp_label"].to(device)

        # 2. 前向传播：模型输出MLM logits和NSP logits
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            labels=mlm_labels,  # MLM任务标签
            next_sentence_label=nsp_label  # NSP任务标签
        )
        # 旧版outputs包含：total_loss, mlm_loss, nsp_loss, mlm_logits, nsp_logits
        total_loss = outputs.loss
        mlm_logits = outputs.prediction_logits  # (batch_size, seq_len, vocab_size)
        mlm_loss = F.cross_entropy(
            mlm_logits.reshape(-1, tokenizer.vocab_size),  # (batch*seq_len, vocab_size)
            mlm_labels.reshape(-1),  # (batch*seq_len,)
            ignore_index=-100  # 忽略非掩码位置
        )
        nsp_logits = outputs.seq_relationship_logits  # (batch_size, 2)
        nsp_loss = F.cross_entropy(nsp_logits, nsp_label)

        # 3. 反向传播（梯度累积）
        total_loss = total_loss / accumulation_steps  # 累积步长归一化
        total_loss.backward()

        # 4. 梯度更新（每accumulation_steps步更新一次）
        if (batch_idx + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
            total_step += 1

        # 5. 统计损失
        epoch_loss += total_loss.item() * accumulation_steps  # 还原真实损失
        progress_bar.set_postfix({
            "Total Loss": f"{total_loss.item()*accumulation_steps:.4f}",
            "MLM Loss": f"{mlm_loss.item():.4f}",
            "NSP Loss": f"{nsp_loss.item():.4f}"
        })

    # 打印每轮平均损失
    avg_epoch_loss = epoch_loss / len(train_loader)
    print(f"Epoch {epoch+1} Average Loss: {avg_epoch_loss:.4f}")

# 保存预训练模型（供下游任务微调使用）
model.save_pretrained("./bert_pretrained_simple")
tokenizer.save_pretrained("./bert_pretrained_simple")
print("预训练完成，模型已保存到 ./bert_pretrained_simple")

Epoch 1/1: 100%|██████████| 2788/2788 [07:01<00:00,  6.61it/s, Total Loss=1.0720, MLM Loss=0.3783, NSP Loss=0.6937]


Epoch 1 Average Loss: 1.7253
预训练完成，模型已保存到 ./bert_pretrained_simple


## TEST


#### TEST my model

In [21]:
model.eval()
test_sentence = "I [MASK] a book yesterday."
encoded = tokenizer(test_sentence, return_tensors="pt").to(device)
# 2. 前向传播预测[MASK]（无需计算损失）
with torch.no_grad():  # 禁用梯度计算，节省内存
    outputs = model(**encoded)
    mlm_logits = outputs.prediction_logits  # (1, seq_len, vocab_size)

    # 3. 找到[MASK]的位置
    mask_token_id = tokenizer.mask_token_id
    mask_positions = torch.where(encoded["input_ids"][0] == mask_token_id)[0]
    
    # 确保找到至少一个[MASK]
    if len(mask_positions) == 0:
        print("未找到[MASK]标记！")
    else:
        mask_pos = mask_positions[0]  # 取第一个[MASK]的位置
        # 4. 预测[MASK]位置的token（取logits最大值）
        predicted_token_id = mlm_logits[0, mask_pos].argmax(dim=-1)
        predicted_token = tokenizer.decode(predicted_token_id)
        
        # 5. 输出结果
        print(f"原句子：{test_sentence}")
        print(f"预测结果：I {predicted_token} a book yesterday.")

原句子：I [MASK] a book yesterday.
预测结果：I [MASK] a book yesterday.


In [18]:
test_sent1 = "I like playing basketball."
test_sent2_pos = "It is my favorite sport."  # 正例
test_sent2_neg = "The sky is blue today."  # 负例

# 正例预测
encoded_pos = tokenizer(test_sent1, test_sent2_pos, return_tensors="pt").to(device)
with torch.no_grad():
    outputs_pos = model(**encoded_pos)
    nsp_logits_pos = outputs_pos.seq_relationship_logits
    nsp_pred_pos = nsp_logits_pos.argmax(dim=-1).item()
print(f"句子对1：{test_sent1} | {test_sent2_pos} → {'连贯' if nsp_pred_pos == 1 else '不连贯'}")

# 负例预测
encoded_neg = tokenizer(test_sent1, test_sent2_neg, return_tensors="pt").to(device)
with torch.no_grad():
    outputs_neg = model(**encoded_neg)
    nsp_logits_neg = outputs_neg.seq_relationship_logits
    nsp_pred_neg = nsp_logits_neg.argmax(dim=-1).item()
print(f"句子对2：{test_sent1} | {test_sent2_neg} → {'连贯' if nsp_pred_neg == 1 else '不连贯'}")

句子对1：I like playing basketball. | It is my favorite sport. → 连贯
句子对2：I like playing basketball. | The sky is blue today. → 连贯


#### Test Official Model

In [22]:
from transformers import BertTokenizer, BertForPreTraining, BertForMaskedLM, BertForNextSentencePrediction
import torch

# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 加载官方预训练模型和Tokenizer
model_name = "bert-base-uncased"  # 官方预训练模型
tokenizer = BertTokenizer.from_pretrained(model_name)

# 按需加载不同任务的模型（官方模型已包含所有预训练权重）
# 1. 同时支持MLM+NSP的模型（BertForPreTraining）
model_pretrain = BertForPreTraining.from_pretrained(model_name).to(device)
# 2. 仅MLM的轻量模型（BertForMaskedLM，推荐用于掩码预测）
model_mlm = BertForMaskedLM.from_pretrained(model_name).to(device)
# 3. 仅NSP的模型（BertForNextSentencePrediction）
model_nsp = BertForNextSentencePrediction.from_pretrained(model_name).to(device)

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [23]:
with torch.no_grad():  # 禁用梯度计算，节省内存
    outputs = model_pretrain(**encoded)
    mlm_logits = outputs.prediction_logits  # (1, seq_len, vocab_size)

    # 3. 找到[MASK]的位置
    mask_token_id = tokenizer.mask_token_id
    mask_positions = torch.where(encoded["input_ids"][0] == mask_token_id)[0]
    
    # 确保找到至少一个[MASK]
    if len(mask_positions) == 0:
        print("未找到[MASK]标记！")
    else:
        mask_pos = mask_positions[0]  # 取第一个[MASK]的位置
        # 4. 预测[MASK]位置的token（取logits最大值）
        predicted_token_id = mlm_logits[0, mask_pos].argmax(dim=-1)
        predicted_token = tokenizer.decode(predicted_token_id)
        
        # 5. 输出结果
        print(f"原句子：{test_sentence}")
        print(f"预测结果：I {predicted_token} a book yesterday.")

原句子：I [MASK] a book yesterday.
预测结果：I read a book yesterday.


In [27]:
test_sent1 = "I like playing basketball."
test_sent2_pos = "It is my favorite sport."  # 正例
test_sent2_neg = "The sky is blue today."  # 负例

# 正例预测
encoded_pos = tokenizer(test_sent1, test_sent2_pos, return_tensors="pt").to(device)
with torch.no_grad():
    outputs_pos = model_pretrain(**encoded_pos)
    nsp_logits_pos = outputs_pos.seq_relationship_logits
    nsp_pred_pos = nsp_logits_pos.argmax(dim=-1).item()
print(f"句子对1：{test_sent1} | {test_sent2_pos} → {'连贯' if nsp_pred_pos == 1 else '不连贯'}")

# 负例预测
encoded_neg = tokenizer(test_sent1, test_sent2_neg, return_tensors="pt").to(device)
with torch.no_grad():
    outputs_neg = model_pretrain(**encoded_neg)
    nsp_logits_neg = outputs_neg.seq_relationship_logits
    nsp_pred_neg = nsp_logits_neg.argmax(dim=-1).item()
print(f"句子对2：{test_sent1} | {test_sent2_neg} → {'连贯' if nsp_pred_neg == 1 else '不连贯'}")

句子对1：I like playing basketball. | It is my favorite sport. → 不连贯
句子对2：I like playing basketball. | The sky is blue today. → 不连贯
