In [4]:
import torch


REL_SIZE = 2

REL_PATH = "/kaggle/input/casrel-adr-data/rel.csv"
TRAIN_PATH = '/kaggle/input/casrel-adr-data/adr-train.csv'
#TEST_PATH = './data/input/adr-test.csv'

BERT_MODEL_NAME = "dmis-lab/biobert-base-cased-v1.2"#'./bert-model/biobert-base-cased-v1.2'

MODEL_DIR = './'

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

BATCH_SIZE = 100  #作为demo batch设为2，在训练时调味100
BERT_DIM = 768  #BERT的输出维数
LR = 1e-3       #学习率
EPOCH = 100

# sub和obj的head与tail的判断阈值
SUB_HEAD_BAR = 0.6
SUB_TAIL_BAR = 0.6

OBJ_HEAD_BAR = 0.6
OBJ_TAIL_BAR = 0.6


#降权
CLS_WEIGHT_COEF = [0.3, 1.0]
SUB_WEIGHT_COEF = 3


EPS = 1e-10

In [5]:
import  pandas as pd


# def process_data(file_path):
#     data = pd.read_csv(file_path)
#     # print(data.columns)
#     processed_data = []
#
#     for index, line in data.iterrows():
#         dct = {
#             "text": [],
#             "spo_list": {
#                 "subject": [],
#                 "object": [],
#                 "predicate": [],
#             },
#         }
#         dct["text"] = line["text"]
#         dct["spo_list"]["object"] = line["effect"]
#         dct["spo_list"]["subject"] = line["drug"]
#         dct["spo_list"]["predicate"] = "causes"
#         processed_data.append(dct)
#
#
#     return processed_data
#
#
#
def merge_data(data):
    name_list = []
    data_list = []

    for each in data:
        if each["text"] not in name_list:
            name_list.append(each["text"])
            data_list.append(each)
        else:
            index = name_list.index(each["text"])
            data_list[index]["spo_list"].append(each["spo_list"][0])

    return data_list


def process_data(file_path):
    data = pd.read_csv(file_path)
    # print(data.columns)
    processed_data = []

    for index, line in data.iterrows():
        dct = {
            "text": [],
            "spo_list": [],
        }
        spo = {
            "subject": "",
            "object": "",
            "predicate": ""
        }
        dct["text"] = line["text"]

        spo["object"] = line["effect"]
        spo["subject"] = line["drug"]
        spo["predicate"] = "causes"

        dct["spo_list"].append(spo)
        processed_data.append(dct)

    return merge_data(processed_data)






In [6]:
import torch.utils.data as data
import pandas as pd
import random
from transformers import BertTokenizerFast

#返回一个rel2id 和id2rel
def get_rel():
    df = pd.read_csv(REL_PATH, names=['rel', 'id'])
    return df['rel'].tolist(), dict(df.values)

#生成长度为len，hot_pos位置为1其余位置为0的独热编码
def multihot(length, hot_pos):
    return [1 if i in hot_pos else 0 for i in range(length)]



class Dataset(data.Dataset):
    def __init__(self, type='train'):
        super().__init__()
        _, self.rel2id = get_rel()

        # 加载文件
        if type == 'train':
            file_path = TRAIN_PATH
        elif type == 'test':
            file_path = TEST_PATH

        self.lines = process_data(file_path)

        # 加载bert
        self.tokenizer = BertTokenizerFast.from_pretrained(BERT_MODEL_NAME)

    def __len__(self):
        return len(self.lines)

    def __getitem__(self, index):
        info = self.lines[index]
        # print(type(info))
        # exit()
        tokenized = self.tokenizer(info['text'], return_offsets_mapping=True)
        info['input_ids'] = tokenized['input_ids']
        info['offset_mapping'] = tokenized['offset_mapping']

        return self.parse_dict(info)


    def get_pos_id(self, source, elem):
        for head_id in range(len(source)):
            tail_id = head_id + len(elem)
            if source[head_id:tail_id] == elem:
                return head_id, tail_id - 1


    def collate_fn(self,batch):
        #获得最长的句子长度,便于填充pad
        batch.sort(key=lambda x : len(x["input_ids"]),reverse=True)
        max_len = len(batch[0]["input_ids"])

        batch_text = {
            'text': [],
            'input_ids': [],
            'offset_mapping': [],
            'triple_list': [],
        }
        batch_mask = []

        batch_sub = {
            'heads_seq': [],
            'tails_seq': [],
        }
        batch_sub_rnd = {
            'head_seq': [],
            'tail_seq': [],
        }
        batch_obj_rel = {
            'heads_mx': [],
            'tails_mx': [],
        }

        #对batch中的每一个item进行处理
        for item in batch:
            input_ids = item["input_ids"]  #对元素进行pad填充
            item_len = len(input_ids)
            pad_len = max_len - item_len
            input_ids = input_ids + [0] * pad_len
            mask = [1] * item_len + [0] * pad_len
            # print(mask)
            # exit()

            sub_head_seq = multihot(max_len,item["sub_head_ids"])
            sub_tail_seq = multihot(max_len, item["sub_tail_ids"])
            # print(item["sub_head_ids"])
            # print(sub_head_seq)
            # exit()

            if len(item['triple_id_list']) == 0:  #如果没有三元组则continue
                continue

            sub_rnd = random.choice(item['triple_id_list'])[0]
            sub_rnd_head_seq = multihot(max_len, [sub_rnd[0]])
            sub_rnd_tail_seq = multihot(max_len, [sub_rnd[1]])

            #根据随机subject计算relations矩阵
            obj_head_mx = [[0] * REL_SIZE for _ in range(max_len)]   #生成两个二维全0矩阵（一个head矩阵一个tail矩阵）
            obj_tail_mx = [[0] * REL_SIZE for _ in range(max_len)]

            for triple in item["triple_id_list"]:                    #对全0矩阵进行填充获得obj的head和tail的rel矩阵
                rel_id = triple[1]
                head_id, tail_id = triple[2]
                if triple[0] == sub_rnd:       # 对于本课题可以取消这一步，因为数据集中有且仅有唯一的sub
                    obj_head_mx[head_id][rel_id] = 1
                    obj_tail_mx[tail_id][rel_id] = 1

            #重新组装batch，一条item压入一组信息
            batch_text["text"].append(item["text"])
            batch_text["input_ids"].append(input_ids)
            batch_text["offset_mapping"].append(item["offset_mapping"])
            batch_text["triple_list"].append(item["triple_list"])

            batch_mask.append(mask)

            batch_sub["heads_seq"].append(sub_head_seq)
            batch_sub["tails_seq"].append(sub_tail_seq)

            batch_sub_rnd["head_seq"].append(sub_rnd_head_seq)
            batch_sub_rnd["tail_seq"].append(sub_rnd_tail_seq)

            # print(sub_rnd_head_seq)
            # print(sub_head_seq)
            # exit()

            batch_obj_rel["heads_mx"].append(obj_head_mx)
            batch_obj_rel["tails_mx"].append(obj_tail_mx)

        return batch_mask,(batch_text,batch_sub_rnd),(batch_sub,batch_obj_rel)




    def parse_dict(self, info):  #对dict串进行解析
        text = info['text']
        input_ids = info['input_ids']  #整个text的input_ids

        dct = {
            'text': text,
            'input_ids': input_ids,
            'offset_mapping': info['offset_mapping'],
            'sub_head_ids': [],
            'sub_tail_ids': [],
            'triple_list': [],
            'triple_id_list': []
        }

        for spo in info['spo_list']:
            subject = spo['subject']
            object = spo['object']
            predicate = spo['predicate']
            dct['triple_list'].append((subject, predicate, object))
            # 计算 subject 实体位置
            tokenized = self.tokenizer(subject, add_special_tokens=False)
            sub_token = tokenized['input_ids']
            sub_pos_id = self.get_pos_id(input_ids, sub_token)
            if not sub_pos_id:
                continue
            sub_head_id, sub_tail_id = sub_pos_id
            # 计算 object 实体位置
            tokenized = self.tokenizer(object, add_special_tokens=False)
            obj_token = tokenized['input_ids']
            obj_pos_id = self.get_pos_id(input_ids, obj_token)
            if not obj_pos_id:
                continue
            obj_head_id, obj_tail_id = obj_pos_id
            # 数据组装
            dct['sub_head_ids'].append(sub_head_id)
            dct['sub_tail_ids'].append(sub_tail_id)
            dct['triple_id_list'].append((
                [sub_head_id, sub_tail_id],
                self.rel2id[predicate],
                [obj_head_id, obj_tail_id],
            ))
        return dct

        # # 数据预处理部分已经格式化
        # spo = info["spo_list"]
        # subject = spo['subject']
        # object = spo['object']
        # predicate = spo['predicate']
        # dct['triple_list'].append((subject, predicate, object))
        #
        # tokenized = self.tokenizer(subject, add_special_tokens=False)
        # sub_token = tokenized['input_ids']
        # # print(input_ids)
        # # print(sub_token)
        # # exit()
        # sub_pos_id = self.get_pos_id(input_ids, sub_token) #通过编码后的
        #
        # sub_head_id, sub_tail_id = sub_pos_id
        # # 计算 object 实体位置
        # tokenized = self.tokenizer(object, add_special_tokens=False)
        # obj_token = tokenized['input_ids']
        # obj_pos_id = self.get_pos_id(input_ids, obj_token)
        #
        # obj_head_id, obj_tail_id = obj_pos_id
        # # 数据组装
        # dct['sub_head_ids'].append(sub_head_id)
        # dct['sub_tail_ids'].append(sub_tail_id)
        #
        # dct['triple_id_list'].append((
        #     [sub_head_id, sub_tail_id],
        #     self.rel2id[predicate],
        #     [obj_head_id, obj_tail_id],
        # ))
        #
        # return dct

if __name__ == '__main__':
    dataset = Dataset()
    loader = data.DataLoader(dataset, shuffle=False, batch_size=2, collate_fn=dataset.collate_fn)
    print(next(iter(loader)))
    exit()

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.11k [00:00<?, ?B/s]

([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], ({'text': ["Immobilization, while Paget's bone disease was present, and perhaps enhanced activation of dihydrotachysterol by rifampicin, could have led to increased calcium-release into the circulation.", 'Intravenous azithromycin-induced ototoxicity.'], 'input_ids': [[101, 13280, 3702, 15197, 2734, 117, 1229, 3674, 1204, 112, 188, 6028, 3653, 1108, 1675, 117, 1105, 3229, 9927, 14915, 1104, 4267, 7889, 23632, 16339, 8992, 4648, 4063, 1118, 187, 8914, 19471, 27989, 1179, 117, 1180, 1138, 1521, 1106, 2569, 15355, 118, 1836, 1154, 1103, 9097, 119, 102], [101, 1107, 4487, 7912, 2285, 170, 5303, 1582, 16071, 1183, 16430, 118, 10645, 184, 2430, 2430, 8745, 9041, 119, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [7]:
import torch.nn as nn
from transformers import BertModel
import torch
import torch.nn.functional as F

# 忽略 transformers 警告
from transformers import logging
logging.set_verbosity_error()


class CasRel(nn.Module):

    #初始化model
    def __init__(self):
        super().__init__()
        self.bert = BertModel.from_pretrained(BERT_MODEL_NAME)
        # 冻结Bert参数，只训练下游模型
        for name, param in self.bert.named_parameters():
            param.requires_grad = False

        #定义CasRel网络
        self.sub_head_linear = nn.Linear(BERT_DIM, 1)   #sub只需要一维即
        self.sub_tail_linear = nn.Linear(BERT_DIM, 1)

        self.obj_head_linear = nn.Linear(BERT_DIM, REL_SIZE)  #预测的obj矩阵需要REL_SIZE维计
        self.obj_tail_linear = nn.Linear(BERT_DIM, REL_SIZE)


    #subject头尾标记预测
    def get_encoded_text(self, input_ids, mask):
        return self.bert(input_ids, attention_mask=mask)[0]

    def get_subs(self, encoded_text):
        #encoded_text(b,c,768) -> (b,c,1)
        #对每个单词进行sigmoid预测
        pred_sub_head = torch.sigmoid(self.sub_head_linear(encoded_text))
        pred_sub_tail = torch.sigmoid(self.sub_tail_linear(encoded_text))

        return pred_sub_head, pred_sub_tail

    def get_objs_for_specific_sub(self, encoded_text, sub_head_seq, sub_tail_seq):  # 获得预测的obj-rel矩阵
        # sub_head_seq.shape (b, c) -> (b, 1, c)
        sub_head_seq = sub_head_seq.unsqueeze(1).float()
        sub_tail_seq = sub_tail_seq.unsqueeze(1).float()

        # encoded_text.shape (b, c, 768)
        sub_head = torch.matmul(sub_head_seq, encoded_text)   #获得head和tail的编码并加在encoded_text中
        sub_tail = torch.matmul(sub_tail_seq, encoded_text)

        encoded_text = encoded_text + (sub_head + sub_tail) / 2

        # encoded_text.shape (b, c, 768)
        pred_obj_head = torch.sigmoid(self.obj_head_linear(encoded_text))
        pred_obj_tail = torch.sigmoid(self.obj_tail_linear(encoded_text))

        # shape (b, c, REL_SIZE)
        return pred_obj_head, pred_obj_tail

    def forward(self, input, mask):

        input_ids, sub_head_seq, sub_tail_seq = input
        encoded_text = self.get_encoded_text(input_ids, mask)
        pred_sub_head, pred_sub_tail = self.get_subs(encoded_text)


        input_ids, sub_head_seq, sub_tail_seq = input
        encoded_text = self.get_encoded_text(input_ids, mask)

        # 预测subject首尾序列
        pred_sub_head, pred_sub_tail = self.get_subs(encoded_text)

        # 预测relation-object矩阵
        pred_obj_head, pred_obj_tail = self.get_objs_for_specific_sub(encoded_text, sub_head_seq, sub_tail_seq)

        return encoded_text, (pred_sub_head, pred_sub_tail, pred_obj_head, pred_obj_tail)

    def loss_fn(self, true_y, pred_y, mask):

        def calc_loss(pred, true, mask):
            true = true.float()

            # pred.shape (b, c, 1) -> (b, c)
            pred = pred.squeeze(-1)
            weight = torch.where(true > 0, CLS_WEIGHT_COEF[1], CLS_WEIGHT_COEF[0])  # 分配权重


            loss = F.binary_cross_entropy(pred, true, weight=weight, reduction='none')


            if loss.shape != mask.shape:
                mask = mask.unsqueeze(-1)

            return torch.sum(loss * mask) / torch.sum(mask)  #通过与mask相乘将pad补充的元素损失进行归0

        pred_sub_head, pred_sub_tail, pred_obj_head, pred_obj_tail = pred_y
        true_sub_head, true_sub_tail, true_obj_head, true_obj_tail = true_y


        return calc_loss(pred_sub_head, true_sub_head, mask) * SUB_WEIGHT_COEF + \
               calc_loss(pred_sub_tail, true_sub_tail, mask) * SUB_WEIGHT_COEF + \
               calc_loss(pred_obj_head, true_obj_head, mask) + \
               calc_loss(pred_obj_tail, true_obj_tail, mask)













huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [None]:

from torch.utils import data

import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'


def get_triple_list(sub_head_ids, sub_tail_ids, model, encoded_text, text, mask, offset_mapping):
    id2rel, _ = get_rel()
    triple_list = []
    for sub_head_id in sub_head_ids:
        sub_tail_ids = sub_tail_ids[sub_tail_ids >= sub_head_id]
        if len(sub_tail_ids) == 0:
            continue
        sub_tail_id = sub_tail_ids[0]
        if mask[sub_head_id] == 0 or mask[sub_tail_id] == 0:
            continue
        # 根据位置信息反推出 subject 文本内容
        sub_head_pos_id = offset_mapping[sub_head_id][0]
        sub_tail_pos_id = offset_mapping[sub_tail_id][1]
        subject_text = text[sub_head_pos_id:sub_tail_pos_id]
        # 根据 subject 计算出对应 object 和 relation
        sub_head_seq = torch.tensor(multihot(len(mask), sub_head_id)).to(DEVICE)
        sub_tail_seq = torch.tensor(multihot(len(mask), sub_tail_id)).to(DEVICE)

        pred_obj_head, pred_obj_tail = model.get_objs_for_specific_sub(\
            encoded_text.unsqueeze(0), sub_head_seq.unsqueeze(0), sub_tail_seq.unsqueeze(0))
        # 按分类找对应关系
        pred_obj_head = pred_obj_head[0].T
        pred_obj_tail = pred_obj_tail[0].T
        for j in range(len(pred_obj_head)):
            obj_head_ids = torch.where(pred_obj_head[j] > OBJ_HEAD_BAR)[0]
            obj_tail_ids = torch.where(pred_obj_tail[j] > OBJ_TAIL_BAR)[0]
            for obj_head_id in obj_head_ids:
                obj_tail_ids = obj_tail_ids[obj_tail_ids >= obj_head_id]
                if len(obj_tail_ids) == 0:
                    continue
                obj_tail_id = obj_tail_ids[0]
                if mask[obj_head_id] == 0 or mask[obj_tail_id] == 0:
                    continue
                # 根据位置信息反推出 object 文本内容，mapping中已经有移位，不需要再加1
                obj_head_pos_id = offset_mapping[obj_head_id][0]
                obj_tail_pos_id = offset_mapping[obj_tail_id][1]
                object_text = text[obj_head_pos_id:obj_tail_pos_id]
                triple_list.append((subject_text, id2rel[j], object_text))
    return list(set(triple_list))



def report(model, encoded_text, pred_y, batch_text, batch_mask):
    # 计算三元结构，和统计指标
    pred_sub_head, pred_sub_tail, _, _ = pred_y
    true_triple_list = batch_text['triple_list']
    pred_triple_list = []

    correct_num, predict_num, gold_num = 0, 0, 0

    # 遍历batch
    for i in range(len(pred_sub_head)):
        text = batch_text['text'][i]
        true_triple_item = true_triple_list[i]
        mask = batch_mask[i]
        offset_mapping = batch_text['offset_mapping'][i]

        sub_head_ids = torch.where(pred_sub_head[i] > SUB_HEAD_BAR)[0]
        sub_tail_ids = torch.where(pred_sub_tail[i] > SUB_TAIL_BAR)[0]

        pred_triple_item = get_triple_list(sub_head_ids, sub_tail_ids, model, \
            encoded_text[i], text, mask, offset_mapping)

        # 统计个数
        correct_num += len(set(true_triple_item) & set(pred_triple_item))
        predict_num += len(set(pred_triple_item))
        gold_num += len(set(true_triple_item))

        pred_triple_list.append(pred_triple_item)

    precision = correct_num / (predict_num + EPS)
    recall = correct_num / (gold_num + EPS)
    f1_score = 2 * precision * recall / (precision + recall + EPS)
    print('\tcorrect_num:', correct_num, 'predict_num:', predict_num, 'gold_num:', gold_num)
    print('\tprecision:%.3f' % precision, 'recall:%.3f' % recall, 'f1_score:%.3f' % f1_score)

if __name__ == '__main__':
    model = CasRel().to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    dataset = Dataset()

    for e in range(EPOCH):
        loader = data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=dataset.collate_fn)
        for b, (batch_mask, batch_x, batch_y) in enumerate(loader):

            # print(batch_x)
            # exit()

            batch_text, batch_sub_rnd = batch_x
            batch_sub, batch_obj_rel = batch_y

            # 整理input数据并预测
            input_mask = torch.tensor(batch_mask).to(DEVICE)

            input = (
                torch.tensor(batch_text['input_ids']).to(DEVICE),
                torch.tensor(batch_sub_rnd['head_seq']).to(DEVICE),
                torch.tensor(batch_sub_rnd['tail_seq']).to(DEVICE),
            )
            encoded_text, pred_y = model(input, input_mask)

            # 整理target数据并计算损失
            true_y = (
                torch.tensor(batch_sub['heads_seq']).to(DEVICE),
                torch.tensor(batch_sub['tails_seq']).to(DEVICE),
                torch.tensor(batch_obj_rel['heads_mx']).to(DEVICE),
                torch.tensor(batch_obj_rel['tails_mx']).to(DEVICE),
            )


            loss = model.loss_fn(true_y, pred_y, input_mask)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if b % 50 == 0:
                print('>> epoch:', e, 'batch:', b, 'loss:', loss.item())
            # print('>> epoch:', e, 'batch:', b, 'loss:', loss.item())
            if b % 500 == 0:
                report(model, encoded_text, pred_y, batch_text, batch_mask)

        if e % 10 == 0:
            torch.save(model, MODEL_DIR + f'model_{e}.pth')

Downloading pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

>> epoch: 0 batch: 0 loss: 2.484619617462158
	correct_num: 0 predict_num: 975 gold_num: 150
	precision:0.000 recall:0.000 f1_score:0.000
