Skip to content
This repository was archived by the owner on Jul 10, 2024. It is now read-only.
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit ab7b4b7

Browse files
committedJul 18, 2020
add afm
1 parent fa151e5 commit ab7b4b7

File tree

1 file changed

+97
-0
lines changed
  • submarine-sdk/pysubmarine/submarine/ml/pytorch/model/ctr

1 file changed

+97
-0
lines changed
 
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one or more
2+
# contributor license agreements. See the NOTICE file distributed with
3+
# this work for additional information regarding copyright ownership.
4+
# The ASF licenses this file to You under the Apache License, Version 2.0
5+
# (the "License"); you may not use this file except in compliance with
6+
# the License. You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import torch
17+
from torch import nn
18+
19+
from submarine.ml.pytorch.layers.core import (FeatureEmbedding, FeatureLinear)
20+
from submarine.ml.pytorch.model.base_pytorch_model import BasePyTorchModel
21+
22+
23+
class AttentionalFM(BasePyTorchModel):
24+
25+
def model_fn(self, params):
26+
super().model_fn(params)
27+
return _AttentionalFM(**self.params['model']['kwargs'])
28+
29+
30+
class _AttentionalFM(nn.Module):
31+
32+
def __init__(self, num_features: int, embedding_dim: int,
33+
attention_dim: int, out_features: int, dropout_rate: float,
34+
**kwargs):
35+
super().__init__()
36+
self.feature_linear = FeatureLinear(num_features=num_features,
37+
out_features=out_features)
38+
self.feature_embedding = FeatureEmbedding(num_features=num_features,
39+
embedding_dim=embedding_dim)
40+
self.attentional_interaction = AttentionalInteratction(
41+
embedding_dim=embedding_dim,
42+
attention_dim=attention_dim,
43+
out_features=out_features,
44+
dropout_rate=dropout_rate)
45+
46+
def forward(self, feature_idx: torch.LongTensor,
47+
feature_value: torch.LongTensor):
48+
"""
49+
:param feature_idx: torch.LongTensor (batch_size, num_fields)
50+
:param feature_value: torch.LongTensor (batch_size, num_fields)
51+
"""
52+
return self.feature_linear(
53+
feature_idx, feature_value) + self.attentional_interaction(
54+
self.feature_embedding(feature_idx, feature_value))
55+
56+
57+
class AttentionalInteratction(nn.Module):
58+
59+
def __init__(self, embedding_dim: int, attention_dim: int,
60+
out_features: int, dropout_rate: float):
61+
super().__init__()
62+
self.attention_score = nn.Sequential(
63+
nn.Linear(in_features=embedding_dim, out_features=attention_dim),
64+
nn.ReLU(), nn.Linear(in_features=attention_dim, out_features=1),
65+
nn.Softmax(dim=1))
66+
self.pairwise_product = PairwiseProduct()
67+
self.dropout = nn.Dropout(p=dropout_rate)
68+
self.fc = nn.Linear(in_features=embedding_dim,
69+
out_features=out_features)
70+
71+
def forward(self, x: torch.FloatTensor):
72+
"""
73+
:param x: torch.FloatTensor (batch_size, num_fields, embedding_dim)
74+
"""
75+
x = self.pairwise_product(x)
76+
score = self.attention_score(x)
77+
attentioned = torch.sum(score * x, dim=1)
78+
return self.fc(self.dropout(attentioned))
79+
80+
81+
class PairwiseProduct(nn.Module):
82+
83+
def __init__(self):
84+
super().__init__()
85+
86+
def forward(self, x: torch.FloatTensor):
87+
"""
88+
:param x: torch.FloatTensor (batch_sie, num_fields, embedding_dim)
89+
"""
90+
batch_size, num_fields, embedding_dim = x.size()
91+
92+
all_pairs_product = x.unsqueeze(dim=1) * x.unsqueeze(dim=2)
93+
idx_row, idx_col = torch.unbind(torch.triu_indices(num_fields,
94+
num_fields,
95+
offset=1),
96+
dim=0)
97+
return all_pairs_product[:, idx_row, idx_col]

0 commit comments

Comments
 (0)
This repository has been archived.