In [1]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import Dataset, DataLoader
from pypinyin import lazy_pinyin
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
from sklearn.metrics.pairwise import cosine_similarity
import os
import json

device = 'cuda' if torch.cuda.is_available() else 'cpu'


print("环境信息")
print(f"PyTorch: {torch.__version__}")
print(f"设备: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

print("\n环境配置完成！\n")

环境信息
PyTorch: 2.0.0+cu118
设备: cuda
GPU: NVIDIA GeForce RTX 4090
VRAM: 23.5 GB

环境配置完成！



In [2]:
# Cell 6: 定义数据集和模型类
print("定义训练组件...")

from transformers import BertTokenizer, BertModel
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

# ============ 使用本地模型 ============
MODEL_NAME = '/root/bert-base-chinese'
# ====================================

class PairDataset(Dataset):
    """配对数据集 - 支持字符级和拼音级"""
    def __init__(self, data, tokenizer, max_length=128, use_pinyin=False):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.use_pinyin = use_pinyin
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # 根据 use_pinyin 选择数据源
        if self.use_pinyin:
            # 拼音级：使用转换后的拼音文本
            query_text = item['query']  # 在 pinyin_data 中，这已经是拼音了
            doc_text = item['document']
        else:
            # 字符级：使用原始中文
            query_text = item['query']
            doc_text = item['document']
        
        # Tokenize
        query_enc = self.tokenizer(
            query_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        doc_enc = self.tokenizer(
            doc_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'query_input_ids': query_enc['input_ids'].squeeze(0),
            'query_attention_mask': query_enc['attention_mask'].squeeze(0),
            'doc_input_ids': doc_enc['input_ids'].squeeze(0),
            'doc_attention_mask': doc_enc['attention_mask'].squeeze(0)
        }


class EmbeddingModel(nn.Module):
    """基于 BERT 的 Sentence Embedding 模型"""
    def __init__(self, model_name=MODEL_NAME):
        super().__init__()
        print(f"初始化模型: {model_name}")
        self.encoder = BertModel.from_pretrained(model_name)
        print(f"✓ 模型加载完成")
        
    def mean_pooling(self, token_embeddings, attention_mask):
        """平均池化"""
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(
            token_embeddings.size()
        ).float()
        sum_embeddings = torch.sum(
            token_embeddings * input_mask_expanded, 1
        )
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sum_embeddings / sum_mask
    
    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        embeddings = self.mean_pooling(
            outputs.last_hidden_state,
            attention_mask
        )
        # L2 归一化
        embeddings = F.normalize(embeddings, p=2, dim=1)
        return embeddings


def cosine_similarity_loss(query_emb, doc_emb):
    """余弦相似度损失"""
    similarity = F.cosine_similarity(query_emb, doc_emb, dim=1)
    loss = 1 - similarity.mean()
    return loss


# 测试加载
print("\n测试模型加载...")
try:
    tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
    print(f"Tokenizer 加载成功，词汇表大小: {len(tokenizer)}")
    
    # 测试字符级分词
    test_text = "糖尿病的症状"
    tokens = tokenizer.tokenize(test_text)
    print(f"字符级测试: '{test_text}' -> {tokens}")
    
    # 测试拼音级分词
    from pypinyin import lazy_pinyin
    pinyin_text = ' '.join(lazy_pinyin(test_text))
    pinyin_tokens = tokenizer.tokenize(pinyin_text)
    print(f"拼音级测试: '{pinyin_text}' -> {pinyin_tokens}")
    
    print("\n所有组件定义完成！")
except Exception as e:
    print(f"加载失败: {e}")
    print("请检查模型路径是否正确")

定义训练组件...

测试模型加载...
Tokenizer 加载成功，词汇表大小: 21128
字符级测试: '糖尿病的症状' -> ['糖', '尿', '病', '的', '症', '状']
拼音级测试: 'tang niao bing de zheng zhuang' -> ['tan', '##g', 'ni', '##ao', 'bing', 'de', 'zh', '##eng', 'zh', '##uan', '##g']

所有组件定义完成！


In [3]:
# 验证本地模型
print("验证模型加载...")

from transformers import BertTokenizer, BertModel

MODEL_PATH = '/root/bert-base-chinese'

# 1. 检查文件是否存在
import os
required_files = [
    'config.json',
    'pytorch_model.bin',
    'tokenizer_config.json',
    'vocab.txt'
]

print("检查必需文件:")
for file in required_files:
    path = os.path.join(MODEL_PATH, file)
    exists = os.path.exists(path)
    status = "✓" if exists else ""
    print(f"  {status} {file}")

# 2. 加载并测试
print("\n加载模型...")
try:
    tokenizer = BertTokenizer.from_pretrained(MODEL_PATH)
    model = BertModel.from_pretrained(MODEL_PATH)
    
    print(f"Tokenizer 词汇表大小: {len(tokenizer)}")
    print(f"模型参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
    
    # 3. 测试推理
    print("\n测试推理...")
    text = "糖尿病患者应该注意饮食"
    inputs = tokenizer(text, return_tensors='pt')
    outputs = model(**inputs)
    
    print(f"输入 shape: {inputs['input_ids'].shape}")
    print(f"输出 shape: {outputs.last_hidden_state.shape}")
    print(f"分词结果: {tokenizer.tokenize(text)}")
    
    print("\n模型加载和推理测试成功！")
    
except Exception as e:
    print(f"\n错误: {e}")

验证模型加载...
检查必需文件:
  ✓ config.json
  ✓ pytorch_model.bin
  ✓ tokenizer_config.json
  ✓ vocab.txt

加载模型...


Some weights of the model checkpoint at /root/bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Tokenizer 词汇表大小: 21128
模型参数量: 102.3M

测试推理...
输入 shape: torch.Size([1, 13])
输出 shape: torch.Size([1, 13, 768])
分词结果: ['糖', '尿', '病', '患', '者', '应', '该', '注', '意', '饮', '食']

模型加载和推理测试成功！


In [4]:
# Cell 7: 训练三个版本的模型
print("训练三个版本的模型...")

from torch.utils.data import DataLoader
import torch.optim as optim

tokenizer = BertTokenizer.from_pretrained('/root/bert-base-chinese')

# 1. 字符级
char_dataset = PairDataset(char_data, tokenizer, max_length=128)
char_loader = DataLoader(char_dataset, batch_size=16, shuffle=True)
char_model = EmbeddingModel('/root/bert-base-chinese')
char_model = train_model(char_model, char_loader, epochs=3, model_name="字符级")
torch.save(char_model.state_dict(), 'char_model.pt')

# 2. 拼音级
pinyin_dataset = PairDataset(pinyin_data, tokenizer, max_length=128)
pinyin_loader = DataLoader(pinyin_dataset, batch_size=16, shuffle=True)
pinyin_model = EmbeddingModel('/root/bert-base-chinese')
pinyin_model = train_model(pinyin_model, pinyin_loader, epochs=3, model_name="拼音级")
torch.save(pinyin_model.state_dict(), 'pinyin_model.pt')

# 3. 笔画级（StrokeNet 风格）
stroke_dataset = PairDataset(stroke_data, tokenizer, max_length=128)
stroke_loader = DataLoader(stroke_dataset, batch_size=16, shuffle=True)
stroke_model = EmbeddingModel('/root/bert-base-chinese')
stroke_model = train_model(stroke_model, stroke_loader, epochs=3, model_name="笔画级")
torch.save(stroke_model.state_dict(), 'stroke_model.pt')

print("\n✓ 三个模型训练完成！")

训练三个版本的模型...


NameError: name 'char_data' is not defined