In [11]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


# 代码复现
https://github.com/terrifyzhao/spo_extract

In [20]:
## train 代码
 
import json
from tqdm import tqdm   # 进度条
import os
import numpy as np
from transformers import BertTokenizer, AdamW
import torch
# from model import ObjectModel, SubjectModel
 
GPU_NUM = 0
 
#device = torch.device(f'cuda:{GPU_NUM}') if torch.cuda.is_available() else torch.device('cpu')
device = torch.device('cpu')
 
vocab = {}
with open('vocab.txt')as file:  # 使用with open 的方法读取词典
    for l in file.readlines():
        vocab[len(vocab)] = l.strip() # 根据key读取词典
 
 
def load_data(filename):  # 中文解码加载数据
    """加载数据
    单条格式：{'text': text, 'spo_list': [[s, p, o]]}
    """
    with open(filename) as f:
        json_list = json.load(f)
    return json_list
 
 
# 加载数据集
train_data = load_data('train_fulltext.json')
valid_data = load_data('dev_fulltext.json')

In [21]:
train_data[0],len(train_data),valid_data[0],len(valid_data)

({'text': 'The Birth of a Nation, originally called The Clansman, is a 1915 American silent epic drama film directed by D. W. Griffith and starring Lillian Gish. The screenplay is adapted from Thomas Dixon Jr.\'s 1905 novel and play "". Griffith co-wrote the screenplay with Frank E. Woods and produced the film with Harry Aitken.\n"The Birth of a Nation" is a landmark of film history, lauded for its technical virtuosity. It was the first American 12-reel film ever made and, at three hours, also the longest up to that point. Its plot, part fiction and part history, chronicles the assassination of Abraham Lincoln by John Wilkes Booth and the relationship of two families in the Civil War and Reconstruction eras over the course of several years—the pro-Union (Northern) Stonemans and the pro-Confederacy (Southern) Camerons. It was originally shown in two parts separated by an intermission, and it was the first American-made film to have a musical score for an orchestra. It pioneered closeups

In [22]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")   # 调用分词器
 
with open('schemas.json') as f: # 读取predicate
    json_list = json.load(f)
    id2predicate = json_list[0]
    predicate2id = json_list[1]

In [23]:
id2predicate,predicate2id,len(id2predicate),len(predicate2id)

({'0': 'unknown', '1': 'direct', '2': 'act', '3': 'star'},
 {'unknown': 0, 'direct': 1, 'act': 2, 'star': 3},
 4,
 4)

In [24]:
from transformers import BertModel, BertPreTrainedModel

import torch.nn as nn
import torch


class SubjectModel(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.bert = BertModel(config)
        self.dense = nn.Linear(config.hidden_size, 2)

    def forward(self, input_ids, attention_mask=None):
        output = self.bert(input_ids, attention_mask=attention_mask)
        subject_out = self.dense(output[0])
        subject_out = torch.sigmoid(subject_out)

        return output[0], subject_out


class ObjectModel(nn.Module):
    def __init__(self, subject_model):
        super().__init__()
        self.encoder = subject_model
        self.dense_subject_position = nn.Linear(2, 768)
        self.dense_object = nn.Linear(768, 4 * 2)

    def forward(self,input_ids,subject_position,attention_mask=None):
        output, subject_out = self.encoder(input_ids, attention_mask)

        subject_position = self.dense_subject_position(subject_position).unsqueeze(1)
        object_out = output + subject_position
        # [bs, 768] -> [bs, 98]
        object_out = self.dense_object(object_out)
        # [bs, 98] -> [bs, 4, 2]
        object_out = torch.reshape(object_out, (object_out.shape[0], object_out.shape[1], 4, 2))
        object_out = torch.sigmoid(object_out)
        object_out = torch.pow(object_out, 4)
        return subject_out, object_out

In [25]:
def search(pattern, sequence):
    """从sequence中寻找子串pattern
    如果找到，返回第一个下标；否则返回-1。
    """
    n = len(pattern)
    for i in range(len(sequence)):
        if sequence[i:i + n] == pattern:
            return i
    return -1
 
 
def sequence_padding(inputs, length=None, padding=0, mode='post'):
    """Numpy函数，将序列padding到同一长度
    """
    if length is None:
        length = max([len(x) for x in inputs])
 
    pad_width = [(0, 0) for _ in np.shape(inputs[0])]
    outputs = []
    for x in inputs:
        x = x[:length]
        if mode == 'post':
            pad_width[0] = (0, length - len(x))
        elif mode == 'pre':
            pad_width[0] = (length - len(x), 0)
        else:
            raise ValueError('"mode" argument must be "post" or "pre".')
        x = np.pad(x, pad_width, 'constant', constant_values=padding)
        outputs.append(x)
 
    return np.array(outputs)

In [26]:
def data_generator(data, batch_size=3):  #  数据迭代器/数据生成器
 
    batch_input_ids, batch_attention_mask = [], [] #  输出给模型（object）的变量，通过调用bert分词器得到
    batch_subject_labels, batch_subject_ids, batch_object_labels = [], [], []
    texts = []
    for i, d in enumerate(data): #  数据来自dataloader i = 数据索引 d = text
        text = d['text'] # 从train 中取出text
        texts.append(text)   # text 贴入元组
        encoding = tokenizer(text=text) # 使用bert 分词
        input_ids, attention_mask = encoding.input_ids, encoding.attention_mask  # 分词后对应“bert词典下标”和mask
        # 整理三元组 {s: [(o, p)]}
        spoes = {}
        for s, p, o in d['spo_list']: # 遍历三元组
            # [cls] XXX [sep]
            s_encoding = tokenizer(text=s).input_ids[1:-1]  # 将s，o编码成对应的下标
            o_encoding = tokenizer(text=o).input_ids[1:-1]  # [1:-1] 去除cls sep
 
            s_idx = search(s_encoding, input_ids) # 从text的input_ids 寻找s的下标
            o_idx = search(o_encoding, input_ids) # 从text的input_ids 寻找o的下标
 
            p = predicate2id[p]  # 的到predicate的下标
 
            if s_idx != -1 and o_idx != -1: # 做判断没有反应的返回-1
                s = (s_idx, s_idx + len(s_encoding) - 1) # s保存subject的起始位置，起始位置加上长度 -1
                o = (o_idx, o_idx + len(o_encoding) - 1, p)# 同上 s,o 是一个元组保存着起始位置和终止位置的下标 以及 p
                if s not in spoes:
                    spoes[s] = []
                spoes[s].append(o) # 将 下标加入 spoes 字典当中去
        if spoes:
            # subject标签
            subject_labels = np.zeros((len(input_ids), 2)) # 生成一个input长度的二维向量/ s头s尾
            for s in spoes:
                # 注意要+1，因为有cls符号
                subject_labels[s[0], 0] = 1 # 第一行 = ‘0’ 的起始 = s[0] 等于1
                subject_labels[s[1], 1] = 1  # 第二行 = ‘1’ 的终止 =s[1] 等于1
            # 一个s对应多个o时，随机选一个subject
            start, end = np.array(list(spoes.keys())).T
            start = np.random.choice(start)
            end = np.random.choice(end[end >= start])
            subject_ids = (start, end)
            # 对应的object标签
            object_labels = np.zeros((len(input_ids), len(predicate2id), 2)) # 序列长度 x predicate长度 x 2
            for o in spoes.get(subject_ids, []): # 通过subject 拿出对应的 o
                object_labels[o[0], o[2], 0] = 1 # 对应 起始位置，predicate ， 第一维度/头（取字o元组）
                object_labels[o[1], o[2], 1] = 1 # 同上
            # 构建batch
            batch_input_ids.append(input_ids)  # 将上述值加入batch
            batch_attention_mask.append(attention_mask)
            batch_subject_labels.append(subject_labels)
            batch_subject_ids.append(subject_ids)
            batch_object_labels.append(object_labels)
            if len(batch_subject_labels) == batch_size or i == len(data) - 1: # 没有补偿
                batch_input_ids = sequence_padding(batch_input_ids)
                batch_attention_mask = sequence_padding(batch_attention_mask)
                batch_subject_labels = sequence_padding(batch_subject_labels)
                batch_subject_ids = np.array(batch_subject_ids)
                batch_object_labels = sequence_padding(batch_object_labels)
                yield [
                          torch.from_numpy(batch_input_ids).long(), torch.from_numpy(batch_attention_mask).long(),
                          torch.from_numpy(batch_subject_labels), torch.from_numpy(batch_subject_ids),
                          torch.from_numpy(batch_object_labels)
                      ], None
                batch_input_ids, batch_attention_mask = [], [] # 清空进入下个batch
                batch_subject_labels, batch_subject_ids, batch_object_labels = [], [], []
 
 
if os.path.exists('graph_model.bin'):  # 加载模型 保存档将graph model 加载过来
    print('load model')
    model = torch.load('graph_model.bin').to(device)
    subject_model = model.encoder
else:

    #subject_model = SubjectModel.from_pretrained('facebook/bart-base') # 没有使用bert train
    subject_model = SubjectModel.from_pretrained('bert-base-cased') # 没有使用bert train
    subject_model.to(device)
 
    model = ObjectModel(subject_model)
    model.to(device)
 
train_loader = data_generator(train_data, batch_size=8) # dataloader = 8
 
optim = AdamW(model.parameters(), lr=5e-6) # 加速器 adamw 学习率 5e-5
loss_func = torch.nn.BCELoss() # cross binary loss
 
model.train()

Some weights of the model checkpoint at bert-base-cased were not used when initializing SubjectModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing SubjectModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SubjectModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of SubjectModel were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['dense.weight', 'dense.bi

ObjectModel(
  (encoder): SubjectModel(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(28996, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
              

In [27]:
class SPO(tuple):
    def __init__(self, spo):
        self.spox = (
            spo[0],
            spo[1],
            spo[2],
        )
 
    def __hash__(self):
        return self.spox.__hash__()
 
    def __eq__(self, spo):
        return self.spox == spo.spox
 
 
def train_func():
    train_loss = 0
    pbar = tqdm(train_loader) # 开启进度条并遍历 train_loader
    for step, batch in enumerate(pbar): # 遍历每个step 和 batch
        optim.zero_grad()  # 将每个梯度清零
        batch = batch[0] # 将batch 数据取出来第一个维度
        input_ids = batch[0].to(device) # text对应bert词典的下标
        attention_mask = batch[1].to(device) # mask
        subject_labels = batch[2].to(device) # subject对应bert词典的下标
        subject_ids = batch[3].to(device) # subject 在句子中id
        object_labels = batch[4].to(device) # object对应bert词典的下标
 
        subject_out, object_out = model(input_ids, subject_ids.float(), attention_mask) # 拿到subject和object输出
        subject_out = subject_out * attention_mask.unsqueeze(-1) # 将输入中补长的位置变成 0 / input当中的padding
        object_out = object_out * attention_mask.unsqueeze(-1).unsqueeze(-1)#  同上
 
        subject_loss = loss_func(subject_out, subject_labels.float()) # 识别subject的损失函数
        object_loss = loss_func(object_out, object_labels.float()) #  识别object的损失函数
 
        # subject_loss = torch.mean(subject_loss, dim=2)
        # subject_loss = torch.sum(subject_loss * attention_mask) / torch.sum(attention_mask)
 
        loss = subject_loss + object_loss # 将loss进行相加 根据实际情况添加超参数
 
        train_loss += loss.item() # 累加到train loss
        loss.backward() # 反向传播
        optim.step() # 更新参数
 
        pbar.update()
        pbar.set_description(f'train loss:{loss.item()}')  # 显示更新参数
 
        if step % 1000 == 0:  # 每跑1000个step 保存模型
            torch.save(model, 'graph_model.bin')
 
        if step % 5 == 0 and step != 0:  # 每跑100步在验证集当中检验效果
            with torch.no_grad():
                # texts = ['The film stars William Boyd, Russell Hayden, Andy Clyde, Eleanor Stewart, Morris Ankrum and William Haade.',
                #          'Texas Rangers Ride Again is a 1940 American Western film directed by James P. Hogan, written by William R. Lipman and Horace McCoy, and starring Ellen Drew, John Howard, Akim Tamiroff, May Robson, Broderick Crawford, Charley Grapewin, and John Miljan.']
                X, Y, Z = 1e-10, 1e-10, 1e-10
                pbar = tqdm()
                spo = []
                for data in valid_data[0:10]: # 遍历验证集
                # for text in texts:
                    text = data['text'] # 取出text
                    spo_ori = data['spo_list'] # 去除三元组
                    en = tokenizer(text=text, return_tensors='pt') # 将text分词
                    _, subject_preds = subject_model(en.input_ids.to(device), en.attention_mask.to(device)) # 检验阶段需要预测subject的下标
                    subject_preds = subject_preds.cpu().data.numpy() # 将下标转换成numpy数组
                    start = np.where(subject_preds[0, :, 0] > 0.5)[0] # 阈值，大于0.5判断为start
                    end = np.where(subject_preds[0, :, 1] > 0.4)[0] # 阈值 大于0.4判断为end # 阈值自己设定
                    subjects = []
                    for i in start: # 遍历start 用来应对多个start的情况
                        j = end[end >= i] # 只取大于start的end 否则会出现逻辑错误
                        if len(j) > 0: # 如果 end 大于0将 start end 成对加入subject
                            j = j[0]
                            subjects.append((i, j))
                    # print(subjects)
                    if subjects:
                        for s in subjects: # 遍历每个s
                            index = en.input_ids.cpu().data.numpy().squeeze(0)[s[0]:s[1] + 1] # 根据输入的下标
                            subject = ''.join([vocab[i] for i in index]) # 将bert的vcab里的汉字映射出来
                            # print(subject)
 
                            _, object_preds = model(en.input_ids.to(device), # 将input分词的结果添加进去
                                                    torch.from_numpy(np.array([s])).float().to(device), # s的下标添加进去
                                                    en.attention_mask.to(device)) # 将mask添加进去
                            object_preds = object_preds.cpu().data.numpy() # 转换成numpy数组
                            for object_pred in object_preds:  # 遍历所有的object
                                start = np.where(object_pred[:, :, 0] > 0.2) # object的阈值大于0.2取start
                                end = np.where(object_pred[:, :, 1] > 0.2) # 同上
                                for _start, predicate1 in zip(*start): # 星号zip代表把两个值解开 两行对应的元组 # 遍历start取 s 和 p
                                    for _end, predicate2 in zip(*end): # 遍历end 取 e 和 p
                                        if _start <= _end and predicate1 == predicate2: # 判断是否复合逻辑 spo
                                            index = en.input_ids.cpu().data.numpy().squeeze(0)[_start:_end + 1] # 从输入中找到对应下标
                                            object = ''.join([vocab[i] for i in index]) # 从bert词典中映射成中文
                                            predicate = id2predicate[str(predicate1)] # 找predicate下标返回predicate
                                            # print(object, '\t', predicate)
                                            spo.append([subject, predicate, object])  # 三元组放到数组当中
                    # 预测结果
                    R = set([SPO(_spo) for _spo in spo]) # 预测去重
                    # print(R)
                    # 真实结果
                    T = set([SPO(_spo) for _spo in spo_ori]) # 真是去重
                    # R = set(spo_ori)
                    # T = set(spo)
                    # 交集
                    X += len(R ) #& T R & T 交集长度
                    Y += len(R)  # R 长度
                    Z+= len(T)#=1 # T 长度
                    f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z # f1 精准度 召回率
                    pbar.update()  # 把代码更新到pbar
                    pbar.set_description(
                        'f1: %.5f, precision: %.5f, recall: %.5f' % (f1, precision, recall)
                    )
                pbar.close()
                print('f1:', f1, 'precision:', precision, 'recall:', recall)
 
 
for epoch in range(20):
    print('************start train************')
    train_func()

************start train************


train loss:0.608623743057251: : 9it [01:58, 12.27s/it] 
0it [00:00, ?it/s][A
1it [00:02,  2.58s/it][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 1it [00:02,  2.58s/it][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 2it [00:06,  3.52s/it][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 2it [00:06,  3.52s/it][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 3it [00:12,  4.31s/it][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 3it [00:12,  4.31s/it][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 4it [00:19,  5.41s/it][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 4it [00:19,  5.41s/it][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 5it [00:31,  7.79s/it][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 5it [00:31,  7.79s/it][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 6it [00:44,  9.66s/it][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 6it [00:44,  9.66s/it][A
f1: 0.00000, precision: 1.00000, recal

f1: 1.9999999999600002e-11 precision: 1.0 recall: 9.999999999900001e-12


train loss:0.608623743057251: : 10it [04:02, 38.94s/it]
0it [00:00, ?it/s][A
1it [00:01,  1.00s/it][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 1it [00:01,  1.00s/it][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 2it [00:01,  1.39it/s][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 2it [00:01,  1.39it/s][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 3it [00:02,  1.16it/s][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 3it [00:02,  1.16it/s][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 4it [00:04,  1.17s/it][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 4it [00:04,  1.17s/it][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 5it [00:05,  1.05s/it][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 5it [00:05,  1.05s/it][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 6it [00:06,  1.25s/it][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 6it [00:06,  1.25s/it][A
f1: 0.00000, precision: 1.00000, recal

f1: 1.9999999999600002e-11 precision: 1.0 recall: 9.999999999900001e-12


train loss:0.608623743057251: : 15it [06:29, 26.80s/it]
0it [00:00, ?it/s][A
1it [00:01,  1.01s/it][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 1it [00:01,  1.01s/it][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 2it [00:01,  1.36it/s][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 2it [00:01,  1.36it/s][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 3it [00:02,  1.57it/s][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 3it [00:02,  1.57it/s][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 4it [00:02,  1.39it/s][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 4it [00:02,  1.39it/s][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 5it [00:03,  1.31it/s][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 5it [00:03,  1.31it/s][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 6it [00:05,  1.06s/it][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 6it [00:05,  1.06s/it][A
f1: 0.00000, precision: 1.00000, recal

f1: 1.9999999999600002e-11 precision: 1.0 recall: 9.999999999900001e-12


train loss:0.608623743057251: : 20it [08:18, 21.26s/it]
0it [00:00, ?it/s][A
1it [00:00,  1.87it/s][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 1it [00:00,  1.87it/s][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 2it [00:01,  1.85it/s][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 2it [00:01,  1.85it/s][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 3it [00:01,  1.85it/s][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 3it [00:01,  1.85it/s][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 4it [00:02,  1.50it/s][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 4it [00:02,  1.50it/s][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 5it [00:03,  1.37it/s][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 5it [00:03,  1.37it/s][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 6it [00:04,  1.31it/s][A
f1: 0.00000, precision: 1.00000, recall: 0.00000: : 6it [00:04,  1.31it/s][A
f1: 0.00000, precision: 1.00000, recal

f1: 1.9999999999600002e-11 precision: 1.0 recall: 9.999999999900001e-12


train loss:0.608623743057251: : 24it [10:01, 25.08s/it]


KeyboardInterrupt: ignored