In [21]:
from transformers import AutoTokenizer
import transformers
import torch
from datasets import load_dataset, load_from_disk
from transformers import AutoModel
from transformers import AdamW
import random
import numpy as np
import torch.nn.functional as F
from torch.utils.data import Dataset
from tqdm import tqdm
from torchcrf import CRF  # 引入 CRF

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [22]:
model_path = "../model/GujiRoBERTa_jian_fan"
tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
model = AutoModel.from_pretrained(model_path, local_files_only=True).to(device)

Some weights of BertModel were not initialized from the model checkpoint at ../model/GujiRoBERTa_jian_fan and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [23]:
class TextDataset(Dataset):
    def __init__(self, text_file, max_length=510):
        self.text_file = text_file
        self.max_length = max_length
        self.texts = self._load_data()
        self.dataset = self._filter_long_sentences()

    def _filter_long_sentences(self):
        filtered_texts = []
        for text in self.texts:
            if len(text) <= self.max_length:
                filtered_texts.append(text)
        return filtered_texts

    def _load_data(self):
        texts = []
        with open(self.text_file, 'r', encoding='utf-8') as f_text:
            for text in f_text:
                texts.append(list(text.strip()))
        return texts

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.dataset[idx]

text_file = '../data/text_B.txt'
dataset = TextDataset(text_file)

def collate_fn(data):
    inputs = tokenizer.batch_encode_plus(data,
                                         truncation=True,
                                         padding=True,
                                         return_tensors='pt',
                                         is_split_into_words=True) 
    return inputs.to(device)

#定义下游模型
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.tuneing = False
        self.pretrained = None

        self.rnn = torch.nn.LSTM(768, 768, batch_first=True)
        self.fc = torch.nn.Linear(768, 14)
        self.crf = CRF(14, batch_first=True)

    def forward(self, inputs, labels=None):
        if self.tuneing:
            out = self.pretrained(**inputs).last_hidden_state
        else:
            with torch.no_grad():
                out = pretrained(**inputs).last_hidden_state

        out, _ = self.rnn(out)
        out = self.fc(out)

        if labels is not None:
            # 如果提供了 labels，则计算 CRF loss
            mask = inputs['attention_mask'].bool()
            loss = -self.crf(out, labels, mask=mask, reduction='mean')
            return loss
        else:
            # 否则，使用 CRF 进行解码
            mask = inputs['attention_mask'].bool()
            prediction = self.crf.decode(out, mask=mask)
            return prediction

    def fine_tuneing(self, tuneing):
        self.tuneing = tuneing
        if tuneing:
            for i in pretrained.parameters():
                i.requires_grad = True

            pretrained.train()
            self.pretrained = pretrained
        else:
            for i in pretrained.parameters():
                i.requires_grad_(False)

            pretrained.eval()
            self.pretrained = None

model = Model().to(device)

# loader = torch.utils.data.DataLoader(dataset=dataset,
#                                      batch_size=1,
#                                      collate_fn=collate_fn,
#                                      shuffle=False)

# def predict():
#     model_load = torch.load('../model/NER_crf_attention_C.model', weights_only=False)
#     model_load.eval()

#     loader_test = torch.utils.data.DataLoader(dataset=TextDataset(text_file),
#                                               batch_size=2,
#                                               collate_fn=collate_fn,
#                                               shuffle=False)

#     for i, inputs in enumerate(loader_test):
#         break

#     with torch.no_grad():
#         outs = model_load(inputs)

#     for i in range(2):
#         select = inputs['attention_mask'][i] == 1
#         input_id = inputs['input_ids'][i, select]
#         out = outs[i]

#         print(tokenizer.decode(input_id).replace(' ', ''))

#         s = ''
#         for j in range(len(out)):
#             if out[j] == 0:
#                 s += '·'
#                 continue
#             s += tokenizer.decode(input_id[j])
#             s += str(out[j])
#         print("Out:", s)
#         print('==========================')
# predict()

In [24]:
def predict():
    # 读入文本文件
    with open('../data/TestSet/test_B.txt', 'r', encoding='utf-8') as f:
        text = f.read()
    
    # 加载模型
    model_load = torch.load('../model/NER_crf_lstm_B.model', map_location=device, weights_only=False)
    model_load.eval()

    # 处理文本
    inputs = tokenizer([list(line) for line in text.split("\n") if line], 
                       truncation=True, 
                       padding=True, 
                       return_tensors='pt', 
                       is_split_into_words=True).to(device)
    
    with torch.no_grad():
        outs = model_load(inputs)

    results = []
    for i in range(len(outs)):
        select = inputs['attention_mask'][i] == 1
        input_id = inputs['input_ids'][i, select]
        out = outs[i]

        # 将标签结果保存到列表
        result = [str(label) for label in out][1 : -1]
        results.append("[" + ", ".join(result) + "]")

    # 将结果写入输出文件
    with open('../data/TestSet/output_B.txt', 'w', encoding='utf-8') as f:
        for result in results:
            f.write(result + "\n")

predict()

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [25]:
import os

def denumericalize_labels(numericalized_labels, label_map):
    """将数字列表转换为标签列表."""
    inverse_label_map = {v: k for k, v in label_map.items()}
    denumericalized_labels = []
    for label_list in numericalized_labels:
        denumericalized_labels.append([inverse_label_map[label] for label in label_list])
    return denumericalized_labels

def create_original_format(text_file, label_file, output_file):
    """将文本文件和标签文件组合为原始格式."""
    with open(text_file, 'r', encoding='utf-8') as text_f, open(label_file, 'r', encoding='utf-8') as label_f:
        texts = text_f.readlines()
        labels = label_f.readlines()

    with open(output_file, 'w', encoding='utf-8') as output_f:
        for text, label in zip(texts, labels):
            text = text.strip()
            label = eval(label.strip())  # 将字符串转换回列表
            denumericalized_label = denumericalize_labels([label], label_map)[0]
            for word, label in zip(text, denumericalized_label):
                output_f.write(f"{word}\t{label}\n")

# label_map = {
#     "O": 0,
#     "B-NR": 1,
#     "M-NR": 2,
#     "E-NR": 3,
#     "S-NR": 4,
#     "B-NS": 5,
#     "M-NS": 6,
#     "E-NS": 7,
#     "S-NS": 8,
#     "B-NB": 9,
#     "M-NB": 10,
#     "E-NB": 11,
#     "S-NB": 12,
#     "B-NO": 13,
#     "M-NO": 14,
#     "E-NO": 15,
#     "S-NO": 16,
#     "B-NG": 17,
#     "M-NG": 18,
#     "E-NG": 19,
#     "S-NG": 20,
#     "B-T": 21,
#     "M-T": 22,
#     "E-T": 23,
#     "S-T": 24,
# }

label_map = {
    "O": 0,
    "B-NR": 1,
    "M-NR": 2,
    "E-NR": 3,
    "S-NR": 4,
    "B-NS": 5,
    "M-NS": 6,
    "E-NS": 7,
    "S-NS": 8,
    "B-T": 9,
    "M-T": 10,
    "E-T": 11,
    "S-T": 12,
}

# label_map = {
#     "O": 0,
#     "B-ZD": 1,
#     "M-ZD": 2,
#     "E-ZD": 3,
#     "S-ZD": 4,
#     "B-ZZ": 5,
#     "M-ZZ": 6,
#     "E-ZZ": 7,
#     "S-ZZ": 8,
#     "B-ZF": 9,
#     "M-ZF": 10,
#     "E-ZF": 11,
#     "S-ZF": 12,
#     "B-ZP": 13,
#     "M-ZP": 14,
#     "E-ZP": 15,
#     "S-ZP": 16,
#     "B-ZS": 17,
#     "M-ZS": 18,
#     "E-ZS": 19,
#     "S-ZS": 20,
#     "B-ZA": 21,
#     "M-ZA": 22,
#     "E-ZA": 23,
#     "S-ZA": 24,
# }

text_test_file = '../data/TestSet/test_B.txt'
label_test_file = '../data/TestSet/output_B.txt'
output_test_file = '../data/TestSet/predicted_B.txt'

create_original_format(text_test_file, label_test_file, output_test_file)

In [None]:
# def find_first_difference(file1, file2):
#     with open(file1, 'r', encoding='utf-8') as f1, open(file2, 'r', encoding='utf-8') as f2:
#         lines1 = [line.strip().split()[0] for line in f1 if line.strip()]  # 提取第一列内容
#         lines2 = [line.strip().split()[0] for line in f2 if line.strip()]  # 提取第一列内容

#     # 比较两个文本内容
#     for i, (char1, char2) in enumerate(zip(lines1, lines2)):
#         if char1 != char2:
#             return i + 1  # 返回第一个不同的位置（从1开始计数）

#     # 如果一个文件比另一个文件长，返回第一个多出字符的位置
#     if len(lines1) != len(lines2):
#         return min(len(lines1), len(lines2)) + 1

#     # 如果完全相同，返回-1
#     return -1

# file1 = "../data/EvaHan2025_traingdata/trainset_C.txt"
# file2 = "../data/reconstructed_testset_C.txt"
# result = find_first_difference(file1, file2)
# print(f"第一个不同的位置是：{result}")