forked from lukun199/SemanticRL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_loader.py
78 lines (61 loc) · 3.03 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""
lukun199@gmail.com
19th Feb., 2021
# data_loader.py
"""
import os, pickle, torch
from torch.utils.data import Dataset
import numpy as np
class Dataset_sentence(Dataset):
def __init__(self, _path, use_sos=False):
if not _path: _path = r'H:\Europarl' # change to your own path
self._path = os.path.join(_path, 'english_vocab.pkl')
self.dict = {}
tmp = pickle.load(open(self._path, 'rb'))
for kk,vv in tmp['voc'].items(): self.dict[kk] = vv+3
# add sos, eos, and pad.
self.dict['PAD'], self.dict['SOS'], self.dict['EOS'] = 0, 1, 2
self.len_range = tmp['len_range']
self.rev_dict = {vv: kk for kk, vv in self.dict.items()}
sos_head = [1] if use_sos else []
self.data_num = torch.tensor([sos_head + list(map(lambda t:self.dict[t], x.split(' '))) + [2]
+ (self.len_range[1]-len(x.split(' ')))*[0]
for idx, x in enumerate(tmp['sent_str']) if idx%5!=0]) # use tmp['sent_str'][:1000] for debugging
self.test_data_num = [sos_head + list(map(lambda t:self.dict[t], x.split(' '))) + [2]
+ (self.len_range[1]-len(x.split(' ')))*[0]
for idx, x in enumerate(tmp['sent_str']) if idx%5==0] # 20% of data
self.data_len = np.array(list(map(lambda s: sum(s != 0), self.data_num)))
print('[*]------------vocabulary size is:----', self.get_dict_len())
print('[*]------------sentences size is:----', self.__len__())
#print('[*]------------test sentences size is:----', len(self.test_data_num))
def __getitem__(self, index):
return self.data_num[index], self.data_len[index]
def __len__(self):
return len(self.data_num)
def get_dict_len(self):
return len(self.dict)
class Dataset_sentence_test(Dataset):
def __init__(self, _path):
if not _path: _path = r'H:\Europarl'
self._path = os.path.join(_path, 'english_vocab.pkl')
self.dict = {}
tmp = pickle.load(open(self._path, 'rb'))
for kk,vv in tmp['voc'].items(): self.dict[kk] = vv+3
# add sos, eos, and pad.
self.dict['PAD'], self.dict['SOS'], self.dict['EOS'] = 0, 1, 2
self.len_range = tmp['len_range']
self.rev_dict = {vv: kk for kk, vv in self.dict.items()}
self.data_num = [[1] + list(map(lambda t:self.dict[t], x.split(' '))) + [2]
+ (self.len_range[1]-len(x.split(' ')))*[0]
for idx, x in enumerate(tmp['sent_str']) if idx%5==0]
print('[*]------------vocabulary size is:----', self.get_dict_len())
print('[*]------------sentences size is:----', len(self.data_num))
def __getitem__(self, index):
return torch.tensor(self.data_num[index])
def __len__(self):
return len(self.data_num)
def get_dict_len(self):
return len(self.dict)
def collate_func(in_data):
batch_tensor, batch_len = list(zip(*(sorted(in_data, key=lambda s:-s[1]))))
return torch.stack(batch_tensor, dim=0), batch_len