In [1]:
import torch
from tqdm import tqdm
from torch import nn
from model import PinyinBertForMaskedLM, PinyinBertCorrectModel
from transformers import BertForTokenClassification, BertTokenizer
from torch.utils.data import DataLoader
from CSCDatasets import DetectionDataset, CorrectDataset

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

tokenizer = BertTokenizer.from_pretrained("../pretrained_models/bert-base-chinese")
detection = BertForTokenClassification.from_pretrained("../pretrained_models/bert-base-chinese", use_safetensors=True).to(device)
corrector = PinyinBertForMaskedLM.from_pretrained("../pretrained_models/bert-base-chinese",  use_safetensors=True).to(device)
model = PinyinBertCorrectModel(detection, corrector).to(device)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at ../pretrained_models/bert-base-chinese and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of PinyinBertForMaskedLM were not initialized from the model checkpoint at ../pretrained_models/bert-base-chinese and are newly initialized: ['bert.embeddings.pinyin_embeddings.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
train_dataset = CorrectDataset("../dataset/correct_train.tsv")
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True, num_workers=4)

learning_rate = 1e-5
loss_fn = nn.CrossEntropyLoss(ignore_index=0).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [6]:
total = len(train_dataloader)

with tqdm(total=total) as progress:
    for original_text, correct_text, pinyin_ids in train_dataloader:
        optimizer.zero_grad()
    
        inputs = tokenizer(original_text, return_tensors="pt", max_length=256, truncation=True, padding=True).to(device)
        label = tokenizer(correct_text, return_tensors="pt", max_length=256, truncation=True, padding=True)['input_ids'].to(device)
        pinyin_ids = pinyin_ids[:, :inputs["input_ids"].size(-1)].to(device)
    
        output = model(**inputs, pinyin_ids=pinyin_ids)[0]
        loss = loss_fn(output.permute(0, 2, 1), label)
        loss.backward()
        optimizer.step()
    
        # Update the progress bar with the current loss
        progress.set_postfix(loss=loss.item())
        progress.update(1)

 71%|███████   | 5548/7869 [28:27<11:54,  3.25it/s, loss=0.0839]


RuntimeError: Expected target size [32, 254], got [32, 256]

In [9]:
inputs = tokenizer(original_text, return_tensors="pt", max_length=256, truncation=True, padding=True).to(device)
label = tokenizer(correct_text, return_tensors="pt", max_length=256, truncation=True, padding=True)['input_ids'].to(device)
pinyin_ids = pinyin_ids[:, :inputs["input_ids"].size(-1)].to(device)

In [12]:
output = model(**inputs, pinyin_ids=pinyin_ids)[0].permute(0, 2, 1)
loss_fn(output, label)

tensor(1.6400, device='cuda:0', grad_fn=<NllLoss2DBackward0>)

In [16]:
a = torch.FloatTensor(1, 256, 723)
b = torch.FloatTensor(1, 256)

In [18]:
a * b.unsqueeze(-1)

tensor([[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-5.7249e-41, -0.0000e+00, -5.6507e-41,  ..., -6.9635e-41,
          -0.0000e+00, -4.3020e-42],
         ...,
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]]])

In [19]:
import sys
sys.path.append('..')

from build.utils import *

def predict(text, pinyin_ids):
    input = tokenizer(text, return_tensors='pt').to(device)
    output = model(**input, pinyin_ids=pinyin_ids)[0]
    output_ids = output.argmax(dim=-1)
    return "".join(tokenizer.convert_ids_to_tokens(output_ids[0]))

def get_pinyin_ids(text, pinyin_vocab, device):
    pinyin_tokens = text2pinyintoken(text)
    pinyin_ids = [0] + list(map(lambda x : pinyin_vocab[x], pinyin_tokens)) +[0]
    pinyin_ids = torch.tensor(pinyin_ids, device=device)
    return pinyin_ids

In [21]:
pinyin_vocab = {}
with open("../build/pinyin_vocab.txt", "r", encoding="utf-8") as f:
    lines = f.readlines()
    for i, line in enumerate(lines):
        pinyin_vocab[line.strip("\n")] = i

In [123]:
text = "我们"
predict(text, get_pinyin_ids(text, pinyin_vocab, device))

'[CLS]我的父母都在国企上班[SEP]'

In [None]:
torch.save(model.state_dict()， ""