In [1]:
import os

import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
from transformers import BertModel, BertTokenizer
from tqdm.notebook import tqdm
from TorchCRF import CRF

from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from torch.utils.tensorboard import SummaryWriter
from collections import deque, defaultdict
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 1. 加载BERT-large模型

In [2]:
tokenizer = BertTokenizer.from_pretrained(
    "bert-large-cased", cache_dir="../../../BERT/large"
)

model = BertModel.from_pretrained("bert-large-cased", cache_dir="../../../BERT/large")
model = prepare_model_for_kbit_training(model)

In [3]:
model

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(28996, 1024, padding_idx=0)
    (position_embeddings): Embedding(512, 1024)
    (token_type_embeddings): Embedding(2, 1024)
    (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-23): 24 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=1024, out_features=1024, bias=True)
            (key): Linear(in_features=1024, out_features=1024, bias=True)
            (value): Linear(in_features=1024, out_features=1024, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=1024, out_features=1024, bias=True)
            (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inpl

In [4]:
def find_all_linear_names(model):
    cls = torch.nn.Linear
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    return list(lora_module_names)

In [5]:
modules = find_all_linear_names(model)
config = LoraConfig(r=16,
        lora_alpha=16,
        lora_dropout=0.05,
        bias="none",
        target_modules=['query', 'value'],
        task_type="TOKEN_CLS",
    )
lora_model = get_peft_model(model, config)
lora_model.to(device)

PeftModelForTokenClassification(
  (base_model): LoraModel(
    (model): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(28996, 1024, padding_idx=0)
        (position_embeddings): Embedding(512, 1024)
        (token_type_embeddings): Embedding(2, 1024)
        (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-23): 24 x BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): lora.Linear(
                  (base_layer): Linear(in_features=1024, out_features=1024, bias=True)
                  (lora_dropout): ModuleDict(
                    (default): Dropout(p=0.05, inplace=False)
                  )
                  (lora_A): ModuleDict(
                    (default): Linear(in_features=1024, out_features=16, bias=False)
                  )
      

In [6]:
# 模型大小
print(sum(i.numel() for i in lora_model.parameters() if i.requires_grad) / 1000000)

1.572864


# 2. 获取数据

In [7]:
df = pd.read_csv("./relations/sample.csv")
df = df[df['entity'].str[:]!='[]']
df["label"] = df["label"].apply(
    lambda x: x[8:-2].replace("\n", "").replace(" ", "").split(",")
)
df["label"] = df["label"].apply(lambda x: [int(i) for i in x])
df["text"] = df["text"].apply(lambda x: x[2:-2].split("', '"))

FileNotFoundError: [Errno 2] No such file or directory: './relations/sample.csv'

# 3. 创建dataset

In [None]:
class NERDataset(Dataset):
    def __init__(self, dataframe, tokenizer):
        self.data = dataframe
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        text = self.data["text"][index]
        labels = self.data["label"][index][1:-1]

        # Convert tokens to token IDs
        tokens = [i.replace(" ", "") for i in text][1:-1]

        return tokens, labels

## 3.1. 创建数据整理函数

In [None]:
def batch_tokenizer(input_text, max_len=50):
    res = defaultdict(list)
    max_len -= 2
    for text in input_text:
        ids = tokenizer.convert_tokens_to_ids(text)
        valid_len = len(ids) + 2
        if len(ids) > max_len:
            ids = ids[:max_len]
            ids = [101] + ids + [102]
            attention_mask = [1] * (max_len + 2)
        else:
            ids = [101] + ids + [102] + [0] * (max_len - len(ids))
            attention_mask = [1] * valid_len + [0] * (max_len - valid_len + 2)
        res['input_ids'].append(ids)
        res['attention_mask'].append(attention_mask)
    res['input_ids'] = torch.tensor(res['input_ids']).to(device)
    res['attention_mask'] = torch.tensor(res['attention_mask']).to(device)
    return res


In [None]:
# 数据整理函数
def collate_fn(data):
    tokens = [i[0] for i in data]
    labels = [i[1] for i in data]
    inputs = batch_tokenizer(tokens)

    lens = inputs["input_ids"].shape[1]

    for i in range(len(labels)):
        labels[i] = [3] + labels[i]
        labels[i] += [3] * lens
        labels[i] = labels[i][:lens]

    return inputs, torch.LongTensor(labels)

## 3.2. 拆分数据集

In [None]:
# Set batch size
batch_size = 32

# Create the DataLoader
# Split the dataset into training and validation sets
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)

# Create the training dataset and dataloader
train_dataset = NERDataset(train_df.reset_index(), tokenizer)
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    drop_last=True,
)

# Create the validation dataset and dataloader
val_dataset = NERDataset(val_df.reset_index(), tokenizer)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    drop_last=True,
)

In [12]:
# 模型试算
for i, j in train_dataloader:
    print(lora_model(i['input_ids'], i['attention_mask']).last_hidden_state.shape)
    break

torch.Size([32, 50, 1024])


# 4. 搭建微调模型

In [13]:
# 定义下游模型
class Model(torch.nn.Module):
    def __init__(self, pretrained):
        super().__init__()
        self.pretrained = pretrained
        self.hidden_size = deque(pretrained.parameters())[-1].shape[0]
        self.fc1 = torch.nn.Linear(self.hidden_size, int(self.hidden_size / 2))
        self.swish = torch.nn.SiLU()
        self.fc2 = torch.nn.Linear(int(self.hidden_size / 2), 4)
        self.crf = CRF(num_tags=4, batch_first=True)

    def forward(self, inputs):
        with torch.no_grad():
            out = self.pretrained(**inputs).last_hidden_state

        out = self.fc2(self.swish(self.fc1(out))).softmax(dim=2)

        return out

    def loss(self, inputs, labels):
        with torch.no_grad():
            out = self.pretrained(**inputs).last_hidden_state

        out = self.fc2(self.swish(self.fc1(out))).softmax(dim=2)
        loss = -self.crf(
            out, labels, inputs["attention_mask"].to(torch.bool), reduction="mean"
        )

        return loss


mymodel = Model(lora_model)
mymodel = mymodel.to(device)

In [14]:
print(sum(i.numel() for i in mymodel.parameters() if i.requires_grad) / 1000000)

2.09974


In [15]:
# 对计算结果和label变形,并且移除pad
def reshape_and_remove_pad(outs, labels, attention_mask):
    # 变形,便于计算loss
    # [b, lens, 8] -> [b*lens, 8]
    outs = outs.reshape(-1, 4)
    # [b, lens] -> [b*lens]
    labels = labels.reshape(-1)

    # 忽略对pad的计算结果
    # [b, lens] -> [b*lens - pad]
    select = attention_mask.reshape(-1) == 1
    outs = outs[select]
    labels = labels[select]

    return outs, labels

In [16]:
# 获取正确数量和总数
def get_correct_and_total_count(labels, outs):
    # [b*lens, 8] -> [b*lens]
    outs = outs.argmax(dim=1)
    correct = (outs == labels).sum().item()
    total = len(labels)

    # 计算除了0以外元素的正确率,因为0太多了,包括的话,正确率很容易虚高
    select = labels != 0
    outs = outs[select]
    labels = labels[select]
    correct_content = (outs == labels).sum().item()
    total_content = len(labels)

    return correct, total, correct_content, total_content

# 5. 训练

In [17]:
# 训练
def train(loader, epochs, model=None, optimizer=None, scheduler=None):
    writer = SummaryWriter()

    # 训练
    criterion = torch.nn.CrossEntropyLoss()
    mymodel.train()
    i = 0
    for epoch in range(epochs):

        for step, (inputs, labels) in tqdm(
            enumerate(loader), desc="Epoch" + str(epoch + 1), total=len(loader)
        ):
            i += 1
            optimizer.zero_grad()
            labels = labels.to(device)
            inputs = {k: v.to(device) for k, v in inputs.items()}
            outs = model(inputs)
            loss = model.loss(inputs, labels)
            writer.add_scalar("Loss/train", loss, i)
            loss.backward()
            optimizer.step()
            outs, labels = reshape_and_remove_pad(
                outs, labels=labels, attention_mask=inputs["attention_mask"]
            )
            counts = get_correct_and_total_count(labels, outs)

            accuracy = counts[0] / counts[1]
            accuracy_content = counts[2] / counts[3]

            writer.add_scalar("accuracy/train", accuracy, i)
            writer.add_scalar("true_accuracy/train", accuracy_content, i)
        if scheduler is not None:
            print("lr: ", optimizer.param_groups[0]["lr"])
            scheduler.step()
        torch.save(model.state_dict(), "../../../BERT/NER_FT/NER_FT.pkl")
        torch.save(optimizer.state_dict(), "../../../BERT/NER_FT/NER_FT_optimizer.pkl")
    writer.flush()
    writer.close()

In [18]:
# 检查训练数据
for i, j in train_dataloader:
    print(tokenizer.decode(i['input_ids'][0]))
    list_a = []
    for t in range(len(j[0])):
        if j[0][t]!=0:
            list_a.append(i['input_ids'][0][t])
        else:
            list_a.append(119)
    print(tokenizer.decode(list_a))
    print('===================')

[CLS] Hessian H that we introduced for our analysis of L1 regularization, we find that [UNK] = H i, i w. If [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
[CLS].......... L1 regularization............. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
[CLS] ( Krizhevsky et al., 2012 ; Ioffe and Szegedy, 2015 ). Object recognition is the same basic technology [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
[CLS]............................... [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
[CLS] decrease by various amounts. However, empirical risk minimization is prone to overfitting. Models [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [

In [25]:
lr = 1e-3
epochs = 100
optimizer = torch.optim.AdamW(mymodel.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, epochs, eta_min=0, last_epoch=-1
)
if os.path.exists("../../../BERT/NER_FT/NER_FT.pkl"):
    mymodel.load_state_dict(torch.load("../../../BERT/NER_FT/NER_FT.pkl"))
    optimizer.load_state_dict(torch.load("../../../BERT/NER_FT/NER_FT_optimizer.pkl"))

train(
    epochs=epochs,
    model=mymodel,
    loader=train_dataloader,
    optimizer=optimizer,
    scheduler=scheduler,
)

Epoch1:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.001


Epoch2:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0009997532801828658


Epoch3:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0009990133642141358


Epoch4:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.00099778098230154


Epoch5:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.000996057350657239


Epoch6:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0009938441702975688


Epoch7:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0009911436253643444


Epoch8:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0009879583809693736


Epoch9:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0009842915805643154


Epoch10:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0009801468428384714


Epoch11:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0009755282581475767


Epoch12:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0009704403844771127


Epoch13:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0009648882429441257


Epoch14:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0009588773128419905


Epoch15:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0009524135262330098


Epoch16:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0009455032620941839


Epoch17:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0009381533400219318


Epoch18:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0009303710135019719


Epoch19:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0009221639627510076


Epoch20:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.000913540287137281


Epoch21:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0009045084971874739


Epoch22:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0008950775061878453


Epoch23:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0008852566213878948


Epoch24:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0008750555348152299


Epoch25:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0008644843137107058


Epoch26:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0008535533905932738


Epoch27:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0008422735529643445


Epoch28:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.000830655932661826


Epoch29:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.000818711994874345


Epoch30:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0008064535268264884


Epoch31:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0007938926261462368


Epoch32:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0007810416889260655


Epoch33:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0007679133974894984


Epoch34:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0007545207078751858


Epoch35:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0007408768370508578


Epoch36:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0007269952498697736


Epoch37:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0007128896457825365


Epoch38:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0006985739453173904


Epoch39:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0006840622763423392


Epoch40:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0006693689601226459


Epoch41:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0006545084971874739


Epoch42:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0006394955530196148


Epoch43:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0006243449435824275


Epoch44:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0006090716206982715


Epoch45:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0005936906572928626


Epoch46:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0005782172325201157


Epoch47:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0005626666167821524


Epoch48:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0005470541566592573


Epoch49:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.000531395259764657


Epoch50:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0005157053795390644


Epoch51:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0005000000000000002


Epoch52:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.00048429462046093607


Epoch53:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0004686047402353435


Epoch54:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.000452945843340743


Epoch55:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.00043733338321784806


Epoch56:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0004217827674798847


Epoch57:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.00040630934270713783


Epoch58:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.000390928379301729


Epoch59:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0003756550564175727


Epoch60:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.00036050444698038553


Epoch61:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.00034549150281252655


Epoch62:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0003306310398773544


Epoch63:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.00031593772365766127


Epoch64:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0003014260546826097


Epoch65:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0002871103542174637


Epoch66:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0002730047501302267


Epoch67:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.00025912316294914234


Epoch68:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0002454792921248144


Epoch69:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.00023208660251050164


Epoch70:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.00021895831107393473


Epoch71:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.00020610737385376356


Epoch72:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.00019354647317351177


Epoch73:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0001812880051256552


Epoch74:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.00016934406733817422


Epoch75:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0001577264470356557


Epoch76:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.00014644660940672634


Epoch77:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0001355156862892944


Epoch78:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.0001249444651847703


Epoch79:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.00011474337861210548


Epoch80:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  0.00010492249381215483


Epoch81:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  9.549150281252637e-05


Epoch82:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  8.645971286271923e-05


Epoch83:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  7.78360372489926e-05


Epoch84:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  6.962898649802815e-05


Epoch85:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  6.184665997806824e-05


Epoch86:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  5.449673790581613e-05


Epoch87:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  4.758647376699034e-05


Epoch88:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  4.112268715800956e-05


Epoch89:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  3.511175705587434e-05


Epoch90:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  2.9559615522887284e-05


Epoch91:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  2.447174185242324e-05


Epoch92:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  1.9853157161528526e-05


Epoch93:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  1.570841943568452e-05


Epoch94:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  1.204161903062634e-05


Epoch95:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  8.85637463565564e-06


Epoch96:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  6.155829702431171e-06


Epoch97:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  3.942649342761118e-06


Epoch98:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  2.2190176984600023e-06


Epoch99:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  9.866357858642206e-07


Epoch100:   0%|          | 0/27 [00:00<?, ?it/s]

lr:  2.467198171342e-07


In [26]:
# 测试
def test(loader_test=val_dataloader):
    mymodel.load_state_dict(torch.load("../../../BERT/NER_FT/NER_FT.pkl"))
    mymodel.eval()

    correct = 0
    total = 0

    correct_content = 0
    total_content = 0

    for step, (inputs, labels) in enumerate(loader_test):
        labels = labels.to(device)
        print(step)

        with torch.no_grad():
            # [b, lens] -> [b, lens, 8] -> [b, lens]
            outs = mymodel(inputs)

        # 对outs和label变形,并且移除pad
        # outs -> [b, lens, 8] -> [c, 8]
        # labels -> [b, lens] -> [c]
        outs, labels = reshape_and_remove_pad(outs, labels, inputs["attention_mask"])

        counts = get_correct_and_total_count(labels, outs)
        correct += counts[0]
        total += counts[1]
        correct_content += counts[2]
        total_content += counts[3]

        print(correct / total, correct_content / total_content)


test(val_dataloader)

0
0.8875338753387534 0.8354430379746836
1
0.8782263401720715 0.7580645161290323
2
0.8779946761313221 0.7814207650273224
3
0.879124469127736 0.7879213483146067
4
0.8721234309623431 0.7956043956043956
5
0.8703301476976543 0.799625468164794


In [27]:
# 测试
def predict(loader_test=val_dataloader):
    mymodel.load_state_dict(torch.load("../../../BERT/NER_FT/NER_FT.pkl"))
    mymodel.eval()

    for i, (inputs, labels) in enumerate(loader_test):

        inputs["input_ids"] = inputs["input_ids"].to(device)
        inputs["attention_mask"] = inputs["attention_mask"].to(device)
        labels = labels.to(device)
        with torch.no_grad():
            outs = torch.tensor(mymodel.crf.decode(mymodel(inputs))).to(device)
        for i in range(5):
            # 移除pad
            select = inputs["attention_mask"][i] == 1
            input_id = inputs["input_ids"][i, select]
            out = outs[i, select]
            label = labels[i, select]

            # 输出原句子
            print(tokenizer.decode(input_id))

            # 输出tag
            for index, tag in enumerate([label, out]):
                prefix = "label" if index == 0 else "predict"
                s = ""
                for j in range(len(tag)):
                    if tag[j] == 0:
                        s += "·"
                        continue
                    s += (
                        tokenizer.decode(input_id[j]).replace(" ", "").replace("##", "")
                        + " "
                    )
                    s += " "

                print(prefix + ": " + s)
            print("==========================")


predict(train_dataloader)

[CLS] “ batch gradient descent ” implies the use of the full training set, while the use of the term [SEP]
label: [CLS]  ··········full  training  set  ·······[SEP]  
predict: [CLS]  ·batch  gradient  descent  ·················
[CLS] about the whole pastsequence. Recurrent neural networks can be built in many different ways. Much [SEP]
label: [CLS]  ·······Re  current  neural  networks  ·········[SEP]  
predict: [CLS]  ·······Re  current  neural  networks  ··········
[CLS] of the ∈ ∞ norm penalty term, Ω, relative to the standard objective function J. Setting α to 0 [SEP]
label: [CLS]  ········Ω  ··············[SEP]  
predict: [CLS]  ························
[CLS] disadvantage of having two separate training phases is that each phase has its own hyperparameters. [SEP]
label: [CLS]  ··············h  yper  par  ame  ters  ·[SEP]  
predict: [CLS]  ··············h  yper  par  ame  ters  ··
[CLS] + + bbxx + + bbww + + ccxx + + ccww + + ddxx + + eeyy + + ffzz ffyy + + ggzz ggy [SEP]
label: [