/
interaction.py
389 lines (323 loc) · 13.3 KB
/
interaction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
# @Time : 2020/7/10
# @Author : Yupeng Hou
# @Email : houyupeng@ruc.edu.cn
# UPDATE
# @Time : 2022/7/8, 2020/9/15, 2020/9/16, 2020/8/12
# @Author : Zhen Tian, Yupeng Hou, Yushuo Chen, Xingyu Pan
# @email : chenyuwuxinn@gmail.com, houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn, panxy@ruc.edu.cn
"""
recbole.data.interaction
############################
"""
import numpy as np
import pandas as pd
import torch
import torch.nn.utils.rnn as rnn_utils
def _convert_to_tensor(data):
"""This function can convert common data types (list, pandas.Series, numpy.ndarray, torch.Tensor) into torch.Tensor.
Args:
data (list, pandas.Series, numpy.ndarray, torch.Tensor): Origin data.
Returns:
torch.Tensor: Converted tensor from `data`.
"""
elem = data[0]
if isinstance(elem, (float, int, np.float, np.int64)):
new_data = torch.as_tensor(data)
elif isinstance(elem, (list, tuple, pd.Series, np.ndarray, torch.Tensor)):
seq_data = [torch.as_tensor(d) for d in data]
new_data = rnn_utils.pad_sequence(seq_data, batch_first=True)
else:
raise ValueError(f"[{type(elem)}] is not supported!")
if new_data.dtype == torch.float64:
new_data = new_data.float()
return new_data
class Interaction(object):
"""The basic class representing a batch of interaction records.
Note:
While training, there is no strict rules for data in one Interaction object.
While testing, it should be guaranteed that all interaction records of one single
user will not appear in different Interaction object, and records of the same user
should be continuous. Meanwhile, the positive cases of one user always need to occur
**earlier** than this user's negative cases.
A correct example:
======= ======= =======
user_id item_id label
======= ======= =======
1 2 1
1 6 1
1 3 1
1 1 0
2 3 1
... ... ...
======= ======= =======
Some wrong examples for Interaction objects used in testing:
1.
======= ======= ======= ============
user_id item_id label
======= ======= ======= ============
1 2 1
1 6 0 # positive cases of one user always need to
occur earlier than this user's negative cases
1 3 1
1 1 0
2 3 1
... ... ...
======= ======= ======= ============
2.
======= ======= ======= ========
user_id item_id label
======= ======= ======= ========
1 2 1
1 6 1
1 3 1
2 3 1 # records of the same user should be continuous.
1 1 0
... ... ...
======= ======= ======= ========
Attributes:
interaction (dict or pandas.DataFrame): keys are meaningful str (also can be called field name),
and values are Torch Tensor of numpy Array with shape (batch_size, \\*).
"""
def __init__(self, interaction):
self.interaction = dict()
if isinstance(interaction, dict):
for key, value in interaction.items():
if isinstance(value, (list, np.ndarray)):
self.interaction[key] = _convert_to_tensor(value)
elif isinstance(value, torch.Tensor):
self.interaction[key] = value
else:
raise ValueError(
f"The type of {key}[{type(value)}] is not supported!"
)
elif isinstance(interaction, pd.DataFrame):
for key in interaction:
value = interaction[key].values
self.interaction[key] = _convert_to_tensor(value)
else:
raise ValueError(
f"[{type(interaction)}] is not supported for initialize `Interaction`!"
)
self.length = -1
for k in self.interaction:
self.length = max(self.length, self.interaction[k].unsqueeze(-1).shape[0])
def __iter__(self):
return self.interaction.__iter__()
def __getattr__(self, item):
if "interaction" not in self.__dict__:
raise AttributeError(f"'Interaction' object has no attribute 'interaction'")
if item in self.interaction:
return self.interaction[item]
raise AttributeError(f"'Interaction' object has no attribute '{item}'")
def __getitem__(self, index):
if isinstance(index, str):
return self.interaction[index]
if isinstance(index, (np.ndarray, torch.Tensor)):
index = index.tolist()
ret = {}
for k in self.interaction:
ret[k] = self.interaction[k][index]
return Interaction(ret)
def __setitem__(self, key, value):
if not isinstance(key, str):
raise KeyError(f"{type(key)} object does not support item assigment")
self.interaction[key] = value
def __delitem__(self, key):
if key not in self.interaction:
raise KeyError(f"{type(key)} object does not in this interaction")
del self.interaction[key]
def __contains__(self, item):
return item in self.interaction
def __len__(self):
return self.length
def __str__(self):
info = [f"The batch_size of interaction: {self.length}"]
for k in self.interaction:
inter = self.interaction[k]
temp_str = f" {k}, {inter.shape}, {inter.device.type}, {inter.dtype}"
info.append(temp_str)
info.append("\n")
return "\n".join(info)
def __repr__(self):
return self.__str__()
@property
def columns(self):
"""
Returns:
list of str: The columns of interaction.
"""
return list(self.interaction.keys())
def to(self, device, selected_field=None):
"""Transfer Tensors in this Interaction object to the specified device.
Args:
device (torch.device): target device.
selected_field (str or iterable object, optional): if specified, only Tensors
with keys in selected_field will be sent to device.
Returns:
Interaction: a coped Interaction object with Tensors which are sent to
the specified device.
"""
ret = {}
if isinstance(selected_field, str):
selected_field = [selected_field]
if selected_field is not None:
selected_field = set(selected_field)
for k in self.interaction:
if k in selected_field:
ret[k] = self.interaction[k].to(device)
else:
ret[k] = self.interaction[k]
else:
for k in self.interaction:
ret[k] = self.interaction[k].to(device)
return Interaction(ret)
def cpu(self):
"""Transfer Tensors in this Interaction object to cpu.
Returns:
Interaction: a coped Interaction object with Tensors which are sent to cpu.
"""
ret = {}
for k in self.interaction:
ret[k] = self.interaction[k].cpu()
return Interaction(ret)
def numpy(self):
"""Transfer Tensors to numpy arrays.
Returns:
dict: keys the same as Interaction object, are values are corresponding numpy
arrays transformed from Tensor.
"""
ret = {}
for k in self.interaction:
ret[k] = self.interaction[k].numpy()
return ret
def repeat(self, sizes):
"""Repeats each tensor along the batch dim.
Args:
sizes (int): repeat times.
Example:
>>> a = Interaction({'k': torch.zeros(4)})
>>> a.repeat(3)
The batch_size of interaction: 12
k, torch.Size([12]), cpu
>>> a = Interaction({'k': torch.zeros(4, 7)})
>>> a.repeat(3)
The batch_size of interaction: 12
k, torch.Size([12, 7]), cpu
Returns:
a copyed Interaction object with repeated Tensors.
"""
ret = {}
for k in self.interaction:
ret[k] = self.interaction[k].repeat(
[sizes] + [1] * (len(self.interaction[k].shape) - 1)
)
return Interaction(ret)
def repeat_interleave(self, repeats, dim=0):
"""Similar to repeat_interleave of PyTorch.
Details can be found in:
https://pytorch.org/docs/stable/tensors.html?highlight=repeat#torch.Tensor.repeat_interleave
Note:
``torch.repeat_interleave()`` is supported in PyTorch >= 1.2.0.
"""
ret = {}
for k in self.interaction:
ret[k] = self.interaction[k].repeat_interleave(repeats, dim=dim)
return Interaction(ret)
def update(self, new_inter):
"""Similar to ``dict.update()``
Args:
new_inter (Interaction): current interaction will be updated by new_inter.
"""
for k in new_inter.interaction:
self.interaction[k] = new_inter.interaction[k]
def drop(self, column):
"""Drop column in interaction.
Args:
column (str): the column to be dropped.
"""
if column not in self.interaction:
raise ValueError(f"Column [{column}] is not in [{self}].")
del self.interaction[column]
def _reindex(self, index):
"""Reset the index of interaction inplace.
Args:
index: the new index of current interaction.
"""
for k in self.interaction:
self.interaction[k] = self.interaction[k][index]
def shuffle(self):
"""Shuffle current interaction inplace."""
index = torch.randperm(self.length)
self._reindex(index)
def sort(self, by, ascending=True):
"""Sort the current interaction inplace.
Args:
by (str or list of str): Field that as the key in the sorting process.
ascending (bool or list of bool, optional): Results are ascending if ``True``, otherwise descending.
Defaults to ``True``
"""
if isinstance(by, str):
if by not in self.interaction:
raise ValueError(f"[{by}] is not exist in interaction [{self}].")
by = [by]
elif isinstance(by, (list, tuple)):
for b in by:
if b not in self.interaction:
raise ValueError(f"[{b}] is not exist in interaction [{self}].")
else:
raise TypeError(f"Wrong type of by [{by}].")
if isinstance(ascending, bool):
ascending = [ascending]
elif isinstance(ascending, (list, tuple)):
for a in ascending:
if not isinstance(a, bool):
raise TypeError(f"Wrong type of ascending [{ascending}].")
else:
raise TypeError(f"Wrong type of ascending [{ascending}].")
if len(by) != len(ascending):
if len(ascending) == 1:
ascending = ascending * len(by)
else:
raise ValueError(
f"by [{by}] and ascending [{ascending}] should have same length."
)
for b, a in zip(by[::-1], ascending[::-1]):
if len(self.interaction[b].shape) == 1:
key = self.interaction[b]
else:
key = self.interaction[b][..., 0]
index = np.argsort(key, kind="stable")
if not a:
index = torch.tensor(np.array(index)[::-1])
self._reindex(index)
def add_prefix(self, prefix):
"""Add prefix to current interaction's columns.
Args:
prefix (str): The prefix to be added.
"""
self.interaction = {
prefix + key: value for key, value in self.interaction.items()
}
def cat_interactions(interactions):
"""Concatenate list of interactions to single interaction.
Args:
interactions (list of :class:`Interaction`): List of interactions to be concatenated.
Returns:
:class:`Interaction`: Concatenated interaction.
"""
if not isinstance(interactions, (list, tuple)):
raise TypeError(f"Interactions [{interactions}] should be list or tuple.")
if len(interactions) == 0:
raise ValueError(
f"Interactions [{interactions}] should have some interactions."
)
columns_set = set(interactions[0].columns)
for inter in interactions:
if columns_set != set(inter.columns):
raise ValueError(
f"Interactions [{interactions}] should have some interactions."
)
new_inter = {
col: torch.cat([inter[col] for inter in interactions]) for col in columns_set
}
return Interaction(new_inter)