-
Notifications
You must be signed in to change notification settings - Fork 0
/
cnndcf.py
62 lines (53 loc) · 2.75 KB
/
cnndcf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import torch.nn.functional as F
from .oncf import ONCF
from .layers import MultilayerPerceptrons
# Done
class CNNDCF(ONCF):
"""
Outer Product-based Neural Collaborative Filtering.
- field_dims (list): List of features dimensions.
- embed_dim (int): Embedding dimension.
- filter_size (int): Convolution filter/depth/channel size. Default: 32.
- kernel_size (int): Convolution square-window size or convolving square-kernel size. Default: 2
- stride (int): Convolution stride. Default: 2.
- activation (str): Activation function applied across the convolution layers. Default: "relu".
- batch_norm (bool): If True, apply batch normalization during convolutions. Batch normalization is applied between activation and dropout layer across the convolution layers. Default: True.
- lr (float): Learning rate. Default: 1e-3.
- weight_decay (float): L2 regularization rate. Default: 1e-3.
- criterion: Criterion or objective or loss function. Default: F.mse_loss.
"""
def __init__(self, field_dims:list,
embed_dim:int = 32, #or 64
filter_size:int = 32, #or 64
kernel_size:int = 2,
stride:int = 2,
activation:str = "relu",
dropout:bool = 0.5,
batch_norm:bool = True,
lr:float = 1e-3,
weight_decay:float = 1e-3,
criterion = F.mse_loss):
super().__init__(field_dims, embed_dim, filter_size, kernel_size, stride, activation, dropout, batch_norm, lr, weight_decay, criterion)
self.save_hyperparameters()
# Residual connection layer
self.residual = MultilayerPerceptrons(input_dim = embed_dim ** 2,
hidden_dims = [embed_dim ** 2],
remove_last_dropout = True,
remove_last_batch_norm = True,
output_layer = None)
def forward(self, x):
# Outer product between user and item embeddings
x_embed = self.features_embedding(x.int())
user_embed, item_embed = x_embed[:, 0], x_embed[:, 1]
outer = torch.bmm(user_embed.unsqueeze(2), item_embed.unsqueeze(1))
# Residual connection
outer_flatten = torch.flatten(outer, start_dim = 1)
residual = self.residual(outer_flatten)
residual = residual.view(outer.shape)
outer_residual = outer + residual
# Unsqueeze outer product so that each matrix contain single depth for convolution
outer_residual = torch.unsqueeze(outer_residual, 1)
# Non linear on residual
y = self.cnn(outer_residual)
return y