# Data Processing


**Agend：**

1. Covert row data to be JSON/将原始数据转化为json格式

2. Load the pretrained word vectors/预加载词向量

3. Replation Encoding/关系编码

4. Training Sample Generation 样本生成


In [1]:
import os
import torch
import numpy as np
import json
from tqdm import tqdm
import re
from nltk.tokenize import word_tokenize
from torch.utils.data import Dataset, DataLoader

## 1 Covert row data to be JSON
input /输入：
```
8001	"The most common <e1>audits</e1> were about <e2>waste</e2> and recycling."
Message-Topic(e1,e2)
Comment: Assuming an audit = an audit document.
```
output /输出:
```
{"id": "8001", "relation": "Message-Topic(e1,e2)", "head": "audits", "tail": "waste", "subj_start": 3, "subj_end": 3, "obj_start": 6, "obj_end": 6, "sentence": ["The", "most", "common", "audits", "were", "about", "waste", "and", "recycling", "."], "comment": " Assuming an audit = an audit document."}
```

　

In [2]:
sentence = "The most common <e1>audits</e1> were about <e2>waste</e2> and recycling."

In [3]:
e1 = re.findall(r'<e1>(.*)</e1>', sentence)[0]
e2 = re.findall(r'<e2>(.*)</e2>', sentence)[0]

In [4]:
sentence = sentence.replace('<e1>' + e1 + '</e1>', ' <e1> ' + e1 + ' </e1> ', 1)
sentence = sentence.replace('<e2>' + e2 + '</e2>', ' <e2> ' + e2 + ' </e2> ', 1)

In [5]:
sentence

'The most common  <e1> audits </e1>  were about  <e2> waste </e2>  and recycling.'

In [6]:
sentence = word_tokenize(sentence)

In [7]:
# sentence

In [8]:
class processor(object):
    def __init__(self):
        pass
    '''text cleansing'''
    def search_entity(self,sentence):
        # extract the entities
        e1 = re.findall(r'<e1>(.*)</e1>', sentence)[0]
        e2 = re.findall(r'<e2>(.*)</e2>', sentence)[0]
        
        sentence = sentence.replace('<e1>' + e1 + '</e1>', ' <e1> ' + e1 + ' </e1> ', 1)
        sentence = sentence.replace('<e2>' + e2 + '</e2>', ' <e2> ' + e2 + ' </e2> ', 1)
        
        sentence = word_tokenize(sentence)
        sentence = ' '.join(sentence)
        sentence = sentence.replace('< e1 >', '<e1>')
        sentence = sentence.replace('< e2 >', '<e2>')
        sentence = sentence.replace('< /e1 >', '</e1>')
        sentence = sentence.replace('< /e2 >', '</e2>')
        sentence = sentence.split()

        assert '<e1>' in sentence
        assert '<e2>' in sentence
        assert '</e1>' in sentence
        assert '</e2>' in sentence
        
        ## two entiti location index finding
        subj_start = subj_end = obj_start = obj_end = 0
        
        pure_sentence = []
        for i, word in enumerate(sentence):
            if '<e1>' == word:
                subj_start = len(pure_sentence)
                continue
            if '</e1>' == word:
                subj_end = len(pure_sentence) - 1
                continue
            if '<e2>' == word:
                obj_start = len(pure_sentence)
                continue
            if '</e2>' == word:
                obj_end = len(pure_sentence) - 1
                continue
            pure_sentence.append(word)
        return e1, e2, subj_start, subj_end, obj_start, obj_end, pure_sentence
    
    '''covert to be json format'''
    
    def convert(self,path_src, path_des):
        with open(path_src, 'r', encoding='utf-8') as fr:
            data = fr.readlines()
        with open(path_des, 'w', encoding='utf-8') as fw:
            for i in tqdm(range(0, len(data), 4)):
                id_s, sentence = data[i].strip().split('\t')
                #每三行为一整个
                sentence = sentence[1:-1]
                e1, e2, subj_start, subj_end, obj_start, obj_end, sentence = self.search_entity(sentence)
                meta = dict(
                    id=id_s,
                    relation=data[i+1].strip(),
                    head=e1,
                    tail=e2,
                    subj_start=subj_start,
                    subj_end=subj_end,
                    obj_start=obj_start,
                    obj_end=obj_end,
                    sentence=sentence,
                    comment=data[i+2].strip()[8:]
                )
                json.dump(meta, fw, ensure_ascii=False)
                fw.write('\n')

In [9]:
data_processor=processor()
s="The most common <e1>audits</e1> were about <e2>waste</e2> and recycling."
print(data_processor.search_entity(s))

('audits', 'waste', 3, 3, 6, 6, ['The', 'most', 'common', 'audits', 'were', 'about', 'waste', 'and', 'recycling', '.'])


In [10]:
#train data
data_processor.convert("./data/TRAIN_FILE.TXT",'./data/train.json')

100%|████████████████████████████████████████████████████████████████████████████| 8000/8000 [00:02<00:00, 2826.53it/s]


In [11]:
#train data
data_processor.convert("./data/TEST_FILE_FULL.TXT",'./data/test.json')

100%|████████████████████████████████████████████████████████████████████████████| 2717/2717 [00:01<00:00, 2697.28it/s]


# 2 Load the pretrained word vectors
### 1)
input /输入：
```
word vectors file path /词向量文件地址
```
output /输出:
```
{'PAD': 0, 'sigarms': 1, 'katuna': 2, 'aqm': 3, '1.3775': 4, 'corythosaurus': 5, 'chanty': 6, 'kronik': 7, 'rolonda': 8, 'zsombor': 9, 'sandberger': 10, '*UNKNOWN*': 11}

[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 2.6818e-01,  1.4346e-01, -2.7877e-01,  1.6257e-02,  1.1384e-01,
          6.9923e-01, -5.1332e-01, -4.7368e-01, -3.3075e-01, -1.3834e-01,
          2.7020e-01,  3.0938e-01, -4.5012e-01, -4.1270e-01, -9.9320e-02,
          3.8085e-02,  2.9749e-02,  1.0076e-01, -2.5058e-01, -5.1818e-01,
          3.4558e-01,  4.4922e-01,  4.8791e-01, -8.0866e-02, -1.0121e-01,
         -1.3777e+00, -1.0866e-01, -2.3201e-01,  1.2839e-02, -4.6508e-01,
          3.8463e+00,  3.1362e-01,  1.3643e-01, -5.2244e-01,  3.3020e-01,
          3.3707e-01, -3.5601e-01,  3.2431e-01,  1.2041e-01,  3.5120e-01,
         -6.9043e-02,  3.6885e-01,  2.5168e-01, -2.4517e-01,  2.5381e-01,
          1.3670e-01, -3.1178e-01, -6.3210e-01, -2.5028e-01, -3.8097e-01],
        [ 3.3042e-01,  2.4995e-01, -6.0874e-01,  1.0923e-01,  3.6372e-02,
          1.5100e-01, -5.5083e-01, -7.4239e-02, -9.2307e-02, -3.2821e-01,
          9.5980e-02, -8.2269e-01, -3.6717e-01, -6.7009e-01,  4.2909e-01,
          1.6496e-02, -2.3573e-01,  1.2864e-01, -1.0953e+00,  4.3334e-01,
          5.7067e-01, -1.0360e-01,  2.0422e-01,  7.8308e-02, -4.2795e-01,
         -1.7984e+00, -2.7865e-01,  1.1954e-01, -1.2689e-01,  3.1744e-02,
          3.8631e+00, -1.7786e-01, -8.2434e-02, -6.2698e-01,  2.6497e-01,
         -5.7185e-02, -7.3521e-02,  4.6103e-01,  3.0862e-01,  1.2498e-01,
         -4.8609e-01, -8.0272e-03,  3.1184e-02, -3.6576e-01, -4.2699e-01,
          4.2164e-01, -1.1666e-01, -5.0703e-01, -2.7273e-02, -5.3285e-01],
        [ 2.1705e-01,  4.6515e-01, -4.6757e-01,  1.0082e-01,  1.0135e+00,
          7.4845e-01, -5.3104e-01, -2.6256e-01,  1.6812e-01,  1.3182e-01,
         -2.4909e-01, -4.4185e-01, -2.1739e-01,  5.1004e-01,  1.3448e-01,
         -4.3141e-01, -3.1230e-02,  2.0674e-01, -7.8138e-01, -2.0148e-01,
         -9.7401e-02,  1.6088e-01, -6.1836e-01, -1.8504e-01, -1.2461e-01,
         -2.2526e+00, -2.2321e-01,  5.0430e-01,  3.2257e-01,  1.5313e-01,
          3.9636e+00, -7.1365e-01, -6.7012e-01,  2.8388e-01,  2.1738e-01,
          1.4433e-01,  2.5926e-01,  2.3434e-01,  4.2740e-01, -4.4451e-01,
          1.3813e-01,  3.6973e-01, -6.4289e-01,  2.4142e-02, -3.9315e-02,
         -2.6037e-01,  1.2017e-01, -4.3782e-02,  4.1013e-01,  1.7960e-01],
        [ 2.5769e-01,  4.5629e-01, -7.6974e-01, -3.7679e-01,  5.9272e-01,
         -6.3527e-02,  2.0545e-01, -5.7385e-01, -2.9009e-01, -1.3662e-01,
          3.2728e-01,  1.4719e+00, -7.3681e-01, -1.2036e-01,  7.1354e-01,
         -4.6098e-01,  6.5248e-01,  4.8887e-01, -5.1558e-01,  3.9951e-02,
         -3.4307e-01, -1.4087e-02,  8.6488e-01,  3.5460e-01,  7.9990e-01,
         -1.4995e+00, -1.8153e+00,  4.1128e-01,  2.3921e-01, -4.3139e-01,
          3.6623e+00, -7.9834e-01, -5.4538e-01,  1.6943e-01, -8.2017e-01,
         -3.4610e-01,  6.9495e-01, -1.2256e+00, -1.7992e-01, -5.7474e-02,
          3.0498e-02, -3.9543e-01, -3.8515e-01, -1.0002e+00,  8.7599e-02,
         -3.1009e-01, -3.4677e-01, -3.1438e-01,  7.5004e-01,  9.7065e-01],
        [ 2.3727e-01,  4.0478e-01, -2.0547e-01,  5.8805e-01,  6.5533e-01,
          3.2867e-01, -8.1964e-01, -2.3236e-01,  2.7428e-01,  2.4265e-01,
          5.4992e-02,  1.6296e-01, -1.2555e+00, -8.6437e-02,  4.4536e-01,
          9.6561e-02, -1.6519e-01,  5.8378e-02, -3.8598e-01,  8.6977e-02,
          3.3869e-03,  5.5095e-01, -7.7697e-01, -6.2096e-01,  9.2948e-02,
         -2.5685e+00, -6.7739e-01,  1.0151e-01, -4.8643e-01, -5.7805e-02,
          3.1859e+00, -1.7554e-02, -1.6138e-01,  5.5486e-02, -2.5885e-01,
         -3.3938e-01, -1.9928e-01,  2.6049e-01,  1.0478e-01, -5.5934e-01,
         -1.2342e-01,  6.5961e-01, -5.1802e-01, -8.2995e-01, -8.2739e-02,
          2.8155e-01, -4.2300e-01, -2.7378e-01, -7.9010e-03, -3.0231e-02],
        [ 4.4130e-01, -6.6569e-01,  5.3693e-01, -2.8451e-01, -4.9432e-01,
          3.7877e-01,  2.9462e-01, -4.2007e-01, -3.3072e-01, -1.4959e-01,
          1.7734e-02, -6.9029e-01, -4.1045e-01, -1.7590e-01,  3.5267e-01,
          4.5984e-02, -5.5600e-01,  7.3799e-01,  1.3054e-01,  4.5681e-01,
         -1.0072e-01,  9.9244e-01,  3.2525e-01, -4.8654e-01, -9.0020e-01,
         -5.8867e-01,  3.7007e-01,  7.8460e-01,  2.5547e-01,  9.4015e-01,
          5.3539e-01,  3.4617e-01,  1.6523e-01,  1.3020e-01, -6.2095e-01,
         -1.2875e-01, -7.9159e-01, -7.6295e-01,  5.2197e-01, -7.5438e-02,
         -3.8554e-01,  5.8315e-01, -7.1882e-02,  5.6333e-02, -2.8585e-01,
         -3.5239e-01,  1.4035e-01,  8.2411e-01, -6.6001e-01,  1.9785e-01]]

```

In [12]:
# Load the word embeeding
class WordEmbeddingLoader(object):
    """
    A loader for pre-trained word embedding
    """

    def __init__(self, config):
        self.path_word = config.embedding_path  # path of pre-trained word embedding
        self.word_dim = config.word_dim  # dimension of word embedding
    
    def load_embedding(self):
        word2id = dict()  # word to wordID
        word_vec = list()  # wordID to word embedding

        word2id['PAD'] = len(word2id)  # PAD character
        #
        with open(self.path_word, 'r', encoding='utf-8') as fr:
            for line in fr:
                line = line.strip().split()
                if len(line) != self.word_dim + 1:
                    continue
                word2id[line[0]] = len(word2id)
                word_vec.append(np.asarray(line[1:], dtype=np.float32))
        if("*UNKNOWN*" not in word2id):
            word2id['*UNKNOWN*'] = len(word2id)
            unk_emb= np.random.uniform(-1, 1,self.word_dim)
            word_vec.append(unk_emb)
        pad_emb = np.zeros([1, self.word_dim], dtype=np.float32)  # <pad> is initialize as zero
        word_vec = np.concatenate((pad_emb, word_vec), axis=0)
        word_vec = word_vec.astype(np.float32).reshape(-1, self.word_dim)
        word_vec = torch.from_numpy(word_vec)
        return word2id, word_vec

In [13]:
from collections import namedtuple
conf = namedtuple('conf',['embedding_path','word_dim'])
config = conf("embedding/glove.6B.50d.txt",50)

In [14]:
config

conf(embedding_path='embedding/glove.6B.50d.txt', word_dim=50)

In [15]:
emd_loader=WordEmbeddingLoader(config)
word2id, word_vec=emd_loader.load_embedding()

# 3 Relation Encoding
input /输入：
```
relation encoding file path /关系编码文件地址
```
output /输出:
```
rel2id:

{'Other': 0, 'Cause-Effect(e1,e2)': 1, 'Cause-Effect(e2,e1)': 2, 'Component-Whole(e1,e2)': 3, 'Component-Whole(e2,e1)': 4, 'Content-Container(e1,e2)': 5, 'Content-Container(e2,e1)': 6, 'Entity-Destination(e1,e2)': 7, 'Entity-Destination(e2,e1)': 8, 'Entity-Origin(e1,e2)': 9, 'Entity-Origin(e2,e1)': 10, 'Instrument-Agency(e1,e2)': 11, 'Instrument-Agency(e2,e1)': 12, 'Member-Collection(e1,e2)': 13, 'Member-Collection(e2,e1)': 14, 'Message-Topic(e1,e2)': 15, 'Message-Topic(e2,e1)': 16, 'Product-Producer(e1,e2)': 17, 'Product-Producer(e2,e1)': 18}

id2rel:

{0: 'Other', 1: 'Cause-Effect(e1,e2)', 2: 'Cause-Effect(e2,e1)', 3: 'Component-Whole(e1,e2)', 4: 'Component-Whole(e2,e1)', 5: 'Content-Container(e1,e2)', 6: 'Content-Container(e2,e1)', 7: 'Entity-Destination(e1,e2)', 8: 'Entity-Destination(e2,e1)', 9: 'Entity-Origin(e1,e2)', 10: 'Entity-Origin(e2,e1)', 11: 'Instrument-Agency(e1,e2)', 12: 'Instrument-Agency(e2,e1)', 13: 'Member-Collection(e1,e2)', 14: 'Member-Collection(e2,e1)', 15: 'Message-Topic(e1,e2)', 16: 'Message-Topic(e2,e1)', 17: 'Product-Producer(e1,e2)', 18: 'Product-Producer(e2,e1)'}

class_num:
19
```

In [16]:
class RelationLoader(object):
    def __init__(self, config):
        self.data_dir = config.data_dir

    def __load_relation(self):
        relation_file = os.path.join(self.data_dir, 'relation2id.txt')
        rel2id = {}
        id2rel = {}
        with open(relation_file, 'r', encoding='utf-8') as fr:
            for line in fr:
                relation, id_s = line.strip().split()
                id_d = int(id_s)
                rel2id[relation] = id_d
                id2rel[id_d] = relation
        return rel2id, id2rel, len(rel2id)

    def get_relation(self):
        return self.__load_relation()

In [17]:
# ! head data/relation2id.txt
from collections import namedtuple
conf = namedtuple('conf',['data_dir'])
config = conf("data")
rel_loader=RelationLoader(config)
rel2id, id2rel, class_num=rel_loader.get_relation()
# print(rel_loader.get_relation()[2])

# 4 training sample generation

### 1 defind the dataset class

In [18]:
class SemEvalDateset(Dataset):
    def __init__(self, filename, rel2id, word2id, config):
        self.filename = filename
        self.rel2id = rel2id
        self.word2id = word2id
        self.max_len = config.max_len
        self.pos_dis = config.pos_dis
        self.data_dir = config.data_dir
        self.dataset, self.label = self.__load_data()
        
    # position encoding
    # pos_dis: parameter for position encoding: the length to shift the position
    def __get_pos_index(self, x):
        if x < -self.pos_dis:
            return 0
        if x >= -self.pos_dis and x <= self.pos_dis:
            return x + self.pos_dis + 1
        if x > self.pos_dis:
            return 2 * self.pos_dis + 2
        
    # relative position encoding
    def __get_relative_pos(self, x, entity_pos):
        # entity_pos[0] -> begin
        # entity_pos[1] -> end
        if x < entity_pos[0]:
            return self.__get_pos_index(x-entity_pos[0])
        elif x > entity_pos[1]:
            return self.__get_pos_index(x-entity_pos[1])
        else:
            return self.__get_pos_index(0)
        
    #sentence feature
    def _symbolize_sentence(self, e1_pos, e2_pos, sentence):
        """
            Args:
                e1_pos (tuple) span of e1
                e2_pos (tuple) span of e2
                sentence (list)
        """
        
        mask = [1] * len(sentence)
        # for exmaple  [1,1,1,1,1,2,2,2,2,2,2,3,3,3,3,3]
        # 1 for the positions before the first entity
        # 2 for the positions bwt first ane second entity
        # 3 for the positions after the second entity
        
        if e1_pos[0] < e2_pos[0]:
            for i in range(e1_pos[0], e2_pos[1]+1):
                mask[i] = 2
            for i in range(e2_pos[1]+1, len(sentence)):
                mask[i] = 3
        else:
            for i in range(e2_pos[0], e1_pos[1]+1):
                mask[i] = 2
            for i in range(e1_pos[1]+1, len(sentence)):
                mask[i] = 3

        words = [] # words id list
        pos1 = [] # word position relative to es1
        pos2 = [] # word pisition relative to es2
        
        length = min(self.max_len, len(sentence))
        mask = mask[:length]

        for i in range(length):
            words.append(self.word2id.get(sentence[i], self.word2id['*UNKNOWN*']))
            pos1.append(self.__get_relative_pos(i, e1_pos))
            pos2.append(self.__get_relative_pos(i, e2_pos))
        
        # PADDING
        if length < self.max_len:
            for i in range(length, self.max_len):
                mask.append(0)  # 'PAD' mask is zero
                words.append(self.word2id['PAD'])

                pos1.append(self.__get_relative_pos(i, e1_pos))
                pos2.append(self.__get_relative_pos(i, e2_pos))
        unit = np.asarray([words, pos1, pos2, mask], dtype=np.int64)
        unit = np.reshape(unit, newshape=(1, 4, self.max_len))
        return unit
    
    # lexical feature
    def _lexical_feature(self,e1_idx,e2_idx, sent):
        
        def _entity_context(e_idx,sent):
            ''' return [w(e-1), w(e), w(e+1)]
            '''
            context = []
            context.append(sent[e_idx])
            if e_idx >= 1:
                context.append(sent[e_idx-1])
            else:
                context.append(sent[e_idx])

            if e_idx < len(sent)-1:
                context.append(sent[e_idx+1])
            else:
                context.append(sent[e_idx])
            return context
        
        # to find the right and left word
        context1 = _entity_context(e1_idx[0], sent)
        context2 = _entity_context(e2_idx[0], sent)
        
        # ignore WordNet hypernyms in paper
        lexical = context1 + context2
#         print(sent)
#         print(lexical)
        lexical_ids=[self.word2id.get(word, self.word2id['*UNKNOWN*']) for word in lexical]
        lexical_ids=np.asarray(lexical_ids, dtype=np.int64)
#         print(lexical_ids)
        return np.reshape(lexical_ids, newshape=(1, 6))
    
    def __load_data(self):
        path_data_file = os.path.join(self.data_dir, self.filename)
        data = []
        labels = []
        with open(path_data_file, 'r', encoding='utf-8') as fr:
            for line in fr:
                line = json.loads(line.strip())
                label = line['relation']
                sentence = line['sentence']
                e1_pos = (line['subj_start'], line['subj_end'])
                e2_pos = (line['obj_start'], line['obj_end'])
                label_idx = self.rel2id[label]

                one_sentence = self._symbolize_sentence(e1_pos, e2_pos, sentence)
                
                lexical = self._lexical_feature(e1_pos, e2_pos, sentence)
                
                temp = (one_sentence, lexical)
                data.append(temp)
                #data.append(one_sentence)
                labels.append(label_idx)
        return data, labels

    def __getitem__(self, index):
        data = self.dataset[index]
        label = self.label[index]
        return data, label

    def __len__(self):
        return len(self.label)

In [19]:
# ! head data/relation2id.txt
from collections import namedtuple
conf = namedtuple('conf',['max_len','pos_dis','data_dir'])
config = conf(100,50,"data")
filename='train.json'
data_loader=SemEvalDateset(filename, rel2id, word2id,config)

In [20]:
# example for lexical features encoded
e1_idx = (3,3)
e2_idx = (6,6)
sent=['The', 'most', 'common', 'audits', 'were', 'about', 'waste', 'and', 'recycling', '.']
data_loader._lexical_feature(e1_idx,e2_idx,sent)

array([[19894,   862,    36,  3632,    60,     6]], dtype=int64)

In [21]:
# example for sentence with position features encoding
e1_idx = (3,3)
e2_idx = (6,6)
sent=['The', 'most', 'common', 'audits', 'were', 'about', 'waste', 'and', 'recycling', '.']
sentence_features = data_loader._symbolize_sentence(e1_idx,e2_idx,sent)
# return  [words id, pos enc relative to  entity 1, pos enc relative to entity 2, mask]

In [22]:
sentence_features

array([[[400001,     97,    862,  19894,     36,     60,   3632,      6,
          12521,      3,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0],
        [    48,     49,     50,     51,     52,     53,     54,     55,
         

In [23]:
# sentence_features # [words id, pos enc for entity 1, pos enc for entity 2, mask]

In [24]:
# dataloader output
data,label = next(iter(data_loader))

In [25]:
data # ([words id, pos enc relative to  entity 1, pos enc relative to entity 2, mask],lexical feature)
# lexical feature = (left of entity1 ,entity1,right of entity 1 )

(array([[[400001,    279,     20,    980,   1070,     32,     48,   2606,
            3251,      7,     30,  40634,  11465,      4,  13874,   2623,
               3,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0],
         [    39,     40,     41,     42,     43,     44,     45,    

In [26]:
data[0].shape

(1, 4, 100)

### 4) Defind Data Loader
Define loader.__collate_fn convert data to tensor

In [27]:
#自定义loader
class SemEvalDataLoader(object):
    def __init__(self, rel2id, word2id, config):
        self.rel2id = rel2id
        self.word2id = word2id
        self.config = config

    def __collate_fn(self, batch):
        data, label = zip(*batch)  # unzip the batch data
        data = list(data)
        label = list(label)
        # cover array to be tensor
        sentence_feat = torch.from_numpy(np.concatenate([x[0] for x in data], axis=0)) # word id
        lexical_feat = torch.from_numpy(np.concatenate([x[1] for x in data], axis=0)) # pos for entity_1
        label = torch.from_numpy(np.asarray(label, dtype=np.int64))
        return (sentence_feat,lexical_feat),label

    def __get_data(self, filename, shuffle=False):
        dataset = SemEvalDateset(filename, self.rel2id, self.word2id, self.config)
        loader = DataLoader(
            dataset=dataset,
            batch_size=self.config.batch_size,
            shuffle=shuffle,
            num_workers=2,
            collate_fn=self.__collate_fn
        )
        return loader

    def get_train(self):
        return self.__get_data('train.json', shuffle=True)

    def get_dev(self):
        return self.__get_data('test.json', shuffle=False)

    def get_test(self):
        return self.__get_data('test.json', shuffle=False)