In [1]:
import torch
from tqdm import tqdm
from torch import nn
from transformers import BertForTokenClassification, BertTokenizer
from torch.utils.data import DataLoader
from CSCDatasets import DetectionDataset

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

train_dataset = DetectionDataset("../dataset/detect_train.tsv")
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)

tokenizer = BertTokenizer.from_pretrained("../pretrained_models/bert-base-chinese")
model = BertForTokenClassification.from_pretrained("../pretrained_models/bert-base-chinese", use_safetensors=True).to(device)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at ../bert-base-chinese and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
lr = 1e-5
loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

total = len(train_dataloader)
with tqdm(total=total) as progress:
    for text, label in train_dataloader:
        inputs =  tokenizer(text, return_tensors='pt', max_length=256, truncation=True, padding='max_length').to(device)
        label = label.to(device)
        outputs = model(**inputs).logits
        loss = loss_fn(outputs.permute(0, 2, 1), label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        progress.set_postfix(loss=loss.item())
        progress.update(1)

100%|██████████| 7869/7869 [44:37<00:00,  2.94it/s, loss=0.004]   


In [22]:
state_dict = torch.load("../bert_base_cn_detect_0.ckpt")
model.load_state_dict(state_dict)

<All keys matched successfully>

In [89]:
text = ""
inputs = tokenizer(text, return_tensors='pt').to(device)
output = model(**inputs).logits
error_prob = torch.nn.functional.softmax(output, dim=-1)[0, :, 1].tolist()[1:-1]
for c, l in zip(text, error_prob):
    if l > 0.1:
        print(f"error: {c}")

In [46]:
torch.save(model.state_dict(), "../check_point/bert_base_cn_detect_1.ckpt")

In [60]:
torch.nn.functional.softmax(output, dim=-1)[0, :, 1].tolist()[1:-1]

[8.605405309936032e-05,
 0.0016445911023765802,
 0.3076367676258087,
 0.9885391592979431,
 0.0012921460438519716,
 0.0002344208478461951,
 8.365231769857928e-05,
 0.00010163187107536942,
 0.0005347515107132494,
 0.0008677940932102501,
 0.00024000425764825195,
 7.675557571928948e-05,
 7.319291034946218e-05,
 7.76168963056989e-05]