-
Notifications
You must be signed in to change notification settings - Fork 29
/
DataIter.py
54 lines (48 loc) · 1.85 KB
/
DataIter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import sys
import linecache
import mxnet as mx
class Batch(object):
def __init__(self, data_names, data, label_names, label):
self.data = data
self.label = label
self.data_names = data_names
self.label_names = label_names
@property
def provide_data(self):
return [(n, x.shape) for n, x in zip(self.data_names, self.data)]
@property
def provide_label(self):
return [(n, x.shape) for n, x in zip(self.label_names, self.label)]
class DataIter(mx.io.DataIter):
def __init__(self, fname, batch_size):
super(DataIter, self).__init__()
self.batch_size = batch_size
self.fname = fname
self.index_start = 1
self.provide_data = [('user', (batch_size, )), ('item', (batch_size, ))]
self.provide_label = [('rate', (batch_size, ))]
def __iter__(self):
while(True):
buser = []
bitem = []
brate = []
if (not linecache.getline(self.fname, self.index_start + self.batch_size)):
return
for i in range(self.index_start, self.index_start + self.batch_size):
line = linecache.getline(self.fname, i)
lines = line.strip().split('::')
if(len(lines) != 4):
continue
line_user, line_item, line_rate, _ = lines
buser.append(line_user)
bitem.append(line_item)
brate.append(line_rate)
data_all = [mx.nd.array(buser), mx.nd.array(bitem)]
label_all = [mx.nd.array(brate)]
data_names = ['user', 'item']
label_names = ['rate']
self.index_start += self.batch_size
data_batch = Batch(data_names, data_all, label_names, label_all)
yield data_batch
def reset(self):
self.index_start = 1