In [14]:
import pandas as pd

# 读取 MBTI 数据集
df = pd.read_csv("../data/mbti_1.csv")

# 查看数据前几行
df.head()

Unnamed: 0,type,posts
0,INFJ,'http://www.youtube.com/watch?v=qsXHcwe3krw|||...
1,ENTP,'I'm finding the lack of me in these posts ver...
2,INTP,'Good one _____ https://www.youtube.com/wat...
3,INTJ,"'Dear INTP, I enjoyed our conversation the o..."
4,ENTJ,'You're fired.|||That's another silly misconce...


In [15]:
#2️⃣ 处理数据（dataset.py）
from transformers import BertTokenizer
from torch.utils.data import Dataset
from sklearn.preprocessing import LabelEncoder

# 预处理数据
class MBTIDataset(Dataset):
    def __init__(self, texts, labels):
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        self.encodings = self.tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt")
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item["labels"] = self.labels[idx]
        return item

# 进行标签编码
label_encoder = LabelEncoder()
df["label"] = label_encoder.fit_transform(df["type"])

# 创建数据集
dataset = MBTIDataset(df["posts"].tolist(), df["label"].values)
print(f"数据集大小: {len(dataset)}")


数据集大小: 8675


In [29]:
import torch
import pandas as pd
from transformers import BertForSequenceClassification, Trainer, TrainingArguments, BertTokenizer
from torch.utils.data import Dataset
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split

# 读取数据
def load_data(file_path):
    df = pd.read_csv(file_path)
    return df

# 自定义数据集
class MBTIDataset(Dataset):
    def __init__(self, texts, labels):
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        self.encodings = self.tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt")
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item["labels"] = self.labels[idx]
        return item

# 读取数据集
df = load_data("../data/mbti_1.csv")

# 进行标签编码
label_encoder = LabelEncoder()
df["label"] = label_encoder.fit_transform(df["type"])

# 划分训练集和测试集（80% 训练，20% 评估）
train_texts, eval_texts, train_labels, eval_labels = train_test_split(
    df["posts"].tolist(), df["label"].values, test_size=0.2, random_state=42
)

# 生成训练和评估数据集
train_dataset = MBTIDataset(train_texts, 
                                 # 每个 epoch 进行一次评估
    save_strategy="epoch",          # 每个 epoch 保存一次模型
    per_device_train_batch_size=8,  # 训练 batch size
    per_device_eval_batch_size=8,   # 评估 batch size
    num_train_epochs=3,             # 训练 3 轮
    logging_dir="../logs",          # 训练日志存放位置
    load_best_model_at_end=True,    # 训练结束后自动加载最优模型
    save_total_limit=2              # 最多只保存 2 个模型 checkpoint
)

# 创建 Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,  # 传入评估数据集
)

# 开始训练
trainer.train()

# 训练完成后保存模型
model.save_pretrained("../models/mbti_bert")
print("✅ 训练完成，最佳模型已保存到 ../models/mbti_bert")



Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss
1,2.2308,1.757558
2,1.592,1.532281
3,1.2125,1.483758


✅ 训练完成，最佳模型已保存到 ../models/mbti_bert


In [1]:
#测试模型
import torch
from transformers import BertTokenizer, BertForSequenceClassification

# 加载训练好的模型
model_path = "../models/mbti_bert"
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForSequenceClassification.from_pretrained(model_path)

# 预测函数
def predict_mbti(text):
    inputs = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
        predicted_label = torch.argmax(outputs.logits).item()
    return predicted_label

# 测试
test_text = "I love exploring new technologies and solving problems with AI."
prediction = predict_mbti(test_text)
print(f"预测的 MBTI 类型: {prediction}")


预测的 MBTI 类型: 9
