/
text_datasets.py
executable file
·212 lines (173 loc) · 7.07 KB
/
text_datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
# import blobfile as bf
import numpy as np
from torch.utils.data import DataLoader, Dataset
import torch
import json
import psutil
import datasets
from datasets import Dataset as Dataset2
def load_data_text(
batch_size,
seq_len,
deterministic=False,
data_args=None,
model_emb=None,
split='train',
loaded_vocab=None,
loop=True,
):
"""
For a dataset, create a generator over (seqs, kwargs) pairs.
Each seq is an (bsz, len, h) float tensor, and the kwargs dict contains zero or
more keys, each of which map to a batched Tensor of their own.
The kwargs dict can be used for some meta information.
:param batch_size: the batch size of each returned pair.
:param seq_len: the max sequence length (one-side).
:param deterministic: if True, yield results in a deterministic order.
:param data_args: including dataset directory, num of dataset, basic settings, etc.
:param model_emb: loaded word embeddings.
:param loaded_vocab: loaded word vocabs.
:param loop: loop to get batch data or not.
"""
print('#'*30, '\nLoading text data...')
training_data = get_corpus(data_args, seq_len, split=split, loaded_vocab=loaded_vocab)
dataset = TextDataset(
training_data,
data_args,
model_emb=model_emb
)
data_loader = DataLoader(
dataset,
batch_size=batch_size, # 20,
# drop_last=True,
shuffle=not deterministic,
num_workers=0,
)
if loop:
return infinite_loader(data_loader)
else:
# print(data_loader)
return iter(data_loader)
def infinite_loader(data_loader):
while True:
yield from data_loader
def helper_tokenize(sentence_lst, vocab_dict, seq_len):
# Process.memory_info is expressed in bytes, so convert to megabytes
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
raw_datasets = Dataset2.from_dict(sentence_lst)
print(raw_datasets)
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
def tokenize_function(examples):
input_id_x = vocab_dict.encode_token(examples['src'])
input_id_y = vocab_dict.encode_token(examples['trg'])
result_dict = {'input_id_x': input_id_x, 'input_id_y': input_id_y}
return result_dict
tokenized_datasets = raw_datasets.map(
tokenize_function,
batched=True,
num_proc=4,
remove_columns=['src', 'trg'],
load_from_cache_file=True,
desc="Running tokenizer on dataset",
)
print('### tokenized_datasets', tokenized_datasets)
print('### tokenized_datasets...example', tokenized_datasets['input_id_x'][0])
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
def merge_and_mask(group_lst):
lst = []
mask = []
for i in range(len(group_lst['input_id_x'])):
end_token = group_lst['input_id_x'][i][-1]
src = group_lst['input_id_x'][i][:-1]
trg = group_lst['input_id_y'][i][:-1]
while len(src) + len(trg) > seq_len - 3:
if len(src)>len(trg):
src.pop()
elif len(src)<len(trg):
trg.pop()
else:
src.pop()
trg.pop()
src.append(end_token)
trg.append(end_token)
lst.append(src + [vocab_dict.sep_token_id] + trg)
mask.append([0]*(len(src)+1))
group_lst['input_ids'] = lst
group_lst['input_mask'] = mask
return group_lst
tokenized_datasets = tokenized_datasets.map(
merge_and_mask,
batched=True,
num_proc=1,
desc=f"merge and mask",
)
def pad_function(group_lst):
max_length = seq_len
group_lst['input_ids'] = _collate_batch_helper(group_lst['input_ids'], vocab_dict.pad_token_id, max_length)
group_lst['input_mask'] = _collate_batch_helper(group_lst['input_mask'], 1, max_length)
return group_lst
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
lm_datasets = tokenized_datasets.map(
pad_function,
batched=True,
num_proc=1,
desc=f"padding",
)
print(lm_datasets, 'padded dataset')
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
raw_datasets = datasets.DatasetDict()
raw_datasets['train'] = lm_datasets
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
return raw_datasets
def get_corpus(data_args, seq_len, split='train', loaded_vocab=None):
print('#'*30, '\nLoading dataset {} from {}...'.format(data_args.dataset, data_args.data_dir))
sentence_lst = {'src':[], 'trg': []}
if split == 'train':
print('### Loading form the TRAIN set...')
path = f'{data_args.data_dir}/train.jsonl'
elif split == 'valid':
print('### Loading form the VALID set...')
path = f'{data_args.data_dir}/valid.jsonl'
elif split == 'test':
print('### Loading form the TEST set...')
path = f'{data_args.data_dir}/test.jsonl'
else:
assert False, "invalid split for dataset"
with open(path, 'r') as f_reader:
for row in f_reader:
sentence_lst['src'].append(json.loads(row)['src'].strip())
sentence_lst['trg'].append(json.loads(row)['trg'].strip())
print('### Data samples...\n', sentence_lst['src'][:2], sentence_lst['trg'][:2])
# get tokenizer.
vocab_dict = loaded_vocab
train_dataset = helper_tokenize(sentence_lst, vocab_dict, seq_len)
return train_dataset
class TextDataset(Dataset):
def __init__(self, text_datasets, data_args, model_emb=None):
super().__init__()
self.text_datasets = text_datasets
self.length = len(self.text_datasets['train'])
self.data_args = data_args
self.model_emb = model_emb
def __len__(self):
return self.length
def __getitem__(self, idx):
with torch.no_grad():
input_ids = self.text_datasets['train'][idx]['input_ids']
hidden_state = self.model_emb(torch.tensor(input_ids))
# obtain the input vectors, only used when word embedding is fixed (not trained end-to-end)
arr = np.array(hidden_state, dtype=np.float32)
out_kwargs = {}
out_kwargs['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids'])
out_kwargs['input_mask'] = np.array(self.text_datasets['train'][idx]['input_mask'])
return arr, out_kwargs
def _collate_batch_helper(examples, pad_token_id, max_length, return_mask=False):
result = torch.full([len(examples), max_length], pad_token_id, dtype=torch.int64).tolist()
mask_ = torch.full([len(examples), max_length], pad_token_id, dtype=torch.int64).tolist()
for i, example in enumerate(examples):
curr_len = min(len(example), max_length)
result[i][:curr_len] = example[:curr_len]
mask_[i][:curr_len] = [1] * curr_len
if return_mask:
return result, mask_
return result