In [1]:
from whitespace_correction import WhitespaceCorrector
from tqdm.notebook import tqdm

from transformers import AutoTokenizer, AutoModel
import os
from sentence_transformers.util import cos_sim
import pandas as pd

os.environ['http_proxy'] = 'http://127.0.0.1:7890'
os.environ['https_proxy'] = 'http://127.0.0.1:7890'

In [2]:
model_semantic_similarity = AutoModel.from_pretrained('/home/css/models/NV-Embed-v2', trust_remote_code=True, device_map="auto")

def calculate_semantic_similarity(sentence1, sentence2, max_length=32768):
    # 对输入的两个句子进行编码
    embeddings = model_semantic_similarity.encode([sentence1, sentence2], 
                                                  instruction="Retrieve semantically similar text.", 
                                                  max_length=max_length)
    
    # 计算余弦相似度
    similarity = cos_sim(embeddings[0], embeddings[1])
    
    return similarity.item()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

We've detected an older driver with an RTX 4000 series GPU. These drivers have issues with P2P. This can affect the multi-gpu inference when using accelerate device_map.Please make sure to update your driver to the latest version which resolves this.


In [3]:
text_1="The lecturer went against the authour 's insistence for the following reasons ."
text_2="The lecturer went against the authour's insistence for the following reasons."
calculate_semantic_similarity(text_1,text_2)

  'input_ids': torch.tensor(batch_dict.get('input_ids').to(batch_dict.get('input_ids')).long()),
  self.gen = func(*args, **kwds)


0.9861305952072144

In [4]:
WhitespaceCorrector.available_models()

[ModelInfo(name='eo_large_byte', description='Byte-level model combining fast inference and good quality', tags=['default', 'lang::en', 'arch::encoder-only', 'input::byte']),
 ModelInfo(name='eo_large_char', description='Character-level model combining fast inference and good quality', tags=['lang::en', 'arch::encoder-only', 'input::char']),
 ModelInfo(name='eo_large_char_v1', description='Character-level model combining fast inference and good quality, trained with a different loss than eo_large_char', tags=['lang::en', 'arch::encoder-only', 'input::char']),
 ModelInfo(name='eo_larger_byte', description='Larger and slower than eo_large_byte, but also more accuracte', tags=['lang::en', 'arch::encoder-only', 'input::byte']),
 ModelInfo(name='eo_medium_byte', description='Smaller and faster than eo_large_byte, but less accurate', tags=['lang::en', 'arch::encoder-only', 'input::byte']),
 ModelInfo(name='eo_medium_char', description='Smaller and faster than eo_large_char, but less accurate

In [5]:
cor = WhitespaceCorrector.from_pretrained(model="eo_larger_byte", device="cuda:0", download_dir="/home/css/models/wsc")

2024-11-12 17:17:06,477 [WHITESPACE CORRECTION DOWNLOAD] [INFO] eo_larger_byte is already downloaded to download directory /home/css/models/wsc
  return torch.load(checkpoint_path, map_location=device)
2024-11-12 17:17:07,407 [WHITESPACE CORRECTION] [INFO] running eo_huge_byte whitespace corrector on device NVIDIA GeForce RTX 4090 (24,217MiB memory, 8.9 compute capability, 128 multiprocessors)


## 1.1读取txt文件，wsc，保存txt文件

In [None]:
# 文本文件路径
tt_lst = [0, 1]
for tt in tt_lst:
    file_name = ["sources.txt", "corrections.txt"]
    input_file_path = "/home/liujunhui/workspace/proj/WSC/dataset/jfleg_txt/jfleg/" + file_name[tt]
    output_file_path = "/home/liujunhui/workspace/proj/WSC/dataset/jfleg_txt/jfleg_corrected/" + file_name[tt]

    # 打开文件并读取内容到列表
    with open(input_file_path, "r") as file:
        jfleg_lst = [line.strip() for line in file.readlines()]


    jfleg_lst_output = []
    for item in tqdm(jfleg_lst):
        corrected_string = cor.correct_text(item)
        jfleg_lst_output.append(corrected_string)

    # 将列表内容写入到txt文件中，最后一行不加换行符
    with open(output_file_path, "w") as file:
        for i, line in enumerate(jfleg_lst_output):
            if i < len(jfleg_lst_output) - 1:
                file.write(line + "\n")
            else:
                file.write(line)

  0%|          | 0/1601 [00:00<?, ?it/s]

  with amp.autocast(enabled=False):


  0%|          | 0/1601 [00:00<?, ?it/s]

## 1.2txt文件 rollback

In [10]:
# 文本文件路径
score = 0.98    #确定阈值，小于阈值的WSC纠正进行回滚
tt_lst = [0, 1]
for tt in tt_lst:
    file_name = ["sources.txt", "corrections.txt"]
    input_file_path = "/home/liujunhui/workspace/proj/WSC/dataset/jfleg_txt/jfleg/" + file_name[tt]
    output_file_path_rollback = "/home/liujunhui/workspace/proj/WSC/dataset/jfleg_txt/jfleg_rollback/" + file_name[tt]

    # 打开文件并读取内容到列表
    with open(input_file_path, "r") as file:
        jfleg_lst = [line.strip() for line in file.readlines()]

    jfleg_lst_output_rollback = []
    nn = 0
    for item in tqdm(jfleg_lst):
        ori_text = item
        cor_text = cor.correct_text(item)
        if calculate_semantic_similarity(ori_text, cor_text) < score:
            jfleg_lst_output_rollback.append(ori_text)
            nn += 1
        else:
            jfleg_lst_output_rollback.append(cor_text)

    with open(output_file_path_rollback, "w") as file:
        for i, line in enumerate(jfleg_lst_output_rollback):
            if i < len(jfleg_lst_output_rollback) - 1:
                file.write(line + "\n")
            else:
                file.write(line)

    print(nn)

  0%|          | 0/1601 [00:00<?, ?it/s]

  with amp.autocast(enabled=False):
  'input_ids': torch.tensor(batch_dict.get('input_ids').to(batch_dict.get('input_ids')).long()),
  self.gen = func(*args, **kwds)


325


  0%|          | 0/1601 [00:00<?, ?it/s]

405


## 2.1保存excel文件

In [8]:
# 文本文件路径
score=0.98
tt_lst = [0, 1]
for tt in tt_lst:
    file_name = ["sources", "corrections"]
    input_file_path = "/home/liujunhui/workspace/proj/WSC/dataset/jfleg_txt/jfleg/" + file_name[tt] + ".txt"
    output_file_path_1 = "/home/liujunhui/workspace/proj/WSC/dataset/jfleg_excel/jfleg/" + file_name[tt] + ".xlsx"
    output_file_path_2 = "/home/liujunhui/workspace/proj/WSC/dataset/jfleg_excel/jfleg_corrected/" + file_name[tt] + ".xlsx"
    output_file_path_3 = "/home/liujunhui/workspace/proj/WSC/dataset/jfleg_excel/jfleg_rollback/" + file_name[tt] + ".xlsx"

    # 打开文件并读取内容到列表
    with open(input_file_path, "r") as file:
        jfleg_lst = [line.strip() for line in file.readlines()]


    # output_1
    id_lst = []
    jfleg_lst_output = []
    counter = 1  # 初始化计数器
    for item in tqdm(jfleg_lst):
        id_lst.append(f"a_{counter:04d}")
        corrected_string = cor.correct_text(item)
        jfleg_lst_output.append(corrected_string)

        counter += 1

    data = {
        "id":id_lst,
        "text":jfleg_lst_output
    }
    df = pd.DataFrame(data)
    df.to_excel(output_file_path_1, index=False)

    # output_2
    id_lst = []
    jfleg_lst_output = []
    counter = 1  # 初始化计数器
    for item in tqdm(jfleg_lst):
        id_lst.append(f"a_{counter:04d}")
        corrected_string = cor.correct_text(item)
        jfleg_lst_output.append(corrected_string)

        counter += 1

    data = {
        "id":id_lst,
        "text":jfleg_lst_output
    }
    df = pd.DataFrame(data)
    df.to_excel(output_file_path_2, index=False)

    # output_3
    id_lst = []
    jfleg_lst_output = []
    counter = 1  # 初始化计数器
    for item in tqdm(jfleg_lst):

        ori_text = item
        cor_text = cor.correct_text(item)
        if calculate_semantic_similarity(ori_text, cor_text) < score:
            id_lst.append(f"a_{counter:04d}")
            jfleg_lst_output.append(ori_text)
        else:
            id_lst.append(f"b_{counter:04d}")
            jfleg_lst_output.append(cor_text)

        counter += 1

    data = {
        "id":id_lst,
        "text":jfleg_lst_output
    }
    df = pd.DataFrame(data)
    df.to_excel(output_file_path_3, index=False)

  0%|          | 0/1601 [00:00<?, ?it/s]

  0%|          | 0/1601 [00:00<?, ?it/s]

  0%|          | 0/1601 [00:00<?, ?it/s]

  0%|          | 0/1601 [00:00<?, ?it/s]

  0%|          | 0/1601 [00:00<?, ?it/s]

  0%|          | 0/1601 [00:00<?, ?it/s]