-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
133 lines (113 loc) · 3.78 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import sys, pickle, os, random
import numpy as np
## tags, BIO
tag2label = {"B-ENT": 0, "I-ENT": 1,
"B-HYPER": 2, "I-HYPER": 3,
"O": 4
}
def read_corpus_1(corpus_path):
data = []
with open(corpus_path, encoding='utf-8') as fr:
lines = fr.readlines()
sent_ = []
for line in lines:
if line != '\n':
sent_ = line.strip().split()
data.append(sent_)
else:
sent_ = []
data.append(sent_)
return data
def read_corpus(corpus_path):
data = []
with open(corpus_path, encoding='utf-8') as fr:
lines = fr.readlines()
sent_, tag_ = [], []
for line in lines:
if line != '\n':
[char, label] = line.strip().split()
sent_.append(char)
tag_.append(label)
else:
data.append((sent_, tag_))
sent_, tag_ = [], []
return data
def vocab_build(vocab_path, corpus_path, min_count):
data = read_corpus_1(corpus_path)
word2id = {}
for sent_ in data:
for word in sent_:
if word.isdigit():
word = '<NUM>'
elif ('\u0041' <= word <='\u005a') or ('\u0061' <= word <='\u007a'):
word = '<ENG>'
if word not in word2id:
word2id[word] = [len(word2id)+1, 1]
else:
word2id[word][1] += 1
low_freq_words = []
for word, [word_id, word_freq] in word2id.items():
if word_freq < min_count and word != '<NUM>' and word != '<ENG>':
low_freq_words.append(word)
for word in low_freq_words:
del word2id[word]
new_id = 1
for word in word2id.keys():
word2id[word] = new_id
new_id += 1
word2id['<UNK>'] = new_id
word2id['<PAD>'] = 0
print(len(word2id))
with open(vocab_path, 'wb') as fw:
pickle.dump(word2id, fw)
def sentence2id(sent, word2id):
sentence_id = []
for word in sent:
if word.isdigit():
# Inspecting the str whether combine by number
word = '<NUM>'
elif ('\u0041' <= word <= '\u005a') or ('\u0061' <= word <= '\u007a'):
# Judging the english
word = '<ENG>'
if word not in word2id:
# Chinese
word = '<UNK>'
sentence_id.append(word2id[word])
return sentence_id
def read_dictionary(vocab_path):
vocab_path = os.path.join(vocab_path)
with open(vocab_path, 'rb') as fr:
word2id = pickle.load(fr)
print('vocab_size:', len(word2id))
return word2id
def random_embedding(vocab, embedding_dim):
embedding_mat = np.random.uniform(-0.25, 0.25, (len(vocab), embedding_dim))
embedding_mat = np.float32(embedding_mat)
return embedding_mat
def pad_sequences(sequences, pad_mark=0):
# padding
max_len = max(map(lambda x: len(x), sequences))
seq_list, seq_len_list = [], []
for seq in sequences:
seq = list(seq)
seq_ = seq[:max_len] + [pad_mark] * max(max_len - len(seq), 0)
seq_list.append(seq_)
# The length of seq
seq_len_list.append(min(len(seq), max_len))
return seq_list, seq_len_list
def batch_yield(data, batch_size, vocab, tag2label, shuffle=False):
if shuffle:
# Random data
random.shuffle(data)
seqs, labels = [], []
for (sent_, tag_) in data:
sent_ = sentence2id(sent_, vocab)
label_ = [tag2label[tag] for tag in tag_]
if len(seqs) == batch_size:
yield seqs, labels
seqs, labels = [], []
seqs.append(sent_)
labels.append(label_)
if len(seqs) != 0:
# Return an iteration
yield seqs, labels