In [1]:
import torch
from transformers import BertModel, BertTokenizer
import pandas as pd
import numpy as np
import pickle
from datetime import datetime

In [2]:
# 加载BERT模型和tokenizer
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)

In [20]:
def extract_features(batch_size, text_file, feature_file):
    model.eval()  # 切换为评估模式，以避免dropout等影响

    data = pd.read_csv(text_file, sep='\t', encoding='utf-8', header=None, usecols=[1])
    texts = data[1].tolist()
    num_rows = len(texts)
    num_batches = num_rows // batch_size + 1
    features = []
    
    timestamp = datetime.now().strftime("[%H:%M:%S]")
    print(f'{timestamp} 开始从{text_file}提取特征')

    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = min((i + 1) * batch_size, len(data))
        batch_texts = texts[start_idx:end_idx]
        encoded_inputs = tokenizer(batch_texts, padding='longest', truncation=True, return_tensors='pt')

        with torch.no_grad():
            outputs = model(**encoded_inputs)
            pooled_output = outputs.pooler_output
            features.append(pooled_output.numpy())

        if (i + 1) % 10 == 0 or i == num_batches - 1:
            timestamp = datetime.now().strftime("[%H:%M:%S]")
            completed_batches = i + 1
            completed_rows = (completed_batches - 1) * batch_size + len(batch_texts)
            print(
                f'{timestamp} 已完成{completed_batches}/{num_batches}批次，{completed_rows}/{num_rows}行')

    text_features = np.concatenate(features)
    pickle.dump(text_features, open(feature_file, 'wb'))
    timestamp = datetime.now().strftime("[%H:%M:%S]")
    print(f'{timestamp} 文本特征提取完成，已保存至{feature_file}')

In [22]:
dataset_name = 'DB15K'
text_file = f'../IMF-Pytorch/datasets/{dataset_name}/entity_description.txt'
feature_file = f'{dataset_name}/text_features.pkl'
batch_size = 16

extract_features(batch_size, text_file, feature_file)

[22:46:21] 开始从../IMF-Pytorch/datasets/DB15K/entity_description.txt提取特征
[22:47:09] 已完成10/803批次，160/12842行
[22:47:55] 已完成20/803批次，320/12842行
[22:48:39] 已完成30/803批次，480/12842行
[22:49:28] 已完成40/803批次，640/12842行
[22:50:19] 已完成50/803批次，800/12842行
[22:51:09] 已完成60/803批次，960/12842行
[22:52:00] 已完成70/803批次，1120/12842行
[22:52:51] 已完成80/803批次，1280/12842行
[22:53:40] 已完成90/803批次，1440/12842行
[22:54:29] 已完成100/803批次，1600/12842行
[22:55:21] 已完成110/803批次，1760/12842行
[22:56:14] 已完成120/803批次，1920/12842行
[22:57:04] 已完成130/803批次，2080/12842行
[22:57:52] 已完成140/803批次，2240/12842行
[22:58:47] 已完成150/803批次，2400/12842行
[22:59:48] 已完成160/803批次，2560/12842行
[23:00:42] 已完成170/803批次，2720/12842行
[23:01:39] 已完成180/803批次，2880/12842行
[23:02:38] 已完成190/803批次，3040/12842行
[23:03:31] 已完成200/803批次，3200/12842行
[23:04:24] 已完成210/803批次，3360/12842行
[23:05:18] 已完成220/803批次，3520/12842行
[23:06:06] 已完成230/803批次，3680/12842行
[23:07:01] 已完成240/803批次，3840/12842行
[23:07:54] 已完成250/803批次，4000/12842行
[23:08:46] 已完成260/803批次，4160/12842行
[23:09:3