<a href="https://colab.research.google.com/github/Hou-jing/paper_record_public/blob/main/E2EM_NYT_11.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install transformers

Collecting transformers
  Downloading transformers-4.17.0-py3-none-any.whl (3.8 MB)
[K     |████████████████████████████████| 3.8 MB 4.0 MB/s 
[?25hCollecting sacremoses
  Downloading sacremoses-0.0.47-py2.py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 30.1 MB/s 
[?25hCollecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 68.0 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.4.0-py3-none-any.whl (67 kB)
[K     |████████████████████████████████| 67 kB 7.7 MB/s 
Collecting tokenizers!=0.11.3,>=0.11.1
  Downloading tokenizers-0.11.6-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.5 MB)
[K     |████████████████████████████████| 6.5 MB 56.8 MB/s 
Installing collected packages: pyyaml, tokenizers, sacremoses, huggingface-hub, transformers
  Attempting uninstall: pyyaml
  

In [None]:
import numpy as np
import re, os, json
from random import choice
from tqdm import tqdm
import torch
import transformers
from transformers import BertModel,BertTokenizer

In [None]:
model_name='bert-base-cased'
BERT_MAX_LEN = 128
RANDOM_SEED = 2019

In [None]:
def to_tuple(sent):
    triple_list = []
    for triple in sent['triple_list']:
        triple_list.append(tuple(triple))
    sent['triple_list'] = triple_list

def seq_padding(batch, padding=0):
    # length_batch = [len(seq) for seq in batch]
    # max_length = max(length_batch)
    max_length=BERT_MAX_LEN

    return np.array([
        np.concatenate([seq, [padding] * (max_length - len(seq))]) if len(seq) < max_length else seq for seq in batch
    ])

def load_data(test_path, rel_dict_path):
    test_data = json.load(open(test_path))
    id2rel, rel2id = json.load(open(rel_dict_path))

    id2rel = {int(i): j for i, j in id2rel.items()}
    num_rels = len(id2rel)

    for sent in test_data:
        to_tuple(sent)

    print("test_data len:", len(test_data))

    return test_data, id2rel, rel2id, num_rels

def find_head_idx(source, target):
    target_len = len(target)
    for i in range(len(source)):
        if source[i: i + target_len] == target:
            return i
    return -1

### preprocess

In [None]:
class data_generator:
    def __init__(self, data, tokenizer, rel2id, num_rels, maxlen):
        self.data = data
        self.batch_size = len(self.data)
        self.tokenizer = tokenizer
        self.maxlen = maxlen
        self.rel2id = rel2id
        self.num_rels = num_rels
    def __len__(self):
        return self.batch_size
    def generator(self):
        # while True:
            idxs = list(range(len(self.data)))
            # np.random.seed(RANDOM_SEED)
            # np.random.shuffle(idxs)
            tokens_batch, segments_batch, sub_heads_batch, sub_tails_batch, sub_head_batch, sub_tail_batch, obj_heads_batch, obj_tails_batch = [], [], [], [], [], [], [], []
            for idx in idxs:
                line = self.data[idx]
                text = ' '.join(line['text'].split()[:self.maxlen])
                tokens = self.tokenizer.tokenize(text)
                if len(tokens) > BERT_MAX_LEN:
                    tokens = tokens[:BERT_MAX_LEN]


                s2ro_map = {}
                for triple in line['triple_list']:
                    triple = (self.tokenizer.tokenize(triple[0]), triple[1], self.tokenizer.tokenize(triple[2]))
                    sub_head_idx = find_head_idx(tokens, triple[0])
                    obj_head_idx = find_head_idx(tokens, triple[2])
                    if sub_head_idx != 0 and obj_head_idx != 0:
                        sub = (sub_head_idx, sub_head_idx + len(triple[0]) - 1)
                        if sub not in s2ro_map:
                            s2ro_map[sub] = []
                        s2ro_map[sub].append((obj_head_idx,##subject to relation object
                                           obj_head_idx + len(triple[2]) - 1,#obj_tail_idx
                                           self.rel2id[triple[1]]))#rel
                text_len = BERT_MAX_LEN
                inputs=self.tokenizer(text,return_tensors='pt',add_special_tokens=False,truncation=True,padding=True,max_length=BERT_MAX_LEN)
                token_ids, segment_ids=inputs['input_ids'],inputs['attention_mask']
                pad_len=BERT_MAX_LEN-token_ids.shape[1]
                pad_seq=torch.zeros(1,pad_len)
                token_ids=torch.cat((token_ids,pad_seq),dim=1)
                segment_ids=torch.cat((segment_ids,pad_seq),dim=1)
                    # if inputs.shape[1]<BERT_MAX_LEN:
                    #     inputs=inputs
                    #     [seq, [padding] * (max_length - len(seq))]

                    # if len(token_ids) > text_len:
                    #     token_ids = token_ids[:text_len]
                    #     segment_ids = segment_ids[:text_len]
                tokens_batch.append(token_ids)
                segments_batch.append(segment_ids)
                sub_heads, sub_tails=torch.zeros(text_len),torch.zeros(text_len)
                obj_heads, obj_tails = torch.zeros((text_len, self.num_rels)), torch.zeros((text_len, self.num_rels))
                if s2ro_map:
                    # token_ids, segment_ids = self.tokenizer.encode(text)
                    
                    
                    for s in s2ro_map:
                        sub_heads[s[0]] = 1
                        sub_tails[s[1]] = 1
                    # sub_head, sub_tail = choice(list(s2ro_map.keys()))

                    
                    sub=list(s2ro_map.keys())
                    for sub_head,sub_tail in sub:
                        for ro in s2ro_map.get((sub_head, sub_tail), []):
                            obj_heads[ro[0]][ro[2]] = 1
                            obj_tails[ro[1]][ro[2]] = 1
                        # sub_head_batch.append([sub_head])
                        # sub_tail_batch.append([sub_tail])
                    # print(sub_heads,sub_tails,torch.where(obj_heads==1),torch.where(obj_tails)==1)
                sub_heads_batch.append(sub_heads)
                sub_tails_batch.append(sub_tails)
                obj_heads_batch.append(obj_heads)
                obj_tails_batch.append(obj_tails)
            return tokens_batch, segments_batch, sub_heads_batch, sub_tails_batch, sub_head_batch, sub_tail_batch, obj_heads_batch, obj_tails_batch

### download___NYT 

In [None]:
# Dropbox
!wget https://www.dropbox.com/s/u5u173tlze2er6z/Desktop.zip?dl=0 -O desktop.zip

# Unzip the dataset.
# This may take some time.
!unzip -q desktop.zip
# !gdown--https://www.dropbox.com/s/u5u173tlze2er6z/Desktop.zip?dl=0


--2022-03-12 13:16:35--  https://www.dropbox.com/s/u5u173tlze2er6z/Desktop.zip?dl=0
Resolving www.dropbox.com (www.dropbox.com)... 162.125.5.18, 2620:100:601d:18::a27d:512
Connecting to www.dropbox.com (www.dropbox.com)|162.125.5.18|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: /s/raw/u5u173tlze2er6z/Desktop.zip [following]
--2022-03-12 13:16:36--  https://www.dropbox.com/s/raw/u5u173tlze2er6z/Desktop.zip
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://uce43f6a6bd473a44688ccfcaec4.dl.dropboxusercontent.com/cd/0/inline/BhW3ZNAKG1rjwt1GR0mNhxqM0ultVaDWZ7XRU-OYSB5Kd_LKBDgLZRFlJhvW1xcD9ponnIZpNMhH1sUMjOznYuydDG70nnhAq5wdy14h1-mgSdyztciCESmkAH_1Ax1AM5IviPIuLZCQkIJ_GbdeNOlBVu-c445em0ubWWe00nBtDw/file# [following]
--2022-03-12 13:16:36--  https://uce43f6a6bd473a44688ccfcaec4.dl.dropboxusercontent.com/cd/0/inline/BhW3ZNAKG1rjwt1GR0mNhxqM0ultVaDWZ7XRU-OYSB5Kd_LKBDgLZRFlJhvW1xcD9p

In [None]:
!wget https://www.dropbox.com/s/fcovcanizu6me70/dev_triples.json?dl=0  -O dev_triples.json

--2022-03-12 13:16:38--  https://www.dropbox.com/s/fcovcanizu6me70/dev_triples.json?dl=0
Resolving www.dropbox.com (www.dropbox.com)... 162.125.5.18, 2620:100:601d:18::a27d:512
Connecting to www.dropbox.com (www.dropbox.com)|162.125.5.18|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: /s/raw/fcovcanizu6me70/dev_triples.json [following]
--2022-03-12 13:16:38--  https://www.dropbox.com/s/raw/fcovcanizu6me70/dev_triples.json
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://ucedc6a70f9755354fb670722770.dl.dropboxusercontent.com/cd/0/inline/BhUQX64KVaTbqJcbs8l89lKWOfKb1BFhMd0Yb_P9MX5tSicjwpquQ9MFsbFy9wu25f5OHQu69JQ0j-04HAlTzdGKVn3eHe0PybCHTN-sVkOWvkXgRhBKJIdotCl78y6rY0N8noMFZSk0vg6sLD40W1O3c4wML-uYvVGz4ovRw5fNGg/file# [following]
--2022-03-12 13:16:38--  https://ucedc6a70f9755354fb670722770.dl.dropboxusercontent.com/cd/0/inline/BhUQX64KVaTbqJcbs8l89lKWOfKb1BFhMd0Yb_P9MX5tSicjwpq

In [None]:
!wget https://www.dropbox.com/s/vkz83jfc9ekx1i5/rel2id.json?dl=0  -O rel2id.json

--2022-03-12 13:16:39--  https://www.dropbox.com/s/vkz83jfc9ekx1i5/rel2id.json?dl=0
Resolving www.dropbox.com (www.dropbox.com)... 162.125.5.18, 2620:100:601d:18::a27d:512
Connecting to www.dropbox.com (www.dropbox.com)|162.125.5.18|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: /s/raw/vkz83jfc9ekx1i5/rel2id.json [following]
--2022-03-12 13:16:39--  https://www.dropbox.com/s/raw/vkz83jfc9ekx1i5/rel2id.json
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://uc194fc3c3a022ed61492d857cbd.dl.dropboxusercontent.com/cd/0/inline/BhXglNkX29AOoguRvmJH_aOdZiqGqTOw8MCiA-YPJJImlmV6N-n58UDATTtTfxFbz4e4vTD8uSjK5Yesayhv56h3Cns8qwFgZbD678YpCXNLYNme2K5nxyrpd7UKXYk_4A0r8YvXkYfr5utP4mf582T2Kem23DWNw0P4VlS4zt6Z9w/file# [following]
--2022-03-12 13:16:39--  https://uc194fc3c3a022ed61492d857cbd.dl.dropboxusercontent.com/cd/0/inline/BhXglNkX29AOoguRvmJH_aOdZiqGqTOw8MCiA-YPJJImlmV6N-n58UDATTtTfxFbz4

### NYT_11_download

In [None]:
!wget https://www.dropbox.com/s/kj5f1vcomr21xyv/NYT11.zip?dl=0  -O NYT11.zip 
!unzip -q NYT11.zip

--2022-03-12 13:16:40--  https://www.dropbox.com/s/kj5f1vcomr21xyv/NYT11.zip?dl=0
Resolving www.dropbox.com (www.dropbox.com)... 162.125.5.18, 2620:100:601d:18::a27d:512
Connecting to www.dropbox.com (www.dropbox.com)|162.125.5.18|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: /s/raw/kj5f1vcomr21xyv/NYT11.zip [following]
--2022-03-12 13:16:40--  https://www.dropbox.com/s/raw/kj5f1vcomr21xyv/NYT11.zip
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://uc67352ba0fb49de598d86077b6c.dl.dropboxusercontent.com/cd/0/inline/BhXO1eiiIJpaZcHwAKnJtDR_3hcJHj653piyLB3Qy6U5yM5TR4wUpxV0UxxBKy8H2cRDmgc22E727ZKEK4fPmsd3U7k5dCshLUkNqNxojYG4iHoCAr8ObpxDXL_7551HUUO7JEpT3JuchiU2HlcQDrhwcduOJ4bF-1d1D4ZO8lWOmw/file# [following]
--2022-03-12 13:16:40--  https://uc67352ba0fb49de598d86077b6c.dl.dropboxusercontent.com/cd/0/inline/BhXO1eiiIJpaZcHwAKnJtDR_3hcJHj653piyLB3Qy6U5yM5TR4wUpxV0UxxBKy8H2cRDmgc2

### 数据

In [None]:
if __name__=='__main__':
    test_data, id2rel, rel2id, num_rels=load_data('./train_triples.json',rel_dict_path='./rel2id.json')
    tokenizer=BertTokenizer.from_pretrained(model_name)
    maxlen=100
    val_data,dev_id2rel, dev_rel2id, dev_num_rels=load_data('./dev_triples.json',rel_dict_path='./rel2id.json')
    tokens_batch, segments_batch, sub_heads_batch, sub_tails_batch, sub_head_batch, sub_tail_batch, obj_heads_batch, obj_tails_batch=data_generator(test_data, tokenizer, rel2id, num_rels, maxlen)\
        .generator()
    dev_tokens,dev_segs,dev_heads,dev_tails,dev_head,dev_tail,dev_oheads,dev_otails=data_generator(val_data, tokenizer, rel2id, num_rels, maxlen).generator()



test_data len: 62335


Downloading:   0%|          | 0.00/208k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

test_data len: 313


In [None]:
sub_tails_batch[:5]

[tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.]),
 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.

In [None]:
len(tokens_batch)

62335

In [None]:
tokens_batch=torch.cat([l for l in tokens_batch]).int()
print(tokens_batch.shape)
segments_batch=torch.cat([l for l in segments_batch]).int()#sents_length*128
sub_heads_batch=torch.cat([l for l in sub_heads_batch]).reshape(tokens_batch.shape[0],-1).float()#128
sub_tails_batch=torch.cat([l for l in sub_tails_batch]).reshape(tokens_batch.shape[0],-1).float()
obj_heads_batch=torch.cat([l for l in obj_heads_batch]).reshape(tokens_batch.shape[0],128,-1).float()
obj_tails_batch=torch.cat([l for l in obj_tails_batch]).reshape(tokens_batch.shape[0],128,-1).float()
print(type(segments_batch))
print(segments_batch.shape)
print(obj_heads_batch.shape)
print(type(obj_heads_batch))
dev_tokens=torch.cat([l for l in dev_tokens]).int()
dev_segs=torch.cat([l for l in dev_segs]).int()
dev_heads=torch.cat([l for l in dev_heads]).reshape(dev_tokens.shape[0],-1).float()
dev_tails=torch.cat([l for l in dev_tails]).reshape(dev_tokens.shape[0],-1).float()
dev_oheads=torch.cat([l for l in dev_oheads]).reshape(dev_tokens.shape[0],128,-1).float()
dev_otails=torch.cat([l for l in dev_otails]).reshape(dev_tokens.shape[0],128,-1).float()

torch.Size([62335, 128])
<class 'torch.Tensor'>
torch.Size([62335, 128])
torch.Size([62335, 128, 12])
<class 'torch.Tensor'>


In [None]:
tokens_batch.shape[0]

62335

In [None]:
len(test_data)

62335

In [None]:
len(val_data)

313

In [None]:
#@title
# tokens = tokenizer.tokenize('Jackie')
# tokens

In [None]:
#@title
# line=test_data[0]
# line

In [None]:
#@title
# text = ' '.join(line['text'].split()[:100])
# tokens =tokenizer.tokenize(text)
# print(text,'\n',tokens)

In [None]:
#@title
# inputs=tokenizer(text,return_tensors='pt',add_special_tokens=False,truncation=True,padding=True,max_length=BERT_MAX_LEN)
# token_ids, segment_ids=inputs['input_ids'],inputs['attention_mask']
# print(token_ids)

In [None]:
#@title
# tokens_batch[0]#10*

In [None]:
#@title
# segments_batch[0]#46
# sub_heads_batch[0]#18*
# sub_tails_batch[0]
# obj_heads_batch[0]
# obj_tails_batch[[0]]

In [None]:

print(type(segments_batch))
print(segments_batch.shape)
print(obj_heads_batch.shape)
print(type(obj_heads_batch))

<class 'torch.Tensor'>
torch.Size([62335, 128])
torch.Size([62335, 128, 12])
<class 'torch.Tensor'>


In [None]:
sub_heads_batch[:3]

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0

In [None]:
sub_tails_batch.shape

torch.Size([62335, 128])

In [None]:
import torch.nn as nn
from torch.utils.data import DataLoader,Dataset,TensorDataset
train_set=TensorDataset(tokens_batch, segments_batch, sub_heads_batch, sub_tails_batch,obj_heads_batch,obj_tails_batch)
from torch.nn import functional as F
dev_set=TensorDataset(dev_tokens,dev_segs,dev_heads,dev_tails,dev_oheads,dev_otails)

### seed

In [None]:
def same_seeds(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    # np.random.seed(seed)
    # random.seed(seed)
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True
same_seeds(0)

### Model

In [None]:
class E2EModel(nn.Module):
    def __init__(self):
        super(E2EModel, self).__init__()
        self.label_num=len(rel2id)
        self.encode=BertModel.from_pretrained(model_name)
        self.subject_h_model=nn.Sequential(
          nn.Linear(768, 1),
            nn.Sigmoid()
        )
        self.subject_t_model=nn.Sequential(
            nn.Linear(768, 1),
            nn.Sigmoid()
        )
        self.object_h_model=nn.Sequential(
            nn.Linear(768, self.label_num),
            nn.Sigmoid()
        )
        self.object_t_model = nn.Sequential(
            nn.Linear(768, self.label_num),
            nn.Sigmoid()
        )
    def forward(self,inputs_id,att_mask,sub_tar,tail_tar,obj_htar,obj_ttar):
        x=self.encode(inputs_id,att_mask)[0]#B*L*768
        sub_tar=sub_tar.unsqueeze(-1)
        tail_tar=tail_tar.unsqueeze(-1)
        sub=self.subject_h_model(x)#B*L*1
        tail=self.subject_t_model(x)
        sub_loss=F.binary_cross_entropy(sub,sub_tar)+F.binary_cross_entropy(tail,tail_tar)
        subfea=sub+tail#这里与原文有些不符
        objinput=subfea*x+x
        obj_h=self.object_h_model(objinput)
        obj_t=self.object_t_model(objinput)
        obj_loss=F.binary_cross_entropy(obj_h,obj_htar)+F.binary_cross_entropy(obj_t,obj_ttar)
        total_loss=obj_loss+sub_loss
        return sub,tail,obj_h,obj_t, sub_loss,obj_loss,total_loss

### 修改

In [None]:
def loss_fn(pred,gold,mask): 
  loss = F.binary_cross_entropy(pred, gold, reduction='none')
  if loss.shape != mask.shape:
    mask = mask.unsqueeze(-1)
    loss = torch.sum(loss * mask) / torch.sum(mask)
    return loss

In [None]:
class E2EModel(nn.Module):
    def __init__(self):
        super(E2EModel, self).__init__()
        self.label_num=len(rel2id)
        self.encode=BertModel.from_pretrained(model_name)
        self.subject_h_model=nn.Sequential(
          nn.Linear(768, 1),
            nn.Sigmoid()
        )
        self.subject_t_model=nn.Sequential(
            nn.Linear(768, 1),
            nn.Sigmoid()
        )
        self.object_h_model=nn.Sequential(
            nn.Linear(768, self.label_num),
            nn.Sigmoid()
        )
        self.object_t_model = nn.Sequential(
            nn.Linear(768, self.label_num),
            nn.Sigmoid()
        )
    def forward(self,inputs_id,att_mask,sub_tar,tail_tar,obj_htar,obj_ttar):
        x=self.encode(inputs_id,att_mask)[0]#B*L*768
        sub_tar=sub_tar.unsqueeze(-1)
        tail_tar=tail_tar.unsqueeze(-1)
        sub=self.subject_h_model(x)#B*L*1
        tail=self.subject_t_model(x)
        sub_loss=loss_fn(sub,sub_tar,att_mask)+loss_fn(tail,tail_tar,att_mask)
        subfea=sub+tail#这里与原文有些不符
        objinput=subfea*x+x
        obj_h=self.object_h_model(objinput)
        obj_t=self.object_t_model(objinput)
        obj_loss=loss_fn(obj_h,obj_htar,att_mask)+loss_fn(obj_t,obj_ttar,att_mask)
        total_loss=obj_loss+sub_loss
        return sub,tail,obj_h,obj_t, sub_loss,obj_loss,total_loss

### 修改——苏剑林（model_cnn）

In [None]:
# class subModel(nn.Module):
#     def __init__(self):
#         super(subModel, self).__init__()
#         self.label_num=len(rel2id)
#         self.encode=BertModel.from_pretrained(model_name)
#         self.sub_conv_1=nn.Conv1d(768,256,3,1,padding=1)
#         self.sub_conv_2=nn.Conv1d(768,256,3,1,padding=1)
#         self.obj_conv_1=nn.Conv1d(768,256,3,1,padding=1)
#         self.obj_conv_2=nn.Conv1d(768,256,3,1,padding=1)
#         self.subject_h_model=nn.Sequential(
#             nn.Linear(256, 1),
#             nn.Sigmoid()
#         )
#         self.subject_t_model=nn.Sequential(
#             nn.Linear(256, 1),
#             nn.Sigmoid()
#         )
#         self.object_h_model = nn.Sequential(
#             nn.Linear(256, self.label_num),
#             nn.Sigmoid()
#         )
#         self.object_t_model = nn.Sequential(
#             nn.Linear(256, self.label_num),
#             nn.Sigmoid()
#         )



#     def forward(self, inputs_id, att_mask, sub_tar, tail_tar, obj_htar, obj_ttar):
#             x = self.encode(inputs_id, att_mask)[0]  # B*L*768
#             sub_tar = sub_tar.unsqueeze(-1)
#             tail_tar = tail_tar.unsqueeze(-1)
#             x=x.transpose(2,1)#B*768*L
#             sub_x_1=self.sub_conv_1(x)
#             sub_x_2=self.sub_conv_2(x)
#             sub = self.subject_h_model(sub_x_1.transpose(2,1))  # B*L*1
#             tail = self.subject_t_model(sub_x_2.transpose(2,1))
#             sub_loss = F.binary_cross_entropy(sub, sub_tar) + F.binary_cross_entropy(tail, tail_tar)
#             subfea = sub + tail  # 这里与原文有些不符
#             objinput = subfea * x.transpose(2,1) + x.transpose(2,1)
#             objinput=objinput.transpose(2,1)
#             obj_x_1=self.obj_conv_1(objinput)
#             obj_x_2=self.obj_conv_2(objinput)
#             obj_h = self.object_h_model(obj_x_1.transpose(2,1))
#             obj_t = self.object_t_model(obj_x_2.transpose(2,1))
#             obj_loss = F.binary_cross_entropy(obj_h, obj_htar) + F.binary_cross_entropy(obj_t, obj_ttar)
#             total_loss = obj_loss + sub_loss
#             return sub, tail, obj_h, obj_t, sub_loss, obj_loss, total_loss

# model=subModel()
# model_path='./submodel.ckpt'

In [None]:

# # model.bert_model(torch.from_numpy(np.array(train_id[:4])),torch.from_numpy(np.array(train_mask[:4])))

# optim=torch.optim.Adam(model.parameters(),lr=1e-3)
# #
# import transformers

# # for epoch in range(EPOCHS):
# #     step=1
# #     model.train()
# #     for i,data in tqdm(enumerate(train_loader)):
# #         input_id,input_mask,label=data
# #         pre=model.forward2(input_id,input_mask)
# #         loss=criterier(pre,label)
# device = 'cuda' if torch.cuda.is_available() else 'cpu'

### 文中的模型

In [None]:
model=E2EModel()
model_path='./model.ckpt'
# model.bert_model(torch.from_numpy(np.array(train_id[:4])),torch.from_numpy(np.array(train_mask[:4])))

optim=torch.optim.Adam(model.parameters(),lr=1e-3)

import transformers

# for epoch in range(EPOCHS):
#     step=1
#     model.train()
#     for i,data in tqdm(enumerate(train_loader)):
#         input_id,input_mask,label=data
#         pre=model.forward2(input_id,input_mask)
#         loss=criterier(pre,label)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

Downloading:   0%|          | 0.00/416M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### Train

In [None]:
def partial_match(pred_set, gold_set):
    pred = {(i[0].split(' ')[0] if len(i[0].split(' ')) > 0 else i[0], i[1],
                 i[2].split(' ')[0] if len(i[2].split(' ')) > 0 else i[2]) for i in pred_set}
    gold = {(i[0].split(' ')[0] if len(i[0].split(' ')) > 0 else i[0], i[1],
                 i[2].split(' ')[0] if len(i[2].split(' ')) > 0 else i[2]) for i in gold_set}
    return pred, gold

In [None]:
#@title
def metric(eval_data,exact_match=False, output_path=None):
    if output_path:
        F = open(output_path, 'w+')
    orders = ['subject', 'relation', 'object']
    correct_num, predict_num, gold_num = 1e-10, 1e-10, 1e-10
    # for line in tqdm(iter(eval_data)):
    with open('triple_list.json') as f:
        f=f.readline()
        dict_=json.loads(f)
    for i,line in enumerate(eval_data):
        if i<dev_tokens.shape[0]:
            pre=dict_[str(i)]
            Pred_triples=set([tuple(l) for l in pre])
            Gold_triples = set(tuple(l) for l in line['triple_list'])
            Pred_triples_eval, Gold_triples_eval = partial_match(Pred_triples, Gold_triples) if not exact_match else (Pred_triples, Gold_triples)

            correct_num += len(Pred_triples_eval & Gold_triples_eval)
            predict_num += len(Pred_triples_eval)
            gold_num += len(Gold_triples_eval)

            if output_path:
                result = json.dumps({
                    'text': line['text'],
                    'triple_list_gold': [
                        dict(zip(orders, triple)) for triple in Gold_triples
                    ],
                    'triple_list_pred': [
                        dict(zip(orders, triple)) for triple in Pred_triples
                    ],
                    'new': [
                        dict(zip(orders, triple)) for triple in Pred_triples - Gold_triples
                    ],
                    'lack': [
                        dict(zip(orders, triple)) for triple in Gold_triples - Pred_triples
                    ]
                }, ensure_ascii=False, indent=4)
                F.write(result + '\n')


    precision = correct_num / predict_num
    recall = correct_num / gold_num
    f1_score = 2 * precision * recall / (precision + recall)

    print(f'correct_num:{correct_num}\npredict_num:{predict_num}\ngold_num:{gold_num}')
    return precision, recall, f1_score
def partial_match(pred_set, gold_set):
    pred = {(i[0].split(' ')[0] if len(i[0].split(' ')) > 0 else i[0], i[1],
                 i[2].split(' ')[0] if len(i[2].split(' ')) > 0 else i[2]) for i in pred_set}
    gold = {(i[0].split(' ')[0] if len(i[0].split(' ')) > 0 else i[0], i[1],
                 i[2].split(' ')[0] if len(i[2].split(' ')) > 0 else i[2]) for i in gold_set}
    return pred, gold

In [None]:
def metric(eval_data,exact_match=False, output_path=None):
    if output_path:
        F = open(output_path, 'w+')
    orders = ['subject', 'relation', 'object']
    correct_num, predict_num, gold_num = 1e-10, 1e-10, 1e-10
    # for line in tqdm(iter(eval_data)):
    with open('val_triples.json') as f:
        f=f.readline()
        dict_=json.loads(f)
    for i,line in enumerate(eval_data):
        if i<tokens_batch.shape[0]:
            pre=dict_[str(i)]
            Pred_triples=set([tuple(l) for l in pre])
            Gold_triples = set(tuple(l) for l in line['triple_list'])
            Pred_triples_eval, Gold_triples_eval = partial_match(Pred_triples, Gold_triples) if not exact_match else (Pred_triples, Gold_triples)

            correct_num += len(Pred_triples_eval & Gold_triples_eval)
            predict_num += len(Pred_triples_eval)
            gold_num += len(Gold_triples_eval)

            # if output_path:
            #     result = json.dumps({
            #         'text': line['text'],
            #         'triple_list_gold': [
            #             dict(zip(orders, triple)) for triple in Gold_triples
            #         ],
            #         'triple_list_pred': [
            #             dict(zip(orders, triple)) for triple in Pred_triples
            #         ],
            #         'new': [
            #             dict(zip(orders, triple)) for triple in Pred_triples - Gold_triples
            #         ],
            #         'lack': [
            #             dict(zip(orders, triple)) for triple in Gold_triples - Pred_triples
            #         ]
            #     }, ensure_ascii=False, indent=4)
            #     F.write(result + '\n')
            if output_path:
                result = json.dumps({'text': line['text'],'triple_list_gold': [dict(zip(orders, triple)) for triple in Gold_triples],'triple_list_pred': [dict(zip(orders, triple)) for triple in Pred_triples],'new': [dict(zip(orders, triple)) for triple in Pred_triples - Gold_triples],'lack': [dict(zip(orders, triple)) for triple in Gold_triples - Pred_triples]}, ensure_ascii=False)

                F.write(result + '\n')



    precision = correct_num / predict_num
    recall = correct_num / gold_num
    f1_score = 2 * precision * recall / (precision + recall)

    print(f'correct_num:{correct_num}\npredict_num:{predict_num}\ngold_num:{gold_num}')
    return precision, recall, f1_score



In [None]:
def extrac_triple(sub_heads_logits, sub_tails_logits, obj_heads_logits, obj_tails_logits, input_ids):  # 抽取一个triple
    h_bar = 0.5
    t_bar = 0.5
    sub_heads_logits = np.array(sub_heads_logits)
    sub_tails_logits = np.array(sub_tails_logits)
    tokens = tokenizer.convert_ids_to_tokens(input_ids)
    # print(tokens)
    sub_heads, sub_tails = np.where(sub_heads_logits > h_bar)[0], np.where(sub_tails_logits > t_bar)[0]
    subjects = []
    for sub_head in sub_heads:
        sub_tail = sub_tails[sub_tails >= sub_head]
        if len(sub_tail) > 0:
            sub_tail = sub_tail[0]
            if sub_tail == sub_head:
                subject = tokens[sub_head]
                subjects.append((subject, sub_head, sub_tail))
            else:
                subject = tokens[sub_head: sub_tail]
                subjects.append((subject, sub_head, sub_tail))
    if subjects:
        triple_list = []
        sub_heads, sub_tails = np.array([sub[1:] for sub in subjects]).T.reshape((2, -1, 1))
        for i, subject in enumerate(subjects):
            sub = subject[0]
            sub = ''.join([i.lstrip("##") for i in sub])
            sub = ' '.join(sub.split('[unused1]'))
            obj_heads, obj_tails = np.where(obj_heads_logits > h_bar), np.where(obj_tails_logits > t_bar)
            for obj_head, rel_head in zip(*obj_heads):
                for obj_tail, rel_tail in zip(*obj_tails):
                    if obj_head <= obj_tail and rel_head == rel_tail:
                        rel = id2rel[rel_head]
                        if obj_head == obj_tail:
                            obj = tokens[obj_head]
                        else:
                            obj = tokens[obj_head: obj_tail]
                        obj = ''.join([i.lstrip("##") for i in obj])
                        obj = ' '.join(obj.split('[unused1]'))
                        triple_list.append((sub, rel, obj))
                        break
        triple_set = set()
        for s, r, o in triple_list:
            triple_set.add((s, r, o))
        return list(triple_set)
    else:
        return []

In [None]:
import numpy as np
warmup_steps = 2500
init_lr = 0.1  
# 模拟训练15000步
max_steps = 15000
for train_steps in range(max_steps):
    if warmup_steps and train_steps < warmup_steps:
        warmup_percent_done = train_steps / warmup_steps
        warmup_learning_rate = init_lr * warmup_percent_done  #gradual warmup_lr
        learning_rate = warmup_learning_rate
    else:
        #learning_rate = np.sin(learning_rate)  #预热学习率结束后,学习率呈sin衰减
        learning_rate = learning_rate**1.0001 #预热学习率结束后,学习率呈指数衰减(近似模拟指数衰减)


In [None]:
# from sys import last_traceback

def train(net, train_set,val_set, num_epochs, learning_rate, batch_size):
    i = 1
    print('执行次数为：{}'.format(i))
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    t_total = len(train_loader) // num_epochs
    
    
    # scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=t_total)

    net = net.to(device)
    net.train()
    # optimizer = AdamW(net.parameters(), lr=learning_rate)
    # criterier = nn.CrossEntropyLoss()
    # criterier=nn.MSELoss()
    for epoch in range(num_epochs):

        train_loss, sub_hacc,sub_tacc,obj_hacc,obj_tacc = 0, 0,0,0,0
        accum_iter = 4
        step = 0
        for batch_idx, data in enumerate(tqdm(train_loader)):
            lr=learning_rate
            optimizer = torch.optim.Adam(net.parameters(), lr=lr,betas=[0.9,0.9])
            inputs_id,att_mask,sub_htar,sub_ttar,obj_htar,obj_ttar = data
            text, mask, sub_htar,sub_ttar,obj_htar,obj_ttar =inputs_id.to(device),att_mask.to(device),sub_htar.to(device),sub_ttar.to(device),obj_htar.to(device),obj_ttar.to(device)
            optimizer.zero_grad()
            # with torch.set_grad_enabled(True):
            sub,tail,obj_h,obj_t, sub_loss,obj_loss,total_loss= net(text, mask, sub_htar,sub_ttar,obj_htar,obj_ttar)
            sub_htar,sub_ttar=sub_htar.unsqueeze(-1),sub_ttar.unsqueeze(-1)
            total_loss.backward()
            optimizer.step()
            sub_hacc+=(sub.ge(0.5).int()==sub_htar).sum().item()
            sub_tacc += (tail.ge(0.5).int() == sub_ttar).sum().item()
            obj_hacc+=(obj_h.ge(0.5).int()==obj_htar).sum().item()
            obj_tacc+=(obj_t.ge(0.5).int()==obj_ttar).sum().item()
            train_loss += total_loss.item()

            step += 1
            if step % 800 == 0:
                print('train_epoch|{},sub_loss={},sub_h_acc={}，sub_t_acc={}'.format(epoch+1, sub_loss,sub_hacc/(step*batch_size*128),
                                                                                sub_tacc/(step*batch_size*128)))
                print('obj_loss={},obj_h_acc={}，obj_t_acc={}'.format(obj_loss,obj_hacc/(step*batch_size*128*len(rel2id)),
                                                                                obj_tacc/(step*batch_size*128*len(rel2id))))
                print('epoch{}|train_loss={},total_loss={}'.format(epoch+1,train_loss,train_loss/step))
        torch.save(net.state_dict(), model_path)

        net.eval()
        val_loader = DataLoader(val_set,batch_size=1,shuffle=False)
        val_batch_size=1
        with torch.no_grad():
            dict_t = {}
            vstep = 0
            train_loss, sub_hacc, sub_tacc, obj_hacc, obj_tacc = 0, 0, 0, 0, 0
            f = open('val_triples.json', mode='w+', encoding='utf_8')
          
            for j, data in enumerate(val_loader):
                inputs_id, att_mask, sub_htar, sub_ttar, obj_htar, obj_ttar = data
                text, mask, sub_htar, sub_ttar, obj_htar, obj_ttar = inputs_id.to(device), att_mask.to(device), sub_htar.to(
                    device), sub_ttar.to(device), obj_htar.to(device), obj_ttar.to(device)
                sub, tail, obj_h, obj_t, sub_loss, obj_loss, total_loss = model(text, mask, sub_htar, sub_ttar, obj_htar,
                                                                                obj_ttar)
                sub_h = sub.ge(0.5).int()
                sub_t = tail.ge(0.5).int()
                obj_h = obj_h.ge(0.5).int()
                obj_t = obj_t.ge(0.5).int()
                sub_hacc += (sub.ge(0.5).int() == sub_htar).sum().item()
                sub_tacc += (tail.ge(0.5).int() == sub_ttar).sum().item()
                obj_hacc += (obj_h.ge(0.5).int() == obj_htar).sum().item()
                obj_tacc += (obj_t.ge(0.5).int() == obj_ttar).sum().item()
                train_loss += total_loss.item()
                vstep += 1
                sub, tail, obj_h, obj_t = sub.cpu(), tail.cpu(), obj_h.cpu(), obj_t.cpu()
                sub_heads_logits, sub_tails_logits, obj_heads_logits, obj_tails_logits = np.array(sub), np.array(
                    tail), np.array(obj_h), np.array(obj_t)
                inputs_id = inputs_id.cpu()
                # for i in range(val_batch_size):
                #     sub_heads_logits, sub_tails_logits, obj_heads_logits, obj_tails_logits=sub_heads_logits[i],sub_tails_logits[i],obj_heads_logits[i],obj_tails_logits[i]
                #     input_ids=inputs_id[i]
                #     triple_list=extrac_triple(sub_heads_logits,sub_tails_logits,obj_heads_logits,obj_tails_logits,input_ids)
                #     dict_t[j]=triple_list
                triple_list = extrac_triple(sub_heads_logits[0], sub_tails_logits[0], obj_heads_logits[0], obj_tails_logits[0], inputs_id[0])
                dict_t[j] = triple_list
                # if vstep<20:
                #   tokens = tokenizer.convert_ids_to_tokens(inputs_id[0])
                #   print(tokens)
                #   print('\n')
                #   print(triple_list)
            # print(dict_t)
            json.dump(dict_t, f)
            f.close()
            print('-' * 50)
            print('val_train_epoch|{},val_sub_loss={},sub_h_acc={}，sub_t_acc={}'.format( 1, sub_loss, sub_hacc / (
                        len(dev_set) * val_batch_size * 128),
                                                                                        sub_tacc / (
                                                                                                    len(dev_set) * val_batch_size * 128)))
            print('obj_loss={},obj_h_acc={}，obj_t_acc={}'.format(obj_loss, obj_hacc / (
                        len(dev_set) * val_batch_size * 128 * len(rel2id)),
                                                                obj_tacc / (len(dev_set) * val_batch_size * 128 * len(
                                                                    rel2id))))
            print('val_epoch{}|total_loss={}'.format( 1, train_loss / vstep))
        precision, recall, f1_score = metric(val_data, exact_match=False, output_path='valoutputs.json')
        print(f'precision={precision},recall={recall},f1_score={f1_score}')        
        # with torch.no_grad():
        #     dict_t = {}
        #     vstep=0
        #     train_loss, sub_hacc, sub_tacc, obj_hacc, obj_tacc = 0, 0, 0, 0, 0
        #     f=open('val_traintriple_list.json',mode='w+',encoding='utf_8')
        #     for j, data in enumerate(val_loader):
        #         inputs_id, att_mask, sub_htar, sub_ttar, obj_htar, obj_ttar = data
        #         text, mask, sub_htar, sub_ttar, obj_htar, obj_ttar = inputs_id.to(device), att_mask.to(device), sub_htar.to(
        #             device), sub_ttar.to(device), obj_htar.to(device), obj_ttar.to(device)
        #         sub, tail, obj_h, obj_t, sub_loss, obj_loss, total_loss = model(text, mask, sub_htar, sub_ttar, obj_htar,obj_ttar)
        #         sub_h=sub.ge(0.5).int()
        #         sub_t=tail.ge(0.5).int()
        #         obj_h=obj_h.ge(0.5).int()
        #         obj_t=obj_t.ge(0.5).int()
        #         sub_hacc += (sub.ge(0.5).int() == sub_htar).sum().item()
        #         sub_tacc += (tail.ge(0.5).int() == sub_ttar).sum().item()
        #         obj_hacc += (obj_h.ge(0.5).int() == obj_htar).sum().item()
        #         obj_tacc += (obj_t.ge(0.5).int() == obj_ttar).sum().item()
        #         train_loss += total_loss.item()
        #         vstep+=1
        #         sub, tail, obj_h, obj_t=sub.cpu(),tail.cpu(),obj_h.cpu(),obj_t.cpu()
        #         sub_heads_logits,sub_tails_logits,obj_heads_logits,obj_tails_logits=np.array(sub),np.array(tail),np.array(obj_h),np.array(obj_t)
        #         inputs_id=inputs_id.cpu()
        #         # for i in range(val_batch_size):
        #         #     sub_heads_logits, sub_tails_logits, obj_heads_logits, obj_tails_logits=sub_heads_logits[i],sub_tails_logits[i],obj_heads_logits[i],obj_tails_logits[i]
        #         #     input_ids=inputs_id[i]
        #         #     triple_list=extrac_triple(sub_heads_logits,sub_tails_logits,obj_heads_logits,obj_tails_logits,input_ids)
        #         #     dict_t[j]=triple_list
        #         triple_list=extrac_triple(sub_heads_logits,sub_tails_logits,obj_heads_logits,obj_tails_logits,inputs_id)
        #         dict_t[j]=triple_list
        #     # print(dict_t)
        #     json.dump(dict_t,f)
        #     f.close()
        #     print('-'*50)
        #     print('val_train_epoch|{},val_sub_loss={},sub_h_acc={}，sub_t_acc={}'.format(epoch+1, sub_loss,sub_hacc/(len(val_data)*val_batch_size*128),
        #                                                                         sub_tacc/(len(val_data)*val_batch_size*128)))
        #     print('obj_loss={},obj_h_acc={}，obj_t_acc={}'.format(obj_loss,obj_hacc/(len(val_data)*val_batch_size*128*len(rel2id)),
        #                                                                         obj_tacc/(len(val_data)*val_batch_size*128*len(rel2id))))
        #     print('val_epoch{}|total_loss={}'.format(epoch+1,train_loss/vstep))
        # precision, recall, f1_score = metric(val_data, exact_match=False, output_path='val_output.json')
        # print(f'precision={precision},recall={recall},f1_score={f1_score}')


In [None]:
#@title

# def train(net, train_set,val_set, num_epochs, learning_rate, batch_size):
#     i = 1
#     print('执行次数为：{}'.format(i))
#     train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
#     t_total = len(train_loader) // num_epochs
#     optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
#     # scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=t_total)

#     net = net.to(device)
#     net.train()
#     # optimizer = AdamW(net.parameters(), lr=learning_rate)
#     # criterier = nn.CrossEntropyLoss()
#     # criterier=nn.MSELoss()
#     for epoch in range(num_epochs):
#         train_loss, sub_hacc,sub_tacc,obj_hacc,obj_tacc = 0, 0,0,0,0
#         accum_iter = 4
#         step = 1
#         for batch_idx, data in enumerate(tqdm(train_loader)):
#             inputs_id,att_mask,sub_htar,sub_ttar,obj_htar,obj_ttar = data
#             text, mask, sub_htar,sub_ttar,obj_htar,obj_ttar =inputs_id.to(device),att_mask.to(device),sub_htar.to(device),sub_ttar.to(device),obj_htar.to(device),obj_ttar.to(device)
#             optimizer.zero_grad()
#             # with torch.set_grad_enabled(True):
#             sub,tail,obj_h,obj_t, sub_loss,obj_loss,total_loss= net(text, mask, sub_htar,sub_ttar,obj_htar,obj_ttar)
#             sub_htar,sub_ttar=sub_htar.unsqueeze(-1),sub_ttar.unsqueeze(-1)
#             # label = label.float()
#             # loss = criterier(pre, label)
#             # obj_loss.backward()
#             total_loss.backward()
#             optimizer.step()
#             # scheduler.step()
#             # print(sub.ge(0.5).int())
#             # print(subtar)
#             sub_hacc+=(sub.ge(0.5).int()==sub_htar).sum().item()
#             sub_tacc += (tail.ge(0.5).int() == sub_ttar).sum().item()
#             obj_hacc+=(obj_h.ge(0.5).int()==obj_htar).sum().item()
#             obj_tacc+=(obj_t.ge(0.5).int()==obj_ttar).sum().item()
#             train_loss += total_loss.item()

#             step += 1
#             if step % 100 == 0:
#                 print('train_epoch|{},sub_loss={},sub_h_acc={}，sub_t_acc={}'.format(epoch+1, sub_loss,sub_hacc/(step*batch_size*128),
#                                                                                 sub_tacc/(step*batch_size*128)))
#                 print('obj_loss={},obj_h_acc={}，obj_t_acc={}'.format(obj_loss,obj_hacc/(step*batch_size*128*len(rel2id)),
#                                                                                 obj_tacc/(step*batch_size*128*len(rel2id))))
#                 print('epoch{}|total_loss={}'.format(epoch+1,train_loss/step))
#         torch.save(net.state_dict(), model_path)

#         net.eval()
#         val_loader = DataLoader(val_set,batch_size=1,shuffle=False)
#         val_batch_size=1
#         with torch.no_grad():
#             dict_t = {}
#             for j, data in enumerate(val_loader):
#                 f=open('triple_list.json',mode='w+',encoding='utf_8')
#                 train_loss, sub_hacc, sub_tacc, obj_hacc, obj_tacc = 0, 0, 0, 0, 0
#                 inputs_id, att_mask, sub_htar, sub_ttar, obj_htar, obj_ttar = data
#                 text, mask, sub_htar, sub_ttar, obj_htar, obj_ttar = inputs_id.to(device), att_mask.to(device), sub_htar.to(
#                     device), sub_ttar.to(device), obj_htar.to(device), obj_ttar.to(device)
#                 sub, tail, obj_h, obj_t, sub_loss, obj_loss, total_loss = model(text, mask, sub_htar, sub_ttar, obj_htar,                                                                    obj_ttar)
#                 sub_h=sub.ge(0.5).int()
#                 sub_t=tail.ge(0.5).int()
#                 obj_h=obj_h.ge(0.5).int()
#                 obj_t=obj_t.ge(0.5).int()
#                 sub_hacc += (sub.ge(0.5).int() == sub_htar).sum().item()
#                 sub_tacc += (tail.ge(0.5).int() == sub_ttar).sum().item()
#                 obj_hacc += (obj_h.ge(0.5).int() == obj_htar).sum().item()
#                 obj_tacc += (obj_t.ge(0.5).int() == obj_ttar).sum().item()
#                 train_loss += total_loss.item()
#                 # print('sub_h', sub_h)
#                 # print('sub_t',sub_t)
#                 # print('obj_h',obj_h)
#                 # print('obj_t',obj_t)
#                 # print(sub_hacc)
#                 # print(sub_tacc)
#                 # print(obj_hacc)
#                 # print(obj_tacc)
#                 sub, tail, obj_h, obj_t=sub.cpu(),tail.cpu(),obj_h.cpu(),obj_t.cpu()
#                 sub_heads_logits,sub_tails_logits,obj_heads_logits,obj_tails_logits=np.array(sub),np.array(tail),np.array(obj_h),np.array(obj_t)
#                 inputs_id=inputs_id.cpu()
#                 def extrac_triple(sub_heads_logits,sub_tails_logits,obj_heads_logits,obj_tails_logits,input_ids):  # 抽取一个triple
#                     h_bar=0.5
#                     t_bar=0.5
#                     sub_heads_logits = np.array(sub_heads_logits)
#                     sub_tails_logits = np.array(sub_tails_logits)
#                     tokens = tokenizer.convert_ids_to_tokens(input_ids)
#                     sub_heads, sub_tails = np.where(sub_heads_logits > h_bar)[0], np.where(sub_tails_logits > t_bar)[0]
#                     subjects = []
#                     for sub_head in sub_heads:
#                         sub_tail = sub_tails[sub_tails >= sub_head]
#                         if len(sub_tail) > 0:
#                             sub_tail = sub_tail[0]
#                             if sub_tail==sub_head:
#                                 subject = tokens[input_ids[sub_head]]
#                                 subjects.append((subject, sub_head, sub_tail))
#                             else:
#                                 subject = tokens[input_ids[sub_head: sub_tail]]
#                                 subjects.append((subject, sub_head, sub_tail))
#                     if subjects:
#                         triple_list = []
#                         sub_heads, sub_tails = np.array([sub[1:] for sub in subjects]).T.reshape((2, -1, 1))
#                         for i, subject in enumerate(subjects):
#                             sub = subject[0]
#                             sub = ''.join([i.lstrip("##") for i in sub])
#                             sub = ' '.join(sub.split('[unused1]'))
#                             obj_heads, obj_tails = np.where(obj_heads_logits > h_bar), np.where(obj_tails_logits > t_bar)
#                             for obj_head, rel_head in zip(*obj_heads):
#                                 for obj_tail, rel_tail in zip(*obj_tails):
#                                     if obj_head <= obj_tail and rel_head == rel_tail:
#                                         rel = id2rel[rel_head]
#                                         if obj_head==obj_tail:
#                                             obj=tokens[input_ids[obj_head]]
#                                         else:
#                                             obj = tokens[input_ids[obj_head: obj_tail]]
#                                         obj = ''.join([i.lstrip("##") for i in obj])
#                                         obj = ' '.join(obj.split('[unused1]'))
#                                         triple_list.append((sub, rel, obj))
#                                         break
#                         triple_set = set()
#                         for s, r, o in triple_list:
#                             triple_set.add((s, r, o))
#                         return list(triple_set)
#                     else:
#                         return []
#                 for i in range(val_batch_size):
#                     sub_heads_logits, sub_tails_logits, obj_heads_logits, obj_tails_logits=sub_heads_logits[i],sub_tails_logits[i],obj_heads_logits[i],obj_tails_logits[i]
#                     input_ids=inputs_id[i]
#                     triple_list=extrac_triple(sub_heads_logits,sub_tails_logits,obj_heads_logits,obj_tails_logits,input_ids)
#                     dict_t[j]=triple_list
#             # print(dict_t)
#             json.dump(dict_t,f)
#             f.close()
#         precision, recall, f1_score = metric(val_data, exact_match=False, output_path='output.json')
#         print(f'precision={precision},recall={recall},f1_score={f1_score}')

### train_epoch

### 文中模型训练

In [None]:
EPOCHS =30
train(model, train_set,dev_set, EPOCHS, learning_rate=1e-5, batch_size=4)

执行次数为：1


  5%|▌         | 802/15584 [01:05<20:08, 12.23it/s]

train_epoch|1,sub_loss=0.039343155920505524,sub_h_acc=0.99696044921875，sub_t_acc=0.99761474609375
obj_loss=0.06399550288915634,obj_h_acc=0.9995572916666666，obj_t_acc=0.9997037760416667
epoch1|train_loss=71.43427553056972,total_loss=0.08929284441321216


 10%|█         | 1602/15584 [02:11<19:07, 12.18it/s]

train_epoch|1,sub_loss=0.06530480831861496,sub_h_acc=0.99699462890625，sub_t_acc=0.99758056640625
obj_loss=0.06935432553291321,obj_h_acc=0.9995830281575521，obj_t_acc=0.9997177124023438
epoch1|train_loss=145.77316790743498,total_loss=0.09110822994214686


 15%|█▌        | 2402/15584 [03:16<17:57, 12.24it/s]

train_epoch|1,sub_loss=0.015951575711369514,sub_h_acc=0.997109375，sub_t_acc=0.997607421875
obj_loss=0.05534094199538231,obj_h_acc=0.9996099175347222，obj_t_acc=0.9997227647569444
epoch1|train_loss=213.1974768108048,total_loss=0.088832282004502


 21%|██        | 3202/15584 [04:22<16:54, 12.21it/s]

train_epoch|1,sub_loss=0.0005129882483743131,sub_h_acc=0.997008056640625，sub_t_acc=0.997581787109375
obj_loss=0.00258204760029912,obj_h_acc=0.999605458577474，obj_t_acc=0.9997221374511719
epoch1|train_loss=286.7658813186572,total_loss=0.08961433791208037


 26%|██▌       | 4002/15584 [05:27<15:59, 12.08it/s]

train_epoch|1,sub_loss=0.026952456682920456,sub_h_acc=0.99699365234375，sub_t_acc=0.99760595703125
obj_loss=0.04722091555595398,obj_h_acc=0.9996066487630209，obj_t_acc=0.9997274983723958
epoch1|train_loss=353.88088289472216,total_loss=0.08847022072368053


 31%|███       | 4802/15584 [06:33<14:44, 12.19it/s]

train_epoch|1,sub_loss=0.17299556732177734,sub_h_acc=0.9969698079427083，sub_t_acc=0.9976021321614583
obj_loss=0.2508247196674347,obj_h_acc=0.9996112738715278，obj_t_acc=0.9997311062282986
epoch1|train_loss=420.32134558673715,total_loss=0.08756694699723691


 36%|███▌      | 5602/15584 [07:39<13:39, 12.18it/s]

train_epoch|1,sub_loss=0.0010802315082401037,sub_h_acc=0.9969447544642858，sub_t_acc=0.9976402064732143
obj_loss=0.04475496709346771,obj_h_acc=0.9996099853515625，obj_t_acc=0.9997315906343006
epoch1|train_loss=491.8765814851213,total_loss=0.0878351038366288


 41%|████      | 6402/15584 [08:44<12:35, 12.15it/s]

train_epoch|1,sub_loss=0.009528722614049911,sub_h_acc=0.9969143676757812，sub_t_acc=0.9976321411132812
obj_loss=0.003369607264176011,obj_h_acc=0.9996122741699218，obj_t_acc=0.9997341156005859
epoch1|train_loss=562.4309317299449,total_loss=0.08787983308280388


 46%|████▌     | 7202/15584 [09:50<11:31, 12.12it/s]

train_epoch|1,sub_loss=0.04460179805755615,sub_h_acc=0.99685546875，sub_t_acc=0.9976356336805555
obj_loss=0.035871587693691254,obj_h_acc=0.999610279224537，obj_t_acc=0.9997357177734375
epoch1|train_loss=636.4132966666475,total_loss=0.0883907356481455


 51%|█████▏    | 8002/15584 [10:55<10:24, 12.13it/s]

train_epoch|1,sub_loss=0.0006047900533303618,sub_h_acc=0.996903564453125，sub_t_acc=0.9976533203125
obj_loss=0.014913925901055336,obj_h_acc=0.9996141967773438，obj_t_acc=0.9997389322916667
epoch1|train_loss=705.6868214588103,total_loss=0.08821085268235129


 56%|█████▋    | 8802/15584 [12:01<09:17, 12.17it/s]

train_epoch|1,sub_loss=0.0040470086969435215,sub_h_acc=0.9969475763494318，sub_t_acc=0.9976808860085228
obj_loss=0.015921585261821747,obj_h_acc=0.9996139433889678，obj_t_acc=0.9997404341264204
epoch1|train_loss=779.1007997647284,total_loss=0.08853418179144641


 62%|██████▏   | 9602/15584 [13:06<08:13, 12.12it/s]

train_epoch|1,sub_loss=5.697547021554783e-05,sub_h_acc=0.9969600423177083，sub_t_acc=0.9976778157552083
obj_loss=0.00015990171232260764,obj_h_acc=0.9996153598361545，obj_t_acc=0.999741465250651
epoch1|train_loss=850.5764707550697,total_loss=0.0886017157036531


 67%|██████▋   | 10402/15584 [14:12<07:04, 12.20it/s]

train_epoch|1,sub_loss=0.020247969776391983,sub_h_acc=0.9970010141225961，sub_t_acc=0.99770263671875
obj_loss=0.04385776072740555,obj_h_acc=0.9996192971254007，obj_t_acc=0.9997437149439102
epoch1|train_loss=921.8450137902328,total_loss=0.08863894363367623


 72%|███████▏  | 11202/15584 [15:17<06:01, 12.12it/s]

train_epoch|1,sub_loss=0.010269559919834137,sub_h_acc=0.9970295061383928，sub_t_acc=0.997718505859375
obj_loss=0.051650598645210266,obj_h_acc=0.9996196637834821，obj_t_acc=0.9997433326357887
epoch1|train_loss=993.4648753024267,total_loss=0.08870222100914524


 77%|███████▋  | 12002/15584 [16:23<04:55, 12.11it/s]

train_epoch|1,sub_loss=2.606430098239798e-05,sub_h_acc=0.9970675455729167，sub_t_acc=0.9977190755208334
obj_loss=1.5336900105467066e-05,obj_h_acc=0.999623779296875，obj_t_acc=0.999744615342882
epoch1|train_loss=1064.1892497704685,total_loss=0.08868243748087237


 82%|████████▏ | 12802/15584 [17:29<03:48, 12.16it/s]

train_epoch|1,sub_loss=0.0044414810836315155,sub_h_acc=0.997080078125，sub_t_acc=0.9977107238769531
obj_loss=0.027900777757167816,obj_h_acc=0.9996263885498047，obj_t_acc=0.9997452290852865
epoch1|train_loss=1139.4066093578658,total_loss=0.08901614135608327


 87%|████████▋ | 13602/15584 [18:34<02:43, 12.11it/s]

train_epoch|1,sub_loss=0.0016903569921851158,sub_h_acc=0.9971057846966912，sub_t_acc=0.9977023494944853
obj_loss=0.015968888998031616,obj_h_acc=0.999627565870098，obj_t_acc=0.9997446576286765
epoch1|train_loss=1216.726150203407,total_loss=0.0894651581031917


 92%|█████████▏| 14402/15584 [19:40<01:36, 12.19it/s]

train_epoch|1,sub_loss=0.0005905277794227004,sub_h_acc=0.997125244140625，sub_t_acc=0.9976987033420139
obj_loss=0.04364532232284546,obj_h_acc=0.9996311442057292，obj_t_acc=0.9997463424117476
epoch1|train_loss=1285.5508386289312,total_loss=0.08927436379367577


 98%|█████████▊| 15202/15584 [20:45<00:31, 12.17it/s]

train_epoch|1,sub_loss=0.008302723988890648,sub_h_acc=0.9971321186266447，sub_t_acc=0.9977040501644737
obj_loss=0.13331863284111023,obj_h_acc=0.9996320757949562，obj_t_acc=0.9997471002946821
epoch1|train_loss=1360.0363155761897,total_loss=0.08947607339317037


100%|██████████| 15584/15584 [21:16<00:00, 12.20it/s]


--------------------------------------------------
val_train_epoch|1,val_sub_loss=5.405672709457576e-06,sub_h_acc=125.70866613418531，sub_t_acc=125.80880591054313
obj_loss=1.894318120321259e-05,obj_h_acc=0.999621439030884，obj_t_acc=0.9997940794728435
val_epoch1|total_loss=0.08710413585104891
correct_num:176.0000000001
predict_num:497.0000000001
gold_num:390.0000000001
precision=0.3541247484910756,recall=0.45128205128219195,f1_score=0.3968432919956264


  5%|▌         | 802/15584 [01:05<20:06, 12.25it/s]

train_epoch|2,sub_loss=0.05906885862350464,sub_h_acc=0.9976513671875，sub_t_acc=0.99842041015625
obj_loss=0.001001895871013403,obj_h_acc=0.9996732584635417，obj_t_acc=0.9998264567057291
epoch2|train_loss=54.656187960225,total_loss=0.06832023495028125


 10%|█         | 1602/15584 [02:10<18:57, 12.29it/s]

train_epoch|2,sub_loss=0.00011282834748271853,sub_h_acc=0.9976611328125，sub_t_acc=0.9983935546875
obj_loss=4.390787216834724e-05,obj_h_acc=0.9995419311523438，obj_t_acc=0.9996854654947916
epoch2|train_loss=179.06074381446797,total_loss=0.11191296488404248


 15%|█▌        | 2402/15584 [03:15<17:52, 12.29it/s]

train_epoch|2,sub_loss=1.929621021190542e-06,sub_h_acc=0.9978426106770834，sub_t_acc=0.9984415690104167
obj_loss=0.11820939928293228,obj_h_acc=0.9996153428819444，obj_t_acc=0.9997391086154513
epoch2|train_loss=235.00741282325123,total_loss=0.09791975534302134


 21%|██        | 3202/15584 [04:20<16:51, 12.24it/s]

train_epoch|2,sub_loss=0.00012784528371412307,sub_h_acc=0.997943115234375，sub_t_acc=0.9984454345703125
obj_loss=0.016591859981417656,obj_h_acc=0.999643300374349，obj_t_acc=0.9997656758626302
epoch2|train_loss=297.18694411733577,total_loss=0.09287092003666743


 26%|██▌       | 4002/15584 [05:25<15:46, 12.24it/s]

train_epoch|2,sub_loss=0.0231076180934906,sub_h_acc=0.9980439453125，sub_t_acc=0.99852099609375
obj_loss=0.015330925583839417,obj_h_acc=0.9996640625，obj_t_acc=0.9997810872395834
epoch2|train_loss=354.8754397269197,total_loss=0.08871885993172993


 31%|███       | 4802/15584 [06:30<14:46, 12.16it/s]

train_epoch|2,sub_loss=5.4546114824916e-07,sub_h_acc=0.9980802408854167，sub_t_acc=0.998531494140625
obj_loss=4.428205784279271e-07,obj_h_acc=0.9996758015950521，obj_t_acc=0.9997904798719618
epoch2|train_loss=414.2302943135103,total_loss=0.08629797798198131


 36%|███▌      | 5602/15584 [07:36<13:37, 12.20it/s]

train_epoch|2,sub_loss=6.143033260741504e-06,sub_h_acc=0.9981044224330358，sub_t_acc=0.9985484095982143
obj_loss=0.08743640780448914,obj_h_acc=0.9996658470517114，obj_t_acc=0.9997771344866071
epoch2|train_loss=478.0987109346366,total_loss=0.08537476980975653


 40%|███▉      | 6200/15584 [08:24<12:44, 12.28it/s]


KeyboardInterrupt: ignored

In [None]:
EPOCHS =30
train(model, train_set,dev_set, EPOCHS, learning_rate=1e-5, batch_size=2)

In [None]:
EPOCHS =10
train(model, train_set,dev_set, EPOCHS, learning_rate=1e-5, batch_size=16)

In [None]:
EPOCHS =10
train(model, train_set,dev_set, EPOCHS, learning_rate=1e-5, batch_size=16)

In [None]:
len(train_set)

In [None]:

val_loader = DataLoader(dev_set,batch_size=1,shuffle=False)
val_batch_size=1
with torch.no_grad():
    dict_t = {}
    vstep = 0
    train_loss, sub_hacc, sub_tacc, obj_hacc, obj_tacc = 0, 0, 0, 0, 0
    f = open('val_traintriple_list.json', mode='w+', encoding='utf_8')
  
    for j, data in enumerate(val_loader):
        inputs_id, att_mask, sub_htar, sub_ttar, obj_htar, obj_ttar = data
        text, mask, sub_htar, sub_ttar, obj_htar, obj_ttar = inputs_id.to(device), att_mask.to(device), sub_htar.to(
            device), sub_ttar.to(device), obj_htar.to(device), obj_ttar.to(device)
        sub, tail, obj_h, obj_t, sub_loss, obj_loss, total_loss = model(text, mask, sub_htar, sub_ttar, obj_htar,
                                                                        obj_ttar)
        sub_h = sub.ge(0.5).int()
        sub_t = tail.ge(0.5).int()
        obj_h = obj_h.ge(0.5).int()
        obj_t = obj_t.ge(0.5).int()
        sub_hacc += (sub.ge(0.5).int() == sub_htar).sum().item()
        sub_tacc += (tail.ge(0.5).int() == sub_ttar).sum().item()
        obj_hacc += (obj_h.ge(0.5).int() == obj_htar).sum().item()
        obj_tacc += (obj_t.ge(0.5).int() == obj_ttar).sum().item()
        train_loss += total_loss.item()
        vstep += 1
        sub, tail, obj_h, obj_t = sub.cpu(), tail.cpu(), obj_h.cpu(), obj_t.cpu()
        sub_heads_logits, sub_tails_logits, obj_heads_logits, obj_tails_logits = np.array(sub), np.array(
            tail), np.array(obj_h), np.array(obj_t)
        inputs_id = inputs_id.cpu()
        # for i in range(val_batch_size):
        #     sub_heads_logits, sub_tails_logits, obj_heads_logits, obj_tails_logits=sub_heads_logits[i],sub_tails_logits[i],obj_heads_logits[i],obj_tails_logits[i]
        #     input_ids=inputs_id[i]
        #     triple_list=extrac_triple(sub_heads_logits,sub_tails_logits,obj_heads_logits,obj_tails_logits,input_ids)
        #     dict_t[j]=triple_list
        triple_list = extrac_triple(sub_heads_logits[0], sub_tails_logits[0], obj_heads_logits[0], obj_tails_logits[0], inputs_id[0])
        dict_t[j] = triple_list
        if vstep<20:
          tokens = tokenizer.convert_ids_to_tokens(inputs_id[0])
          print(tokens)
          print(sub_h,sub_t)
          print('\n')
          print(triple_list)
    # print(dict_t)
    json.dump(dict_t, f)
    f.close()
    print('-' * 50)
    print('val_train_epoch|{},val_sub_loss={},sub_h_acc={}，sub_t_acc={}'.format( 1, sub_loss, sub_hacc / (
                len(dev_set) * val_batch_size * 128),
                                                                                sub_tacc / (
                                                                                            len(dev_set) * val_batch_size * 128)))
    print('obj_loss={},obj_h_acc={}，obj_t_acc={}'.format(obj_loss, obj_hacc / (
                len(dev_set) * val_batch_size * 128 * len(rel2id)),
                                                         obj_tacc / (len(dev_set) * val_batch_size * 128 * len(
                                                             rel2id))))
    print('val_epoch{}|total_loss={}'.format( 1, train_loss / vstep))
precision, recall, f1_score = metric(val_data, exact_match=False, output_path='val_output.json')
print(f'precision={precision},recall={recall},f1_score={f1_score}')

In [None]:
len(dev_set)

In [None]:
len(dev_set)

In [None]:
len(val_data)

In [None]:
i=0 
for i in range(5):
  i+=1
  print(i)

### Test

In [None]:
BERT_MAX_LEN = 128
RANDOM_SEED = 2019
model_path='./model.ckpt'


test数据

In [None]:
def to_tuple(sent):
    triple_list = []
    for triple in sent['triple_list']:
        triple_list.append(tuple(triple))
    sent['triple_list'] = triple_list
def load_data(test_path, rel_dict_path):
    test_data = json.load(open(test_path))
    id2rel, rel2id = json.load(open(rel_dict_path))

    id2rel = {int(i): j for i, j in id2rel.items()}
    num_rels = len(id2rel)

    for sent in test_data:
        to_tuple(sent)

    print("test_data len:", len(test_data))

    return test_data, id2rel, rel2id, num_rels

def find_head_idx(source, target):
    target_len = len(target)
    for i in range(len(source)):
        if source[i: i + target_len] == target:
            return i
    return -1
class data_generator:
    def __init__(self, data, tokenizer, rel2id, num_rels, maxlen):
        self.data = data
        self.batch_size = len(self.data)
        self.tokenizer = tokenizer
        self.maxlen = maxlen
        self.rel2id = rel2id
        self.num_rels = num_rels
    def __len__(self):
        return self.batch_size
    def generator(self):
        while True:
            idxs = list(range(len(self.data)))
            # np.random.seed(RANDOM_SEED)
            # np.random.shuffle(idxs)
            tokens_batch, segments_batch, sub_heads_batch, sub_tails_batch, sub_head_batch, sub_tail_batch, obj_heads_batch, obj_tails_batch = [], [], [], [], [], [], [], []
            for idx in idxs:
                line = self.data[idx]
                text = ' '.join(line['text'].split()[:self.maxlen])
                tokens = self.tokenizer.tokenize(text)
                if len(tokens) > BERT_MAX_LEN:
                    tokens = tokens[:BERT_MAX_LEN]


                s2ro_map = {}
                for triple in line['triple_list']:
                    triple = (self.tokenizer.tokenize(triple[0]), triple[1], self.tokenizer.tokenize(triple[2]))
                    sub_head_idx = find_head_idx(tokens, triple[0])
                    obj_head_idx = find_head_idx(tokens, triple[2])
                    if sub_head_idx != 0 and obj_head_idx != 0:
                        sub = (sub_head_idx, sub_head_idx + len(triple[0]) - 1)
                        if sub not in s2ro_map:
                            s2ro_map[sub] = []
                        s2ro_map[sub].append((obj_head_idx,##subject to relation object
                                           obj_head_idx + len(triple[2]) - 1,#obj_tail_idx
                                           self.rel2id[triple[1]]))#rel
                text_len = BERT_MAX_LEN
                if s2ro_map:
                    # token_ids, segment_ids = self.tokenizer.encode(text)
                    inputs=self.tokenizer(text,return_tensors='pt',add_special_tokens=False,truncation=True,padding=True,max_length=BERT_MAX_LEN)
                    token_ids, segment_ids=inputs['input_ids'],inputs['attention_mask']
                    pad_len=BERT_MAX_LEN-token_ids.shape[1]
                    pad_seq=torch.zeros(1,pad_len)
                    token_ids=torch.cat((token_ids,pad_seq),dim=1)
                    segment_ids=torch.cat((segment_ids,pad_seq),dim=1)
                    # if inputs.shape[1]<BERT_MAX_LEN:
                    #     inputs=inputs
                    #     [seq, [padding] * (max_length - len(seq))]

                    # if len(token_ids) > text_len:
                    #     token_ids = token_ids[:text_len]
                    #     segment_ids = segment_ids[:text_len]
                    tokens_batch.append(token_ids)
                    segments_batch.append(segment_ids)
                    sub_heads, sub_tails=torch.zeros(text_len),torch.zeros(text_len)
                    for s in s2ro_map:
                        sub_heads[s[0]] = 1
                        sub_tails[s[1]] = 1
                    sub_head, sub_tail = choice(list(s2ro_map.keys()))
                    obj_heads, obj_tails = torch.zeros((text_len, self.num_rels)), torch.zeros((text_len, self.num_rels))
                    for ro in s2ro_map.get((sub_head, sub_tail), []):
                        obj_heads[ro[0]][ro[2]] = 1
                        obj_tails[ro[1]][ro[2]] = 1
                    sub_heads_batch.append(sub_heads)
                    sub_tails_batch.append(sub_tails)
                    sub_head_batch.append([sub_head])
                    sub_tail_batch.append([sub_tail])
                    obj_heads_batch.append(obj_heads)
                    obj_tails_batch.append(obj_tails)
            return tokens_batch, segments_batch, sub_heads_batch, sub_tails_batch, sub_head_batch, sub_tail_batch, obj_heads_batch, obj_tails_batch
# from changshi2 import data_generator

test_data, id2rel, rel2id, num_rels = load_data('./test_triples.json', rel_dict_path='./rel2id.json')
tokenizer = BertTokenizer.from_pretrained(model_name)
maxlen = 100
test_tokens_batch, segments_batch, sub_heads_batch, sub_tails_batch, sub_head_batch, sub_tail_batch, obj_heads_batch, obj_tails_batch=data_generator(test_data, tokenizer, rel2id, num_rels, maxlen)\
        .generator()

In [None]:
class data_generator:
    def __init__(self, data, tokenizer, rel2id, num_rels, maxlen):
        self.data = data
        self.batch_size = len(self.data)
        self.tokenizer = tokenizer
        self.maxlen = maxlen
        self.rel2id = rel2id
        self.num_rels = num_rels
    def __len__(self):
        return self.batch_size
    def generator(self):
        # while True:
            idxs = list(range(len(self.data)))
            # np.random.seed(RANDOM_SEED)
            # np.random.shuffle(idxs)
            tokens_batch, segments_batch, sub_heads_batch, sub_tails_batch, sub_head_batch, sub_tail_batch, obj_heads_batch, obj_tails_batch = [], [], [], [], [], [], [], []
            for idx in idxs:
                line = self.data[idx]
                text = ' '.join(line['text'].split()[:self.maxlen])
                tokens = self.tokenizer.tokenize(text)
                if len(tokens) > BERT_MAX_LEN:
                    tokens = tokens[:BERT_MAX_LEN]


                s2ro_map = {}
                for triple in line['triple_list']:
                    triple = (self.tokenizer.tokenize(triple[0]), triple[1], self.tokenizer.tokenize(triple[2]))
                    sub_head_idx = find_head_idx(tokens, triple[0])
                    obj_head_idx = find_head_idx(tokens, triple[2])
                    if sub_head_idx != 0 and obj_head_idx != 0:
                        sub = (sub_head_idx, sub_head_idx + len(triple[0]) - 1)
                        if sub not in s2ro_map:
                            s2ro_map[sub] = []
                        s2ro_map[sub].append((obj_head_idx,##subject to relation object
                                           obj_head_idx + len(triple[2]) - 1,#obj_tail_idx
                                           self.rel2id[triple[1]]))#rel
                text_len = BERT_MAX_LEN
                inputs=self.tokenizer(text,return_tensors='pt',add_special_tokens=False,truncation=True,padding=True,max_length=BERT_MAX_LEN)
                token_ids, segment_ids=inputs['input_ids'],inputs['attention_mask']
                pad_len=BERT_MAX_LEN-token_ids.shape[1]
                pad_seq=torch.zeros(1,pad_len)
                token_ids=torch.cat((token_ids,pad_seq),dim=1)
                segment_ids=torch.cat((segment_ids,pad_seq),dim=1)
                    # if inputs.shape[1]<BERT_MAX_LEN:
                    #     inputs=inputs
                    #     [seq, [padding] * (max_length - len(seq))]

                    # if len(token_ids) > text_len:
                    #     token_ids = token_ids[:text_len]
                    #     segment_ids = segment_ids[:text_len]
                tokens_batch.append(token_ids)
                segments_batch.append(segment_ids)
                sub_heads, sub_tails=torch.zeros(text_len),torch.zeros(text_len)
                obj_heads, obj_tails = torch.zeros((text_len, self.num_rels)), torch.zeros((text_len, self.num_rels))
                if s2ro_map:
                    # token_ids, segment_ids = self.tokenizer.encode(text)
                    
                    
                    for s in s2ro_map:
                        sub_heads[s[0]] = 1
                        sub_tails[s[1]] = 1
                    # sub_head, sub_tail = choice(list(s2ro_map.keys()))

                    
                    sub=list(s2ro_map.keys())
                    for sub_head,sub_tail in sub:
                        for ro in s2ro_map.get((sub_head, sub_tail), []):
                            obj_heads[ro[0]][ro[2]] = 1
                            obj_tails[ro[1]][ro[2]] = 1
                        # sub_head_batch.append([sub_head])
                        # sub_tail_batch.append([sub_tail])
                    # print(sub_heads,sub_tails,torch.where(obj_heads==1),torch.where(obj_tails)==1)
                sub_heads_batch.append(sub_heads)
                sub_tails_batch.append(sub_tails)
                obj_heads_batch.append(obj_heads)
                obj_tails_batch.append(obj_tails)
            return tokens_batch, segments_batch, sub_heads_batch, sub_tails_batch, sub_head_batch, sub_tail_batch, obj_heads_batch, obj_tails_batch

In [None]:
test_data, id2rel, rel2id, num_rels = load_data('./test_triples.json', rel_dict_path='./rel2id.json')
tokenizer = BertTokenizer.from_pretrained(model_name)
maxlen = 100
test_tokens_batch, segments_batch, sub_heads_batch, sub_tails_batch, sub_head_batch, sub_tail_batch, obj_heads_batch, obj_tails_batch=data_generator(test_data, tokenizer, rel2id, num_rels, maxlen)\
        .generator()

In [None]:

tokens_batch=torch.cat([l for l in test_tokens_batch]).int()
print(tokens_batch.shape)
segments_batch=torch.cat([l for l in segments_batch]).int()#sents_length*128
sub_heads_batch=torch.cat([l for l in sub_heads_batch]).reshape(tokens_batch.shape[0],-1).float()#128
sub_tails_batch=torch.cat([l for l in sub_tails_batch]).reshape(tokens_batch.shape[0],-1).float()
obj_heads_batch=torch.cat([l for l in obj_heads_batch]).reshape(tokens_batch.shape[0],128,-1).float()
obj_tails_batch=torch.cat([l for l in obj_tails_batch]).reshape(tokens_batch.shape[0],128,-1).float()
print(type(segments_batch))
print(segments_batch.shape)
print(obj_heads_batch.shape)
print(type(obj_heads_batch))

In [None]:
device='cuda' if torch.cuda.is_available() else 'gpu'
model=E2EModel().to(device)
model.load_state_dict(torch.load(model_path))
optim=torch.optim.Adam(model.parameters(),lr=1e-5)
from torch.utils.data import DataLoader,Dataset,TensorDataset
train_set=TensorDataset(tokens_batch, segments_batch, sub_heads_batch, sub_tails_batch,obj_heads_batch,obj_tails_batch)
def partial_match(pred_set, gold_set):
    pred = {(i[0].split(' ')[0] if len(i[0].split(' ')) > 0 else i[0], i[1],
                 i[2].split(' ')[0] if len(i[2].split(' ')) > 0 else i[2]) for i in pred_set}
    gold = {(i[0].split(' ')[0] if len(i[0].split(' ')) > 0 else i[0], i[1],
                 i[2].split(' ')[0] if len(i[2].split(' ')) > 0 else i[2]) for i in gold_set}
    return pred, gold

### batch_size

In [None]:
batch_size=1
test_loader = DataLoader(train_set,batch_size=batch_size,shuffle=False)

### test测试版

In [None]:
val_batch_size = 1
with torch.no_grad():
    dict_t = {}
    vstep = 0
    train_loss, sub_hacc, sub_tacc, obj_hacc, obj_tacc = 0, 0, 0, 0, 0
    for j, data in enumerate(test_loader):
        f = open('triple_list_test.json', mode='w+', encoding='utf_8')

        inputs_id, att_mask, sub_htar, sub_ttar, obj_htar, obj_ttar = data
        text, mask, sub_htar, sub_ttar, obj_htar, obj_ttar = inputs_id.to(device), att_mask.to(device), sub_htar.to(
            device), sub_ttar.to(device), obj_htar.to(device), obj_ttar.to(device)
        sub, tail, obj_h, obj_t, sub_loss, obj_loss, total_loss = model(text, mask, sub_htar, sub_ttar, obj_htar,
                                                                        obj_ttar)
        sub_h = sub.ge(0.5).int()
        sub_t = tail.ge(0.5).int()
        obj_h = obj_h.ge(0.5).int()
        obj_t = obj_t.ge(0.5).int()
        sub_hacc += (sub.ge(0.5).int() == sub_htar).sum().item()
        sub_tacc += (tail.ge(0.5).int() == sub_ttar).sum().item()
        obj_hacc += (obj_h.ge(0.5).int() == obj_htar).sum().item()
        obj_tacc += (obj_t.ge(0.5).int() == obj_ttar).sum().item()
        train_loss += total_loss.item()
        # print('sub_h', sub_h)
        # print('sub_t',sub_t)
        # print('obj_h',obj_h)
        # print('obj_t',obj_t)
        # print(sub_hacc)
        # print(sub_tacc)
        # print(obj_hacc)
        # print(obj_tacc)
        vstep += 1
        sub, tail, obj_h, obj_t = sub.cpu(), tail.cpu(), obj_h.cpu(), obj_t.cpu()
        sub_heads_logits, sub_tails_logits, obj_heads_logits, obj_tails_logits = np.array(sub), np.array(
            tail), np.array(obj_h), np.array(obj_t)
        inputs_id = inputs_id.cpu()


        def extrac_triple(sub_heads_logits, sub_tails_logits, obj_heads_logits, obj_tails_logits,
                          input_ids):  # 抽取一个triple
            h_bar = 0.5
            t_bar = 0.5
            sub_heads_logits = np.array(sub_heads_logits)
            sub_tails_logits = np.array(sub_tails_logits)
            tokens = tokenizer.convert_ids_to_tokens(input_ids)
            # print(tokens)
            sub_heads, sub_tails = np.where(sub_heads_logits > h_bar)[0], np.where(sub_tails_logits > t_bar)[0]
            subjects = []
            for sub_head in sub_heads:
                sub_tail = sub_tails[sub_tails >= sub_head]
                if len(sub_tail) > 0:
                    sub_tail = sub_tail[0]
                    if sub_tail == sub_head:
                        subject = tokens[sub_head]
                        subjects.append((subject, sub_head, sub_tail))
                    else:
                        subject = tokens[sub_head: sub_tail]
                        subjects.append((subject, sub_head, sub_tail))
            if subjects:
                triple_list = []
                sub_heads, sub_tails = np.array([sub[1:] for sub in subjects]).T.reshape((2, -1, 1))
                for i, subject in enumerate(subjects):
                    sub = subject[0]
                    sub = ''.join([i.lstrip("##") for i in sub])
                    sub = ' '.join(sub.split('[unused1]'))
                    obj_heads, obj_tails = np.where(obj_heads_logits > h_bar), np.where(obj_tails_logits > t_bar)
                    for obj_head, rel_head in zip(*obj_heads):
                        for obj_tail, rel_tail in zip(*obj_tails):
                            if obj_head <= obj_tail and rel_head == rel_tail:
                                rel = id2rel[rel_head]
                                if obj_head == obj_tail:
                                    obj = tokens[obj_head]
                                else:
                                    obj = tokens[obj_head: obj_tail]
                                obj = ''.join([i.lstrip("##") for i in obj])
                                obj = ' '.join(obj.split('[unused1]'))
                                triple_list.append((sub, rel, obj))
                                break
                triple_set = set()
                for s, r, o in triple_list:
                    triple_set.add((s, r, o))
                return list(triple_set)
            else:
                return []


        for i in range(val_batch_size):
            sub_heads_logits, sub_tails_logits, obj_heads_logits, obj_tails_logits = sub_heads_logits[i], \
                                                                                     sub_tails_logits[i], \
                                                                                     obj_heads_logits[i], \
                                                                                     obj_tails_logits[i]
            input_ids = inputs_id[i]
            triple_list = extrac_triple(sub_heads_logits, sub_tails_logits, obj_heads_logits, obj_tails_logits,
                                        input_ids)
            dict_t[j] = triple_list

    # print(dict_t)
    json.dump(dict_t, f)
    f.close()
    print('-' * 50)
    print('val_train_epoch|{},val_sub_loss={},sub_h_acc={}，sub_t_acc={}'.format( 1, sub_loss, sub_hacc / (
                vstep * val_batch_size * 128),
                                                                                sub_tacc / (
                                                                                            vstep * val_batch_size * 128)))
    print('obj_loss={},obj_h_acc={}，obj_t_acc={}'.format(obj_loss,obj_hacc / (vstep * val_batch_size * 128 * len(rel2id)),obj_tacc / (vstep * val_batch_size * 128 * len(rel2id))))
    print('val_epoch{}|total_loss={}'.format(1, train_loss / vstep))


In [None]:
def metric(eval_data,exact_match=False, output_path=None):
    if output_path:
        F = open(output_path, 'w+')
    orders = ['subject', 'relation', 'object']
    correct_num, predict_num, gold_num = 1e-10, 1e-10, 1e-10
    # for line in tqdm(iter(eval_data)):
    with open('triple_list_test.json') as f:
        f=f.readline()
        dict_=json.loads(f)
    for i,line in enumerate(eval_data):
        if i<tokens_batch.shape[0]:
            pre=dict_[str(i)]
            Pred_triples=set([tuple(l) for l in pre])
            Gold_triples = set(tuple(l) for l in line['triple_list'])
            Pred_triples_eval, Gold_triples_eval = partial_match(Pred_triples, Gold_triples) if not exact_match else (Pred_triples, Gold_triples)

            correct_num += len(Pred_triples_eval & Gold_triples_eval)
            predict_num += len(Pred_triples_eval)
            gold_num += len(Gold_triples_eval)

            if output_path:
                result = json.dumps({
                    'text': line['text'],
                    'triple_list_gold': [
                        dict(zip(orders, triple)) for triple in Gold_triples
                    ],
                    'triple_list_pred': [
                        dict(zip(orders, triple)) for triple in Pred_triples
                    ],
                    'new': [
                        dict(zip(orders, triple)) for triple in Pred_triples - Gold_triples
                    ],
                    'lack': [
                        dict(zip(orders, triple)) for triple in Gold_triples - Pred_triples
                    ]
                }, ensure_ascii=False, indent=4)
                F.write(result + '\n')


    precision = correct_num / predict_num
    recall = correct_num / gold_num
    f1_score = 2 * precision * recall / (precision + recall)

    print(f'correct_num:{correct_num}\npredict_num:{predict_num}\ngold_num:{gold_num}')
    return precision, recall, f1_score

precision, recall, f1_score=metric(test_data,exact_match=False, output_path='output.json')
print(precision, recall, f1_score)

In [None]:
precision, recall, f1_score = metric(test_data, exact_match=False, output_path='output_test.json')
print(f'precision={precision},recall={recall},f1_score={f1_score}')

### test_pro

In [None]:
with torch.no_grad():
    dict_t = {}
    train_loss, sub_hacc, sub_tacc, obj_hacc, obj_tacc = 0, 0, 0, 0, 0
    for j, data in enumerate(test_loader):
        f=open('triple_list.json',mode='w+',encoding='utf_8')
        
        inputs_id, att_mask, sub_htar, sub_ttar, obj_htar, obj_ttar = data
        text, mask, sub_htar, sub_ttar, obj_htar, obj_ttar = inputs_id.to(device), att_mask.to(device), sub_htar.to(
            device), sub_ttar.to(device), obj_htar.to(device), obj_ttar.to(device)
        sub, tail, obj_h, obj_t, sub_loss, obj_loss, total_loss = model(text, mask, sub_htar, sub_ttar, obj_htar,                                                                    obj_ttar)
        sub_h=sub.ge(0.5).int()
        sub_t=tail.ge(0.5).int()
        obj_h=obj_h.ge(0.5).int()
        obj_t=obj_t.ge(0.5).int()
        sub_hacc += (sub.ge(0.5).int() == sub_htar).sum().item()
        sub_tacc += (tail.ge(0.5).int() == sub_ttar).sum().item()
        obj_hacc += (obj_h.ge(0.5).int() == obj_htar).sum().item()
        obj_tacc += (obj_t.ge(0.5).int() == obj_ttar).sum().item()
        train_loss += total_loss.item()
        # print('sub_h', sub_h)
        # print('sub_t',sub_t)
        # print('obj_h',obj_h)
        # print('obj_t',obj_t)
        # print(sub_hacc)
        # print(sub_tacc)
        # print(obj_hacc)
        # print(obj_tacc)
        sub, tail, obj_h, obj_t=sub.cpu(),tail.cpu(),obj_h.cpu(),obj_t.cpu()
        sub_heads_logits,sub_tails_logits,obj_heads_logits,obj_tails_logits=np.array(sub),np.array(tail),np.array(obj_h),np.array(obj_t)
        inputs_id=inputs_id.cpu()
        def extrac_triple(sub_heads_logits,sub_tails_logits,obj_heads_logits,obj_tails_logits,input_ids):  # 抽取一个triple
            h_bar=0.5
            t_bar=0.5
            sub_heads_logits = np.array(sub_heads_logits)
            sub_tails_logits = np.array(sub_tails_logits)
            tokens = tokenizer.convert_ids_to_tokens(input_ids)
            sub_heads, sub_tails = np.where(sub_heads_logits > h_bar)[0], np.where(sub_tails_logits > t_bar)[0]
            subjects = []
            for sub_head in sub_heads:
                sub_tail = sub_tails[sub_tails >= sub_head]
                if len(sub_tail) > 0:
                    sub_tail = sub_tail[0]
                    if sub_tail==sub_head:
                        subject = tokens[sub_head]
                        subjects.append((subject, sub_head, sub_tail))
                    else:
                        subject = tokens[sub_head: sub_tail]
                        subjects.append((subject, sub_head, sub_tail))
            if subjects:
                triple_list = []
                sub_heads, sub_tails = np.array([sub[1:] for sub in subjects]).T.reshape((2, -1, 1))
                for i, subject in enumerate(subjects):
                    sub = subject[0]
                    sub = ''.join([i.lstrip("##") for i in sub])
                    sub = ' '.join(sub.split('[unused1]'))
                    obj_heads, obj_tails = np.where(obj_heads_logits > h_bar), np.where(obj_tails_logits > t_bar)
                    for obj_head, rel_head in zip(*obj_heads):
                        for obj_tail, rel_tail in zip(*obj_tails):
                            if obj_head <= obj_tail and rel_head == rel_tail:
                                rel = id2rel[rel_head]
                                if obj_head==obj_tail:
                                    obj=tokens[obj_head]
                                else:
                                    obj = tokens[obj_head: obj_tail]
                                obj = ''.join([i.lstrip("##") for i in obj])
                                obj = ' '.join(obj.split('[unused1]'))
                                triple_list.append((sub, rel, obj))
                                break
                triple_set = set()
                for s, r, o in triple_list:
                    triple_set.add((s, r, o))
                return list(triple_set)
            else:
                return []
        for i in range(batch_size):
            sub_heads_logits, sub_tails_logits, obj_heads_logits, obj_tails_logits=sub_heads_logits[i],sub_tails_logits[i],obj_heads_logits[i],obj_tails_logits[i]
            input_ids=inputs_id[i]
            triple_list=extrac_triple(sub_heads_logits,sub_tails_logits,obj_heads_logits,obj_tails_logits,input_ids)
            dict_t[j*batch_size+i]=triple_list
    # print(dict_t)
    json.dump(dict_t,f)
    f.close()


In [None]:
precision, recall, f1_score=metric(test_data,exact_match=False, output_path='triple_list.json')
print(precision, recall, f1_score)

In [None]:
f=json.load(open('triple_list.json',encoding='utf_8'))

In [None]:
def load_data(test_path, rel_dict_path):
    test_data = json.load(open(test_path))
    id2rel, rel2id = json.load(open(rel_dict_path))

    id2rel = {int(i): j for i, j in id2rel.items()}
    num_rels = len(id2rel)

    for sent in test_data:
        to_tuple(sent)

In [None]:
for i,line in enumerate(test_data):
    if i<5:
      print(json.loads(line))