In [7]:
import torch
import random
import numpy as np
from tqdm import tqdm
from datasets import Dataset
from transformers import (
    BertTokenizer,
    DataCollatorForWholeWordMask,
)

In [8]:
def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    return seed

In [9]:
seed = 3407
seed_everything(seed)
model_path = 'bert-base-chinese'
vocab_path = model_path
file_path = "./shortStn.txt"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
max_length = 128
mlm_probability = 0.15
wwm=True

In [10]:
train_data = {"sentence": []}
with open(file_path, 'r', encoding="utf-8") as f:
    lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
    for line in tqdm(lines):
        train_data["sentence"].append(line)

100%|██████████| 222/222 [00:00<?, ?it/s]


In [11]:
dataset = Dataset.from_dict(train_data)
dataset

Dataset({
    features: ['sentence'],
    num_rows: 222
})

In [12]:
tokenizer = BertTokenizer.from_pretrained(vocab_path)

'HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /bert-base-chinese/resolve/main/vocab.txt (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x000001C5F586A940>, 'Connection to huggingface.co timed out. (connect timeout=10)'))' thrown while requesting HEAD https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt


In [13]:
dataset = dataset.map(lambda example:{"input_ids":tokenizer(example["sentence"], truncation=True, max_length=max_length).input_ids},batched=True)
dataset

  0%|          | 0/1 [00:00<?, ?ba/s]


Dataset({
    features: ['sentence', 'input_ids'],
    num_rows: 222
})

In [14]:
dataset.save_to_disk("./dataset_temp_short")

In [15]:
from datasets import load_from_disk
dataset = load_from_disk("./dataset_temp_short")
dataset

Dataset({
    features: ['sentence', 'input_ids'],
    num_rows: 222
})

In [16]:
from get_chinese_ref import prepare_ref
from ltp import LTP

if wwm:
    ltp = LTP().to(device)#这一句要加载180MB的模型并初始化，所以很慢，要大概五秒钟时间
    dataset = dataset.map(lambda example:{"chinese_ref":prepare_ref(example["sentence"], ltp, tokenizer)},batched=True)
dataset

Loading weights from local directory


  0%|          | 0/1 [00:02<?, ?ba/s]


Dataset({
    features: ['sentence', 'input_ids', 'chinese_ref'],
    num_rows: 222
})

In [17]:
dataset["sentence"][0]

'谷物联合收获机自动测产系统设计-基于变权分层激活扩散模型'

In [18]:
tokenizer.decode(dataset["input_ids"][0])

'[CLS] 谷 物 联 合 收 获 机 自 动 测 产 系 统 设 计 - 基 于 变 权 分 层 激 活 扩 散 模 型 [SEP]'

In [19]:
dataset["chinese_ref"][0]

[2, 4, 6, 7, 9, 11, 13, 15, 18, 20, 22, 24, 26, 28]

In [20]:
data_collator = DataCollatorForWholeWordMask(tokenizer=tokenizer, mlm=True, mlm_probability=mlm_probability)

In [21]:
tokenizer.decode(data_collator([dataset[0]])["labels"][0])

'[UNK] 谷 物 [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] 扩 散 [UNK] [UNK] [UNK]'

In [22]:
dataset = dataset.remove_columns("sentence")

In [23]:
dataset.save_to_disk("./dataset_short")
