In [14]:
import pandas as pd

# 读取 MBTI 数据集
df = pd.read_csv("../data/mbti_1.csv")

# 查看数据前几行
df.head()

Unnamed: 0,type,posts
0,INFJ,'http://www.youtube.com/watch?v=qsXHcwe3krw|||...
1,ENTP,'I'm finding the lack of me in these posts ver...
2,INTP,'Good one _____ https://www.youtube.com/wat...
3,INTJ,"'Dear INTP, I enjoyed our conversation the o..."
4,ENTJ,'You're fired.|||That's another silly misconce...


In [15]:
#2️⃣ 处理数据（dataset.py）
from transformers import BertTokenizer
from torch.utils.data import Dataset
from sklearn.preprocessing import LabelEncoder

# 预处理数据
class MBTIDataset(Dataset):
    def __init__(self, texts, labels):
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        self.encodings = self.tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt")
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item["labels"] = self.labels[idx]
        return item

# 进行标签编码
label_encoder = LabelEncoder()
df["label"] = label_encoder.fit_transform(df["type"])

# 创建数据集
dataset = MBTIDataset(df["posts"].tolist(), df["label"].values)
print(f"数据集大小: {len(dataset)}")


数据集大小: 8675


In [23]:
#3️⃣ 训练模型（train.py）
import torch
from transformers import BertForSequenceClassification, Trainer, TrainingArguments

# 加载 BERT 预训练模型
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=16)

# 训练参数
training_args = TrainingArguments(
    output_dir="../models",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    logging_dir="../logs",
)

# 使用 Hugging Face `Trainer` 训练模型
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
)

# 开始训练
trainer.train()

# 保存模型
model.save_pretrained("../models/mbti_bert")
print("✅ 模型训练完成，已保存到 models/mbti_bert")


ModuleNotFoundError: No module named 'torc'