/
ComiRec.py
93 lines (79 loc) · 3.76 KB
/
ComiRec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# -*- coding: UTF-8 -*-
# @Author : Chenyang Wang
# @Email : THUwangcy@gmail.com
""" ComiRec
Reference:
"Controllable Multi-Interest Framework for Recommendation"
Cen et al., KDD'2020.
CMD example:
python main.py --model_name ComiRec --emb_size 64 --lr 1e-3 --l2 1e-6 --attn_size 8 --K 4 --add_pos 1 \
--history_max 20 --dataset 'Grocery_and_Gourmet_Food'
"""
import torch
import torch.nn as nn
import numpy as np
from models.BaseModel import SequentialModel
from utils import layers
class ComiRec(SequentialModel):
reader = 'SeqReader'
runner = 'BaseRunner'
extra_log_args = ['emb_size', 'attn_size', 'K']
@staticmethod
def parse_model_args(parser):
parser.add_argument('--emb_size', type=int, default=64,
help='Size of embedding vectors.')
parser.add_argument('--attn_size', type=int, default=8,
help='Size of attention vectors.')
parser.add_argument('--K', type=int, default=2,
help='Number of hidden intent.')
parser.add_argument('--add_pos', type=int, default=1,
help='Whether add position embedding.')
return SequentialModel.parse_model_args(parser)
def __init__(self, args, corpus):
super().__init__(args, corpus)
self.emb_size = args.emb_size
self.attn_size = args.attn_size
self.K = args.K
self.add_pos = args.add_pos
self.max_his = args.history_max
self.len_range = torch.from_numpy(np.arange(self.max_his)).to(self.device)
self._define_params()
self.apply(self.init_weights)
def _define_params(self):
self.i_embeddings = nn.Embedding(self.item_num, self.emb_size)
if self.add_pos:
self.p_embeddings = nn.Embedding(self.max_his + 1, self.emb_size)
self.W1 = nn.Linear(self.emb_size, self.attn_size)
self.W2 = nn.Linear(self.attn_size, self.K)
def forward(self, feed_dict):
self.check_list = []
i_ids = feed_dict['item_id'] # [batch_size, -1]
history = feed_dict['history_items'] # [batch_size, history_max]
lengths = feed_dict['lengths'] # [batch_size]
batch_size, seq_len = history.shape
valid_his = (history > 0).long()
his_vectors = self.i_embeddings(history)
if self.add_pos:
position = (lengths[:, None] - self.len_range[None, :seq_len]) * valid_his
pos_vectors = self.p_embeddings(position)
his_pos_vectors = his_vectors + pos_vectors
else:
his_pos_vectors = his_vectors
# Self-attention
attn_score = self.W2(self.W1(his_pos_vectors).tanh()) # bsz, his_max, K
attn_score = attn_score.masked_fill(valid_his.unsqueeze(-1) == 0, -np.inf)
attn_score = attn_score.transpose(-1, -2) # bsz, K, his_max
attn_score = (attn_score - attn_score.max()).softmax(dim=-1)
attn_score = attn_score.masked_fill(torch.isnan(attn_score), 0)
interest_vectors = (his_vectors[:, None, :, :] * attn_score[:, :, :, None]).sum(-2) # bsz, K, emb
i_vectors = self.i_embeddings(i_ids)
if feed_dict['phase'] == 'train':
target_vector = i_vectors[:, 0] # bsz, emb
target_pred = (interest_vectors * target_vector[:, None, :]).sum(-1) # bsz, K
idx_select = target_pred.max(-1)[1] # bsz
user_vector = interest_vectors[torch.arange(batch_size), idx_select, :] # bsz, emb
prediction = (user_vector[:, None, :] * i_vectors).sum(-1)
else:
prediction = (interest_vectors[:, None, :, :] * i_vectors[:, :, None, :]).sum(-1) # bsz, -1, K
prediction = prediction.max(-1)[0] # bsz, -1
return {'prediction': prediction.view(batch_size, -1)}