In [7]:
# -*- coding: utf-8 -*-

# @Author  : xmh
# @Time    : 2021/3/11 16:53
# @File    : config_rel.py

"""
file description:：

"""
import torch

if torch.cuda.is_available():
    USE_CUDA = True
    print("USE_CUDA....")
else:
    USE_CUDA = False


class ConfigRel:
    def __init__(self,
                 lr=0.001,
                 epochs=100,
                 vocab_size=22000,  # 22000,
                 embedding_dim=100,
                 hidden_dim_lstm=128,
                 num_layers=3,
                 batch_size=32,
                 layer_size=128,
                 token_type_dim=8
                 ):
        self.lr = lr
        self.epochs = epochs
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim_lstm = hidden_dim_lstm
        self.num_layers = num_layers
        self.batch_size = batch_size
        self.layer_size = layer_size
        self.token_type_dim = token_type_dim
        self.relations = ["causes",'丈夫', '上映时间', '专业代码', '主持人', '主演', '主角', '人口数量', '作曲', '作者', '作词', '修业年限', '出品公司', '出版社',
                          '出生地',
                          '出生日期', '创始人', '制片人', '占地面积', '号', '嘉宾', '国籍', '妻子', '字', '官方语言', '导演', '总部地点', '成立日期',
                          '所在城市', '所属专辑',
                          '改编自', '朝代', '歌手', '母亲', '毕业院校', '民族', '气候', '注册资本', '海拔', '父亲', '目', '祖籍', '简称', '编剧', '董事长',
                          '身高',
                          '连载网站', '邮政编码', '面积', '首都']
        self.num_relations = len(self.relations)
        self.token_types_origin = ['Date', 'Number', 'Text', '书籍', '人物', '企业', '作品', '出版社', '历史人物', '国家', '图书作品', '地点',
                                   '城市', '学校', '学科专业',
                                   '影视作品', '景点', '机构', '歌曲', '气候', '生物', '电视综艺', '目', '网站', '网络小说', '行政区', '语言', '音乐专辑',"drug","adverse"]
        self.token_types = self.get_token_types()
        self.num_token_type = len(self.token_types)
        self.vocab_file = '../data/vocab.txt'
        self.max_seq_length = 256
        self.num_sample = 1480
        
        self.dropout_embedding = 0.1  # 从0.2到0.1
        self.dropout_lstm = 0.1
        self.dropout_lstm_output = 0.9
        self.dropout_head = 0.9  # 只更改这个参数 0.9到0.5
        self.dropout_ner = 0.8
        self.use_dropout = True
        self.threshold_rel = 0.95  # 从0.7到0.95
        self.teach_rate = 0.2
        self.ner_checkpoint_path = '../models/rel_cls/'
        self.pretrained = False
        self.pad_token_id = 0
        self.rel_num = 500

        self.pos_dim = 32
    
    def get_token_types(self):
        token_type_bio = []
        for token_type in self.token_types_origin:
            token_type_bio.append('B-' + token_type)
            token_type_bio.append('I-' + token_type)
        token_type_bio.append('O')
        
        return token_type_bio



USE_CUDA....


In [8]:
import  pandas as pd


# def process_data(file_path):
#     data = pd.read_csv(file_path)
#     # print(data.columns)
#     processed_data = []
#
#     for index, line in data.iterrows():
#         dct = {
#             "text": [],
#             "spo_list": {
#                 "subject": [],
#                 "object": [],
#                 "predicate": [],
#             },
#         }
#         dct["text"] = line["text"]
#         dct["spo_list"]["object"] = line["effect"]
#         dct["spo_list"]["subject"] = line["drug"]
#         dct["spo_list"]["predicate"] = "causes"
#         processed_data.append(dct)
#
#
#     return processed_data
#
#
#
def merge_data(data):
    name_list = []
    data_list = []

    for each in data:
        if each["text"] not in name_list:
            name_list.append(each["text"])
            data_list.append(each)
        else:
            index = name_list.index(each["text"])
            data_list[index]["spo_list"].append(each["spo_list"][0])

    return data_list


def process_data(file_path):
    data = pd.read_csv(file_path)
    # print(data.columns)
    processed_data = []

    for index, line in data.iterrows():
        dct = {
            "text": [],
            "spo_list": [],
        }
        spo = {
            "subject": "",
            "predicate": "",
            "object": "",
            "subject_type":"drug",
            "object_type":"adverse"

        }

        dct["text"] = line["text"]
        spo["object"] = line["effect"]
        spo["subject"] = line["drug"]
        spo["predicate"] = "causes"
        dct["spo_list"].append(({"subject":spo["subject"],"predicate":spo["predicate"],"object":spo["object"],"subject_type":spo["subject_type"],"object_type":spo["object_type"]}))
        processed_data.append(dct)

    return merge_data(processed_data)






In [9]:
# -*- coding: utf-8 -*-

# @Author  : xmh
# @Time    : 2021/3/11 15:20
# @File    : process_rel.py

"""
file description:：

"""
import json
import torch

import copy
from transformers import BertTokenizer
import codecs
from collections import defaultdict


class DataPreparationRel:
    def __init__(self, config):
        self.config = config
        # self.get_token2id()
        self.rel_cnt = defaultdict(int)
    
    def get_data(self, file_path, is_test=False):
        data = []

        datas = process_data(file_path)

        for data_item in datas:
            spo_list = data_item['spo_list']
            text = data_item['text']
            text = text.lower()
            for spo_item in spo_list:
                subject = spo_item["subject"]
                subject = subject.lower()
                object = spo_item["object"]
                object = object.lower()
                # 增加位置信息
                # index_s = text.index(subject)
                # index_o = text.index(object)
                # position_s, position_o = [], []
                # for i, word in enumerate(text):
                #     if word not in self.word2id:
                #         continue
                #     position_s.append(i-index_s+self.config.max_seq_length*2)
                #     position_o.append(i-index_o+self.config.max_seq_length*2)
                if not is_test:
                    relation = spo_item['predicate']
                    if self.rel_cnt[relation] > self.config.rel_num:
                        continue
                    self.rel_cnt[relation] += 1

                else:
                    relation = []
                # sentence_cls = '$'.join([subject, object, text.replace(subject, '#'*len(subject)).replace(object, '#'*len(object))])
                sentence_cls = ''.join([subject, object, text])
                # sentence_cls = text
                item = {'sentence_cls': sentence_cls, 'relation': relation, 'text': text,
                        'subject': subject, 'object': object}  # 'position_s': position_s, 'position_o': position_o}
                data.append(item)
            # 添加负样本
            sentence_neg = '$'.join([object, subject, text])
            # sentence_neg = '$'.join(
            #     [object, subject, text.replace(subject, '#' * len(subject)).replace(object, '#' * len(object))])
            item_neg = {'sentence_cls': sentence_neg, 'relation': 'N', 'text': text,
                        'subject': object, 'object': subject}
            data.append(item_neg)

        dataset = Dataset(data)
        if is_test:
            dataset.is_test = True
        data_loader = torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=self.config.batch_size,
            collate_fn=dataset.collate_fn,
            shuffle=True,
            drop_last=True
        )

        return data_loader


    def get_token2id(self):
        self.word2id = {}
        with codecs.open('../data/vec.txt', 'r', encoding='utf-8') as f:
            cnt = 0
            for line in f.readlines():
                self.word2id[line.split()[0]] = cnt
                cnt += 1

    def get_train_dev_data(self, path_train=None, path_dev=None, path_test=None, is_test=False):
        train_loader, dev_loader, test_loader = None, None, None
        if path_train is not None:
            train_loader = self.get_data(path_train)
        if path_dev is not None:
            dev_loader = self.get_data(path_dev)
        if path_test is not None:
            test_loader = self.get_data(path_test, is_test=True)
    
        return train_loader, dev_loader, test_loader
    
    
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = copy.deepcopy(data)
        self.is_test = False
        with open('/kaggle/input/lstm-crf-adr/rel2id.json', 'r', encoding='utf-8') as f:
            self.rel2id = json.load(f)
        vocab_file = 'dmis-lab/biobert-base-cased-v1.2'
        self.bert_tokenizer = BertTokenizer.from_pretrained(vocab_file)

        # vocab_file = '../data/vec.txt'
        # self.get_token2id()

    # def get_token2id(self):
    #     self.word2id = {}
    #     with codecs.open('../data/vec.txt', 'r', encoding='utf-8') as f:
    #         cnt = 0
    #         for line in f.readlines():
    #             self.word2id[line.split()[0]] = cnt
    #             cnt += 1

    def __getitem__(self, index):
        sentence_cls = self.data[index]['sentence_cls']
        relation = self.data[index]['relation']
        text = self.data[index]['text']
        subject = self.data[index]['subject']
        object = self.data[index]['object']
        # position_s = self.data[index]['position_s']
        # position_o = self.data[index]['position_o']
        
        data_info = {}
        for key in self.data[0].keys():
            if key in locals():
                data_info[key] = locals()[key]
        
        return data_info
    
    def __len__(self):
        return len(self.data)
    
    def collate_fn(self, data_batch):
        def merge(sequences):
            lengths = [len(seq) for seq in sequences]
            max_length = max(lengths)
            # padded_seqs = torch.zeros(len(sequences), max_length)
            padded_seqs = torch.zeros(len(sequences), max_length)
            tmp_pad = torch.ones(1, max_length)
            mask_tokens = torch.zeros(len(sequences), max_length)
            for i, seq in enumerate(sequences):
                end = lengths[i]
                seq = torch.LongTensor(seq)
                if len(seq) != 0:
                    padded_seqs[i, :end] = seq[:end]
                    mask_tokens[i, :end] = tmp_pad[0, :end]
            
            return padded_seqs, mask_tokens
        
        item_info = {}
        for key in data_batch[0].keys():
            item_info[key] = [d[key] for d in data_batch]

        # 转化为数值
        sentence_cls = [self.bert_tokenizer.encode(sentence, add_special_tokens=True) for sentence in item_info['sentence_cls']]
        # sentence_cls = [[] for _ in range(len(item_info['sentence_cls']))]
        # for i, sentence in enumerate(item_info['sentence_cls']):
        #     tmp = []
        #     for c in sentence:
        #         if c in self.word2id:
        #             tmp.append(self.word2id[c])
        #     sentence_cls[i] = tmp

        if not self.is_test:
            relation = torch.Tensor([self.rel2id[rel] for rel in item_info['relation']]).to(torch.int64)
        
        # 批量数据对齐
        sentence_cls, mask_tokens = merge(sentence_cls)
        sentence_cls = sentence_cls.to(torch.int64)
        mask_tokens = mask_tokens.to(torch.int64)
        relation = relation.to(torch.int64)
        # position_s, _ = merge(item_info['position_s'])
        # position_o, _ = merge(item_info['position_o'])
        if USE_CUDA:
            sentence_cls = sentence_cls.contiguous().cuda()
            mask_tokens = mask_tokens.contiguous().cuda()
            # position_s = position_s.contiguous().cuda()
            # position_o = position_o.contiguous().cuda()
        else:
            sentence_cls = sentence_cls.contiguous()
            mask_tokens = mask_tokens.contiguous()
            # position_s = position_s.contiguous()
            # position_o = position_o.contiguous()
        if not self.is_test:
            if USE_CUDA:
                relation = relation.contiguous().cuda()
            else:
                relation = relation.contiguous()

        data_info = {"mask_tokens": mask_tokens.to(torch.uint8)}
        data_info['text'] = item_info['text']
        data_info['subject'] = item_info['subject']
        data_info['object'] = item_info['object']
        for key in item_info.keys():
            if key in locals():
                data_info[key] = locals()[key]
        
        return data_info
        
        
if __name__ == '__main__':
    config = ConfigRel()
    process = DataPreparationRel(config)
    #train_loader, dev_loader, test_loader = process.get_train_dev_data('adr-train.csv')
    



In [None]:
# -*- coding: utf-8 -*-

# @Author  : xmh
# @Time    : 2021/3/11 16:36
# @File    : model_rel.py

"""
file description:：

"""
import torch
import torch.nn as nn
import torch.nn.functional as F


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)   # 使用相同的初始化种子，保证每次初始化结果一直，便于调试


class Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        setup_seed(1)
        # self.query = nn.Parameter(torch.randn(1, config.hidden_dim_lstm))  # [batch, 1, hidden_dim]
    
    # def forward(self, H):
    #     M = torch.tanh(H)  # H [batch_size, sentence_length, hidden_dim_lstm]
    #     attention_prob = torch.matmul(M, self.query.transpose(-1, -2))  # [batch_size, sentence_length, 1]
    #     alpha = F.softmax(attention_prob,dim=-1)
    #     attention_output = torch.matmul(alpha.transpose(-1, -2), H)  # [batch_size, 1, hidden_dim_lstm]
    #     attention_output = attention_output.squeeze(axis=1)
    #     attention_output = torch.tanh(attention_output)
    #     return attention_output
    
    def forward(self, output_lstm, hidden_lstm):
        hidden_lstm = torch.sum(hidden_lstm, dim=0)
        att_weights = torch.matmul(output_lstm, hidden_lstm.unsqueeze(2)).squeeze(2)
        alpha = F.softmax(att_weights, dim=1)
        new_hidden = torch.matmul(output_lstm.transpose(-1, -2), alpha.unsqueeze(2)).squeeze(2)
        
        return new_hidden

class AttBiLSTM(nn.Module):
    def __init__(self, config, embedding_pre=None):
        super().__init__()
        setup_seed(1)
        self.embedding_dim = config.embedding_dim
        self.vocab_size = config.vocab_size
        self.hidden_dim = config.hidden_dim_lstm
        self.num_layers = config.num_layers
        self.batch_size = config.batch_size
        self.embed_dropout = nn.Dropout(config.dropout_embedding)
        self.lstm_dropout = nn.Dropout(config.dropout_lstm_output)
        self.pretrained = config.pretrained
        self.config = config
        self.relation_embed_layer = nn.Embedding(config.num_relations, self.hidden_dim)
        self.relations = torch.Tensor([i for i in range(config.num_relations)])
        if USE_CUDA:
            self.relations = self.relations.cuda()
        self.relation_bias = nn.Parameter(torch.randn(config.num_relations))
        
        assert (self.pretrained is True and embedding_pre is not None) or \
               (self.pretrained is False and embedding_pre is None), "预训练必须有训练好的embedding_pre"
        # 定义网络层
        # 对于关系抽取，命名实体识别和关系抽取共享编码层
        if self.pretrained:
            # self.word_embedding = nn.Embedding.from_pretrained(torch.FloatTensor(embedding_pre), freeze=False)
            self.word_embedding = nn.Embedding.from_pretrained(torch.FloatTensor(embedding_pre), freeze=False)
        else:
            self.word_embedding = nn.Embedding(config.vocab_size, config.embedding_dim, padding_idx=config.pad_token_id)
        
        # self.pos1_embedding = nn.Embedding(config.pos_size, config.embedding_dim)
        # self.pos2_embedding = nn.Embedding(config.pos_size, config.embedding_dim)
        self.gru = nn.GRU(config.embedding_dim+2*config.pos_dim, config.hidden_dim_lstm, num_layers=config.num_layers, batch_first=True, bidirectional=True,
                          dropout=config.dropout_lstm)
        self.attention_layer = Attention(config)
        # self.classifier = nn.Linear(config.hidden_dim_lstm, config.num_relations)

        if USE_CUDA:
            self.weights_rel = (torch.ones(self.config.num_relations) * 6).cuda()
        else:
            self.weights_rel = torch.ones(self.config.num_relations) * 6
        # self.weights_rel[9], self.weights_rel[13], self.weights_rel[14], self.weights_rel[46] = 100, 100, 100, 100
        self.weights_rel[0] = 1
        self.hidden_init = torch.randn(2 * self.num_layers, self.batch_size, self.hidden_dim)
        if USE_CUDA:
            self.hidden_init = self.hidden_init.cuda()
        # self.pos_embedding_layer = nn.Embedding(config.max_seq_length*4, config.pos_dim)
    
    def forward(self, data_item, is_test=False):

        word_embeddings = self.word_embedding(data_item['sentence_cls'].to(torch.int64))
        # pos1_embeddings = self.pos_embedding_layer(data_item['position_s'].to(torch.int64))
        # pos2_embeddings = self.pos_embedding_layer(data_item['position_o'].to(torch.int64))
        # embeddings = torch.cat((word_embeddings, pos1_embeddings, pos2_embeddings), 2)  # batch_size, seq, word_dim+2*pos_dim
        embeddings = word_embeddings
        if self.config.use_dropout:
            embeddings = self.embed_dropout(embeddings)

        output, h_n = self.gru(embeddings, self.hidden_init)
        if self.config.use_dropout:
            output = self.lstm_dropout(output)
        attention_input = output[:, :, :self.hidden_dim] + output[:, :, self.hidden_dim:]
        attention_output = self.attention_layer(attention_input, h_n)
        # hidden_cls = torch.tanh(attention_output)
        # output_cls = self.classifier(attention_output)
        relation_embeds = self.relation_embed_layer(self.relations.to(torch.int64))
        # res = torch.add(torch.matmul(attention_output, relation_embeds.transpose(-1, -2)), self.relation_bias)
        res = torch.matmul(attention_output, relation_embeds.transpose(-1, -2))

        if not is_test:
            loss = F.cross_entropy(res, data_item['relation'], self.weights_rel)  # loss = F.cross_entropy(attention_output, data_item['relation'])
            # loss /= self.config.batch_size
        res = F.softmax(res, -1)
        pred = res.argmax(dim=-1)
        if is_test:
            return pred
        return loss, pred


In [None]:
# -*- coding: utf-8 -*-

# @Author  : xmh
# @Time    : 2021/3/11 16:38
# @File    : trainer_rel.py

"""
file description:：

"""
!pip install neptune
# coding=utf-8
import sys


import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

import torch
import torch.nn as nn
from tqdm import tqdm

import numpy as np
import codecs
from transformers import BertForSequenceClassification
import neptune


class Trainer:
    def __init__(self,
                 model,
                 config,
                 train_dataset=None,
                 dev_dataset=None,
                 test_dataset=None
                 ):
        self.model = model
        self.config = config
        self.train_dataset = train_dataset
        self.dev_dataset = dev_dataset
        self.test_dataset = test_dataset
        
        if USE_CUDA:
            self.model = self.model.cuda()
        
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=config.lr)
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=0.5,
                                                                    patience=8, min_lr=1e-5, verbose=True)
        self.get_id2rel()
        
    def train(self):
        print('STARTING TRAIN...')
        self.num_sample_total = len(self.train_dataset) * self.config.batch_size
        loss_eval_best = 1e8
        for epoch in range(self.config.epochs):
            print("Epoch: {}".format(epoch))
            pbar = tqdm(enumerate(self.train_dataset), total=len(self.train_dataset))
            loss_rel_total = 0.0
            self.optimizer.zero_grad()
            # with torch.no_grad():
            for i, data_item in pbar:
                loss_rel, pred_rel = self.model(data_item)
                loss_rel.backward()
                self.optimizer.step()

                loss_rel_total += loss_rel
            loss_rel_train_ave = loss_rel_total / self.num_sample_total
            print("train rel loss: {0}".format(loss_rel_train_ave))
            neptune.log_metric("train rel loss", loss_rel_train_ave)
            if (epoch + 1) % 1 == 0:
                loss_rel_ave = self.evaluate()

            if epoch > 0 and (epoch+1) % 2 == 0:
                if loss_rel_ave < loss_eval_best:
                    loss_eval_best = loss_rel_ave
                    torch.save({
                        'epoch': epoch + 1, 'state_dict': self.model.state_dict(), 'loss_rel_best': loss_eval_best,
                        'optimizer': self.optimizer.state_dict(),
                    },
                        self.config.ner_checkpoint_path + str(epoch) + 'm-' + 'loss' +
                        str("%.2f" % loss_rel_ave) + 'ccks2019_rel.pth'
                    )
    
    def evaluate(self):
        print('STARTING EVALUATION...')
        self.model.train(False)
        pbar_dev = tqdm(enumerate(self.dev_dataset), total=len(self.dev_dataset))
    
        loss_rel_total = 0
        for i, data_item in pbar_dev:
            loss_rel, pred_rel = self.model(data_item)
            loss_rel_total += loss_rel
        
        self.model.train(True)
        loss_rel_ave = loss_rel_total / (len(self.dev_dataset) * self.config.batch_size)
        print("eval rel loss: {0}".format(loss_rel_ave))
        
        print(data_item['text'][1])
        print("subject: {0}, object：{1}".format(data_item['subject'][1], data_item['object'][1]))
        print("object rel: {}".format(self.id2rel[int(data_item['relation'][1])]))
        print("predict rel: {}".format(self.id2rel[int(pred_rel[1])]))
        return loss_rel_ave
    
    def get_id2rel(self):
        self.id2rel = {}
        for i, rel in enumerate(self.config.relations):
            self.id2rel[i] = rel

    def predict(self):
        print('STARTING PREDICTING...')
        self.model.train(False)
        pbar = tqdm(enumerate(self.test_dataset), total=len(self.test_dataset))
        for i, data_item in pbar:
            pred_rel = self.model(data_item, is_test=True)
        self.model.train(True)
        rel_pred = [[] for _ in range(len(pred_rel))]
        for i in range(len(pred_rel)):
            # for item in pred_rel[i]:
            rel_pred[i].append(self.id2rel[int(pred_rel[i])])
        return rel_pred

    def bert_train(self):
        print('STARTING TRAIN...')
        self.num_sample_total = len(self.train_dataset) * self.config.batch_size
        acc_best = 0.0
        for epoch in range(self.config.epochs):
            print("Epoch: {}".format(epoch))
            pbar = tqdm(enumerate(self.train_dataset), total=len(self.train_dataset))
            loss_rel_total = 0.0
            # self.optimizer.zero_grad()
            correct = 0
            # with torch.no_grad():
            for i, data_item in pbar:
                self.optimizer.zero_grad()
                output = self.model(data_item['sentence_cls'], attention_mask=data_item['mask_tokens'], labels=data_item['relation'])
                loss_rel, logits = output[0], output[1]
                loss_rel.backward()
                self.optimizer.step()

                _, pred_rel = torch.max(logits.data, 1)
                correct += pred_rel.data.eq(data_item['relation'].data).cpu().sum().numpy()

                loss_rel_total += loss_rel
            loss_rel_train_ave = loss_rel_total / self.num_sample_total
            print("train rel loss: {0}".format(loss_rel_train_ave))
            # neptune.log_metric("train rel loss", loss_rel_train_ave)
            print("precision_score: {0}".format(correct / self.num_sample_total))
            if (epoch + 1) % 1 == 0:
                acc_eval = self.bert_evaluate()

            if epoch > 0 and (epoch + 1) % 2 == 0:
                if acc_eval > acc_best:
                    acc_best = acc_eval
                    torch.save({
                        'epoch': epoch + 1, 'state_dict': self.model.state_dict(), 'acc_best': acc_best,
                        'optimizer': self.optimizer.state_dict(),
                    },
                        str(epoch) + 'm-' + 'acc' +
                        str("%.2f" % acc_best) + 'ccks2019_rel.pth'
                    )

    def bert_evaluate(self):
        print('STARTING EVALUATION...')
        self.model.train(False)
        pbar_dev = tqdm(enumerate(self.dev_dataset), total=len(self.dev_dataset))

        loss_rel_total = 0
        correct = 0
        with torch.no_grad():
            for i, data_item in pbar_dev:
                output = self.model(data_item['sentence_cls'], attention_mask=data_item['mask_tokens'], labels=data_item['relation'])
                loss_rel, logits = output[0], output[1]
                _, pred_rel = torch.max(logits.data, 1)
                loss_rel_total += loss_rel
                correct += pred_rel.data.eq(data_item['relation'].data).cpu().sum().numpy()

        self.model.train(True)
        loss_rel_ave = loss_rel_total / (len(self.dev_dataset) * self.config.batch_size)
        correct_ave = correct / (len(self.dev_dataset) * self.config.batch_size)
        print("eval rel loss: {0}".format(loss_rel_ave))
        print("precision_score: {0}".format(correct_ave))

        print(data_item['text'][1])
        print("subject: {0}, object：{1}".format(data_item['subject'][1], data_item['object'][1]))
        print("object rel: {}".format(self.id2rel[int(data_item['relation'][1])]))
        print("predict rel: {}".format(self.id2rel[int(pred_rel[1])]))
        return correct_ave

    def bert_predict(self):
        print('STARTING PREDICTING...')
        self.model.train(False)
        pbar = tqdm(enumerate(self.test_dataset), total=len(self.test_dataset))
        for i, data_item in pbar:
            output = self.model(data_item['sentence_cls'], attention_mask=data_item['mask_tokens'])
            logits = output[0]
            _, pred_rel = torch.max(logits.data, 1)
        self.model.train(True)
        # rel_pred = [[] for _ in range(len(pred_rel))]
        rel_pred = []
        for i in range(len(pred_rel)):
            # for item in pred_rel[i]:
            rel_pred.append(self.id2rel[int(pred_rel[i])])
        return rel_pred





if __name__ == '__main__':

    print("Run EntityRelationExtraction REL BERT ...")
    config = ConfigRel()
    model = BertForSequenceClassification.from_pretrained('dmis-lab/biobert-base-cased-v1.2', num_labels=config.num_relations)
    data_processor = DataPreparationRel(config)
    train_loader, dev_loader, test_loader = data_processor.get_train_dev_data(
        '/kaggle/input/lstm-crf-adr/adr-train.csv',
    '/kaggle/input/lstm-crf-adr/adr-test.csv',
    '/kaggle/input/lstm-crf-adr/adr-test.csv',)
    # train_loader, dev_loader, test_loader = data_processor.get_train_dev_data('../data/train_data_small.json')
    trainer = Trainer(model, config, train_loader, dev_loader, test_loader)
    trainer.bert_train()

